基于 Keras 对深度学习模型进行微调的全面指南 Part 2

2018-08-16    来源:raincent

容器云强势上线!快速搭建集群,上万Linux镜像随意使用
原标题 A Comprehensive guide to Fine-tuning Deep Learning Models in Keras (Part I),作者为 Felix Yu 。

本部分属该两部系列中的第二部分,该系列涵盖了基于 Keras 对深度学习模型的微调。第一部分阐述微调背后的动机和原理,并简要介绍常用的做法和技巧。本部分将详细地指导如何在 Keras 中实现对流行模型 VGG,Inception 和 ResNet 的微调。

为什么选择 Keras ?

Keras 是建立在 Theano 或 TensorFlow 之上的一个极简的神经网络库。该库允许开发人员快速地将想法原型化。除非你正在做一些涉及制定具有截然不同的激活机制的神经架构的前沿研究,否则 Keras 将提供构建相当复杂的神经网络所需的所有构建模块。

同时附带了大量的文档和在线资源。

硬件说明

我强烈建议在涉及繁重计算的Covnet训练时,使用GPU加速。速度差异相当大,我们谈论的 GPU 大约几小时而 CPU 需要几天。

我推荐使用 GTX 980 Ti 或者有点贵的 GTX 1080,它售价约 600 美元。

Keras 微调

我已经实现了基于 Keras 的微调启动脚本,这些脚本存放在这个 github 页面中。包括 VGG16,VGG19,GoogleLeNet,nception-V3 和 ResNet50 的实现。这样,你就可以为自己的微调任务定制脚本。

下面是如何使用脚本微调 VGG16 和 Inception-V3 模型的详细演练。

VGG16 微调

VGG16 是牛津大学视觉几何组(VGG)在 2014 年 ILVRC(ImageNet)竞赛中使用的 16 层卷积神经网络。该模型在验证集上达到了 7.5% 的前 5 错误率,这使得他们在竞赛中获得了第二名。

VGG16 模型示意图:

 

 

可以在 vgg16.py 中找到用于微调 VGG16 的脚本。vgg_std16_model 函数的第一部分是 VGG 模型的结构。定义全连接层之后,我们通过下面一行将 ImageNet 预训练权重加载到模型中:

 

 

为了进行微调,我们截断了原始的 softmax 层,并使用下面一段我们自己的代码替换:

 

 

最后一行的 num_class 变量代表我们分类任务中的类别标签的数量。

有时,我们希望冻结前几层的权重,使它们在整个微调过程中保持不变。假设我们想冻结前 10 层的权重,可以通过以下几行代码来完成:

 

 

然后,我们通过使用随机梯度下降 (SGD) 算法最小化交叉熵损失函数来微调模型。注意:我们使用的初始学习率为 0.001,小于从头开始训练的模型学习率(通常为 0.01)。

 

 

img_rows,img_cols 和 channel 定义输入的维度。对于分辨率为 224×224 的彩色图像,img_rows=img_cols=224,channel=3。

接下来,我们加载数据集,将其拆分为训练集和测试集,然后开始微调模型:

 

 

微调过程需要一段时间,具体取决于你的硬件。完成后,我们使用模型对验证集进行预测,并且返回交叉熵损失函数的分数。

 

 

Inception-V3 微调。 Inception-V3 在 2015 年 ImageNet 竞赛中获得第二名,验证集上的前 5 个错误率为 5.6%。

该模型的特点是使用了Inception模块,它是由不同维度的内核生产的特征映射的串联。

27 层 Inception-V1 模型示意图(类似于 V3 的想法):

 

 

用于微调 Inception-V3 的代码可以在 inception_v3.py 中找到。这个过程与 VGG16 很相似,但有细微差别。由于Inception模块分支需要合并,Inception-V3 不使用 Keras 的序列模型,因此我们不能简单地使用 model.pop() 截断顶层。

取而代之的是,在创建模型并加载 ImageNet 权重之后,我们通过在最后一个起始模块(X)上定义另一个全连接的 softmax(x_newfc) 来执行等效于顶层截断。这使用以下代码来完成:

 

 

这就是 Inception-V3。可以在此处找到其他模型(如 VGG19,GoogleLeNet 和 ResNet)。

网络微调操作

如果你是深度学习或者计算机视觉的从业人员,很可能你已经尝试过微调预训练的网络来解决自己的分类问题。

对我来说,我遇到了有趣的 Kaggle 比赛,要求候选人通过分析车载摄像头图像来识别注意力不集中的驾驶员。这是我尝试使用基于 Keras 微调的好机会。 按照上面列出的微调方法,结合数据预处理、数据增强和模型集成,我们团队在竞赛中获得了前 4% 的名次。

本文详细介绍了我们使用的方法和经验。

标签: Google 代码 脚本 网络

版权申明:本站文章部分自网络,如有侵权,请联系:west999com@outlook.com
特别注意:本站所有转载文章言论不代表本站观点!
本站所提供的图片等素材,版权归原作者所有,如需使用,请与原作者联系。

上一篇:理解Python数据类:Dataclass fields 的概述(下)

下一篇:在工程领域中,机器学习的数学理论基础尤为重要