Skip to end of metadata
Go to start of metadata

You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 3 Next »

Overview

This example is based on Multinode Training

Test Matrix

PyTorch 2.4.1

GPU Device

Cuda Version

A30

L40s

H100

11.8

works

works

works

12.1

works

works

works

12.4

works

works

works

Example

The following example is based heavily on the code within the ddp-tutorial-series GitHub repo.

You will need to update the slurm submission script appropriately.

 multi_node.py
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

import argparse
import os
import socket


class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        return self.data[index]


def ddp_setup():
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    init_process_group(backend="nccl")


class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        save_every: int,
        snapshot_path: str,
    ) -> None:
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        self.model = model.to(self.local_rank)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.epochs_run = 0
        self.snapshot_path = snapshot_path
        if os.path.exists(snapshot_path):
            print("Loading snapshot")
            self._load_snapshot(snapshot_path)

        self.model = DDP(self.model, device_ids=[self.local_rank])

    def _load_snapshot(self, snapshot_path):
        loc = f"cuda:{self.local_rank}"
        snapshot = torch.load(snapshot_path, map_location=loc)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[{socket.gethostname()}] [GPU{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz}"
              f" | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.local_rank)
            targets = targets.to(self.local_rank)
            self._run_batch(source, targets)

    def _save_snapshot(self, epoch):
        snapshot = {
            "MODEL_STATE": self.model.module.state_dict(),
            "EPOCHS_RUN": epoch,
        }
        torch.save(snapshot, self.snapshot_path)
        print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")

    def train(self, max_epochs: int):
        for epoch in range(self.epochs_run, max_epochs):
            self._run_epoch(epoch)
            if self.local_rank == 0 and epoch % self.save_every == 0:
                self._save_snapshot(epoch)


def load_train_objs():
    train_set = MyTrainDataset(2048)  # load your dataset
    model = torch.nn.Linear(20, 1)  # load your model
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset)
    )


def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
    ddp_setup()
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
    trainer.train(total_epochs)
    destroy_process_group()


if __name__ == "__main__":
    device_count = torch.cuda.device_count()
    hostname = socket.gethostname()
    device_str = ("[" + hostname + "] PyTorch Version: " + str(torch.__version__) + "\n" +
                  "Torch Distributed: " + str(torch.distributed.is_available()) + "\n" +
                  "Cuda Available: " + str(torch.cuda.is_available()) + "\n" +
                  "Cuda Version: " + str(torch.version.cuda) + "\n" +
                  "ArchList: " + "\n" + str(torch.cuda.get_arch_list()) + "\n" +
                  "NCCL: Version: " + str(torch.cuda.nccl.version()) + "\n" +
                  "Device Count: " + str(device_count))
    print(device_str)

    for device_id in range(0, device_count):
        device = torch.device("cuda:" + str(device_id))
        major, minor = torch.cuda.get_device_capability(device)
        gpu_str = ("[" + hostname + "] Device ID: " + str(device_id) +
                   " Device Name: " + str(torch.cuda.get_device_name(device_id)) + "\n" +
                   "  CUDA compute capability: " + str(major) + "." + str(minor) + "\n" +
                   "  Properties: " + str(torch.cuda.get_device_properties(device_id)) + "\n" +
                   "  NCCL: Available: " + str(torch.cuda.nccl.is_available(torch.rand(1, device=device))))
        print(gpu_str)

    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('save_every', type=int, help='How often to save a snapshot')
    parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()

    main(args.save_every, args.total_epochs, args.batch_size)
 Slurm submission script
#!/bin/bash

#SBATCH --job-name=pyt-multi-node
#SBATCH --account=<project-name>
#SBATCH --time=5:00
#SBATCH --nodes=<num-of-nodes>
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-task=<num-of-gpus-per-node>
#SBATCH --cpus-per-task=4
#SBATCH --mail-type=ALL
#SBATCH --mail-user=<email_address>
#SBATCH --output=pyt_multi_node_%A.out
#SBATCH --partition=<gpu-partition>

export OMP_NUM_THREADS=1
# Uncomment for NCCL related logging.
# export NCCL_DEBUG=INFO

echo "SLURM_JOB_ID:" $SLURM_JOB_ID
echo "SLURM_JOB_NUM_NODES:" $SLURM_JOB_NUM_NODES
echo "SLURM_JOB_NODELIST:" $SLURM_JOB_NODELIST
echo "- - - - - - - - - - - -"
echo "SLURM_GPUS:" $SLURM_GPUS
echo "SLURM_GPUS_PER_NODE" $SLURM_GPUS_PER_NODE
echo "SLURM_GPUS_ON_NODE:" $SLURM_GPUS_ON_NODE
echo "SLURM_JOB_GPUS:" $SLURM_JOB_GPUS
echo "CUDA_VISIBLE_DEVICES:" $CUDA_VISIBLE_DEVICES
echo "- - - - - - - - - - - -"

# Environment Variable used to set a random port to potentially allow different jobs to 
# run on the same GPU device.
export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
echo "MASTER_PORT="$MASTER_PORT

# List GPU devices allocated on head node.
nvidia-smi -L

module purge
module load miniconda3/24.3.0
conda activate <path-to-conda-environmnet>

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
echo "Node IP: "$head_node_ip

srun torchrun \
--nnodes "$SLURM_JOB_NUM_NODES" \
--nproc_per_node "$SLURM_GPUS_ON_NODE" \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:$MASTER_PORT \
multi_node.py 50 10

echo "Done."

  • No labels