GAN之生成对抗网络(Matlab)

2023-10-27

代码来源

代码全文


clear all; close all; clc;
%% Basic Generative Adversarial Network
%% Load Data  
load('mnistAll.mat')
trainX = preprocess(mnist.train_images); 
trainY = mnist.train_labels;
testX = preprocess(mnist.test_images); 
testY = mnist.test_labels;
%% Settings
settings.latent_dim = 100;
settings.batch_size = 32; settings.image_size = [28,28,1]; 
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;

%% Initialization
%% Generator
paramsGen.FCW1 = dlarray(...
    initializeGaussian([256,settings.latent_dim],.02));
paramsGen.FCb1 = dlarray(zeros(256,1,'single'));
paramsGen.BNo1 = dlarray(zeros(256,1,'single'));
paramsGen.BNs1 = dlarray(ones(256,1,'single'));
paramsGen.FCW2 = dlarray(initializeGaussian([512,256]));
paramsGen.FCb2 = dlarray(zeros(512,1,'single'));
paramsGen.BNo2 = dlarray(zeros(512,1,'single'));
paramsGen.BNs2 = dlarray(ones(512,1,'single'));
paramsGen.FCW3 = dlarray(initializeGaussian([1024,512]));
paramsGen.FCb3 = dlarray(zeros(1024,1,'single'));
paramsGen.BNo3 = dlarray(zeros(1024,1,'single'));
paramsGen.BNs3 = dlarray(ones(1024,1,'single'));
paramsGen.FCW4 = dlarray(initializeGaussian(...
    [prod(settings.image_size),1024]));
paramsGen.FCb4 = dlarray(zeros(prod(settings.image_size)...
    ,1,'single'));

stGen.BN1 = []; stGen.BN2 = []; stGen.BN3 = [];

%% Discriminator
paramsDis.FCW1 = dlarray(initializeGaussian([1024,...
     prod(settings.image_size)],.02));
paramsDis.FCb1 = dlarray(zeros(1024,1,'single'));
paramsDis.BNo1 = dlarray(zeros(1024,1,'single'));
paramsDis.BNs1 = dlarray(ones(1024,1,'single'));
paramsDis.FCW2 = dlarray(initializeGaussian([512,1024]));
paramsDis.FCb2 = dlarray(zeros(512,1,'single'));
paramsDis.BNo2 = dlarray(zeros(512,1,'single'));
paramsDis.BNs2 = dlarray(ones(512,1,'single'));
paramsDis.FCW3 = dlarray(initializeGaussian([256,512]));
paramsDis.FCb3 = dlarray(zeros(256,1,'single'));
paramsDis.FCW4 = dlarray(initializeGaussian([1,256]));
paramsDis.FCb4 = dlarray(zeros(1,1,'single'));

stDis.BN1 = []; stDis.BN2 = [];

% average Gradient and average Gradient squared holders
avgG.Dis = []; avgGS.Dis = []; avgG.Gen = []; avgGS.Gen = [];
%% Train
numIterations = floor(size(trainX,2)/settings.batch_size);
out = false; epoch = 0; global_iter = 0;
while ~out
    tic; 
    trainXshuffle = trainX(:,randperm(size(trainX,2)));
    fprintf('Epoch %d\n',epoch) 
    for i=1:numIterations
        global_iter = global_iter+1;
        noise = gpdl(randn([settings.latent_dim,...
            settings.batch_size]),'CB');
        idx = (i-1)*settings.batch_size+1:i*settings.batch_size;
        XBatch=gpdl(single(trainXshuffle(:,idx)),'CB');

        [GradGen,GradDis,stGen,stDis] = ...
                dlfeval(@modelGradients,XBatch,noise,...
                paramsGen,paramsDis,stGen,stDis);

        % Update Discriminator network parameters
        [paramsDis,avgG.Dis,avgGS.Dis] = ...
            adamupdate(paramsDis, GradDis, ...
            avgG.Dis, avgGS.Dis, global_iter, ...
            settings.lrD, settings.beta1, settings.beta2);

        % Update Generator network parameters
        [paramsGen,avgG.Gen,avgGS.Gen] = ...
            adamupdate(paramsGen, GradGen, ...
            avgG.Gen, avgGS.Gen, global_iter, ...
            settings.lrG, settings.beta1, settings.beta2);
        
        if i==1 || rem(i,20)==0
            progressplot(paramsGen,stGen,settings);
%             if i==1 || (epoch>=0 && i==1) 
%                 h = gcf;
%                 % Capture the plot as an image 
%                 frame = getframe(h); 
%                 im = frame2im(frame); 
%                 [imind,cm] = rgb2ind(im,256); 
%                 % Write to the GIF File 
%                 if epoch == 0
%                   imwrite(imind,cm,'GANmnist.gif','gif', 'Loopcount',inf); 
%                 else 
%                   imwrite(imind,cm,'GANmnist.gif','gif','WriteMode','append'); 
%                 end 
%             end
        end
        
    end

    elapsedTime = toc;
    disp("Epoch "+epoch+". Time taken for epoch = "+elapsedTime + "s")
    epoch = epoch+1;
    if epoch == settings.maxepochs
        out = true;
    end    
end
%% Helper Functions
%% preprocess
function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end
%% extract data
function x = gatext(x)
x = gather(extractdata(x));
end
%% gpu dl array wrapper
function dlx = gpdl(x,labels)
dlx = gpuArray(dlarray(x,labels));
end
%% Weight initialization
function parameter = initializeGaussian(parameterSize,sigma)
if nargin < 2
    sigma = 0.05;
end
parameter = randn(parameterSize, 'single') .* sigma;
end
%% Generator
function [dly,st] = Generator(dlx,params,st)
% fully connected
%1
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,0.2);
% if isempty(st.BN1)
%     [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,params.BNs1);
% else
%     [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,...
%         params.BNs1,st.BN1.mu,st.BN1.sig);
% end
%2
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,0.2);
% if isempty(st.BN2)
%     [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,params.BNs2);
% else
%     [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,...
%         params.BNs2,st.BN2.mu,st.BN2.sig);
% end
%3
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,0.2);
% if isempty(st.BN3)
%     [dly,st.BN3.mu,st.BN3.sig] = batchnorm(dly,params.BNo3,params.BNs3);
% else
%     [dly,st.BN3.mu,st.BN3.sig] = batchnorm(dly,params.BNo3,...
%         params.BNs3,st.BN3.mu,st.BN3.sig);
% end
%4
dly = fullyconnect(dly,params.FCW4,params.FCb4);
% tanh
dly = tanh(dly);
end
%% Discriminator
function [dly,st] = Discriminator(dlx,params,st)
% fully connected 
%1
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,0.2);
dly = dropout(dly);
% if isempty(st.BN1)
%     [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,params.BNs1);
% else
%     [dly,st.BN1.mu,st.BN1.sig] = batchnorm(dly,params.BNo1,...
%         params.BNs1,st.BN1.mu,st.BN1.sig);
% end
%2
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,0.2);
dly = dropout(dly);
% if isempty(st.BN2)
%     [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,params.BNs2);
% else
%     [dly,st.BN2.mu,st.BN2.sig] = batchnorm(dly,params.BNo2,...
%         params.BNs2,st.BN2.mu,st.BN2.sig);
% end
%3
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,0.2);
dly = dropout(dly);
%4
dly = fullyconnect(dly,params.FCW4,params.FCb4);
% sigmoid
dly = sigmoid(dly);
end
%% modelGradients
function [GradGen,GradDis,stGen,stDis]=modelGradients(x,z,paramsGen,...
    paramsDis,stGen,stDis)
[fake_images,stGen] = Generator(z,paramsGen,stGen);
d_output_real = Discriminator(x,paramsDis,stDis);
[d_output_fake,stDis] = Discriminator(fake_images,paramsDis,stDis);

% Loss due to true or not
d_loss = -mean(.9*log(d_output_real+eps)+log(1-d_output_fake+eps));
g_loss = -mean(log(d_output_fake+eps));

% For each network, calculate the gradients with respect to the loss.
GradGen = dlgradient(g_loss,paramsGen,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end
%% progressplot
function progressplot(paramsGen,stGen,settings)
r = 5; c = 5;
noise = gpdl(randn([settings.latent_dim,r*c]),'CB');
gen_imgs = Generator(noise,paramsGen,stGen);
gen_imgs = reshape(gen_imgs,28,28,[]);

fig = gcf;
if ~isempty(fig.Children)
    delete(fig.Children)
end

I = imtile(gatext(gen_imgs));
I = rescale(I);
imagesc(I)
title("Generated Images")
colormap gray

drawnow;
end
%% dropout
function dly = dropout(dlx,p)
if nargin < 2
    p = .3;
end
n = p*10;
mask = randi([1,10],size(dlx));
mask(mask<=n)=0;
mask(mask>n)=1;
dly = dlx.*

代码展示

         本代码采用MNIST手写数字数据集(训练集60000个,测试集10000个,本例中采用训练集数据),可实现数据集自动下载,最大的epoch次数为50,单个epoch中有1875个batch,batch_size为32,结果迭代如下:

第0次迭代

第10次迭代

第20次迭代

第30次迭代

第40次迭代

 第50迭代

 代码以及相关资料附件

ggg9 

内容简介:

1、上述网址源代码压缩包(本文代码在...\github_repo\GAN下,GAN.m文件)

2、生成对抗网络的源文献

 关注公众号“故障诊断与寿命预测工具箱”,每天进步一点点。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

GAN之生成对抗网络(Matlab) 的相关文章

  • 如何在 MATLAB 编译的应用程序中运行外部 .m 代码? [复制]

    这个问题在这里已经有答案了 我有一个 MATLAB 项目 我使用 MCC 对其进行编译以获得单个可执行文件 然后我想知道外部程序员是否可以在 exe 中执行他的一些 m 文件 而无需重新编译整个项目 重点是提供一个应用程序 其他开发人员可以
  • 将 kinect RGB 和深度值转换为 XYZ 坐标

    我正在寻找一种简单的方法将 kinect RGB 和深度值转换为 XYZ 坐标 使用 MATLAB 我的目标是一个输入为以下内容的函数 每个点的 RGB 和深度值Kinect相机 并输出 每个点的 x y 和 z 值 RGB 深度 RGB
  • 在 MATLAB 中绘图后恢复轴

    从文本文件绘制多种方法的输出后 未显示轴的右侧和上侧 我需要拥有它们并将它们加粗 就像当前的轴一样 绘制的数据来自存储每种方法数据的文件 每个数据文件都是一个 256x2 文件 包含 0 1 之间的值 第一列是精度 第二列是召回率 figu
  • 图像梯度角计算

    我实际上是按照论文的说明进行操作的 输入应该是二进制 边缘 图像 输出应该是一个新图像 并根据论文中的说明进行了修改 我对指令的理解是 获取边缘图像的梯度图像并对其进行修改 并使用修改后的梯度创建一个新图像 因此 在 MATLAB Open
  • Deploytool for MATLAB R2013b 不起作用,发生了什么变化?

    多年来我一直在使用集成deploytool为我的同事创建易于分发的 exe 文件 我几天前安装了R2013b 但无法使用deploytool不再了 尝试打包时的日志文件给出了以下内容 ant
  • matlab 中的动画绘图

    我正在尝试创建一个三角形的动画图 最终结果应该是十个三角形 后面跟着两个更大的三角形 后面跟着一条直线 使用matlab文档 https de mathworks com help matlab ref drawnow html 我最终得到
  • 保存符号方程以供以后使用?

    From here http www mathworks com help releases R2011a toolbox symbolic brvfu8o 1 html brvfxem 1 我正在尝试求解这样的符号方程组 syms x y
  • MATLAB 变量传递和惰性赋值

    我知道在 Matlab 中 当将新变量分配给现有变量时 会进行 惰性 评估 例如 array1 ones 1 1e8 array2 array1 的价值array1不会被复制到array2除非元素array2被修改 由此我推测Matlab中
  • Matlab:2行10列的子图

    如何在 matlab 中绘制 20 幅图像 2 行 10 列 我知道我必须使用 子图 功能 但我对给出的参数感到困惑 我尝试给予 子图 2 10 行索引 列索引 但它似乎不起作用 请帮忙 的前两个参数subplot函数分别给出图中子图的总行
  • Mathworks 生成 Matlab HTML 文档的方法是什么?

    我正在开发共享的 Matlab 代码 我们希望在本地网络中将生成的文档作为可搜索的 HTML 文档共享 我知道以下生成文档的方法 编写一个类似于 C 文件的转换器 这是在中完成的将 Doxygen 与 Matlab 结合使用 http ww
  • MATLAB parfor 和 C++ 类 mex 包装器(需要复制构造函数?)

    我正在尝试使用概述的方法将 C 类包装在 matlab mex 包装器中here http www mathworks com matlabcentral newsreader view thread 278243 基本上 我有一个初始化
  • 绘制布朗运动 matlab

    首先 我只想说我不太习惯使用matlab 但我需要一个作业 我应该创建一个 布朗运动 我的代码目前如下所示 clf hold on prompt Ge ett input size input prompt numParticles inp
  • 如何在Matlab中将图像从笛卡尔坐标更改为极坐标?

    我正在尝试将图像的像素从 x y 坐标转换为极坐标 但我遇到了问题 因为我想自己编写该函数 这是我到目前为止所做的代码 function newImage PolarCartRot read and show the image image
  • 二维随机微分方程 (SDE)

    我第一次研究随机微分方程 我正在寻求模拟和求解二维随机微分方程 模型如下 dp F t p dt G t p dW t where p 是一个 2 1 向量 p theta t phi t F是列向量 F sin theta Psi cos
  • 用于读取csv写入数组的c++程序;然后操作并打印到文本文件中(已经用 matlab 编写)

    我想知道是否有人可以帮助我 我正在尝试构建一个程序 从 csv 文件中读取大小未知的浮点数大数据块 我已经在 MATLAB 中编写了此代码 但想要编译和分发此代码 因此转向 C 我只是在学习并尝试阅读本文以开始 7 5 19892 4 23
  • 可以避免迭代元胞数组时的“s{1} 烦恼”吗?

    The s 1 标题的 烦恼 指的是下面的 for 块中的第一行 for s some cell array s s 1 unpeel the enclosing cell do stuff with s end This s s 1 业务
  • MATLAB 图形渲染:OpenGL 与 Painters?

    当谈到使用哪个渲染器来处理 MATLAB 图形或何时它很重要时 我一无所知 但我遇到过某些示例 其中does matter plot 0 0 ko markersize 50 linewidth 8 set gcf renderer ope
  • Python 中的 eig(a,b) 给出错误“需要 1 个位置参数,但给出了 2 个”

    根据https docs scipy org doc numpy 1 15 0 user numpy for matlab users html https docs scipy org doc numpy 1 15 0 user nump
  • 如何从一个清晰的例子计算二维图像中的吉布斯能量

    我有一个关于矩阵的有趣问题 在吉布斯分布中 吉布斯能量U x 可以计算为 这是所有可能的派系 C 上的派系势 Vc x 的总和 右图 团 c 被定义为 S 中站点的子集 x 蓝色像素的邻域是左图中黄色像素的邻居 其中每对不同的站点都是邻居
  • matlab 中的 for 或 while 循环

    我刚刚开始在编程课的 matlab 中使用 for 循环 基本的东西对我来说很好 但是我被要求 使用循环创建一个 3 x 5 矩阵 其中每个元素的值是其行号其列号除以行号和列号之和的幂 例如元素 2 3 的值为 2 3 2 3 1 6 那么

随机推荐

  • MySQL - 唯一索引

    唯一索引 所谓唯一索引 就是在创建索引时 限制索引的字段值必须是唯一的 通过该类型的索引可以比普通索引更快速地查询某条记录 1 创建表时定义索引 CREATE TABLE tablename propname1 type1 propname
  • 利用用adb查看短信、通讯录、拨号的应用数据

    利用用adb查看短信 通讯录 拨号的应用数据 1 进入root界面 adb shell su 2 查看你想要查看的应用手机目录下应用界面的包名 adb shell dumpsys activity findstr mFocusedActiv
  • 第 0004 题: 任一个英文的纯文本文件,统计其中的单词出现的个数

    import os file open wz txt mode r dict for line in file h line line split for key in h line if key 1 gt a and key 1 lt z
  • Centos-启动network报错RTNETLINK answers: File exists解决方法

    背景 今天在Vcenter上 用模板克隆了一个虚拟机 启动之后 网卡启动不了 报错如下 RTNETLINK answers File exists 说明 环境 Centos6 6 X64 网卡两个 原因 由于用模板克隆虚拟机 所以网卡的配置
  • ts类型体操Concat

    533 Concat by Andrey Krasovsky bre30kra69cs easy array Question Implement the JavaScript Array concat function in the ty
  • 总结的一些MySQL索引相关的知识点

    博客迁移 http cui zhbor com article 14 html MySQL索引 有很多很多的东西需要去学习 我会写一些自己的总结 这些总结主要是平时运用在实际项目中的 有很多的经验往往设计表的人很清楚 但是总是有 这个东西就
  • 【实例分割】3、Mask Scoring R-CNN

    文章目录 摘要 1 引言 2 相关工作 2 1 实例分割 2 2 检测得分校正 3 方法 3 1 动机 3 2 Mask scoring in Mask R CNN 4 实验 4 1 实验细节 4 2 定量结果分析 4 3 消融学习 4 4
  • 时序逻辑和组合逻辑

    一 组合逻辑与时序逻辑的对比 1 组合逻辑的输出状态与输入直接相关 时序逻辑还必须在时钟上升沿触发后输出新值 2 组合逻辑容易出现竞争 冒险现象 时序逻辑一般不会 3 组合逻辑的时序较难保证 时序逻辑更容易达到时序收敛 时序逻辑可控 4 组
  • IP代理安全吗?如何防止IP被限制访问?

    你是否遇到过可以正常上网 但访问某个网站却被禁止 注册某个网站账号 却被封号 那都是因为IP出现问题 您的IP地址透露很多关于您的信息 包括您的位置和互联网活动 在本文中 我们将一起了解IP地址 网站如何利用它来跟踪您 以及与IP代理如何帮
  • 求助:stm32+proteus+adc采集电压仿真显示为零

    求助一下大佬 因为板子上的oled不是ssd1306驱动的所以现在只能学习跑仿真 在学adc采集电压的实验 OLED显示没问题 现在的问题是采集不到电压 显示总是0 麻烦好心人帮我看看是哪里出了问题 软件用的keil mdk5 24 pro
  • Game【HDU-6873】【Splay】

    2020 Multi University Training Contest 9 G题 题意 有N个有各自高度的位置 按1 N从左到右排列 现在我们有两种操作 x y将第x列 第y行的方块 包括它上面的方块从右往左的移动过去 同时推动前面的
  • 【导入导出测试用例编写】

    导入导出测试用例编写 一 导出模板测试用例 二 导出数据测试用例 三 导入数据测试用例 一 导出模板测试用例 1 检查模板是否可以正常下载正常打开 2 检查模板表头格式展示是否正确 与系统列表中的字段是否一致 3 检查必填项 字段长度 字段
  • 接口性能 指标

    接口测试响应时间 通用得接口响应使时间分布情况 100ms为优良 500ms为及格 1000ms以上为不可忍受 金融接口响应时间得分布情况 100ms为优良 200ms为及格 300ms以上为不可忍受
  • 动态链接库(一)--动态链接库简介

    写在前面 自从微软推出的第一个版本的Windows操作系统以来 动态链接库 DLL 一直就是Windows操作系统的基础 动态链接库通常不能直接运行 也不能接收消息 它们一直是独立的文件 其中包含能被可执行程序或其他DLL文件调用来完成某项
  • 【解决ElementUI 和Antd的对话弹窗样式冲突问题】

    项目中使用了Antd 和element UI两种UI库 Antd是全局样式 element ui则是按需引入 在使用element ui的页面处点击退出 弹出的对话框就会样式失效 首先在随便一个地方点击退出登录看一下正常效果 再打开F12查
  • Unity3D中的ref、out、Params三种参数的使用

    目录 ref out Params ref 作用 将一个变量传入一个函数中进行 处理 处理 完成后 再将 处理 后的值带出函数 语法 使用时形参和实参都要添加ref关键字 using System Collections using Sys
  • JavaSE学习总结:面向对象编程

    Java面向对象编程 1 类与对象 1 1面向对象的理解 1 1 1面向对象和面向过程的区别 1 1 2面向对象的好处 1 1 3面向对象的思考步骤 1 2类与对象 1 2 1什么是类 1 2 2什么是对象 1 2 3二者的区别 1 2 4
  • ubuntu设置环境变量

    vim bashrc export VCPKG FORCE SYSTEM BINARIES 1 export VCPKG HOME PATH vcpkg export X VCPKG ASSET SOURCES x azurl http 1
  • GPT-4最强竞争对手Claude 2震撼发布,据说超过GPT-4?

    OpenAI 发布了 GPT 4 的 API 和令人兴奋的 最强插件 代码解释器 这无疑给竞争对手们敲响了警钟 而最近 Anthropic 旗下的 Claude 揭开了它的第二代面纱 免费使用Claude 2请加微信wyxyellow 相较
  • GAN之生成对抗网络(Matlab)

    代码来源 代码全文 clear all close all clc Basic Generative Adversarial Network Load Data load mnistAll mat trainX preprocess mni