-
Couldn't load subscription status.
- Fork 6.4k
Bria fibo #12545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Bria fibo #12545
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR. Excited for FIBO to make strides!
I have left a bunch of comments, most of which should be easily resolvable. If not, please let me know.
Additionally, I think:
- It'd be nice to include a code snippet for folks to test it out (@linoytsaban @asomoza).
- Remove the custom block implementations from the PR, host them on the Hub (just like this one), and guide the users about how to use them alongside the pipeline.
| ```bash | ||
| hf auth login | ||
| ``` | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to talk a little more about how "control" is interfaced in the pipeline i.e., what users can do with the pipeline to take "control".
| @@ -0,0 +1,446 @@ | |||
| from typing import Any, Dict, List, Optional, Tuple, Union | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to add the licensing header.
| from ...models.modeling_utils import ModelMixin | ||
| from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZeroSingle | ||
| from ...models.transformers.transformer_bria import BriaAttnProcessor | ||
| from ...models.transformers.transformer_flux import FluxTransformerBlock |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6, is this a pattern we want to avoid? 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup. It would be ideal if we can use #Copy from and define a BriaTransformerBlock inside this file.
| max_diff = np.abs(output_same_prompt - output_different_prompts).max() | ||
| assert max_diff > 1e-6 | ||
|
|
||
| def test_image_output_shape(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we're only testing for a single image pair. Is that expected for this test?
| from ...models.modeling_utils import ModelMixin | ||
| from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZeroSingle | ||
| from ...models.transformers.transformer_bria import BriaAttnProcessor | ||
| from ...models.transformers.transformer_flux import FluxTransformerBlock |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup. It would be ideal if we can use #Copy from and define a BriaTransformerBlock inside this file.
|
|
||
| processor = BriaAttnProcessor() | ||
|
|
||
| self.attn = Attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be done in a follow up, but we're moving towards defining all components of a model within a single file (with some exceptions for timestep embeddings and norms). This means defining a dedicated Attention class per model. e.g BriaAttention
If it's the same as the FluxAttention, we can use #Copied from
Reference:
| class FluxAttention(torch.nn.Module, AttentionModuleMixin): |
| latents_scaled = [latent / latents_std + latents_mean for latent in latents] | ||
| latents_scaled = torch.cat(latents_scaled, dim=0) | ||
| image = [] | ||
| for scaled_latent in latents_scaled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think we can just use self.vae.decode on the latent here directly. Instance level decoding can be done by setting pipe.vae.enable_slicing()
| return noise_scheduler, timesteps, num_inference_steps, mu | ||
|
|
||
| @staticmethod | ||
| def create_attention_matrix(attention_mask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we name this _prepare_attention_mask for consistency with other pipelines?
| def _prepare_attention_mask( |
| return latents, latent_image_ids | ||
|
|
||
| @staticmethod | ||
| def init_inference_scheduler(height, width, device, image_seq_len, num_inference_steps=1000, noise_scheduler=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible to just have these steps in the __call__ method? Similar to
diffusers/src/diffusers/pipelines/flux/pipeline_flux.py
Lines 868 to 887 in dc6bd15
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas | |
| if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: | |
| sigmas = None | |
| image_seq_len = latents.shape[1] | |
| mu = calculate_shift( | |
| image_seq_len, | |
| self.scheduler.config.get("base_image_seq_len", 256), | |
| self.scheduler.config.get("max_image_seq_len", 4096), | |
| self.scheduler.config.get("base_shift", 0.5), | |
| self.scheduler.config.get("max_shift", 1.15), | |
| ) | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, | |
| num_inference_steps, | |
| device, | |
| sigmas=sigmas, | |
| mu=mu, | |
| ) | |
| num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | |
| self._num_timesteps = len(timesteps) |
| def init_inference_scheduler(height, width, device, image_seq_len, num_inference_steps=1000, noise_scheduler=None): | ||
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | ||
|
|
||
| assert height % 16 == 0 and width % 16 == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checks should be placed under check_inputs
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.