机器学习——从0开始构建自己的GAN网络

2023-10-29

目录

一 前言

二 生成式对抗网络GAN

三 GAN的训练思路

四 数据集——Chinese MNIST

五 代码——python

1.文件展示

2.代码(一) ——数据预处理

3.代码(二) ——生成器的构建

4.代码(三) ——判别器的构建

5.代码(四) ——图像的储存

6.代码(五) ——网络的训练

7.代码(六) ——网络参数的定义

8.完整代码

六 运行效果

总结


一 前言

本文仅作为经验分享以及学习记录,如有问题,可以在评论区和我讨论。

具体的理论知识暂且不讲,待有时间了我就会慢慢分享理论知识,目前就整点干货,直接上代码,怎么从零开始构建自己的GAN网络。

本项目Github地址

本人环境:

python:3.7        kares:2.3.1        tensorflow:1.14.0        opencv:4.5.5

二 生成式对抗网络GAN

生成对抗网络(GAN)有两个部分:生成网络G(Generator)和判别网络D(Discriminator)。
1生成网络G:用来生成图片的网络,它接收一个随机的噪声noise,通过这个噪声生成图片。
2判别网络D:用来判别图片是否真实的网络。它的输入是一张图片img,输出是img为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

生成网络G的目的是努力生成一个图片来骗过判别网络D,判别网络D的目的是努力鉴别出生成出来的图片是假的。两个网络在不断博弈中互相进步,达到理想状态:D(G(noise))=0.5(即判别网络D也不确定是到底是不是真实的)

三 GAN的训练思路

GAN的训练要同时训练两个网络,我们使用的方法是:单独交替迭代训练(即训练一个网络的时候,固定住一个网络,去训练另一个网络)

这样做的目的是防止其中一个网络比另一个网络强大太多,导致网络性能弱化。在整个训练过程中,两个网络不断变强,达到理想状态。

四 数据集——Chinese MNIST

我的数据集选的是Kaggle网站上的Chinese MNIST,下载地址

下载速度慢的可以参考我的另一篇博客——解决Kaggle网站下载数据集速度慢,不方便下载的可以联系我发给你压缩包。

数据集举例:

五 代码——python

如果需要替换成自己数据集,我会在每部分代码首部进行特别说明。

这里我们直接开始,直接上代码,通过代码,一方面有助于我梳理本次学习思路,二是我觉得这样更直接明了一些,毕竟动手才有趣。

1.文件展示

2.代码(一) ——数据预处理

这部分函数用来加载path路径下的文件,即我的数据集,也可以根据你们需求换成别的数据集。只需要更改自己的数据集文件夹即可。

 def load_data(self, path):
        print("loading images...")
        data = []
        labels = []
        imagePaths = sorted(list(paths.list_images(path)))
        random.seed(42)
        random.shuffle(imagePaths)
        for imagePath in imagePaths:
            image = cv2.imread(imagePath)
            image = cv2.resize(image, (self.img_rows, self.img_cols))
            image = img_to_array(image)
            data.append(image)

            label = str(imagePath.split(os.path.sep)[-2])
            labels.append(label)

        data = np.array(data, dtype="float") / 255.0
        return data

3.代码(二) ——生成器的构建

这部分代码用来构建生成网络G,不需要更改,尽管网络性能不是很好,但不是必须修改的。

# 构建生成器
    def build_generator(self):
        model = Sequential()        # 模型选用的是传统的线性模型
        model.add(Dense(256, input_dim=self.latent_dim))  # 全连接层
        model.add(LeakyReLU(alpha=0.2))  # 带泄露修正线性单元
        model.add(BatchNormalization(momentum=0.8))  # 批归一化
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))  # np.prod()计算所有乘积,输入
        model.add(Reshape(self.img_shape))  # reshape成图片的尺寸

        # model.summary()  # 日志

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

4.代码(三) ——判别器的构建

这部分代码用来构建判别网络D,不需要更改,尽管网络性能不是很好,但不是必须修改的。

# 构建判别器
    def build_discriminator(self):
        # 模型选用的是传统的线性模型,CNN中用的也是这个
        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))  # 展平层
        model.add(Dense(512))  # 全连接层
        model.add(LeakyReLU(alpha=0.2))  # 带泄露修正线性单元
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        # model.summary()

        img = Input(shape=self.img_shape)  # 输入尺寸
        validity = model(img)

        return Model(img, validity)

5.代码(四) ——图像的储存

这部分代码用来储存生成网络不同epoch的输出,不必更改。

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        # 保存地址为:"images/"
        fig.savefig("images/%d.png" % epoch)
        plt.close()

6.代码(五) ——网络的训练

这部分代码用来训练网络,不必更改。

    def train(self, epochs, batch_size=128, sample_interval=50, file_path=None):

        # 加载数据
        X_train = self.load_data(file_path)
        # 标准化
        # X_train = np.expand_dims(X_train, axis=3)

        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            g_loss = self.combined.train_on_batch(noise, valid)

            if epoch % 200 == 0:
                print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

            # 图像的保存,每sample_interval次保存图片一次
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
            # 模型权重的保存,每2000个epoch,保存一次模型,保存地址为"weights/"
            if epoch % 2000 == 0:
                os.makedirs('weights', exist_ok=True)
                self.generator.save_weights("weights/gen_epoch%d.h5" % epoch)
                self.discriminator.save_weights("weights/dis_epoch%d.h5" % epoch)

7.代码(六) ——网络参数的定义

这部分代码定义了网络的一些参数,比如输入尺寸(我的数据集图片大小是[64,64,3]),优化器等等。

需要根据自己的数据集图片的大小,更改self.img_rows、self.img_cols、self.channels

    def __init__(self):
        # 图片尺寸 在这里更改!!!!
        self.img_rows = 64
        self.img_cols = 64
        self.channels = 3
        # 输入的图片尺寸
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        # Adam优化器
        optimizer = Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        self.generator = self.build_generator()

        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        self.discriminator.trainable = False

        validity = self.discriminator(img)

        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

8.完整代码

需要根据自己需求,更改

epochs训练次数;batch_size每组的数量;sample_interval多少次输出一张图片;file_path数据集路径

完整代码我已上传到github中,代码地址

六 运行效果

迭代0次

迭代10000次

迭代30000次

 迭代50000次

 由于时间以及本人显卡配置的限制,只进行了50000次迭代,为了更好的效果可以增加迭代次数。

总结

至此本博客,从0开始搭建GAN网络就结束了,有什么问题欢迎和我讨论。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习——从0开始构建自己的GAN网络 的相关文章

  • 如何将base64字符串直接解码为二进制音频格式

    音频文件通过 API 发送给我们 该文件是 Base64 编码的 PCM 格式 我需要将其转换为 PCM 然后再转换为 WAV 进行处理 我能够使用以下代码解码 gt 保存到 pcm gt 从 pcm 读取 gt 保存为 wav decod
  • 切片稀疏(scipy)矩阵

    我将不胜感激任何帮助 以理解从 scipy sparse 包中切片 lil matrix A 时的以下行为 实际上 我想根据行和列的任意索引列表提取子矩阵 当我使用这两行代码时 x1 A list 1 x2 x1 list 2 一切都很好
  • Python 2.7 将比特币私钥转换为 WIF 私钥

    作为一名编码新手 我刚刚完成了教程 教程是这样的 https www youtube com watch v tX XokHf nI https www youtube com watch v tX XokHf nI 我想用 1 个易于阅读
  • Python在postgresql表中查找带有单引号符号的字符串

    我需要从 psql 表中查找包含多个单引号的字符串 我当前的解决方案是将单引号替换为双单引号 如下所示 sql query f SELECT exists SELECT 1 FROM table name WHERE my column m
  • 当单词以“|”分隔时如何读取文件(埃因霍温)?

    在Python中 我有一个文件 其中的单词由 例如 city state zipcode 我的文件阅读器无法区分单词 另外 我希望我的文件阅读器从第 2 行而不是第 1 行开始 如何让我的文件阅读器分隔单词 import os import
  • 根据开始列和结束列扩展数据框(速度)

    我有一个pandas DataFrame含有start and end列 加上几个附加列 我想将此数据框扩展为一个时间序列 从start值并结束于end值 但复制我的其他专栏 到目前为止 我想出了以下内容 import pandas as
  • 更改 python tkinter canvas 中的线坐标

    我画了一条线tkinter Canvas现在我想移动一端 这可能吗 例如和itemconfig import tkinter tk tkinter Tk canvas tkinter Canvas tk canvas pack line c
  • Apache Spark 中的高效字符串匹配

    我使用 OCR 工具从屏幕截图中提取文本 每个大约 1 5 句话 然而 当手动验证提取的文本时 我注意到时不时会出现一些错误 鉴于文本 你好 我真的很喜欢 Spark 我注意到 1 像 I 和 l 这样的字母被 替换 2 表情符号未被正确提
  • 动态 __init_subclass__ 方法的参数绑定

    我正在尝试让类装饰器工作 装饰器会添加一个 init subclass 方法到它所应用的类 但是 当该方法动态添加到类中时 第一个参数不会绑定到子类对象 为什么会发生这种情况 举个例子 这是可行的 下面的静态代码是我试图最终得到的示例 cl
  • `list()` 被认为是一个函数吗?

    list显然是内置类型 https docs python org 3 library stdtypes html list在Python中 我看到底下有一条评论this https stackoverflow com a 53645813
  • Python 声音(“铃声”)

    我想让一个 python 程序在完成任务时通过发出嘟嘟声来提醒我 目前 我使用import os然后使用命令行语音程序说 进程完成 我更愿意它是一个简单的 铃 我知道有一个函数可以用于Cocoa apps NSBeep 但我认为这与此没有太
  • 如何使用 Keras ImageDataGenerator 预测单个图像?

    我已经训练 CNN 对图像进行 3 类分类 在训练模型时 我使用 keras 的 ImageDataGenerator 类对图像应用预处理功能并重新缩放它 现在我的网络在测试集上训练得非常准确 但我不知道如何在单图像预测上应用预处理功能 如
  • 是否可以将 pd.Series 分配给无序 pd.DataFrame 中的列而不映射到索引(即不重新排序值)?

    在 Pandas 中创建或分配新列时 我发现了一些意外的行为 当我对 pd DataFrame 进行过滤或排序 从而混合索引 然后从 pd Series 创建新列时 Pandas 会重新排序该系列以映射到 DataFrame 索引 例如 d
  • 检测 IDLE 的存在/如何判断 __file__ 是否未设置

    我有一个脚本需要使用 file 所以我了解到 IDLE 没有设置这个 有没有办法从我的脚本中检测到 IDLE 的存在 if file not in globals file is not set 如果你想做一些特别的事情 file 未设置
  • 处理大文件的最快方法?

    我有多个 3 GB 制表符分隔文件 每个文件中有 2000 万行 所有行都必须独立处理 任何两行之间没有关系 我的问题是 什么会更快 逐行阅读 with open as infile for line in infile 将文件分块读入内存
  • 为什么 smtplib.SMTP().sendmail 不发送 DKIM 签名邮件

    我已经在服务器上设置了 postfix 以及 openDKIM 当我跑步时 echo Testing setup mail s Postfix test my email address 我收到电子邮件 邮件标题中有一个DKIM Signa
  • 如何循环遍历字典列表并打印特定键的值?

    我是 Python 新手 有一个问题 我知道这是一个非常简单的问题 运行Python 3 4 我有一个需要迭代并提取特定信息的列表 以下是列表 称为部分 的示例 已截断 数千个项目 state DEAD id phwl type name
  • python sklearn中的fit方法

    我问自己关于 sklearn 中拟合方法的各种问题 问题1 当我这样做时 from sklearn decomposition import TruncatedSVD model TruncatedSVD svd 1 model fit X
  • 如何获取所有mysql元组结果并转换为json

    我能够从表中获取单个数据 但是当我试图获取表上的所有数据时 我只得到一行 cnn execute sql rows cnn fetchall column t 0 for t in cnn description for row in ro
  • 使用 urllib 编码时保持 url 参数有序

    我正在尝试用 python 模拟 get 请求 我有一个参数字典 并使用 urllib urlencode 对它们进行 urlencode 我注意到虽然字典的形式是 k1 v1 k2 v2 k3 v3 urlencoding 后参数的顺序切

随机推荐

  • LeetCode--数组类算法:删除排序数组中的重复项 II

    题目 给定一个排序数组 你需要在原地删除重复出现的元素 使得每个元素最多出现两次 返回移除后数组的新长度 不要使用额外的数组空间 你必须在原地修改输入数组并在使用 O 1 额外空间的条件下完成 示例一 给定 nums 1 1 1 2 2 3
  • 梳理webpack

    一 入门 1 项目初始化 新建一个目录 初始化npm npm init 此时会需要填入一些项目的基本描述 webpack是运行在node环境中的 我们需要安装以下两个npm包 npm i D webpack webpack cli 生成no
  • 【mcuclub】扫码枪-(型号:M100(1D)-TTL)(型号:GM861S)

    一 实物图 型号 M100 1D TTL 只能扫描一维条形码 二 原理图 编号 名称 功能 1 VCC 电源正 2 GND 电源地 3 TXD 串口数据发送引脚 接单片机上的RX引脚 4 RXD 串口数据接收引脚 接单片机上的TX引脚 三
  • Unity 处理mono内存(堆内存)泄露问题

    先讲解一下mono特性 一个很重要的信息 mono内存从系统里面申请的内存不会返回给系统 mono内存不足的时候会预申请内存 内存大小不定有可能10m有可能5m 最近优化一个mono内存泄露问题 引起mono一直撑大多数都是内存泄露 要不就
  • ArrayBlockingQueue和LinkedBlockingQueue

    ArrayBlockingQueue ArrayBlockingQueue是一个用数组实现的有界阻塞队列 其是线程安全的 内部通过 互斥锁 保护竞争资源 此队列按照先进先出 FIFO 的原则对元素进行排序 队列的头部是在队列中存在时间最长的
  • el-tabs组件切换之前拦截函数异常踩坑记录

    背景 产品需求在离开当前tab之前要对页面填写信息进行校验 若没有任何改动则可以直接切换tab 若有改动 则需要在跳转之前进行拦截 提示用户 当前页面信息未保存 确定离开吗 确定或取消由用户选择 代码实现
  • 逆向工程核心原理——DLL注入——创建远程线程

    什么是DLL注入 dll注入是一种将Windows动态链接库注入到目标进程中的技术 具体的说 就是将dll文件加载到一个进程的虚拟地址空间中 对某个进程进行dll注入 也就意味着dll模块与该进程共用一个进程空间 则这个dll文件就有了操纵
  • 可变频率正弦信号发生器的FPGA实现(Quartus)

    一 说明 实现平台 Quartus17 1 MATLAB2021a和Modelsim SE 64 10 4 二 内容 1 产生一个完整周期的正弦波信号 并保存为 mif文件 2 设计一个ROM 将正弦波信号文件初始化如该ROM中 3 设计一
  • 内存分配---kmalloc

    kmalloc 内存分配引擎是一个功能强大的工具 下面我们来讲解一下这个函数 Kmalloc 函数分配内存时有几个特点 1 获取内存空间时不会对内存空间进行清零 也就是说 分配给它的区域仍然保持着原有的数据 2 它分配的区域在物理内存中也是
  • Ubuntu中火狐浏览器Firefox打不开网页

    浏览器地址栏输入 about config 搜索 general useragent override 无则新建 输入字符串 Mozilla 5 0 X11 Linux x86 64 AppleWebKit 537 36 KHTML lik
  • 2021-09-02防火墙和CDN、Ajax跨域

    欢迎大家一起来Hacking水友攻防实验室学习 渗透测试 代码审计 免杀逆向 实战分享 靶场靶机 求关注 CDN 内容分发网络 Content Delivery Network 简称CDN 是建立并覆盖在承载网之上 由分布在不同区域的边缘节
  • 如何查看mac系统是32位还是64位的操作系统

    一 点击工具栏左上角点击 苹果Logo 标志 关于本机 gt 更多信息 gt 系统报告 gt 左侧栏中 软件 二 打开终端 输入命令 uname a 回车 x86 64 表示系统为64位 i686 表示系统32位的 比如我的 三 在终端输入
  • js实现模糊搜索

    功能一 关键字搜索 总结 1 搜索出的结果 前台先要清空原有表格 tbody empty 2 后台返回的json格式字符串 js eval 专成对象var stus eval msg 在循环进行字符串拼接到表格上 tbody html st
  • Ubuntu上vsftpd安装与多用户目录配置

    vsftpd安装与多用户目录配置 文章配置使用Ubuntu进行配置 CentOS系统的配置也是大同小异 主要理解虚拟用户的加载方式和权限目录的配置 配置目标 在 home vsftpd 目录下有3个子目录分别为folder1 folder2
  • 二叉搜索树的建立和排序

    二叉搜索树的建立和排序 今天面了一家自研 有一道二叉搜索树的题目 但是自己做的不好 就是有几个学生和成绩 使用树来存储 左子树大于等于root 右节点小于root package org example public class Main
  • 《Apache MINA 2.0 用户指南》第二章:基础知识

    最近准备将Apache MINA 2 0 用户指南英文文档翻译给大家 但是我偶然一次百度 发现 Defonds 这位大牛已经翻译大部分文档 原文链接 http mina apache org mina project userguide c
  • LAN9252芯片控制资料

    一 整个ethercat项目开发流程 通过STM32相关学习板 理解EtherCAT协议栈和通信步骤 根据项目需求构建XML 该XML将会由TwinCAT2解析 将相关特STM32程序烧写 修改应用层协议的程序 STM32作为SPI主模式与
  • Faiss流程与原理分析

    1 Faiss简介 Faiss是Facebook AI团队开源的针对聚类和相似性搜索库 为稠密向量提供高效相似度搜索和聚类 支持十亿级别向量的搜索 是目前最为成熟的近似近邻搜索库 它包含多种搜索任意大小向量集 备注 向量集大小由RAM内存决
  • 华为机试题--坐标移动

    题目描述 开发一个坐标计算工具 A表示向左移动 D表示向右移动 W表示向上移动 S表示向下移动 从 0 0 点开始移动 从输入字符串里面读取一些坐标 并将最终输入结果输出到输出文件里面 输入 合法坐标为A 或者D或者W或者S 数字 两位以内
  • 机器学习——从0开始构建自己的GAN网络

    目录 一 前言 二 生成式对抗网络GAN 三 GAN的训练思路 四 数据集 Chinese MNIST 五 代码 python 1 文件展示 2 代码 一 数据预处理 3 代码 二 生成器的构建 4 代码 三 判别器的构建 5 代码 四 图