使用tf-slim的ResNet V1 152和ResNet V2 152预训练模型进行图像分类

2023-11-09

本文使用tf-slim的ResNet V1 152和ResNet V2 152预训练模型进行图像分类,并研究slim网络的scope命名等。

tf-slim文档不太多,实现过程中多参考官网的源码: https://github.com/tensorflow/models/tree/master/research/slim
注意resnet v2的预处理有点不一样,输入是299而不是224
ResNet V2 152
(tf-slim: ResNet V2 models use Inception pre-processing and input image size of 299 (use –preprocessing_name inception –eval_image_size 299 when using eval_image_classifier.py). Performance numbers for ResNet V2 models are reported on the ImageNet validation set.)

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 29 16:25:16 2017

@author: wayne


我们用的是tf1.2,最新的tf1.3地址是
https://github.com/tensorflow/models/tree/master/research/slim

http://geek.csdn.net/news/detail/126133
如何用TensorFlow和TF-Slim实现图像分类与分割

https://www.2cto.com/kf/201706/649266.html
【Tensorflow】辅助工具篇——tensorflow slim(TF-Slim)介绍

https://stackoverflow.com/questions/39582703/using-pre-trained-inception-resnet-v2-with-tensorflow
The Inception networks expect the input image to have color channels scaled from [-1, 1]. As seen here.
You could either use the existing preprocessing, or in your example just scale the images yourself: im = 2*(im/255.0)-1.0 before feeding them to the network.
Without scaling the input [0-255] is much larger than the network expects and the biases all work to very strongly predict category 918 (comic books).

TensorFlow实现ResNet(ResNet 152网络结构的forward耗时检测)
http://blog.csdn.net/superman_xxx/article/details/65452735
ResNet原理及其在TF-Slim中的实现
http://www.jianshu.com/p/3af06422c768
"""

import tensorflow as tf
slim = tf.contrib.slim
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import imagenet  #注意需要用最新版tf中的对应文件,否则http地址是不对的

from inception_resnet_v2 import *
from resnet_v1 import *
from resnet_v2 import *

import inception_preprocessing
import vgg_preprocessing


'''
inception_resnet_v2
  Returns:
    tensor_out: output tensor corresponding to the final_endpoint.
    end_points: a set of activations for external use, for example summaries or
                losses.
'''
tf.reset_default_graph()

checkpoint_file = 'inception_resnet_v2_2016_08_30.ckpt'
image = tf.image.decode_jpeg(tf.read_file('dog.jpeg'), channels=3) #['dog.jpg', 'panda.jpg']

image_size = inception_resnet_v2.default_image_size #  299

'''这个函数做了裁剪,缩放和归一化等'''
processed_image = inception_preprocessing.preprocess_image(image, 
                                                        image_size, 
                                                        image_size,
                                                        is_training=False,)
processed_images  = tf.expand_dims(processed_image, 0)

'''Creates the Inception Resnet V2 model.'''
arg_scope = inception_resnet_v2_arg_scope()
with slim.arg_scope(arg_scope):
  logits, end_points = inception_resnet_v2(processed_images, is_training=False)   

probabilities = tf.nn.softmax(logits)

saver = tf.train.Saver()


with tf.Session() as sess:
    saver.restore(sess, checkpoint_file)

    #predict_values, logit_values = sess.run([end_points['Predictions'], logits])
    logits2, image2, network_inputs, probabilities2 = sess.run([logits,
                                                                image,
                                                       processed_images,
                                                       probabilities])

    print(logits2)  
    print(logits2.shape) #(1, 1001)

    print(network_inputs.shape)
    print(probabilities2.shape)
    probabilities2 = probabilities2[0,:]
    sorted_inds = [i[0] for i in sorted(enumerate(-probabilities2),
                                        key=lambda x:x[1])]    


# 显示下载的图片
plt.figure()
plt.imshow(image2)#.astype(np.uint8))
plt.suptitle("Original image", fontsize=14, fontweight='bold')
plt.axis('off')
plt.show()

# 显示最终传入网络模型的图片
plt.imshow(network_inputs[0,:,:,:])
plt.suptitle("Resized, Cropped and Mean-Centered inputs to network",
             fontsize=14, fontweight='bold')
plt.axis('off')
plt.show()

names = imagenet.create_readable_names_for_imagenet_labels()
for i in range(5):
    index = sorted_inds[i]
    print(index)
    # 打印top5的预测类别和相应的概率值。
    print('Probability %0.2f => [%s]' % (probabilities2[index], names[index+1]))





'''https://github.com/tensorflow/models/blob/master/research/slim/train_image_classifier.py'''
def _get_variables_to_train():
    """Returns a list of variables to train.
    Returns:
      A list of variables to train by the optimizer.
    """
    trainable_scopes = 'InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits'

    if trainable_scopes is None:
      return tf.trainable_variables()
    else:
      scopes = [scope.strip() for scope in trainable_scopes.split(',')]

    variables_to_train = []
    for scope in scopes:
      variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
      variables_to_train.extend(variables)
    return variables_to_train

'''
一些关于inception_resnet_v2变量的测试,在理解模型代码和迁移学习中很有用
'''
exx = tf.trainable_variables()
print(type(exx))
print(exx[0])
print(exx[-1])
print(exx[-2])
print(exx[-3])
print(exx[-4])
print(exx[-5])
print(exx[-6])
print(exx[-7])
print(exx[-8])
print(exx[-9])
print(exx[-10])

print('###############################################################')
variables_to_train = _get_variables_to_train()
print(variables_to_train)

print('###############################################################')
exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
variables_to_restore = slim.get_variables_to_restore(exclude = exclude)
print(variables_to_restore[0])
print(variables_to_restore[-1])

print('###############################################################')
exclude = ['InceptionResnetV2/Logits']
variables_to_restore = slim.get_variables_to_restore(exclude = exclude)
print(variables_to_restore[0])
print(variables_to_restore[-1])



'''
resnet_v2 152
    num_classes: Number of predicted classes for classification tasks. If None
      we return the features (2048) before the logit layer.

  Returns:
    net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
      If global_pool is False, then height_out and width_out are reduced by a
      factor of output_stride compared to the respective height_in and width_in,
      else both height_out and width_out equal one. If num_classes is None, then
      net is the output of the last ResNet block, potentially after global
      average pooling. If num_classes is not None, net contains the pre-softmax
      activations.
    end_points: A dictionary from components of the network to the corresponding
      activation.
'''
tf.reset_default_graph()

checkpoint_file = 'resnet_v2_152.ckpt'
image = tf.image.decode_jpeg(tf.read_file('dog.jpeg'), channels=3) #['dog.jpg', 'panda.jpg']

image_size = inception_resnet_v2.default_image_size #  299

'''这个函数做了裁剪,缩放和归一化等'''
processed_image = inception_preprocessing.preprocess_image(image, 
                                                        image_size, 
                                                        image_size,
                                                        is_training=False,)
processed_images  = tf.expand_dims(processed_image, 0)

'''Creates the Resnet V2 model.'''
arg_scope = resnet_arg_scope()
with slim.arg_scope(arg_scope):
    net, end_points = resnet_v2_152(processed_images, 1001, is_training=False)   


probabilities = tf.nn.softmax(net)

saver = tf.train.Saver()


with tf.Session() as sess:
    saver.restore(sess, checkpoint_file)

    net2, image2, network_inputs, end_points2, probabilities2= sess.run([net, 
                                                             image,
                                                       processed_images,
                                                       end_points,
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用tf-slim的ResNet V1 152和ResNet V2 152预训练模型进行图像分类 的相关文章

随机推荐

  • ARP(地址解析协议)

    ARP Address Resolution Protocol 地址解析协议 可以在以太网上 根据已知的IP地址查找主机的硬件地址 一 ARP的工作原理 我们以以太网的工作环境作为背景来探讨这一协议 串行链路由于是点到点链路 故而不需要AR
  • 微信网页开发分享

    首先提供一个微信官方地址点击打开链接 早期web项目中经常用到微信分享功能 现在整理一下 供记忆与分享 开发环境为JAVA H5 1 微信的开发环境不在多说 大概为 使用已备案的域名 设置 公众号设置 的三项域名 设置开发者密码 AppSe
  • Java直接杀死线程方法_如何杀死一个线程?

    1 简介 在这篇短文中 我们将讲述一下java中如果结束一个线程 事实上 这并没有想象中的那么简单 因为 Thread stop 方法已经被废弃啦 根据Oracle的解释 stop 方法可以导致被监视对象遭受破坏 2 使用一个Flag 我们
  • DWT数字水印算法(Python)

    DWT数字水印算法的基本原理 结合Arnold变换的基于DWT的数字水印的嵌入 充分利用了小波变换的特点 采用Haar小波 把原始图像及水印图像进行三级小波分解 然后在多分辨率分解后的频段嵌入水印信号 得到嵌入水印的图像 数字水印最重要的性
  • Keil5识别不到ST-Link的解决办法

    刚开始还以为是pack的问题 下载好多pack也没解决 后来发现其实是驱动的问题 从官网上下载驱动 之后进行基本的配置 如下所示 点击魔术棒标志 然后 然后 点击settings 点击add 添加自己的芯片类型 选择erase full c
  • 基于BERT模型实现文本分类任务(transformers+torch)

    BERT的原理分析可以看这 BERT Pre training of Deep Bidirectional Transformers for Language Understanding 论文笔记 代码实现主要用到huggingface的t
  • 如何保证MQ不丢失信息

    为了保证消息队列 MQ 不丢失信息 有以下几种方法可以考虑 增加冗余 通过将数据存储到多个不同的地方来防止数据丢失 使用持久化存储 通过将数据存储到磁盘上 而不是内存中 以确保数据不会丢失 引入数据备份 定期对数据进行备份 以防止意外数据丢
  • 二. go 常见控制结构实现原理之 select

    目录 一 基础问题 select 与channel select 与 channel 二 实现原理 1 select 底层结构 2 select选择case的执行逻辑 一 基础问题 select是Golang在语言层面提供的多路IO复用的机
  • Vue基础--组件的创建和使用

    一 组件化思想 一个页面中所有的处理概述逻辑全部放在一起 处理起来就会变得非常复杂 不利于后续的管理以及扩展 但是 我们将一个页面逻辑复杂的页面拆分成一个个小的功能块 每个功能块只完成属于自己这部分独立的功能 把大功能拆分成一个个小的功能
  • 51单片机0-9数字LED灯循环输出

    代码 include
  • 703n的OpenWrt配置一:安装和基本设置

    OpenWrt支持的路由可以从官网查到 顺藤摸瓜也可以找到固件的下载地址 如果知道路由器的cpu也可以从这里分类查找路由器型号 对于703n的ar71xx就是点我里面搜索703n找到的那几个文件 挑最小的固件下载 这样可以剩下更多空间安装其
  • 【C++】类的默认成员函数——构造函数、析构函数、拷贝构造函数、赋值运算符重载

    文章目录 一 前言 二 构造函数 1 基本概念 2 初始化列表 3 自动生成的构造函数 三 析构函数 1 基本概念 2 自动生成的析构函数 四 拷贝构造函数 1 基本概念 2 自动生成的拷贝构造函数 五 赋值运算符重载 1 基本概念 2 自
  • 全面剖析PMD静态代码扫描工具

    PMD是使用JavaCC生成解析器来解析源代码并生成AST 抽象语法树 的 这两天对PMD及自定义规则做了调研及实验 部分说明来自官方说明文档 做了大部分参数的详细描述及测试 少数几个参数不明白含义 有了解的朋友欢迎讨论 1 调研对象 pm
  • 如何连接安卓手机到mac并传文件

    平时你有没有需求将文件拖拽到安卓手机文件夹下呢 我最近就需要安装许多插件包到我的手机上 今天就记录下我是如何做这个事情的 本文纯属自己记录自己的学习过程 下面交代下步骤 1 mac端下载HandShaker 2 安装HandShaker包
  • 2.1.cuda驱动API-概述

    目录 前言 1 Driver API概述 2 补充知识 总结 前言 杜老师推出的 tensorRT从零起步高性能部署 课程 之前有看过一遍 但是没有做笔记 很多东西也忘了 这次重新撸一遍 顺便记记笔记 本次课程学习精简 CUDA 教程 Dr
  • git rebase 合并提交与避免分叉合并

    本文让你熟练使用 rebase 学会以下两种操作 从此拒绝杂乱无章的 git 提交 目录 用法一 合并当前分支的多个commit记录 step1 找到想要合并的 commit 使用 rebase i step2 进入 Interact 交互
  • 阮一峰ES6 入门教程

    学习地址 https es6 ruanyifeng com
  • 书单(含资源链接,快撸!)

    撸资源 笨办法 学Python 第3版 https www jianshu com p 67a4827e88a1 Python 编写高质量Python代码的59个有效方法 https pan baidu com s 1vAw1R9bP5EC
  • 计算机视觉毕业后找不到工作怎么办?

    点击上方 视学算法 选择加 星标 置顶 重磅干货 第一时间送达 编辑 Amusi 来源 知乎 https www zhihu com question 335451320 本文仅作为学术分享 如果侵权 会删文处理 计算机视觉毕业后找不到工作
  • 使用tf-slim的ResNet V1 152和ResNet V2 152预训练模型进行图像分类

    本文使用tf slim的ResNet V1 152和ResNet V2 152预训练模型进行图像分类 并研究slim网络的scope命名等 tf slim文档不太多 实现过程中多参考官网的源码 https github com tensor