一、综述
预训练网络(pretrained network)是一个保存好的网络,之前已在大型数据集(通常是大规模图像分类任务)上训练好。如果这个原始数据集足够大且足够通用,那么预训练网络学到的特征的空间层次结构可以有效地作为视觉世界的通用模型,即使这些新问题涉及的类别和原始任务完全不同。
假设有一个在 ImageNet 数据集(140 万张标记图像,1000 个不同的类别)上训练 好的大型卷积神经网络。ImageNet 中包含许多动物类别,其中包括不同种类的猫和狗,因此可 以认为它在猫狗分类问题上也能有良好的表现。
使用预训练网络有两种方法:特征提取(feature extraction)和微调模型(fine-tuning)。
二、特征提取
特征提取是使用之前网络学到的表示来从新样本中提取出有趣的特征。然后将这些特征输 入一个新的分类器,从头开始训练。
用于图像分类的卷积神经网络包含两部分:首先是一系列池化层和卷积层,最后是一个密集连接分类器。第一部分叫作模型的卷积基(convolutionalbase)。对于卷积神经网络而言,特征提取就是取出之前训练好的网络的卷积基,在上面运行新数据,然后在输出上面 训练一个新的分类器
为什么仅重复使用卷积基?我们能否也重复使用密集连接分类器?一般来说,应该避免这么做。原因在于卷积基学到的表示可能更加通用,因此更适合重复使用。卷积神经网络的特征图表示通用概念在图像中是否存在,无论面对什么样的计算机视觉问题,这种特征图都可能很有用。但是,分类器学到的表示必然是针对于模型训练的类别,其中仅包含某个类别出现在整张图像中的概率信息。此外,密集连接层的表示不再包含物体在输入图像中的位置信息。密集连接层舍弃了空间的概念
某个卷积层提取的表示的通用性(以及可复用性)取决于该层在模型中的深度。模型中更靠近底部的层提取的是局部的、高度通用的特征图(比如视觉边缘、颜色和纹理),而更 靠近顶部的层提取的是更加抽象的概念(比如“猫耳朵”或“狗眼睛”)。如果你的新数据集与原始模型训练的数据集有很大差异,那么最好只使用模型的前几层来做特征提取,而不是使用整个卷积基。
三、特征提取程序
1、在你的数据集上运行卷积基,将输出保存成硬盘中的Numpy 数组,然后用这个数据作 为输入,输入到独立的密集连接分类器中。
验证集损失从一开始就在增加,而验证集精度在90%上下摆动,说明从一开始就已经过拟合了,原因是由于训练数据太少了。因此,我添加了一个ImageDataGenerator以对图片进行增强处理,想要将增强图片输入进卷积基以得到增强的卷积基的输出,得增强处理后的图片达到8000张,但是情况却和增强之前一样,不知为何?推测是因为卷积基对这些增强前后的图片的预测输出都没有太大的区别
from keras.applications import VGG16
conv_base = VGG16(weights='imagenet',
include_top=False,
input_shape=(150, 150, 3))
conv_base.summary()
def extract_features(directory, sample_count, enhance=0):
features = np.zeros(shape=(sample_count, 4, 4, 512))
labels = np.zeros(shape=(sample_count))
if enhance:
generator = train_datagen.flow_from_directory(
directory,
target_size=(150, 150),
batch_size=batch_size,
class_mode='binary')
else:
generator = test_datagen.flow_from_directory(
directory,
target_size=(150, 150),
batch_size=batch_size,
class_mode='binary')
i = 0
for inputs_batch, labels_batch in generator:
features_batch = conv_base.predict(inputs_batch)
features[i * batch_size : (i + 1) * batch_size] = features_batch
labels[i * batch_size : (i + 1) * batch_size] = labels_batch
i += 1
if i * batch_size >= sample_count:
break
return features, labels
train_features , train_labels = extract_features(train_dir, 2000)
validation_features , validation_labels = extract_features(validation_dir, 1000)
test_features , test_labels = extract_features(test_dir, 1000)
train_features = np.reshape(train_features, (len(train_features), 4 * 4 * 512))
test_features = np.reshape(test_features, (len(test_features), 4 * 4 * 512))
validation_features = np.reshape(validation_features, (len(validation_features), 4 * 4 * 512))
model = models.Sequential()
model.add(layers.Dense(256, activation='relu', input_dim=4*4*512))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(1, activation='sigmoid'))
history = model.fit(train_features, train_labels,
epochs=30,
batch_size=20,
validation_data=(validation_features, validation_labels))
2、在顶部添加 Dense 层来扩展已有模型,并在输入数据上端到端地运行整个模型。
这样可以使用数据增强,因为每个输入图像进入模型时都会经过卷积基。 但出于
同样的原因,这种方法的计算代价比第一种要高很多。
书上的测试该方式可以达到96%,但是测试中只达到90%精度,损失智在一定幅度内摆动
from keras import layers, models
from keras.applications import VGG16
conv_base = VGG16(weights='imagenet',
include_top=False,
input_shape=(150, 150, 3))
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.summary()
len(model.trainable_weights)
conv_base.trainable = False
len(model.trainable_weights)
四、模型微调
微调是指将其顶部的几层“解冻”,并将这解冻的几层和新增加的部分联合训练。之所以叫作
微调,是因为它只是略微调 整了所复用模型中更加抽象的表示,以便让这些表示与手头的问题更
加相关。
冻结VGG16 的卷积基是为了能够在上面训练一个随机初始化的分类器。同理,只有上面的
分类器已经训练好了,才能微调卷积基的顶部几层。
该方式与书籍的测试结果相同,达到97%的准确率
微调网络的步骤如下。
(1) 在已经训练好的基网络(base network)上添加自定义网络。
(2) 冻结基网络。
(3) 训练所添加的部分。
(4) 解冻基网络的一些层。
(5) 联合训练解冻的这些层和添加的部分。
conv_base.trainable = True
set_trainable = False
for layer in conv_base.layers:
if layer.name == 'block5_conv1':
set_trainable = True
if set_trainable:
layer.trainable = True
else:
layer.trainable = False
model.compile(loss='binary_crossentropy',
optimizer=optimizers.RMSprop(lr=1e-5),
metrics=['acc'])
五、
如果损失没有降低,那么精度怎么能保持稳定或提高呢?答案很简单:图中展示的是逐点(pointwise)
损失值的平均值,但影响精度的是损失值的分布,而不是平均值,因为精度是 模型预测的类别概率的二
进制阈值。即使从平均损失中无法看出,但模型也仍然可能在改进。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)