Skip to content

Attention masking in Chroma pipeline #12116

@dxqbYD

Description

@dxqbYD

Describe the bug

There is an issue with attention masking in the Chroma pipeline. With the prompt in your example here, https://huggingface.co/docs/diffusers/main/api/pipelines/chroma the difference is not very large, probably because there are enough meaningful tokens with some weight.

But short prompts fail because of incorrect masking.

Below are 3 sample pairs: first pair is the prompt in your example, second and third one is just "man" as positive prompt (negative prompt unchanged). First sample each with current code, second sample with correct masking.

The issue is the data type of the attention mask and how it is interpreted. It's created as a floating point mask, which is fine for its first use in T5: https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5ForSequenceClassification.forward.attention_mask

There, 1.0 is not masked and 0.0 is masked.

However it is then passed to the transformer in the same dtype. This attention mask eventually ends up at https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

There, 0.0 is not masked and -inf is masked.

The good samples below I've generated by changing the type here to torch.bool:

attention_mask = attention_mask.to(dtype)

But a good solution should probably directly convert the tokenizer output to bool, not take the detour via a floating point type.

Image Image --- Image Image --- Image Image

Reproduction

run this example, but with a short prompt:
https://huggingface.co/docs/diffusers/main/api/pipelines/chroma

Logs

System Info

diffusers HEAD, python 3.11.11

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions