TPAMI 2024 | 广义线性因果网络的联邦学习

文摘   2024-11-08 19:00   辽宁  

点击下方PaperEveryday”,每天获得顶刊论文解读

点击加入论文投稿、写作、阅读分享交流群

题目:Federated Learning of Generalized Linear Causal Networks

广义线性因果网络的联邦学习

作者:Qiaoling Ye; Arash A. Amini; Qing Zhou


摘要

因果发现,即从数据中推断变量之间的因果关系,是科学中的一个基本问题。如今,由于对数据隐私问题的日益关注,分布式数据收集、处理和存储发生了转变。为了满足分布式因果发现的迫切需求,我们提出了一种新的联合有向无环图(DAG)学习方法,称为分布式退火正则化似然分数(DARLS),用于从存储在多个客户端上的数据中学习因果图。DARLS模拟了一个退火过程来搜索拓扑排序的空间,其中与排序兼容的最优图形结构是通过分布式优化找到的。这种分布式优化依赖于本地客户端和中央服务器之间的多轮通信来估计图形结构。我们建立了它收敛到可以访问所有数据的预言机得到的解的保证。据我们所知,DARLS是第一个具有如此有限样本预言机保证的分布式学习方法来学习因果图。为了建立DARLS的一致性,我们还推导出了因果图参数化的新可识别性结果,这些结果可能具有独立的兴趣。通过广泛的模拟研究和现实世界的应用,我们展示了DARLS在估计分布式数据的因果网络方面优于现有的联合学习方法,并且与在汇总数据上的预言机方法相当,展示了它在估计分布式数据的因果网络方面的巨大优势。

关键词

因果图,联合学习,广义线性模型,模拟退火,拓扑排序

I. 引言

推断变量之间的因果关系是许多应用中的基本问题,如计算生物学、医学科学、社会科学等。它与统计学中的多个领域紧密相连,包括随机实验[1]、[2]、回顾性反事实推理[3]、潜在结果[4]、[5]和概率图模型[6]、[7]。通过图形模型进行因果推断的一个关键而具有挑战性的步骤是从数据中识别因果效应关系,即通常表述为因果图或其马尔可夫等价类的结构学习。作为主要的因果模型,因果图已在流行病学[8]、[9]、病理生理学[10]、经济学[11]和风险分析[12]等多个领域得到广泛应用。因果图通常由有向无环图(DAG)表示,其中边缘编码变量(节点)之间的因果效应。一组随机变量在因果DAG 中的概率密度分解为
其中的父集,是它的值。当通过实验将的父集设置为时,条件分布也解释为的干预分布,即[13]。由于随机实验并不总是可行或可用的,已经提出了各种方法来从观测数据中学习因果DAG;参见[14]、[15]、[16]的最新综述。

A. 联合学习

在这项工作中,我们专注于联合DAG学习的任务,即从分布式数据中学习由DAG编码的因果关系,特别关注非高斯情况,这包括了广泛的数据类型。此外,通过其充分统计量的简单更新就可以实现高斯线性DAG模型的分布式学习(备注1)。分布式数据存储已被用于隐私保护,用于管理政府机构、研究机构、医疗中心、技术公司等每天产生的大量数据[17]。这些组织经常收集相似的数据,或来自同一人群的数据,并在社会、科学和商业领域进行合作。著名的例子包括Google和Apple在2021年为追踪COVID-19接触而开发的暴露通知系统,以及用于健康指标收集的隐私保护分析平台[18],以及小型诊所汇集其他设施的数据以实现统计上显著的结果,因为它们自己的数据集有限。这种合作需要严格的隐私披露,确保每个机构持有的敏感数据不会对外共享。
联合学习被证明是保护数据隐私或合并多个来源的数据不可行时的有效工具。这种方法使本地实体能够协作地从分散的数据中学习,而无需直接共享数据。因此,联合学习不仅解决了隐私和后勤问题,而且还通过从多样化的数据源中学习来促进强大和泛化模型的发展。
获得全局估计的一个简单方法是对本地估计进行平均,这种技术称为一次性参数平均[19]。然而,这种方法与基于所有数据的全局估计相比,无法达到任何期望的准确性水平[20]、[21]。为了解决这些限制,开发了更复杂的通信高效算法,这些算法采用本地客户端和中央服务器之间的多轮通信[22]、[23]。这种算法在多智能体系统的分布式优化中越来越重要,包括电力系统、传感器网络和智能制造[24]、[25]。

B. 本工作的贡献

尽管上述方法取得了进展,但从多个本地客户端分布的数据中学习因果DAG仍然是一项具有挑战性的任务。一个主要困难是如何整合本地信息,以有效估计全局因果图,就好像统计学家可以访问所有本地客户端的数据一样。有一些简单的方法会迭代本地数据集(一次),然后组合本地图或本地p值以形成一个全局图[26]、[27]、[28]。然而,使用这种单次迭代方法的本地估计的简单聚合,不会导致接近于如果访问汇总数据则构建的相应全局估计的估计。此后,我们将这种全局估计称为预言机解。
另一个特定于我们问题的挑战来自DAG的无环约束。显然,简单的平均图可能会失败这个基本约束,因此不能提供有意义的因果解释。为了克服这些困难,我们提出了一种基于分数的学习方法,该方法执行多轮通信以估计分布式数据的DAG。我们的目标函数等同于整体数据的正则化对数似然,已经证明在学习连续和离散DAG方面是有效的[29]、[30]、[31]。中央服务器提出了一个候选排序π,其中π的分数是通过分布式优化在与π兼容的DAG上评估的。然后,通过模拟退火选择候选排序π。因为每个DAG至少有一个排序,搜索排序空间确保始终满足无环约束。
我们的方法是一个联合学习的例子,其中中央服务器与分布式本地客户端通信以学习DAG。我们展示了我们联合估计收敛到整体数据上的预言机估计的速率为,对于一个固定的真DAG,其中n是所有本地客户端中的总样本大小,m是最小的本地样本大小(定理1,第IV-A节)。因此,即使对于有限的样本,只要很小,我们的分布式估计将基本上与预言机解相同,实现了理想的效率,同时保护了数据隐私。据我们所知,我们的方法是第一个具有如此良好理论保证的联合因果发现方法。当这项工作在审查时,两种联合DAG学习方法被发表[32]、[33]。这两篇论文都没有建立这样的预言机估计保证。我们将在第III-D节中进一步阐述其他技术差异。
我们工作的另一个贡献是使用广义线性模型(GLMs)用于(1)中的局部条件分布,这为因果结构学习带来了几个优势。我们提出的GLM DAG模型是一个灵活的家庭,适用于各种数据类型,超越了具有等方差和多对数模型的线性高斯模型[31]、[34]。常用的GLMs下的负对数似然函数是凸的,这有助于优化任务。此外,我们展示了GLM DAG模型在温和条件下是可识别的(命题1,第II节),而其他模型,如离散网络的多项式和高斯线性DAGs(具有异方差)通常不是可识别的[35]、[36]。在这种可识别性下,我们建立了我们正则化似然分数的全局最大化器DAG的一致性(定理2,第IV-B节)。

C. 组织

本文的组织如下。第II节定义了广义线性DAG模型,并在该模型下建立了一些可识别性结果。在第III节,我们建立了学习因果图的优化问题,并开发了结合模拟退火搜索和迭代优化方法的DARLS算法,以学习因果DAG结构。然后我们在第IV节建立了分布式优化算法的收敛性和估计一致性的定理。第V节包括模拟研究,比较我们的方法与现有方法在分布式和汇总数据上的性能,测试DARLS对其底层模型假设违规的鲁棒性,并检查我们分布式学习算法的准确性损失和计算效率。我们还在第VI节将分布式学习方法应用于ChIP测序数据,以模拟蛋白质-DNA结合网络。本文在第VII节以讨论结束。所有证明都放在补充材料中。

II. 广义线性DAG模型

是变量的一个实现,其中对于数值对于具有类的分类变量,使用一位有效编码。让表示与边缘相关的参数,如果,则。设
其中。这里和其他地方,表示两个向量或矩阵的垂直连接。我们通过以下方式定义广义线性DAG (GLDAG):
其中都是从的函数。注意只依赖于父集。GLDAG模型允许通过选择对数划分函数来实现许多常见分布。示例包括伯努利分布对于,恒定方差高斯分布对于,泊松分布对于,伽马分布对于,以及多项式分布对于。注意,在多项式情况下,是多变量函数,操作向量,与其它示例中是标量函数形成对比。上述伯努利和多项式选择分别产生了每个节点的逻辑和多对数回归模型。
我们通过水平连接,每个如(2)中定义,在矩阵中收集模型(4)的所有参数。我们说GLDAG (4)是连续的,如果所有变量都是连续的。回想一下,在这种情况下,对于所有,且矩阵。我们重写(4)的对数概率密度函数,在连续情况下为:
其中且如果且仅如果,则。接下来,我们定义DAG模型的可识别性,按照[37]、[38]、[39],并展示连续GLDAGs是可识别的。
定义1(可识别性):假设我们有一个由未知GLDAG模型(4)生成的联合分布。如果分布不能由任何具有不同图的GLDAG模型生成,则我们说可以从中识别出来。
众所周知,线性高斯DAG(具有异方差)和一般多项式DAG通常不是可识别的[35]、[36]。相比之下,连续GLDAG模型(5)在温和的假设下是可识别的:
命题1:假设联合分布由对数概率密度函数定义,具有DAG 根据(5),使得如果且仅如果中,则。如果关于是二阶可微的,且对于所有的一阶导数存在且不是常数,则可以从中识别出来。
命题1建立了连续GLDAG模型(5)的可识别性,部分证明了我们学习因果图的目标。这个结果扩展了文献中可识别DAG模型的类别。另一类可识别的DAG模型是加性噪声模型,在是非线性的[37]、[38]、[40]或误差非高斯的假设下[41]。

III. 联合DAG学习

在本节中,我们使用分布式数据构建目标函数,并提出结合模拟退火搜索和迭代优化方法来学习因果DAG结构。我们从DAG的拓扑排序的定义开始。给定上的一个排列,我们根据排列向量以获得重标记的向量。DAG的一个拓扑排序是节点的一个排列,使得如果,则在由定义的顺序中先于,记作。根据定义(1),每个DAG至少有一个拓扑排序。
是从模型(4)中的i.i.d样本大小为。我们也用表示第个数据点中第个变量()的观测值。考虑一个子集。子样本的归一化负对数似然由下式给出,忽略一个加性常数,
注意在这里,表示整个样本大小为的归一化负对数似然。

A. 全局目标函数和退火

我们考虑整体数据存储在个不同服务器上的情况,每个本地客户端持有其私有数据,并与中央服务器通信。设中的样本大小,使得
基于整个数据的归一化负对数似然可以分解为。设上所有排列的集合,是与排列兼容的DAG集合。注意中的一个线性子空间。我们理想地希望通过最小化正则化损失函数的形式来估计
其中
是促进中稀疏性的适当正则化器。我们称为全局目标函数,因为它使用所有本地客户端之间的数据定义。
回想一下,如果且仅当,则。为了学习稀疏DAG,我们应用形式为
的分组正则化,其中是非负且非递减的组正则化器,且是调优参数。在的限制下,正则化器可以进一步简化为。在本文中,我们考虑组Lasso(即组)惩罚,选择
其中是矩阵的Frobenius范数。作为一个凸惩罚和Lasso正则化的自然扩展,组Lasso在分组变量选择中表现出色。
为了在分布式数据上搜索如(7)中所述,我们提出了分布式退火正则化似然分数(DARLS)算法,它将退火策略应用于排列空间的搜索,并结合分布式优化方法。这种联合优化排列空间和DAG空间的方式在结构学习中显示出了巨大的有效性;参见例如[43]、[44]、[45]、[46]及其参考文献。
DARLS的主要步骤概述在算法1中。在每次退火迭代中,基于当前(第5行)提出排列,并根据模拟退火给定递减的温度计划接受。为了计算给定排列的最优DAG结构的分数,我们使用在第III-B节中讨论的分布式优化方法,概述在算法2中。这种方法允许本地客户端和中央服务器之间进行多轮通信以更新和综合信息。注意DARLS可以应用于任何目标函数,只要关于的梯度有封闭形式表达式。其他步骤(第1行和第9行)在第III-C节中讨论。

备注1:对于从多个独立数据集估计高斯DAG,我们可以使用一阶方法,如随机梯度或近端梯度算法[46]。给定,这种类型的算法使用本地数据集的样本协方差来计算全局估计。因此,高斯DAG的分布式学习只需要平均本地充分统计量,不需要分布式优化(算法1中的第6行),因此不是这项工作的重点。

B. 本地目标和分布式优化

对于任何固定的,我们使用分布式计算来评估,因为样本并未在本地客户端之间共享。也就是说,而不是直接使用(7)中的目标函数,我们依赖于它的本地版本来指导分布式算法,该算法将计算的任务分配给个本地客户端。特别地,我们考虑本地目标函数
其中
全局版本(7)可以重写为,其中。通常,由于正则化器的存在,每个都是非光滑的,但差异通常是光滑的。梯度用于指导每个本地客户端中的迭代。也就是说,给定当前(全局)估计,本地客户端执行更新
本地正则化损失指导,表示为,是全局正则化损失的一阶近似,直到一个加性常数。让是算法在迭代的全局估计。在下一次迭代中,我们获得本地估计对于
这些本地估计然后传递给中央服务器以通过平均值计算下一个全局估计,即。这种分布式优化方法的主要步骤概述在算法2中。
上述方法本质上是DANE算法的一个版本[19]、[21]、[22]、[23]。注意为了计算本地更新(算法2中的第5行),只需要将当前全局估计和全局梯度传达给每个本地客户端。在第IV-A节中,我们展示了对于足够大的每个客户端的最小样本大小,即,产生的序列将收敛到上的全局最小化器

分布式优化中的另一个部分是计算本地更新(10)(算法2中的第5行),我们使用近端梯度算法(算法3)。给定当前全局估计,优化本地目标(10)相当于
定义,作为全局似然的替代。为了解决(10),我们使用迭代近端梯度下降。在每次迭代中,我们最小化围绕当前解的二次近似,加上一个正则化项,
其中扮演步长的角色,且是我们对解的下一个估计。等价地,更新(12)可以重写为
是应用于缩放函数的近端算子。方程(13)被称为近端梯度更新,对于我们的正则化器选择,由(8)和(9)给出,具有以下封闭形式表达式:
这通常被称为块软阈值运算符。为了确定步长的值,我们使用后退线搜索,缩小初始值直到找到合适的步长(算法3中的第7行)。算法3的收敛性在(11)在上的凸性给定时是保证的[47]、[48]。然而,为了避免可能的缓慢收敛,我们设置了最大迭代次数以提前停止(第2行)。我们推荐读者阅读[49]以获取更多关于近端算法的细节。
计算复杂性:现在我们来估计DARLS的整体计算复杂性。为简单起见,考虑数值,使得对于所有。那么,中的向量,且可以被认为是一个维度为的向量。从(5)中,对于单个样本
因此需要次操作来计算。计算需要次操作对于单个样本,因此具有计算复杂性。对于后续计算,注意到本身是一个维向量。为简单起见,假设所有个本地客户端具有相同数量的本地样本,即。那么,算法2的第3步的复杂性是,我们将并行计算视为一个。让是算法3中的最大迭代次数。由于计算也需要次操作,算法2(即整个算法3)的复杂性是。算法2的步骤4和6的复杂性是。因此,算法2的整体复杂性是。这给出了DARLS的整体复杂性

C. 调优参数和结构估计

给定初始排列,我们使用BIC网格搜索来选择在组Lasso惩罚(8)中使用的调优参数(算法1中的第1行)。为了构建网格,我们选择了20个等间隔的点,对数刻度上从区间,其中足够大以产生空图在我们的测试中。我们选择最小化BIC得分的调优参数,BIC(i) = ,其中上使用惩罚参数的最小化器,由算法2计算,且中的自由参数数量。注意我们的调优参数选择是在算法1中估计之前完成的,这与通常在获得解路径后选择的常见做法不同。我们的策略大大减少了计算成本,并且在实践中效果很好,如我们之前的工作[46]中所示。
DARLS算法结束时提供的估计GLDAG参数,从中我们可以估计因果结构(算法1中的第9行)。设是一个DAG的加权邻接矩阵,权重。使用组Lasso正则化有助于产生稀疏估计DAG,但通常会呈现假阳性边缘。因此,我们通过设置为零,如果,来进一步细化估计结构。可以根据需要调整的值以实现期望的稀疏度水平,特别是当具有先验知识时。在我们的模拟测试中,我们固定以移除权重相对较小的边缘。

D. 与其他联合学习方法的比较

在本节中,我们阐明了我们的方法与两种最近的联合方法FedDAG[33]和NOTEARS-ADMM[32]之间的区别。首先,也是最重要的,我们的工作作为唯一一种具有关键理论保证的方法脱颖而出,确保联合估计收敛到基于所有本地数据的预言机估计(见第IV节)。第二,在我们的工作中施加的无环性约束与其他方法不同。在我们的算法中,我们依赖于中央服务器的基于顺序的搜索,而其他两种方法使用连续的代数约束[50]。然而,最近的工作[51]表明,代数约束不能被精确满足,因此需要后处理,例如阈值化,以获得估计的DAG。第三,相关的分布式优化也大不相同。在我们的算法中,本地客户端和中央服务器通信梯度信息。FedDAG[32]采用策略,中央服务器在每轮通信中平均来自本地客户端的代理邻接矩阵,并将平均值广播回本地客户端。在NOTEARS-ADMM[33]中,本地客户端和中央服务器的模型参数(例如加权邻接矩阵)通过交替方向乘子法(ADMM)的迭代更新规则进行交换。最后,我们的工作专注于广义线性模型,用于各种类别的变量,而其他两篇论文专注于连续变量在线性和非高斯高斯情况下。

IV. 理论保证

在本节中,我们研究了分布式优化(算法2)收敛到预言机解的收敛性,并建立了(7)全局最小化器的一致性。由于我们的方法主要受涉及大量分布式数据的应用的驱动,我们在很大且变量数量保持固定的设置下开发理论结果。我们的分析重点将是非高斯情况。在高斯情况下,样本协方差矩阵是一个充分统计量,本地客户端可以将其版本通信给中央服务器,在一轮通信中,然后可以形成完整的矩阵并计算全局DAG估计。换句话说,在高斯情况下不需要像算法1这样的复杂分布式算法。

A. 预言机保证

回想本地迭代函数定义在(10)中。分布式算法的总体迭代函数可以写成(算法2中的第6行)。设是模型的总体二阶矩矩阵,且是其最小特征值。对于矩阵,让表示以为中心,半径为的Frobenius球。我们考虑数值变量的情况,即对于所有,这包括连续和二元离散随机变量。以下定理为任何固定的提供了分布式优化算法由表示的收敛保证。让是任何全局最小化器,即
其中是一个凸正则化器。在分布式数据设置中,是一个可以访问多个本地客户端跨所有数据的预言机解。设是GLDAGs的参数空间。我们回忆一下,对于表示第列,且是从GLDAG模型(4)中的i.i.d样本。
定理1(收敛到预言机):假设的坐标被限制,即对于所有。设是任何GLDAG参数和,并设。设,且假设上是-Lipschitz的。定义
其中。进一步假设。存在常数使得如果
对于所有
定理1适用于任何。自然地取,全局目标函数的最小化器,即
由于的一致估计对于任何(定理2,第IV-B节),则随着增长而趋向于零。因此,高概率下,迭代算子将是一个压缩:序列由分布式算法产生,几何级数收敛到预言机估计器如果。对于固定的,对于足够大的使得,总能通过取每客户端的最小样本大小足够大来满足的条件。因此,定理1为几何级数收敛提供了量化的最小下界。注意-有界假设对于二元和序数数据是trivially满足的,这是这项工作的主要关注点。
定理1是通过建立GLDAG模型(4)的Hessian的均匀集中于其期望在参数空间中某些球上,然后引用我们在补充材料中推导出的DANE算法的一般收敛结果(参见定理S2)来证明的。在GLDAG模型中建立这种均匀集中是具有挑战性的,因为之间存在高度依赖和非线性关系。建立Hessian的集中的Ledoux-Talagrand收缩定理的技术工具。为了将论点扩展到多对数和一般向量值DAG模型,需要一个目前文献中尚不可用的收缩定理的多变量扩展。原则上,这种扩展是可能的,我们将其留作未来的工作。

B. 一致性

在本节中,我们在(4)中的模型类下建立一致性结果,没有限制在数值变量。设为单个样本来自模型(4)的负对数似然。我们将视为向量通过连接列来处理,以便
回想是全局正则化负对数似然,且是GLDAG参数空间。优化问题(7)等同于。让我们用表示的全局最小化器,且表示具有真实DAG 的真实参数。对于任何,考虑(交叉)费舍尔信息矩阵
我们注意到对于指数族是凸函数,因此总是半正定的。为了建立以及对于任何的一致性,我们提出以下假设:
(A1) 真实的DAG 是可识别的。
(A2) 对于每一个,存在一个邻域和函数使得几乎必然地
对于所有
(A3) 对于每一个,我们有
在(A2)中,隐含地假设几乎必然地在中是有限的。
在我们陈述理论结果之前,我们定义,这正是与一致的排列集合,即的拓扑排序,特别是非空的。看到这一点,我们首先注意到对于任何与一致的,我们有。KL散度论证然后表明是优化问题定义的唯一解。也就是说,任何与。反之,如果,则,因此一致。有了这个观察,我们在以下定理中建立了所需的一致性结果。
定理2:假设(A1)-(A3)和。则,a) 对于每一个上有一个唯一的最小化器,且
b) (DAGs的空间)上有一个唯一的最小化器,且
c) 当时,以概率收敛到1,对于某个(序列的)
定理2确认了组Lasso正则化估计器,定义为的全局最小化器,是-一致的,并且它将在大样本极限下识别正确的拓扑排序。此外,该定理还建立了所有的受限最小化器的一致性,即定理1中的预言机估计器。假设(A1)根据命题1在温和条件下成立,假设(A2)是标准正则性条件,假设(A3)与第二矩矩阵的非奇异性有关。例如,考虑对于所有的情况,假设的元素被限制,设,将视为矩阵,其第列是。那么如果对于所有,非奇异性对于(A3)成立是足够的。

V. 模拟数据上的结果

个节点图上边的数量。我们从贝叶斯网络存储库[52]下载了以下网络()以模拟数据:Asia (8, 8), Sachs (11, 17), Child (20, 25), Insurance (27, 52), Alarm (37, 46), Hailfinder (56, 66) 和 Hepar2 (70, 123)。我们在GLDAG模型(4)和其他常见DAG模型下生成数据,在第V-C节和V-D节中,后者是为了测试我们方法对其模型假设违规的鲁棒性。

A. 方法

我们将DARLS算法与以下DAG结构学习方法进行了比较:标准的贪心爬山(HC)算法[53],Peter-Clark (PC)算法[54],最大-最小爬山(MMHC)算法[55],快速贪心等价搜索(FGES)[56],[57],[58],NOTEARS算法[50],和DAG-GNN方法[59]。在这些方法中,PC是约束型方法,MMHC是混合方法。其他三种方法是基于分数的,其中HC搜索DAG,FGES搜索等价类,NOTEARS使用连续优化来估计DAG结构,DAG-GNN应用图神经网络架构来学习DAGs。
按照分布式数据上DAG学习的做法[26]、[27]、[28],我们结合了由竞争方法生成的本地估计来获得全局图估计。将本地客户端上的本地数据集记为。我们在每个本地数据集上应用竞争方法来获得一个完成的部分有向无环图(CPDAG),然后使用构建全局图。我们在这里使用CPDAG,因为所有竞争方法都是在非可识别DAG模型下开发的。在五种竞争方法中,只有PC和FGES输出CPDAG,因此我们将其他方法估计的DAG转换为CPDAG以获得。给定,我们统计了每一对(i, j)之间三种可能的方向:i → j, i ← j或i − j(无向)。然后我们按照它们在本地图中的出现次数降序排列这些边缘方向,并依次将这些边缘方向添加到空图中,只要它们不会引入有向循环(由所有有向边组成的循环)。在这个进程结束时,我们得到了一个部分有向图。最后,我们应用Meek的规则[56]、[60]来最大化无向边缘的方向,从而构建了全局CPDAG估计。
备注2:HC的全局估计在这种方法中边太多,因为它的本地CPDAG缺乏共识,导致候选边缘数量大,假阳性(FP)边缘更多。为了解决这个问题,如果本地图中的大多数没有它们之间的边缘,则我们不添加任何边缘在节点对(i, j)之间。通过这种方式,由HC估计的全局图与其它方法相比变得合理地稀疏。此外,我们提供DARLS和NOTEARS-ADMM[32]在二元数据上的数值结果的补充材料部分S2,因为NOTEARS-ADMM不是为具有多于两个水平的分类数据设计的。数值结果表明,DARLS始终优于NOTEARS-ADMM,一致地实现更低的SHD在各种数据生成场景中。
我们用MATLAB实现了DARLS算法,并使用以下程序包运行竞争方法:bnlearn[61]用于MMHC和HC算法,pcalg[62]用于PC算法,rcausal[57]用于FGES。NOTEARS和DAG-GNN方法是用它们的在线Python代码[63]、[64]运行的。竞争方法被应用到每个本地数据集上,使用2016 MacBook Pro(2.9 GHz Intel Core i5, 16 GB内存)。由于DARLS是为分布式计算设计的,它在计算机集群上运行。
在这项研究中,DARLS算法(算法1)是用一个随机排列初始化的。根据目标函数的景观,我们将初始退火温度设置为,并在总共次迭代中逐渐降低到。注意由于对数似然已经通过样本大小归一化,如(6)所示,目标函数的范围相当小。对于PC算法,使用了0.01的显著性水平来生成具有期望稀疏度的图。FGES应用了默认值0.1的显著性水平。对于MMHC和HC方法,一个节点的最大父数被设置为三。对于NOTEARS,我们使用了默认的损失和默认的阈值。对于DAG-GNN,我们将其阈值设置为0.15,并使用了其他默认参数值,其中默认阈值是0.3,用于细化最终的DAG结构,类似于我们后处理步骤(第III-C节)中的参数。在我们的数值预实验中,使用DAG-GNN的所有默认值,它总是输出空图。因此,我们手动调整了它的输入参数一个接一个,这种方法只有在降低阈值时才生成非空结构。

B. 准确性指标

给定上述方法生成的估计,我们使用一些指标来评估它们的结构准确性。为了标准化性能指标,我们在计算以下指标之前将估计的DAG转换为CPDAG,当真实的DAG不可识别时。
设P、TP、FP、M、R分别为估计边缘、真阳性边缘、假阳性边缘、缺失边缘和反向边缘的数量。更具体地说,P是估计图中边缘的数量,FP是估计图骨架中但不在真实骨架中的边缘数量,M计算真实骨架中但不在估计图骨架中的边缘数量。TP报告估计DAG/CPDAG和真实DAG/CPDAG之间一致的边缘数量,其中一致的边缘必须在两个节点之间具有相同的方向。在DAG中有2种可能的边缘方向,在CPDAG中有3种。最后,反向边缘数量R = P − TP − FP。然后我们定义结构Hamming距离,SHD = R + FP + M,作为组合指标。如果一个方法实现了更低的SHD,则它具有更高的结构学习准确性。

C. GLDAG数据

我们使用逻辑GLDAG模型(4)与对于所有,生成二元数据,其中系数参数中均匀采样。我们在两种设置下为每个网络模拟20个数据集,,其中将总共个观测值随机分配给个本地客户端。由于GLDAG是可识别的,我们通过计算SHD将DARLS估计的DAG与真实DAG进行比较。对于所有其他方法,我们比较估计的和真实的CPDAG,因为它们不假设可识别的DAG。
表I报告了每个七个图的20个数据集的平均性能指标SHD,使用六种方法。包括TP、FP、R和M的平均值以及SHD的标准差更详细的分解,见补充材料中的表S1。由于NOTEARS估计过于稀疏,使用默认设置,我们将惩罚调优参数降低到,从建议值。我们只为两个小图,Asia和Sachs提供了它在补充材料中的结果。它的SHD落在与其他方法相比的中位数性能范围内。DAG-GNN(表I中的GNN)在计算图时需要时间,在时每个数据集超过一小时,而其他竞争方法最多需要5分钟。因此,我们只为较小的前四个图提供了它在表I中的结果。PC也在最后两个网络,Hailfinder和Hepar2上生成估计时遇到了困难,并且因此在这两个图的比较中被移除。

表I显示,在两种情况下,DARLS在每个网络中始终实现了所有方法中的最低SHD,展示了在估计图形结构方面的更高准确性。DARLS的相对效能保持稳健,随着图的大小增加没有显示出减少的迹象。在四个最大的图,Insurance、Alarm、Hailfinder和Hepar2中,DARLS一致地实现了相对于第二佳方法约40%的SHD降低。在补充材料第S4节中提供了关于DARLS性能随着图大小从76增长到223的额外数值研究,显示出可比的甚至更大幅度的改进。DARLS在几乎所有情况下都识别出了比其他方法更多的TP边缘。通过阈值化的细化步骤也有助于通过减少FP边缘来降低SHD。与竞争方法的一个关键区别是DARLS的联合学习特性,它在所有本地客户端之间协调结构学习。数值比较表明,这比简单地通过对本地客户端单独学习得到的图进行投票或平均得到的共识要准确得多。由于缺乏协调,本地估计可能代表一组具有显著不同图形结构的局部最优结构。由这样一组图构建的共识可能与真实结构不接近。这个问题通常在图较大时更为严重。
为了检查在分布式数据上估计网络结构的准确性损失,我们计算了每个方法的预言机解,假设完全访问所有个本地数据集(汇总数据)。为了简洁地报告结果,我们选择了每个网络中表现最好的方法,称为最佳竞争方法,与DARLS进行比较。图1显示了DARLS和最佳竞争方法在分布式和汇总数据上的SHD性能。

首先,我们观察到,无论是应用于分布式还是汇总数据,DARLS都实现了类似的SHD值。DARLS在分布式和汇总数据上的SHD差异远小于最佳竞争方法,这表明DARLS在利用分布式数据方面更为有效。其次,与表I中的结果一致,DARLS始终显著优于应用于分布式数据的最佳竞争方法(最佳-分布式)。此外,DARLS在分布式数据上的表现与最佳竞争方法在汇总数据上的表现(最佳-预言机)有相当大的重叠,后者是HC或FGES对于和HC、MMHC或FGES对于。这种重叠表明我们分布式学习算法的高度竞争力。DARLS的SHD变异性通常小于最佳竞争方法,显示出在不同数据集之间的更高一致性。
为了量化使用算法2的分布式优化的准确性,我们为固定的计算了排列分数(7)在不同的值下。对于每个值,我们固定了调优参数、排列和所有内部计算参数,以确保只有在计算时变化。设是使用个本地客户端计算的的值,且是损失的相对增加。图2(a)显示了所有网络的的值在的顺序,验证了使用整体()或分布式数据()计算的本质上是相同的。由于迭代次数是固定的,随着网络大小的增加而增加。

我们还测试了在20个Insurance数据集上使用计算的计算时间。在每次测试中,相同的数据被分割并分发到不同数量的本地客户端,以解决优化问题(7),所有其他参数固定。图2(b)显示了计算时间与的关系。正如预期的那样,当分配一个复杂的任务时,如果使用更多的客户端,计算需要的时间更少。然而,随着的增加,减少的计算时间达到大约一个稳定水平之后,表明了从并行计算中获得的收益与通信开销之间的权衡。此外,最小的本地样本大小大时减小,定理1表明这会减慢算法2的收敛速度。联合学习主要旨在使多个本地客户端之间的协作学习成为可能,而不是通过数据分配来加速过程。然而,图2(b)表明,数据分布实际上可以减少DARLS的计算时间。具体来说,我们的发现表明,获得汇总数据上的预言机估计的运行时间(即)大约是时的四倍。

D. 其他数据生成模型

为了测试DARLS对其模型假设违规的鲁棒性,我们还从不同的DAG模型生成数据。特别地,我们使用阈值高斯和多项式DAG来生成离散数据,然后比较DARLS与其他方法。
对于阈值高斯DAG模型,我们首先使用高斯结构方程模型生成连续变量,其中是从中的高斯噪声,每个系数参数在中从中均匀采样,其中。然后,我们阈值化这些连续值以生成二元数据,其中的样本均值。我们使用两个组分混合高斯来模拟中的连续数据,每个组分从中以相等的概率抽取。在这种设计下,大多数的分布在分布上是双峰的,这大大增加了阈值化的鲁棒性。这种连续和分类变量之间的转换已在先前的工作[65]、[66]、[67]中用于贝叶斯网络的建模,但阈值高斯DAG模型不是我们比较的六种方法中的任何一种的潜在模型。
我们还从贝叶斯网络存储库[52]中提供的列联表模拟多项式数据。我们对列联表进行了一些修改,以确保(1)每个变量的状态数最多为三个,且(2)每个变量的边缘概率至少为0.1,通过合并状态来实现。由于Hailfinder和Hepar2的高结构复杂性,一些节点的原始列联表在没有修改的情况下使用,导致一些状态的边缘概率小于0.1。我们注意到,多项式DAG模型是大多数竞争方法的潜在模型,包括HC、PC、FGES和MMHC。因此,在这些数据上的比较还将测试GLDAG(4)能否很好地近似通常使用的多项式DAG模型。
表II报告了每种方法在由阈值高斯和多项式DAG模型生成的数据上的结构估计准确性,其中。包括TP、FP、R、M和SHD的标准差的更详细指标,见补充材料中的表S2。当底层模型是阈值高斯DAG时,DARLS在4个网络中实现了最低的SHD,即Asia、Sachs、Child和Insurance。对于由多项式模型模拟的数据,DARLS仍然可以在Sachs和Child中估计网络,实现最低的SHD。在大多数其他情况下,DARLS仅比HC差,但比其它方法更好或至少相当。值得注意的是,我们特别为HC调整了组合本地估计并构建全局图的过程,由于其本地估计中缺乏共识(备注2)。看到DARLS在所有情况下都优于FGES,一种在多项式DAG模型下的一致分数方法[56],这突出了DARLS在分布式数据中使用的有效性。这也表明GLDAG(4)可以很好地近似通常用于离散数据的多项式DAG模型。这项研究证实了DARLS在不同DAG模型生成的数据上确实表现相对较好,这对于其实际应用很重要。这在下一节的现实世界数据应用中得到了进一步证明。

备注3:DAG-GNN的源代码没有展示它如何处理具有三个或更多类别的分类数据,因此我们无法直接将其应用于多项式数据集。为了简洁,我们在补充材料中提供了它在阈值高斯情况下的结果,该情况是二元的。DAG-GNN的SHD值接近或略高于MMHC。

VI. 真实数据应用

在本节中,我们将我们的方法应用于由[68]生成的ChIP-Seq数据。数据集包含小鼠胚胎干细胞中12个转录因子(TFs)的DNA结合位点:Smad1、Stat3、Sox2、Pou5f1、Nanog、Esrrb、Tcfcp2l1、Klf4、Zfx、E2f1、Myc和Mycn。对于每个TF,计算了相应的ChIP-Seq信号强度的加权和作为每个基因的关联强度得分[69]。粗略地说,这个得分可以被理解为TF与基因结合的强度度量。按照[70]中的相同预处理,从我们的分析中移除了具有零关联得分的基因。因此,我们观察到的数据矩阵大小为,包含了12个TF在8462个基因上的关联得分。我们的目标是构建一个因果网络,揭示这些12个TF如何影响彼此对基因的结合。TF的关联得分通常是双峰的,它们在网络估计之前被离散化;见补充材料中的图S1,以示例离散化。
我们将这个数据集分布在个本地客户端上,因此每个本地客户端包含大约400个样本。然后我们应用DARLS、HC、MMHC、PC和FGES到分布式数据上,以学习蛋白质-DNA结合网络。NOTEARS和DAG-GNN被排除在这次比较之外,因为它们在模拟研究中的性能并不具有竞争力。本地估计的每个竞争方法被组合以构建全局图,就像我们在第V节中所做的那样。为了便于比较,我们控制了估计网络的稀疏度,使得每种方法在分布式数据上产生了两个图,大约有和29条边,除了FGES难以生成接近17条边的输出。我们还应用了每种方法到汇总数据上(即,),使用了第V节中的相同参数。在这种情况下,每个估计的图有大约条边。唯一的例外是PC,即使将其显著性水平降低到,它的估计也只有21条边。每种方法的关键参数见补充材料第S5节。
由于真实的网络结构未知,所以我们使用多项式DAG模型下的十折交叉验证中的测试数据似然来评估估计网络的准确性。设分别为在十折交叉验证中使用训练和测试数据集在多项式DAG模型下的似然值(见补充材料第S5节关于计算)。我们还计算了BIC = 用于模型比较,其中是训练样本大小,是估计图的多项式参数数量。我们选择一些基准来方便比较。设分别为最高的测试数据似然值和最低的BIC值。定义对数似然差异和BIC差异。注意,由于是测试数据对数似然,而BIC是使用训练数据计算的,所以的量级远大于。我们还计算了作为归一化边际似然比率(NLR)的近似值(),其中表示训练数据,是由BIC确定的基准模型。
表III总结了在三种比较设置下每种方法的和NLR,分别是稀疏、中等和预言机。第一种和第二种设置报告了使用分布式数据(客户端)估计的具有不同稀疏度的图的结果,最后一种显示了在汇总数据上(即)的相应预言机结果。在稀疏和中等设置中,DARLS实现了最高的测试数据似然值和最小的BIC,显著优于所有其他方法,这再次证明了DARLS在分布式数据中DAG学习方面的有效性。预言机方法,除了PC,具有可比的测试数据似然值,所有这些都高于它们在分布式数据上的相应结果。比较每种方法在汇总和分布式数据上的似然值,我们看到DARLS显示出最小的差异。换句话说,在所有方法中,DARLS在应用于分布式数据时与其在汇总数据上的预言机结果相比损失最小。

值得注意的是,每种方法的似然值是在多项式DAG模型下计算的,而不是在GLDAG模型下。因此,DARLS在现实世界数据上的优越性能表明,我们提出的GLDAG模型是潜在数据生成机制的良好近似。
为了获得更多的科学见解,我们在图3中展示了由DARLS从分布在个本地客户端上的完整数据集()中学习到的更稀疏的DAG()及其转换的CPDAG。一个有趣的观察是估计的CPDAG中Nanog→Pou5f1→Sox2的有向路径,这是小鼠胚胎干细胞中基因调控网络的三个核心调节因子[68]、[71]。众所周知,许多基因是由Pou5f1、Sox2和Nanog共同调控的。估计的路径表明Nanog结合会引起Pou5f1结合,然后可能导致Sox2结合。这为这三个TF如何协同调控下游基因提供了新的线索。数据分析[68],即生成ChIP-Seq数据的原始工作,表明有两个TF群倾向于共同结合:第一组包括Nanog、Sox2、Oct4、Smad1和STAT3,而第二组包括Mycn、Myc、Zfx和E2f1。这两个组在估计的CPDAG中被清楚地恢复,其中第二组TF有一个密集的无向子图,第一组有一个完全有向的子图。此外,有向边Myc→Pou5f1表明第二组可能在因果上游的第一组,这是一个潜在的实验调查新假设。

VII. 讨论

在本文中,我们开发了DARLS算法,它结合了模拟退火中的分布式优化方法,用于从分布式数据中学习因果图。基于模拟研究和现实世界数据应用,我们已经展示了DARLS即使在其模型假设被违反时也具有高度竞争力。在给定排序的分布式优化中,DARLS通过优化一个凸惩罚似然来学习因果图。在实践中,可以考虑使用[30]、[46]中的凹面惩罚来提高学习DAG的准确性,尽管可能缺乏分布式学习中凹面惩罚收敛性的理论保证。这肯定是我们方法未来发展的一个有希望的方向。
我们提出的GLDAG模型包括一系列灵活的分布,除了具有等方差的线性高斯模型(与多对数模型)之外,因此可以应用于不同类型的数据。也可以将GLDAGs(4)的框架推广到模拟变量之间的非线性因果关系。考虑标量变量以简化。对于每个边缘,我们关联一个非线性函数。然后(4)中的被替换为,导致[Xj | PAj]的广义加性模型。这种推广预计将以更高的准确性近似真实的因果关系。我们已经建立了连续GLDAGs的可识别性,证明了它们在因果发现中的使用,并且将一般GLDAGs的可识别性研究留作未来的工作。
本文的主要重点是大型分布式数据,其中大但适中。在这种设置下,我们建立了分布式优化得到的解收敛到全局最小化器(即预言机解)和全局最小化器作为真实DAG参数估计的一致性。然而,将收敛和一致性结果推广到允许发散的在理论上是有趣的,并且留作未来的工作。

声明

本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。

#论  文  推  广#

 让你的论文工作被更多人看到 


你是否有这样的苦恼:自己辛苦的论文工作,几乎没有任何的引用。为什么会这样?主要是自己的工作没有被更多的人了解。


计算机书童为各位推广自己的论文搭建一个平台,让更多的人了解自己的工作,同时促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 计算机书童 鼓励高校实验室或个人,在我们的平台上分享自己论文的介绍、解读等。


稿件基本要求:

• 文章确系个人论文的解读,未曾在公众号平台标记原创发表, 

• 稿件建议以 markdown 格式撰写,文中配图要求图片清晰,无版权问题


投稿通道:

• 添加小编微信协商投稿事宜,备注:姓名-投稿

△长按添加 PaperEveryday 小编



PaperEveryday
为大家分享计算机和机器人领域顶级期刊
 最新文章