AAAI2024|上交等提出自適應間距強化對比學習,增強多個模型的分類能力

©PaperWeekly 原創 · 作者 | 張劍清
單位 | 上海交通大學、清華大學(AIR)
研究方向 | 聯邦學習
本文介紹的是我們的一篇收錄於 AAAI 2024 的論文,主要考慮的是資料異質和模型異構場景下的聯邦學習框架。在異構聯邦學習中,由於模型架構不同,傳統聯邦學習中的引數聚合方法不再適用,取而代之的是基於知識蒸餾的知識共享方法。
在這些方法中,我們關注不引入額外資料集的(data-free)這一類方法。這類方法普遍透過共享類別表徵向量(prototype)實現,但在模型架構差異較大的場景,每個客戶機生成的表徵向量差異懸殊,直接在伺服器端聚合表徵向量會造成表徵能力的下降。於是,我們提出一種在伺服器端基於自適應間距強化的對比學習來提高表徵向量的表徵能力的方法 FedTGP,進一步提升客戶端模型的分類能力。
論文標題:
FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning
論文連結:
https://arxiv.org/abs/2401.03230
程式碼連結:
https://github.com/TsingZ0/FedTGP(含有PPT和Poster)
執行實驗所需倉庫-個性化聯邦學習演算法庫:
https://github.com/TsingZ0/PFLlib
執行實驗所需倉庫-異構聯邦學習演算法庫:
https://github.com/TsingZ0/HtFLlib

異構聯邦學習背景

傳統聯邦學習透過在每一次迭代中傳遞模型引數的方式實現知識共享,但該方式存在侷限,無法適應更廣泛的場景,尤其是不易尋找到參與聯邦學習的客戶機。客戶機在參與聯邦學習之前,有自己本地的模型訓練任務,也有自研的模型架構和訓練得到的模型引數。每個客戶機參加聯邦學習的動機是為了透過聯邦學習增強自己模型的表現能力。若強制要求參與的客戶機都使用相同的模型結構且進行模型引數共享,則需要每個客戶機重新訓練模型。
另一方面,每個客戶機訓練得到的模型引數也是一種數字資產,尤其是在大模型時代保護模型引數的智慧財產權尤為重要。此外,共享模型引數也有通訊量大的問題。透過允許異構模型參與聯邦學習,並共享輕量化的知識載體,異構聯邦學習拓展了傳統聯邦學習的邊界,變得更加實用。
▲ 圖1:異構聯邦學習技術
目前異構聯邦學習技術還未形成統一的知識共享機制,我們考慮一種輕量化且不需要額外資料的知識共享機制:共享 prototype。本文考慮的是面向影像的多分類任務,其 prototype 的定義就是每個類別的代表性特徵向量,可透過平均該類所有的特徵向量獲得。現有工作中,FedProto [1] 是這方面最具代表性的方法之一,如下圖所示。
▲ 圖2:異構聯邦學習中使用prototype作為知識載體

FedProto的侷限性

雖然 FedProto 得到了廣泛使用,但之前的工作要麼將其用在傳統聯邦學習場景(異構聯邦學習技術在傳統場景也都適用),要麼採用異構性不強的異構模型(比如增減全連線層數和改變 CNN 網路的卷積核等)。在這些場景下,透過加權平均聚合 prototype 的方式確實具有不錯的表現。 
但當我們考慮更一般的場景:參與聯邦學習的客戶機訓練的模型的架構差異巨大,比如兩層 CNN 模型和 ResNet-152 模型。此時 FedProto 的 prototype 聚合方法就出現了一些問題。我們觀察到,由於模型架構相差巨大,不同模型的特徵提取能力也天差地別,它們生成的 prototype 也天差地別。
當我們透過加權平均去計算全域性 prototype(global prototype)時,具有較好表徵能力(不同 prototype 之間的間距(margin)較大)的 prototype 會被較差表徵能力的 prototype 影響,導致最終得到的 global prototype 表徵能力弱於最好的客戶機模型。我們稱這種現象為間距收縮(margin shrink),如下圖所示。進一步地,當這個特徵提取能力最好的客戶機模型使用了 global prototype 之後,其表徵能力則會下降。
▲ 圖3:FedProto在模型異構性較大場景下的間距收縮現象(Cifar10)

自適應間距強化的對比學習(ACL)

為了解決上述間距收縮的問題,我們提出了一種自適應間距強化的對比學習方法(ACL),如下圖所示。
▲ 圖4:FedProto與FedTGP的對比。其中圓形代表客戶機上傳的prototype,三角形代表global prototype。
該方法的核心思想是訓練一個 global prototype,使其能夠最大限度地保留最強客戶機模型生成的 prototype 的表徵能力,同時也汲取來自其他客戶機的 prototype 資訊。為了實現這一點,我們首先給傳統對比學習方法加上一個間距限制,即儘可能保證 prototype 之間的間距不低於所設定的閾值 。考慮類別 對應的 trainable global prototype(TGP),我們定義其訓練時候的損失函式為:
其中, 是在第 輪參與聯邦學習的客戶機集合, 是客戶機 上生成的對應類別 的 prototype, 是間距計算函式。
但在聯邦學習的過程中,各個客戶機模型的特徵提取能力不斷變化,若設定一個固定的閾值,則會導致間距過大或過小。於是我們考慮將 設定為一個自適應的值,其計算細節如下,其描述的就是每一輪不同類別之間的最大間距,且具有最大值 。
從而我們得到最終的對比學習目標:
使用 ACL 之後,我們便可以消除間距收縮的問題:
 ▲ 圖5:我們的FedTGP在使用ACL之後,消除了間距收縮的問題(Cifar10)

部分實驗

由於篇幅原因,我們只展示部分實驗結果,更多實驗結果和分析詳見論文。
▲ 表1:在4個數據集和8種異構模型場景下的測試準確率
▲ 表2:在Cifar100資料集和不同模型異構級別情況下的測試準確率
參考文獻
[1] Tan Y, Long G, Liu L, et al. Fedproto: Federated prototype learning across heterogeneous clients[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2022.
更多閱讀
#投 稿 通 道#
 讓你的文字被更多人看到 
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋樑,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。 
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學術熱點剖析科研心得競賽經驗講解等。我們的目的只有一個,讓知識真正流動起來。
📝 稿件基本要求:
• 文章確係個人原創作品,未曾在公開渠道發表,如為其他平臺已發表或待發表的文章,請明確標註 
• 稿件建議以 markdown 格式撰寫,文中配圖以附件形式傳送,要求圖片清晰,無版權問題
• PaperWeekly 尊重原作者署名權,並將為每篇被採納的原創首發稿件,提供業內具有競爭力稿酬,具體依據文章閱讀量和文章質量階梯制結算
📬 投稿通道:
• 投稿郵箱:[email protected] 
• 來稿請備註即時聯絡方式(微信),以便我們在稿件選用的第一時間聯絡作者
• 您也可以直接新增小編微信(pwbot02)快速投稿,備註:姓名-投稿
△長按新增PaperWeekly小編
🔍
現在,在「知乎」也能找到我們了
進入知乎首頁搜尋「PaperWeekly」
點選「關注」訂閱我們的專欄吧
·
·
·

相關文章