Skip to content

Bug in initialization of UNet1DModel GaussianFourier time projection #12110

@SammyAgrawal

Description

@SammyAgrawal

Describe the bug

TLDR:
In the constructor for the UNet1dModel (line 104, seen here), the embedding size is manually hardcoded to be 8 for no apparent good reason. Instead it should be block_out_channels[0].

Explanation:
The embedding_size argument is passed into GaussianFourierProjection and determines the output dimension of self.time_proj. If the user is using timestep embeddings, the output of self.time_proj is fed into the timestep embedding MLP. The input dimensionality of this feedforward ANN is defined as 2*block_out_channels[0] but the input fed into it is going to always be 2*self.time_proj.embedding_dim which is hardcoded as 8. You can see below that a Positional time embedding is initialized based on block_out_channels[0]; only the Gaussian is hardcoded. I think this is a very simple, very easily fixable bug.

Reproduction

import diffusers
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
unet = diffusers.UNet1DModel(use_timestep_embedding=True, act_fn='silu').to(device)
unet(torch.randn(32, unet.config.in_channels, 64).to(device), 0)

Logs

RuntimeError                              Traceback (most recent call last)
Cell In[272], line 1
----> 1 unet(torch.randn(32, unet.config.in_channels, 64).to(device), 0)

File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /srv/conda/envs/notebook/lib/python3.12/site-packages/diffusers/models/unets/unet_1d.py:228, in UNet1DModel.forward(self, sample, timestep, return_dict)
    226 timestep_embed = self.time_proj(timesteps)
    227 if self.config.use_timestep_embedding:
--> 228     timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype))
    229 else:
    230     timestep_embed = timestep_embed[..., None]

File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /srv/conda/envs/notebook/lib/python3.12/site-packages/diffusers/models/embeddings.py:1308, in TimestepEmbedding.forward(self, sample, condition)
   1306 if condition is not None:
   1307     sample = sample + self.cond_proj(condition)
-> 1308 sample = self.linear_1(sample)
   1310 if self.act is not None:
   1311     sample = self.act(sample)

File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/linear.py:125, in Linear.forward(self, input)
    124 def forward(self, input: Tensor) -> Tensor:
--> 125     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16 and 64x128)

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • 🤗 Diffusers version: 0.34.0
  • Platform: Linux-6.6.56+-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.12.11
  • PyTorch version (GPU?): 2.5.1.post303 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.33.1
  • Transformers version: not installed
  • Accelerate version: not installed
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: Tesla T4, 15360 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions