近年來,新興起的圖神經網路在很多應用領域都取得了非常出色的表現,如今年用於Google 地圖的到達時間估計(Estimated Time of Arrival,ETA),在紐約、洛杉磯、東京、新加坡等國際大都市都獲得了很大的提升,該結果對其他地區也具有通用性[1]。圖神經網路以圖結構為核心組成部分,這與結構因果模型有著相似的結構形式。鑑於此,DeepMind最新的研究工作[2]以圖神經網路為網路結構,設計了一種基於圖神經網路的變分圖自編碼器,用於近似Pearl因果層次結構中的因果計算問題。與以往用因果推斷思想提升深度學習效能不同的是,該研究工作在圖神經結構與結構因果模型之間建立了轉換機制,為Pearl因果層次結構中的因果計算提供了一種新型計算方法和思路,是深度學習在因果推斷領域應用的一項開創性的嘗試性工作。
結構因果模型(Structural Causal Model,SCM)用於描述現實世界關聯特徵及其相互作用,是一種能夠形式化表述資料背後因果假設的方法。結構因果模型含有兩個變數集
和
以及一組函式
,
即函式
根據模型中其他變數的值給變數
賦值。若
存在於
的定義域中,則變數
是變數
的直接原因。
中的變數稱為外生變數,屬於模型的外部,不必解釋其變化的原因。
中的變數為內生變數,模型中的每一個內生變數都至少有一個外生變數作為直接原因。
圖1:SCM例項和Pearl Causal Hierarchy
每一個結構因果模型都對應一個圖模型
,
中的每個變數都表示為一個節點,對於變數
,如果
的定義域中含有變數
,那麼在
中,會有一條從
到
的有向邊。同時,透過例項化外生變數
,也可以生成包含所有內生變數
的資料,因此SCM也是一種資料生成模型。在SCM中,可以回答Pearl因果層次結構(Pearl Causal Hierarchy,PCH)關於基於觀察資料(L1)、干預(L2)以及反事實(L3)三個層次的因果計算, 如圖1所示。在PCH因果計算中,其核心思想是需要解決干預(L2)的問題,這是回答反事實問題(L3)的基礎。在SCM中,干預主要體現在do-演算操作,即將干預變數強制設定為某一固定值,使得干預變數不會隨其他變數影響。下面以一個例子來說明do-演算在SCM中的計算過程。
Example:考慮飲食(D)對血壓(B)的影響,
表示高纖維的飲食習慣,
表示高血壓,假設給定的SCM為M,外生變數
,內生變數
以及函式
其中⊕是XOR邏輯操作,該SCM相對應的圖結構
如圖2所示。
此外,因為所有變數都取二元值,因此我們也可以列舉推演出其真值表,如表1所示。
現在我們對變數D進行干預,設定為
,那麼根據表1我們可以計算
的機率,就是將表1中所有藍色行的數字求和,即
圖結構是SCM模型的一個重要組成部分,能直觀地表達變數之間的互動資訊。但是在實際問題中,這種互動關係的確定往往嚴重依賴於領域專家知識,無可避免地引入了人為誤差。個人認為,GNN作為深度學習方法的衍生體,能有效地近似任何函式,是模擬SCM中機率分佈的一個可行方法。
Do-演算也可在圖上以更直觀的方式呈現。當對變數進行干預時,意味著削弱了該變數響應其他變數而變化的自然趨勢。在SCM對應的圖結構表示上,就需要刪除指向該變數的所有邊,如圖3中右邊紅色錘子對應的邊。
按照上述思想,文[2]對GNN也定義了類似的干預操作,主要體現在GNN的訊息傳遞(message-passing)中。在標準的GNN資訊聚合操作中,圖中節點
透過聚合其父節點
的資訊完成當前節點
的資訊更新,如式(2)所示。但是在干預的GNN中,如果當前節點
為干預變數,則忽略其父節點的資訊(將
替換為
),如式(3)所示。
在上述圖神經網路(Graph Neural Networks, GNN)的do-演算基礎上,文[2]定義了用於近似PCH因果推斷的干預變分圖自編碼器(Interventional Variational Graph Auto-Encoder, iVGAE),如圖4所示。
乍一看,圖4上部分描述的是標準變分圖自編碼器模型結構。但為了能近似SCM在L2層次的因果推斷,文[2]將編碼器
函式、解碼器
函式都設計為以給定SCM為圖結構的GNN的聚合函式。在進行L2層次因果推理時,根據給定的查詢變數和干預變數,動態地對圖結構進行do-演算調整,即忽略/不計算來自干預變數的父節點的資訊,而模型的輸出可近似成該干預變數下的機率分佈,即完成L2層次的因果計算。在訓練iVGAE,主要採用了變分方法,其中目標函式也需要考慮干預變數。
與以往用因果推斷改進深度學習方法效果不同的是,文[2]側重於用基於GNN的深度學習來完成SCM中的PCH因果計算,側重於基於觀察資料(L1)、干預的推斷(L2)。由於圖神經網路與SCM都是基於圖結構,一種簡單、直接的方法就是在給定SCM圖結構上,設計一種合適引數轉化機制,以確保SCM和深度學習模型表達同個分佈,這也是文[2]的主要設計思路。同時,文[2]也指出,SCM需要對每個變數都定義各自相應的對映函式。相反的,在iVGAE中,可以找到單個共享聚合函式,用於聚合圖中所有節點的訊息。然而將單個聚合函式轉換成多個結構方程的最佳化過程是異常困難,而這也是實現反事實推理需要解決的問題,這也是文[2]沒有考慮L3層次推理的一個原因。
雖然文[2]、[3]都試圖在SCM與深度學習之間建立聯絡,目前主要側重於將深度學習看成一種近似方法來完成PCH中的因果計算。當然,對PCH因果計算的支援是實現因果推斷的重要內容,也可以看成深度學習在因果表達上邁出了重要的一步。不同方法有不同程度的相容性,如文[2]不支援L3層次的計算。這些研究也引出了更深層次的問題,如基於神經網路的因果計算優勢體現在哪裡,例如,推理計算是否更高效?文中尚未提供明確的答案。
除了近似分佈,也有將深度學習用於因果發現中的研究工作,如文[4]中提出了連續最佳化(continuous optimization)的思想,重新定義了因果圖發現的一種求解方式。與其在圖空間進行搜尋,轉化為尋找一個包含圖結構的鄰接矩陣的函式,從而可以使用深度學習方法進行梯度下降求解。當然,這與文[2]有著不同的研究目標。不可否認,如何使得深度學習和因果推斷相得益彰,是一個非常值得探索的方向,相信在不久的將來兩者能碰觸更多的火花。
最後,個人覺得文[2]的亮點在於採用現流行的GNN模型來模擬SCM的資料生成機制,雖然這種資料生成過程是一種黑盒子方法(這也是深度學習廣為爭議的一個特徵),但如果僅從資料模擬效果的角度來看,未嘗不可?
由於水平有限,文中存在不足的地方,請各位讀者批評指正,也歡迎大家參與我們的討論。
[1] Austin Derrow-Pinion, Jennifer She, David Wong, et al. ETA Predictionwith Graph Neural Networks in Google Maps. 2021
[2] Matej Zecevi,Devendra Singh Dhami, Petar Velickovi,Kristian Kersting. Relating Graph Neural Networks to Structural Causal Models. 2021
[3]Kevin Xia, Kai-Zhan Lee, Yoshua Bengio, Elias Bareinboim. The Causal-Neural Connection: Expressiveness, Learnability, and inference. 2021
[4] Xun Zheng, Bryon Aragam, Pradeep Ravikumar, Eric P Xing. Dags with no tears: continuous optimization for structure learning. 2018
壁仞科技研究院作為壁仞科技的前沿研究部門,旨在研究新型智慧計算系統的關鍵技術,重點關注新型架構,先進編譯技術和設計方法學,並將逐漸拓展研究方向,探索未來智慧系統的各種可能。壁仞科技研究院秉持開放的原則,將積極投入各類產學研合作並參與開源社群的建設,為相關領域的技術進步做出自己的貢獻。