The increased usage of Internet of Things (IoT) devices has pushed Machine Learning (ML) to the edge. Traditionally, each IoT device collects and transmits data to the cloud, where it is then analyzed and ML models trained. However, this approach poses a privacy challenge when users are reluctant to share their personal data. Federated learning is a novel privacy-preserving paradigm to train ML models (such as neural networks) in a decentralized fashion. Federated learning aims to identify a well performing model by only transmitting or receiving the models between the server and IoT devices, while leaving the users’ data on device.
Let us consider the next word spotting problem in a federated learning framework . There are a massive number of mobile phones in a system and our goal is to train a neural network that learns the language pattern and predicts the next word based on the users’ text data. Since the user’s data is highly privacy-sensitive, it is better to keep the data on the device instead of it being centralized into the cloud server. Federated learning enables training to take place by optimizing neural networks on mobile phones, and transmitting models to the server for aggregation. However, sharing the models between the phones and the server consumes significant amounts of IoT device energy. The main goal is to reduce the communication cost of federated learning as much as possible to train a good predictive model.
A naive solution for federated learning problems is to adopt the Stochastic Gradient Descent (SGD) procedure to train neural networks locally on each device (Local SGD) . Specifically, with Local SGD, devices receive the model from the server, perform one SGD step with local data on the devices, and transmit the model update (for example, gradient) back to the server. The server averages the received gradients and performs a one-gradient descent update with the average gradient. Local SGD repeats this process over many iterations until convergence is achieved.
Federated Averaging (FedAvg)  extends the Local SGD approach by leveraging more local computation to decrease the communication cost. Specifically, the devices receive the server model, perform a fixed number of SGD updates, and transmit back to the trained model, where the fixed number of SGD updates is a hyperparameter. The server sets the new model by averaging the parameters of the device models. FedAvg empirically demonstrates high communication savings on benchmark datasets, including MNIST  and CIFAR10 .
An important problem of FedAvg is that it is not robust for data heterogeneity. Consider the next word spotting problem as mentioned. Assume that the users have different word choices depending on their location, political ideology, age. Due to these differences in preferences, the dataset from one user will be significantly different than the dataset of another. In other words, the data is non-identically distributed across the users. Thus, stationary points of the device objectives do not align with those of the global objective. The way that FedAvg aggregates these device models on the server generate bias in the final model. As a result, the accuracy of FedAvg degrades significantly as the data heterogeneity grows . FedAvg mitigates this problem by reducing the number of local training steps on the device, which in turn significantly increases the communication costs.
We propose solving the heterogeneity problem in federated learning by modifying the local objective function on the device side with a novel dynamic regularizer in each communication round. This dynamic regularizer aligns the stationary point of the modified device objectives with that of the global objective, so that minimizing the modified device level objectives gives a global stationary solution.
Specifically, we modify device-level objectives with a linear term and a quadratic penalty term. The added terms depend on the device state, and the received server model. These terms explicitly debias the local loss, so that when we fully optimize the problem, we are no longer pulled towards device minima. Debiasing allows us to undertake full minimization and achieve a trained federated model with fewer model transmissions. So, our method leads to high communication savings.
Although the dynamic regularization requires the additional state terms compared to FedAvg, it does not increase the per round communication cost. This is because the state terms are stored on the devices and the server, and are not transmitted through the communications. Only the device and server models are transmitted as in FedAvg.
To provide more insights into dynamic regularization, let us visualize the optimization process in a simplified scenario. Here, there are two devices in the system. and the device objective is a quadratic function with one parameter θ. Figure 1 shows the loss trajectory of the objective function over the model parameter θ for each device (L1(θ), L2(θ)), as well as the averaged loss trajectory (also called the global loss) from the server. The optimal θ that minimizes each loss is noted on the top-left part of the figure. Due to the non-identical data distributions on the two devices, the optimal θ among the devices and the server are quite different. This results in misalignment between the stationary point of the device models and that of the server model.
Figure 1: Toy example with two devices. Device losses L1, L2 and the global loss 1 [L1 + L2] are shown along with the minimizers.
In our solution, we modify the device objective functions with dynamic regularization to align the device and server models. Figure 2 shows the optimization process of our algorithm on the previous simplified problem. As mentioned earlier, we add a linear term and a quadratic term on top of the local losses, and optimize the modified objective. Since these terms depend on the received server model as well as a local state, we get a dynamic loss surface in each communication round. Namely, our algorithm dynamically modifies the local losses in each round so that eventually they align with the global loss. Consequently, our solution reaches the optimum global point.
Figure 2: The trajectory of local losses and device models using our solution.
Figure 3: Smoothed CIFAR10 convergence curves for identical and non-identical data distribution settings. Our solution results in 2.9X and 4.2X communication savings compared to the best competitor to reach the same target accuracy level in IID and non-IID settings respectively.
In our recently published work, we compare our algorithm on benchmark datasets such as MNIST, E-MNIST, CIFAR10, CIFAR100, and Shakespeare. Figure 3 shows convergence curves of ours and alternative methods on CIFAR10 dataset with 100 devices. The plot on the left is the case where user datasets are identically distributed (IID), and as such each user has a similar dataset. Different from IID setting, the plot on the right is the non-IID setting where the user datasets are not close to each other. They are non-identically distributed (non-IID) among devices. As seen in figure 3, our solution convergences faster than alternative methods in both settings. For instance, in the non-IID setting, to reach target accuracy of 82.5%, our solution results in 4.2X communication savings compared to the best competitor in non-IID setting.
In summary, we identified the inconsistency between local losses in federated learning when the device datasets are non-identically distributed. We propose the modification of device objectives to debias local losses. Debiasing allows our solution to achieve significant communication savings even in heterogeneous settings.
We believe that our method is especially helpful for industry applications in federated learning. In a real-world setting, the user preferences would be quite different. Due to this difference, the local datasets are non-identically distributed among the users. Our method accounts for this heterogeneity and it performs better than the alternatives in such scenarios.
Our paper was presented at the ninth International Conference on Learning Representations (ICLR) as an oral presentation. You could find out our presentation of the paper in this link. Feel free to reach out to me with any questions in an email.
Learn more about our ML research, from robust object detection with Stochastic-YOLO to cycle-accurate NPU simulator for your research experiments with SCALE-Sim.
More ML research
 Brendan McMahan and Daniel Ramage. Federated learning: Collaborative machine learning without centralized training data, Apr 2017.
 Sebastian Urban Stich. Local SGD converges fast and communicates little. International Conference on Learning Representations (ICLR), page arXiv:1805.09767, 201.
 Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Ar- cas. Communication-efficient learning of deep networks from decentralized data. In Artificial Intelligence and Statistics, pages 1273–1282, 2017.
 Yann LeCun, Leon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to ´ document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
 Alex Krizhevsky et al. Learning multiple layers of features from tiny images. Technical report, 2009.
 Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, and Vikas Chandra. Feder- ated learning with non-iid data. arXiv preprint arXiv:1806.00582, 2018.