WhiteDLG's

自蒸馏:从自我提升到辅助训练

2026-3-4 Deep Learning / Model Compression 188 Views

1. Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation

关注的问题

为了让神经网络(如CNN)在自动驾驶、医疗诊断等对准确性要求极高的领域发挥作用,科学家通常会增加网络的层数(更深)或每层的神经元数量(更宽)。

这种做法的副作用是:模型会变得巨大,需要更强的计算能力和更多的存储空间,导致反应变慢,不适合在手机、嵌入式设备等资源有限的平台上运行。

设计方法

作者将一个完整的神经网络从前往后切成多个部分(分段)。然后,让网络后面(更深) 的部分(它通常能学到更抽象、更精准的特征)作为“老师”,去指导前面(更浅) 的部分(作为“学生”)进行学习。这样,浅层部分就能提前学到深层部分才能提取到的知识,从而提升整个网络的性能。

框架

自蒸馏框架图 (原图 image.png)
图1:配备了自蒸馏的ResNet细节图 (示意图)

框架图详解

(i) 将ResNet按深度分为四个部分:图中将一个完整的ResNet(如ResNet50)根据其残差块划分成了4个连续的段:Section 1, Section 2, Section 3, Section 4。Input从左侧进入,依次经过这些部分,最终从最顶部的Classifier 4/4输出。

(ii) 在每个部分后设置瓶颈层和全连接层,构成多个分类器:在每个分段之后,都额外添加了Bottleneck -> FC,构成新的分类器。因此整个网络有了4个分类器:Classifier 1/4, 2/4, 3/4, 4/4。Classifier 4/4是“老师”,其他是“学生”。

(iii) 所有分类器都可以独立使用:具有不同的精度和响应时间,为可扩展推理提供灵活性。

(iv) 每个分类器都在三种监督下进行训练

  • 绿色箭头(标签监督):每个分类器的输出和真实标签计算交叉熵损失。
  • 蓝色箭头(蒸馏监督):从最深“老师”Classifier 4/4的Softmax输出出发,指向浅层学生,计算KL散度损失。
  • 灰色箭头(提示监督):从最深老师特征图出发,指向学生瓶颈层,计算L2损失。

(v) 虚线以下的部分可以在推理时移除:蒸馏结束后,辅助分类器可被丢弃,不增加推理开销。

其他

  • 大规模的模型的部署成本往往较高,因此模型压缩技术成为研究热点。
  • 张林峰: 大模型的训练难度和成本更高,高校难以满足训练的要求。因此,我近几年对知识蒸馏的研究反而变少了。
  • 我们现在蒸馏大模型的方式还比较传统,知识蒸馏领域可以探索的内容还有很多。例如,通过蒸馏提升通用大模型在特定应用领域的专用能力。此外,知识蒸馏需同时加载教师模型和学生模型进行大量训练计算,蒸馏成本较高,我们可以研究如何降低蒸馏的成本。
  • 现在,我从事的工作包括数据集压缩和模型压缩,这两者在某种程度上具有统一性。我们从数据的视角出发,通过减少数据量来提高模型的推理和学习效率。具体来说,以往的模型压缩工作主要关注减小模型的参数量,而我针对数据集压缩的工作则是关注在固定模型参数量的前提下,减少模型处理的数据量,降低token的规模。这样,我们可以通过减少数据量来加快推理速度、提高学习效率。从这个角度来看,数据集压缩和我目前从事的模型压缩的最终目的是相同的,是降低AI模型计算成本的两个互补的角度。
        —— 《Dataset Distillation with Neural Characteristic Function: A Minmax Perspective》

瓶颈层和全连接层

在论文图中,它们组合在一起构成了一个辅助分类器

1. 全连接层

通俗解释:“分类头”。卷积神经网络的前面部分负责“看”和提取特征,全连接层负责“想”,输出概率分布。

技术作用:维度变换,分类决策。

2. 瓶颈层

通俗解释:“信息压缩器”或“适配器”。将不同尺寸的特征图统一压缩,以便与深层老师对比。

技术作用:降维、统一尺寸、减少干扰。

总结:两者的关系

Bottleneck在前,FC在后:Section输出 → Bottleneck(压缩适配) → FC(分类)。瓶颈层负责特征对齐(损失源3),全连接层负责输出对齐(损失源2和1)。

总结:把神经网络切分成多个部分,最后一部分作为老师教前面的,然后交叉熵损失+KL散度+L2损失。

2. Auxiliary Training: Towards Accurate and Robust Models

关注的问题

模型精度和鲁棒性之间存在一种令人尴尬的权衡——一方的提升会导致另一方的下降。如何同时提高精度和鲁棒性仍然是一个挑战。

架构

辅助训练框架图 (原图 image-1.png)
图2:辅助训练框架(Figure 1)示意图

一、自蒸馏设计方法解释

论文中的“输入感知自蒸馏”是一种更平等的“学生-学生”相互学习框架

1. 设计动机

传统知识蒸馏需要一个预训练的“教师”。但在鲁棒性问题上,很难找到万能教师。因此,作者将任务分解:主分类器负责精度(干净图像),辅助分类器专门处理特定损坏类型(学习鲁棒特征)。

2. “自蒸馏”的实现机制

更接近“协同正则化”:通过损失函数 \(\Omega\) 实现:

\[ \Omega (\theta_g^0,\theta_g^j) = \ell_{KL}(g(\hat{x}^0;\theta_g^0),g(\hat{x}^j;\theta_g^j)) + \gamma \| \theta_1 - \theta_2\| _2^2 \]
  • KL散度项:强制主分类器对干净样本的输出与辅助分类器对损坏样本的输出在概率分布上一致。
  • L2权重约束:训练后期强制主、辅分类器参数趋同,内化知识。

3. “输入感知”的含义

体现在选择性批归一化(Selective BN):根据输入图像的类型(干净或特定损坏),分别计算各自的均值和方差,避免统计量污染。

4. 总结:自蒸馏的本质

通过辅助分类器作为“锚点”,引导主分类器在特征空间中找到一个既靠近干净数据决策边界,又对输入扰动不敏感的位置。

二、框架图(Figure 1)详细解释

阶段 (a): 训练阶段 1 (Training Period 1) —— 并行学习

  • 输入图像:干净图像 + 损坏图像(模糊、噪声等)。
  • 共享卷积层 + 选择性BN:根据输入类型使用不同的统计数据归一化。
  • 特征分流与分类:主分类器接收干净图像特征;辅助分类器接收损坏图像特征(带注意力模块、瓶颈层)。
  • 关键机制:输入感知自蒸馏(KL散度对齐)。

阶段 (b): 训练阶段 2 (Training Period 2) —— 权重合并:最后阶段通过L2权重约束强制主、辅分类器权重相同。

阶段 (c): 测试阶段 (Testing Period) —— 无损推理:移除所有辅助分类器,仅用主分类器推理,无额外开销。

关于三个分类器

1. 主分类器 —— 负责"精度":只接收干净图像特征,结构简单,最终部署使用。

2. 噪声图像的辅助分类器 —— 负责针对"噪声"的鲁棒性:只接收噪声损坏图像特征,含注意力模块和瓶颈层。

3. 模糊图像的辅助分类器 —— 负责针对"模糊"的鲁棒性:只接收模糊损坏图像特征,结构同上。

它们的关系:分工与整合

  • 共享基础卷积层。
  • 知识传递通过KL散度(自蒸馏)。
  • 最终权重合并使主分类器内化专家知识。
总结:“学生-学生”相互学习框架。【做的人很多了】

其他

SBN:

损失/模糊图像用损失/模糊的批归一化均值与方差。

BN干嘛的?

核心作用:解决“中间层分布漂移”(内部协变量偏移)。BN强制把每一层的输入拉回到稳定分布。

润滑作用:允许更大的学习率,加速训练,使网络对学习率不敏感。

正则作用:提供隐形的抗过拟合,每个mini-batch的噪声带来泛化增益。

总结:没有BN会怎样?调参难、收敛慢、容易梯度爆炸/消失。因此选择性BN在论文中至关重要。

  Tags: Self Distillation, Model Compression, 张林峰, Auxiliary Training
Comments Section (Giscus Loading...)