-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Description
Describe the bug
TLDR:
In the constructor for the UNet1dModel (line 104, seen here), the embedding size is manually hardcoded to be 8 for no apparent good reason. Instead it should be block_out_channels[0].
Explanation:
The embedding_size argument is passed into GaussianFourierProjection and determines the output dimension of self.time_proj. If the user is using timestep embeddings, the output of self.time_proj is fed into the timestep embedding MLP. The input dimensionality of this feedforward ANN is defined as 2*block_out_channels[0]
but the input fed into it is going to always be 2*self.time_proj.embedding_dim
which is hardcoded as 8. You can see below that a Positional time embedding is initialized based on block_out_channels[0]; only the Gaussian is hardcoded. I think this is a very simple, very easily fixable bug.
Reproduction
import diffusers
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
unet = diffusers.UNet1DModel(use_timestep_embedding=True, act_fn='silu').to(device)
unet(torch.randn(32, unet.config.in_channels, 64).to(device), 0)
Logs
RuntimeError Traceback (most recent call last)
Cell In[272], line 1
----> 1 unet(torch.randn(32, unet.config.in_channels, 64).to(device), 0)
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /srv/conda/envs/notebook/lib/python3.12/site-packages/diffusers/models/unets/unet_1d.py:228, in UNet1DModel.forward(self, sample, timestep, return_dict)
226 timestep_embed = self.time_proj(timesteps)
227 if self.config.use_timestep_embedding:
--> 228 timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype))
229 else:
230 timestep_embed = timestep_embed[..., None]
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /srv/conda/envs/notebook/lib/python3.12/site-packages/diffusers/models/embeddings.py:1308, in TimestepEmbedding.forward(self, sample, condition)
1306 if condition is not None:
1307 sample = sample + self.cond_proj(condition)
-> 1308 sample = self.linear_1(sample)
1310 if self.act is not None:
1311 sample = self.act(sample)
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/linear.py:125, in Linear.forward(self, input)
124 def forward(self, input: Tensor) -> Tensor:
--> 125 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16 and 64x128)
System Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
- 🤗 Diffusers version: 0.34.0
- Platform: Linux-6.6.56+-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.12.11
- PyTorch version (GPU?): 2.5.1.post303 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.33.1
- Transformers version: not installed
- Accelerate version: not installed
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: Tesla T4, 15360 MiB
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
No response