Skip to content

Conversation

@galbria
Copy link
Contributor

@galbria galbria commented Oct 26, 2025

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR. Excited for FIBO to make strides!

I have left a bunch of comments, most of which should be easily resolvable. If not, please let me know.

Additionally, I think:

  • It'd be nice to include a code snippet for folks to test it out (@linoytsaban @asomoza).
  • Remove the custom block implementations from the PR, host them on the Hub (just like this one), and guide the users about how to use them alongside the pipeline.

```bash
hf auth login
```

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to talk a little more about how "control" is interfaced in the pipeline i.e., what users can do with the pipeline to take "control".

@@ -0,0 +1,446 @@
from typing import Any, Dict, List, Optional, Tuple, Union
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to add the licensing header.

from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZeroSingle
from ...models.transformers.transformer_bria import BriaAttnProcessor
from ...models.transformers.transformer_flux import FluxTransformerBlock
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6, is this a pattern we want to avoid? 👀

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. It would be ideal if we can use #Copy from and define a BriaTransformerBlock inside this file.

max_diff = np.abs(output_same_prompt - output_different_prompts).max()
assert max_diff > 1e-6

def test_image_output_shape(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we're only testing for a single image pair. Is that expected for this test?

@sayakpaul sayakpaul requested a review from DN6 October 27, 2025 08:30
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZeroSingle
from ...models.transformers.transformer_bria import BriaAttnProcessor
from ...models.transformers.transformer_flux import FluxTransformerBlock
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. It would be ideal if we can use #Copy from and define a BriaTransformerBlock inside this file.


processor = BriaAttnProcessor()

self.attn = Attention(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be done in a follow up, but we're moving towards defining all components of a model within a single file (with some exceptions for timestep embeddings and norms). This means defining a dedicated Attention class per model. e.g BriaAttention

If it's the same as the FluxAttention, we can use #Copied from

Reference:

class FluxAttention(torch.nn.Module, AttentionModuleMixin):

latents_scaled = [latent / latents_std + latents_mean for latent in latents]
latents_scaled = torch.cat(latents_scaled, dim=0)
image = []
for scaled_latent in latents_scaled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we can just use self.vae.decode on the latent here directly. Instance level decoding can be done by setting pipe.vae.enable_slicing()

return noise_scheduler, timesteps, num_inference_steps, mu

@staticmethod
def create_attention_matrix(attention_mask):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we name this _prepare_attention_mask for consistency with other pipelines?

def _prepare_attention_mask(

return latents, latent_image_ids

@staticmethod
def init_inference_scheduler(height, width, device, image_seq_len, num_inference_steps=1000, noise_scheduler=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to just have these steps in the __call__ method? Similar to

sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
sigmas = None
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

def init_inference_scheduler(height, width, device, image_seq_len, num_inference_steps=1000, noise_scheduler=None):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)

assert height % 16 == 0 and width % 16 == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checks should be placed under check_inputs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants