Skip to content

Fix group offloading synchronization bug for parameter-only GroupModule's #12077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 6, 2025

Conversation

a-r-r-o-w
Copy link
Member

Fixes #11981.

Requires #11990 to be merged first.

code
import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.hooks import apply_group_offloading, ModelHook, HookRegistry
from diffusers.models import ModelMixin
from diffusers.utils.logging import set_verbosity_debug
from torch.profiler import profile, record_function, ProfilerActivity

set_verbosity_debug()

class LayerOutputTrackerHook(ModelHook):
    def __init__(self):
        super().__init__()
        self.outputs = []
    
    def post_forward(self, module, output):
        self.outputs.append(output)
        return output


class Model(ModelMixin):
    def __init__(self, d_model=1024, num_layers=1):
        super().__init__()
        self.d_model = d_model
        
        self.input_proj = nn.Linear(1024, d_model)
        # self.norm = nn.LayerNorm(d_model, elementwise_affine=True)
        self.blocks = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers)])
        
        # This is problematic
        self.norm = nn.LayerNorm(d_model, elementwise_affine=True)
        # This works
        # self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
        
        self.output_proj = nn.Linear(d_model, 1024)
    
    def forward(self, x):
        x = self.input_proj(x)
        # x = self.norm(x)
        for block in self.blocks:
            x = block(x)
            x = F.relu(x)
        x = self.norm(x)
        x = self.output_proj(x)
        return x


def apply_layer_output_tracker_hook(model: Model):
    for name, module in model.named_modules():
        if not isinstance(module, (torch.nn.Linear, torch.nn.LayerNorm)):
            continue
        registry = HookRegistry.check_if_exists_or_initialize(module)
        hook = LayerOutputTrackerHook()
        registry.register_hook(hook, "layer_output_tracker")


def print_output_diffs(ref_model: Model, model: Model):
    for (ref_name, ref_module), (name, module) in zip(ref_model.named_modules(), model.named_modules()):
        assert ref_name == name
        if not isinstance(ref_module, (torch.nn.Linear, torch.nn.LayerNorm)):
            continue
        ref_outputs = HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs
        outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs
        cumulated_absmax = 0.0
        for i in range(len(outputs)):
            diff = ref_outputs[0] - outputs[i]
            absdiff = diff.abs()
            absmax = absdiff.max().item()
            cumulated_absmax += absmax
            if ref_name == "output_proj":
                print(f"{ref_name} absmax {i}: {absmax}")
        print(f"{name}: cumulated_absmax={cumulated_absmax:.5f}, num_outputs={len(outputs)}")


torch.manual_seed(42)
model_ref = Model()
model1 = Model()
model2 = Model()

model1.load_state_dict(model_ref.state_dict())
model2.load_state_dict(model_ref.state_dict())

model_ref.eval()
model1.eval()
model2.eval()

onload_device = torch.device("cuda:0")
offload_device = torch.device("cpu")

model_ref = model_ref.to(onload_device)
apply_group_offloading(
    model1,
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="block_level",
    num_blocks_per_group=1,
    use_stream=True,
)
apply_group_offloading(
    model2,
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="leaf_level",
    use_stream=True,
)

apply_layer_output_tracker_hook(model_ref)
apply_layer_output_tracker_hook(model1)
apply_layer_output_tracker_hook(model2)

x = torch.randn(1, 512, 1024).to("cuda")
out_ref = model_ref(x)

def compare_outputs(out1, out2):
    diff = out1 - out2
    absdiff = diff.abs()
    absmax = absdiff.max()
    mae = absdiff.mean()
    mse = (absdiff ** 2).mean()
    cossim = F.cosine_similarity(out1.flatten(), out2.flatten(), dim=0)
    print(f"{absmax=:.5f}, {mae=:.5f}, {mse=:.5f}, {cossim=:.5f}")

for _ in range(2):
    model1(x)
    print("=" * 80)
    model2(x)

do_profile = False
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
# context = profile(
#     activities=activities,
#     record_shapes=True,
#     profile_memory=True,
#     with_stack=True,
# ) if do_profile else contextlib.nullcontext()
context = contextlib.nullcontext()

with context as prof:
    with torch.inference_mode():
        for i in range(10):
            with record_function(f"model_1_run_{i}"):
                output1 = model1(x)
            print(i)
            compare_outputs(out_ref, output1)
            print()

        print("=" * 80)
        
        for i in range(10):
            with record_function(f"model_2_run_{i}"):
                output2 = model2(x)

            print(i)
            compare_outputs(out_ref, output2)
            print()


print_output_diffs(model_ref, model1)
print()
print_output_diffs(model_ref, model2)

# prof.export_chrome_trace("dump_trace.json")
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=1000))

Tested for 100 rounds with:

seq 100 | xargs -Iz python3 dump12.py

Testing with profiling is not helpful because the problem never shows up. See heisenbug thread: https://huggingface.slack.com/archives/C065E480NN9/p1754035222558869

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w requested review from sayakpaul and DN6 August 5, 2025 21:10
@a-r-r-o-w
Copy link
Member Author

cc @seed93, this seems to resolve many different tests I tried that were previously causing outputs to be different. Could you verify on your end if everything works well? Thanks 🤗

@sayakpaul
Copy link
Member

Let's quickly merge the cleaning PR so that it's easier to review this one :)

@seed93
Copy link

seed93 commented Aug 6, 2025 via email

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ridiculously simple fix yet ridiculously critical heisenbug.

Comment on lines 314 to 316
# If this group didn't onload itself, it means it was asynchronously onloaded by the
# previous group. We need to synchronize the side stream to ensure parameters
# are completely loaded to proceed with forward pass.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): It would be beneficial to comment on the consequences of not performing this synchronization.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
@a-r-r-o-w
Copy link
Member Author

Sometimes the hardest of problems have the simplest solutions :)

Failing tests are unrelated

@a-r-r-o-w a-r-r-o-w merged commit 69cdc25 into main Aug 6, 2025
14 of 15 checks passed
@a-r-r-o-w a-r-r-o-w deleted the fix-group-offload-sync-bug branch August 6, 2025 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Groupoffloading introduce bad results
4 participants