更新时间:2021-12-17 GMT+08:00
分享

持续学习算法

背景信息

当终端数据结构非独立同分布(non-i.i.d), 终端模型存在差异。随着联邦轮次增多,联邦效果下降,全局模型通用性变差,所以引入持续学习。

算法介绍

采用弹性权重固化(Elastic Weight Consolidation, EWC)算法,旨在防止多轮联邦任务过程中,出现网络灾难性遗忘,提升联邦性能。

应用场景:终端数据分布差异较大的联邦场景,同时终端具备计算Fisher信息的能力。

实现方案:终端在损失函数中增加一个惩罚项,通过对影响联邦性能较大的网络参数减小惩罚力度,对影响较小的网络参数增大惩罚力度,使得局部模型能收敛到一个共享的最优值。其中,网络参数对联邦性能的影响程度用Fisher Information表示。

损失函数的惩罚项获取方式:

  1. 终端训练结束后,根据训练数据计算Fisher信息阵,得到本轮网络参数重要程度并存于ewc.h5文件。
  2. 终端上传ewc.h5文件至服务端。
  3. 服务端对Fisher信息平均化,再下发给终端以此作为损失函数的惩罚项。

算法优势

算法能在终端数据分布差异较大的联邦场景下,适当地提高联邦性能。同时终端设备需要有强大的计算能力。

参考论文

Shoham N, Avidor T, Keren A, et al. Overcoming Forgetting in Federated Learning on Non-IID Data[J]. arXiv preprint arXiv:1910.07796, 2019.

分享:

    相关文档

    相关产品

close