GAN(生成对抗网络)Matlab代码详解

2023-11-08

这篇博客主要是对GAN网络的代码进行一个详细的讲解:

首先是预定义:

clear; clc; %%%clc是清除当前command区域的命令,表示清空,看着舒服些 。而clear用于清空环境变量。两者是不同的。

%%%装载数据集

train_x=load('Normalization_wbc.txt');%train_x就是我们希望GAN网络能够生成与其相似的数据。

[m,n]=size(train_x);%m表示train_x有多少行,n表示有多少列。

%%%定义模型

generator=nnsetup([30,15,30]);%第一个30代表第一层有30个神经元,这是要与train_x的维度相同的,最后一个30也是要与train_x的维度相同。

discriminator=nnsetup([30,15,1]);%第一个30要与生成器的最后一层的神经元个数相同,最后一层是1个神经元,输出的是每个样本来自于真实数据的概率。

%%参数设置

batch_size=m; %batchsize表示一次输入多少样本进行训练,因为我的数据量少,直接全部输入进去就行了。

iteration=100;%迭代多少次,或者说走多少次正向传播。

images_num=m;

batch_num=floor(images_num / batch_size);

learning_rate=0.0001;

其次是根据预定义的网络模型构建神经网络:

function nn=nnsetup(architecture)

nn.architecture= architecture;%%%把预定义的网络结构传递给nn(neuron network)这个结构体

nn.layers_count= numel(nn.architecture);% 计算传递过来的网络结构有多少层

%%%%%%%adam优化器需要设置的参数%%%

nn.t=0;

nn.beta1=0.9;

nn.beta2=0.999;

nn.epsilon=10^(-8);

%%%%%%%%%%%%%%%%%%%%%%%

for  i=2:nn.layers_count

nn.layers{i}.w=normrnd(0,0.02,nn.architecture(i-1),nn.architecture(i));%normrnd是指生成正态分布的随机数,第一个数字0表示均值为0,第二个数字0.02表示sigma=0.02,第三个值表示生成的维度大小。例如第三与第四的值分别为30,15,则表示生成30*15的矩阵。

nn.layers{i}.b = normrnd(0, 0.02, 1, nn.architecture(i));%生成偏置
        nn.layers{i}.w_m = 0;%好像是跟权重偏置有关的参数,但是都设置为了0,好像没啥意义。
        nn.layers{i}.w_v = 0;
        nn.layers{i}.b_m = 0;
        nn.layers{i}.b_v = 0;

end 

end

第3部分:正向传播

function nn=nnff(nn,x)

        nn.layers{1}.a=x;%%%将数据集x作为输入层

        for i=2:nn.layers_count %%%%nn.layers_count在传入nn时,就已经是网络的层数了

                input=nn.layers{i-1}.a;

                w=nn.layers{i}.w;

                b=nn.layers{i}.b;

                nn.layers{i}.z=input * w +repmat(b,size(input,1),1);

                if i~=nn.layers_count

                        nn.layers{i}.a=relu(nn.layers{i}.z);%%%%如果不是最后一层,就过relu激活函数

                else

                        nn.layers{i}.a=sigmoid(nn.layers{i}.z);

                end

        end

end

第四部分:反向传播,反向传播又分为生成器的反传,判别器的反传

A.判别器的反传

function nn=nnbp_d(nn, y_h, y)

%判别器的输入是生成器的最后一层,输出的数据Fake data 和我们手头有的真实数据train_x,即real data

n=nn.layers_count;

nn.layers{n}.d=delta_sigmoid_cross_entropy(y_h,y); %%%%nn.layers{n}.d表示最后一层的残差

for i=n-1:-1:2%%%n是i的初始值,1是终止值,-1是步长。即从i=n开始,每次都加 -1,即减1,直到i等于1为止.

        d=nn.layers{i+1}.d;

        w=nn.layers{i+1}.w;

        z=nn.layers{i}.z;

        nn.layers{i}.d=d*w' .*delta_relu(z);

end

for i=2:n

        d=nn.layers{i}.d;

        a=nn.layers{i-1}.a;

        nn.layers{i}.dw=a'*d /size(d,1);

        nn.layers{i}.db=mean(d,1);

end

end

第五部分,生成器的反传

function g_net=nnbp_g(g_net,d_net)

        n=g_net.layers_count;

        a=g_net.layers{n}.a;

        g_net.layers{n}.d=d_net.layers{2}.d*d_net.layers{2}.w' .* (a .*(1-a));

        for i=n-1:-1:2

                d=g_net.layers{i+1}.d;

                w=g_net.layers{i+1}.w;

                z=g_net.layers{i}.z;

                g_net.layers{i}.d=d*w' .* delta_relu(z);

        end

        %计算偏导数

        for i=2:n

               d=g_net.layers{i}.d;

                a=g_net.layers{i-1}.a;

                g_net.layers{i}.dw=a'*d/size(d,1); 

                g_net.layers{i}.db=mean(d,1);

        end

end

第六部分,激活函数与损失函数

%%%sigmoid激活函数

function output=sigmoid(x)

        output=1 ./(1+exp(-x));

end

%%%relu激活函数

function output=relu(x)

        output=max(x,0);

end

%%%Leaky_Relu激活函数

function output = Leaky_ReLU(x)
a=2;
if x>=0
    output=x;
else
    output=x/a;
end
end

%%%%损失函数

%relu对x的导数

function output=delta_relu(x)

        output=max(x,0);

        output(output>0)=1;

end

%%%%sigmoid交叉熵损失函数

function result=sigmoid_cross_entropy(logits,labels)

        result=max(logits,0) -logits .*labels +log(1+exp(-abs(logits)));

        result=mean(result);

end

%%%sigmoid交叉熵对logits的导数

function result=delta_sigmoid_cross_entropy(logits, labels)

        temp1=max(logits,0);

        temp1(temp1>0)=1;

        temp2=logits;

        temp2(temp2>0)=-1;

        temp2(temp2<0)=1;

        result=temp1-labels+exp(-abs(logits)) ./ (1+exp(-abs(logits))) .* temp2;

end

第七部分,Adam优化器

%Adam优化器
function nn = nnapplygrade(nn, learning_rate);
    n = nn.layers_count;
    nn.t = nn.t+1;
    beta1 = nn.beta1;
    beta2 = nn.beta2;
    lr = learning_rate * sqrt(1-nn.beta2^nn.t) / (1-nn.beta1^nn.t);
    for i = 2:n
        dw = nn.layers{i}.dw;
        db = nn.layers{i}.db;
        %使用adam更新权重与偏置
        nn.layers{i}.w_m = beta1 * nn.layers{i}.w_m + (1-beta1) * dw;
        nn.layers{i}.w_v = beta2 * nn.layers{i}.w_v + (1-beta2) * (dw.*dw);
        nn.layers{i}.w = nn.layers{i}.w -lr * nn.layers{i}.w_m ./ (sqrt(nn.layers{i}.w_v) + nn.epsilon);
        nn.layers{i}.b_m = beta1 * nn.layers{i}.b_m + (1-beta1) * db;
        nn.layers{i}.b_v = beta2 * nn.layers{i}.b_v + (1-beta2) * (db .* db);
        nn.layers{i}.b = nn.layers{i}.b -lr * nn.layers{i}.b_m ./ (sqrt(nn.layers{i}.b_v) + nn.epsilon);        
    end
    
end

第八部分, 上正餐,开始训练GAN。

for i=1:iteration

        kk=randperm(images_num);

        images_real=train_x;

        noise=unifrnd(0,1,m,30);

        generator=nnff(generator,noise);

        images_fake=generator.layers{generator.layers_count}.a;

        discriminator=nnff(discriminator,images_fake);

        logits_fake=discriminator.layers{discriminator.layers_count}.z;

        discriminator=nnbp_d(discriminator, logits_fake, ones(batch_size,1));

        generator= nnbp_g(generator, discriminator);

        generator=nnbp_g(generator, discriminator);

        generator=nnapplygrade(generator,learning_rate);

        %%%%%%%开始更新判别器

        generator=nnff(generator,noise);

        images_fake=generator.layers{generator.layers_count}.a;

        images=[images_fake;images_real];

        discriminator=nnff(discriminator,images);

        logits=discriminator.layers{discriminator.layers_count}.z;

        logits = discriminator.layers{discriminator.layers_count}.z;
    labels = [zeros(batch_size,1); ones(batch_size,1)];%预定义一个标签,前面的数据是0,后面的是1,也进行了拼接。
    discriminator = nnbp_d(discriminator, logits, labels);%logits与真实的标签进行对比,注意与第29行代码进行对比。
    discriminator = nnapplygrade(discriminator, learning_rate);%更新了辨别器网络的权重。
    
    %----输出loss损失
    c_loss(i,:) = sigmoid_cross_entropy(logits(1:batch_size), ones(batch_size,1));%这是生成器的损失
    d_loss (i,:)= sigmoid_cross_entropy(logits, labels);%判别器的损失

end

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

GAN(生成对抗网络)Matlab代码详解 的相关文章

随机推荐

  • redis 作为缓存总结

    redis缓存服务器笔记 redis是一个高性能的key value存储系统 能够作为缓存框架和队列 但是由于他是一个内存内存系统 这些数据还是要存储到数据库中的 作为缓存框架 create updae delete 同时存到redis和数
  • CentOS 安装nginx最简单办法

    我看了很多都挺复杂 然后查了下管网就有安装步骤 参考这个链接 http nginx org en linux packages html RHEL CentOS 第一步 sudo yum install yum utils 如果yum命令遇
  • 【Python人工智能】Python全栈体系(十六)

    人工智能 第四章 分类模型 一 分类业务模型 分类预测模型与回归不同 回归模型是根据已知的输入和输出寻找一个性能最佳的模型 从而通过未知输出的样本得到连续的输出 而分类模型则是需要得到离散的输出 即根据已知样本的所属类别预测未知输出的样本所
  • 解决RedisTemplate 使用 setIfAbsent 做分布式锁出现返回值为 null 的问题

    我们现在较少使用RedisTemplate 提供的setIfAbsent 做分布式锁 解决并发场景问题 一般使用成熟的三方工具Redisssion来解决分布式锁问题 但是有时候还是需要手动通过RedisTemplate 提供的setIfAb
  • 线圈自感的计算公式

    线圈自感等于总的磁通量除以电流 磁路的磁阻R为 l是磁通的总长度 mu 电路材料的相对磁导率 0 mu 0 0 自由空间的磁导率 4
  • Qt中三个窗口基类(QMainWindow , QWidget , QDialoh)的区别

    在平常qt开发中 通常要写自己的窗口类 那么这个窗口类该继承自哪个类呢 下面就来看下三个窗口基类的区别 1 QMainWindow QMainWindow类提供一个带有菜单条 工具条和一个状态条的主应用程序窗口 主窗口通常提供一个大的中央窗
  • 聚类与分类的定义

    1 聚类的概念 有一堆数据 讲这堆数据分成几类称为聚类 举个例子 比如有一堆水果 我们按着不同的特征分为 苹果 橘子 香蕉三类叫做分类 2 分类的概念 在聚类的前提下 拿来一个新水果 我们按着他的特征 把他分到橘子或者香蕉那类中 叫做分类
  • Spring Data JPA 定义实体类关系:一对一

    JPA使用 OneToOne来标注一对一的关系 实体 Dept 部门 实体 User 用户 Dept和 User 是一对一的关系 这里使用关联字段描述JPA的一对一关系 通过关联字段的方式 一个实体通过关联字段关联到另一个实体的主键 通过关
  • SAP事务码MM17物料主数据批量维护

    这个事务码真的很有意思 因为可以看到物料主数据不同层次的内容 为什么这么说呢 进入MM17
  • mysql 修改数据库字段长度限制_修改数据库字段长度问题,非常紧急!大家来帮忙...

    你的位置 问答吧 gt JAVA gt 问题详情 修改数据库字段长度问题 非常紧急 大家来帮忙 我有一个表里有个主键id char 3 第一个问题 能不能把char 3 改为varchar2 10 alter table sys compa
  • Hadoop安装过程与问题解决

    Hadoop安装过程与问题解决 安装环境 CentOS JDK1 8 如何查看系统版本号 如下图 cat etc redhat release 下载Hadoop 包 可以通过在windows下下载 然后通过linux的客户端工具进行上传 这
  • AI测试中的数据收集

    人工智能 通俗来说就是让机器最大程度的接近于人 如人与人之间沟通 识别图像 奔跑 越障等 例如之前被刷屏的波士顿动力机器人 猎豹移动在世界机器人大会展出的研磨咖啡机器人 图像识别是目前人工智能应用的一大类型 不断地收集 调整 完善测试数据来
  • 【深度长文】人脸识别:人脑认知与计算机算法(五部曲)

    来源 本文经作者 Owl of Minerva 授权转载 链接 https zhuanlan zhihu com HicRhodushicsalta 1 初期预测和介绍 现阶段 人脸识别是人工智能领域最炙手可热的话题之一 Google和Fa
  • 用Python画圣诞树

    拿去给自己所思所念之人 import turtle as t as就是取个别名 后续调用的t都是turtle from turtle import import random as r import time n 100 0 speed f
  • uniapp微信小程序使用axios(vue3+axios+ts版)

    版本号 vue 3 2 45 axios 1 4 0 axios miniprogram adapter 0 3 5 安装axios及axios适配器 适配小程序 yarn add axios axios miniprogram adapt
  • CentOS安装docker

    Docker这两年大受追捧 风光无二 Docker是一个轻量级容器技术 类似于虚拟机技术 xen kvm VMware virtualbox Docker是直接运行在当前操作系统 Linux 之上 而不是运行在虚拟机中 但是也实现了虚拟机技
  • vmware workstation pro 14 虚拟机无法开启、黑屏的解决方案汇总

    方案1 卸载鲁大师 重启 方案2 管理员命令行 输入netsh winsock reset 重启 方案3 360安全管家修复LSP 重启 方案4 卸载14 0 安装12 0 手动导入虚拟机 转载于 https www cnblogs com
  • 【待解决】【OpenCV图像处理】1.27 模板匹配(Template Match)

    1 相关理论 直观介绍 介绍 模板匹配就是在整个图像区域发现与给定子图像匹配的小块区域 所以模板匹配首先需要一个模板图像T 给定的子图像 另外需要一个待检测的图像 源图像S 工作方法 在带检测图像上 从左到右 从上向下计算模板图像与重叠子图
  • 解决ModuleNotFoundError: No module named ‘pip‘

    pip install U pip 把pip搞没了 报错 环境路径 Scripts pip script py is not present 这个错误可以通过两行简单的cmd命令行语句进行改正修复 python m ensurepip py
  • GAN(生成对抗网络)Matlab代码详解

    这篇博客主要是对GAN网络的代码进行一个详细的讲解 首先是预定义 clear clc clc是清除当前command区域的命令 表示清空 看着舒服些 而clear用于清空环境变量 两者是不同的 装载数据集 train x load Norm