Skip to content

Add Acelerator API to ddp_tutorial #3508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions intermediate_source/ddp_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,12 @@ be found in
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
# such as CUDA, MPS, MTIA, or XPU.
acc = torch.accelerator.current_accelerator()
backend = torch.distributed.get_default_backend_for_device(acc)
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
dist.init_process_group(backend, rank=rank, world_size=world_size)

def cleanup():
dist.destroy_process_group()
Expand Down Expand Up @@ -216,8 +220,11 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
# such as CUDA, MPS, MTIA, or XPU.
acc = torch.accelerator.current_accelerator()
# configure map_location properly
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
map_location = {f'{acc}:0': f'{acc}:{rank}'}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location, weights_only=True))

Expand Down Expand Up @@ -295,7 +302,7 @@ either the application or the model ``forward()`` method.


if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
n_gpus = torch.accelerator.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_basic, world_size)
Expand Down Expand Up @@ -331,12 +338,14 @@ Let's still use the Toymodel example and create a file named ``elastic_ddp.py``.


def demo_basic():
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
dist.init_process_group("nccl")
torch.accelerator.set_device_index(int(os.environ["LOCAL_RANK"]))
acc = torch.accelerator.current_accelerator()
backend = torch.distributed.get_default_backend_for_device(acc)
dist.init_process_group(backend)
rank = dist.get_rank()
print(f"Start running basic DDP example on rank {rank}.")
# create model and move it to GPU with id rank
device_id = rank % torch.cuda.device_count()
device_id = rank % torch.accelerator.device_count()
model = ToyModel().to(device_id)
ddp_model = DDP(model, device_ids=[device_id])
loss_fn = nn.MSELoss()
Expand Down