ICLR盲審階段被審稿人讚不絕口的論文:會是Transformer架構的一大創新嗎?

首次!無殘差連線或歸一化層,也能成功訓練深度transformer。
儘管取得了很多顯著的成就,但訓練深度神經網路(DNN)的實踐進展在很大程度上獨立於理論依據。大多數成功的現代 DNN 依賴殘差連線和歸一化層的特定排列,但如何在新架構中使用這些元件的一般原則仍然未知,並且它們在現有架構中的作用也依然未能完全搞清楚。
殘差架構是最流行和成功的,最初是在卷積神經網路(CNN)的背景下開發的,後來自注意力網路中產生了無處不在的 transformer 架構。殘差架構之所以取得成功,一種原因是與普通 DNN 相比具有更好的訊號傳播能力,其中訊號傳播指的是幾何資訊透過 DNN 層的傳輸,並由核心函式表示。
最近,使用訊號傳播原則來訓練更深度的 DNN 並且殘差架構中沒有殘差連線和 / 或歸一化層的參與,成為了社群感興趣的領域。原因有兩個:首先驗證了殘差架構有效性的訊號傳播假設,從而闡明對 DNN 可解釋性的理解;其次這可能會實現超越殘差正規化的 DNN 可訓練性的一般原則和方法。
對於 CNN,Xiao et al. (2018)的工作表明,透過更好初始化提升的訊號傳播能夠高效地訓練普通深度網路,儘管與殘差網路比速度顯著降低。Martens et al. (2021) 的工作提出了 Deep Kernel Shaping (DKS),使用啟用函式轉換來控制訊號傳播,使用 K-FAC 等強二階最佳化器在 ImageNet 上實現了普通網路和殘差網路的訓練速度相等。Zhang et al. (2022) 的工作將 DKS 擴充套件到了更大類的啟用函式,在泛化方面也實現了接近相等。
訊號傳播中需要分析的關鍵量是 DNN 的初始化時間核心,或者更準確地說,是無限寬度限制下的近似核心。對於多層感知機(MLP)以及使用 Delta 初始化的 CNN,該核心可以編寫為僅包含 2D 函式的簡單層遞迴,以便於進行直接分析。跨層 transformer 的核心演化更加複雜,因此 DKS 等現有方法不適用 transformer 或實際上任何包含自注意力層的架構。
在 MLP 中,訊號傳播是透過檢視(一維)核心的行為來判斷的,而 transformer 中的訊號傳播可以透過檢視(高維)核心矩陣在網路層中的演化來判斷。
該研究必須避免一種情況:對角線元素隨深度增加快速增長或收縮,這與不受控制的啟用範數有關,可能導致飽和損失或數值問題。避免秩崩潰(rank collapse)對於深度 transformer 的可訓練性是必要的,而是否可以訓練深度無殘差 transformer 仍是一個懸而未決的問題。
ICLR 2023 盲審階段的這篇論文解決了這個問題,首次證明了無需殘差連線或歸一化層時也可能成功訓練深度 transformer。為此,他們研究了深度無殘差 transformer 中的訊號傳播和秩崩潰問題,並推匯出三種方法來阻止它們。具體而言,方法中使用了以下組合:引數初始化、偏置矩陣和位置相關的重縮放,並強調了 transformer 中訊號傳播特有的幾種複雜性,包括與位置編碼和因果掩蔽的互動。研究者實證證明了他們的方法可以生成可訓練的深度無殘差 transformer。
在實驗部分,在 WikiText-103 和 C4 資料集上,研究者展示了使用他們主要的方法——指數訊號保持注意力(Exponential Signal Preserving Attention, E-SPA),可以透過延長大約五倍的訓練時間使得標準 transformer 與文中無殘差 transformer 的訓練損失相當。此外透過將這一方法與殘差連線結合,研究者還表明無歸一化層的 transformer 能夠實現與標準 transformer 相當的訓練速度。
論文地址:https://openreview.net/pdf?id=NPrsUQgMjKK
對於這篇論文,Google AI 首席工程師 Rohan Anil 認為是 Transformer 架構向前邁出的一大步,還是一個基礎性的改進。
構造無捷徑可訓練的深層 Transformer
迄今為止,糾正 Transformer 秩崩潰(rank collapse)的唯一策略依賴於殘差連線,該方式跳過了自注意力層固有的可訓練性問題。與此相反,該研究直接解決這個問題。首先透過注意力層更好地理解訊號傳播,然後根據見解(insights)進行修改,以在深度 transformer 中實現對忠實訊號的傳輸,無論是否使用殘差連線,都可以對訊號進行訓練。
具體而言,首先,該研究對僅存在注意力的深度 vanilla transformer 進行了一下簡單設定,之後他們假設該 transformer 具有單一頭(h = 1)設定或具有多頭設定,其中注意力矩陣 A 在不同頭之間不會變化。如果塊 l≤L 初始化時有注意力矩陣 A_l,則最終塊的表示形式為 X_L:
對於上式而言,如果

採用正交初始化,那麼

就可以在初始化時正交。

在上述假設下,如果採用

表示跨位置輸入核矩陣,經過一些簡化處理後,可以得到如下公式:

從這個簡化公式(深度僅注意力 transformer 中的核矩陣)中,可以確定對 (A_l)_l 的三個要求: 
  1. 必須在每個塊中表現良好,避免退化情況,如秩崩潰和爆炸 / 消失的對角線值;

  2. A_l 必須是元素非負 ∀l;
  3. A_l 應該是下三角∀l,以便與因果掩碼注意力相容。
在接下來的 3.1 和 3.2 節中,該研究專注於尋找滿足上述需求的注意力矩陣,他們提出了 3 種方法 E-SPA、U-SPA 和 Value-Skipinit,每種方法都用來控制 transformer 的注意力矩陣,即使在很深的深度也能實現忠實的訊號傳播。此外,3.3 節演示瞭如何修改 softmax 注意力以實現這些注意力矩陣。
下圖中,該研究對提出的兩個 SPA 方案進行了驗證,U-SPA 和 E-SPA,結果顯示即使在網路較深時也能成功地避免僅注意力 vanilla transformers 中的秩崩潰現象。
實驗
WikiText-103 基線:首先,該研究驗證了沒有殘差連線的標準深度 transformer 是不可訓練的,即使它們有歸一化層 (LN) 和 transformed 啟用,但本文的方法可以解決這個問題。如圖 2 所示,可以清楚地看到,從標準 transformer 中移除殘差連線使其不可訓練,訓練損失穩定在 7.5 左右。正如圖 1 所示,標準 transformer 遭受了秩崩潰。
另一方面,該研究提出的 E-SPA 方法優於 U-SPA 和 Value-Skipinit。然而,與本文無殘差方法相比,帶有殘差和 LN 的預設 transformer 仍然保持訓練速度優勢。
在表 1 中,該研究使用提出的方法評估了 MLP 塊中不同啟用函式的影響,以及 LN 在無殘差 transformer 的使用。可以看到在深度為 36 處,本文方法針對一系列啟用實現了良好的訓練效能:DKS-transformed GeLU、TAT-transformed Leaky ReLU 以及 untransformed GeLU ,但不是 untransformed Sigmoid。透過實驗還看到,層歸一化對於訓練速度而言相對不重要,甚至在使用 SPA 時對 transformed activation 的啟用有害,因為 SPA 已經具有控制啟用規範的內建機制。
在圖 3 中,我們看到一種不需要更多迭代就能匹配預設 transformer 訓練損失的方法是使用歸一化殘差連線。
表 2 顯示帶有歸一化殘差和 LN 的 E-SPA 優於預設的 PreLN transformer。
下圖 4(a)表明 E-SPA 再次優於其他方法;4(b)表明訓練損失差距可以透過簡單地增加訓練時間來消除。
CVPR/ECCV 2022論文和程式碼下載
後臺回覆:CVPR2022,即可下載CVPR 2022論文和程式碼開源的論文合集
後臺回覆:ECCV2022,即可下載ECCV 2022論文和程式碼開源的論文合集
後臺回覆:Transformer綜述,即可下載最新的3篇Transformer綜述PDF
目標檢測和Transformer交流群成立
掃描下方二維碼,或者新增微信:CVer222,即可新增CVer小助手微信,便可申請加入CVer-目標檢測或者Transformer 微信交流群。另外其他垂直方向已涵蓋:目標檢測、影像分割、目標跟蹤、人臉檢測&識別、OCR、姿態估計、超解析度、SLAM、醫療影像、Re-ID、GAN、NAS、深度估計、自動駕駛、強化學習、車道線檢測、模型剪枝&壓縮、去噪、去霧、去雨、風格遷移、遙感影像、行為識別、影片理解、影像融合、影像檢索、論文投稿&交流、PyTorch、TensorFlow和Transformer等。
一定要備註:研究方向+地點+學校/公司+暱稱(如目標檢測或者Transformer+上海+上交+卡卡),根據格式備註,可更快被透過且邀請進群
▲掃碼或加微訊號: CVer222,進交流群
CVer學術交流群(知識星球)來了!想要了解最新最快最好的CV/DL/ML論文速遞、優質開源專案、學習教程和實戰訓練等資料,歡迎掃描下方二維碼,加入CVer學術交流群,已彙集數千人!

掃碼進群
▲點選上方卡片,關注CVer公眾號
整理不易,請點贊和在看


相關文章