train.py
source modelling training pytorch
File Path: src/modelling/train.py
Purpose: Training loop implementation supporting mixed precision and distributed training.
Features
- Automatic Mixed Precision (AMP): Uses
torch.amp.autocast(bfloat16) andGradScaler. - Distributed Data Parallel (DDP): Supports multi-GPU training via
rankparameter. - Checkpointing: Saves model only when validation loss improves.
- Scheduling:
ReduceLROnPlateaufor learning rate adjustment.
Training Pipeline
sequenceDiagram participant Train as Model (Train) participant Valid as Model (Eval) participant Check as Checkpointer loop Every Epoch rect rgb(20, 20, 30) Optional: Dark background for Validation Note over Valid: Validation Phase Valid->>Valid: Forward (No Grad) end opt Val Loss < Best Loss Check-->>Check: Save State Dict end end
Functions
train(...)
def train(model, loss, optimizer, scheduler, train_dl, val_dl, num_epochs, ...):Workflow:
- Setup: Initializes
GradScalerand Checkpoint directory. - Epoch Loop:
- Train Phase: Iterates
train_dl. Updates weights. Logs loss. - Eval Phase: Iterates
val_dl. Aggregates global metrics (if DDP). - Checkpoint: Saves
state_dictifval_loss < best_val_loss. - Scheduler: Steps based on
val_loss.
- Train Phase: Iterates
Returns: Path to the best checkpoint file.
Execution (__main__)
Allows executing training directly as a script.
- Defaults:
num_epochs: 1lr: 1e-3weight_decay: 1e-4
python -m src.modelling.trainRelated Documentation
Depends On:
- dataloader.py - Data providers
- model.py - Model Architecture
- constants.py -
DEVICE
Conceptual: