留言板

尊敬的读者、作者、审稿人, 关于本刊的投稿、审稿、编辑和出版的任何问题, 您可以本页添加留言。我们将尽快给您答复。谢谢您的支持!

姓名
邮箱
手机号码
标题
留言内容
验证码

基于残差SDE-Net的深度神经网络不确定性估计

王永光 姚淑珍 谭火彬

王永光,姚淑珍,谭火彬. 基于残差SDE-Net的深度神经网络不确定性估计[J]. 北京航空航天大学学报,2023,49(8):1991-2000 doi: 10.13700/j.bh.1001-5965.2021.0604
引用本文: 王永光,姚淑珍,谭火彬. 基于残差SDE-Net的深度神经网络不确定性估计[J]. 北京航空航天大学学报,2023,49(8):1991-2000 doi: 10.13700/j.bh.1001-5965.2021.0604
WANG Y G,YAO S Z,TAN H B. Residual SDE-Net for uncertainty estimates of deep neural networks[J]. Journal of Beijing University of Aeronautics and Astronautics,2023,49(8):1991-2000 (in Chinese) doi: 10.13700/j.bh.1001-5965.2021.0604
Citation: WANG Y G,YAO S Z,TAN H B. Residual SDE-Net for uncertainty estimates of deep neural networks[J]. Journal of Beijing University of Aeronautics and Astronautics,2023,49(8):1991-2000 (in Chinese) doi: 10.13700/j.bh.1001-5965.2021.0604

基于残差SDE-Net的深度神经网络不确定性估计

doi: 10.13700/j.bh.1001-5965.2021.0604
基金项目: 国家重点研发计划(2018YFB1402600)
详细信息
    通讯作者:

    E-mail:thbin@buaa.edu.cn

  • 中图分类号: TP183

Residual SDE-Net for uncertainty estimates of deep neural networks

Funds: National Key R & D Program of China (2018YFB1402600)
More Information
  • 摘要:

    神经随机微分方程模型(SDE-Net)可以从动力学系统的角度来量化深度神经网络(DNNs)的认知不确定性。但SDE-Net面临2个问题,一是在处理大规模数据集时,随着网络层次的增加会导致性能退化;二是SDE-Net在处理具有噪声或高丢失率的分布内数据所引起的偶然不确定性问题时性能较差。为此设计了一种残差SDE-Net(ResSDE-Net),该模型采用了改进的残差网络(ResNets)中的残差块,并应用于SDE-Net以获得一致稳定性和更高的性能;针对具有噪声或高丢失率的分布内数据,引入具有平移等变性的卷积条件神经过程 (ConvCNPs)进行数据修复,从而提高ResSDE-Net处理此类数据的性能。实验结果表明:ResSDE-Net在处理分布内和分布外的数据时获得了一致稳定的性能,并在丢失了70%像素的MNIST、CIFAR10及实拍的SVHN数据集上,仍然分别获得89.89%、65.22%和93.02%的平均准确率。

     

  • 图 1  普通神经网络、基本残差块或瓶颈残差块所构建的DNNs

    Figure 1.  DNNs constructed with plain networks, basic or bottleneck residual block

    图 2  用于训练和测试阶段的ResSDE-Net的不确定性估计方法框架

    Figure 2.  Framework for the uncertainty estimates of the proposed ResSDE-Net for training and testing phases

    图 3  具有8/16/20/24层DNNs的SDE-Net由基本残差块、瓶颈残差块和普通块所构建模型的训练和测试准确性结果

    Figure 3.  Training and testing accuracy results of SDE-Net with basic residual blocks, bottleneck residual blocks and plain building blocks to construct 8/16/20/24-layer DNNs

    图 4  基于cResSDE-Net的MNIST修复情况

    Figure 4.  Completed MNIST based on cResSDE-Net

    图 5  基于ResSDE-Net的真实世界实拍的数据集SVHN的修复情况

    Figure 5.  Completed real-world dataset SVHN based on ResSDE-Net

    表  1  基本残差块,瓶颈残差块以普通块所构建的24层神经网络架构

    Table  1.   24-layer neural network architectures constructed by basic residual blocks, bottleneck residual blocks, and plain blocks

    名称24层基本残差块24层瓶颈残块24层普通块
    下采样层Conv1{${3}{ \text{×} }{3,}\;{64,}\;{ {\rm{stride} }\;1}$}; {${4}{\text{×}}{4,64,\;{\rm{stride}}\;2}$}; {${4}{\text{×}}{4,}\;{64,{\rm{stride}}\;2}$}
    ConcatConv2_x$\left\{ \begin{array}{cc}\text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right)\\ \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right)\end{array} \right\}$$\left\{ \begin{array}{c}\begin{array}{cc} \text{1}\text{×}\text{1,}\; \left(\text{64, 64}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right) \end{array}\\ \begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right) \\ \text{1}\text{×}\text{1,}\; \left(\text{64, 64}\right) \end{array}\end{array} \right\}$$\left\{ \begin{array}{c}\begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right) \end{array}\\ \begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right) \end{array}\end{array} \right\}$
    $\left[ \begin{array}{cc}\text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right)\\ \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right)\end{array} \right]$
    ConcatConv3_x$\left\{ \begin{array}{cc}\text{3×3,}\; \left(\text{64, 128}\right)\\ \text{3×3,}\; \left(\text{128, 128}\right)\end{array} \right\}$$\left\{ \begin{array}{c}\begin{array}{cc} \text{1}\text{×}\text{1,}\; \left(\text{64, 128}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{128, 128}\right) \end{array}\\ \begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{128, 128}\right) \\ \text{1}\text{×}\text{1,}\; \left(\text{128, 64}\right) \end{array}\end{array} \right\}$$\left\{ \begin{array}{c}\begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{64, 128}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{128, 128}\right) \end{array}\\ \begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{128, 64}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right) \end{array}\end{array} \right\}$
    $\left\{ \begin{array}{cc}\text{3×3,}\; \left(\text{128, 64}\right)\\ \text{3×3,}\; \left(\text{64, 64}\right)\end{array} \right\}$
    ConcatConv4_x$\left\{ \begin{array}{cc}\text{3×3,}\; \left(\text{64, 256}\right)\\ \text{3×3,}\; \left(\text{256, 256}\right)\end{array} \right\}$$\left\{ \begin{array}{c}\begin{array}{cc} \text{1}\text{×}\text{1,}\; \left(\text{64, 256}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{256, 256}\right) \end{array}\\ \begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{256, 256}\right) \\ \text{1}\text{×}\text{1,}\; \left(\text{256, 64}\right) \end{array}\end{array} \right\}$$\left\{ \begin{array}{c}\begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{64, 256}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{256, 256}\right) \end{array}\\ \begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{256, 64}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right) \end{array}\end{array} \right\}$
    $\left\{ \begin{array}{cc}\text{3}\text{×}\text{3,}\; \left(\text{256, 64}\right)\\ \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right)\end{array} \right\}$
    ConcatConv5_x$\left\{ \begin{array}{cc}\text{3×3,}\; \left(\text{64, 512}\right)\\ \text{3×3,}\; \left(\text{512, 512}\right)\end{array} \right\}$$\left\{ \begin{array}{c}\begin{array}{cc} \text{1}\text{×}\text{1,}\; \left(\text{64, 512}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{512, 512}\right) \end{array}\\ \begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{512, 512}\right) \\ \text{1}\text{×}\text{1,}\; \left(\text{512, 64}\right) \end{array}\end{array} \right\}$$\left\{ \begin{array}{c}\begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{64, 512}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{512, 512}\right) \end{array}\\ \begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{512, 64}\right) \\ \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right) \end{array}\end{array} \right\}$
    $\left\{ \begin{array}{cc}\text{3×3,}\; \left(\text{512, 64}\right)\\ \text{3×3,}\; \left(\text{64, 64}\right)\end{array} \right\}$
    ConcatConv6_x$\left\{ \begin{array}{cc}\text{3}\text{×}\text{3,} \left(\text{64, 1 024}\right)\\ \text{3}\text{×}\text{3,} \left(\text{1 024, 1 024}\right)\end{array} \right\}$$\left\{ \begin{array}{c}\begin{array}{cc} \text{1}\text{×}\text{1,}\; \left(\text{64, 1 024}\right) \\ \text{3}\text{×}\text{3,} \left(\text{1 024, 1 024}\right) \end{array}\\ \begin{array}{cc} \text{3}\text{×}\text{3,}\; \left(\text{1 024, 1 024}\right) \\ \text{1}\text{×}\text{1,} \left(\text{1 024, 64}\right) \end{array}\end{array} \right\}$$\left\{ \begin{array}{c}\begin{array}{cc} \text{3}\text{×}\text{3,} \left(\text{64, 1 024}\right) \\ \text{3}\text{×}\text{3,} \left(\text{1 024, 1 024}\right) \end{array}\\ \begin{array}{cc} \text{3}\text{×}\text{3,} \left(\text{1 024, 64}\right) \\ \text{3}\text{×}\text{3,} \left(\text{64, 64}\right) \end{array}\end{array} \right\}$
    $\left\{ \begin{array}{cc}\text{3}\text{×}\text{3,}\; \left(\text{1 024, 64}\right)\\ \text{3}\text{×}\text{3,}\; \left(\text{64, 64}\right)\end{array} \right\}$
    全连接层平均池化,10分类全连接,归化的指数函数Softmax
    下载: 导出CSV

    表  2  在MNIST 和 SVHN 数据集上的分类和 OOD 检测

    Table  2.   Classification and OOD detection on MNIST and SVHN datasets

    模型方法分类准确性TNR at TPR 95%AUROC检测准确性AUPR
    实验①实验②实验①实验②实验①实验②实验①实验②InOut
    实验①实验②实验①实验②
    Threshold99.5±0.095.2±0.190.1±2.366.1±1.996.8±0.994.4±0.492.9±1.189.8±0.590.0±3.596.7±0.298.7±0.384.6±0.8
    DeepEnsemble99.695.492.766.598.094.694.190.194.597.899.184.8
    MC-dropout99.5±0.095.2±0.188.7±0.666.9±0.695.9±0.494.3±0.192.0±0.389.8±0.287.6±2.096.7±0.198.4±0.184.8±0.2
    PNs99.3±0.195.0±0.190.4±2.866.9±2.094.1±2.289.9±0.693.0±1.487.4±0.673.2±7.392.5±0.698.0±0.682.3±0.9
    BBP99.2±0.393.3±0.680.5±3.242.2±1.296.0±1.190.4±0.391.9±0.983.9±0.492.6±2.496.4±0.298.3±0.473.9±0.5
    p-SGLD99.3±0.294.1±0.594.5±2.163.5±0.995.7±1.394.3±0.495.0±1.287.8±1.275.6±5.297.9±0.298.7±0.283.9±0.7
    ResSDE-Net99.6±0.095.5±0.296.3±1.280.6±1.999.0±0.396.1±0.496.3±0.891.2±0.596.8±0.998.3±0.299.7±0.191.7±0.9
     注:加粗数据表示该模型具有最优平均性能。实验①ID数据集为MNIST时,OOD数据集为SVHN;实验②ID数据集为SVHN时,OOD数据集为CIFAR10。
    下载: 导出CSV

    表  3  MR={0.1,0.3,0.5,0.7,0.9}的ID数据集MNIST、CIFAR10和SVHN

    Table  3.   ID datasets MNIST, CIFAR10 and SVHN with MR = {0.1, 0.3, 0.5, 0.7, 0.9}

    MRMNISTCIFAR10SVHN
    cResSDE-NetResSDE-NetSDE-NetcResSDE-NetResSDE-NetSDE-NetcResSDE-NetResSDE-NetSDE-Net
    0.199.41±0.1298.82±0.1498.87±0.0679.53±0.5929.11±3.6823.23±0.2594.72±0.1865.27±3.7562.11±5.55
    0.399.04±0.1792.93±0.5494.98±0.1977.60±0.4614.80±2.2313.11±2.8994.68±0.1336.21±3.1831.22±4.80
    0.597.64±0.1174.64±1.5280.54±0.3673.70±0.2612.25±0.9810.33±0.0694.46±0.1724.53±1.7822.28±2.27
    0.789.89±0.2044.02±2.2549.25±0.1165.22±0.3910.92±0.5910.10±0.0893.02±0.1417.80±0.6019.06±0.47
    0.938.95±1.4916.09±1.3314.56±0.3440.69±0.7110.89±0.5710.18±0.0671.77±0.3013.73±0.5118.18±0.82
     注:加粗数据表示该模型具有最优平均性能。
    下载: 导出CSV
  • [1] KRIZHEVSKY A, SUTSKEVER I, HINTON G E. Imagenet classification with deep convolutional neural networks[C]//26th Advances in Neural Information Processing Systems. La Jolla: MIT press, 2012: 1097-1105.
    [2] HE K M, ZHANG X Y, REN S Q, et al. Deep residual learning for image recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition. Piscataway: IEEE Press, 2016: 770-778.
    [3] 张钹, 朱军, 苏航. 迈向第三代人工智能[J]. 中国科学:信息科学, 2020, 50(9): 1281-1302. doi: 10.1360/SSI-2020-0204

    ZHANG B, ZHU J, SU H. Toward the third generation of artificial intelligence[J]. Scientia Sinica (Informationis), 2020, 50(9): 1281-1302(in Chinese). doi: 10.1360/SSI-2020-0204
    [4] GUO C, PLEISS G, SUN Y, et al. On calibration of modern neural networks[C]//Proceedings of the 34th International Conference on Machine Learning. New York: ACM, 2017: 1321-1330.
    [5] CHEN R T Q, RUBANOVA Y, BETTENCOURT J, et al. Neural ordinary differential equations[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. La Jolla: MIT Press, 2018: 6572–6583.
    [6] KONG L K, SUN J M, ZHANG C. SDE-Net: Equipping deep neural networks with uncertainty estimates[C]//Proceedings of the 37th International Conference on Machine Learning. New York: ACM, 2020: 5405-5415.
    [7] ØKSENDAL B. Stochastic differential equations[M]. Berlin: Springer, 2003: 65-84.
    [8] BASS R F. Stochastic processes[M]. New York: Cambridge University Press, 2011: 6.
    [9] JEANBLANC M, YOR M, CHESNEY M. Continuous-path random processes: Mathematical prerequisites[M]. Mathematical Methods for Financial Markets. Berlin: Springer, 2009: 3-78.
    [10] HE K M, ZHANG X Y, REN S Q, et al. Identity mappings in deep residual networks[C]//European Conference on Computer Vision. Berlin: Springer, 2016: 630-645.
    [11] GORDON J, BRUINSMA W P, FOONG A Y K, et al. Convolutional conditional neural processes[C]//8th International Conference on Learning Representations. Addis Ababa: OpenReview.net, 2020.
    [12] REZENDE D, MOHAMED S. Variational Inference with Normalizing Flows[C]//Proceedings of the 32nd International Conference on Machine Learning. New York: ACM, 2015: 1530–1538.
    [13] RAISSI M, KARNIADAKIS G E. Hidden physics models: Machine learning of nonlinear partial differential equations[J]. Journal of Computational Physics, 2018, 357: 125-141. doi: 10.1016/j.jcp.2017.11.039
    [14] HE K M, SUN J. Convolutional neural networks at constrained time cost[C]//2015 IEEE Conference on Computer Vision and Pattern Recognition. Piscataway: IEEE Press, 2015: 5353-5360.
    [15] EMIN O, XAQ P. Skip connections eliminate singularities[C] //International Conference on Learning Representations. Vancouver: OpenReview.net, 2018.
    [16] LALLEY S P. Stochastic differential equations[D]. Chicago: University of Chicago, 2016: 1-11.
    [17] 朱军, 胡文波. 贝叶斯机器学习前沿进展综述[J]. 计算机研究与发展, 2015, 52(1): 16-26. doi: 10.7544/issn1000-1239.2015.20140107

    ZHU J, HU W B. Recent advances in Bayesian machine learning[J]. Journal of Computer Research and Development, 2015, 52(1): 16-26(in Chinese). doi: 10.7544/issn1000-1239.2015.20140107
    [18] BLUNDELL C, CORNEBISE J, KAVUKCUOGLU K, et al. Weight uncertainty in neural network[C]//Proceedings of the 32nd International Conference on Machine Learning. New York: ACM, 2015: 1613-1622.
    [19] MALININ A, GALES M J F. Predictive uncertainty estimation via prior networks[C]//Advances in Neural Information Processing System. La Jolla: MIT Press, 2018: 7047-7058.
    [20] GAL Y, GHAHRAMANI Z. Dropout as a Bayesian approximation: representing model uncertainty in deep learning[C]//Proceedings of the 33rd International Conference on International Conference on Machine Learning. New York: ACM, 2016: 1050-1059.
    [21] HENDRYCKS D, GIMPEL K. A baseline for detecting misclassified and out-of-distribution examples in neural networks[C]//International Conference on Learning Representations, arxiv: OpenReview.net, 2016.
    [22] LI C, CHEN C, CARLSON D, et al. Preconditioned stochastic gradient langevin dynamics for deep neural networks[C]//AAAI Conference on Artificial Intelligence. Palo Alto: AAAI, 2016: 1788-1794.
    [23] LAKSHMINARAYANAN B, PRITZEL A, BLUNDELL C. Simple and scalable predictive uncertainty estimation using deep ensemble[C]//Advances in Neural Information Processing System. La Jolla: MIT Press, 2017: 6402-6413.
  • 加载中
图(5) / 表(3)
计量
  • 文章访问数:  312
  • HTML全文浏览量:  21
  • PDF下载量:  34
  • 被引次数: 0
出版历程
  • 收稿日期:  2021-10-13
  • 录用日期:  2022-01-14
  • 网络出版日期:  2022-01-29
  • 整期出版日期:  2023-08-31

目录

    /

    返回文章
    返回
    常见问答