CVPR2025|DiG:具有門控線性注意力的高效擴散模型

點選下方卡片,關注“CVer”公眾號

AI/CV重磅乾貨,第一時間送達

新增微訊號:CVer2233,小助手會拉你進群!
掃描下方二維碼,加入CVer學術星球可以獲得最新頂會/頂刊上的論文ideaCV從入門到精通資料,及最前沿應用!發論文/搞科研/漲薪,強烈推薦!

作者:科技猛獸 | 轉載自:極市平臺

導讀

在相同的模型尺寸下,DiG-XL/2 比基於 Mamba 的擴散模型在 1024 的解析度下快 4.2 倍,在 2048 的解析度下比帶有 CUDA 最佳化的 FlashAttention2 的 DiT 快 1.8 倍。這些結果都證明了其優越效能。 

本文目錄

1 DiG:使用門控線性注意力機制的高效可擴充套件 Diffusion Transformer(來自華科,字節跳動)1 DiM 論文解讀1.1 DiG:一種輕量級 DiT 架構1.2 門控線性注意力 Transformer1.3 擴散模型1.4 Diffusion GLA 模型1.5 DiG Block 架構1.6 複雜度分析1.7 實驗結果

太長不看版

Diffusion Transformer 模型面臨的一個問題是計算複雜度與序列長度呈二次方關係,這不利於擴散模型的縮放。本文透過門控線性注意力機制 (Gated Linear Attention) 的長序列建模能力來應對這個問題,來提升擴散模型的適用性。
本文提出的模型稱為 Diffusion Gated Linear Attention Transformers (DiG),是一種基於門控線性注意力機制和 DiT[1]的簡單高效的擴散 Transformer 模型。除了比 DiT 更好的效能外,DiG-S/2 的訓練速度比 DiT-S/2 高 2.5 倍,並在 1792 的解析度節省 75.7% 的 GPU 視訊記憶體。此外,作者分析了 DiG 在各種計算複雜度下的可擴充套件性。結果是隨著模型的縮放,DiG 模型始終表現出更優的 FID。作者還將 DiG 與其他次 subquadratic-time 的擴散模型進行了比較。在相同的模型尺寸下,DiG-XL/2 比基於 Mamba 的擴散模型在 1024 的解析度下快 4.2 倍,在 2048 的解析度下比帶有 CUDA 最佳化的 FlashAttention2 的 DiT 快 1.8 倍。這些結果都證明了其優越效能。

本文做了哪些具體的工作

  1. 提出了 Diffusion GLA (DiG),透過分層掃描和區域性視覺感知進行全域性視覺上下文建模。DiG 使用線性注意力 Transformer 來實現 diffusion backbone。
  2. DiG 在訓練速度和 GPU 視訊記憶體成本方面都表現出更高的效率,同時保持與 DiT 相似的建模能力。具體而言,DiG 比 DiT 快 2.5 倍,並在 1792×1792 的解析度中節省 75.7% 的 GPU 視訊記憶體,如圖1所示。
  3. 作者在 ImageNet 資料集上進行了廣泛的實驗。結果表明,與 DiT 相比,DiG 表現出可擴充套件的能力並實現了卓越的效能。在大規模長序列生成的背景下,DiG 有望成為下一代 Backbone。

圖1:DiT、DiS 和 DiG 模型的效率比較。DiG 在處理高解析度影像時實現了更高的訓練速度,同時成本更低的 GPU 視訊記憶體

圖2:DiS、DiT、帶有Flash Attention-2 (Flash-DiT) 的 DiT 和不同模型大小的 DiG 模型之間的 FPS 對比

1 DiG:使用門控線性注意力機制的高效可擴充套件 Diffusion Transformer

論文名稱:DiG: Scalable and Efficient Diffusion Models with Gated Linear Attention (Arxiv 2024.05)
論文地址:
http://arxiv.org/pdf/2405.18428
程式碼連結:
http://github.com/hustvl/DiG

1.1 DiG:一種輕量級 DiT 架構

擴散模型以其生成高質量的影像生成能力而聞名。隨著取樣演算法的快速發展,主要技術根據其 Backbone 架構演變為2個主要類別:基於 U-Net 的方法[2]和基於 ViT 的方法[3]。基於 U-Net 的方法繼續利用卷積神經網路 (CNN) 架構,其分層特徵建模能力有利於視覺生成任務。另一方面,基於 ViT 的方法結合注意力機制。由於其出色的效能與可擴充套件性,基於 ViT 的方法已被用作最先進的擴散工作中的 Backbone,包括 PixArt、Sora、Stable Diffusion 3 等。然而,基於 ViT 的架構的 Self-attention 機制與輸入序列長度呈二次方關係,使得它們在處理長序列生成任務 (例如高解析度影像生成、影片生成等) 時資源消耗較大。最近的架構 Mamba[4]、RWKV[5]和 Gated Linear Attention Transformer (GLA)[6],試圖透過整合 RNN 類的架構,以及硬體感知演算法來提高長序列處理效率。其中,GLA 將依賴於資料的門控操作和硬體高效的實現結合到線性注意力 Transformer 中,顯示出具有競爭力的效能,但吞吐量更高。
受 GLA 在自然語言處理領域的成功的啟發,作者將這種成功從語言生成轉移到視覺內容生成領域,即使用高階線性注意力設計可擴充套件且高效的 Diffusion Backbone。然而,使用 GLA 進行視覺生成面臨兩個挑戰,即單向掃描建模和缺乏區域性資訊。為了應對這些挑戰,本文提出了 Diffusion GLA (DiG) 模型,該模型結合了一個輕量級的空間重定向和增強模組 (Spatial Reorient & Enhancement Module, SREM),用於分層掃描方向控制和區域性感知。掃描方向包含四個基本模式,並使序列中的每個 Patch 能夠感知沿縱橫方向的其他 Patch。此外,作者還在 SREM 中加入了深度卷積 (DWConv),使用很少的引數為模型注入區域性資訊。

1.2 門控線性注意力 Transformer

Gated Linear Attention Transformer (GLA) 結合依賴於資料的門控機制和線性注意力, 實現了卓越的迴圈建模效能。給定輸入  (  是序列長度,  是維度),GLA 計算 Query、Key 和 Value 向量:

式中  是線性投影權重。 和  是維度數。接下來, GLA 計算門控矩陣 ,如下所示:

其中  是 token 的索引,  是 sigmoid 函式,  是偏置項,  是溫度項。如圖3所示, 最終輸出  如下:

圖3:GLA Pipeline
其中, Swish 是 Swish 啟用函式,  是逐元素乘法運算。在接下來的部分中, 使用 來指代輸入序列的門控線性注意力計算。

1.3 擴散模型

DDPM[7]透過迭代去噪輸入將噪聲作為輸入和取樣影像。DDPM 的前向過程是隨機過程,其中初始影像  逐漸被噪聲破壞,最後轉化為更簡單、噪聲主導的狀態。前向噪聲過程可以表示如下:

其中  是從時間  到  的噪聲影像序列。然後, DDPM 使用可學習的  和  恢復原始影像的反向過程:

其中,  是去噪模型的引數, 使用 variational lower bound 在觀測資料  的分佈下訓練:

其中,  是總的損失函式。為了進一步簡化 DDPM 的訓練過程, 研究人員將  重引數化為噪聲預測網路 , 使  與真實高斯噪聲  之間的均方誤差損失  做最小化:

然而, 為了訓練能夠學習反向過程協方差 Σ𝜃 的擴散模型, 就需要最佳化完整的 𝐷𝐾𝐿 項。本文作者遵循 DiT 訓練網路, 其中使用損失 𝐿simple  來訓練噪聲預測網路 𝜖𝜃, 並使用全損失 𝐿 來訓練協方差預測網路 Σ𝜃 。

1.4 Diffusion GLA 模型

本文提出了 Diffusion GLA (DiG),一種用於生成任務的新架構。本文的目標是儘可能忠實於標準的 GLA 架構,以保持其縮放能力和高效率的特性。GLA 的概述如圖 3 所示。
標準 GLA 一般用於一維序列的因果語言建模。為了適配影像的 DDPM 訓練, 本文遵循 ViT 架構的實踐。DiG 以 VAE 編碼器的輸出的空間表徵  作為輸入。對於  的影像, VAE 編碼器的空間表徵  的形狀為  。DiG 隨後透過 Patchify 層將空間輸入轉換為 token 序列 , 其中  為序列的長度,  為空間表示通道數,  為影像補丁的大小, 因此  的減半將使得  變為 4 倍。接下來, 將  線性投影到維度為  的向量上, 並將基於頻率的位置嵌入  新增到所有投影 token 中, 如下所示:

其中  是  的第  個 Patch,  是可學習的投影矩陣。至於噪聲時間步 和類標籤  等條件資訊, 作者分別採用多層感知 (MLP) 和嵌入層作為 timestep embedder 和 label embedder。

其中  是 time Embedding,  是 label Embedding。然後, 作者將令牌序列  傳送到 DiG 編碼器的第  層, 得到輸出  。最後, 對輸出標記序列  進行歸一化, 並將其饋送到線性投影頭以獲得最終預測的噪聲  和預測的協方差 , 如下所示:

其中,  是第  個擴散  Block,  是層數, Norm 是歸一化層。 和預測的協方差  與輸入空間表示具有相同的形狀, 即  。

1.5 DiG Block 架構

原始的 GLA Block 以迴圈格式處理輸入序列,這隻能對 1-D 序列進行因果建模。本文提出的 DiG 的 Block 架構集成了一種空間重定向和增強模組 (Spatial Reorient & Enhancement Module, SREM),用於控制逐層掃描方向。DiG Block 架構如下圖4所示。
圖4:DiG 模型架構
作者透過調整迴歸自適應層範數 (adaLN) 引數來啟動門控線性注意 (GLA) 和前饋網路 (FFN)。

圖5:DiG 演算法流程
然後,作者把序列改為 2D 的形狀,並使用一個輕量級的 3×3 深度卷積來感知區域性空間資訊。但使用傳統的 DWConv2d 初始化會導致收斂速度慢,因為卷積權重分散在周圍。為了解決這個問題,作者提出了 Identity 初始化,將卷積核中心設定為1,將周圍其他設定為0。最後,每兩個塊執行轉置 2D token 矩陣,並翻轉展平的序列,來控制下一個 Block 的掃描方向。如圖4右側所示,每層只處理一個方向的掃描。

1.6 複雜度分析

DiG 架構共有4種尺寸,分別是 DiG-S, DiG-B, DiG-L, 和 DiG-XL,配置如下圖6所示。其引數量從 31.5M 到 644.6M,計算量從 1.09GFLOPs 到 22.53GFLOPs。值得注意的是,與相同大小的基線模型 (即 DiT) 相比,DiG 只消耗 77.0% 到 78.9% 的 GFLOPs。

圖6:DiG 架構配置
GPU 包含兩個重要的元件, 即高頻寬記憶體 (HBM) 和 SRAM。HBM 具有更大的記憶體大小, 但 SRAM 具有更大的頻寬。為了以並行形式充分利用 SRAM 和建模序列, GLA 將整個序列拆分為許多塊, 可以在 SRAM 上完成計算。定義塊的尺寸為 𝑀, 訓練複雜度是

𝑂(𝑇𝑀(𝑀2𝐷+𝑀𝐷2))=𝑂(𝑇𝑀𝐷+𝑇𝐷2) 。當 𝑇<𝐷 時, 略小於傳統注意力機制的計算複雜度 𝑂(𝑇2𝐷) 。此外, DiT Block 中的 Depth-wise 卷積和高效矩陣運算也保證了效率。

1.7 實驗結果

作者使用 ImageNet 進行 class-conditional 影像生成任務的訓練,解析度為 256×256。作者使用水平翻轉作為資料增強,使用Frechet Inception Distance (FID)、Inception Score、sFID 和 Precision/Recall 來衡量生成效能。
使用恆定學習率為 1e-4 的 AdamW 最佳化器。遵循 DiT 的做法在訓練期間對 DiG 權重進行指數移動平均 (EMA),衰減率為 0.9999。使用 EMA 模型生成影像。對於 ImageNet 的訓練,使用現成的預訓練的 VAE。
如下圖7所示,作者分析了所提出的空間重定向和增強模組 (SREM) 的有效性。作者將 DiT-S/2 作為基線方法。原始的 DiG 模型只有 causal modeling,計算量和引數量都很少。但是因為缺乏全域性上下文,因此 FID 很差。作者首先向 DiG 新增雙向掃描,並觀察到了顯著的改進,證明了全域性上下文的重要性。而且,使用 Identity 初始化的 DWConv2d 也可以大大提高效能。DWConv2d 的實驗證明了 Identity 初始化和區域性資訊的重要性。最後一行的實驗表明,完整的 SREM 可以實現最佳的效能,且同時關注區域性和全域性上下文。

圖7:SREM 模組的消融實驗結果
縮放模型尺寸
作者研究了 DiG 在 ImageNet 上的四種不同模型尺度之間的縮放能力。如圖 8(a) 所示,隨著模型從 S/2 擴充套件到 XL/2,效能有所提高。結果表明了 DiG 的縮放能力,以及作為基礎擴散模型的潛力。
Patch Size 的影響
作者在 ImageNet 上訓練了 Patch Size 從 2、4 和 8 不等的 DiG-S。如圖 8(b) 所示,透過減少 DiG 的 Patch Size,可以在整個訓練過程中觀察到明顯的 FID 最佳化。因此,最佳效能需要更小的 Patch Size 和更長的序列長度。與 DiT 基線相比,DiG 在處理長序列生成任務方面更有效。

圖8:DiG 模型大小和 Patch Size 的縮放分析
作者將所提出的 DiG 與基線方法 DiT 進行比較,二者具有相同的超引數,結果如下圖9所示。所提出的 DiG 在 400K 訓練迭代的4個模型尺度上優於 DiT。此外,與以前的最先進方法相比,classifier-free guidance 的 DiG-XL/2-1200K 也顯示出具有競爭力的結果。

圖9:ImageNet 256×256 class-conditional 影像生成任務實驗結果
圖10展示了從 DiG-XL/2 中取樣的結果,這些結果來自 ImageNet 訓練的模型,解析度為 256×256。結果表明,DiG 生成結果的正確的語義和精確的空間關係。

圖10:DiG-XL/2 模型生成結果
參考
  1. ^Scalable Diffusion Models with Transformers
  2. ^Denoising Diffusion Probabilistic Models
  3. ^An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale
  4. ^Mamba: Linear-Time Sequence Modeling with Selective State Spaces
  5. ^RWKV: Reinventing RNNs for the Transformer Era
  6. ^abGated Linear Attention Transformers with Hardware-Efficient Training
  7. ^Denoising Diffusion Probabilistic Models

何愷明在MIT授課的課件PPT下載

在CVer公眾號後臺回覆:何愷明,即可下載本課程的所有566頁課件PPT!趕緊學起來!

CVPR 2025 論文和程式碼下載

在CVer公眾號後臺回覆:CVPR2025,即可下載CVPR 2025論文和程式碼開源的論文合集

ECCV 2024 論文和程式碼下載

在CVer公眾號後臺回覆:ECCV2024,即可下載ECCV 2024論文和程式碼開源的論文合集
CV垂直方向和論文投稿交流群成立
掃描下方二維碼,或者新增微訊號:CVer2233,即可新增CVer小助手微信,便可申請加入CVer-垂直方向和論文投稿微信交流群。另外其他垂直方向已涵蓋:目標檢測、影像分割、目標跟蹤、人臉檢測&識別、OCR、姿態估計、超解析度、SLAM、醫療影像、Re-ID、GAN、NAS、深度估計、自動駕駛、強化學習、車道線檢測、模型剪枝&壓縮、去噪、去霧、去雨、風格遷移、遙感影像、行為識別、影片理解、影像融合、影像檢索、論文投稿&交流、PyTorch、TensorFlow和Transformer、NeRF、3DGS、Mamba等。
一定要備註:研究方向+地點+學校/公司+暱稱(如Mamba、多模態學習或者論文投稿+上海+上交+卡卡),根據格式備註,可更快被透過且邀請進群

▲掃碼或加微訊號: CVer2233,進交流群
CVer計算機視覺(知識星球)來了!想要了解最新最快最好的CV/DL/AI論文速遞、優質實戰專案、AI行業前沿、從入門到精通學習教程等資料,歡迎掃描下方二維碼,加入CVer計算機視覺(知識星球),已彙集上萬人!

掃碼加入星球學習
▲點選上方卡片,關注CVer公眾號
整理不易,請點贊和在看


相關文章