深度学习——图像增强 小组代码

2023-11-18

TJU暑期的深度学习训练营,这是人脸识别运用图像增强后的一段代码~

import os, shutil
# ! unzip tjudataset.zip
base_dir = './tjudataset'

# read data
train_dir = os.path.join(base_dir,'train')
validation_dir = os.path.join(base_dir,'validation')
test_dir = os.path.join(base_dir,'test')

from keras import layers
from keras import models
from keras import optimizers
from keras.preprocessing.image import ImageDataGenerator
model = models.Sequential()
model.add(layers.Conv2D(64, (2, 2), activation='relu',
                        input_shape=(210, 210, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (2, 2), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(256, (2, 2), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(512, (2, 2), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dropout(0.3))
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(61, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.RMSprop(lr=1e-4),
              metrics=['acc'])
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 209, 209, 64)      832       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 104, 104, 64)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 103, 103, 128)     32896     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 51, 51, 128)       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 50, 50, 256)       131328    
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 25, 25, 256)       0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 24, 24, 512)       524800    
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 12, 12, 512)       0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 73728)             0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 73728)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               37749248  
_________________________________________________________________
dropout_2 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 61)                31293     
=================================================================
Total params: 38,470,397
Trainable params: 38,470,397
Non-trainable params: 0
_________________________________________________________________
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=10,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest')

# Note that the validation data should not be augmented!
validation_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        # This is the target directory
        train_dir,
        # All images will be resized to 150x150
        target_size=(210, 210),
        batch_size=61,
        # Since we use binary_crossentropy loss, we need binary labels
        class_mode='categorical')

validation_generator = validation_datagen.flow_from_directory(
        validation_dir,
        target_size=(210, 210),
        batch_size=61,
        class_mode='categorical')
Found 549 images belonging to 61 classes.
Found 61 images belonging to 61 classes.
from keras.callbacks import ModelCheckpoint   
from matplotlib import pyplot as plt
import numpy as np
checkpointer = ModelCheckpoint(filepath='TJUFACE.augmentation.model.weights.best.hdf5', verbose=1, 
                               save_best_only=True)
before = 0
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(210, 210),
        batch_size=61,
        class_mode='categorical')
xx = []
yy = []

for i in range(200):
    xx += [i]
    print('The ',i+1,'times:')
    history = model.fit_generator(
      train_generator,
      steps_per_epoch=9,
      epochs=1, 
      validation_data=validation_generator,
      validation_steps=1, 
      callbacks=[checkpointer], 
      verbose=1)
#     model.load_weights('TJUFACE.augmentation.model.weights.best.hdf5')
    test_loss, test_acc = model.evaluate_generator(test_generator, steps=1)
    yy += [test_acc]
    if test_acc > before:
        print('-------------------------------------------------------------------------')
        print('epochs = ',i+1)
        print('Test_acc:', test_acc)
        print('-------------------------------------------------------------------------')
        before = test_acc
print()
print('The highest test_acc :',before)
Found 61 images belonging to 61 classes.
The  1 times:
Epoch 1/1
9/9 [==============================] - 14s 2s/step - loss: 4.1443 - acc: 0.0109 - val_loss: 4.0953 - val_acc: 0.0984

Epoch 00001: val_loss improved from inf to 4.09529, saving model to TJUFACE.augmentation.model.weights.best.hdf5
-------------------------------------------------------------------------
epochs =  1
Test_acc: 0.09836065769195557
-------------------------------------------------------------------------
The  2 times:
Epoch 1/1
9/9 [==============================] - 11s 1s/step - loss: 4.1036 - acc: 0.0328 - val_loss: 4.0612 - val_acc: 0.0656

Epoch 00001: val_loss improved from 4.09529 to 4.06118, saving model to TJUFACE.augmentation.model.weights.best.hdf5
The  3 times:
Epoch 1/1
9/9 [==============================] - 11s 1s/step - loss: 4.0701 - acc: 0.0510 - val_loss: 3.9763 - val_acc: 0.0656
……
……
……
Epoch 00001: val_loss did not improve from 0.33823
The  199 times:
Epoch 1/1
9/9 [==============================] - 11s 1s/step - loss: 0.4557 - acc: 0.8743 - val_loss: 0.8803 - val_acc: 0.8361

Epoch 00001: val_loss did not improve from 0.33823
The  200 times:
Epoch 1/1
9/9 [==============================] - 11s 1s/step - loss: 0.5764 - acc: 0.8452 - val_loss: 0.4710 - val_acc: 0.8852

Epoch 00001: val_loss did not improve from 0.33823

The highest test_acc : 0.9508196711540222
plt.figure(figsize=(10,5),dpi=200)
plt.title('Test_acc')
plt.xlabel('Epochs')
plt.xticks(np.arange(0,205,10))
plt.ylabel('Test_acc')
plt.yticks(np.arange(0,1,0.1))
def smooth_curve(points, factor=0.8):
  smoothed_points = []
  for point in points:
    if smoothed_points:
      previous = smoothed_points[-1]
      smoothed_points.append(previous * factor + point * (1 - factor))
    else:
      smoothed_points.append(point)
  return smoothed_points
plt.plot(xx,smooth_curve(yy),label='Final',color='m',marker=',',linestyle='-')# 16进制颜色码
plt.legend()
plt.show()

在这里插入图片描述

model.load_weights('TJUFACE.augmentation.model.weights.best.hdf5')
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(210, 210),
        batch_size=61,
        class_mode='categorical')

test_loss, test_acc = model.evaluate_generator(test_generator, steps=1)
print('test acc:', test_acc)
Found 61 images belonging to 61 classes.
test acc: 0.9016393423080444
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习——图像增强 小组代码 的相关文章

  • 如果两点之间的距离低于某个阈值,则从列表中删除点

    我有一个点列表 只有当它们之间的距离大于某个阈值时 我才想保留列表中的点 因此 从第一个点开始 如果第一个点和第二个点之间的距离小于阈值 那么我将删除第二个点 然后计算第一个点和第三个点之间的距离 如果该距离小于阈值 则比较第一点和第四点
  • 如何手动计算分类交叉熵?

    当我手动计算二元交叉熵时 我应用 sigmoid 来获取概率 然后使用交叉熵公式并平均结果 logits tf constant 1 1 0 1 2 labels tf constant 0 0 1 1 1 probs tf nn sigm
  • Django 的内联管理:一个“预填充”字段

    我正在开发我的第一个 Django 项目 我希望用户能够在管理中创建自定义表单 并向其中添加字段当他或她需要它们时 为此 我在我的项目中添加了一个可重用的应用程序 可在 github 上找到 https github com stephen
  • 如何使用 opencv.omnidir 模块对鱼眼图像进行去扭曲

    我正在尝试使用全向模块 http docs opencv org trunk db dd2 namespacecv 1 1omnidir html用于对鱼眼图像进行扭曲处理Python 我正在尝试适应这一点C 教程 http docs op
  • 安装了 32 位的 Python,显示为 64 位

    我需要运行 32 位版本的 Python 我认为这就是我在我的机器上运行的 因为这是我下载的安装程序 当我重新运行安装程序时 它会将当前安装的 Python 版本称为 Python 3 5 32 位 然而当我跑步时platform arch
  • 将html数据解析成python列表进行操作

    我正在尝试读取 html 网站并提取其数据 例如 我想查看公司过去 5 年的 EPS 每股收益 基本上 我可以读入它 并且可以使用 BeautifulSoup 或 html2text 创建一个巨大的文本块 然后我想搜索该文件 我一直在使用
  • 跟踪 pypi 依赖项 - 谁在使用我的包

    无论如何 是否可以通过 pip 或 PyPi 来识别哪些项目 在 Pypi 上发布 可能正在使用我的包 也在 PyPi 上发布 我想确定每个包的用户群以及可能尝试积极与他们互动 预先感谢您的任何答案 即使我想做的事情是不可能的 这实际上是不
  • Pandas 日期时间格式

    是否可以用零后缀表示 pd to datetime 似乎零被删除了 print pd to datetime 2000 07 26 14 21 00 00000 format Y m d H M S f 结果是 2000 07 26 14
  • 独立滚动矩阵的行

    我有一个矩阵 准确地说 是 2d numpy ndarray A np array 4 0 0 1 2 3 0 0 5 我想滚动每一行A根据另一个数组中的滚动值独立地 r np array 2 0 1 也就是说 我想这样做 print np
  • 使用Python请求登录Google帐户

    在多个登录页面上 需要谷歌登录才能继续 我想用requestspython 中的库以便让我自己登录 通常这很容易使用requests库 但是我无法让它工作 我不确定这是否是由于 Google 做出的一些限制 也许我需要使用他们的 API 或
  • 在Python中连接反斜杠

    我是 python 新手 所以如果这听起来很简单 请原谅我 我想加入一些变量来生成一条路径 像这样 AAAABBBBCCCC 2 2014 04 2014 04 01 csv Id TypeOfMachine year month year
  • Python,将函数的输出重定向到文件中

    我正在尝试将函数的输出存储到Python中的文件中 我想做的是这样的 def test print This is a Test file open Log a file write test file close 但是当我这样做时 我收到
  • 如何使用 pybrain 黑盒优化训练神经网络来处理监督数据集?

    我玩了一下 pybrain 了解如何生成具有自定义架构的神经网络 并使用反向传播算法将它们训练为监督数据集 然而 我对优化算法以及任务 学习代理和环境的概念感到困惑 例如 我将如何实现一个神经网络 例如 1 以使用 pybrain 遗传算法
  • pip 列出活动 virtualenv 中的全局包

    将 pip 从 1 4 x 升级到 1 5 后pip freeze输出我的全局安装 系统 软件包的列表 而不是我的 virtualenv 中安装的软件包的列表 我尝试再次降级到 1 4 但这并不能解决我的问题 这有点类似于这个问题 http
  • 如何使用原始 SQL 查询实现搜索功能

    我正在创建一个由 CS50 的网络系列指导的应用程序 这要求我仅使用原始 SQL 查询而不是 ORM 我正在尝试创建一个搜索功能 用户可以在其中查找存储在数据库中的书籍列表 我希望他们能够查询 书籍 表中的 ISBN 标题 作者列 目前 它
  • 如何在 Windows 命令行中使用参数运行 Python 脚本

    这是我的蟒蛇hello py script def hello a b print hello and that s your sum sum a b print sum import sys if name main hello sys
  • 根据列 value_counts 过滤数据框(pandas)

    我是第一次尝试熊猫 我有一个包含两列的数据框 user id and string 每个 user id 可能有多个字符串 因此会多次出现在数据帧中 我想从中导出另一个数据框 一个只有那些user ids列出至少有 2 个或更多string
  • 使用for循环时如何获取前一个元素? [复制]

    这个问题在这里已经有答案了 可能的重复 Python 循环内的上一个和下一个值 https stackoverflow com questions 1011938 python previous and next values inside
  • Django-tables2 列总计

    我正在尝试使用此总结列中的所有值文档 https github com bradleyayers django tables2 blob master docs pages column headers and footers rst 但页
  • 使用 z = f(x, y) 形式的 B 样条方法来拟合 z = f(x)

    作为一个潜在的解决方案这个问题 https stackoverflow com questions 76476327 how to avoid creating many binary switching variables in gekk

随机推荐

  • 21_OpenCV复制矩阵

    本文是关于矩阵复制的相关函数 目录 1 cv repeat 根据需要重复多次复制 2 实现矩阵的转置操作 cv transpose 1 cv repeat 根据需要重复多次复制 函数原型 void cv repeat cv InputArr
  • java.net.SocketTimeoutException: Read timed out问题排查

    欢迎关注博主微信订阅号 问题日志 java sql SQLException I O Error Read timed out at net sourceforge jtds jdbc TdsCore executeSQL TdsCore
  • Windows安装Apache Maven 3.5.4

    一 安装前的准备 1 1 官网下载Apache Maven Maven 3 6 3 此时最新版 的下载地址 https maven apache org download cgi Maven其他版本的下载地址 https archive a
  • 服务器启动显示按f1f2f10,电脑开机提示按f1f2f5 电脑开机要按F1F2F5

    电脑开机要求按F1 F2 F3或F5 有朋友跟我反应说他的XP系统 开机的时候要手动按F1才可以进WIN程序 那怎么改成默认的呢 可以尝试下以下方法 方法一 开启计算机或重新启动计算机后 及时按下 Del 键进入BIOS的设置界面 随便点击
  • Android:JNI与NDK(二)交叉编译与动态库,静态库

    本篇目录 一 前言 本篇主要以window开发环境为背景介绍一下NDK开发中需要掌握的交叉编译等基础知识 选window系统主要是照顾大多数读者 mac linux操作系统基本是同样适用的 交叉编译就是在A平台编译出可以在B平台执行的文件
  • J-Link仿真器与JTAG和SWD下载与接线

    目录 1 JTAG 1 1JTAG今天被用来主要的三大功能 1 2JTAG引脚 1 3可选引脚 2 SWD 2 1 SWD引脚 2 2 可选择引脚 2 3 JTag和SWD模式引脚定义 3 J Link仿真器 4 IAR与MDK配置两种下载
  • lol服务器位置峡谷之巅,lol峡谷之巅

    英雄联盟峡谷之巅第六赛季的奖励正式的公布了 这次只要排位赛胜场最多的2000名玩家就可以领取到奥术师佐伊至臻的皮肤 很多玩家还不清楚在哪领取峡谷之巅第六赛季的奖励 下面就来为大家分享一下地址 英雄联盟的官方在7月6日的下午5点发布了最新的峡
  • QT入门Buttons之QCheckBox

    目录 一 界面布局介绍 1 布局器中的位置及使用 2 常用属性 二 属性功能介绍 1 常用信号 2 测试信号stateChanged int 3 组合框效果 三 Demo展示 此文为作者原创 转载标明出处 一 界面布局介绍 1 布局器中的位
  • 从一个对象数组中的某一个属性组成新数组,然后比较大小

    需求 从一个对象数组中的某一个属性组成新数组 然后比较大小 示例数组 原始数组 expmArr name zhangsan age 18 name lisi age 20 name wangwu age 17 name zhaoliu ag
  • 编码与调制

    一 信道 信道是信号的传输媒介 一般用来表示向某一个方向传送信息的介质 因此一条通信线路往往包含一条发送信道和一条接收信道 信道根据传输信号分为数字信道 传输数字信号 和模拟信道 传输模拟信号 根据传输介质可分为无线信道和有线信道 同时根据
  • Qt 图片适应QLabel控件大小(饱满缩放和按比例缩放)

    直接上代码 QImage Image Image load d test jpg QPixmap pixmap QPixmap fromImage Image int with ui gt labPic gt width int heigh
  • 【计算机毕业设计】基于微信小程序的流浪动物救助系统 动物领养系统

    毕设帮助 源码交流 技术解答 见文末 一 前言 目前对流浪动物的救助采用的方式非常有限 一般是通过微信群 论坛 贴吧等平台发布流浪动物信息 由其它用户参与救助 这种方式由于没有监控渠道 造成有很多骗子的出现 而且这种方式的宣传力度也不够 经
  • 服务器管理口IP及账号密码(知识汇总)

    HP管理口 ILO 默认用户 密码 Administrator password HP以前管理口登陆MP卡 通过网线连接MP卡的RJ 45口 通过telnet方式登录 默认用户 密码 Admin Admin DELL服务器管理口 idac
  • 生产级logback-spring.xml配置明细

  • win32平台中的程序转换为wince中的一些错误 . 未能为“VCCLCompilerTool”工具生成命令行

    转载自 http blog csdn net shirui1125 article details 6095774 gt ToolBox error PRJ0004 未能为 VCCLCompilerTool 工具生成命令行 从原有的平台复制
  • 第一个nodejs应用

    应用这个词很火 哪里都在用 这里的nodejs应用其实是一个站点 准确的说是运行在本地的一个小小的Http站点 但是nodejs开发主要还是集中在少数的几个核心功能上 而不是那种动辄几千几万个文件 支撑多少并发多少功能的这种大型站点 所以n
  • jmeter接口关联-跨线程和正则表达式提取headers信息(视频详解)

    首先 看下常见的jmeter工作中的3个问题 1 如何提取响应头里面的cookie 2 参数md5加密后 再请求接口 3 多个线程组之间参数如何关联 技术知识 jmeter 跨线程关联 1 提取器 正则表达式 2 md5加密函数 3 Bea
  • 量化分析小函数——上穿函数

    量化分析小函数 上穿函数 上穿函数用于判断上穿信号的有无 输入为两条信号 obj和ref 两者数据类型为python列表 主要判断obj是否上穿ref 1 参考代码 import talib as tl import pandas as p
  • 短文简单理解遗传算法和代码审计应用思路

    短文简单理解遗传算法和代码审计应用思路 如何理解遗传算法 假设小明爷爷DNA之中带有A字段 小明爸也有 小明也有 说明A字段会遗传 如果A是存在危险函数 这就是遗传 同样的在代码之中多数存在包含关系 也称为调用 所以危险函数是可以被 遗传
  • 深度学习——图像增强 小组代码

    TJU暑期的深度学习训练营 这是人脸识别运用图像增强后的一段代码 import os shutil unzip tjudataset zip base dir tjudataset read data train dir os path j