6 min read

On-Device Training with ONNX Runtime: A deep dive

We continue our discussion through the multi-part blog series on On-Device Training with ONNX Runtime (ORT). In the first part of this series, On-Device Training: Efficient training on the edge with ONNX Runtime we explored the fundamentals of On-Device Training, delving into multiple use cases and the advantages it offers when combined with ORT. Building upon the foundation we established earlier, this blog will present comprehensive information about the underlying details of training models directly on user devices using ORT. Equipped with these technical details, we encourage you to try out On-Device Training with ONNX Runtime for your custom scenario. 

How does it work

On-Device training with ORT is a framework-agnostic solution that leverages the existing ONNX Runtime Inference engine as its foundation. Here are the high-level steps for training on a device with ORT. 

  • Export your model to ONNX format.
  • Generate all the prerequisites to training (models, checkpoints) and validate training recipes.
  • Deploy on device and train.
  • Based on your scenario, choose a post-training option.
diagram, schematic
Figure 1: E2E Flow for On-Device Training

Next, we go over each of these steps in detail.

Export to ONNX

The very first step is to convert the model from the original framework to ONNX (if it is not already in ONNX). ONNX ecosystem provides tools to export models from different popular machine learning frameworks to ONNX. Here is a simple example illustrating how to export an ONNX model: 

Figure 2: Example to convert PyTorch model to ONNX format.

Offline preparation: Artifact Generation and Recipe Validation

The app developers and data scientists generate the training artifact which are prerequisites to training. The training artifacts include—the training, evaluation, and optimizer ONNX models and the checkpoint state. These artifacts can then be used to conduct experiments to finalize the recipes to ensure both the artifacts and the recipes are optimized and ready for on-device training. 

Artifact Generation

The python-based frontend tools provide flexibility to experiment offline with various elements such as loss functions, gradient accumulation techniques, optimizers, and more. This flexibility enables comprehensive exploration and fine-tuning of the training process, making it easy for the users to discover the optimal configurations for their models. 

You can refer to the Artifact Generation step in the android sample for more details. Here is a simple example to generate training artifacts using frontend tools:

graphical user interface, text, application, email
Figure 3: Code snippet illustrating Artifact Generation step

Now, let’s take a closer look at each of these artifacts to delve deeper into their significance and role in the training workflow.

Training, Optimizer, and Evaluation ONNX models

The Training and Optimizer graphs are integral components of the training process. The Training graph captures the model’s computations and gradients, while the Optimizer graph ensures effective parameter updates. 

The Training Graph is derived from the forward graph. It is constructed by augmenting the forward graph with gradients, which enable the model to learn and improve iteratively. These gradients represent the sensitivity of the model’s outputs to changes in its parameters.  

Figure 4: Training ONNX Model demonstrating forward and backward passes for a simple forward ONNX model 

The Optimizer Graph includes operations such as gradient normalization and parameter updates. Gradient normalization ensures that the magnitude of the gradients remains within a desired range, preventing them from becoming too large or too small, which can impede the training process. Parameter updates adjust the model’s weights based on the gradients. 

Figure 5: Optimizer ONNX model demonstrating parameter update step 

The model can thus undergo backpropagation and parameter updates, ensuring it continuously adapts to the training data. 

The Evaluation Graph is an optional component. It is closely related to the forward graph as it shares the same structure as the forward graph, however, it differs in that it incorporates a loss function at the end of the graph. By evaluating the loss function on a separate set of data, such as a validation dataset, developers can get insights into how well the model is generalizing and performing on unseen examples. 

Figure 6: Evaluation ONNX model with a loss function to carry out the evaluation task 

Checkpoint State

The Checkpoint State captures the essential training states, enabling initiating, or resuming training. Checkpoint is a flatbuffer message which includes the following data: 

  • Model Parameters: This includes both trainable and non-trainable parameters.
  • Optimizer State (Optional): In cases where pausing and resuming training is necessary, the Checkpoint State can also capture the state of the optimizer, allowing for seamless continuation of training.
  • Other User-Defined Parameters (Optional): Additional user-defined parameters, such as epoch number, learning rate, last recorded loss, and more, can also be included in the Checkpoint State.

All the ONNX models (training, evaluation, and optimizer) refer to the parameters in the shared checkpoint state. This eliminates the need to include the parameters within each model and thus results in a significant reduction in the model size. 

Recipe Validation

ORT provides python bindings for training, which enable users to conduct offline experiments in a familiar and convenient environment thus making it easy to tune and refine the training recipes. Once the desired configurations and models are proven successful, they can be seamlessly deployed to the devices. Refer to the recipe validation step in the android sample for more details.

On-Device training

Once the training artifacts are ready and the recipes are validated, on-device training can begin. Refer to android sample to see how these artifacts and recipes can be easily deployed to the device.  

ORT offers support for a wide range of edge devices, including both powerful devices like PCs and gaming consoles and resource-constrained devices like mobile devices. Currently, the following platforms and language bindings are supported: 

graphical user interface, text, application
Figure 7: Available support for different platforms and language bindings 
Swift and Objective-C will be available starting ORT release 1.16 

All language bindings serve as extensions to the existing ORT Inference APIs. This approach provides developers with a consistent and unified API for both training and inference tasks using ORT. 

Figure 8: ORT Architecture stack. *Coming soon

The entry point for the training workflow is the TrainingSession. It is responsible for orchestrating the training process and handling the associated business logic. Developers provide the required training artifacts to initialize the training session. The training session loads these artifacts, effectively sharing the parameters between the training, and evaluation graphs. This parameter-sharing approach minimizes memory consumption for both sessions, optimizing the overall memory footprint. Furthermore, during the optimizer step, parameter updates are performed in place, eliminating any memory penalties that could arise from unnecessary copying. 

TrainingSession exposes APIs to run training step, evaluation step, optimizer step, reset gradient buffers, and other utilities. To understand the usage in detail, please refer to the training step in the android sample. Here is a simple snippet illustrating a typical training loop: 

graphical user interface, text, application
Figure 9: Simple training loop in C++ 

ORT enables users to optimize their deployments on edge devices by creating scenario-specific custom minimal builds. ORT employs different mechanisms like disabling RTTI (Run-time type information), operator type reduction, removing runtime optimizations, etc. The most common way is to reduce the set of supported operators in the runtime to only those in the models that run in the target environment. Users can choose which mechanisms to enable during build creation time. To understand all the available options in detail, please refer to the build instructions

Post training options

Upon completion of the training process, app developers have several options depending on their specific scenario: 

  1. Export an inference graph on the device: This option allows developers to generate an optimized inference graph directly on the device which can be leveraged for enabling personalized and efficient predictions on the same device. 
  2. Extract updated weights in memory for aggregation on the server: In scenarios involving federated learning, the updated weights can be extracted from the TrainingSession in memory. These weights can then be securely aggregated on a server, allowing for collaborative and privacy-preserving model training across multiple devices. 
  3. Save a trained checkpoint for future training resumption: Developers also have the option to save a trained checkpoint, which captures the updated model parameters and training state. This checkpoint can be used to resume training later, ensuring continuity in the training process. 

Next steps

Throughout this multi-series blog, we have explored the fundamentals of On-Device Training with ORT, delving into the offline preparation stage, the training artifacts, and on-device training step. 

If you’re eager to dive deeper and start experimenting with the code yourself, we highly recommend checking out this comprehensive tutorial: On-Device Training: Building an Android Application. This tutorial will guide you through the step-by-step process of setting up your environment, constructing training artifacts, and performing training. By following this tutorial, you will gain hands-on experience and practical insights into implementing On-Device Training with ORT. 

Looking forward

We have quite a few exciting features in the pipeline. Here’s a glimpse into what is coming next: 

iOS Support: We are working on expanding the platform support to include iOS devices. We will be adding Objective-C and Swift bindings and making these packages available in Cocoa Pods as well as Swift Package Manager. 

Training on Web Browsers: The web platform continues to evolve as a powerful environment for running machine learning models. ORT-Web is already available for our partners to run inference on Web with ORT. We are working on enabling on-device training on web browsers as well.

Exciting developments are underway, but what’s next in our blog series? In the next installment, we will dive deep into the Minimal Build support feature. We will explore how developers can create leaner builds tailored to their specific models, and we will also compare the resulting binary sizes.