通向機率分佈之路:盤點Softmax及其替代品

©PaperWeekly 原創 · 作者 | 蘇劍林
單位 | 科學空間
研究方向 |

NLP、神經網路

論是在基礎的分類任務中,還是如今無處不在的注意力機制中,機率分佈的構建都是一個關鍵步驟。具體來說,就是將一個 維的任意向量,轉換為一個 元的離散型機率分佈。眾所周知,這個問題的標準答案是 Softmax,它是指數歸一化的形式,相對來說比較簡單直觀,同時也伴有很多優良性質,從而成為大部分場景下的“標配”。
儘管如此,Softmax 在某些場景下也有一些不如人意之處,比如不夠稀疏、無法絕對等於零等,因此很多替代品也應運而生。在這篇文章中,我們將簡單總結一下 Softmax 的相關性質,並盤點和對比一下它的部分替代方案。
Softmax回顧
首先引入一些通用記號: 是需要轉為機率分佈的 n 維向量,它的分量可正可負,也沒有限定的上下界。 定義為全體 元離散機率分佈的集合,即
之所以標註 而不是 ,是因為約義了 維空間中的一個 維子平面,再加上 的約束, 的集合就只是該平面的一個子集,即實際維度只有 。
基於這些記號,本文的主題就可以簡單表示為探討 的對映,其中 我們習慣稱之為 Logits 或者 Scores。
基本定義
Softmax 的定義很簡單:
Softmax 的來源和詮釋都太多了,比如能量模型、統計力學或者單純作為 的光滑近似等,所以我們很難考證它的最早出處,也不去做這個嘗試了。很多時候我們也會加上一個溫度引數,即考慮 ,但 tau 本身也可以整合到 的定義之中,因此這裡不特別分離出 引數。
Softmax 的分母我們通常記為 ,它的對數就是大多數深度學習框架都自帶的 運算,他它是 的一個光滑近似:
當 取1時,就可以寫出 , 方差越大近似程度越高,更進一步的討論可以參考《尋求一個光滑的最大值函式》[1]
兩點性質
除了將任意向量轉換為機率分佈外,Softmax 還滿足兩點性質
單調性意味著 Softmax 是保序的, 的最大值/最小值跟 的最大值/最小值相對應;不變性說的是 的每個分量都加上同一個常數,Softmax 的結果不變,這跟 的性質是一樣的,即同樣有 。
因此,根據這兩點性質我們可以得出,Softmax實際是 一個光滑近似(更準確來說是 的光滑近似),更具體地我們有
這大概就是 Softmax 這個名字的來源。注意不要混淆了,Softmax 是 而不是 的光滑近似, 的光滑近似是 才對。
梯度計算
對於深度學習來說,瞭解一個函式的性質重要方式之一是瞭解它的梯度,對於 Softmax,我們在《從梯度最大化看Attention的Scale操作》曾經算過:
這樣排列成的矩陣也稱為 Softmax 的雅可比矩陣(Jacobian Matrix)[2],它的 L1 範數有一個簡單的形式
當 是 one hot 分佈時,上式等於 0,這意味著 Softmax 的結果越接近 one hot,它的梯度消失現象越嚴重,所以至少初始化階段,我們不能將 Softmax 初始化得接近 one hot。同時上式最右端也聯絡到了 Rényi 熵的概念,它跟常見的香儂熵類似。
參考實現
Softmax 的直接實現很簡單,直接取 然後歸一化就行,Numpy 的參考程式碼為:
1defsoftmax(x):
2

    y = np.exp(x)

3return

 y / y.sum()

然而,如果 中存在較大的分量,那麼算 時很容易溢位,因此我們通常都要利用 Softmax 的不變性,先將每個分量減去所有分量的最大值,然後再算 Softmax,這樣每個取 的分量都不大於 0,確保不會溢位:
1defsoftmax(x):
2

    y = np.exp(x - x.max())

3return

 y / y.sum()

損失函式
構建機率分佈的主要用途之一是用於構建單標籤多分類任務的輸出,即假設有一個 分類任務, 是模型的輸出,那麼我們希望透過 來預測每個類的機率。為了訓練這個模型,我們需要一個損失函式,假設目標類別是 ,常見的選擇是交叉熵損失:
我們可以求得它的梯度:
注意 是給定的,所以 實際表達的是目標分佈 ,而全體 就是 本身,所以上式可以更直觀地寫成:
也就是說,它的梯度正好是目標分佈與預測分佈之差,只要兩者不相等,那麼梯度就一直存在,最佳化就可以持續下去,這是交叉熵的優點。當然,某些情況下這也是缺點,因為 Softmax 只有在 才會得到 one hot,換言之正常情況下都不會出現 one hot,即最佳化一直不會完全停止,那麼就有可能導致過度最佳化,這也是後面的一些替代品的動機。
除了交叉熵之外,還有其他一些損失可用,比如 ,這可以理解為準確率的光滑近似的相反數,但它可能會有梯度消失問題,所以它的最佳化效率往往不如交叉熵,一般只適用於微調而不是從零訓練,更多討論可以參考《如何訓練你的準確率?》
Softmax變體
介紹完 Softmax,我們緊接著總結一下本部落格以往討論過 Softmax 的相關變體工作,比如 Margin Softmax、Taylor Softmax、Sparse Softmax 等,它們都是在 Softmax 基礎上的衍生品,側重於不同方面的改進,比如損失函式、、稀疏性、長尾性等。
Margin Softmax
首先我們介紹起源於人臉識別的一系列 Softmax 變體,它們可以統稱為 Margin Softmax,後來也被應用到 NLP 的 Sentence Embedding 訓練之中,本站曾在《基於GRU和am-softmax的句子相似度模型》[3] 討論過其中的一個變體 AM-Softmax,後來則在《從三角不等式到Margin Softmax》有過更一般的討論。
儘管 Margin Softmax 被冠以 Softmax 之名,但它實際上更多是一種損失函式改進。以 AM-Softmax 為例,它有兩個特點:第一,以 形式構造 Logits,即 的形式,此時的溫度引數 是必須的,因為單純的 值域為 ,不能拉開類機率之間的差異;第二,它並不是簡單地以 為損失,而是做了加強:
直觀來看,就是交叉熵希望 是 所有分量中最大的一個,而 AM-Softmax 則不僅希望 最大,還希望它至少比第二大的分量多出 ,這裡的 就稱為 Margin。
為什麼要增加對目標類的要求呢?這是應用場景導致的。剛才說了,Margin Softmax 起源於人臉識別,放到 NLP 中則可以用於語義檢索,也就是說它的應用場景是檢索,但訓練方式是分類。如果單純用分類任務的交叉熵來訓練模型,模型編碼出來的特徵不一定能很好地滿足檢索要求,所以要加上 Margin 使得特徵更加緊湊一些。更具體的討論請參考《從三角不等式到Margin Softmax》一文,或者查閱相關論文。
Taylor Softmax
接下來要介紹的,是在《exp(x)在x=0處的偶次泰勒展開式總是正的》討論過的 Taylor Softmax,它利用了 的泰勒展開式的一個有趣性質:
對於任意實數 及偶數 ,總即 在 處的偶次泰勒展開式總是正的。
利用這個恆正性,我們可以構建一個 Softmax 變體(是任意偶數):
由於是基於 的泰勒展開式構建的,所以在一定範圍內 Taylor Softmax 與 Softmax 有一定的近似關於,某些場景下可以用 Taylor Softmax 替換  Softmax。那麼 Taylor Softmax 有什麼特點呢?答案是更加長尾,因為 Taylor Softmax 是多項式函式歸一化,相比指數函式衰減得更慢,所以對於尾部的類別,Taylor Softmax 往往能夠給其分配更高的機率,可能有助於緩解 Softmax 的過度自信現象。
Taylor Softmax 的最新應用,是用來替換 Attention 中的 Softmax,使得原本的平方複雜度降低為線性複雜度,相關理論推導可以參考《Transformer升級之路:作為無限維的線性Attention》[4]
該思路的最新實踐是一個名為 Based 的模型,它利用 來線性化 Attention,聲稱比 Attention 高效且比 Mamba 效果更好,詳細介紹可以參考部落格《Zoology(Blogpost 2): Simple, Input-Dependent, and Sub-Quadratic Sequence Mixers》[5] 和《BASED: Simple linear attention language models balance the recall-throughput tradeoff》[6]
Sparse Softmax
Sparse Softmax 是筆者參加 2020 年法研杯時提出的一個簡單的 Softmax 稀疏變體,首先發表在《SPACES:“抽取-生成”式長文字摘要(法研杯總結)》,後來也補充了相關實驗,寫了篇簡單的論文《Sparse-softmax: A Simpler and Faster Alternative Softmax Transformation》[7]
我們知道,在文字生成中,我們常用確定性的 Beam Search 解碼,或者隨機性的 TopK/TopP Sampling 取樣,這些演算法的特點都是隻保留了預測機率最大的若干個 Token 進行遍歷或者取樣,也就等價於將剩餘的 Token 機率視為零,而訓練時如果直接使用 Softmax 來構建機率分佈的話,那麼就不存在絕對等於零的可能,這就讓訓練和預測出現了不一致性。Sparse Softmax 就是希望能處理這種不一致性,思路很簡單,就是在訓練的時候也把 Top- 以外的 Token 機率置零:
其中 是將 從大到小排列後前 個元素的原始下標集合。簡單來說,就是在訓練階段就進行與預測階段一致的階段操作。這裡的 選取方式也可以按照 Nucleus Sampling 的 Top- 方式來操作,看具體需求而定。但要注意的是,Sparse Softmax 強行截斷了剩餘部分的機率,意味著這部分 Logits 無法進行反向傳播了,因此 Sparse Softmax 的訓練效率是不如 Softmax 的,所以它一般只適用於微調場景,而不適用於從零訓練。
Perturb Max
這一節我們介紹一種新的構建機率分佈的方式,這裡稱之為 Perturb Max,它是 Gumbel Max 的一般化,首次介紹於《從重引數的角度看離散機率分佈的構建》,此外在論文《EXACT: How to Train Your Accuracy》[8] 也有過相關討論,至於更早的出處筆者則沒有進一步考究了。
問題反思
首先我們知道,構建一個 的對映並不是難事,只要 是 (實數到非負實數)的對映,如 ,那麼只需要讓
就是滿足條件的映射了。如果要加上“兩點性質” [9] 中的“單調性”呢?那麼也不難,只需要保證 的單調遞增函式,這樣的函式也有很多,比如 。但如果再加上“不變性”呢?我們還能隨便寫出一個滿足不變性的 對映嗎?(反正我不能)
可能有讀者疑問:為什麼非要保持單調性和不變性呢?的確,單純從擬合機率分佈的角度來看,這兩點似乎都沒什麼必要,反正都是“力大磚飛”,只要模型足夠大,那麼沒啥不能擬合的。但從 “Softmax 替代品”這個角度來看,我們希望新定義的機率分佈同樣能作為 的光滑近似,那麼就要儘可能多保持跟 相同的性質,這是我們希望保持單調性和不變性的主要原因。
Gumbel Max
Perturb Max 藉助於Gumbel Max 的一般化來構造這樣的一類分佈。不熟悉 Gumbel Max 的讀者,可以先到《漫談重引數:從正態分佈到Gumbel Softmax》[10] 瞭解一下 Gumbel Max。簡單來說,Gumbel Max 就是發現:
怎麼理解這個結果呢?首先,這裡的 是指 的每個分量都是從 Gumbel 分佈獨立 [11] 重複取樣出來的;接著,我們知道給定向量 ,本來 是確定的結果,但加了隨機噪聲 之後, 的結果也帶有隨機性了,於是每個 都有自己的機率;最後,Gumbel Max 告訴我們,如果加的是 Gumbel 噪聲,那麼i的出現機率正好是 。
Gumbel Max 最直接的作用,就是提供了一種從分佈 中取樣的方式,當然如果單純取樣還有更簡單的方法,沒必要“殺雞用牛刀”。Gumbel Max 最大的價值是“重引數(Reparameterization)”,它將問題的隨機性從帶引數 的離散分佈轉移到了不帶引數的 上,再結合 Softmax 是 的光滑近似,我們得到 是 Gumbel Max 的光滑近似,這便是 Gumbel Softmax,是訓練“離散取樣模組中帶有可學引數”的模型的常用技巧。
一般噪聲
Perturb Max 直接源自 Gumbel Max:既然 Softmax 可以從 Gumbel 分佈中匯出,那麼如果將 Gumbel 分佈換為一般的分佈,比如正態分佈,不就可以匯出新的機率分佈形式了?也就是說直接定義
重複 Gumbel Max 的推導,我們可以得到
其中 是 的累積機率函式。對於一般的分佈,哪怕是簡單的標準正態分佈,上式都很難得出解析解,所以只能數值估計。為了得到確定性的計算結果,我們可以用逆累積機率函式的方式進行均勻取樣,即先從 均勻選取 ,然後透過求解 來得到 。
從 Perturb Max 的定義或者最後 的形式我們都可以斷言 Perturb Max 滿足單調性和不變性,這裡就不詳細展開了。那它在什麼場景下有獨特作用呢?說實話,還真不知道,《EXACT: How to Train Your Accuracy》[8] 一文用它來構建新的機率分佈並最佳化準確率的光滑近似,但筆者自己的實驗顯示沒有特別的效果。個人感覺,可能在某些需要重引數的場景能夠表現出特殊的作用吧。
Sparsemax
接下來要登場的是名為 Sparsemax 的機率對映,出自 2016 年的論文《From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification》[12],它跟筆者提出的 Sparse Softmax 一樣,都是面向稀疏性的改動,但作者的動機是用在 Attention 中提供更好的可解釋性。跟 Sparse Softmax 直接強行截斷 Top- 個分量不同,Sparsemax 提供了一個更為自適應的稀疏型機率分佈構造方式。
基本定義
原論文將 Sparsemax 定義為如下最佳化問題的解:
透過拉格朗日乘數法就可以求出精確解的表示式。然而,這種方式並不直觀,而且也不容易揭示它跟 Softmax 的聯絡。下面提供筆者構思的一種私以為更加簡明的引出方式。
首先,我們可以發現,Softmax 可以等價地表示為
其中 是使得 的各分量之和為 1 的常數,對於 Softmax 我們可以求出 。
然後,在 Taylor Softmax 那一節我們說了, 的偶次泰勒展開總是正的,因此可以用偶次泰勒展開來構建 Softmax 變體。但如果是奇數次呢?比如 ,它並不總是非負的,但我們可以加個 強行讓它變成非負的,即 ,用這個近似替換掉式(20)的 ,就得到了 Sparsemax:
其中 依然是使得 的各分量之和為1的常數,並且常數 1 也可以整合到 之中,所以上式也等價於
求解演算法
到目前為止,Sparsemax 還只是一個形式化的定義,因為 的具體計算方法尚不清楚,這就是本節需要探討的問題。不過即便如此,單靠這個定義我們也不難看出 Sparsemax 滿足單調性和不變性兩點性質,如果還覺得不大肯定的讀者,可以自行嘗試證明一下它。
現在我們轉向 的計算。不失一般性,我們假設 各分量已經從大到小排序好,即 ,接著我們不妨先假設已知 ,那麼很顯然
根據 的定義,我們有
這就可以求出 。當然,我們無法事先知道 ,但我們可以遍歷 ,利用上式求一遍 ,取滿足 ,這也可以等價地表示為求滿足 的最大的 ,然後返回對應的
參考實現:
1defsparsemax(x):
2

    x_sort = np.sort(x)[::

-1

]

3

    x_lamb = (np.cumsum(x_sort) - 

1

) / np.arange(

1

, len(x) + 

1

)

4

    lamb = x_lamb[(x_sort >= x_lamb).argmin() - 

1

]

5return

 np.maximum(x - lamb, 

0

)

梯度計算
方便起見,我們引入記號
那麼可以寫出
從這個等價形式可以看出,跟 Sparse Softmax 一樣,Sparsemax 同樣也只對部分類別有梯度,可以直接算出雅可比矩陣:
由此可以看出,對於在 裡邊的類別,Sparsemax 倒是不會梯度消失,因為此時它的梯度恆為常數,但它總的梯度大小,取決於 的元素個數,它越少則越稀疏,意味著梯度也越稀疏。
損失函式
最後我們來討論 Sparsemax 作為分類輸出時的損失函式。比較直觀的想法就是跟 Softmax 一樣用交叉熵 ,但 Sparsemax 的輸出可能是嚴格等於0的,所以為了防止 錯誤,還要給每個分量都加上 ,最終的交叉熵形式為 ,但這一來有點醜,二來它還不是凸函式,所以並不是一個理想選擇。
事實上,交叉熵在 Softmax 中之所以好用,是因為它的梯度恰好有(12)的形式,所以對於 Sparsemax,我們不妨同樣假設損失函式的梯度為 ,然後反推出損失函式該有的樣子,即:
從右往左驗證比較簡單,從左往右推可能會有些困難,但不多,反覆拼湊一下應該就能出來了。第一個 常數是為了保證損失函式的非負性,我們可以取一個極端來驗證一下:假設最佳化到完美,那麼 應該也是 one hot,此時 ,並且 ,於是
所以要多加上常數 。
Entmax-α
Entmax- 是 Sparsemax 的一般化,它的動機是 Sparsemax 往往會過度稀疏,這可能會導致學習效率偏低,導致最終效果下降的問題,所以 Entmax- 引入了 引數,提供了 Softmax()到Sparsemax()的平滑過度。
Entmax- 出自論文《Sparse Sequence-to-Sequence Models》[13],作者跟 Sparsemax 一樣是 Andre F. T. Martins,這位大佬圍繞著稀疏 Softmax、稀疏 Attention 做了不少工作,有興趣的讀者可以在他的主頁 [14] 查閱相關工作。
基本定義
跟 Sparsemax 一樣,原論文將 Entmax- 定義為類似(19)的最佳化問題的解,但這個定義涉及到 Tsallis entropy [15] 的概念(也是 Entmax 的 Ent 的來源),求解還需要用到拉格朗日乘數法,相對來說比較複雜,這裡不採用這種引入方式。
我們的介紹同樣是基於上一節的近似 ,對於 Softmax 和 Sparsemax,我們有
Sparsemax 太稀疏,背後的原因也可以理解為 近似精度不夠高,我們可以從中演化出更高精度的近似
只要 ,那麼最右端就是一個比 更好的近似(想想為什麼)。利用這個新近似,我們就可以構建
這裡 是為了對齊原論文的表達方式,事實上用 表示更簡潔一些。同樣地,常數 1 也可以收入到 定義之中,所以最終定義可以簡化為
求解演算法
對於一般的 ,求解 是比較麻煩的事情,通常只能用二分法求解。
首先我們記 ,並且不失一般性假設 ,然後我們可以發現 Entmax- 是滿足單調性和不變性的,藉助不變性我們可以不失一般性地設 (如果不是,每個 都減去 即可)。
現在可以檢驗,當 時, 的所有分量之和大於等於 1,當 時, 的所有分量之和等於 0,所以最終能使分量之和等於 1 的 必然在 內,然後我們就可以使用二分法來逐步逼近最優的 。
對於某些特殊的 ,我們可以得到一個求精確解的演算法。Sparsemax 對應 ,我們前面已經給出了求解過程,另外一個能給解析解的例子是 ,這也是原論文主要關心的例子,如果不加標註,那麼 Entmax 預設就是 Entmax-1.5。跟 Sparsemax 一樣的思路,我們先假設已知 ,於是有
這只不過是關於 的一元二次方程,可以解得
當我們無法事先知道 時,可以遍歷 ,利用上式求一遍 ,取滿足 那一個 ,但注意這時候不等價於求滿足 的最大的 。
完整的參考實現:
1defentmat(x):
2

    x_sort = np.sort(x / 

2

)[::

-1

]

3

    k = np.arange(

1

, len(x) + 

1

)

4

    x_mu = np.cumsum(x_sort) / k

5

    x_sigma2 = np.cumsum(x_sort**

2

) / k  - x_mu**

2
6

    x_lamb = x_mu - np.sqrt(np.maximum(

1.

 / k - x_sigma2, 

0

))

7

    x_sort_shift = np.pad(x_sort[

1

:], (

0

1

), constant_values=-np.inf)

8

    lamb = x_lamb[(x_sort > x_lamb) & (x_lamb > x_sort_shift)]

9return

 np.maximum(x / 

2

 - lamb, 

0

)**

2

其他內容
Entmax- 的梯度跟 Sparsemax 大同小異,這裡就不展開討論了,讀者自行推導一下或者參考原論文就行。至於損失函式,同樣從梯度 出發反推出損失函式也是可以的,但其形式有點複雜,有興趣瞭解的讀者可以參考原論文《Sparse Sequence-to-Sequence Models》[13] 以及《Learning with Fenchel-Young Losses》[16]
不過就筆者看來,直接用 運算元來定義損失函式更為簡單通用,可以避免求原函式的複雜過程:
這裡的 是向量內積,這樣定義出來的損失,其梯度正好是 ,但要注意這個損失函式只有梯度是有效的,它本身的數值是沒有參考意義的,比如它可正可負,也不一定越小越好,所以要評估訓練進度和效果的話,得另外建立指標(比如交叉熵或者準確率)。
文章小結
本文簡單回顧和整理了 Softmax 及其部分替代品,其中包含的工作有 Softmax、Margin Softmax、Taylor Softmax、Sparse Softmax、Perturb Max、Sparsemax、Entmax- 的定義、性質等內容。
參考文獻
[1] https://kexue.fm/archives/3290
[2] https://en.wikipedia.org/wiki/Jacobian_matrix_and_neterminant
[3] https://kexue.fm/archives/5743
[4] https://kexue.fm/archives/8601
[5] https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
[6] https://www.together.ai/blog/based
[7] https://papers.cool/arxiv/2112.12433
[8] https://papers.cool/arxiv/2205.09615
[9] https://kexue.fm/archives/10145#兩點性質
[10] https://kexue.fm/archives/6705
[11] https://en.wikipedia.org/wiki/Gumbel_nistribution
[12] https://papers.cool/arxiv/1602.02068
[13] https://papers.cool/arxiv/1905.05702
[14] https://andre-martins.github.io
[15] https://en.wikipedia.org/wiki/Tsallis_entropy
[16] https://papers.cool/arxiv/1901.02324
更多閱讀
#投 稿 通 道#
 讓你的文字被更多人看到 
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋樑,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。 
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學術熱點剖析科研心得競賽經驗講解等。我們的目的只有一個,讓知識真正流動起來。
📝 稿件基本要求:
• 文章確係個人原創作品,未曾在公開渠道發表,如為其他平臺已發表或待發表的文章,請明確標註 
• 稿件建議以 markdown 格式撰寫,文中配圖以附件形式傳送,要求圖片清晰,無版權問題
• PaperWeekly 尊重原作者署名權,並將為每篇被採納的原創首發稿件,提供業內具有競爭力稿酬,具體依據文章閱讀量和文章質量階梯制結算
📬 投稿通道:
• 投稿郵箱:[email protected] 
• 來稿請備註即時聯絡方式(微信),以便我們在稿件選用的第一時間聯絡作者
• 您也可以直接新增小編微信(pwbot02)快速投稿,備註:姓名-投稿
△長按新增PaperWeekly小編
🔍
現在,在「知乎」也能找到我們了
進入知乎首頁搜尋「PaperWeekly」
點選「關注」訂閱我們的專欄吧
·
·
·

相關文章