知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍
2022/3/19 12:58:30
本文主要是介绍知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
如有错误,恳请指出。
这篇博客将记录我看视频后对知识蒸馏的笔记,视频链接见参考资料[1],其中包含知识蒸馏的算法原理(训练流程与测试流程),以及知识蒸馏能够正常工作的背后机理与其发展的趋势及展望。
在这篇博客中,主要都是介绍没有涉及具体的代码,我另外还将会记录一下知识蒸馏的测试过程,见另外的一篇博客。
文章目录
- 框架
- 1. 知识蒸馏的算法原理
- 1.1 知识的表示与迁移
- 1.2 训练流程
- 1.3 推理过程
- 1.4 KD与Labe Smoothing的区别
- 2. 知识蒸馏的应用场景
- 3. 知识蒸馏的背后机理
- 4. 知识蒸馏的发展趋势
框架
1)第一个方向是把一个已经训练好的臃肿的网络进行瘦身
权值量化:把模型的权重从原来的32个比特数变成用int8,8个比特数来表示,节省内存,加速运算
剪枝:去掉多余枝干,保留有用枝干。分为权重剪枝和通道剪枝,也叫结构化剪枝和非结构化剪枝,一根树杈一根树杈的剪叫非结构化剪枝,也可以整层整层的剪叫结构化剪枝。
2)第二个方向是在设计时就考虑哪些算子哪些设计是轻量化的
轻量化网络有很多需要考虑的内容:参数量、计算量
3)第三个方向是在数值运算的角度来加速各种算子的运算
比如im2col+GEMM,就是把卷积操作转成矩阵操作,矩阵操作是很多算法库里内置的功能,比如py,tf和matlab都有底层的加速到极致的矩阵运算的算子
4)第四个方向就是硬件部署
用英伟达的TensorRT库,把模型压缩成中间格式,部署在Jetson开发板上;Tensorflow-slim和Tensorflow-lite是tensorflow轻量化的生态;因特尔的openvino;FPGA集成电路也可以部署人工智能算法
1. 知识蒸馏的算法原理
1.1 知识的表示与迁移
把左边的马图像喂给分类模型,会有很多类别,每个类别识别出一个概率,训练网络时,我们只会告诉网络,这张图片是马,其余是驴是汽车的概率都是0,这个就是hard targets,用hard targets训练网络,但这就相当于告诉网络,这就是一匹马,不是驴不是车,而且不是驴不是车的概率是相等的,这是不科学的。若是把马的图片喂给已经训练好的网络里面,网络给出soft targets这个结果,是马的概率为0.7,为驴的概率为0.25,为车的概率是0.05,所以soft targets就传递了更多的信息
总结:Soft Label包含了更多“知识”和“信息,像谁,不像谁,有多像,有多不像,特别是非正确类别概率的相对大小(驴和车)
此外还引入蒸馏温度T,把原来比较硬的soft targets变的更软,更软的soft targets去训练学生网络,那些非正确类别概率的信息就暴露的越彻底,相对大小的知识就暴露出来,让学生网络去学
- T为1,就是原softmax函数,softmax本来就是把每个类别的logic强行变成0-1之间的概率,并且求和为1,是有放大差异的功能,如果logic高一点点,经过softmax,都会变的很高。
- T越小,非正确类别的概率相对大小的信息就会暴露的更明显;T越大,曲线就会变得更soft,高的概率给降低,低的概率会变高,贫富差距就没有了。
关于对softmax的温度测试同样可以见我另外的一篇博客,其中包含对温度改变后logits的变化。
1.2 训练流程
总的损失分为两个部分:
1)Distillation loss:一部分是来自于利用温度T进行软分类,也就是与教师网络的输出结果进行交叉熵损失
2)Student loss:另一部分是进行原始的硬分类,也就是与真实标签进行交叉熵损失
总的损失就是以上两个损失的加权和,Distillation loss与Student loss的具体计算方法见上图所示。
具体来说,知识蒸馏的流程是,一方面让Student模型去拟合Teacher网络输出的软标签信息,从而让Student网络可以学校到一些潜在的语义信息归纳Teacher网络的经验;另一方面,让Student网络与真实的硬标签做一个交叉熵损失了解真实数据的差异,两种损失通过一个权重相加形成总损失。
1.3 推理过程
当训练完成后,推理过程中就不需要温度为T去进行测试了,直接对网络输出的logit进行softmax进行预测即可。Teacher网络是比较臃肿的,Student网络是比较轻巧的。也就是可以利用一个比较轻巧的模型学习到一个质量比较好但是参数量比较大的模型,然后就可以部署在嵌入式设备中。
1.4 KD与Labe Smoothing的区别
Label Smoothing是为了模型太过自信,为此给予其他类别也有一点的分数,也就是杜绝了模型读cat类别的100%预测,使其拥有一点回流余地,但是很明显Label Smoothing会丢失很多的信息,其不能判断类别之间的关系,不能判断类别之间有多像与有多不像,所以Labe Smoothing是没有Soft label进行蒸馏的效果好的。
2. 知识蒸馏的应用场景
1)无监督的训练
将海量没有标签的数据集输入到已经训练好的Teacher网络中,获得的soft label就可以指导训练Student网络,这是无监督的一种方式。
2)Few Shot/Zero Shot
由于Teacher网络将经验传授了给Student网络,这使得就算训练集中没有测试集中要出现的数据,也就是Student网络可能从来没有见过某一类的数据,但在测试的过程中仍然可以对其进行正确分类。又或者只给Student网络提供少量的某种类型的数据,还是可以在测试过程中对这种少样本的数据集进行正确的分类。
3)防止过拟合
对于大模型的训练很容易会出现过拟合的情况,所以需要设置一些正则化Dropout或者是数据增强Data Aug来增加模型的泛化能力。而对于Student网络来说,由于需要部署所以是轻量级,模型参数肯定比较小,所以训练Student网络不容易出现过拟合的情况。
4)模型压缩
知识蒸馏的最重要的目的就是为了让一个轻量级的模型可以获得重量级模型的经验,从而可以轻易的部署在移动端或者是嵌入式端中,而且这种soft label的训练方式可以有效的指导Student网络。
3. 知识蒸馏的背后机理
question:为什么知识蒸馏的效果这么好,这里有一个有说服力的解释:
1)解释一
就是说,对于大模型来说,其可解的空间可能会比较大(比如Teacher网络的可解空间是绿色区域),更容易找到一个比较好的解;而对于一个小模型,其可解的空间相比大模型来说会比较小(比如Student网络的可解空间是蓝色区域),那么其找一个比较好的解可能比较困难而且也其可解区域不完全包括大模型的解集,优化的方向也难以控制。
而Student网络的作用就体现出来了,大模型的红色解集区域会慢慢引导小模型的黄色解集区域到一个比较靠近大模型解集附近的也橙色区域,这就是知识蒸馏的一个知识引导迁移的作用,让小模型获得一个更好的解,毕竟一般来说大模型的解肯定要比小模型的解要好的。
2)解释二
在Bert中也用到了知识蒸馏技术,其中它也给了一个有说服力的解释:在训练一个大型的语言模型时,会训练出很多比较容易的特征;而迁移学习与微调就是把这些冗余的特征精选出一些有用的特征来进行泛化和迁移。而如何获取揽括比较多的有用特征,就只能是模型参数量大一点,显而易见对于小模型来说其参数量是比较小的,所以其很难从浩如烟海的数据中找出比较有用的特征,所以队大模型进行知识蒸馏就是告诉小模型哪里知识的冗余的哪些特征是有用的。
这样通过蒸馏的技术,相当于是把大模型进行了一个精炼,让小模型知道大模型的哪些参数哪些部位是有用的,从而达到了模型压缩的效果。
4. 知识蒸馏的发展趋势
主要方向:
1)教学相长
之前一直是大模型老师网络来教导学生网络,但是其实也可以通过学生网络反过来指导教师网络,使得教师网络可以进一步成长。两个网络互帮互助,相互学习,我觉得这种模式其实是互学习的一种。
2)助教,多个老师/同学
在刚刚的角色中只有一个老师网络与一个学生网络,但是可以引入多个老师多个学生的模式,甚至是进入一个助教的模式。也就是不需要全部问题都问实力强厚的老师网络,可以先从助教网络来吸取部分经验,再让老师网络占总体方向的引导,也就是分工来指导学生网络。
3)多模态,知识图谱,预训练大模型的知识蒸馏
4)知识的表示(中间层),数据集蒸馏,对比学习
对于刚刚所展示的知识蒸馏,其实用的是网络输出最后一层的soft target表示出来的,那么其实网络的中间层也有尝试的解剖出来进行知识蒸馏,整个中间层的结果可以是feature map,可以是feature map构建出来的自注意力图,也可以是层之间的关系。几篇例子:
- Attention Transfer:用中间层的feature map来进行知识蒸馏
- Channel-wise knowledge distillation for dense prediction:用中间层的注意力图来进行知识蒸馏
- Contrastive Representation Distillation:用对比学习来进行知识蒸馏
下面分别对几种知识蒸馏的知识表示进行拓展展示:
- Response-Based Knowledge:把预测结果作为知识的表示
- Feature-Based Knowledge:把中间层作为知识的表示
- Relation-Based Knowledge:把注意图之间的关系作为知识的表示
彩蛋:代码库工具
ps:在视频的随后,up还贴了几个知识蒸馏的代码块工具,这里我顺便贴出来
1)MMRazor:OpenMMLab模型压缩工具(github/open-mmlab/mmrazor)
2)MMDeploy:OpenMMLab模型转换与部署工具箱(github/open-mmlab/mmdeploy)
3)RepDistiller:12个SoTA知识蒸馏算法的Pytorch复现(github/Hobbitlong/RepDistiller)
参考资料:
1)https://www.bilibili.com/video/BV1N44y1n7mU
2)https://www.bilibili.com/read/cv15391720?from=note
这篇关于知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-23Springboot应用的多环境打包入门
- 2024-11-23Springboot应用的生产发布入门教程
- 2024-11-23Python编程入门指南
- 2024-11-23Java创业入门:从零开始的编程之旅
- 2024-11-23Java创业入门:新手必读的Java编程与创业指南
- 2024-11-23Java对接阿里云智能语音服务入门详解
- 2024-11-23Java对接阿里云智能语音服务入门教程
- 2024-11-23JAVA对接阿里云智能语音服务入门教程
- 2024-11-23Java副业入门:初学者的简单教程
- 2024-11-23JAVA副业入门:初学者的实战指南