ONNX Export Process
To verify efficient inference on CPU-based edge devices or web backends, we export the trained PyTorch model to ONNX (Open Neural Network Exchange) format.
Why ONNX?
- Interoperability: Can be run in many environments (Python, C++, JS).
- Optimization: ONNX Runtime performs graph optimizations (fusion, constant folding) specific to the hardware.
- Speed: Significant inference speedup on CPU compared to PyTorch.
Export Workflow
The script src/modelling/export.py handles the conversion:
- Load PyTorch Model: Loads the
.pthcheckpoint. - Validation: Runs a dummy inference in PyTorch to get baseline outputs.
- Export: Uses
torch.onnx.exportto trace the graph.- Dynamic Axes: We configure the batch size to be dynamic (
{0: 'batch_size'}) so the exported model can handle any batch size.
- Dynamic Axes: We configure the batch size to be dynamic (
- Verification:
- Checker: Runs
onnx.checker.check_modelto validate the schema. - Numerical check: Runs the exported ONNX model using
onnxruntimeand compares the output with the PyTorch baseline usingtorch.testing.assert_close.
- Checker: Runs
Usage
python -m modelling.export --checkpoint_path checkpoints/best_model.pth