1.GAN生成mnist

2023-11-07

1.GAN:Generative Adversarial Network
2.生成器:随机生成一个一维的100个随机数(latent_dim)作为输入生成mnist图片

 def build_generator(self):
        model=Sequential()
        model.add(Dense(256,input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))     
        
        model.add(Dense(np.prod(self.img_shape),activation='tanh'))#全连接层,28*28*1个神经元
        model.add(Reshape(self.img_shape))#变成图片的形状
        
        noise=Input(shape=(self.latent_dim,))
        img=model(noise)#建立了从输入100维随机向量》》》》到28,28,1大小的图片》》》生成模型
        return Model(noise,img)

3.判别器:输入图片,输出一个一维判断结果(0或者1)

def build_discriminator(self):
        model=Sequential()
        model.add(Flatten(input_shape=self.img_shape))#将输入的图片扁平化,因为用的全是全连接层
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Dense(1,activation='sigmoid'))#输出是一个维度,并用sigmoid映射到0到1
        img=Input(shape=self.img_shape)
        validity=model(img)#建立了从输入28,28,1图片》》》到输出一个维度的》》》》判别模型
        return Model(img,validity)

4.1构建判别网络:训练判别器

self.discriminator=self.build_discriminator()#此类的评判标准由此类的动作函数生成
self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])#对评判标准的结果进行验证准确度

4.2训练判别器

#训练判别网络
idx=np.random.randint(0,x_train.shape[0],batch_size)#从train训练集里面随机找出batc——size大小(这么多个)的索引值
imgs=x_train[idx]#取出1个batch大小的图片
noise=np.random.normal(0,1,(batch_size,self.latent_dim))#正态分布生成batch_size个100维向量作为输入
gen_imgs=self.generator.predict(noise)#用生成model的predict方法(model内部方法)将输入进行生成输出
d_loss_real=self.discriminator.train_on_batch(imgs,valid)#输入真实图片和标签全1》》到判别model,》》计算判别模型的loss
d_loss_fake=self.discriminator.train_on_batch(gen_imgs,fake)#输入假的图片和标签全0》》到判别model,》计算判别模型的loss d_loss=0.5*np.add(d_loss_real,d_loss_fake)#将两者损失结合作为总损失

5.1构建生成判别网络:训练生成器

self.generator=self.build_generator()#模型的生成者
gan_input=Input(shape=(self.latent_dim,))#定义此类模型的输入层形状   
img=self.generator(gan_input)#将输入送到生成者去生成图片
self.discriminator.trainable=False#此时只生成图片,不进行评判和两者改进
validity=self.discriminator(img)#评判生成的图片
self.combined=Model(gan_input,validity)#建立输入到评判结果的模型
self.combined.compile(loss='binary_crossentropy',optimizer=optimizer)#模型编译

5.2训练生成器

#训练生成网络
 noise=np.random.normal(0,1,(batch_size,self.latent_dim))
 g_loss=self.combined.train_on_batch(noise,valid)#如果输入噪音的输出是1,则正确,输入噪音输出是0,则生成网络需要改进,所以loss累加

6.全部代码

class GAN():
    def __init__(self):
        self.img_rows=28#定义图片属性
        self.img_cols=28
        self.channels=1
        self.img_shape=(self.img_rows,self.img_cols,self.channels)#shape是元组属性值
        self.latent_dim=100#随机生成input的形状属性值
        optimizer=Adam(0.0002,0.5)#确定模型优化器参数属性值
        self.discriminator=self.build_discriminator()#此类的评判标准由此类的动作函数生成
        self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])#对评判标准的结果进行验证准确度
        self.generator=self.build_generator()#模型的生成者
        gan_input=Input(shape=(self.latent_dim,))#定义此类模型的输入层形状
        img=self.generator(gan_input)#将输入送到生成者去生成图片
        self.discriminator.trainable=False#此时只生成图片,不进行评判和两者改进
        validity=self.discriminator(img)#评判生成的图片
        self.combined=Model(gan_input,validity)#建立输入到评判结果的模型
        self.combined.compile(loss='binary_crossentropy',optimizer=optimizer)#模型编译
    def build_generator(self):
        model=Sequential()
        model.add(Dense(256,input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))     
        
        model.add(Dense(np.prod(self.img_shape),activation='tanh'))#全连接层,28*28*1个神经元
        model.add(Reshape(self.img_shape))#变成图片的形状
        
        noise=Input(shape=(self.latent_dim,))
        img=model(noise)#建立了从输入100维随机向量》》》》到28,28,1大小的图片》》》生成模型
        return Model(noise,img)
    def build_discriminator(self):
        model=Sequential()
        model.add(Flatten(input_shape=self.img_shape))#将输入的图片扁平化,因为用的全是全连接层
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Dense(1,activation='sigmoid'))#输出是一个维度,并用sigmoid映射到0到1
        img=Input(shape=self.img_shape)
        validity=model(img)#建立了从输入28,28,1图片》》》到输出一个维度的》》》》判别模型
        return Model(img,validity)
    def train(self,epochs,batch_size=128,sample_interval=50):
        (x_train,y_train),(x_test,y_test)=mnist.load_data()
        x_train=x_train/127.5-1#将图片像素值映射到-1到1
        x_train=np.expand_dims(x_train,axis=3)#输入时2维tensor,映射到三维,加了第三维1,表示1个通道
        valid=np.ones((batch_size,1))#batch——size大小的全是1的标签
        fake=np.zeros((batch_size,1))#batch_size大小全是0的标签
        for epoch in range(epochs):
            #训练判别网络
            idx=np.random.randint(0,x_train.shape[0],batch_size)#从train训练集里面随机找出batc——size大小(这么多个)的索引值
            imgs=x_train[idx]#取出1个batch大小的图片
            noise=np.random.normal(0,1,(batch_size,self.latent_dim))#正态分布生成batch_size个100维向量作为输入
            gen_imgs=self.generator.predict(noise)#用生成model的predict方法(model内部方法)将输入进行生成输出
            d_loss_real=self.discriminator.train_on_batch(imgs,valid)#输入真实图片和标签全1》》到判别model,》》计算判别模型的loss
            d_loss_fake=self.discriminator.train_on_batch(gen_imgs,fake)#输入假的图片和标签全0》》到判别model,》计算判别模型的loss
            d_loss=0.5*np.add(d_loss_real,d_loss_fake)#将两者损失结合作为总损失
            #训练生成网络
            noise=np.random.normal(0,1,(batch_size,self.latent_dim))
            g_loss=self.combined.train_on_batch(noise,valid)#如果输入噪音的输出是1,则正确,输入噪音输出是0,则生成网络需要改进,所以loss累加
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))#说明生成网络的性能
            if epoch %sample_interval==0:
                self.sample_images(epoch)
    def sample_images(self,epoch):#画出25张图片
        r,c=5,5
        noise=np.random.normal(0,1,(r*c,self.latent_dim))
        gen_imgs=self.generator.predict(noise)
        gen_imgs=0.5*gen_imgs+0.5
        fig,axs=plt.subplots(r,c)
        cnt=0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0],cmap='gray')
                axs[i,j].axis('off')
                cnt+=1
        fig.savefig('images/%d.png'%epoch)
        plt.close()
if __name__=='__main__':
    if not os.path.exists('./images'):
        os.makedirs('./images')
    gan=GAN()
    gan.train(epochs=600,batch_size=256,sample_interval=50)
            

参考:https://blog.csdn.net/weixin_44791964/article/details/103729797

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

1.GAN生成mnist 的相关文章

  • 活动回顾|解锁 AIGC 密码,探寻企业发展新商机

    5月24日 Google Cloud 与 Cloud Ace 联合主办的线下活动顺利落下帷幕 本次活动 有近 40 位企业精英到场支持 三位 Google Cloud 演讲嘉宾就本次活动主题 为大家带来了比较深度的演讲内容 干货满满 以下的
  • 图文详解GPT-4最强对手Claude2的使用方法

    大家好 我是herosunly 985院校硕士毕业 现担任算法研究员一职 热衷于机器学习算法研究与应用 曾获得阿里云天池比赛第一名 CCF比赛第二名 科大讯飞比赛第三名 拥有多项发明专利 对机器学习和深度学习拥有自己独到的见解 曾经辅导过若
  • 终于有本书把ChatGPT和AIGC讲清了!

    AIGC的各大门派是谁 典型技术都有什么 AIGC为什么在绘画领域先破圈 ChatGPT的有哪些局限性 为何科技企业争相推出大模型 人类的创新能力会被AIGC取代吗 诸如此类的这些话题呈现爆发性增长 频频被科技圈热议 与此同时 AI作画 虚
  • IPO 后,北森不断超越自身

    北森锐意变革的思路值得很多行业借鉴 数科星球原创 作者丨苑晶 编辑丨大兔 在所有 HR SaaS 软件中 北森较为独特 这种独特不光体现在其切入赛道的一体化产品 也体现在它所走过的发展路径 2023 年年中 当人口红利消逝之际 人们对人力资
  • 更新预告:chatGPT知识树。

    从一个知识点出发 无限扩展到无数个子知识点 是学习 了解其他行业知识 专业技能的利器 10分钟 就可以对一个行业 一个专业有大概 又专业的了解 简单列2个例子 让你感受下神的强大 上线时间 2023 9 7 体验地址 https ppwor
  • AIGC产业研究报告2023——语言生成篇

    本文阅读时间 10 分钟 今年以来 随着人工智能技术不断实现突破迭代 生成式AI的话题多次成为热门 而人工智能内容生成 AIGC 的产业发展 市场反应与相应监管要求也受到了广泛关注 为了更好地探寻其在各行业落地应用的可行性和发展趋势 易观对
  • 扩散模型实战(三):扩散模型的应用

    推荐阅读列表 扩散模型实战 一 基本原理介绍 扩散模型实战 二 扩散模型的发展 扩散只是一种思想 扩散模型也并非固定的深度网络结构 除此之外 如果将扩散的思想融入其他领域 扩散模型同样可以发挥重要作用 在实际应用中 扩散模型最常见 最成熟的
  • YOLOv5行人检测

    YOLOv5行人检测 1 数据准备 1 下载数据集 2 整理出jpg和xml 2 进行YOLOV5的部署训练 1 划分数据集 2 生成yolo的txt文件 3 配置自己数据集的文件 4 聚类找anchors 5 配置模型文件 6 训练模型
  • 招募 AIGC 训练营助教 @上海

    诚挚邀请对社区活动感兴趣的你 成为我们近期开展的训练营助教 与我们共同开启这场创新之旅 助教需要参与 协助策划和组织训练营活动 协助招募和筛选学员 协助制定训练营的宣传方案 负责协调和组织各项活动 助教可获得 AIGC知识库 获得社区提供的
  • 最近读的AIGC相关论文思路解读

    AIGC之SD可控生成论文阅读记录 提示 本博客是作者本人最近对AIGC领域相关论文调研后 临时记录所用 所有观点都是来自作者本人局限理解 以及个人思考 不代表对 如果你也正好看过相关文章 发现作者的想法和思路有问题 欢迎评论区留言指正 既
  • AIGC驱动产品开发创新,改变你所知的一切!

    你是否想过 3000年后的饮料是什么味道 9月12日 可口可乐全球创意平台 乐创无界 再度推出全新限定产品 首款联合人工智能 AI 打造的无糖可口可乐 未来3000年 从口味研发到包装设计都体现了AI的深度参与打造 Y3000与AI共创这一
  • ChatGPT办公应用:制作PPT大纲

    正文共 617字 阅读大约需要 4 分钟 解决方案专家必备技巧 您将在4分钟后获得以下超能力 制作PPT大纲 Beezy评级 B级 经过简单的寻找 大部分人能立刻掌握 主要节省时间 推荐人 Kim 编辑者 Yuke PPT技能是一项重要办公
  • ChatGPT发布一年后,搜索引擎的日子还好吗?

    导读 生成式AI 搜索引擎的终结者还是进化加速器 ChatGPT发布刚刚一年 互联网世界已经换了人间 2023年 以ChatGPT和大模型为代表的生成式AI浪潮对全球互联网 云计算 人工智能领域都带来巨大冲击 而且生成式AI在各行各业的应用
  • 法律情境扮演、逆向推理文字游戏、AIGC创作……见证AI极致生产力!

    飞桨星河社区 以飞桨和文心大模型为核心 集开放数据 开源算法 云端GPU算力及大模型开发工具于一体 在大模型范式下 为开发者提供模型与应用的高效开发环境 在成立的5年以来 已汇集660 万AI开发者 覆盖深度学习初学者 在职开发者 企业开发
  • 得帆信息创始人-张桐,受邀出席 BV百度风投AIGC主题论坛

    近日 得帆信息创始人兼CEO张桐 作为百度风投被投代表企业创始人受邀出席 向未来 共成长 BV百度风投AIGC主题论坛 与包括上海市徐汇区相关部门领导 百度集团相关事业部负责人及代表 以及来自国寿资本 中网投 麦顿投资的投资人 BV百度风投
  • 新书推介——《AI摄影绘画与PS优化从入门到精通》

    在这个数字化时代的浪潮中 人工智能技术以其惊人的创造力和创新性席卷全球 从智能助手到自动驾驶 从自然语言处理到机器学习 AI正日益成为我们日常生活和各个领域不可或缺的一部分 摄影和绘画领域也不例外 AI技术为我们提供了前所未有的创作和表达方
  • 10000亿规模AIGC产业,谁会成为下一个“巨头”?

    ChatGPT的热潮带火了大语言模型 也让AIGC插上了效率的翅膀 Midjourney 妙鸭相机等产品相继走入大众用户视线 根据艾瑞咨询的预测 2023年中国AIGC产业规模约为143亿元 而随着相关生态的完善 到2030年 中国AIGC
  • 10000亿规模AIGC产业,谁会成为下一个“巨头”?

    ChatGPT的热潮带火了大语言模型 也让AIGC插上了效率的翅膀 Midjourney 妙鸭相机等产品相继走入大众用户视线 根据艾瑞咨询的预测 2023年中国AIGC产业规模约为143亿元 而随着相关生态的完善 到2030年 中国AIGC
  • 蒙牛×每日互动合作获评中国信通院2023“数据+”行业应用优秀案例

    当前在数字营销领域 品牌广告主越来越追求品效协同 针对品牌主更注重营销转化的切实需求 数据智能上市企业每日互动 股票代码 300766 发挥自身数据和技术能力优势 为垂直行业的品牌客户提供专业的数字化营销解决方案 颇受行业认可 就在不久前举
  • Creator AIGC插件!一句话生成人脸

    近几个月以来 AIGC 一路高歌猛进 让我们见证了一场行业革命 然而 AIGC 在 3D 资产领域却仍是业内的难题 少有突破 小编今天给大家推荐一个 3D 角色 AIGC 利器 ChatAvatar 它可以算是 3D AIGC 领域的一匹黑

随机推荐

  • 计算机网路课程设计——电子邮件客户端的设计与实现——接收邮件(POP3协议)

    上一篇已经写了SMTP发送邮件客户端的代码 https blog csdn net dayexiaofan article details 85257320 这一篇我们来写一下POP3接收方的代码 注意这里的密码也是授权码 看代码 如果你能
  • React、Vue和Angular的优缺点

    React React 是一个用于构建用户界面的 JAVASCRIPT 库 React 主要用于构建 UI 很多人认为 React 是 MVC 中的 V 视图 React 起源于 Facebook 的内部项目 用来架设 Instagram
  • 数组的定义与使用

    一 数组的基本用法 1 什么是数组 数组本质上就是让我们能 批量 创建相同类型的变量 如果我们需要创建一个数据 int a 需要创建两个数据 int a int b 需要创建三个数据 int a int b int c 那如果要创建100万
  • 第五章:存储系统和结构

    5 1存储系统的组成 存储器的分类 1 按存储器在计算机系统中的作用分类 高速缓冲存储器 主存储器 辅助存储器 2 按存取方式分类 随机存取存储器RAM 只读存储器ROM 顺序存取存储器SAM 直接存取存储器DAM 3 按存储介质分类 磁芯
  • 【VScode设置免密登录及出现的问题】

    前言 使用VScode进行远程服务器代码调试时 每次都要输入密码 很麻烦 有木有 之前的操作请看 安装并使用VScode进行远程服务器代码调试及遇到的问题和解决办法 一 打开终端 登录上之后 创建一个新的终端 二 创建公钥和私钥 命令如下
  • Attention! No symbol directories found - please check your native debug configuration</font>

    我出现问题的版本是Android Studio2 2 3 之前项目是正常的 可以调试JNI代码 但是突然有一次不知道什么原因就无法调试 断点无法断下 调试时有这样的警告 Now Launching Native Debug Session
  • java进阶篇--TCP 为什么需要三次握手?

    TCP 协议是我们每天都在使用的一个网络通讯协议 因为绝大部分的网络连接都是建立在 TCP 协议上的 比如你此刻正在看的这篇文章是建立在 HTTP Hypertext Transfer Protocol 超文本传送协议 应用层协议的基础上的
  • 手把手教你微信第三方平台开发

    本文适合想接入第三方平台开发的同学 通过真实经验大致讲解一下相关业务 建议收藏以备不时之需 一 什么是微信开放平台 微信开放平台地址 微信开发平台实际上就是给微信外部人员提供微信能力的平台 我们可以在这个平台创建相关的应用 管理对应的认证
  • React服务端渲染框架Next.js入门之旅三:路由跳转和参数传递

    不带参数 静态路由 带参数 根据参数不同显示不同内容 动态路由 一 路由跳转 标签式跳转 在pages下新建juanA js以及juanB js作为两个跳转页面 juanA js import Link from next link exp
  • Vue => Vue监听组件滚动事件

    在dom元素上加ref 利用this refs recordwrapper获取到元素 添加滚动监听事件 希望得到的结果是滚动触发事件handleScroll 现在情况是失效 并没有监听到滚动动作 或者说滚动动作并没有出发事件 问题 监听事件
  • hadoop之hdfs分布式文件

    架构 HDFS是一个主从 Master Slaves 架构 由一个NameNode和一些DataNode组成 面向文件包含 文件数据 data 和文件元数据 metadata NameNode 负责存储和管理文件元数据 并维护了一个层次型的
  • 动态的为实体字段添加注解/注解属性

    可以动态的给实体添加注解 比如 导出表格的时候 根据条件决定是否导出该字段的列等使用 本例子将所有代码都放入工具类中 实际上有些不能实例化到内存中 只能作为一部分代码放在逻辑中 此种代码以再程序中标注 另一部分是可以持久化到内存 使用完工具
  • 移动端750怎么做响应式

    minimum scale 1 0 这个是同时设置最小缩放比例为1 0 在这里不写 user scalable no 禁用用户缩放功能 这样做的目的是为了确保网页在各种设备上都能够有合适的展示效果 缩放比例的限制可以避免用户过度缩放导致页面
  • JAVA IDEA中sout无法正常弹出,System.out.print,和System.out.println以及其他语句标红报错的问题。

    问题 在写代码时发现sout无法正常识别 println方法和println方法标红报错显示无法解析 问题分析 使用输出函数属于代码 而类中只能容纳变量以及方法 代码应该放在代码块 即方法 中 解决方法 在类中写一个方法 将代码放入方法中
  • macOS下 anaconda 虚拟环境及依赖包管理

    文章目录 环境管理 适用mac 1 2 创建虚拟环境失败后 排查问题 并再次成功创建虚拟环境的过程 依赖包管理 环境管理 适用mac 检查conda版本或是否已经安装 base lzh mac conda version conda 4 1
  • Yolo5の网络结构训练策略

    搬来的可能还是熟人的 抱歉啊 为了自己学习 讲解yolov5模型结构 数据增强 以及训练策略 官方地址 https github com ultralytics yolov5 yolov5模型训练流程 https blog csdn net
  • qt 编译时提示error: multiple definition of

    今天在用QT 5 4 1 编译程序时 提示error multiple definition 错误 以下红色字体为错误提示 D Wind PLT Projects BCS tmp moc Cntrlane cpp 156 error mul
  • 《Graph learning》

    上周发布的 图传播算法 上 中讲了关于图传播算法的基本范式和PageRank算法 本文将延续上周的文章 继续讲解剩下的三个算法 2 HITS HITS Hyperlink Induced Topic Search 另一个典型的图传播算法 其
  • 图形用户界面工具:Tkinter库

    Tkinter是Python默认的图形用户界面 Graphical User Interface GUI 库 Tkinter是T看interface 的缩写 意为Tkinter库是 Tkinter Tcl Tk的pathon接口 Tk它基于
  • 1.GAN生成mnist

    1 GAN Generative Adversarial Network 2 生成器 随机生成一个一维的100个随机数 latent dim 作为输入生成mnist图片 def build generator self model Sequ