猫狗数据集

2023-11-09

import numpy as np
import pickle
import cv2
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

#mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
train_data = {b'data': [], b'labels': []}
with open("D:/TensorFlow_gpu/animal.pickle", mode='rb') as file:
data = pickle.load(file, encoding='bytes')
train_data[b'data'] += list(data['train_images'])
train_data[b'labels'] += list(data['train_label'])

train_epochs = 802 # 训练轮数
batch_size = 40 # 随机出去数据大小
display_step = 10 # 显示训练结果的间隔
learning_rate = 0.000001 # 学习效率
drop_prob = 0.2 # 正则化,丢弃比例
fch_nodes = 256 # 全连接隐藏层神经元的个数

def weight_init(shape):
weights = tf.truncated_normal(shape, stddev=0.1, dtype=tf.float32)#符合正太分布mean=0
#weights = tf.truncated_normal(shape, mean=0.01, stddev=0.1, dtype=tf.float32)
return tf.Variable(weights)


# 偏置的初始化
def biases_init(shape):
biases = tf.random_normal(shape, dtype=tf.float32)
# biases = tf.random_normal(shape, mean=-0.01, stddev=0.1, dtype=tf.float32)
return tf.Variable(biases)


# 随机选取mini_batch
def get_random_batchdata(n_samples, batchsize):
start_index = np.random.randint(0, n_samples - batchsize)
return (start_index, start_index + batchsize)


def xavier_init(layer1, layer2, constant=1):
Min = -constant * np.sqrt(6.0 / (layer1 + layer2))
Max = constant * np.sqrt(6.0 / (layer1 + layer2))
return tf.Variable(tf.random_uniform((layer1, layer2), minval=Min, maxval=Max, dtype=tf.float32))


def conv2d(x, w):
return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')


def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')


x = tf.placeholder(tf.float32, [None, 224,224,3])
y = tf.placeholder(tf.float32, [None, 2])
# 把灰度图像一维向量,转换为28x28二维结构
x_image = x

w_conv1 = weight_init([3, 3, 3, 96]) # 3*3,深度为3,96
b_conv1 = biases_init([96])
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) # 输出张量的尺寸:112
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_init([3, 3, 96, 96])
b_conv2 = biases_init([96])
h_conv2 = tf.nn.tanh(conv2d(h_pool1, W_conv2) + b_conv2)#输出是56
h_pool2 = max_pool_2x2(h_conv2)#池化后输出16*16*96
#2-1
W_conv3 = weight_init([3, 3, 96, 128])
b_conv3 = biases_init([128])
h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)#输出28
h_pool3 = max_pool_2x2(h_conv3)#池化后输出16*16*96
#第2层卷积2-2

W_conv4 = weight_init([3, 3, 128, 128])
b_conv4 = biases_init([128])
h_conv4 = tf.nn.tanh(conv2d(h_pool3, W_conv4) + b_conv4)#14
h_pool4 = max_pool_2x2(h_conv4)#池化输出8*8*128
#3-1
W_conv5 = weight_init([3, 3, 128, 256])
b_conv5 = biases_init([256])
h_conv5 = tf.nn.relu(conv2d(h_pool4, W_conv5) + b_conv5)#7*7*256
h_pool5 = max_pool_2x2(h_conv5)#

h_pool5_flat = tf.reshape(h_pool5, [-1, 7 * 7 * 256])

w_fc1 = xavier_init(7 * 7 * 256, fch_nodes)
b_fc1 = biases_init([fch_nodes])
h_fc1 = tf.nn.relu(tf.matmul(h_pool5_flat, w_fc1) + b_fc1)

h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob=drop_prob)

# 隐藏层与输出层权重初始化
w_fc2 = xavier_init(fch_nodes, 2)
b_fc2 = biases_init([2])

# 未激活的输出
y_ = tf.add(tf.matmul(h_fc1, w_fc2), b_fc2)
#y_ = tf.add(tf.matmul(h_fc1_drop, w_fc2), b_fc2)


# 激活后的输出
y_out = tf.nn.softmax(y_)
#y_out = tf.nn.sigmoid(y_)

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_out), reduction_indices=[1]))
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
#optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)

# 准确率
# 每个样本的预测结果是一个(1,10)的vector
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_out, 1))
# tf.cast把bool值转换为浮点数
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
init = tf.global_variables_initializer()
#mnist = input_data.read_data_sets('MNIST/mnist', one_hot=True)
n_samples = int(1800)
total_batches = int(n_samples / batch_size)

#x_train = np.array(train_data[b'data']) / 255
x_train = np.array(train_data[b'data'])
y_train = np.array(pd.get_dummies(train_data[b'labels']))
#x_test = test_data[b'data'] / 255

转载于:https://www.cnblogs.com/TheKat/p/11115554.html

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

猫狗数据集 的相关文章

随机推荐

  • RabbitMQ 消息有效期问题

    目录 一 默认情况 二 TTL Time To Live I TTL 的简介 II 单条消息过期 III 队列消息过期 IV 特殊情况 三 死信队列以及死信交换机 I 死信交换机 II 死信队列 III 具体操作 一 默认情况 在默认情况下
  • html 模板

    模板王 10000 免费网页模板 网站模板下载大全 mobanwang com http www mobanwang com
  • IEEE Transactions模板中参考文献作者缩写、期刊名缩写

    IEEE Transactions模板中参考文献作者缩写 期刊名缩写 本文章记录如何在IEEE Transactions的模板中 解决参考文献的作者缩写 期刊名字缩写的问题 目录 IEEE Transactions模板中参考文献作者缩写 期
  • python爬虫一:爬虫简介

    1 什么是爬虫 络爬 被称为 蜘蛛 络机器 就是模拟客户端发送 络请求 接收请求响应 种按照 定的规则 动地抓取互联 信息的程序 只要是浏览器能做的事情 原则上 爬 都能够做 可见即可爬 1 1爬虫有哪些用途 为其他数据提供数据源 像AI人
  • 数据挖掘的特点

    数据挖掘具有以下几个特点 1 基于大量数据 并非说小数据量上就不可以进行挖掘 实际上大多数数据挖掘的算法都可以在小数据量上运行并得到结果 但是 一方面过小的数据量完全可以通过人工分析来总结规律 另一方面来说 小数据量常常无法反映出真实世界中
  • kettle运行spoon.bat时找不到javaw文件

    我也遇到这问题了 分享一下解决方法吧以后没准还有人能用到 我机器的主要问题是环境变量JAVA HOME的值不对 应该写到jdk也就是C Program Files Java jdk1 7 0 25 并且 改完后要重启机器才行 这个很重要
  • DNS服务器的安装与配置

    一 DNS服务器的安装 步骤1 选择 开始 控制面板 添加或删除程序 添加 删除Windows组件 然后选取 网络服务 组件 再单击详细信息按钮 步骤2 选取 域名系统 DNS 组件后单击 确定 按钮 步骤3 回到前一个画面后 单击 下一步
  • vscode远程开发及公钥配置(告别密码登录)

    文章目录 vscode远程开发及公钥配置 简介 关于远程开发官网简介 关于SSH简介 环境 插件安装 配置服务器 找到配置文件 修改配置文件 连接服务器 配置密钥 简介 密钥生成 服务器上安装公钥 查看或配置打开密钥登录功能 服务器私钥复制
  • SSL/TLS 双向认证(一) -- SSL/TLS 工作原理

    本文部分参考 https www wosign com faq faq2016 0309 03 htm https www wosign com faq faq2016 0309 04 htm http blog csdn net hher
  • 四川计算机专业高职高考,四川职高计算机专业分数线

    类似问题答案 2016年贵州大学计算机类专业在四川录取分数线 学校 地 区 专业 年份 批次 类型 分数 贵州大学 四川 计算机类 2016 一批 理科 597 学校 地 区 专业 年份 批次 类型 分数 贵州大学 四川 计算机类 2016
  • 【华为OD统一考试B卷

    在线OJ 已购买本专栏用户 请私信博主开通账号 在线刷题 运行出现 Runtime Error 0Aborted 请忽略 华为OD统一考试A卷 B卷 新题库说明 2023年5月份 华为官方已经将的 2022 0223Q 1 2 3 4 统一
  • 分布式数据库核心原理 Zookeeper+Mysql

    原文 作者 1菩提行者1 笔者一直做java开发 由于技术演进做过大型微服务项目 微服务即将一个大的服务拆分成一个一个小的微服务 每个微服务自成生态 而在落地过程中紧紧只是应用层拆分 数据层往往用同一个库 有点形变神不变 当然将微服务与其对
  • JavaScript如何截取指定位置的字符串

    我们在日常开发中 经常需要对字符串进行删除截取增加的操作 我们这次说一下使用JavaScript截取指定位置的字符串 一 使用slice 截取 slice 方法可以通过指定的开始和结束位置 提取字符串的某个部分 并以新的字符串返回被提取的部
  • MIPI介绍(CSI DSI接口)

    MIPI介绍 CSI DSI接口 MIPI介绍 CSI DSI接口 视频接口 2 MIPI Solution mipi接口 缘来是你远去是我的博客 CSDN博客 MIPI LVDS RGB HDMI等接口对比 mipi和lvds区别 芒果5
  • socket编程

    socket 可以看做用户进程与内核网络协议栈的编程接口 可以用于本机进程间 网络上不同主机进程间的通信 对等通信 是全双工的 socket 异构系统 所以需要统一字节序统一后的字节序为大端字节序 x86为小端字节序 字节序转换函数 可以看
  • vsCode插件安装之汉化和浏览器打开

    一 汉化的方法 点击最左面第五个图标 在搜索框里面输入Chinese 点击如图第一个内容 点击Install 安装 安装后 重启软件即可 二 浏览器打开html 文件方法 在安装插件窗口搜索Browser 点击如图内容 点击install安
  • SpringBoot(十)SpringBoot自定义starter

    一个月的时间 转眼已经到了我的SpringBoot系列的第十篇文章 还记得我的第二篇文章SpringBoot 二 starter介绍 springboot的starter heart荼毒的博客 CSDN博客 曾经介绍过starter sta
  • mmdetection训练自己的VOC数据集 label=self.cat2label 报错解决方案

    废话不多说 直接上报错的图 看了GitHub上的大佬的回答 报错的原因是self cat2label值不对 所以根据大佬的建议 我print了self cat2label值 发现果然不对 类还是VOC数据集的类 而不是我自己的类 我的类是
  • ARM下高效C编程

    通过一定的风格来编写 C 程序 可以帮助 C 编译器生成执行速度更快的 ARM 代码 下面就是一些与性能相关的关键点 1 对局部变量 函数参数和返回值要使用 signed 和 unsigned int 类型 这样可以避免类型转换 而且可高效
  • 猫狗数据集

    import numpy as npimport pickleimport cv2import pandas as pdimport tensorflow as tfimport matplotlib pyplot as plt mnist