點選下方卡片,關注“CVer”公眾號
點選下方卡片,關注“CVer”公眾號
AI/CV重磅乾貨,第一時間送達
AI/CV重磅乾貨,第一時間送達
新增微訊號:CVer2233,小助手會拉你進群!
掃描下方二維碼,加入CVer學術星球!可以獲得最新頂會/頂刊上的論文idea和CV從入門到精通資料,及最前沿應用!發論文/搞科研/漲薪,強烈推薦!
新增微訊號:CVer2233,小助手會拉你進群!
掃描下方二維碼,加入CVer學術星球!可以獲得最新頂會/頂刊上的論文idea和CV從入門到精通資料,及最前沿應用!發論文/搞科研/漲薪,強烈推薦!

導讀
在相同的模型尺寸下,DiG-XL/2 比基於 Mamba 的擴散模型在 1024 的解析度下快 4.2 倍,在 2048 的解析度下比帶有 CUDA 最佳化的 FlashAttention2 的 DiT 快 1.8 倍。這些結果都證明了其優越效能。
本文目錄
1 DiG:使用門控線性注意力機制的高效可擴充套件 Diffusion Transformer(來自華科,字節跳動)1 DiM 論文解讀1.1 DiG:一種輕量級 DiT 架構1.2 門控線性注意力 Transformer1.3 擴散模型1.4 Diffusion GLA 模型1.5 DiG Block 架構1.6 複雜度分析1.7 實驗結果
太長不看版
Diffusion Transformer 模型面臨的一個問題是計算複雜度與序列長度呈二次方關係,這不利於擴散模型的縮放。本文透過門控線性注意力機制 (Gated Linear Attention) 的長序列建模能力來應對這個問題,來提升擴散模型的適用性。
本文提出的模型稱為 Diffusion Gated Linear Attention Transformers (DiG),是一種基於門控線性注意力機制和 DiT[1]的簡單高效的擴散 Transformer 模型。除了比 DiT 更好的效能外,DiG-S/2 的訓練速度比 DiT-S/2 高 2.5 倍,並在 1792 的解析度節省 75.7% 的 GPU 視訊記憶體。此外,作者分析了 DiG 在各種計算複雜度下的可擴充套件性。結果是隨著模型的縮放,DiG 模型始終表現出更優的 FID。作者還將 DiG 與其他次 subquadratic-time 的擴散模型進行了比較。在相同的模型尺寸下,DiG-XL/2 比基於 Mamba 的擴散模型在 1024 的解析度下快 4.2 倍,在 2048 的解析度下比帶有 CUDA 最佳化的 FlashAttention2 的 DiT 快 1.8 倍。這些結果都證明了其優越效能。
本文做了哪些具體的工作
-
提出了 Diffusion GLA (DiG),透過分層掃描和區域性視覺感知進行全域性視覺上下文建模。DiG 使用線性注意力 Transformer 來實現 diffusion backbone。 -
DiG 在訓練速度和 GPU 視訊記憶體成本方面都表現出更高的效率,同時保持與 DiT 相似的建模能力。具體而言,DiG 比 DiT 快 2.5 倍,並在 1792×1792 的解析度中節省 75.7% 的 GPU 視訊記憶體,如圖1所示。 -
作者在 ImageNet 資料集上進行了廣泛的實驗。結果表明,與 DiT 相比,DiG 表現出可擴充套件的能力並實現了卓越的效能。在大規模長序列生成的背景下,DiG 有望成為下一代 Backbone。


1 DiG:使用門控線性注意力機制的高效可擴充套件 Diffusion Transformer
論文名稱:DiG: Scalable and Efficient Diffusion Models with Gated Linear Attention (Arxiv 2024.05)
論文地址:
http://arxiv.org/pdf/2405.18428
程式碼連結:
http://github.com/hustvl/DiG
1.1 DiG:一種輕量級 DiT 架構
擴散模型以其生成高質量的影像生成能力而聞名。隨著取樣演算法的快速發展,主要技術根據其 Backbone 架構演變為2個主要類別:基於 U-Net 的方法[2]和基於 ViT 的方法[3]。基於 U-Net 的方法繼續利用卷積神經網路 (CNN) 架構,其分層特徵建模能力有利於視覺生成任務。另一方面,基於 ViT 的方法結合注意力機制。由於其出色的效能與可擴充套件性,基於 ViT 的方法已被用作最先進的擴散工作中的 Backbone,包括 PixArt、Sora、Stable Diffusion 3 等。然而,基於 ViT 的架構的 Self-attention 機制與輸入序列長度呈二次方關係,使得它們在處理長序列生成任務 (例如高解析度影像生成、影片生成等) 時資源消耗較大。最近的架構 Mamba[4]、RWKV[5]和 Gated Linear Attention Transformer (GLA)[6],試圖透過整合 RNN 類的架構,以及硬體感知演算法來提高長序列處理效率。其中,GLA 將依賴於資料的門控操作和硬體高效的實現結合到線性注意力 Transformer 中,顯示出具有競爭力的效能,但吞吐量更高。
受 GLA 在自然語言處理領域的成功的啟發,作者將這種成功從語言生成轉移到視覺內容生成領域,即使用高階線性注意力設計可擴充套件且高效的 Diffusion Backbone。然而,使用 GLA 進行視覺生成面臨兩個挑戰,即單向掃描建模和缺乏區域性資訊。為了應對這些挑戰,本文提出了 Diffusion GLA (DiG) 模型,該模型結合了一個輕量級的空間重定向和增強模組 (Spatial Reorient & Enhancement Module, SREM),用於分層掃描方向控制和區域性感知。掃描方向包含四個基本模式,並使序列中的每個 Patch 能夠感知沿縱橫方向的其他 Patch。此外,作者還在 SREM 中加入了深度卷積 (DWConv),使用很少的引數為模型注入區域性資訊。
1.2 門控線性注意力 Transformer
Gated Linear Attention Transformer (GLA) 結合依賴於資料的門控機制和線性注意力, 實現了卓越的迴圈建模效能。給定輸入 ( 是序列長度, 是維度),GLA 計算 Query、Key 和 Value 向量:

式中 是線性投影權重。 和 是維度數。接下來, GLA 計算門控矩陣 ,如下所示:

其中 是 token 的索引, 是 sigmoid 函式, 是偏置項, 是溫度項。如圖3所示, 最終輸出 如下:


其中, Swish 是 Swish 啟用函式, 是逐元素乘法運算。在接下來的部分中, 使用 來指代輸入序列的門控線性注意力計算。
1.3 擴散模型
DDPM[7]透過迭代去噪輸入將噪聲作為輸入和取樣影像。DDPM 的前向過程是隨機過程,其中初始影像 逐漸被噪聲破壞,最後轉化為更簡單、噪聲主導的狀態。前向噪聲過程可以表示如下:

其中 是從時間 到 的噪聲影像序列。然後, DDPM 使用可學習的 和 恢復原始影像的反向過程:

其中, 是去噪模型的引數, 使用 variational lower bound 在觀測資料 的分佈下訓練:

其中, 是總的損失函式。為了進一步簡化 DDPM 的訓練過程, 研究人員將 重引數化為噪聲預測網路 , 使 與真實高斯噪聲 之間的均方誤差損失 做最小化:

然而, 為了訓練能夠學習反向過程協方差 Σ 𝜃 的擴散模型, 就需要最佳化完整的 𝐷 𝐾 𝐿 項。本文作者遵循 DiT 訓練網路, 其中使用損失 𝐿 simple 來訓練噪聲預測網路 𝜖 𝜃 , 並使用全損失 𝐿 來訓練協方差預測網路 Σ 𝜃 。
1.4 Diffusion GLA 模型
本文提出了 Diffusion GLA (DiG),一種用於生成任務的新架構。本文的目標是儘可能忠實於標準的 GLA 架構,以保持其縮放能力和高效率的特性。GLA 的概述如圖 3 所示。
標準 GLA 一般用於一維序列的因果語言建模。為了適配影像的 DDPM 訓練, 本文遵循 ViT 架構的實踐。DiG 以 VAE 編碼器的輸出的空間表徵 作為輸入。對於 的影像, VAE 編碼器的空間表徵 的形狀為 。DiG 隨後透過 Patchify 層將空間輸入轉換為 token 序列 , 其中 為序列的長度, 為空間表示通道數, 為影像補丁的大小, 因此 的減半將使得 變為 4 倍。接下來, 將 線性投影到維度為 的向量上, 並將基於頻率的位置嵌入 新增到所有投影 token 中, 如下所示:

其中 是 的第 個 Patch, 是可學習的投影矩陣。至於噪聲時間步 和類標籤 等條件資訊, 作者分別採用多層感知 (MLP) 和嵌入層作為 timestep embedder 和 label embedder。

其中 是 time Embedding, 是 label Embedding。然後, 作者將令牌序列 傳送到 DiG 編碼器的第 層, 得到輸出 。最後, 對輸出標記序列 進行歸一化, 並將其饋送到線性投影頭以獲得最終預測的噪聲 和預測的協方差 , 如下所示:

其中, 是第 個擴散 Block, 是層數, Norm 是歸一化層。 和預測的協方差 與輸入空間表示具有相同的形狀, 即 。
1.5 DiG Block 架構
原始的 GLA Block 以迴圈格式處理輸入序列,這隻能對 1-D 序列進行因果建模。本文提出的 DiG 的 Block 架構集成了一種空間重定向和增強模組 (Spatial Reorient & Enhancement Module, SREM),用於控制逐層掃描方向。DiG Block 架構如下圖4所示。

作者透過調整迴歸自適應層範數 (adaLN) 引數來啟動門控線性注意 (GLA) 和前饋網路 (FFN)。

然後,作者把序列改為 2D 的形狀,並使用一個輕量級的 3×3 深度卷積來感知區域性空間資訊。但使用傳統的 DWConv2d 初始化會導致收斂速度慢,因為卷積權重分散在周圍。為了解決這個問題,作者提出了 Identity 初始化,將卷積核中心設定為1,將周圍其他設定為0。最後,每兩個塊執行轉置 2D token 矩陣,並翻轉展平的序列,來控制下一個 Block 的掃描方向。如圖4右側所示,每層只處理一個方向的掃描。
1.6 複雜度分析
DiG 架構共有4種尺寸,分別是 DiG-S, DiG-B, DiG-L, 和 DiG-XL,配置如下圖6所示。其引數量從 31.5M 到 644.6M,計算量從 1.09GFLOPs 到 22.53GFLOPs。值得注意的是,與相同大小的基線模型 (即 DiT) 相比,DiG 只消耗 77.0% 到 78.9% 的 GFLOPs。

GPU 包含兩個重要的元件, 即高頻寬記憶體 (HBM) 和 SRAM。HBM 具有更大的記憶體大小, 但 SRAM 具有更大的頻寬。為了以並行形式充分利用 SRAM 和建模序列, GLA 將整個序列拆分為許多塊, 可以在 SRAM 上完成計算。定義塊的尺寸為 𝑀 , 訓練複雜度是
𝑂 ( 𝑇 𝑀 ( 𝑀 2 𝐷 + 𝑀 𝐷 2 ) ) = 𝑂 ( 𝑇 𝑀 𝐷 + 𝑇 𝐷 2 ) 。當 𝑇 < 𝐷 時, 略小於傳統注意力機制的計算複雜度 𝑂 ( 𝑇 2 𝐷 ) 。此外, DiT Block 中的 Depth-wise 卷積和高效矩陣運算也保證了效率。

1.7 實驗結果
作者使用 ImageNet 進行 class-conditional 影像生成任務的訓練,解析度為 256×256。作者使用水平翻轉作為資料增強,使用Frechet Inception Distance (FID)、Inception Score、sFID 和 Precision/Recall 來衡量生成效能。
使用恆定學習率為 1e-4 的 AdamW 最佳化器。遵循 DiT 的做法在訓練期間對 DiG 權重進行指數移動平均 (EMA),衰減率為 0.9999。使用 EMA 模型生成影像。對於 ImageNet 的訓練,使用現成的預訓練的 VAE。
如下圖7所示,作者分析了所提出的空間重定向和增強模組 (SREM) 的有效性。作者將 DiT-S/2 作為基線方法。原始的 DiG 模型只有 causal modeling,計算量和引數量都很少。但是因為缺乏全域性上下文,因此 FID 很差。作者首先向 DiG 新增雙向掃描,並觀察到了顯著的改進,證明了全域性上下文的重要性。而且,使用 Identity 初始化的 DWConv2d 也可以大大提高效能。DWConv2d 的實驗證明了 Identity 初始化和區域性資訊的重要性。最後一行的實驗表明,完整的 SREM 可以實現最佳的效能,且同時關注區域性和全域性上下文。

縮放模型尺寸
作者研究了 DiG 在 ImageNet 上的四種不同模型尺度之間的縮放能力。如圖 8(a) 所示,隨著模型從 S/2 擴充套件到 XL/2,效能有所提高。結果表明了 DiG 的縮放能力,以及作為基礎擴散模型的潛力。
Patch Size 的影響
作者在 ImageNet 上訓練了 Patch Size 從 2、4 和 8 不等的 DiG-S。如圖 8(b) 所示,透過減少 DiG 的 Patch Size,可以在整個訓練過程中觀察到明顯的 FID 最佳化。因此,最佳效能需要更小的 Patch Size 和更長的序列長度。與 DiT 基線相比,DiG 在處理長序列生成任務方面更有效。

作者將所提出的 DiG 與基線方法 DiT 進行比較,二者具有相同的超引數,結果如下圖9所示。所提出的 DiG 在 400K 訓練迭代的4個模型尺度上優於 DiT。此外,與以前的最先進方法相比,classifier-free guidance 的 DiG-XL/2-1200K 也顯示出具有競爭力的結果。

圖10展示了從 DiG-XL/2 中取樣的結果,這些結果來自 ImageNet 訓練的模型,解析度為 256×256。結果表明,DiG 生成結果的正確的語義和精確的空間關係。

參考
-
^Scalable Diffusion Models with Transformers -
^Denoising Diffusion Probabilistic Models -
^An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale -
^Mamba: Linear-Time Sequence Modeling with Selective State Spaces -
^RWKV: Reinventing RNNs for the Transformer Era -
^abGated Linear Attention Transformers with Hardware-Efficient Training -
^Denoising Diffusion Probabilistic Models
何愷明在MIT授課的課件PPT下載
何愷明在MIT授課的課件PPT下載
CVPR 2025 論文和程式碼下載
CVPR 2025 論文和程式碼下載
ECCV 2024 論文和程式碼下載
ECCV 2024 論文和程式碼下載
CV垂直方向和論文投稿交流群成立
一定要備註:研究方向+地點+學校/公司+暱稱(如Mamba、多模態學習或者論文投稿+上海+上交+卡卡),根據格式備註,可更快被透過且邀請進群
▲掃碼或加微訊號: CVer2233,進交流群
CVer計算機視覺(知識星球)來了!想要了解最新最快最好的CV/DL/AI論文速遞、優質實戰專案、AI行業前沿、從入門到精通學習教程等資料,歡迎掃描下方二維碼,加入CVer計算機視覺(知識星球),已彙集上萬人!
▲掃碼加入星球學習
▲點選上方卡片,關注CVer公眾號
整理不易,請點贊和在看

▲掃碼或加微訊號: CVer2233,進交流群
CVer計算機視覺(知識星球)來了!想要了解最新最快最好的CV/DL/AI論文速遞、優質實戰專案、AI行業前沿、從入門到精通學習教程等資料,歡迎掃描下方二維碼,加入CVer計算機視覺(知識星球),已彙集上萬人!

▲掃碼加入星球學習
▲點選上方卡片,關注CVer公眾號
整理不易,請點贊和在看
