ICML2024|具有O(L)訓練儲存和O(1)推理功耗的時間可逆脈衝神經網路

©PaperWeekly 原創 · 作者 | 李國齊課題組
單位 | 中國科學院自動化研究所
研究方向 | 類腦計算
脈衝神經網路(Spike Neural Network,SNN)因其受大腦啟發的神經元動態和基於脈衝的計算模式,被認為是一種低功耗的人工神經網路(Artifical Neural Network,ANN)替代方案。然而受限於 SNN 中的神經元的時空動態特性,SNN 的訓練視訊記憶體開銷與運算時間均遠遠大於 ANN [1,2,3,4]
為解決此問題,本文提出一種時間可逆計算正規化,並基於此開發了 T-RevSNN 模型。與現有的 Spike-driven Transformer [5] 相比,T-RevSNN 的記憶體效率、訓練時間加速和推理能效分別具有 8.6 倍、2 倍和 1.6 倍的顯著提高。
論文標題:
High-Performance Temporal Reversible Spiking Neural Networks with O(L) Training Memory and O(1) Inference Cost
論文地址:
https://openreview.net/forum?id=s4h6nyjM9H
程式碼地址:
https://github.com/BICLab/T-RevSNN

背景

當前 SNN 模型的任務效能已在 ImageNet 上達到 80% 準確率 [6],已能夠滿足絕大多數實際任務場景,但是其訓練難度仍然遠高於同架構下的 ANN。如何降低 SNN 的訓練難度是目前 SNN 領域的重點難題。SNN 的訓練困難來源於其使用的 BPTT 訓練演算法。在訓練時需儲存每一層、每一時間步的神經元的啟用值,即在訓練時視訊記憶體複雜度為 O(LT),其中 L 是層數,T 是時間步。例如,訓練 10 時間步的脈衝 ResNet-19 比 ANN-ResNet-19 多需約 20 倍視訊記憶體 [1]。 
為解決這個問題,目前主流方法是解耦 SNN 訓練過程與時間步。然而,它們中沒有一個能夠同時實現低廉的訓練記憶體和低推理能耗,因為它們往往只在一個方向上進行最佳化。同時最近的研究顯示,SNN 的時間反向傳播對最終梯度影響小。既然如此,我們是否可以僅在關鍵位置保留時間前向,而關閉其他神經元的時間動態呢? 
基於此,我們考慮僅在關鍵位置保留時間前向,關閉其他神經元的時間動態。我們設計了時間可逆的 SNN (T-RevSNN)。首先,為減少訓練記憶體,僅在每個階段的輸出脈衝層啟用時間動態,並實現時間傳遞的可逆性,避免儲存所有神經元的膜電位和啟用。
其次,關閉的脈衝神經元不進行時間動態,簡化為不重用時間維度的引數,同時透過一次編碼輸入,將特徵和網路分為 T 組,避免增加引數和能耗。第三,為提升效能,採用多級資訊傳遞,重新設計 SNN 塊,並調整殘差連線以確保有效性。

本文貢獻

我們的貢獻包括:
1. 我們重新設計了 SNN 的前向傳播,簡單直觀地同時實現了低訓練記憶體、低功耗和高效能。
2. 我們在三個方面進行了系統設計,以實現提出的想法,包括關鍵脈衝神經元的多級時間可逆前向資訊傳遞、輸入編碼和網路架構的分組設計,以及SNN塊和殘差連線的改進。
3. 在 ImageNet-1k 上,我們的模型在基於 CNN 的 SNN 上實現了最先進的準確性,同時具有最小的記憶體和推理成本,並且訓練速度最快。與當前基於 Transformer 的 SNN 相比,即基於脈衝驅動的 Transformer,我們的模型在準確性上接近,而記憶體效率、訓練時間加速和推理能效可以顯著提高分別為 8.6×、2.0× 和 1.6×。

動機:儘可能少的梯度反傳

我們設定了以下實驗來分析哪些脈衝神經元的關鍵哪些不關鍵。同時由於遍歷工作量太大,一般認為個典型的神經網路被分為四個階段,每個階段的特徵層次各不相同。所以我們假設 SNN 中兩個階段交界處的時間資訊傳遞很重要。 為此我們設計瞭如下實驗,首先在 CIFAR-10 上訓練 Spiking Resnet,並將其設為基線,為了確定哪些神經元的時間梯度會對模型的訓練過程產生顯著影響,我們移除了可疑神經元的時間梯度,分為如下兩種情況: 
1. 案例1:我們保留每個階段最後一層的時間梯度,並去除其他神經元的時間梯度;
2. 案例2:我們採取相反的方法,去除每個階段的最終脈衝神經元的時間梯度,但保留其他神經元。
▲ 圖1. 基線與案例1的餘弦相似度隨訓練過程變化圖
▲ 圖2. 基線與案例2的餘弦相似度隨訓練過程變化圖
之後我們計算隨著 epoch 增加基線和案例 1,2 之間的餘弦相似度變化。相似度高表明該條件下的神經元的時間動態重要,相似度低則反之。 最終結果如圖1和圖2所示。隨著訓練週期的增加,圖 1(基線與案例 1 的比較)的相似度始終保持在高水平。相比之下,圖 2(基線與案例 2 的比較)的相似度始終較小。這說明每個階段最後一層的時間梯度比前面階段的脈衝神經元更重要,我們稱這些脈衝神經元為"關鍵神經元"。

方法

基於動機中的發現,我們設計了 Turn off / on 兩種脈衝神經元,分別對應於在動機中找到的不關鍵或關鍵的神經元。
▲ 圖3. 所提出的T-RevSNN的時間前向傳播的示意圖和網路結構細節
4.1 時間可逆脈衝神經元
T-RevSNN 中的脈衝神經元分為兩種關鍵和不關鍵神經元。 對於不關鍵神經元,即上圖中的綠色部分神經元,我們將其時間維度的連線進行關閉。我們稱這種關閉後的神經元為 Turn off 神經元。Turn off 神經元與一般的脈衝神經元唯一的區別是喪失了時間維度上膜電勢的資訊傳遞。其中 Turn off 脈衝神經元的前向傳播可描述如下:
可以看到 Turn off 脈衝神經元的權重更新依賴於空間和時間梯度。 對於 Turn on 脈衝神經元,其權重更新依賴於空間和時間梯度。受可逆性概念的啟發,我們觀察到其是自然可逆的。因此其前向傳播可描述如下:
隨後,可以在 和 之間建立可逆變換。這意味著在計算第一個時間步的梯度時,無需儲存所有時間步的膜電位和啟用值。我們只需要儲存 。這減少了 SNNs 多時間步訓練所需的記憶體。Turn on 神經元的時間複雜度與傳統的 SNN 訓練一致為 O(T)。
4.2 高效能的SNN訓練框架
為了提升 SNN 的效能, 研究者們提出了許多方法,然而,上述方法不足以實現高精度的 SNN,為此我們首先引入了多層次連線訓練框架 首先我們在相鄰時間步的 SNNs 之間建立了更強的多層次連線(如圖 3 所示)。我們將前一時間步的更深層次的高階特徵納入到當前時間步的資訊融合中。通常,我們可以按照以下方式構建前向資訊傳遞:
其次我們重寫設計的基本的 SNN 模組。它由兩個深度可分離卷積(DWConv/PWConv)和一個殘差連線組成。我們去掉了所有批次歸一化(BN)模組,轉而去使用將網路中所有層的權重都進行了歸一化的方法來穩定訓練。 之後我們使用了 ReZero 技術來增強網路在初始化後滿足動態等距的能力和促進高效的網路訓練。為了保證在推理中只發生加法運算,我們使用重引數化,將 ReZero 的縮放比例(即圖 4 中的 α)合併到上一層的權重中。
▲ 圖4. 遵循ConvNext正規化的基本的SNN模組

結果

5.1 不同訓練方法複雜度分析

▲ 圖5. T-RevSNN和其他SNN訓練最佳化方法的前傳和反傳示意圖
傳統的 SNN 訓練演算法(STBP)在計算從最後一層的最後一個時間步的輸出到第一層的第一個時間步的輸入的梯度時所需的記憶和計算構成了訓練 SNNs(脈衝神經網路)的記憶和時間複雜度。我們在表 1 和圖 5 中分析了所提出的 T-RevSNN 和其他 SNN 訓練最佳化方法 [2,3,4] 的訓練記憶體和時間複雜度。
▲ 表1. 不同演算法訓練和推理的計算複雜度

5.2 消融實驗

我們對 T-RevSNN 的不同時間步長和是否使用縮放殘差連線進行了消融實驗。
時間步長:在我們的設計中,我們將整個網路的引數分為T組子網路。在下表中,我們分析了不同的時間步長 T 對準確度、訓練速度和記憶體的影響。由於我們固定了總引數數量,增加T意味著每個時間步的子網路變得更小。相應地,訓練所需的記憶體會減少,但訓練時間會相應增加。此外,可以看到準確性與時間步之間的關係並不是線性的。
▲ 表2. 關於時間步長的消融實驗
縮放殘差連線:可以看到使用該技術有助於提高模型的收斂速度和最終準確度,如下表所示。
▲ 表3. 關於殘差連線的消融實驗

5.3 主要實驗結果

T-RevSNN 在 ImageNet 上的結果如下所示。本文取得了 SNN 域中最快的訓練速度和最低的記憶體消耗。
▲ 表4. 在大型ImageNet資料集上的實驗
如上表所示,本文所提出的 T-RevSNN 以 85.7 MB/圖片的記憶體消耗,和 9.1 分鐘/週期的訓練時間遠低於脈衝 Transformer 和脈衝卷積模型。體現了 T-RevSNN 在訓練速度、記憶體需求和推理功耗方面的顯著優勢,同時在效能上也具有競爭力。儘管準確率低於 Spike-driven Transformer,但我們認為這是由架構引起的差距,並且未來可以解決。 全文到此結束,更多細節建議檢視原文。
參考文獻
[1] Fang W, Chen Y, Ding J, et al. Spikingjelly: An open-source machine learning infrastructure platform for spike-based intelligence[J]. Science Advances, 2023, 9(40): eadi1480.
[2] Zhang H, Zhang Y. Memory-efficient reversible spiking neural networks[C]. Proceedings of the AAAI Conference on Artificial Intelligence. 2024, 38(15): 16759-16767.
[3] Meng Q, Xiao M, Yan S, et al. Towards memory-and time-efficient backpropagation for training spiking neural networks[C]. Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023: 6166-6176.
[4] Xiao M, Meng Q, Zhang Z, et al. Online training through time for spiking neural networks[J]. Advances in neural information processing systems, 2022, 35: 20717-20730.
[5] Yao M, Hu J, Zhou Z, et al. Spike-driven transformer[J]. Advances in neural information processing systems, 2024, 36:64043–64058.
[6] Yao M, Hu J K, Hu T, et al. Spike-driven Transformer V2: Meta Spiking Neural Network Architecture Inspiring the Design of Next-generation Neuromorphic Chips[C]. The Twelfth International Conference on Learning Representations.
更多閱讀
#投 稿 通 道#
 讓你的文字被更多人看到 
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋樑,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。 
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學術熱點剖析科研心得競賽經驗講解等。我們的目的只有一個,讓知識真正流動起來。
📝 稿件基本要求:
• 文章確係個人原創作品,未曾在公開渠道發表,如為其他平臺已發表或待發表的文章,請明確標註 
• 稿件建議以 markdown 格式撰寫,文中配圖以附件形式傳送,要求圖片清晰,無版權問題
• PaperWeekly 尊重原作者署名權,並將為每篇被採納的原創首發稿件,提供業內具有競爭力稿酬,具體依據文章閱讀量和文章質量階梯制結算
📬 投稿通道:
• 投稿郵箱:[email protected] 
• 來稿請備註即時聯絡方式(微信),以便我們在稿件選用的第一時間聯絡作者
• 您也可以直接新增小編微信(pwbot02)快速投稿,備註:姓名-投稿
△長按新增PaperWeekly小編
🔍
現在,在「知乎」也能找到我們了
進入知乎首頁搜尋「PaperWeekly」
點選「關注」訂閱我們的專欄吧
·
·
·

相關文章