学习TensorFlow,调用预训练好的网络(Alex, VGG, ResNet etc)

2023-11-09

       视觉问题引入深度神经网络后,针对端对端的训练和预测网络,可以看是特征的表达和任务的决策问题(分类,回归等)。当我们自己的训练数据量过小时,往往借助牛人已经预训练好的网络进行特征的提取,然后在后面加上自己特定任务的网络进行调优。目前,ILSVRC比赛(针对1000类的分类问题)所使用数据的训练集126万张图像,验证集5万张,测试集10万张(标注未公布),大家一般使用这个比赛的前几名的网络来搭建自己特定任务的神经网络。

      本篇博文主要简单讲述怎么使用TensorFlow调用预训练好的VGG网络,其他的网络(如Alex, ResNet等)也是同样的套路。分为三个部分:第一部分下载网络架构定义以及权重参数,第二部分是如何调用预训练网络中的feature map,第三部分给出参考资料。注:资料是学习查找整理而得,理解有误的地方,请多多指正~

一、下载网络架构定义以及权重参数

https://github.com/leihe001/tensorflow-vgg  训练和测试网络的定义

https://mega.nz/#!YU1FWJrA!O1ywiCS2IiOlUCtCpI6HTJOMrneN-Qdv3ywQP5poecM VGG16

https://mega.nz/#!xZ8glS6J!MAnE91ND_WyfZ_8mvkuSa2YcA7q-1ehfSm-Q1fxOvvs VGG19


二、调用预训练网络中的feature map(以VGG16为例)

import inspect
import os

import numpy as np
import tensorflow as tf
import time

VGG_MEAN = [103.939, 116.779, 123.68]


class Vgg16:
    def __init__(self, vgg16_npy_path=None):
        if vgg16_npy_path is None:
            path = inspect.getfile(Vgg16)
            path = os.path.abspath(os.path.join(path, os.pardir))
            path = os.path.join(path, "vgg16.npy")
            vgg16_npy_path = path
            print path
	# 加载网络权重参数
        self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item()
        print("npy file loaded")

    def build(self, rgb):
        """
        load variable from npy to build the VGG

        :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1]
        """

        start_time = time.time()
        print("build model started")
        rgb_scaled = rgb * 255.0

        # Convert RGB to BGR
        red, green, blue = tf.split(3, 3, rgb_scaled)
        assert red.get_shape().as_list()[1:] == [224, 224, 1]
        assert green.get_shape().as_list()[1:] == [224, 224, 1]
        assert blue.get_shape().as_list()[1:] == [224, 224, 1]
        bgr = tf.concat(3, [
            blue - VGG_MEAN[0],
            green - VGG_MEAN[1],
            red - VGG_MEAN[2],
        ])
        assert bgr.get_shape().as_list()[1:] == [224, 224, 3]

        self.conv1_1 = self.conv_layer(bgr, "conv1_1")
        self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
        self.pool1 = self.max_pool(self.conv1_2, 'pool1')

        self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
        self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
        self.pool2 = self.max_pool(self.conv2_2, 'pool2')

        self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
        self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
        self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
        self.pool3 = self.max_pool(self.conv3_3, 'pool3')

        self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
        self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
        self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")
        self.pool4 = self.max_pool(self.conv4_3, 'pool4')

        self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
        self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
        self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
        self.pool5 = self.max_pool(self.conv5_3, 'pool5')

        self.fc6 = self.fc_layer(self.pool5, "fc6")
        assert self.fc6.get_shape().as_list()[1:] == [4096]
        self.relu6 = tf.nn.relu(self.fc6)

        self.fc7 = self.fc_layer(self.relu6, "fc7")
        self.relu7 = tf.nn.relu(self.fc7)

        self.fc8 = self.fc_layer(self.relu7, "fc8")

        self.prob = tf.nn.softmax(self.fc8, name="prob")

        self.data_dict = None
        print("build model finished: %ds" % (time.time() - start_time))

    def avg_pool(self, bottom, name):
        return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)

    def max_pool(self, bottom, name):
        return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)

    def conv_layer(self, bottom, name):
        with tf.variable_scope(name):
            filt = self.get_conv_filter(name)

            conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')

            conv_biases = self.get_bias(name)
            bias = tf.nn.bias_add(conv, conv_biases)

            relu = tf.nn.relu(bias)
            return relu

    def fc_layer(self, bottom, name):
        with tf.variable_scope(name):
            shape = bottom.get_shape().as_list()
            dim = 1
            for d in shape[1:]:
                dim *= d
            x = tf.reshape(bottom, [-1, dim])

            weights = self.get_fc_weight(name)
            biases = self.get_bias(name)

            # Fully connected layer. Note that the '+' operation automatically
            # broadcasts the biases.
            fc = tf.nn.bias_add(tf.matmul(x, weights), biases)

            return fc

    def get_conv_filter(self, name):
        return tf.constant(self.data_dict[name][0], name="filter")

    def get_bias(self, name):
        return tf.constant(self.data_dict[name][1], name="biases")

    def get_fc_weight(self, name):
        return tf.constant(self.data_dict[name][0], name="weights")

以上是VGG16网络的定义,假设我们现在输入图像image,打算做分割,那么我们可以使用端对端的全卷积网络进行训练和测试。针对这个任务,我们只需要输出pool5的feature map即可。

#以上你的网络定义,初始化方式,以及数据预处理...

vgg = vgg16.Vgg16()
vgg.build(image)
feature_map = vgg.pool5
mask = yournetwork(feature_map)

#以下定义loss,学习率策略,然后train...


三、 参考资料

https://github.com/leihe001/tensorflow-vgg

https://github.com/leihe001/tfAlexNet

https://github.com/leihe001/tensorflow-resnet


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

学习TensorFlow,调用预训练好的网络(Alex, VGG, ResNet etc) 的相关文章

  • 在 TensorFlow 中,tf.identity 有何用途?

    我见过tf identity在一些地方使用过 例如官方 CIFAR 10 教程和 stackoverflow 上的批量规范化实现 但我不明白为什么有必要 它是用来做什么的 谁能给出一两个用例吗 一种建议的答案是它可以用于 CPU 和 GPU
  • 张量流和线程

    下面是来自 Tensorflow 网站的简单 mnist 教程 即单层 softmax 我尝试通过多线程训练步骤对其进行扩展 from tensorflow examples tutorials mnist import input dat
  • 可视化 TFLite 图并获取特定节点的中间值?

    我想知道是否有办法知道 tflite 中特定节点的输入和输出列表 我知道我可以获得输入 输出详细信息 但这不允许我重建发生在Interpreter 所以我要做的是 interpreter tf lite Interpreter model
  • 阻止 TensorFlow 访问 GPU? [复制]

    这个问题在这里已经有答案了 有没有一种方法可以纯粹在CPU上运行TensorFlow 我机器上的所有内存都被运行 TensorFlow 的单独进程占用 我尝试将 per process memory fraction 设置为 0 但未成功
  • 在相同任务上,Keras 比 TensorFlow 慢

    我正在使用 Python 运行斩首 DCNN 本例中为 Inception V3 来获取图像特征 我使用的是 Anaconda Py3 6 和 Windows7 使用 TensorFlow 时 我将会话保存在变量中 感谢 jdehesa 并
  • Tensorflow 中的图像叠加图像卷积

    假设我有两组图像 A 和 B 每个图像都是 11X5x5x3 其中 11 是示例数量 5x5x3 是图像尺寸 Tensorflow 中是否有一种简单的方法可以对 A i 中的每个图像应用 B i 上的卷积 即 B i 扮演过滤器角色 A i
  • 张量流服务错误:参数无效:JSON 对象:没有命名输入

    我正在尝试使用 Amazon Sagemaker 训练模型 并且希望使用 Tensorflow 服务来为其提供服务 为了实现这一目标 我将模型下载到 Tensorflow 服务 docker 并尝试从那里提供服务 Sagemaker 的训练
  • 如何将张量流模型部署到azure ml工作台

    我在用Azure ML Workbench执行二元分类 到目前为止 一切正常 我有很好的准确性 我想将模型部署为用于推理的 Web 服务 我真的不知道从哪里开始 azure 提供了这个doc https learn microsoft co
  • tf.gather_nd 直观上是做什么的?

    你能直观地解释一下或者举更多例子吗tf gather nd用于在 Tensorflow 中索引和切片为高维张量 我读了API https www tensorflow org api docs python tf gather nd 但它保
  • 在张量流中向卷积神经网络提供可变大小的输入

    我正在尝试使用 feed dict 参数将不同大小的 2d numpy 数组列表传递给卷积神经网络 x tf placeholder tf float32 batch size None None None y tf placeholder
  • 合并张量流数据集批次

    请考虑下面的代码 import tensorflow as tf import numpy as np simple features np array 1 1 1 2 2 2 3 3 3 4 4 4 5 5 5 simple labels
  • Ray:如何在一个 GPU 上运行多个 Actor?

    我只有一个 GPU 我想在该 GPU 上运行许多 Actor 这是我使用的方法ray 下列的https ray readthedocs io en latest actors html https ray readthedocs io en
  • 默认情况下,Keras 自定义层参数是不可训练的吗?

    我在 Keras 中构建了一个简单的自定义层 并惊讶地发现参数默认情况下未设置为可训练 我可以通过显式设置可训练属性来使其工作 我无法通过查看文档或代码来解释为什么会这样 这是应该的样子还是我做错了什么导致默认情况下参数不可训练 代码 im
  • 在 Keras 模型中删除然后插入新的中间层

    给定一个预定义的 Keras 模型 我尝试首先加载预先训练的权重 然后删除一到三个模型内部 非最后几层 层 然后用另一层替换它 我似乎找不到任何有关的文档keras io https keras io 即将做这样的事情或从预定义的模型中删除
  • Keras:如何保存模型或权重?

    如果这个问题看起来很简单 我很抱歉 但是阅读 Keras 保存和恢复帮助页面 https www tensorflow org beta tutorials keras save and restore models https www t
  • 在 Tensorflow 中使用 tf.while_loop 更新变量

    我想更新 Tensorflow 中的变量 因此我使用 tf while loop 例如 a tf Variable 0 0 0 0 0 0 dtype np int16 i tf constant 0 size tf size a def
  • 无法使用tensorflow 2.0.0 beta1保存模型

    我已尝试了文档中描述的所有选项 但没有一个允许我将模型保存在tensorflow 2 0 0 beta1中 我还尝试升级到 也不稳定 TF2 RC 但这甚至破坏了我在测试版中工作的代码 所以我很快就回滚到测试版 请参阅下面的最小复制代码 我
  • 如何使用tensorFlow C++ API中的fileWrite摘要在Tensorboard中查看它

    无论如何 我是否可以获得与 FileWriter 相对应的张量名称 以便我可以写出我的摘要以在 Tensorboard 中查看它们 我的应用程序是基于C 的 所以我必须使用C 来进行训练 FileWriter 不是张量 import ten
  • 移动设备上的 TensorFlow(Android、iOS、Windows Phone)

    我目前正在寻找不同的深度学习框架 特别是用于训练和部署卷积神经网络 要求是 它可以在带有 GPU 的普通 PC 上进行训练 但训练后的模型必须部署在三个主要的移动操作系统上 即 Android iOS 和 Windows Phone Ten
  • 让 TensorFlow 在 ARM Mac 上使用 GPU

    我已经安装了TensorFlow在 M1 上 ARM Mac 根据这些说明 https github com apple tensorflow macos issues 153 一切正常 然而 模型训练正在进行CPU 如何将培训切换到GPU

随机推荐

  • Java 学习历程

    最近论坛上看到好几个朋友都在问 如何学习 Java的问题 我已经学习了J2SE 怎么样才能转向J2EE 我看完了Thinking in Java 可以学习J2EE了么 于是就有了写这篇文章的想法 希望能帮助初学者少走一些弯路 也算是对自己几
  • Java中参数的传递机制,究竟是值传递还是引用传递?

    先说结论 Java语言中 本质上只有值传递 没有引用传递 废话不说 咱们直接来看例子 public class Demo public static void main String args int i 10 testInt i Syst
  • 您的嵌入式开发团队的静态代码分析工具是什么? 这份指南你一定需要

    所有的静态分析工具从50 000英尺高空看去往往都是一样的 当计划部署静态分析时 重要的是选择一个适合组织需求的解决方案 并能随着未来的需求而增长 一个工具应该具备的特点和能力可以分成两组 第一组是常见的 预期的技术功能 如支持的语言 ID
  • 波形失真总结

    失真是输入信号与输出信号在幅度比例关系 相位关系及波形形状产生变化的现象 音频功放的失真分为电失真和声失真两大类 电失真是由电路引起的 声失真是由还音器件扬声器引起的 电失真的类型有 谐波失真 互调失真 瞬态失真 声失真主要是交流接口失真
  • QApplication、QGuiApplication和QCoreApplication三者的区别与联系

    为什么80 的码农都做不了架构师 gt gt gt 从继承关系看 QApplication父类是QGuiApplication QGuiApplication父类是QCoreApplication 开发的应用无图像界面 就使用QCoreAp
  • Ant Trip 【HDU - 3018】【欧拉通路一笔画问题】

    题目链接 欧拉通路与欧拉回路不同 欧拉通路其实不强制要求走回 也就是不要求最后从哪开始 然后再回到哪 这道题 是问的我们需要走几次一笔画 那么 很显然 考虑入度出度以及连通性 在同一个联通块中 我们可以拆分成如下几种可能 形成闭环 无奇数度
  • REST API 最佳入门指南

    点击上方 程序员大咖 选择 置顶公众号 关键时刻 第一时间送达 如果你看到这里 你以前可能听说过API 和REST 然后你就会想 这些都是什么东西 也许你已经了解过一些这方面的知识 但却不知道从何入手 在这个教程中 我将会诠释REST的基础
  • create umi创建项目

    1 环境准备 安装node node确保它是 8 10 或更高版本 node v v14 17 0 安装yarn 推荐用于yarn管理 npm 依赖 npm install g yarn gt yarn 1 22 10 preinstall
  • keepalived工作原理和配置说明

    keepalived是什么 keepalived是集群管理中保证集群高可用的一个服务软件 其功能类似于heartbeat 用来防止单点故障 keepalived工作原理 keepalived是以VRRP协议为实现基础的 VRRP全称Virt
  • 斗智斗勇 -- 谷歌浏览器的主页被篡改

    不知道从什么时候开始 每次我打开谷歌浏览器 都会跳出2345网址导航 界面花里胡哨的 今天实在是忍无可忍了 就对他动手了 百度了半天 又是禁服务 又是删注册表的 一直然并软 最后实在没办法 只能装个电脑管家试试了 解决完问题再卸载吧 安装好
  • 在rdesktop 远程时报如下错误Autoselecting keyboard map ‘en-us‘ from localeCore(warning): Certificate received

    在rdesktop 时报如下错误 Autoselecting keyboard map en us from locale Core warning Certificate received from server is NOT trust
  • IT项目管理第七次作业

    完成作业1 3的要求 使用 project 或其他项目管理工具 1 假设 每项工作的单位小时成本数如下表 项目经理单位小时成本为100 项目团队成员单位小时成本为60 WBS条目 小时数 单位小时成本 美元 子层总合 美元 WBS第二层的总
  • java 内存溢出 扩大jvm内存

    随手小记 今天下午遇到一个问题 java lang OutOfMemoryError Java heap space 内存溢出问题 遇到这个问题一般有两个解决方式 第一种 修改代码程序 代码中存在大量未被释放的对象引用 或者gc 机制没有来
  • 全排列 Ⅱ--回溯算法

    LeetCode 全排列 给定一个可包含重复数字的序列 返回所有不重复的全排列 示例 输入 1 1 2 输出 1 1 2 1 2 1 2 1 1 解法 回溯法 解题思路 思路很简单 因为要全排列 所以每一个数字都可能选择 即选择区间为 0
  • 最新版Bootstrap5教程——Bootstrap5基础

    个人主页 这个昵称我想了20分钟 往期专栏 速成之路 jQuery 速成之路 SQLserver 速成之路 Ajax 系列专栏 最新Bootstrap5教程 Bootstrap5 Bootstrap5简介 Bootstrap5下载 Boot
  • Linux下shell脚本实战之批量新建用户

    Linux下shell脚本实战之批量新建用户 一 脚本要求 二 脚本内容 三 运行脚本 一 脚本要求 二 脚本内容 三 运行脚本 一 脚本要求 1 使用提供的user txt用户列表 2 批量新建user txt中用户 二 脚本内容 1 查
  • java 简单 数组 自然合并排序

    题目 对所给元素存储于数组中或链表中 选择一种情形 写出自然合并排序算法 结果演示 基本思想 自然排序是在合并排序的基础上修改而成 合并排序 给出一个n个元素无序的整数数组 将其一分为2 则一个子集为n 2 再将子集划分为2 不断划分直到只
  • 电赛控制-----经验分享

    1 赛前准备 先简单介绍一下电赛 电赛是两年一届 单数年是大电赛 全称是全国大学生电子设计大赛 之前由瑞萨电子赞助 所以之前也叫瑞萨杯 从19年开始赞助方变成了TI公司 偶数年是小电赛 全称是 TI杯 模电邀请赛 这里不得不提TI公司的实力
  • json库报错(TypeError: the JSON object must be str, bytes or bytearray, not TextIOWrapper)

    使用json库导入json文件时 报错 TypeError the JSON object must be str bytes or bytearray not TextIOWrapper import json f open data d
  • 学习TensorFlow,调用预训练好的网络(Alex, VGG, ResNet etc)

    视觉问题引入深度神经网络后 针对端对端的训练和预测网络 可以看是特征的表达和任务的决策问题 分类 回归等 当我们自己的训练数据量过小时 往往借助牛人已经预训练好的网络进行特征的提取 然后在后面加上自己特定任务的网络进行调优 目前 ILSVR