解耦知识蒸馏

文摘   科技   2024-09-03 11:00   广东  

作者:刘博

图片来源于网络

研究背景

随着深度学习的快速发展,知识蒸馏(Knowledge Distillation, KD)已成为提升轻量级模型性能的重要技术。现有KD方法主要集中于从教师模型的中间层提取深度特征,然而,基于logit蒸馏的潜力却常常被忽视。基于logit蒸馏通过教师模型输出的概率分布引导学生模型学习,这一过程在许多任务中被证明是有效的,因此如何有效利用logit信息以提升模型性能变得尤为重要。本文研究的核心问题是:如何通过解耦知识蒸馏(Decoupled Knowledge Distillation, DKD),充分利用logit信息,提升学生模型的表现和训练效率。

研究方法

本文提出了一种新的知识蒸馏方法,通过将传统的KD损失函数重构为如下两部分来实现。

1. 目标类知识蒸馏(Target Classification Knowledge Distillation,TCKD):

  • 该部分专注于目标类的二分类预测,旨在传递训练样本的“难度”信息。这意味着,学生模型可以更好地识别哪些样本更具挑战性,从而在训练过程中优先关注这些困难样本。

  • 在处理难度较大的数据时,TCKD可以帮助模型集中学习复杂的特征,从而提高整体准确性。

2. 非目标类知识蒸馏(Non-target Classification Knowledge Distillation, NCKD):

  • 这部分关注于非目标类的多类预测,其对基于logit蒸馏的有效性至关重要。通过学习非目标类的知识,学生模型能够更全面地理解类间关系,提升分类能力。

  • NCKD帮助模型捕捉更丰富的上下文信息,增强模型的泛化能力。

在DKD中,研究者引入了超参数,允许对TCKD和NCKD的权重进行动态调整。这种灵活性不仅提升了模型的学习效率,还增强了知识转移的适应性,使得模型能够根据不同任务的需求进行优化。传统知识蒸馏(KD)与解耦知识蒸馏(DKD)的区别如图1所示。

图1  传统知识蒸馏(KD)与解耦知识蒸馏(DKD)

图2展示了目标类知识蒸馏(TCKD)和非目标类知识蒸馏(NCKD)在CIFAR-100上的性能表现。实验显示,单独使用TCKD对学生模型的帮助有限,甚至可能导致性能下降,而NCKD的蒸馏效果与经典KD相当,甚至更优秀。这表明,与非目标类之间的知识相比,目标类相关知识的重要性可能较低。

图2  TCKD和NCKD对知识蒸馏性能的影响

实验结果

研究者在多个标准数据集上进行了系统的实验,并在不同维度分析了算法的表现。

 1. 性能提升

DKD方法在多个数据集上的表现均优于传统的特征蒸馏方法,尤其是在CIFAR-100和ImageNet上,准确率提升幅度可达几个百分点。实验结果如图3所示。

图3  在CIFAR-100数据集上实验验证

2. 训练效率

DKD在模型性能与训练成本(例如训练时间和额外参数)之间实现了最佳权衡。由于DKD是从经典KD重新公式化而来,因此其计算复杂度几乎与KD相同,并且不需要额外的参数。然而,基于特征的蒸馏方法需要额外的训练时间来蒸馏中间层特征,同时还会增加GPU内存成本。结果如图4所示。

图4  训练时间vs模型精度

研究结论

本文通过解耦知识蒸馏中的TCKD和NCKD部分,显著提升了知识蒸馏的整体效果,展示了基于logit蒸馏的重要性和潜在优势。DKD方法为未来的知识蒸馏研究提供了新的思路和方法,期待在更广泛的应用场景中发挥作用。这项研究不仅对知识蒸馏的理论发展具有重要意义,也为模型在资源受限的环境中的轻量化部署应用提供了强有力的支持。

相关论文

[1] Zhao B, Cui Q, Song R, et al. Decoupled knowledge distillation[C]//Proceedings of the IEEE/CVF Conference on computer vision and pattern recognition. 2022: 11953-11962.

[2] Hinton G. Distilling the Knowledge in a Neural Network[J]. arXiv preprint arXiv:1503.02531, 2015.

写在最后

我们的文章可以转载了呢~欢迎转载转发

想了解更多前沿科技与资讯?

点击上方入口关注我们!

欢迎点击右上方分享到朋友圈

香港中文大学(深圳)

网络通信与经济实验室

微信号 : ncel_cuhk


网络通信与经济
介绍网络、通信和经济交叉领域的最新科研成果和活动 —香港中文大学(深圳)网络通信与经济学实验室
 最新文章