-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Description
Describe the bug
There is an issue with attention masking in the Chroma pipeline. With the prompt in your example here, https://huggingface.co/docs/diffusers/main/api/pipelines/chroma the difference is not very large, probably because there are enough meaningful tokens with some weight.
But short prompts fail because of incorrect masking.
Below are 3 sample pairs: first pair is the prompt in your example, second and third one is just "man" as positive prompt (negative prompt unchanged). First sample each with current code, second sample with correct masking.
The issue is the data type of the attention mask and how it is interpreted. It's created as a floating point mask, which is fine for its first use in T5: https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5ForSequenceClassification.forward.attention_mask
There, 1.0 is not masked and 0.0 is masked.
However it is then passed to the transformer in the same dtype. This attention mask eventually ends up at https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
There, 0.0 is not masked and -inf is masked.
The good samples below I've generated by changing the type here to torch.bool:
attention_mask = attention_mask.to(dtype) |
But a good solution should probably directly convert the tokenizer output to bool, not take the detour via a floating point type.






Reproduction
run this example, but with a short prompt:
https://huggingface.co/docs/diffusers/main/api/pipelines/chroma
Logs
System Info
diffusers HEAD, python 3.11.11
Who can help?
No response