tensorflow搭建自己的残差网络(ResNet)

2023-11-10

废话不说,直接上代码:

首先

pip install tflearn

训练代码

# -*- coding: utf-8 -*-  

from __future__ import division, print_function, absolute_import  

import tflearn  

# Residual blocks  
# 32 layers: n=5, 56 layers: n=9, 110 layers: n=18  
n = 5  

#numClass看你是几分类
numClass = 10

#这里需要用户自己得到(X, Y), (validationX, validationY)
(X, Y), (validationX, validationY)
Y = tflearn.data_utils.to_categorical(Y, numClass )  
testY = tflearn.data_utils.to_categorical(testY, numClass )  

# Real-time data preprocessing  
img_prep = tflearn.ImagePreprocessing()  
img_prep.add_featurewise_zero_center(per_channel=True)  

# Real-time data augmentation  
img_aug = tflearn.ImageAugmentation()  
img_aug.add_random_flip_leftright()  
img_aug.add_random_crop([256, 256], padding=4)  

# Building Residual Network  
net = tflearn.input_data(shape=[None, 256, 256, 3],  
                         data_preprocessing=img_prep,  
                         data_augmentation=img_aug)  
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)  
net = tflearn.residual_block(net, n, 16)  
net = tflearn.residual_block(net, 1, 32, downsample=True)  
net = tflearn.residual_block(net, n-1, 32)  
net = tflearn.residual_block(net, 1, 64, downsample=True)  
net = tflearn.residual_block(net, n-1, 64)  
net = tflearn.batch_normalization(net)  
net = tflearn.activation(net, 'relu')  
net = tflearn.global_avg_pool(net)  
# Regression  
net = tflearn.fully_connected(net, numClass, activation='softmax')  
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)  
net = tflearn.regression(net, optimizer=mom,  
                         loss='categorical_crossentropy')  
# Training  
model = tflearn.DNN(net, checkpoint_path='model_resnet_mymodel',  
                    max_checkpoints=10, tensorboard_verbose=0,  
                    clip_gradients=0.)  

model.fit(X, Y, n_epoch=200, validation_set=(validationX, validationY),  
          snapshot_epoch=False, snapshot_step=500,  
          show_metric=True, batch_size=128, shuffle=True,  
          run_id='resnet_mymodel')  

(X, Y), (validationX, validationY)分别表示训练和验证的数据和标签,具体代码需要自己实现。numClass表示是几分类。

测试代码

from __future__ import division, print_function, absolute_import  

import tflearn  
import numpy as np
from PIL import Image 
import os


def buildModel():
    numClass = 10
    # Residual blocks  
    # 32 layers: n=5, 56 layers: n=9, 110 layers: n=18  
    n = 5 

# Real-time data preprocessing  
    img_prep = tflearn.ImagePreprocessing()  
    img_prep.add_featurewise_zero_center(per_channel=True)  

    # Real-time data augmentation  
    img_aug = tflearn.ImageAugmentation()  
    img_aug.add_random_flip_leftright()  
    img_aug.add_random_crop([256, 256], padding=4)  


    # Building Residual Network  
    net = tflearn.input_data(shape=[None, 256, 256, 3],  
                         data_preprocessing=img_prep,  
                         data_augmentation=img_aug)  
    net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)  
    net = tflearn.residual_block(net, n, 16)  
    net = tflearn.residual_block(net, 1, 32, downsample=True)  
    net = tflearn.residual_block(net, n-1, 32)  
    net = tflearn.residual_block(net, 1, 64, downsample=True)  
    net = tflearn.residual_block(net, n-1, 64)  
    net = tflearn.batch_normalization(net)  
    net = tflearn.activation(net, 'relu')  
    net = tflearn.global_avg_pool(net)  
    # Regression  
    net = tflearn.fully_connected(net, numClass, activation='softmax')  
    mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)  
    net = tflearn.regression(net, optimizer=mom,  
                         loss='categorical_crossentropy')  
    # Training  

    model = tflearn.DNN(net,tensorboard_verbose=0, clip_gradients=0.) 
    #写入你保存model的路径
    model.load(model_file=YourPath, weights_only=False)
    return model
def predicMsk(picPath,model):

    # Data loading  
    test = []
    image = Image.open(picPath)
    image = image.resize([256, 256])
    image = np.array(image)
    test.append(image/255)
    test = np.array(test)

    a = model.predict(test)
    return a

这样就可以用残差网络处理自己的数据集了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tensorflow搭建自己的残差网络(ResNet) 的相关文章

随机推荐

  • (React入门)组件通信

    1 父组件向子组件通信 props 父组件直接通过props向子组件传递需要的信息 import React Component from react 子组件 class Child extends Component render con
  • 怎么搭建微服务架构?

    我要何时使用微服务架构 又如何将应用程序分解为微服务 分解后 要如何去搭建微服务架构 同时 在微服务架构中 因为会涉及到多个组件 那么这些组件又可以使用什么技术来实现呢 接下来的几个小节中 我们将对这些问题进行详细的讲解 微服务的拆分 对于
  • Jenkins

    Jenkins jenkins定义 jenkins功能 tomcat项目部署 jenkins安装 安装后的jenkins的配置 Jenkins jenkins定义 jenkins tomcat 部署 安装 jenkins 是一个开源软件项目
  • 栈的概念和基本使用

    栈 Stack 是一种线性序列结构 其操作仅限于逻辑上特定的一端 即新元素只能从栈的一端插入 也只能从栈的一端弹出 允许操作的一端叫做栈顶 禁止操作的一端叫做栈底 插入元素称为入栈 弹出元素称为出栈 栈中各个元素的操作次序遵循所谓的后进先出
  • 涅普2021训练营-MIsc(部分)

    bin txt 打开附件发现是一大串的二进制字符串 使用python直接转换成16进制 最好别用网页转换 字符串太长了 也不需要创建什么函数 直接进制转换 参考格式 hex int 字符串 2 将得到的16进制吗复制到文本 将0x删除后复制
  • git 服务端钩子做代码检查

    需求分析 在代码修改后可以对代码进行检查 比如代码规范检查 代码构建 单元测试等 我们需要禁止成员推送不符合规范的代码到服务端 Git 钩子能在特定的重要动作发生时触发自定义脚本 钩子分为客户端和服务器端两类 使用客服端钩子可以在commi
  • 字符游戏-智能蛇

    字符游戏 智能蛇 一 VT 100 终端标准 这里按照老师的课件要求 体验一下VT 100 输入输出功能以及清屏操作 代码直接复制课件中代码 这里就不再放一次了 直接给出运行效果 gcc sin demo c osin out lm sin
  • OpenAI开发系列(一):一文搞懂大模型、GPT、ChatGPT等AI概念

    全文共5000余字 预计阅读时间约10 20分钟 满满干货 建议收藏 本文目标 详细解释大型语言模型 LLM 和OpenAI的GPT系列的基本概念 一 什么是大模型 大型语言模型 也称大语言模型 大模型 Large Language Mod
  • 解决Centos7没有ens33

    进入centos7操作 ifconfig ens33 up systemctl stop NetworkManager systemctl disable NetworkManager ifup ens33 systemctl restar
  • explain查看索引使用

    CREATE TABLE test id int 11 NOT NULL name varchar 20 DEFAULT NULL dep id int 11 DEFAULT NULL age int 11 DEFAULT NULL tt
  • Qt中正确引用外部头文件和库文件的方法和注意点

    Qt中正确引入外部库文件的方法和注意点 一 什么报错是外部库导入错误导致的 二 解决外部库使用的方法 一 写入系统环境变量中的外部库调用 1 解释说明 2 使用演示 1 头文件 2 库文件 二 未写入系统环境变量中的外部库调用 1 解释说明
  • controller层

    前言 controller层代码主要流程都是 1 获取前端数据 运用request getParameter 数据名 2 创建user对象 用来传递参数 创建Service对象 用来使用Service服务的方法 3 调用Service的方法
  • C++11内存对齐之std::aligned_storage与alignas与alignof

    1 std aligned storage 插播一下POD的含义 Plain old data structure 缩写为POD 是C 语言的标准中定义的一类数据结构 POD适用于需要明确的数据底层操作的系统中 POD通常被用在系统的边界处
  • DateTime转换为时间戳

  • 记一次线性插值方法(Mathf.Lerp())的使用体会

    对Mathf Lerp 方法使用体会源于一次开发游戏对警报灯闪烁问题进行处理时 public static float Lerp float from float to float t 分析一下对线性插值函数的认识 就是在from与to之间
  • 看完这篇文章保你面试稳操胜券——小程序篇

    进大厂收藏这一系列就够了 全方位搜集总结 为大家归纳出这篇面试宝典 面试途中祝你一臂之力 共分为四个系列 本 篇 为 看 完 这 篇 文 章 保 你 面 试 稳 操 胜 券 第 四 篇
  • springboot mysql链接语句字段分析

    jdbc mysql localhost 3306 xxxx useUnicode true characterEncoding utf8 zeroDateTimeBehavior convertToNull useSSL true ser
  • 几个简单的system(const char* _Command)函数命令

    几个简单的system const char Command 函数命令 呼出终端 Windows键 r 然后输入cmd system const char Command 函数常用命令 如 system cls 1 shutdown常用命令
  • JS 实现全屏切换,移动端适用

    JS 实现全屏切换 移动端适用 直接看代码吧 简单 只是有些人不知道这个 api 我之前就不知道
  • tensorflow搭建自己的残差网络(ResNet)

    废话不说 直接上代码 首先 pip install tflearn 训练代码 coding utf 8 from future import division print function absolute import import tf