自适应迁移鲁棒特征的个性化联邦医学图像分类

2024-03-20 10:32陆森良冯宝徐坤财陈业航陈相猛
中国图象图形学报 2024年3期
关键词:用户端联邦全局

陆森良,冯宝,徐坤财,陈业航,陈相猛

1.桂林电子科技大学电子工程与自动化学院,桂林 541004;2.桂林航天工业学院智能检测与信息处理实验室,桂林 541004;3.江门市中心医院医学影像智能计算及应用实验室,江门 529000

0 引言

在医学图像任务中,由于医学图像数据获取难度大(Li 等,2019;陈弘扬 等,2021),同时因患者隐私保密规定,各医疗机构间不能共享患者数据,极大限制了数据规模。为打破数据壁垒,联邦学习的提出(McMahan 等,2023)保证了各用户端在不共享数据的情况下,以一种去中心化的方式训练本地模型和聚合全局模型,其组织各用户端分享本地模型参数,协作训练一个全局模型,并将该模型部署到用户端。然而各个医院所拥有的数据存在因数据采集设备、搜集规格和图像质量差异而导致的特征异质性分布等问题(Jiang 等,2022),使得数据总体呈现非独立同分布(non independent identically distribution,Non-IID)(Kairouz 等,2021),从而导致各本地模型间参数一致性差,聚合得到的全局模型难以在所有用户端都具有良好的性能(McMahan 等,2023;Jiang 等,2022;Kairouz 等,2021;Li 等,2020;Li 等,2021a)。因此,有效地全局聚合和部署到用户端是联邦学习研究中的一个热点问题。

Li 等人(2020)基于联邦平均算法(federated averaging,Fedavg)框架(McMahan 等,2023)提出在模型训练中添加一正则化项以减少本地模型与全局模型间的参数差异;模型对比联邦学习算法(modelcontrastive federated learning,MOON)(Li 等,2021a)基于全局模型与本地模型间的特征余弦相似度构建对比损失以加强本地模型与全局模型参数的一致性。Jiang 等人(2022)提出使用幅值归一化技术调和图像低频特征以减轻异质性特征影响,并在本地模型训练中添加扰动以拓宽模型收敛区域。尽管这些方法可以加强本地模型参数一致性,但是仍采用固定权重比例(基于数据规模比例或平均分配)合成全局模型,对模型整体性能产生影响。而且将全局模型直接部署到用户端,对个性化问题(Huang 等,2021)缺乏考虑。

为此,部分研究者提出个性化联邦学习,致力于构建个性化本地模型以提升本地模型的泛化性能。Deng 等人(2020)提出自适应联邦学习算法:自适应个性化联邦学习算法(adaptive personalized federated learning,APFL),该算法基于混合权重α实现本地模型和全局模型的参数合成,构建个性化本地模型;Li等人(2021b)提出一种保持用户端本地模型的批归一化(batch normalization,BN),层参数不变,共享和聚合其他模型层参数的个性化联邦学习算法:本地批处理归一化联邦平均算法(federated averaging with local batch normalization,FedBN);元联邦学习算法(meta federated learning,MetaFed)(Chen 等,2023)基于无服务端和分组用户端的联邦应用场景,提出一种环形信息传播结构,并选择全连接层进行模型间的知识蒸馏以构建个性化模型。Wang 等人(2019)将迁移学习用于联邦学习,提出将训练好的全局模型的部分或所有参数在局部数据上重新训练以实现个性化。Schneider 和Vlachos(2020)提出后期塑造法,让训练过的模型使用来自不同分布的数据进行更新。上述方法可以实现个性化模型的构建,但是需要基于大量试验或以往经验选择部分模型层来继承全局模型参数。

因此,提出一种基于特征迁移的自适应个性化联邦学习算法(adaptive personalized federated learning via feature transfer,APFFT)。该算法主要包括构建个性化本地模型和聚合全局模型两个部分。1)在个性化本地模型构建中,为降低全局模型中异质性特征信息的影响,提出鲁棒特征选择网络(robust feature selection network,RFS-Net)自动计算全局模型向本地模型迁移特征时的特征通道迁移权重和特征迁移量,以决定迁移哪些特征和迁移去哪,然后基于特征通道迁移权重和特征迁移量构建特征迁移损失项以加强本地模型对有效特征的注意力,实现自动识别和迁移有效特征的目的。2)考虑本地模型中存在的异质性特征信息,提出自适应聚合网络(adaptive aggregation network,AA-Net)基于本地模型特征计算各本地模型向全局模型迁移特征时的特征迁移权重,并构建聚合损失项以引导全局模型从各本地模型中识别和迁移鲁棒特征,然后基于全局模型交叉熵的变化实时更新特征迁移权重,从而逐步过滤各本地模型中的异质性特征信息。

1 基于特征迁移的自适应个性化联邦学习算法APFFT

如图1所示,提出的APFFT 算法共包含N+1个用户端,以其中1 个用户端为服务端,接收其余N个用户端上传的本地模型参数后进行全局模型的参数聚合。

图1 APFFT算法框架Fig.1 The framework of APFFT algorithm

该算法主要包括两个部分:1)本地模型的自适应训练。首先,由用户端接收服务端下发的全局模型参数,基于全局模型和本地模型k在本地数据的特征差异计算二范数,然后由鲁棒特征选择网络(robust feature selection network,RFS-Net)基于全局模型特征生成特征通道迁移权重和迁移量,并结合二范数值计算特征迁移损失,最后更新本地模型参数。2)全局模型自适应聚合。服务端接收各用户端上传的本地模型参数后,基于全局模型和各本地模型在服务端数据的特征差异逐一计算二范数,最后由AA-Net基于各本地模型特征计算迁移权重,构建聚合损失以更新全局模型参数。

1.1 基于RFS-Net的本地模型自适应训练

在联邦学习应用场景中,由于各中心间的数据分布存在差异,本地模型如以Fine-tune 方式(Razavian 等,2014)进行训练,会将全局模型中的异质性特征信息一并引入,影响本地模型性能。因此,为降低异质性特征影响,让模型注意力更多地集中在有效特征,使用RFS-Net 引导全局模型到本地模型的特征迁移,基于特征通道迁移权重和迁移量构建迁移损失函数以约束本地模型进行自适应训练,实现本地模型个性化构建。

1.1.1 特征迁移

设有N+1 个数据中心,记为{Dglobal,D1,…,Dk,…,DN},全局模型和本地模型分别记为S(server)和C1,…,Cn,{x}为图像数据,{y}为图像标签,({x},{y}) ∈Dk。模型间具有高度一致性的特征,其适用价值更高(Romero 等,2015;Jang 等,2019;黎英和宋佩华,2022)。因此,基于全局模型和本地模型在Dk上的特征计算特征差异程度,具体为

式中,(m,n) ∈∂,∂为预定义的模型层匹配对,匹配对总数记为P。Sm(x)为全局模型第m模型层的中间特征图(x)为本地模型k第n模型层的中间特征图为被θk参数化可逐点卷积的线性变换,其保证了两模型间特征图的数量一致,此外,该参数仅在本地模型训练时被更新和使用,并与组成了本地模型参数θk。在计算特征图间的L2值时,如存在特征图尺寸不一致,则使用双线性插值法修正特征图尺寸。然后,如图2(a)所示,特征图间的像素点逐一对应计算差异度diffi,v,最后取平均值得到L2值。

图2 特征通道匹配Fig.2 Feature channels matching((a)compute L2;(b)feature channels matching)

进行特征迁移时,首先,定义RFS-Netk的子网络,其是由P组全连接层、池化层和softmax 层构成的全连接网络。该网络以全局模型的特征为输入,以softmax 层为输出,计算特征通道迁移权重并由该权重来选择迁移特征参与本地模型的训练。具体为

其次,定义RFS-Netk的子网络,是由P组全连接层、池化层和ReLU 6 激活层构成的全连接网络,其以全局模型的特征为输入,以ReLU 6 激活层为输出,计算全局模型第m层到本地模型第n层的迁移量具体为

1.1.2 个性化本地模型的自适应训练

基于L2和构建特征迁移损失函数以约束本地模型训练,该损失函数定义为

式中,H×W为和Sm(x)输出特征图的尺寸,i∈{1,2,⋅⋅⋅,H},v∈{1,2,⋅⋅⋅,W}。由式(2)计算得到。

如图3 所示,本地模型进行自适应训练的主要步骤如下:

图3 基于特征迁移的本地模型自适应训练示意图Fig.3 Local model adaptive training based on feature transfer

1)按式(1)计算L2;

2)计算特征通道迁移权重,由k的子网络按式(2)计算

3)计算迁移量,由k的子网 络按 式(3)计算

4)根据式(6)计算Lwfm,更新本地模型参数θk。

由式(5)(6)可知,当本地模型参数与全局模型参数存在较大差异时,两者在Dk上的特征差异,即L2值也随之变大,此时由Lwfm约束本地模型加强对该模型层特征的注意力,如该特征无法降低本地模型交叉熵,则其对应的和会在更新RFS-Net参数后降低。

此外,为加强本地模型与全局模型之间的参数一致性,减少后续全局模型参数的聚合难度,计算对比损失,具体为

最后,得到训练本地模型Ck参数的总损失函数Ltotal,具体为

式中,Lorg(θk|x,y)为Ck的交叉熵损失,β>0 为一个控制迁移损失权重的超参数,μ为一个控制对比损失权重的超参数。本地模型Ck的优化目标为

1.2 基于AA-Net的全局模型自适应聚合

由于多中心数据存在的Non-IID问题,各本地模型参数受自身数据的异质性特征影响,模型间的参数一致性变差(Jiang 等,2022)。若采用传统聚合方法,直接为每个本地模型参数分配一个权重,按权重直接合成全局模型参数(McMahan 等,2023;Jiang等,2022;Li 等,2020;Li 等,2021a),可能会分散模型对有效特征的注意力,难以提升全局模型的泛化能力。因此,在聚合全局模型参数时,为了加强全局模型对鲁棒特征的注意力,从各本地模型中迁移特征时,由AA-Net基于全局模型交叉熵变化对特征进行筛选,从而逐步过滤各本地模型中的异质性特征信息,让全局模型参数在迁移中自动完成聚合,由此定义一个全局模型自适应聚合优化项,具体为

式中,δ为AA-Net 的模型参数。AA-Net 由N组RFSNet构成,使用全局模型交叉熵更新网络参数。

如图4 所示,在提出的全局模型自适应聚合方法中,将其中一个用户端作为服务端,其模型作为全局模型,并将其数据记为Dglobal。当服务端接收各个本地模型参数后,首先,依据式(2)(3)基于Ck在Dglobal上的特征计算得到和再由式(6)(11)计算Lagg(θglobal|x,δ),以更新server参数。

图4 基于特征迁移的全局模型自适应聚合Fig.4 Global model adaptive aggregation based on feature transfer

根据式(10)优化项进行全局模型参数聚合时,由于不同本地模型对Dglobal的特征会存在差异,对于能降低server交叉熵的相关特征,AA-Net为其提高和因此在迁移中持续发挥作用的特征,会被AA-Net持续赋予较高的和。相应地,全局模型在聚合过程中,也会始终对这些特征保持较高的注意力。

APFFT 联邦学习训练框架如算法1 所示。在每个信息交互的循环中,server既要接收来自各个用户端的本地模型参数,又要将全局模型参数传输给各个用户端。Ck和server 均使用随机梯度下降法(stochastic gradient descent,SGD)更新模型权重参数、RFS-Net和AA-Net参数。

算法1:APFFT框架。

2 实验结果

为了评估APFFT 方法在Non-IID 数据集上的表现,在3 种医学图像分类任务上开展实验:肺结核肺腺癌分类任务、乳腺癌组织学图像分类任务和肺结节良恶性分类任务,并将实验结果与联邦平均算法(federated averaging,Fedavg)(McMahan 等,2023)、联邦近端算法(federated proximal,FedProx)(Li 等,2020)、MOON(Li 等,2021a)、协调局部和全局漂移的联邦学习算法(harmonizing local and global drifts in federated learning,HarmoFL)(Jiang 等,2022)4 种联邦学习方法相比较。实验中所采用的卷积神经网络(convolutional neural networks,CNN)模型结构为ResNet18(He 等,2016),并且以在ImageNet(Li 等,2010)数据集上训练好的网络参数为CNN 模型初始化参数。

2.1 数据集与实验设置

2.1.1 肺结核肺腺癌分类任务

使用的数据源自5 家医院,其临床数据信息如表1 所示,样本标签分为肺结核和肺腺癌,个别中心正负样本比例悬殊、各中心间结节特征分布差异大。在APFFT 框架下,以中心1 为服务端,其他4 个中心作为用户端。本文使用Focalloss(Lin 等,2020)计算交叉熵损失,其中alpha 设置为0.76,gamma 设置为2;使用SGD 算法更新RFS-Net、AA-Net 和CNN 网络模型参数,动量为0.9,其中RFS-Net 和AA-Net 的学习率和衰减率均为0.000 1,CNN 网络的学习率和衰减率为0.001和0.000 5。

表1 肺结核肺腺癌分类任务中各中心数据临床信息分布Table 1 The distribution of data and clinical information of each center in the TBG and LAC classification

2.1.2 乳腺癌组织学图像分类任务

本文使用公开数据集Camelyon17(the cancer metastasis in lymph nodes challenge 2017),其由5 家医院的组织学图像组成,共450 000 幅图像(Bandi等,2019),每家医院的图像标签均分为正常组织和肿瘤组织。以中心1 为服务端,其他中心为用户端。本文使用标准交叉熵计算交叉熵损失,CNN 网络、RFS-Net 和AA-Net 的学习率和衰减率均为0.000 1,使用SGD算法更新网络参数,动量为0.9。

2.1.3 肺结节良恶性分类任务

使用公开数据集肺影像数据库联盟(the lung image database consortium,LIDC),该数据集由美国国家癌症研究所发起收集,其数据来源于7 家研究机构和8 家医学图像公司,共有1 018 个病例,其中≥3 mm 的肺结节,其恶性程度分为1~5 级,其中等级3为不确定恶性程度(McNitt-Gray 等,2007),将恶性程度为1~2 级的病灶归为良性,将恶性程度为4~5 级的病灶归为恶性,最后将1 746 个病灶纳入数据集。为模拟联邦学习应用场景,将1 018 个病例随机分为4 个中心,其标签分布情况如表2 所示。以中心1 为服务端,其他中心为用户端。本文使用标准交叉熵计算交叉熵损失,使用SGD 优化法更新RFS-Net、AA-Net 和CNN 网络模 型参数,动量为0.9,其中RFS-Net 和AA-Net 的学习率和衰减率均为0.000 1,CNN 网络的学习率和衰减率为0.001 和0.000 1。

表2 肺结节良恶性分类任务中各中心数据分布表Table 2 The data distribution of each center in the classification of benign and malignant pulmonary nodules

本文使用联邦学习框架超参数的默认设置为:本地模型训练循环E为1,特征迁移循环G为2。做迁移层对(m,n)匹配时,以ResNet18 的4 个block 的输出层来两两匹配,预先设定的匹配层对(m,n)共16对,λm,n初始值均为1。

2.2 实验结果及对比分析

本文将AUC(area under the curve)、准确率(accuracy)、阳性预测值(positive predictive value,PPV)、阴性预测值(negative predictive value,NPV)和召回值(recall)用于模型的对比和评价。其中,AUC反映了算法分类的综合性能;准确率为识别结果正确的样本(包括正负样本)占所有样本的比例;PPV为真阳性样本在识别结果为阳性的样本中的比例;NPV为识别为真阴性样本在识别结果为阴性的样本中的比例;recall为识别结果为正确的正样本占所有正样本的比例。

2.2.1 肺结核肺腺癌分类任务

肺结核(tuberculous granuloma,TBG)肺腺癌(lung adenocarcinoma,LAC)分类任务,如表3所示,展示了各方法在5个数据中心测试集上的端对端准确率和AUC,各个模型的ROC(receiver operating characteristic)曲线如图5所示(其曲线下面积为AUC数值)。与Fedavg相比较,FedProx、MOON、HarmoFL在5个中心上的AUC都有一定的提升,但在中心2和中心4上有明显的性能限制。相较于上述4 种算法,提出的APFFT方法在5个中心测试集上AUC都有明显提升,AUC 分 别为0.791 5,0.798 1,0.760 0,0.705 7,0.806 9。此外,从PPV、NPV 和recall 指标的比较来看,提出方法在具有高PPV 的情况下较大提升了NPV。

表3 不同联邦学习算法在肺结核肺腺癌分类任务上的结果Table 3 The results for TBG and LAC classification of different federated learning methods

图5 肺结核肺腺癌分类任务各中心测试集的ROC曲线Fig.5 The ROC curves of each center test cohort in TBG and LAC classification((a)center 1 testset;(b)center 2 testset;(c)center 3 testset;(d)center 4 testset)

FedProx、MOON 和HarmoFL 3 种算法在中心2和中心4 上出现性能限制的主要原因是:1)正负样本比例悬殊。由表1 知,在中心2 中,训练集的正负比例达到10∶1,测试集达到26∶1。在中心4中,训练集约3∶1,测试集4∶1。由表3 中3 种算法在两个数据中心的准确率和AUC 表现可知,模型对正样本过度学习,故准确率尚可而AUC 不佳;2)结节特征存在较大差异。在肺结核和肺腺癌诊断中,临床信息中的性别、年龄、结节大小具有较大参考价值,由表1可知,在中心2 中,总体的患者年龄和结节大小偏小,而中心4 中结节大小总体偏大,与全局平均水平具有一定差距。

2.2.2 乳腺癌组织学图像分类任务

对于乳腺癌组织学图像分类任务,如表4 所示,对于准确率指标,在各中心上,相比于其他联邦学习算法,所提方法准确率取得了约1%~9%的提升,5 个中心的测试集准确率分别为0.984 9,0.980 8,0.983 5,0.982 6,0.983 4。此外,在PPV、NPV 和recall指标的比较上也均为最优。

表4 不同联邦学习算法在乳腺癌组织学图像分类任务上的结果Table 4 The results for breast cancer histology images classification of different federated learning methods

Fedavg 方法以固定权重合成全局模型参数,影响了全局模型对有效特征的反应,再以fine-tune 方式训练本地模型,全局模型中存在的异质性特征信息可能无益于本地任务;MOON 使用对比损失约束本地与全局模型的参数差异,各中心准确率有一定提升,但中心2 的结果明显较低;而FedProx 使用正则项约束本地模型训练以降低本地模型与全局模型间的参数差异,中心2 的准确率有明显提升;HarmoFL通过调和图像低频特征和添加扰动项,使模型参数一致性得到加强,各中心准确率有一定提升。

2.2.3 肺结节良恶性分类任务

对于肺结节良恶性分类任务,如表5 所示,Fedavg 在中心2 和中心3 上得到良好的AUC 指标,但在中心1 上AUC 指标有很大的下降;与其相比较而言,FedProx 在中心1 上有较大性能提升,但在中心3 上AUC 有所下降;MOON 相较Fedavg、FedProx、HarmoFL 而言,在4 个中心上均得到了较好的AUC指标,但中心3 上的AUC 相对较低。与4 种方法相比,提出的APFFT 在4个中心都有全面的性能提升,AUC 分别为0.809 7,0.849 8,0.784 8,0.792 3,各个模型的ROC 曲线对比如图6所示。此外,在PPV、NPV 和recall 指标上,提出方法在中心1、2、4 和5 均有较好的性能,在中心3上,在具有较高NPV 的情况下PPV得到一定的提升。

图6 肺结节良恶性分类任务各中心测试集ROC曲线Fig.6 The ROC curves of each center test cohort in classification of pulmonary nodules and malignant pulmonary nodule((a)center 1 testset;(b)center 2 testset;(c)center 3 testset;(d)center 4 testset)

Fedavg、FedProx、MOON 和HarmoFL 4 种方法采用fine-tune 方式训练本地模型,引入了全局模型参数中的异质性特征信息,导致个别中心的结果较好,个别中心的结果很差,难以实现全面提升。

4 结论

面对多中心医学图像数据中存在的异质性特征,已有联邦学习算法的全局模型聚合和本地模型个性化存在灵活性较低、适应性差的问题。为此,本文提出一种基于特征迁移的自适应个性化联邦学习算法APFFT。在个性化本地模型构建中,提出RFSNet 自动识别和选择由全局模型向本地模型迁移的有效特征,其根据全局模型在本地数据的特征计算特征通道迁移权重和迁移量,最后结合特征间差异性构建迁移损失函数以加强本地模型对有效迁移特征的注意力;在聚合全局模型中,提出AA-Net 从多个本地模型向全局模型迁移特征,其基于本地模型特征和全局模型交叉熵变化更新迁移权重,然后构建聚合损失以过滤各本地模型中的异质性特征信息。实验中,对肺结核肺腺癌分类任务中的5 个中心临床信息进行统计分析,发现部分中心的临床特征与整体数据集水平有较大差异,对比算法在这些中心上得到的端对端AUC 结果较差。为进一步验证APFFT 算法性能,在两个公开医学图像数据集上进行测试,结果表明,提出的方法在医学数据集中各中心有较好的适应性,并在同类算法中保持较高的模型分类性能。基于特征迁移的思想构建个性化本地模型和聚合全局模型,不仅可以降低和过滤异质性特征信息,还能使模型将注意力放在鲁棒特征上,提升模型在本地医学数据上的泛化能力。医学图像样本规模有限,采用这种基于特征迁移的自适应方法聚合全局模型和构建个性化本地模型,可以为未来形成一种联邦学习在医学图像应用的自动化机制提供有益经验。但目前迁移学习与联邦学习的结合还停留在较浅的层面,后续将考虑从图像卷积特征入手,寻找迁移特征与本地数据深层次结合的方法。

猜你喜欢
用户端联邦全局
Android用户端东北地区秸秆焚烧点监测系统开发与应用
Cahn-Hilliard-Brinkman系统的全局吸引子
量子Navier-Stokes方程弱解的全局存在性
一“炮”而红 音联邦SVSound 2000 Pro品鉴会完满举行
303A深圳市音联邦电气有限公司
落子山东,意在全局
基于三层结构下机房管理系统的实现分析
基于三层结构下机房管理系统的实现分析
一种太阳能户外自动花架电气系统简介
一种改进型算法在用户端的性能评估的应用