CycleGAN算法原理(附源代码,可直接运行)

2023-05-16

前言

CycleGAN是在今年三月底放在arxiv(论文地址CycleGAN)的一篇文章,文章名为Learning to Discover Cross-Domain Relations with Generative Adversarial Networks,同一时期还有两篇非常类似的DualGAN(论文地址:DualGAN)和DiscoGAN(论文地址:DiscoGAN),简单来说,它们的功能就是:自动将某一类图片转换成另外一类图片。不同于GAN和CGAN(上节已经介绍过),CycleGAN不需要配对的训练图像。当然了配对图像也完全可以,不过大多时候配对图像比较难获取。

这里写图片描述
这里写图片描述
配对图像
这里写图片描述
未配对的图像

CycleGAN能做什么?

CycleGAN可以完成GAN和CGAN的工作,如上述配对图像所示,可以从一个特定的场景模式图生成另外一个场景模式图,这两张场景模式中的物体完全相同。除此之外,CycleGAN还可以完成从一个模式到另外一个模式的转换,转换的过程中,物体发生了改变,比如下面的图像中从猫到狗,从男人到女人。

CycleGAN算法原理

如下图所示CycleGAN其实是由两个判别器( Dx D x Dy D y )和两个生成器(G和F)组成,但是为什么要连两个生成器和两个判别器呢?论文中说,是为了避免所有的X都被映射到同一个Y,比如所有男人的图像都映射到范冰冰的图像上,这显然不合理,所以为了避免这种情况,论文采用了两个生成器的方式,既能满足X->Y的映射,又能满足Y->X的映射,这一点其实就是变分自编码器VAE的思想,是为了适应不同输入图像产生不同输出图像。那么下面的四个公式也很清楚了,(1)是判别器Y对X->Y的映射G的损失,判别器X对Y->X映射的损失也非常类似(2)是两个生成器的循环损失,这里其实是 L1 L 1 损失(3)是总损失(4)是对总损失进行优化,先优化D然后优化G和F,这一点和GAN类似

这里写图片描述
这里写图片描述
(1)
这里写图片描述
(2)
这里写图片描述
(3)
这里写图片描述
(4)

源代码

训练源代码

import tensorflow as tf
from model import CycleGAN
from reader import Reader
from datetime import datetime
import os
import logging
from utils import ImagePool

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1')
tf.flags.DEFINE_integer('image_size', 128, 'image size, default: 256')
tf.flags.DEFINE_bool('use_lsgan', True,
                     'use lsgan (mean squared error) or cross entropy loss, default: True')
tf.flags.DEFINE_string('norm', 'instance',
                       '[instance, batch] use instance norm or batch norm, default: instance')
tf.flags.DEFINE_integer('lambda1', 10.0,
                        'weight for forward cycle loss (X->Y->X), default: 10.0')
tf.flags.DEFINE_integer('lambda2', 10.0,
                        'weight for backward cycle loss (Y->X->Y), default: 10.0')
tf.flags.DEFINE_float('learning_rate', 2e-4,
                      'initial learning rate for Adam, default: 0.0002')
tf.flags.DEFINE_float('beta1', 0.5,
                      'momentum term of Adam, default: 0.5')
tf.flags.DEFINE_float('pool_size', 50,
                      'size of image buffer that stores previously generated images, default: 50')
tf.flags.DEFINE_integer('ngf', 64,
                        'number of gen filters in first conv layer, default: 64')

tf.flags.DEFINE_string('X', 'tfrecords/apple.tfrecords',
                       'X tfrecords file for training, default: tfrecords/apple.tfrecords')
tf.flags.DEFINE_string('Y', 'tfrecords/orange.tfrecords',
                       'Y tfrecords file for training, default: tfrecords/orange.tfrecords')
tf.flags.DEFINE_string('load_model', None,
                        'folder of saved model that you wish to continue training (e.g. 20170602-1936), default: None')


def train():
  if FLAGS.load_model is not None:
    checkpoints_dir = "checkpoints/" + FLAGS.load_model
  else:
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(current_time)
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass

  graph = tf.Graph()
  with graph.as_default():
    cycle_gan = CycleGAN(
        X_train_file=FLAGS.X,
        Y_train_file=FLAGS.Y,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambda1=FLAGS.lambda1,
        lambda2=FLAGS.lambda1,
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.beta1,
        ngf=FLAGS.ngf
    )
    G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
    optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    if FLAGS.load_model is not None:
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      restore = tf.train.import_meta_graph(meta_graph_path)
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      step = int(meta_graph_path.split("-")[2].split(".")[0])
    else:
      sess.run(tf.global_variables_initializer())
      step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
              sess.run(
                  [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                  feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
              )
        )
        if step % 100 == 0:
          train_writer.add_summary(summary, step)
          train_writer.flush()

        if step % 100 == 0:
          logging.info('-----------Step %d:-------------' % step)
          logging.info('  G_loss   : {}'.format(G_loss_val))
          logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
          logging.info('  F_loss   : {}'.format(F_loss_val))
          logging.info('  D_X_loss : {}'.format(D_X_loss_val))

        if step % 1000 == 0:
          save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
          logging.info("Model saved in file: %s" % save_path)

        step += 1

    except KeyboardInterrupt:
      logging.info('Interrupted')
      coord.request_stop()
    except Exception as e:
      coord.request_stop(e)
    finally:
      save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
      logging.info("Model saved in file: %s" % save_path)
      # When done, ask the threads to stop.
      coord.request_stop()
      coord.join(threads)

def main(unused_argv):
  train()

if __name__ == '__main__':
  logging.basicConfig(level=logging.INFO)
  tf.app.run()

测试源代码

"""Translate an image to another image
An example of command-line usage is:
python export_graph.py --model pretrained/apple2orange.pb \
                       --input input_sample.jpg \
                       --output output_sample.jpg \
                       --image_size 256
"""

import tensorflow as tf
import os
from model import CycleGAN
import utils

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string('model', 'model/apple2orange.pb', 'model path (.pb)')
tf.flags.DEFINE_string('input', 'samples/real_apple2orange_4.jpg', 'input image path (.jpg)')
tf.flags.DEFINE_string('output', 'output/output_sample3.jpg', 'output image path (.jpg)')
tf.flags.DEFINE_integer('image_size', '256', 'image size, default: 256')

def inference():
  graph = tf.Graph()

  with graph.as_default():
    with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:
      image_data = f.read()
      input_image = tf.image.decode_jpeg(image_data, channels=3)
      input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
      input_image = utils.convert2float(input_image)
      input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

    with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(model_file.read())
    [output_image] = tf.import_graph_def(graph_def,
                          input_map={'input_image': input_image},
                          return_elements=['output_image:0'],
                          name='output')

  with tf.Session(graph=graph) as sess:
    generated = output_image.eval()
    with open(FLAGS.output, 'wb') as f:
      f.write(generated)

def main(unused_argv):
  inference()

if __name__ == '__main__':
  tf.app.run()

实验结果

在这里是以相同物体不同模式下的数据集做训练(由于没有找到不同物体不同模式下的数据,当然你也可以自己做),从苹果到橘子的训练,测试结果如下:

从上图可以看出,苹果的颜色已经改成橘色,效果得到了体现。
源代码链接:CycleGAN source code

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

CycleGAN算法原理(附源代码,可直接运行) 的相关文章

  • PS流详解(载荷H264)

    目录 PS简介标准结构标准H264流结构定长音频帧和其他流式私有数据的结构 PS流封装标准PSH结构PES包结构PSM包结构体 元素流 PS 封装规则H264元素流封装规则音频元素流封装规则私有信息封装规则 PS简介 PS 封装方式需要支持
  • Postman中的authorization

    1 概述 Authorization是验证是否拥有从服务器访问所需数据的权限 当发送请求时 xff0c 通常必须包含参数 xff0c 以确保请求具有访问和返回所需数据的权限 Postman提供了授权类型 xff0c 可以轻松地在Postma
  • 操作pdf,提示startxref not found

    startxref not found多半是文件被损坏了 xff0c 检查一下 xff0c 是不是之前自己写的代码把pdf文件跑崩了 可以尝试重新生成一遍该pdf文件 xff0c 然后再进行操作 或者尝试一下 xff1a https www
  • FTP 530未登录

    提供一种思路 xff1a 如果说FTP服务器已开 xff0c 服务器也能ping通 就得考虑是不是我们在FTP服务器上设置的默认路径有问题 xff08 不符合我们的需求 xff09 Windows10下 xff0c FTP设置默认位置 xf
  • 开源个小demo

    https github com UnderADome epms 内部项目管理
  • LDAP的基本知识

    https zhuanlan zhihu com p 147768058 https www cnblogs com gaoyanbing p 13967860 html
  • 「权威发布」2019年电赛最全各类题目细节问题解答汇总

    点击上方 大鱼机器人 xff0c 选择 置顶 星标公众号 福利干货 xff0c 第一时间送达 xff01 各位朋友大家上午好 xff0c 今天是比赛的第二天 xff0c 许多朋友都给我发消息 xff0c 我不是不回 xff0c 我实在是回不
  • Unable to find explicit activity class

    做项目从一个activity逐渐转向到使用多个activity xff0c 这个时候新手就容易出现一个问题 xff0c 忘了给activity在AndroidManifest xml中注册 打开日志 xff0c 在遇到这个报错信息的时候 x
  • Errors running builder 'Maven Project Builder'

    由于第一次玩maven的时候 xff0c 很多东西都还是懵懵懂懂 xff0c 不是很清楚 xff0c 不知道怎么把Myeclipse中的maven配置弄坏了 xff0c 从外部导入maven项目的时候 xff0c 总会报一些错误 xff1a
  • Type handler was null on parameter mapping for property '__frch_id_0'

    1 Type handler was null on parameter mapping for property frch id 0 2 Type handler was null on parameter mapping or prop
  • 如何解决error: failed to push some refs to 'xxx(远程库)'

    在使用git 对源代码进行push到gitHub时可能会出错 xff0c 信息如下 此时很多人会尝试下面的命令把当前分支代码上传到master分支上 git push u origin master 但依然没能解决问题 出现错误的主要原因是
  • expected an indented block

    Python中没有分号 xff0c 用严格的缩进来表示上下级从属关系 导致excepted an indented block这个错误的原因一般有两个 xff1a 1 冒号后面是要写上一定的内容的 xff08 新手容易遗忘这一点 xff09
  • C 实现TCP服务端(select、poll、epoll)

    使用C简单的实现一个tcp server xff0c 包括常规server 多线程实现server select实现server poll实现server epoll实现server IO模型原理可以看上一篇文章 常规模式 define M
  • UART串口通信

    目录 一 通信特点二 通信应用三 接线示意图三 UART通信协议四 STM32F4 串口使用1 资源分布2 特性3 UART框图4 使用方法5 相关库函数6 函数实例 五 实战 上位机控制开发板小灯 一 通信特点 异步 串行 全双工 一般描
  • 项目:文件搜索助手(FileSeeker)

    目录 1 项目简介 2 项目源代码 3 相关技术 4 实现原理 5 项目架构图 6 项目功能 7 测试报告 7 1 测试用例 7 2 测试环境 7 3 测试结论 7 3 1 功能测试 7 3 2 性能测试 7 3 3 兼容性 7 3 4 容
  • cocos2d实现2D地图A*广度路径算法

    h ifndef HELLOWORLD SCENE H define HELLOWORLD SCENE H include 34 cocos2d h 34 USING NS CC enum PatchFront Uper 61 1 Down
  • Keil 中,仿真调试查看局部变量值总是显示<not in scope>

    原因 xff1a 编译器把代码优化掉了 xff0c 直接导致在仿真中变量根本没有分配内存 xff0c 也就无法查看变量值 以后调试中遇到这种情况的解决办法 xff1a 核心思想是 xff1a 让变量值在代码中被读取其内存值 1 把变量定义为
  • 联合体在串口通讯中的妙用

    背景 本文主要涉及到的是一种串口通讯的数据处理方法 xff0c 主要是为了解决浮点数在串口通讯中的传输问题 xff1b 通常而言 xff0c 整形的数据类型 xff0c 只需进行移位运算按位取出每个字节即可 xff0c 那么遇到浮点型的数据
  • Linux 平均负载

    本文首发自公众号 LinuxOK xff0c ID 为 xff1a Linux ok 关注公众号第一时间获取更新 xff0c 分享不仅技术文章 xff0c 还有关于职场生活的碎碎念 在 Linux 系统中 xff0c 所谓平均负载 xff0
  • Linux 进程状态

    Linux 进程状态是平时排查问题 程序稳定性测试的基础知识 xff0c 查看进程状态的常用工具有 top 和 ps 以 top 的输出为例 xff1a S 列 xff08 Status xff09 表示进程的状态 xff0c 图中可见 D

随机推荐

  • Docker 是什么

    本文首发自公众号 LinuxOK xff0c ID为 xff1a Linux ok xff0c 关注公众号第一时间获取更新 xff0c 分享记录职场开发过程中所见所感 Docker 是一个用 GO 语言实现的开源项目 xff0c 它可以将应
  • 哈希表示例

    哈希表的意义在于高效查找 对于查找来说 xff0c 如果数据量特别大 xff0c 二分查找和哈希表算法十分有用了 二分查找前面已经讲过 xff0c 现来讲讲哈希表算法 就像输入数据数组下标返回数组元素一样 xff0c 这样的查找方式是最高效
  • RS-485通讯协议

    1 硬件层协议 通讯协议主要是实现两个设备之间的数据交换功能 xff0c 通讯协议分硬件层协议和软件层协议 硬件层协议决定数据如何传输问题 xff0c 比如要在设备1向设备2发送0x63 xff0c 0x63的二进制数为0110 0011
  • udp通讯中的connect()和bind()函数

    本文收录于微信公众号 LinuxOK xff0c ID为 xff1a Linux ok xff0c 关注公众号第一时间获取更多技术学习文章 udp是一个基于无连接的通讯协议 xff0c 通讯基本模型如下 可以看出 xff0c 不论是在客户端
  • c语言和c++的相互调用

    本文收录于微信公众号 LinuxOK xff0c ID为 xff1a Linux ok xff0c 关注公众号第一时间获取更多技术学习文章 在实际项目开发中 xff0c c和c 43 43 代码的相互调用是常见的 xff0c c 43 43
  • MSVC 版本号对应

    MSVC 43 43 14 0 MSC VER 61 61 1900 Visual Studio 2015 MSVC 43 43 12 0 MSC VER 61 61 1800 Visual Studio 2013 MSVC 43 43 1
  • SPI通讯协议介绍

    来到SPI通讯协议了 废话两句 xff0c SPI很重要 xff0c 这是我在学校时候听那些单片机开发工程师说的 出来实习 xff0c 到后来工作 xff0c 确实如此 xff0c SPI的使用很常见 xff0c 那么自然重要咯 SPI S
  • Qt多线程中的信号与槽

    1 Qt对象的依附性和事务循环 QThread继承自QObject xff0c 自然拥有发射信号 定义槽函数的能力 QThread默认声明了以下几个关键信号 信号只能声明不能定义 xff1a 1 线程开始运行时发射的信号 span clas
  • TCP/IP协议四层模型

    本文收录于微信公众号 LinuxOK xff0c ID为 xff1a Linux ok xff0c 关注公众号第一时间获取更多技术学习文章 接下来的学习重心会放在Linux网络编程这一块 xff0c 我的博客也会随之更新 参照的书籍有 Li
  • 常见的DoS攻击

    本文收录于微信公众号 LinuxOK xff0c ID为 xff1a Linux ok xff0c 关注公众号第一时间获取更多技术学习文章 拒绝服务攻击DoS Denial of Service xff1a 使系统过于忙碌而不能执行有用的业
  • stm32的can总线理解及应用——程序对应stm32f103系列

    CAN 是Controller Area Network 的缩写 xff08 以下称为CAN xff09 xff0c 是ISO国际标准化的串行通信协议 它的通信速度较快 xff0c 通信距离远 xff0c 最高1Mbps xff08 距离小
  • 多视图几何三维重建实战系列之MVSNet

    点击上方 计算机视觉工坊 xff0c 选择 星标 干货第一时间送达 1 概述 MVS是一种从具有一定重叠度的多视图视角中恢复场景的稠密结构的技术 xff0c 传统方法利用几何 光学一致性构造匹配代价 xff0c 进行匹配代价累积 xff0c
  • LiLi-OM: 走向高性能固态激光雷达惯性里程计和建图系统

    点击上方 计算机视觉工坊 xff0c 选择 星标 干货第一时间送达 编辑丨当SLAM遇见小王同学 声明 本文只是个人学习记录 xff0c 侵权可删 论文版权与著作权等全归原作者所有 xff0c 小王自觉遵守 中华人民共和国著作权法 与 伯尔
  • LITv2来袭 | 使用HiLo Attention实现高精度、快速度的变形金刚,下游任务均实时

    点击上方 计算机视觉工坊 xff0c 选择 星标 干货第一时间送达 作者丨ChaucerG 来源丨集智书童 近两年来 xff0c ViT 在计算机视觉领域的取得了很多重大的突破 它们的高效设计主要受计算复杂度的间接度量 xff08 即 FL
  • 问答|多重曝光相关论文有哪些?

  • ECCV 2022 | 清华&腾讯AI Lab提出REALY: 重新思考3D人脸重建的评估方法

    作者丨人脸人体重建 来源丨人脸人体重建 编辑丨极市平台 极市导读 本文围绕3D人脸重建的评估方式进行了重新的思考和探索 作者团队通过构建新数据集RELAY xff0c 囊括了更丰富以及更高质量的脸部区域信息 xff0c 并借助新的流程对先前
  • Arduino for ESP32-----ESP-NOW介绍及使用

    ESP NOW ESP NOW介绍ESP NOW支持以下特性ESP NOW技术也存在以下局限性获取ESP32的MAC地址ESP NOW单向通信 One way communication ESP32单板间的双向通信一对多通信 xff08 一
  • CLIP还能做分割任务?哥廷根大学提出一个使用文本和图像prompt,能同时作三个分割任务的模型CLIPSeg,榨干CLIP能力...

    点击上方 计算机视觉工坊 xff0c 选择 星标 干货第一时间送达 作者丨小马 来源丨我爱计算机视觉 本篇分享 CVPR 2022 论文 Image Segmentation Using Text and Image Prompts xff
  • Faster RCNN算法解析(附源代码,可以直接运行)

    一 前言知识 1 基于Region Proposal xff08 候选区域 xff09 的深度学习目标检测算法 Region Proposal xff08 候选区域 xff09 xff0c 就是预先找出图中目标可能出现的位置 xff0c 通
  • CycleGAN算法原理(附源代码,可直接运行)

    前言 CycleGAN是在今年三月底放在arxiv xff08 论文地址CycleGAN xff09 的一篇文章 xff0c 文章名为Learning to Discover Cross Domain Relations with Gene