风格迁移背后原理及tensorflow实现

2023-11-05

前言


本文将详细介绍 tf 实现风格迁移的小demo,看完这篇就可以去实现自己的风格迁移了,复现的算法来自论文
Perceptual P e r c e p t u a l LossesforRealTime L o s s e s f o r R e a l − T i m e Style S t y l e Transfer T r a n s f e r and a n d SuperResolution S u p e r − R e s o l u t i o n

GitHub代码链接https://github.com/LDOUBLEV/style_transfer-perceptual_loss 如果感觉有用的话,帮忙给个star吧

本文分为以下部分:
第一节:深度学习在风格迁移上的背后原理;
第二节:风格迁移的代码详解
第三节:总结

图像风格迁移指的是将图像A的风格转换到图像B中去,得到新的图像,取个名字为new B,其中new B中既包含图像B的内容,也包含图像A的风格。如下图所示:
这里写图片描述
从左到右依次为图像A,图像B,图像new B

本文着重介绍基于深度学习技术的风格迁移的原理及其实现,实现使用的工具如下:

  • 框架:Tensorflow 1.4.1
  • 语言:python 2.7
  • 系统:ubuntu 16.04

注:其他条件同样可行,如有问题,欢迎评论、私信

最终效果部分展示:

原图:
这里写图片描述
风格迁移后的图像,右上角那一张明显风格迁移过头了,可以设置style_loss的比例做调整:
这里写图片描述

这里写图片描述
这里写图片描述,最满意的就是左上角那一张了。

第一节:深度学习在风格迁移的背后原理


1.1 背后原理简介

深度学习技术可谓无孔不入,在计算机视觉领域尤为明显,图像分类、识别、定位、超分辨率、转换、迁移、描述等等都已经可以使用深度学习技术实现。其背后的技术可以一言以蔽之:卷积网络具有超强的图像特征提取能力
其中,风格迁移算法的成功,其主要基于以下两点:

  1. 两张图像经过预训练好的分类网络,若提取出的高维特征( highlevel h i g h − l e v e l )之间的欧氏距离越小,则这两张图像内容越相似
  2. 两张图像经过与训练好的分类网络,若提取出的低维特征( lowlevel l o w − l e v e l )在数值上基本相等,则这两张图像越相似,换句话说,两张图像相似等价于二者特征的 Gram G r a m 矩阵具有较小的弗罗贝尼乌斯范数。

基于这两点,就可以设计合适的损失函数优化网络。

1.2 原理解读

对于深度网络来讲,深度卷积分类网络具有良好的特征提取能力,不同层提取的特征具有不同的含义,每一个训练好的网络都可以视为是一个良好的特征提取器,另外,深度网络由一层层的非线性函数组成,可以视为是复杂的多元非线性函数,此函数完成输入图像到输出的映射,因此,完全可以使用训练好的深度网络作为一个损失函数计算器。

Gram G r a m 矩阵的数学形式如下: Gj(x)=AAT G j ( x ) = A ∗ A T
Gram矩阵实际上是矩阵的内积运算,在风格迁移算法中,其计算的是feature map之间的偏心协方差,在feature map 包含着图像的特征,每个数字表示特征的强度,Gram矩阵代表着特征之间的相关性,因此,Gram矩阵可以用来表示图像的风格,因此可以通过Gram矩阵衡量风格的差异性。

1.3 论文解读

本次主要介绍的是论文:Perceptual Losses for Real-Time Style Transfer and Super-Resolution
直接上图:
这里写图片描述
网络框架分为两部分,其一部分是图像转换网络 T T (image transfrom net)和预训练好的损失计算网络VGG-16(loss network),图像转换网络TT以内容图像 x x 为输入,输出风格迁移后的图像yy,随后内容图像 yc y c (也即是 x x ),风格图像ysys,以及 y y ′ 输入vgg-16计算特征,损失计算如下:
内容损失: lφ;jfeat(y;y)=1CjHjWj||φj(y)φj(y)||2 l f e a t φ ; j ( y ; y ) = 1 C j H j W j | | φ j ( y ′ ) − φ j ( y ) | | 2 , 其中 φ φ 代表深度卷积网络VGG-16

感知损失如下:lφ;jstyle(y;y)=||Gj(y)Gj(y)||2Flstyleφ;j(y;y)=||Gj(y)Gj(y)||F2,其中G是Gram矩阵,计算过程如下:

Gφj(x)c,c=||Gφj(y)Gφj(y)|| G j φ ( x ) c ′ , c = | | G j φ ( y ′ ) − G j φ ( y ) | |

总损失定义如下: Losstotal=γ1lfeat+γ2lstyle L o s s t o t a l = γ 1 l f e a t + γ 2 l s t y l e

其中图像转换网络T定义如下图:
这里写图片描述

网络结构三个卷积层后紧接着5个残差块,然后两个上采样(邻近插值的方式),最后一个卷积层,第一层和最后一层的卷积核都是9x9,其余均为3x3。每个残差块中包含两层卷积。

第二节:代码详解


本次实验主要基于tf的slim模块,slim封装的很好,调用起来比较方便。接下来分为网络结构,损失函数,以及训练部分分别做介绍。

2.1 网络结构

slim = tf.contrib.slim
# 定义卷积,在slim中传入参数
def arg_scope(weight_decay=0.0005):
    with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.conv2d_transpose],
                        activation_fn=None,
                        weights_regularizer=slim.l2_regularizer(weight_decay),
                        biases_initializer=tf.zeros_initializer()):
        with slim.arg_scope([slim.conv2d, slim.conv2d_transpose], padding='SAME') as arg_sc:
            return arg_sc

接下来就是图像转换网络结构部分,仿照上图,不过这里有一个trick,就是在输入之前对图像做padding,经过网络后再把padding的部分去掉,防止迁移后出现边缘效应。

def gen_net(imgs, reuse, name, is_train=True):
    imgs = tf.pad(imgs, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='REFLECT')
    with tf.variable_scope(name, reuse=reuse) as vs:
        # encoder : three convs layers
        out1 = slim.conv2d(imgs, 32, [9, 9], scope='conv1')
        out1 = relu(instance_norm(out1))

        out2 = slim.conv2d(out1, 64, [3, 3], stride=2, scope='conv2')
        out2 = instance_norm(out2)
        # out2 = relu(img_scale(out2, 0.5))

        out2 = slim.conv2d(out2, 128, [3, 3], stride=2, scope='conv3')
        out2 = instance_norm(out2)
        # out2 = relu(img_scale(out2, 0.5))

        # transform
        out3 = res_module(out2, 128, name='residual1')
        out3 = res_module(out3, 128, name='residual2')
        out3 = res_module(out3, 128, name='residual3')
        out3 = res_module(out3, 128, name='residual4')
        # decoder
        out4 = img_scale(out3, 2)
        out4 = slim.conv2d(out4, 64, [3, 3], stride=1, scope='conv4')
        out4 = relu(instance_norm(out4))
        # out4 = img_scale(out4, 128)

        out4 = img_scale(out4, 2)
        out4 = slim.conv2d(out4, 32, [3, 3], stride=1, scope='conv5')
        out4 = relu(instance_norm(out4))
        # out4 = img_scale(out4, 256)

        out = slim.conv2d(out4, 3, [9, 9], scope='conv6')
        out = tf.nn.tanh(instance_norm(out))

        variables = tf.contrib.framework.get_variables(vs)

        out = (out + 1) * 127.5

        height = out.get_shape()[1].value  # if is_train else tf.shape(out)[0]
        width = out.get_shape()[2].value  # if is_train else tf.shape(out)[1]

        out = tf.image.crop_to_bounding_box(out, 10, 10, height-20, width-20)
        # out = tf.reshape(out, imgs_shape)

    return out, variables

其中instance_norm是归一化部分[5],res_module是残差块,image_scale是采样部分,scale因子是2表示上采样,特征图扩大2倍:

def img_scale(x, scale):
    weight = x.get_shape()[1].value
    height = x.get_shape()[2].value

    try:
        out = tf.image.resize_nearest_neighbor(x, size=(weight*scale, height*scale))
    except:
        out = tf.image.resize_images(x, size=[weight*scale, height*scale])
    return out

# net = slim.conv2d(net, 4096, [1, 1], scope='fc7')

def res_module(x, outchannel, name):
    with tf.variable_scope(name_or_scope=name):
        out1 = slim.conv2d(x, outchannel, [3, 3], stride=1, scope='conv1')
        out1 = relu(out1)
        out2 = slim.conv2d(out1, outchannel, [3, 3], stride=1, scope='conv2')
        out2 = relu(out2)

        return x+out2

def instance_norm(x):
    epsilon = 1e-9

    mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)

    return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon)))

2.2图的构建

此部分流程:读取训练数据(coco数据集) − − 读取风格图像 − − 并输入图像转换网络计算出转换后的图像gen_img − − 原始图像,风格图像,转换后的图像一同输入VGG计算loss − − VGG权重加载

 def build_model(self):
        # data_path = '/home/liu/Tensorflow/BEGAN/Data/celeba/img_align_celeba'
        data_path = '/home/liu/Downloads/train2014'
        # 加载训练数据(coco数据集)
        imgs = load_data.get_loader(data_path, self.batch_size, self.img_size)
        # 加载风格图像
        style_imgs = load_style_img()

        with slim.arg_scope(model.arg_scope()):
            # 图像转换网络
            gen_img, variables = model.gen_net(imgs, reuse=False, name='transform')

            with slim.arg_scope(vgg.vgg_arg_scope()):
                # 对图像做处理
                gen_img_processed = [load_data.img_process(image, True)
                                     for image in tf.unstack(gen_img, axis=0, num=self.batch_size)]
                # f表示vgg每段卷积的特征图输出, exclude是VGG不需要加载的变量的名字
                f1, f2, f3, f4, exclude = vgg.vgg_16(tf.concat([gen_img_processed, imgs, style_imgs], axis=0))

                gen_f, img_f, _ = tf.split(f3, 3, 0)
                # 计算损失 content loss 和 style loss
                content_loss = tf.nn.l2_loss(gen_f - img_f) / tf.to_float(tf.size(gen_f))

                style_loss = model.styleloss(f1, f2, f3, f4)

                # load vgg model
                vgg_model_path = '/home/liu/Tensorflow-Project/temp/model/vgg_16.ckpt'
                vgg_vars = slim.get_variables_to_restore(include=['vgg_16'], exclude=exclude)
                # vgg_init_var = slim.get_variables_to_restore(include=['vgg_16/fc6'])
                init_fn = slim.assign_from_checkpoint_fn(vgg_model_path, vgg_vars)
                init_fn(self.sess)
                # tf.initialize_variables(var_list=vgg_init_var)
                print 'vgg s weights load done'

            self.gen_img = gen_img

            self.global_step = tf.Variable(0, name="global_step", trainable=False)

            self.content_loss = content_loss
            self.style_loss = style_loss*100   # 100是随意设置的,可以调整控制风格迁移的程度
            self.loss = self.content_loss + self.style_loss
            self.opt = tf.train.AdamOptimizer(0.0001).minimize(self.loss, global_step=self.global_step, var_list=variables)

        all_var = tf.global_variables()
        # init_var = [v for v in all_var if 'beta' in v.name or 'global_step' in v.name or 'Adam' in v.name]
        init_var = [v for v in all_var if 'vgg_16' not in v.name]
        init = tf.variables_initializer(var_list=init_var)
        self.sess.run(init)

        self.save = tf.train.Saver(var_list=variables)

训练部分代码:

    def train(self):
        print ('start to training')
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        try:
            while not coord.should_stop():
                # start_time = time.time()
                _, loss, step, cl, sl = self.sess.run([self.opt, self.loss, self.global_step, self.content_loss, self.style_loss])

                if step%100 == 0:
                    gen_img = self.sess.run(self.gen_img)
                    if not os.path.exists('gen_img'):
                        os.mkdir('gen_img')
                    save_img.save_images(gen_img, './gen_img/{0}.jpg'.format(step/100))

                print ('[{}/40000],loss:{}, content:{},style:{}'.format(step, loss, cl, sl))

                if step % 2000 == 0:
                    if not os.path.exists('model_saved_s'):
                        os.mkdir('model_saved_s')
                    self.save.save(self.sess, './model_saved_s/wave{}.ckpt'.format(step/2000))
                # 训练40000次就停止,大概2epoch
                if step >= 40000:
                    break

        except tf.errors.OutOfRangeError:
                self.save.save(sess, os.path.join(os.getcwd(), 'fast-style-model.ckpt-done'))
        finally:
            coord.request_stop()
        coord.join(threads)

总结:


本文浮现的论文仍然有一些不足之处,比如根据一个风格图像训练一个model只能风格化此种图像,要风格化很多种图像就要训练不同的model,不过在后来的论文中已经得到了解决,以后有时间我会继续复现。
如果感兴趣,请关注微信公众号,还有更多精彩:
这里写图片描述

参考文献:

[1] Image Style Transfer Using Convolutional Neural Networks
[2] https://arxiv.org/abs/1603.08155
[3] https://www.zhihu.com/question/49805962/answer/199427278
[4] https://arxiv.org/abs/1603.08155
[5] https://github.com/hzy46/fast-neural-style-tensorflow

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

风格迁移背后原理及tensorflow实现 的相关文章

随机推荐

  • FasterTransformer 安装

    第一 安装TensorRT FasterTransformer 支持cuda10 0 所以TensorRT也下cuda10 0对应版本 1 下载TensorRT https developer nvidia com nvidia tenso
  • Android中的Wifi框架知识点

    一 Android wifi框架图 Android WIFI系统引入了wpa supplicant 它的整个WIFI系统以wpa supplicant为核心来定义上层接口和下层驱动接口 Android WIFI主要分为六大层 分别是WiFi
  • TCP套接字网络编程实例(二)

    TCP套接字网络编程实例 二 采用多线程实现客户端和服务器的聊天功能 OK 上代码 1 客户端部分 文件 tcp client c 内容 利用TCP实现客户端和服务器的实时聊天 注 服务器的端口号及IP 客户端的端口号及IP的输入通过mai
  • angular2+修改环境变量,不用重新build部署

    转载文章链接 How to use environment variables to configure your Angular application without a rebuild 整体思路 在assets目录下添加js文件 获取
  • warning negative label/yolo标签出现负值?

    问题如上图所示 出现场景 图像增强后 xml文件转txt文件 解决方法1 无脑粗暴 一秒见效 直接把负值转成正值 不影响标注与训练结果 代码如下 import os def process files in folder folder pa
  • BLE低功耗的设置参数

    广播间隔 连接间隔 扫描间隔 扫描窗口 广播间隔 两个相邻广播事件之间的时间称为广播间隔 可以选择 20ms 10 28s 不等的间隔 通常 一个广播中的设备会每一秒广播一次 必须是 0 625ms 的整数倍 由于设备间的时钟会不同程度的漂
  • f12弹出debug_360浏览器网站按f12弹出新窗口解决办法

    为何有些网站在360浏览器中按f12弹出新窗口 而不是在当前页面右侧出现调试部分呢 对于网站开发人员来说 我们需要得到的是在当前页面出现调试结果 不管是本地测试文件还是已经上线的网站 这里成都seo小冬 总结了下面三点 一起来试试吧 1 切
  • Nali:一个离线查询 IP 地理信息和 CDN 提供商的终端利器

    什么是 Nali dig nslookup traceroute 等都是非常实用的网络调试工具 Nali 是一个可以自动解析这些工具 stdout 中的 IP 并附上所属 ISP 和地理信息 对这些已是瑞士军刀般的工具可谓如虎添翼 Nali
  • 力扣(LeetCode)每日一题 LCP 50. 宝石补给

    简单题 不用解释直接看代码 class Solution public int giveGem int gem int operations for int i 0 i
  • Qt笔记(六十)之Qt实现无边框圆角窗口

    一 前言 设置无边框窗口之后 就会显示直角的风格 有用户反馈说 看着太锐了 让我给换成圆角 看着舒服一点 楼主一开始想用Qss实现 发现实在不行 后边想着 估计只能用绘图事件来操作了 二 实现过程 1 实现窗口无边框效果 setWindow
  • java实现mysql数据库增删改查

    本文将介绍java实现数据库增删改查的操作方法定义的代码 包括statment和preparestatment两种模式 两种的区别可以参考别的文章 按需选用 例 getdata是statment的查询的方法 pgetdata是prepare
  • flutter获取验证码输入框组件

    代码 import package flutter material dart class ValidataInputBoxWidget extends StatefulWidget ValidataInputBoxWidget Key k
  • dedecms列表页上一页下一页翻页单独调用的方法

    本文实例讲述了dedecms列表页上一页下一页翻页单独调用的方法 分享给大家供大家参考 具体实现方法如下 在列表页单独调用上一页和下一页 以及首页 简单搞了一下 仅作上下翻页 主页类似 可自行添加 在模板中以 复制代码代码如下 dede p
  • c++11中的tuple(元组)

    转自 http www cnblogs com qicosmos p 3318070 html 这次要讲的内容是 c 11中的tuple 元组 tuple看似简单 其实它是简约而不简单 可以说它是c 11中一个既简单又复杂的东东 关于它简单
  • [深入研究4G/5G/6G专题-61]: 关键概念和常见问题之Cell, UE 上下文, RRC连接,PDU会话, ,SRB Bear, DRB Bear,Qos Flow,

    目录 第1章 协议栈与承载 1 1 LTE空口协议栈 1 2 5G 空口协议栈 第2章 L3关键数据对象的层次架构
  • 一个关于malloc的面试题

    发表于1年前 2014 04 04 13 31 阅读 176 评论 0 9人收藏此文章 我要收藏 赞0 慕课网 程序员升职加薪神器 点击免费学习 前两天看了一个关于malloc的面试题 题目是这样的 1 2 3 4 5
  • 智慧城市视域下政府数据开放共享机制研究

    大数据背景下 基于对数据挖掘和运用基础上的智慧城市建设是城市发展的必然趋势 是新时期实现城市科学发展 高效管理与公共服务更优化的重要战略 政府作为社会管理 城市治理的主体 其形成 管理的数据资源约占全社会总量的80 对其进行有效治理将使数据
  • 程序员被裁,在互联网上该何去何从?

    你一直以来都是一名互联网程序员 而现在你却被裁了 你感到无助和迷茫 不知道该往哪里走 这是一个常见的问题 但是不要担心 我们来看看你未来有哪些方向可以选择 转行去做别的 首先 如果你对现在的工作感到厌倦或者想要寻找新的挑战 你可以考虑转行去
  • JAVA输出语句

    带回车的输出 System out println 输出内容 不带回车的输出 System out print 输出内容 拼接输出 int a 5 System out println a的值为 a
  • 风格迁移背后原理及tensorflow实现

    前言 本文将详细介绍 tf 实现风格迁移的小demo 看完这篇就可以去实现自己的风格迁移了 复现的算法来自论文 Perceptual P e r c e p t u a l