Skip to content

Conversation

@josephdviviano
Copy link
Collaborator

  • I've read the .github/CONTRIBUTING.md file
  • My code follows the typing guidelines
  • I've added appropriate tests
  • I've run pre-commit hooks locally

Description

  • Added 3 new hypergrid tasks which should be more challenging. Note that the specifics are very much up for debate. I tried to identify environments which were easy to divide + conquer vs those which require compositional knowledge (and therefore some amount of knowledge sharing among agents in a multi-agent setting).
  • Added mode verification logic (to ensure that your particular configuration actually contains modes to find).
  • Added lots of tests around these new rewards.
  • Added visualizations of the reward landscape for these various rewards.

@josephdviviano josephdviviano self-assigned this Oct 3, 2025
@josephdviviano josephdviviano added the enhancement New feature or request label Oct 3, 2025
Copy link
Collaborator

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not able to review 1,000+ math-dense LOC for hypergrid.py :(
If you want a careful review, consider splitting this.

151sj28e8yab1

self._n_modes_via_ids_estimate = float(torch.unique(ids).numel())
self._mode_stats_kind = "approx"
except Exception:
warnings.warn("+ Warning: Failed to compute mode_stats (skipping).")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use logger.exception here, to print the exception as well

Also it would be better to avoid catching Exception in general. Why this can fail?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would catch the ValueError in "exact" branch as well. Is this what we want? Should we catch at all?

Comment on lines +564 to +570
# Cheap exact threshold (up to ~200k states)
if self.n_states <= 200_000:
axes = [
torch.arange(self.height, dtype=torch.long) for _ in range(self.ndim)
]
grid = torch.cartesian_prod(*axes)
rewards = self.reward_fn(grid)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did you come up with this number? Doing the cartesian product seems memory intensive

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this number might need to be lowered. It was arbitrary.

Comment on lines +572 to +574
except Exception:
# Fall back to heuristic paths below
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a logger
I don't think in general it is a good idea to mask a lot of stuff to the user. Sometimes we compute the exact mode existence, sometimes we use heuristic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, agreed

Comment on lines +849 to +874
for col in range(m):
# Find pivot
piv = None
for r in range(row, k):
if A[r, col]:
piv = r
break
if piv is None:
continue
# Swap
if piv != row:
A[[row, piv]] = A[[piv, row]]
c[[row, piv]] = c[[piv, row]]
# Eliminate below
for r in range(row + 1, k):
if A[r, col]:
A[r, :] ^= A[row, :]
c[r] ^= c[row]
row += 1
if row == k:
break
# Check for inconsistency: 0 = 1 rows
for r in range(k):
if not A[r, :].any() and c[r]:
return False
return True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't check the details tbh, but it seems quite inefficient and not easily readable. Can we rely to scipy for these stuffs?

https://stackoverflow.com/questions/15638650/is-there-a-standard-solution-for-gauss-elimination-in-python

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into it

"""
with torch.no_grad():
device = torch.device("cpu")
B = min(2048, max(128, 8 * self.ndim))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these numbers? Maybe use constant to improve clarity

Comment on lines +470 to 488
try:
all_states = self.all_states
if all_states is not None:
mask = self.mode_mask(all_states)
ids = self.mode_ids(all_states)
ids = ids[mask]
ids = ids[ids >= 0]
return int(torch.unique(ids).numel())
except Exception:
pass
if self._mode_stats_kind == "exact" and self._n_modes_via_ids_exact is not None:
return int(self._n_modes_via_ids_exact)
if (
self._mode_stats_kind == "approx"
and self._n_modes_via_ids_estimate is not None
):
return int(self._n_modes_via_ids_estimate)

return 2**self.ndim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to recompute this every time?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no you're right it should be stored.

Comment on lines +478 to +479
except Exception:
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to other comment, this is not nice for debuggability

@josephdviviano
Copy link
Collaborator Author

Hi @younik - I hear you, this is a big PR. The "splits" would have to be along tasks, though, so the resulting PRs would still be large.

I appreciate your comments on the code. I think it would make sense to also look at the tasks (the stuff that's plotted in the notebook) to see if they make sense. I'm not convinced by all of the tasks.

I would be open to removing a task or two. I think the one that works best for it's intended purpose is the coprime reward.

@hyeok9855
Copy link
Collaborator

In the above commit, I fixed the comments of Deceptive Reward and also fixed a pyright error.

@hyeok9855
Copy link
Collaborator

I would be open to removing a task or two. I think the one that works best for it's intended purpose is the coprime reward.

I do think Template Minkowski and Bitwise/XOR rewards are not very interesting to benchmark, especially if you care about the mode coverage. Multiplicative/Coprime seems challenging, but you may want to increase the reward for further modes from the origin.

Copy link
Collaborator

@saleml saleml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is high-quality research code with excellent mathematical foundations and thorough testing. The main concerns are:

  • Complexity barrier for new users
  • Performance documentation gaps
  • Some missing edge-case handling

A few questions and suggestions

  • The new reward functions are mathematically sophisticated (GF(2) algebra, prime factorization, etc.). While excellent for research, the barrier to entry is high
    Suggestion: Add a "Quick Start" section to the documentation showing simple use cases before diving into the mathematical details.
  • The _solve_gf2_has_solution method uses Gaussian elimination which could be slow for large constraint systems
    Suggestion: Add performance warnings in docstrings
  • Why GF(2)? The choice is elegant but not obvious. Could you add a paragraph in the documentation explaining why linear algebra over GF(2) is natural for compositional structure?
  • What happens if a user picks "impossible" preset with ndim=2, height=16? Should the factory functions validate compatibility?
  • Can you consider adding type hings for kwargs? something like
class BitwiseXORRewardKwargs(TypedDict, total=False):
    R0: float
    tier_weights: list[float]
    dims_constrained: list[int]
    bits_per_tier: list[tuple[int, int]]
    parity_checks: list[dict] | None
  • What do you think of adding visualization helpers
def visualize_mode_structure(env: HyperGrid, sample_size: int = 10000):
    """Generate 2D/3D plots of mode distribution."""
    # Auto-generate plots similar to notebook but as API

I'd also like to suggest a structural consideration:
The original HyperGrid has become a pedagogical entrypoint of the GFlowNets library:

  • It's the first environment new users encounter
  • Its simplicity (grid + distance-based reward) makes it ideal for teaching core concepts
  • Tutorial code often uses it as the "Hello World" of GFlowNets
  • The cognitive load is intentionally minimal: "navigate a grid, reach high-reward corners"

Can we consider creating a separate file src/gfn/gym/compositional_hypergrid.py that:

  • Inherits from HyperGrid to reuse the core grid mechanics
  • Houses the new reward families (BitwiseXOR, MultiplicativeCoprime, TemplateMinkowski)
  • Includes the sophisticated mode validation and statistics machinery
  • Keeps the original simple and focused on accessibility

mode_stats_samples: Number of random samples used when
`mode_stats="approx"`.
"""
if height <= 4:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was removed but the condition is still relevant. Should this warning be reinstated or is it now handled by validate_modes

ax = (idx / Hm1 - 0.5).abs()
pdf = (1.0 / sqrt(2 * pi)) * torch.exp(-0.5 * (5 * ax) ** 2)
per_dim_discrete = float(((torch.cos(50 * ax) + 1.0) * pdf).max())
per_dim_base = per_dim_discrete if self.height > 4 else per_dim_peak
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Magic number: height > 4 threshold. Should this be documented or parameterized?

return bool((rr >= thr - EPS_REWARD_CMP).any().item())

@staticmethod
def _solve_gf2_has_solution(A: torch.Tensor, c: torch.Tensor) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the GF(2) algorithm in the docstring?
Also, this could be slow for large constraint systems
Suggestion: Add complexity note in docstring (O(k·m²) for k×m matrix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants