From 356b4741d99317e473d8389e25916c2fdf9dac27 Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 11:51:51 +0800 Subject: [PATCH 01/13] [Fix] Move attention mask padding after T5 embedding --- .../pipelines/chroma/pipeline_chroma.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 5482035b3afb..2b67cd20f5ce 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -216,13 +216,13 @@ def _get_t5_prompt_embeds( ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype - + prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - + if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -232,21 +232,25 @@ def _get_t5_prompt_embeds( return_overflowing_tokens=False, return_tensors="pt", ) + text_input_ids = text_inputs.input_ids - attention_mask = text_inputs.attention_mask.clone() + tokenizer_mask = text_inputs.attention_mask # keep the raw tokenizer mask - # Chroma requires the attention mask to include one padding token - seq_lengths = attention_mask.sum(dim=1) - mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1) - attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool() + tokenizer_mask_device = tokenizer_mask.to(device) prompt_embeds = self.text_encoder( - text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device) + text_input_ids.to(device), + output_hidden_states=False, + attention_mask=tokenizer_mask_device, )[0] - dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - attention_mask = attention_mask.to(device=device) + + seq_lengths = tokenizer_mask_device.sum(dim=1) + mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand( + batch_size, -1 + ) + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -259,6 +263,7 @@ def _get_t5_prompt_embeds( return prompt_embeds, attention_mask + def encode_prompt( self, prompt: Union[str, List[str]], From 51931edda113ebbb1963b8210bb686c5cca0c36d Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 11:55:11 +0800 Subject: [PATCH 02/13] [Fix] Move attention mask padding after T5 embedding --- .../chroma/pipeline_chroma_img2img.py | 58 ++++++------------- 1 file changed, 18 insertions(+), 40 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py index 9afd4b9e1577..9546fa1e35ed 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py @@ -25,7 +25,6 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, - deprecate, is_torch_xla_available, logging, replace_example_docstring, @@ -247,20 +246,23 @@ def _get_t5_prompt_embeds( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - attention_mask = text_inputs.attention_mask.clone() + tokenizer_mask = text_inputs.attention_mask - # Chroma requires the attention mask to include one padding token - seq_lengths = attention_mask.sum(dim=1) - mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1) - attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long() + tokenizer_mask_device = tokenizer_mask.to(device) prompt_embeds = self.text_encoder( - text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device) + text_input_ids.to(device), + output_hidden_states=False, + attention_mask=tokenizer_mask_device, )[0] - dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - attention_mask = attention_mask.to(dtype=dtype, device=device) + + seq_lengths = tokenizer_mask_device.sum(dim=1) + mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand( + batch_size, -1 + ) + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -543,12 +545,6 @@ def enable_vae_slicing(self): Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ - depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." - deprecate( - "enable_vae_slicing", - "0.40.0", - depr_message, - ) self.vae.enable_slicing() def disable_vae_slicing(self): @@ -556,12 +552,6 @@ def disable_vae_slicing(self): Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ - depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." - deprecate( - "disable_vae_slicing", - "0.40.0", - depr_message, - ) self.vae.disable_slicing() def enable_vae_tiling(self): @@ -570,12 +560,6 @@ def enable_vae_tiling(self): compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ - depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." - deprecate( - "enable_vae_tiling", - "0.40.0", - depr_message, - ) self.vae.enable_tiling() def disable_vae_tiling(self): @@ -583,12 +567,6 @@ def disable_vae_tiling(self): Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ - depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." - deprecate( - "disable_vae_tiling", - "0.40.0", - depr_message, - ) self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps @@ -749,12 +727,12 @@ def __call__( Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. - guidance_scale (`float`, *optional*, defaults to 3.5): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + guidance_scale (`float`, *optional*, defaults to 5.0): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. strength (`float, *optional*, defaults to 0.9): Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will be used as a starting point, adding more noise to it the larger the strength. The number of denoising @@ -769,7 +747,7 @@ def __call__( latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will be generated by sampling using the supplied random `generator`. + tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. From 0debc37a413afdce788b0ad5623f24b27d14dd95 Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 11:55:35 +0800 Subject: [PATCH 03/13] Clean up whitespace in pipeline_chroma.py Removed unnecessary blank lines for cleaner code. --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 2b67cd20f5ce..a01a6c0a3398 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -216,13 +216,13 @@ def _get_t5_prompt_embeds( ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype - + prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - + if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -232,7 +232,7 @@ def _get_t5_prompt_embeds( return_overflowing_tokens=False, return_tensors="pt", ) - + text_input_ids = text_inputs.input_ids tokenizer_mask = text_inputs.attention_mask # keep the raw tokenizer mask @@ -263,7 +263,6 @@ def _get_t5_prompt_embeds( return prompt_embeds, attention_mask - def encode_prompt( self, prompt: Union[str, List[str]], From 13529b93d68217324ccd661a17fd77c018186b52 Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 12:01:06 +0800 Subject: [PATCH 04/13] Fix --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index a01a6c0a3398..1441b7600576 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -232,9 +232,8 @@ def _get_t5_prompt_embeds( return_overflowing_tokens=False, return_tensors="pt", ) - text_input_ids = text_inputs.input_ids - tokenizer_mask = text_inputs.attention_mask # keep the raw tokenizer mask + tokenizer_mask = text_inputs.attention_mask tokenizer_mask_device = tokenizer_mask.to(device) @@ -247,9 +246,7 @@ def _get_t5_prompt_embeds( prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) seq_lengths = tokenizer_mask_device.sum(dim=1) - mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand( - batch_size, -1 - ) + mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1) attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape From f819059c416be7baa7c5691a10582b757284c099 Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 12:01:38 +0800 Subject: [PATCH 05/13] Fix --- .../chroma/pipeline_chroma_img2img.py | 39 +++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py index 9546fa1e35ed..50fa7a5e5273 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py @@ -25,6 +25,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, + deprecate, is_torch_xla_available, logging, replace_example_docstring, @@ -545,6 +546,12 @@ def enable_vae_slicing(self): Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) self.vae.enable_slicing() def disable_vae_slicing(self): @@ -552,6 +559,12 @@ def disable_vae_slicing(self): Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) self.vae.disable_slicing() def enable_vae_tiling(self): @@ -560,6 +573,12 @@ def enable_vae_tiling(self): compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) self.vae.enable_tiling() def disable_vae_tiling(self): @@ -567,6 +586,12 @@ def disable_vae_tiling(self): Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps @@ -727,12 +752,12 @@ def __call__( Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. - guidance_scale (`float`, *optional*, defaults to 5.0): - Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages - a model to generate images more aligned with `prompt` at the expense of lower image quality. - - Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to - the [paper](https://huggingface.co/papers/2210.03142) to learn more. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. strength (`float, *optional*, defaults to 0.9): Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will be used as a starting point, adding more noise to it the larger the strength. The number of denoising @@ -747,7 +772,7 @@ def __call__( latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. + tensor will be generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. From 4e2ca2c1d11f8542cdb9fbaf3d56c5b00afa7ec9 Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 12:09:04 +0800 Subject: [PATCH 06/13] Update model to final Chroma1-HD checkpoint --- src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py index 50fa7a5e5273..ec7de15042eb 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py @@ -53,8 +53,8 @@ >>> import torch >>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline - >>> model_id = "lodestones/Chroma" - >>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors" + >>> model_id = "lodestones/Chroma1-HD" + >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors" >>> pipe = ChromaImg2ImgPipeline.from_pretrained( ... model_id, ... transformer=transformer, @@ -170,7 +170,7 @@ class ChromaImg2ImgPipeline( r""" The Chroma pipeline for image-to-image generation. - Reference: https://huggingface.co/lodestones/Chroma/ + Reference: https://huggingface.co/lodestones/Chroma1-HD/ Args: transformer ([`ChromaTransformer2DModel`]): From b3753bdc70f9e06ab96cb234b8130af41438b5ad Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 12:09:50 +0800 Subject: [PATCH 07/13] Update to Chroma1-HD --- src/diffusers/models/transformers/transformer_chroma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 5823ae9d3da6..2ef3643dafbd 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -379,7 +379,7 @@ class ChromaTransformer2DModel( """ The Transformer model introduced in Flux, modified for Chroma. - Reference: https://huggingface.co/lodestones/Chroma + Reference: https://huggingface.co/lodestones/Chroma1-HD Args: patch_size (`int`, defaults to `1`): From 23a8518d77b505bc296a8cf759dd3796a9885cec Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 12:10:46 +0800 Subject: [PATCH 08/13] Update model to Chroma1-HD --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 1441b7600576..e8a2a90aab53 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -53,8 +53,8 @@ >>> import torch >>> from diffusers import ChromaPipeline - >>> model_id = "lodestones/Chroma" - >>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors" + >>> model_id = "lodestones/Chroma1-HD" + >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors" >>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) >>> pipe = ChromaPipeline.from_pretrained( ... model_id, @@ -158,7 +158,7 @@ class ChromaPipeline( r""" The Chroma pipeline for text-to-image generation. - Reference: https://huggingface.co/lodestones/Chroma/ + Reference: https://huggingface.co/lodestones/Chroma1-HD/ Args: transformer ([`ChromaTransformer2DModel`]): From a93e017d2649bc79493ad3db7e10e51d6fd4b47e Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 12:11:23 +0800 Subject: [PATCH 09/13] Update model to Chroma1-HD --- docs/source/en/api/models/chroma_transformer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/models/chroma_transformer.md b/docs/source/en/api/models/chroma_transformer.md index 681e81f7a584..1ef24cda3925 100644 --- a/docs/source/en/api/models/chroma_transformer.md +++ b/docs/source/en/api/models/chroma_transformer.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # ChromaTransformer2DModel -A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma) +A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD) ## ChromaTransformer2DModel From 5855bda20ea46849edece1f4510d5e27fc17aa49 Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 12:13:03 +0800 Subject: [PATCH 10/13] Update Chroma model links to Chroma1-HD --- docs/source/en/api/pipelines/chroma.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md index df03fbb325d7..6bf7858b74f9 100644 --- a/docs/source/en/api/pipelines/chroma.md +++ b/docs/source/en/api/pipelines/chroma.md @@ -19,20 +19,20 @@ specific language governing permissions and limitations under the License. Chroma is a text to image generation model based on Flux. -Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma). +Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma1-HD). > [!TIP] > Chroma can use all the same optimizations as Flux. ## Inference -The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma). +The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma1-HD). ```python import torch from diffusers import ChromaPipeline -pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16) +pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() prompt = [ @@ -63,10 +63,10 @@ Then run the following example import torch from diffusers import ChromaTransformer2DModel, ChromaPipeline -model_id = "lodestones/Chroma" +model_id = "lodestones/Chroma1-HD" dtype = torch.bfloat16 -transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype) +transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype) pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype) pipe.enable_model_cpu_offload() From 9dce45183f8d72ffa41e25441df6786e656d862e Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 16:56:12 +0800 Subject: [PATCH 11/13] Add comment about padding/masking --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index e8a2a90aab53..ed6c2c2105b6 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -237,6 +237,7 @@ def _get_t5_prompt_embeds( tokenizer_mask_device = tokenizer_mask.to(device) + # unlike FLUX, Chroma uses the attention mask when generating the T5 embedding prompt_embeds = self.text_encoder( text_input_ids.to(device), output_hidden_states=False, @@ -245,6 +246,7 @@ def _get_t5_prompt_embeds( prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + # for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer seq_lengths = tokenizer_mask_device.sum(dim=1) mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1) attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) From 1fe64c2f6c215d4a2d8e73a1245bf2f462fabf34 Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Sun, 19 Oct 2025 17:03:55 +0800 Subject: [PATCH 12/13] Fix checkpoint/repo references --- docs/source/en/api/pipelines/chroma.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md index 6bf7858b74f9..cc52ffa09a6d 100644 --- a/docs/source/en/api/pipelines/chroma.md +++ b/docs/source/en/api/pipelines/chroma.md @@ -19,15 +19,16 @@ specific language governing permissions and limitations under the License. Chroma is a text to image generation model based on Flux. -Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma1-HD). +Original model checkpoints for Chroma can be found here: +* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD) +* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) +* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint) > [!TIP] > Chroma can use all the same optimizations as Flux. ## Inference -The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma1-HD). - ```python import torch from diffusers import ChromaPipeline From ac4b099d3c620a173c19e1ee7aa71bc6539d18ab Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Oct 2025 05:03:24 +0000 Subject: [PATCH 13/13] Apply style fixes --- src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py index ec7de15042eb..470c746e4146 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py @@ -260,9 +260,7 @@ def _get_t5_prompt_embeds( prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) seq_lengths = tokenizer_mask_device.sum(dim=1) - mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand( - batch_size, -1 - ) + mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1) attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape