4 min read

On-Device Training: Efficient training on the edge with ONNX Runtime

We are introducing On-Device Training, a new capability in ONNX Runtime (ORT) which enables training models on edge devices without the data ever leaving the device. The edge devices can be any compute-enabled devices like laptops, smartphones, gaming consoles, or other embedded devices. This capability opens new opportunities for application developers, as they can now personalize experiences for users without compromising privacy. This blog post provides a quick overview of On-Device Training with ORT and resources to help you get started.

ONNX Runtime at a glance

ORT is a high-performance cross-platform inference and training engine that can run a variety of machine learning models. ORT provides an easy-to-use experience for the AI developers to run models on multiple hardware and software platforms. Beyond accelerating server-side inference and training, ORT is also available for inferencing on mobile devices and on web browsers.

The new On-Device Training capability extends the ORT-Mobile inference offering to enable training on the edge devices. The goal is to make it easy for developers to take an inference model and train it locally on-device—with data present on-device—to provide an improved user experience for end customers.

On-Device Training with ONNX Runtime

As opposed to traditional deep learning (DL) model training, On-Device Training requires efficient use of compute and memory resources. Additionally, edge devices vary greatly in compute and memory configurations. To support these unique needs of edge device training, we created On-Device Training capability that is framework agnostic and builds on top of the existing C++ ORT core functionality.

With On-Device Training, application developers can now infer and train using the same binaries. At the end of a training session, the runtime produces optimized inference ready models which can then be used for a more personalized experience on the device. For scenarios like federated learning, the runtime provides model differences since the aggregation happens on the server side.

High-Level workflow for personalization with ONNX Runtime. Starts from model converted to ONNX, to generating training artifacts to using the locally trained model for inference.
Figure 1: High-Level workflow for personalization with ONNX Runtime

Key benefits

  • Memory and performance efficient local trainer for lower resource consumption on device (battery life, power usage, and multiple app training).
  • Optimized binary size which fits strict constraints on edge devices.
  • Simple APIs and multiple language bindings make it easy to scale across multiple platform targets (Now available – C, C++, Python, C#, Java. Upcoming – JS, Objective-C, and Swift).
  • Developers can extend their existing ORT Inference solutions to enable training on the edge.
  • Same ONNX model and runtime optimizations can run across desktop, edge, and mobile devices, without having to re-design training solution across platforms.

Applications of On-Device Training

The applications of On-Device Training fall into two broad categories:

Federated learning: This technique can be used to train global models based on decentralized data without sacrificing user privacy. Federated learning involves updating a global model based on training that happens on edge devices. The edge devices train their version of the global model based on data local to the devices and return the model difference to the server. The server then aggregates these model differences from various devices to update the global model. This process is repeated until the desired outcome from the model is achieved. On-Device Training provides the local trainer which will run on individual devices. Federated learning infrastructure will provide the orchestration of managing the output of the local trainers, across a large number of devices, to update the global model.

For instance, healthcare industries can use federated learning to train models based on data from different hospitals with the data always staying on location, to provide better predictions for health conditions. Privacy for the patients is maintained because user data never leaves the devices or hospitals. The model improves the quality because it is updated based on model changes suggested from individual hospitals. This should lead to a comprehensive global model with an overall better performance for the end customer.

Personalized learning: This technique involves fine-tuning models on-device to create new personalized models. The training is based on data on-device, which produces a model personalized for the end user locally. On-Device Training again acts as a local trainer, which will update the model present on-device. This personalized model will then be used for inference to provide an improved experience for the end customer.

Personalization has a variety of applications. For instance, personalization can be used to train text prediction, image detection, or image classification models locally on the device. In the case of prediction or detection it is tuned to the individual user behavior or data, so that results are more customized for the end customer. For an image classification scenario like photo-tagging, the customer can leverage device data to customize their photo-tagging experience—for their family and friends—without the data leaving their device.

Looking forward

We are continuously working on improving the feature set and platform support. In the next release, we will add support for iOS and web browser. We will also enable more optimizations to make On-Device Training more efficient. Additionally, we will publish deep dives and tutorials in the coming months. We would love to hear your feedback and feature requests. Please use our GitHub repository to leave comments and feedback.

Getting started

Curious to learn more, or want to see how your app can include On-Device Training? Check out these links to get started: