parallel_train.py
source modelling distributed gpu
File Path: src/modelling/parallel_train.py
Purpose: High-performance multi-GPU training launcher using PyTorch Distributed Data Parallel (DDP).
Overview
This script orchestrates distributed training across all available GPUs on a single node. It manages process spawning, environment setup, and data division.
Logic Workflow
- Initialization: Detects number of available GPUs (
world_size). - Process Spawning: Uses
torch.multiprocessing.spawnto launchtraining_wrapperacross processes (one per GPU). - Environment Setup:
- Initializes the process group using the
ncclbackend. - Sets the local
rankfor each process. - Configures
MASTER_ADDRandMASTER_PORT.
- Initializes the process group using the
- Data Distribution:
- Uses
DistributedSamplerto ensure each GPU sees a unique subset of the data. - Wraps the dataset in
DataLoaderwithpin_memory=True.
- Uses
- Model Parallelism:
- Wraps the model in
DistributedDataParallel (DDP). - Converts Batch Norm layers to
SyncBatchNormfor consistent statistics across GPUs.
- Wraps the model in
- Dynamic Scaling:
- Scales the learning rate by the square root of the
world_size:lr = 1e-3 * sqrt(world_size).
- Scales the learning rate by the square root of the
- Training: Invokes the shared
train()function fromtrain.py.
Key Functions
setup(rank, world_size)
Initializes the distributed environment and the backend communications.
cleanup()
Destroys the process group after training completes.
run_training(rank, world_size, ...)
The core training logic executed on each GPU. Handles device setting, model wrapping, and result visualization (on rank 0).
Usage
# Automatically detects and uses all available GPUs
python src/modelling/parallel_train.py --selected_signs_to 502 --num_epochs 20Related Documentation
- train.py - Shared training loop.
- dataloader.py - Distributed sampling logic.