我需要为 Keras VGG16 预训练权重吗?

2024-01-03

作为背景,我对机器学习领域相对较新,我正在尝试一个项目,目标是对 NBA 比赛中的比赛进行分类。我的输入是 NBA 比赛中每次比赛的 40 帧序列,我的标签是给定比赛的 11 个包罗万象的分类。

该计划是获取每个帧序列并将每个帧传递到 CNN 以提取一组特征。然后,给定视频中的每个特征序列将被传递到 RNN 上。

目前,我的大部分实现都使用 Keras,并且我选择对 CNN 使用 VGG16 模型。下面是一些相关代码:

video = keras.Input(shape = (None, 255, 255, 3), name = 'video')
cnn = keras.applications.VGG16(include_top=False, weights = None, input_shape=
(255,255,3), pooling = 'avg', classes=11)
cnn.trainable = True

我的问题是 - 如果我的目标是对 NBA 比赛的视频剪辑进行分类,将 VGG16 ConvNet 的权重初始化为“imagenet”对我是否仍然有益?如果是这样,为什么?如果没有,我如何训练 VGG16 ConvNet 以获得我自己的权重集,然后如何将它们插入到这个函数中?我没有找到任何有人在使用 VGG16 模型时包含自己的权重集的教程。

如果我的问题看起来很天真,我深表歉意,但我真的很感激任何帮助解决这个问题的帮助。


您是否应该针对您的特定任务重新训练 VGG16?绝对不!重新训练如此庞大的网络非常困难,并且在训练深度网络时需要大量的直觉和知识。让我们分析一下为什么您可以使用在 ImageNet 上预先训练的权重来完成您的任务:

  • ImageNet 是一个巨大的数据集,包含数百万张图像。 VGG16 本身已在强大的 GPU 上经过 3-4 天左右的训练。在 CPU 上(假设您没有像 NVIDIA GeForce Titan X 一样强大的 GPU)需要数周时间。

  • ImageNet 包含来自现实世界场景的图像。 NBA比赛也可以被视为现实世界的场景。因此,基于 ImageNet 特征的预训练很可能也可以用于 NBA 比赛。

实际上,您不需要使用预训练的 VGG16 的所有卷积层。让我们看一下内部 VGG16 层的可视化 https://blog.keras.io/img/vgg16_filters_overview.jpg看看他们检测到了什么(取自本文 https://blog.keras.io/how-convolutional-neural-networks-see-the-world.html;图片太大,为了紧凑,我只放了一个链接):

  • 第一个和第二个卷积块着眼于低级特征,例如角点、边缘等。
  • 第三和第四个卷积块着眼于表面特征、曲线、圆等。
  • 第五层着眼于高层特征

因此,您可以决定哪种功能对您的特定任务有益。您需要第五块的高级功能吗?或者您可能想使用第三块的中级功能?也许您想在 VGG 底层之上堆叠另一个神经网络?有关更多说明,请查看我编写的以下教程;它曾经出现在 SO 文档上。


使用 VGG 和 Keras 进行迁移学习和微调

在这个例子中,提出了三个简短而全面的子例子:

  • 从可用的预训练模型加载权重,包括Keras library
  • 在 VGG 的任意层之上堆叠另一个网络进行训练
  • 在其他图层中间插入一个图层
  • 使用 VGG 进行微调和迁移学习的技巧和一般经验法则

加载预先训练的权重

预训练于ImageNet型号,包括VGG-16 and VGG-19,可用于Keras。在这个例子中,这里和之后,VGG-16将会被使用。欲了解更多信息,请访问Keras 应用程序文档 https://keras.io/applications/.

from keras import applications

# This will load the whole VGG16 network, including the top Dense layers.
# Note: by specifying the shape of top layers, input tensor shape is forced
# to be (224, 224, 3), therefore you can use it only on 224x224 images.
vgg_model = applications.VGG16(weights='imagenet', include_top=True)

# If you are only interested in convolution filters. Note that by not
# specifying the shape of top layers, the input tensor shape is (None, None, 3),
# so you can use them for any size of images.
vgg_model = applications.VGG16(weights='imagenet', include_top=False)

# If you want to specify input tensor
from keras.layers import Input
input_tensor = Input(shape=(160, 160, 3))
vgg_model = applications.VGG16(weights='imagenet',
                               include_top=False,
                               input_tensor=input_tensor)

# To see the models' architecture and layer names, run the following
vgg_model.summary()

使用来自 VGG 的底层创建一个新网络

假设对于某些特定任务,图像尺寸为(160, 160, 3),您想要使用 VGG 的预训练底层,直到具有名称的层block2_pool.

vgg_model = applications.VGG16(weights='imagenet',
                               include_top=False,
                               input_shape=(160, 160, 3))

# Creating dictionary that maps layer names to the layers
layer_dict = dict([(layer.name, layer) for layer in vgg_model.layers])

# Getting output tensor of the last VGG layer that we want to include
x = layer_dict['block2_pool'].output

# Stacking a new simple convolutional network on top of it    
x = Conv2D(filters=64, kernel_size=(3, 3), activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Flatten()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(10, activation='softmax')(x)

# Creating new model. Please note that this is NOT a Sequential() model.
from keras.models import Model
custom_model = Model(input=vgg_model.input, output=x)

# Make sure that the pre-trained bottom layers are not trainable
for layer in custom_model.layers[:7]:
    layer.trainable = False

# Do not forget to compile it
custom_model.compile(loss='categorical_crossentropy',
                     optimizer='rmsprop',
                     metrics=['accuracy'])

删除多层并在中间插入新一层

假设您需要通过替换来加速 VGG16block1_conv1 and block2_conv2使用单个卷积层,以保存预训练权重的方式。 这个想法是将整个网络分解为不同的层,然后将其组装回来。这是专门针对您的任务的代码:

vgg_model = applications.VGG16(include_top=True, weights='imagenet')

# Disassemble layers
layers = [l for l in vgg_model.layers]

# Defining new convolutional layer.
# Important: the number of filters should be the same!
# Note: the receiptive field of two 3x3 convolutions is 5x5.
new_conv = Conv2D(filters=64, 
                  kernel_size=(5, 5),
                  name='new_conv',
                  padding='same')(layers[0].output)

# Now stack everything back
# Note: If you are going to fine tune the model, do not forget to
#       mark other layers as un-trainable
x = new_conv
for i in range(3, len(layers)):
    layers[i].trainable = False
    x = layers[i](x)

# Final touch
result_model = Model(input=layer[0].input, output=x)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

我需要为 Keras VGG16 预训练权重吗? 的相关文章

随机推荐

  • 将我的网站与 Google 日历集成

    我正在用 PHP 开发一个网站 该网站的用户可以从我提供的日历中进行预约 当用户进行预订时 应将其添加到我的谷歌日历中 对于这种情况 我需要什么样的身份验证机制 以下哪一项 1 网络应用程序 2 服务账户 3 安装的应用程序 注意 我不想访
  • Swift 3:获取 UIImage 中像素的颜色(更好:UIImageView)

    我尝试了不同的解决方案 例如this one https stackoverflow com questions 25146557 how do i get the color of a pixel in a uiimage with sw
  • React Native Lottie - 动画结束时反转

    Context 我是lottie react native的新手 并且已经成功实现了我的第一个动画 constructor props super props this state progress new Animated Value 0
  • 无限墙算法中的门

    问题 门在墙上你面对的是一堵向两个方向无限延伸的墙 墙上有一扇门 但你不知道有多远 也不知道在哪个方向 只有当你靠近门时你才能看到门 设计一种算法 使您能够通过最多步行 O n 步到达门 其中 n 是您的初始位置和门之间的 您未知的 步数
  • 在哪里获取 csv 样本数据? [关闭]

    Closed 此问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 作为开发的一部分 我需要处理一些 csv 文件 重要的是我正在用 java 编写一个超快速的 CSV 解
  • pdf图像色彩空间麻烦ios

    EDIT我一直在使用的pdf文件显然是 indesign 格式 无论这意味着什么 因此没有颜色配置文件 有谁知道如果可能的话我如何自己添加配置文件 编辑结束 预先感谢任何人可以为解决此问题提供帮助 首先让我告诉你 我在 IOS 开发方面是个
  • 客户端:访问 Windows Azure 驱动器?

    我正在开发一个 Azure 应用程序 其中一部分涉及用户浏览在线文件系统 为此 我尝试使用 Windows Azure 驱动器 但我不知道如何从客户端访问它 或者如何使其在服务器端可访问 目前 我只知道如何制作驱动器 CloudStorag
  • Docker推送错误“413请求实体太大”

    我设置了registry v2并使用nginx作为反向代理 当我将图像推送到注册表时 出现错误413 Request Entity Too Large 我已在 nginx conf 中将 client max body size 修改为 2
  • 使用起始 X/Y 和起始+扫描角度获取 ArcSegment 中的终点

    有没有人有一个好的算法来计算终点ArcSegment 这不是圆弧 而是椭圆弧 例如 我有这些初始值 起点 X 0 251 起点 Y 0 928 宽度半径 0 436 高度半径 0 593 起始角度 169 51 扫掠角 123 78 我知道
  • nginx 重定向循环,从 url 中删除 index.php

    我想要任何请求 例如http example com whatever index php 执行 301 重定向到http example com whatever 我尝试添加 rewrite index php 1 permanent l
  • 在 Java Web 应用程序中运行常规后台事件

    在播客 15 中 Jeff 提到他在 Twitter 上谈到了如何在后台运行常规事件 就好像它是一个正常功能一样 不幸的是我似乎无法通过 Twitter 找到它 现在我需要做类似的事情 并将这个问题抛给大众 我当前的计划是 当第一个用户 可
  • android.os.SystemProperties 在 Junit 测试期间不保存值

    android os SystemProperties 不能从外部使用 因此反射用于设置和获取操作 看android os SystemProperties 在哪里 https stackoverflow com questions 264
  • 如何使用 Boost Filesystem 忽略隐藏文件(以及隐藏目录中的文件)?

    我使用以下命令递归地迭代目录中的所有文件 try for bf recursive directory iterator end dir dir end dir const bf path p dir gt path if bf is re
  • 我的 Sublime 首选项文件在哪里?

    我正在使用优秀的Sublime Text 3 编辑器 http www sublimetext com 3在我的 Mac 上 我想关闭自动换行功能 所以我去了Preferences gt Settings Default 这将打开一个设置文
  • 错误C2995:函数模板已被定义

    此代码产生 17 错误 C2995 函数模板已被定义 在添加 include set h 标头之前存在一组单独的错误 有一个与此相关的私有 cpp 和 h 文件 File private set cpp Last modified on T
  • 如何告诉castor将空字段编组到空标签?

    我正在编组一个可以将某些字段设置为空的对象 我使用带有 xml 映射文件的 Castor 进行配置 我正在编组的课程是这样的 class Entity private int id private String name private S
  • 为什么我的 eclipse 控制台不以 StatE 启动

    我刚刚安装了带有 StatET 的 Eclipse 一切都是标准的 现在当我打开 StatET 透视图时 我的控制台似乎没有加载 有什么想法吗 我也使用 StatEt 它在这里有相同的行为 要启动 R 控制台 我必须从 运行按钮菜单 中选择
  • android.view.InflateException:错误膨胀类 android.widget.EditText 华硕 Android 5

    I m getting constant reports of this crash happening but it only happens on asus devices with android 5 as the image bel
  • 将重叠的多边形合并为单个多边形

    我有一个数据集 其中包含多个多边形的 x 和 y 坐标 例如 df lt data frame xpol c 0 304147897 0 272762377 0 239435395 0 204166952 0 166957048 0 127
  • 我需要为 Keras VGG16 预训练权重吗?

    作为背景 我对机器学习领域相对较新 我正在尝试一个项目 目标是对 NBA 比赛中的比赛进行分类 我的输入是 NBA 比赛中每次比赛的 40 帧序列 我的标签是给定比赛的 11 个包罗万象的分类 该计划是获取每个帧序列并将每个帧传递到 CNN