王亦洲课题组 CVPR 2021 入选论文解读:时间序列疾病预测的因果隐马尔可夫模型
本文是对发表于计算机视觉和模式识别领域的顶级会议 CVPR 2021的论文“Causal Hidden Markov Model for Time Series Disease Forecasting(时间序列疾病预测的因果隐马尔可夫模型)”的解读。
该论文由北京大学王亦洲课题组与深睿医疗等单位合作,针对时间序列疾病预测的问题,提出了因果隐马尔可夫模型描述疾病的动态发展过程,并使用基于 VAE 的变分框架进行学习。通过对图像隐空间进行解耦,去除疾病无关因子与疾病预测的伪相关关系,从而提高预测的准确率和鲁棒性。
项目主页:https://sites.google.com/view/causal-hmm
论文链接:https://arxiv.org/abs/2103.16391
01
研究背景
在医学诊断中,对不可逆型疾病(如视盘萎缩症)进行时间序列的疾病预测非常重要,对未来疾病发展的预测可以帮助患者进行提前干预,对于疾病的有效控制有很大的意义。
但是这类预测目前存在两个大问题。首先,目前的很多时间序列疾病预测的方法都是提取所有的时序图像特征来进行未来疾病的预测。然而图像中通常存在很多与该类疾病无关的信息或特征,当加入这些疾病无关的信息进行训练时,会引入伪相关关系,即它们本身与疾病无关,但训练时使用它们参与了疾病的预测,引入了统计上的相关关系。当模型使用这些伪相关的信息对新的分布下测试集的样本进行疾病预测时,很容易导致失败。其次,很多时间序列的预测方法建立在时序标签完整的情况下。然而由于医学标签标注成本高昂,实际情况下很多时候过去时间步的疾病标签是缺乏的,标签的缺乏也给未来阶段的疾病预测带来了很大的挑战。
02
方法介绍
图1. 因果隐马尔可夫模型
为了解决上述问题,本文对时间序列疾病预测建立了一个基于时序的因果隐马尔可夫模型(Causal-HMM),即针对现有的观测数据,包括每个时间步的图像数据 xt,临床测量数据 At(如角膜厚度、角膜曲率等),个体属性数据 Bt(如年龄、性别等)以及在未来阶段的疾病标签 YT,建立一个描述其相关关系的因果图(如图1),用该因果图去刻画每个时间步从隐空间到观测值的生成过程。其中隐空间中一部分是与疾病无关的因子 zt,一部分是疾病相关的因子 st 和 vt 。 代表与疾病相关的临床测量数据的隐变量因子, st 代表其他与疾病相关的参与图像生成的因子。而个人属性 Bt 会对所有的隐变量带来影响。本文对图像在隐空间内进行解耦,旨在通过分离疾病无关因子来去除训练中所带来的伪相关关系。在理论上作者通过可识别性定理(如图2所示)对监督场景下的时间序列数据给出了隐变量解耦的可识别性保证。
图2. 可识别性定理
为了学习本文所提出的因果隐马尔可夫模型,作者使用了基于 VAE 的变分框架去学习时间序列下的图像及临床属性的生成过程以及进行疾病标签的预测(如图3所示)。具体来说,在每个时间步下先验网络接受个人属性特征及上一个时间步的隐变量作为输入,得到当前时间步的隐变量先验;而后验网络的编码器接受当前时间步的图像及临床属性特征,及上一步的个人属性特征进行输入,相应地得到隐变量后验。同时每一步的解码器会对隐变量进行解码,完成对当前步的图像及临床属性的重构。隐变量的后验和先验通过 KL 距离进行约束。在最后一个时间步下,通过所提取出的疾病相关的因子进行未来时间步的疾病预测。
图3. 左:Causal-HMM的时间序列网络架构;右:每个时间步的先验网络,后验网络及生成网络构成
03
实验结果
本文收集了507个个体样本,每个样本包括一到五年级的视网膜图像数据以及相应年级的属性数据,以及六年级的视盘萎缩疾病标签数据。作者对507个样本进行了数据集的划分,其中训练集验证集测试集的数量分别为300,100,107。为了更好地验证本文方法的泛化性能,作者将训练集验证集和测试集按照性别划分为两个不同的分布,其中前两者数据集的性别分布为男女比2:3而后者测试集性别分布为3:1。作者对包括一到五年级的十个所有可能的时间序列设置下进行了实验,并与多个现有的疾病进展预测和时间序列预测的方法进行了对比。本文的方法在几乎所有的实验设置及平均情况下的 ACC 和 AUC 指标均高于已有方法(如表1所示),展示了该方法在解决时间序列疾病预测问题上的优越性。
表1. 与对比方法RGL, Devised RNN, LogSparse Transformer在一到五年级所有时间序列上的ACC与AUC结果对比
同时作者对本文的方法进行了消融实验(如表2所示),分别测试了他们所使用的时间序列结构(CNN vs Seq VAE, CNN+LSTM vs Seq VAE),属性信息(Seq VAE vs Seq VAE + Att)及隐空间解耦机制(Seq VAE + Att vs Ours)的有效性。
表2. 针对时间序列网络结构, 属性信息以及解耦机制的消融实验在一到五年级所有时间序列上的ACC与AUC结果对比
此外作者设计了一个第二阶段的疾病分类器证明解耦出来的隐变量的鲁棒性,将已经训练好的 Causal-HMM 模型的疾病相关因子和无关的因子分别取出,在训练集验证集及新分布下的测试集上进行预测。疾病无关的因子在新分布下的预测准确率有很大下降,而解耦出的疾病相关因子在不同分布下的有着稳定和鲁棒的表现(如表3所示)。
表3. 在一到五年级所有时间序列上使用不同隐变量(s+v vs z)对疾病进行预测,在训练集,验证集及测试集上的ACC与AUC结果对比
作者对 Causal-HMM 模型所学习到的隐空间因子 s 和 z 通过 Grad-CAM 进行了可视化(如图4所示),结果表明疾病相关的因子 s 在视盘周围显示了高响应,而疾病无关的因子 z 的高响应处更多地散布在视网膜图像的其他区域如黄斑区等。本文的方法通过将 z 解耦出来可以去除其和疾病的伪相关关系,从而在不同分布下的疾病预测上有更为鲁棒的表现。
图4. 不同隐变量(s vs z)的特征图可视化
04
结 语
本文针对时间序列疾病预测问题提出了一个因果隐马尔可夫模型进行未来阶段的疾病预测。为了保证模型的泛化性能,作者对每个时间步下的隐空间进行了显式的解耦和分离,并通过可识别性的结果对该解耦机制给出了理论保证。针对因果隐马尔可夫模型,作者提出了一个新的时间序列变分框架进行该模型的学习和推断。实验上,作者将其方法应用在了视盘萎缩疾病的时序预测问题中,并和当前现有的最优方法进行了对比,在新的测试集分布下取得了更优越的性能,展示了该方法的有效性和鲁棒性。
参考文献
[1] Judea Pearl. Causality. Cambridge university press, 2009.
[2] Maxime Louis, Raphael Couronne, Igor Koval, Benjamin Charlier, and Stanley Durrleman. Riemannian geometry learning for disease progression modelling. In IPMI 2019.
[3] Ilyes Khemakhem, Diederik P Kingma, and Aapo Hyvärinen. Variational autoencoders and nonlinear ICA: A unifying framework. In AISTATS 2020.
[4] Xinwei Sun, Botong Wu, Chang Liu, Xiangyu Zheng, Wei Chen, Tao Qin, and Tie-yan Liu. Latent causal invariant model. arXiv preprint arXiv:2011.02203, 2020.
IEEE Conference on Computer Vision and Pattern Recognition(IEEE CVPR)是计算机视觉领域国际顶级会议(CCF A类),每年举办一次。CVPR 2021将于2021年6月19-25日在线上举行。