-
Notifications
You must be signed in to change notification settings - Fork 50
New harder task #405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: hypergrid_refactor
Are you sure you want to change the base?
New harder task #405
Conversation
younik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self._n_modes_via_ids_estimate = float(torch.unique(ids).numel()) | ||
| self._mode_stats_kind = "approx" | ||
| except Exception: | ||
| warnings.warn("+ Warning: Failed to compute mode_stats (skipping).") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to use logger.exception here, to print the exception as well
Also it would be better to avoid catching Exception in general. Why this can fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would catch the ValueError in "exact" branch as well. Is this what we want? Should we catch at all?
| # Cheap exact threshold (up to ~200k states) | ||
| if self.n_states <= 200_000: | ||
| axes = [ | ||
| torch.arange(self.height, dtype=torch.long) for _ in range(self.ndim) | ||
| ] | ||
| grid = torch.cartesian_prod(*axes) | ||
| rewards = self.reward_fn(grid) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how did you come up with this number? Doing the cartesian product seems memory intensive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this number might need to be lowered. It was arbitrary.
| except Exception: | ||
| # Fall back to heuristic paths below | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a logger
I don't think in general it is a good idea to mask a lot of stuff to the user. Sometimes we compute the exact mode existence, sometimes we use heuristic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, agreed
| for col in range(m): | ||
| # Find pivot | ||
| piv = None | ||
| for r in range(row, k): | ||
| if A[r, col]: | ||
| piv = r | ||
| break | ||
| if piv is None: | ||
| continue | ||
| # Swap | ||
| if piv != row: | ||
| A[[row, piv]] = A[[piv, row]] | ||
| c[[row, piv]] = c[[piv, row]] | ||
| # Eliminate below | ||
| for r in range(row + 1, k): | ||
| if A[r, col]: | ||
| A[r, :] ^= A[row, :] | ||
| c[r] ^= c[row] | ||
| row += 1 | ||
| if row == k: | ||
| break | ||
| # Check for inconsistency: 0 = 1 rows | ||
| for r in range(k): | ||
| if not A[r, :].any() and c[r]: | ||
| return False | ||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't check the details tbh, but it seems quite inefficient and not easily readable. Can we rely to scipy for these stuffs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll look into it
| """ | ||
| with torch.no_grad(): | ||
| device = torch.device("cpu") | ||
| B = min(2048, max(128, 8 * self.ndim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are these numbers? Maybe use constant to improve clarity
| try: | ||
| all_states = self.all_states | ||
| if all_states is not None: | ||
| mask = self.mode_mask(all_states) | ||
| ids = self.mode_ids(all_states) | ||
| ids = ids[mask] | ||
| ids = ids[ids >= 0] | ||
| return int(torch.unique(ids).numel()) | ||
| except Exception: | ||
| pass | ||
| if self._mode_stats_kind == "exact" and self._n_modes_via_ids_exact is not None: | ||
| return int(self._n_modes_via_ids_exact) | ||
| if ( | ||
| self._mode_stats_kind == "approx" | ||
| and self._n_modes_via_ids_estimate is not None | ||
| ): | ||
| return int(self._n_modes_via_ids_estimate) | ||
|
|
||
| return 2**self.ndim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to recompute this every time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no you're right it should be stored.
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to other comment, this is not nice for debuggability
|
Hi @younik - I hear you, this is a big PR. The "splits" would have to be along tasks, though, so the resulting PRs would still be large. I appreciate your comments on the code. I think it would make sense to also look at the tasks (the stuff that's plotted in the notebook) to see if they make sense. I'm not convinced by all of the tasks. I would be open to removing a task or two. I think the one that works best for it's intended purpose is the coprime reward. |
|
In the above commit, I fixed the comments of Deceptive Reward and also fixed a pyright error. |
I do think |
hypergrid refactor
saleml
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is high-quality research code with excellent mathematical foundations and thorough testing. The main concerns are:
- Complexity barrier for new users
- Performance documentation gaps
- Some missing edge-case handling
A few questions and suggestions
- The new reward functions are mathematically sophisticated (GF(2) algebra, prime factorization, etc.). While excellent for research, the barrier to entry is high
Suggestion: Add a "Quick Start" section to the documentation showing simple use cases before diving into the mathematical details. - The
_solve_gf2_has_solutionmethod uses Gaussian elimination which could be slow for large constraint systems
Suggestion: Add performance warnings in docstrings - Why GF(2)? The choice is elegant but not obvious. Could you add a paragraph in the documentation explaining why linear algebra over GF(2) is natural for compositional structure?
- What happens if a user picks "impossible" preset with ndim=2, height=16? Should the factory functions validate compatibility?
- Can you consider adding type hings for kwargs? something like
class BitwiseXORRewardKwargs(TypedDict, total=False):
R0: float
tier_weights: list[float]
dims_constrained: list[int]
bits_per_tier: list[tuple[int, int]]
parity_checks: list[dict] | None
- What do you think of adding visualization helpers
def visualize_mode_structure(env: HyperGrid, sample_size: int = 10000):
"""Generate 2D/3D plots of mode distribution."""
# Auto-generate plots similar to notebook but as API
I'd also like to suggest a structural consideration:
The original HyperGrid has become a pedagogical entrypoint of the GFlowNets library:
- It's the first environment new users encounter
- Its simplicity (grid + distance-based reward) makes it ideal for teaching core concepts
- Tutorial code often uses it as the "Hello World" of GFlowNets
- The cognitive load is intentionally minimal: "navigate a grid, reach high-reward corners"
Can we consider creating a separate file src/gfn/gym/compositional_hypergrid.py that:
- Inherits from HyperGrid to reuse the core grid mechanics
- Houses the new reward families (BitwiseXOR, MultiplicativeCoprime, TemplateMinkowski)
- Includes the sophisticated mode validation and statistics machinery
- Keeps the original simple and focused on accessibility
| mode_stats_samples: Number of random samples used when | ||
| `mode_stats="approx"`. | ||
| """ | ||
| if height <= 4: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was removed but the condition is still relevant. Should this warning be reinstated or is it now handled by validate_modes
| ax = (idx / Hm1 - 0.5).abs() | ||
| pdf = (1.0 / sqrt(2 * pi)) * torch.exp(-0.5 * (5 * ax) ** 2) | ||
| per_dim_discrete = float(((torch.cos(50 * ax) + 1.0) * pdf).max()) | ||
| per_dim_base = per_dim_discrete if self.height > 4 else per_dim_peak |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Magic number: height > 4 threshold. Should this be documented or parameterized?
| return bool((rr >= thr - EPS_REWARD_CMP).any().item()) | ||
|
|
||
| @staticmethod | ||
| def _solve_gf2_has_solution(A: torch.Tensor, c: torch.Tensor) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain the GF(2) algorithm in the docstring?
Also, this could be slow for large constraint systems
Suggestion: Add complexity note in docstring (O(k·m²) for k×m matrix

Description