3 min read

Accelerate PyTorch training with torch-ort

With a simple change to your PyTorch training script, you can now speed up training large language models with torch_ort.ORTModule, running on the target hardware of your choice.

Training deep learning models requires ever-increasing compute and memory resources. Today we release torch_ort.ORTModule, to accelerate distributed training of PyTorch models, reducing the time and resources needed for training. To provide flexibility for the developer, torch-ort is available for both NVIDIA and AMD GPUs. The torch-ort package can be used with other deep learning optimizers like DeepSpeed to provide additional performance gains on training tasks.

Delivered via the torch-ort package from https://github.com/pytorch/ort, the ORTModule class is a simple wrapper for torch.nn.Module. ORTModule supports transformer models such as the GPT and BERT series, with support for other modalities to come soon. Today, you can fine-tune the most popular language models with a labeled dataset for a target task; augment self-supervised training of a model with a specific corpus, or experiment with pre-training new models from scratch.

Performance

As well as using torch-ort for large workloads inside Microsoft, we have benchmarked fine-tuning the most popular HuggingFace models, showing up to 37 percent improvement in throughput, with ORTModule alone, and up to 86 percent when combined with DeepSpeed.

chart, bar chart

chart, bar chart

These experiments were run on Azure’s world-class Azure ND A100 v4 Infrastructure, with its optimal bandwidth between GPUs inside and across machines.

The graphs above show throughput in training samples per second. The actual time your training job takes depends on the number of training samples you have and the type of CPU/GPU you are using. Before starting the training processing, ORTModule does a one-time optimization of your model. This has a fixed cost that is amortized across the run.

The combination of ORTModule and DeepSpeed also enables Ask Here First to train the 2.7B parameter GPT-Neo model for a custom natural language task, where previously this large model could not be trained on the available hardware. AskHereFirst, a spin-off by Columbia University, uses powerful AI-based natural language query solutions for structured data stores that can be used to dramatically simplify the search in a wide range of industries such as finance, media, marketing, and sports.

Ask me here logo

Training GPT-NEO for our custom natural language task was not possible before we employed ORTModule and DeepSpeed. We have now produced fine-tuned 2.7B parameter GPT-NEO models to map natural language inputs into structured queries for a number of our applications.”—Professor Vishal Misra, Columbia University and founder, Ask Here First

Hardware portability

There are different hardware platform options for running distributed training workloads. The torch_ort.ORTModule works with NVIDIA and AMD GPUs.

We are releasing the torch-ort package for NVIDIA using CUDA 10.2 or CUDA 11.1. This can be used to accelerate the PyTorch training execution on both NVIDIA GPUs on Azure or on a user’s on-prem environment.

We are also releasing the preview package for torch-ort with ROCm 4.2 for use on AMD GPUs.

Simple developer experience

Getting started with ORTModule is simple. You download and install the torch-ort package and wrap your model with ORTModule, as demonstrated in the following code example.

Your PyTorch training loop is unmodified except for wrapping the torch.nn.Module in ORTModule.

lines of code

Because the PyTorch training loop is unmodified, ORTModule can be seamlessly integrated with other libraries in the PyTorch ecosystem, such as torch.autocast and NVIDIA apex.

How does it work?

On the first call to forward, two optimized computation graphs are generated: one for the forward prediction pass and one for the backward gradient calculation pass. All other parts of the training loop are executed by native PyTorch. The optimizations in these graphs, such as optimized kernels, subgraph operator fusion, and reduction of memory copies between CPU and GPU provide the speed up.

For more information

Go for a technical deep dive into ORTModule and read about our partnership with AMD. Also, see documentation, samples, or reach out to the team.