Pytorch Advanced(三) Neural Style Transfer

2023-11-16

神经风格迁移在之前的博客中已经用keras实现过了,比较复杂,keras版本

这里用pytorch重新实现一次,原理图如下:


from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

加载图像

def load_image(image_path, transform=None, max_size=None, shape=None):
    """Load an image and convert it to a torch tensor."""
    image = Image.open(image_path)
    
    if max_size:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS)
    
    if shape:
        image = image.resize(shape, Image.LANCZOS)
    
    if transform:
        image = transform(image).unsqueeze(0)
    
    return image.to(device)

这里用的模型是 VGG-19,所要用的是网络中的5个卷积层

class VGGNet(nn.Module):
    def __init__(self):
        """Select conv1_1 ~ conv5_1 activation maps."""
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28'] 
        self.vgg = models.vgg19(pretrained=True).features
        
    def forward(self, x):
        """Extract multiple convolutional feature maps."""
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features

 模型结构如下,可以看到使用序列模型来写的VGG-NET,所以标号即层号,我们要保存的是['0', '5', '10', '19', '28'] 的输出结果。

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

 训练:

接下来对训练过程进行解释:

1、加载风格图像和内容图像,我们在之前的博客中使用的一幅加噪图进行训练,这里是用的内容图像的拷贝。

2、我们需要优化的就是作为目标的内容图像拷贝,可以看到target需要求导。

3、VGGnet参数是不需要优化的,所以设置为验证状态。

4、将3幅图像输入网络,得到总共15个输出(每个图像有5层的输出)

5、内容损失:这里是遍历5个层的输出来计算损失,而在keras版本中只用了第4层的输出计算损失

6、风格损失:同样计算格拉姆风格矩阵,将每一层的风格损失叠加,得到总的风格损失,计算公式同样和keras版本有所不一样

7、反向传播

def main(config):
    
    # Image preprocessing
    # VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
    # We use the same normalization statistics here.
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                             std=(0.229, 0.224, 0.225))])
    
    # Load content and style images
    # Make the style image same size as the content image
    content = load_image(config.content, transform, max_size=config.max_size)
    style = load_image(config.style, transform, shape=[content.size(2), content.size(3)])
    
    # Initialize a target image with the content image
    target = content.clone().requires_grad_(True)
    
    optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])
    vgg = VGGNet().to(device).eval()
    
    for step in range(config.total_step):
        
        # Extract multiple(5) conv feature vectors
        target_features = vgg(target)
        content_features = vgg(content)
        style_features = vgg(style)

        style_loss = 0
        content_loss = 0
        for f1, f2, f3 in zip(target_features, content_features, style_features):
            # Compute content loss with target and content images
            content_loss += torch.mean((f1 - f2)**2)

            # Reshape convolutional feature maps
            _, c, h, w = f1.size()
            f1 = f1.view(c, h * w)
            f3 = f3.view(c, h * w)

            # Compute gram matrix
            f1 = torch.mm(f1, f1.t())
            f3 = torch.mm(f3, f3.t())

            # Compute style loss with target and style images
            style_loss += torch.mean((f1 - f3)**2) / (c * h * w) 
        
        # Compute total loss, backprop and optimize
        loss = content_loss + config.style_weight * style_loss 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step+1) % config.log_step == 0:
            print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}' 
                   .format(step+1, config.total_step, content_loss.item(), style_loss.item()))

        if (step+1) % config.sample_step == 0:
            # Save the generated image
            denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
            img = target.clone().squeeze()
            img = denorm(img).clamp_(0, 1)
            torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))

写在if __name__=="__main__"后面的语句只会在本脚本中才能被执行,被调用时是不会被执行的。 

python的命令行工具:argparse,很优雅的添加参数

但是由于jupyter不支持添加外部参数,所以使用了外部博客的方法来支持(记住更改读取图片的位置)

import sys
if __name__ == "__main__":
    
    #解决方案来自于博客
    if '-f' in sys.argv:
        sys.argv.remove('-f')
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--content', type=str, default='png/content.png')
    parser.add_argument('--style', type=str, default='png/style.png')
    parser.add_argument('--max_size', type=int, default=400)
    parser.add_argument('--total_step', type=int, default=2000)
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=500)
    parser.add_argument('--style_weight', type=float, default=100)
    parser.add_argument('--lr', type=float, default=0.003)
    #config = parser.parse_args()
    config = parser.parse_known_args()[0]   #参考博客 https://blog.csdn.net/ken_for_learning/article/details/89675904
    print(config)
    main(config)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch Advanced(三) Neural Style Transfer 的相关文章

  • 如何避免使用 python 处理空的标准输入?

    The sys stdin readline 返回之前等待 EOF 或新行 所以如果我有控制台输入 readline 等待用户输入 相反 我想打印帮助并在没有需要处理的情况下退出并显示错误 而不是等待用户输入 原因 我正在寻找一个Pytho
  • 如何使用 Python 3 绕过 HTTP Error 403: Forbidden with urllib.request

    您好 不是每次都这样 但有时在尝试访问 LSE 代码时 我会收到每一个烦人的 HTTP 错误 403 禁止消息 任何人都知道我如何仅使用标准 python 模块来克服这个问题 遗憾的是没有漂亮的汤 import urllib request
  • 此 TypeError 消息中提到的“代码对象”是什么?

    在尝试使用Python时exec声明 我收到以下错误 TypeError exec arg 1 must be a string file or code object 我不想传递字符串或文件 但什么是代码对象 如何创建一个 创建代码对象的
  • 稀有对象的 python 类型注释,例如 psycopg2 对象

    我了解内置类型 但是我如何指定稀有对象 例如数据库连接对象 def get connection and cursor gt tuple psycopg2 extensions cursor psycopg2 extensions conn
  • 如何过滤 Pandas GroupBy 对象并获取 GroupBy 对象?

    当对 Pandas groupby 操作的结果执行过滤时 它返回一个数据帧 但假设我想执行进一步的分组计算 我必须再次调用 groupby 这似乎有点绕 有更惯用的方法吗 EDIT 为了说明我在说什么 我们无耻地从 Pandas 文档中窃取
  • 创建上下文后将 jar 文件添加到 pyspark

    我正在笔记本上使用 pyspark 并且不处理 SparkSession 的创建 我需要加载一个包含一些我想在处理 rdd 时使用的函数的 jar 您可以使用 jars 轻松完成此操作 但在我的特定情况下我无法做到这一点 有没有办法访问sp
  • 使用 Paramiko 进行 DSA 密钥转发?

    我正在使用 Paramiko 在远程服务器上执行 bash 脚本 在其中一些脚本中 存在与其他服务器的 ssh 连接 如果我只使用 bash 不使用 Python 我的 DSA 密钥将被第一个远程服务器上的 bash 脚本转发并使用 以连接
  • 在Python上获取字典的前x个元素

    我是Python的新手 所以我尝试用Python获取字典的前50个元素 我有一本字典 它按值降序排列 k 0 l 0 for k in len dict d l 1 if l lt 51 print dict 举个小例子 dict d m
  • Arcpy 模数在 Pycharm 中不显示

    如何将 Arcpy 集成到 Pycharm 中 我尝试通过导入模块但它没有显示 我确实知道该模块仅适用于 2 x python arcpy 在 PyPi Python 包索引 上不可用 因此无法通过 pip 安装 要使用 arcpy 您需要
  • Python HMAC:类型错误:字符映射必须返回整数、None 或 unicode

    我在使用 HMAC 时遇到了一个小问题 运行这段代码时 signature hmac new key secret key msg string to sign digestmod sha1 我收到一个奇怪的错误 File usr loca
  • Python Anaconda:如何测试更新的库是否与我现有的代码兼容?

    我在 Windows 7 机器上使用 Python 2 7 Anaconda 安装进行数据分析和科学计算 当新的库发布时 例如新版本的 pandas patsy 等 您建议我如何测试新版本与现有代码的兼容性 是否可以在同一台机器上安装两个
  • 查找 Pandas DF 行中的最短日期并创建新列

    我有一个包含多个日期的表 有些日期将为 NaN 我需要找到最旧的日期 所以一行可能有 DATE MODIFIED WITHDRAWN DATE SOLD DATE STATUS DATE 等 因此 对于每一行 一个或多个字段中都会有一个日期
  • 从 Flask 运行 NPM 构建

    我有一个 React 前端 我想在与我的 python 后端 API 相同的源上提供服务 我正在尝试使用 Flask 来实现此目的 但我遇到了 Flask 找不到我的静态文件的问题 我的前端构建是用生成的npm run build in s
  • pandas 相当于 np.where

    np where具有向量化 if else 的语义 类似于 Apache Spark 的when otherwise数据帧方法 我知道我可以使用np where on pandas Series but pandas通常定义自己的 API
  • 是否需要关闭没有引用它们的文件?

    作为一个完全的编程初学者 我试图理解打开和关闭文件的基本概念 我正在做的一项练习是创建一个脚本 允许我将内容从一个文件复制到另一个文件 in file open from file indata in file read out file
  • 如何将带有参数的Python装饰器实现为类?

    我正在尝试实现一个接受一些参数的装饰器 通常带有参数的装饰器被实现为双重嵌套闭包 如下所示 def mydecorator param1 param2 do something with params def wrapper fn def
  • minizinc python 安装

    我通过 anaconda 提示符在 python 上安装了 minizinc 就像其他软件包一样 pip install minizinc 该软件包表示已成功安装 我可以导入该模块 但是 我正在遵循基本示例https minizinc py
  • 将 Keras 集成到 SKLearn 管道?

    我有一个 sklearn 管道 对异构数据类型 布尔 分类 数字 文本 执行特征工程 并想尝试使用神经网络作为我的学习算法来拟合模型 我遇到了输入数据形状的一些问题 我想知道我想做的事情是否可能 或者我是否应该尝试不同的方法 我尝试了几种不
  • 如何使用 python 定位和读取 Data Matrix 代码

    我正在尝试读取微管底部的数据矩阵条形码 我试过libdmtx http libdmtx sourceforge net 它有 python 绑定 当矩阵的点是方形时工作得相当好 但当矩阵的点是圆形时工作得更糟 如下所示 另一个复杂问题是在某
  • 定义在文本小部件中双击时选择哪些字符

    在 Windows 上 双击文本小部件中的单词也将选择连接的标点符号 有什么方法可以定义您想要选择的角色吗 tcl wordchars该变量的值是一个正则表达式 可以设置它来控制什么被视为 单词 字符 例如 通过双击 Tk 中的文本来选择单

随机推荐

  • SiC MOSFET应用中出现的串扰问题,提出3种有效应用对策

    针对 SiC MOSFET 模块应用中出现的串扰问题 百度网盘 请输入提取码 提取码9dfv 本文对测量使用的差分探头进行了详细对比 由结果可知采用高带宽和高采样率的示波器和差分探头可测 量得到准确的信号波形 同时分析了串扰问题的产生 机制
  • 基于Xilinx XDMA 的PCIE通信

    基于Xilinx XDMA 的PCIE通信 概述 想实现基于FPGA的PCIe通信 查阅互联网各种转载 基本都是对PCIe的描述 所以想写一下基于XDMA的PCIe通信的实现 PCIe结构仅做简单的描述 笔记 了解详细结构移至互联网 实践实
  • GPT概述

    全局唯一标识分区表 GUID Partition Table 缩写 GPT 是一个实体硬盘的分区结构 它是可扩展固件接口标准的一部分 用来替代BIOS中的主引导记录分区表 传统的主启动记录 MBR 磁盘分区支持最大卷为 2 2 TB ter
  • C++之继承

    目录 1 继承的概念及定义 1 继承的概念 2 继承定义 2 基类和派生类对象赋值转换 3 继承中的作用域 4 派生类的默认成员函数 5 继承与友元 6 继承与静态成员 7 复杂的菱形继承及菱形虚拟继承 1 单继承 2 多继承 3 菱形继承
  • 拆解雪花算法生成规则

    1 介绍 雪花算法 Snowflake 是一种生成分布式全局唯一 ID 的算法 生成的 ID 称为 Snowflake IDs 或 snowflakes 这种算法由 Twitter 创建 并用于推文的 ID 目前仓储平台生成 ID 是用的雪
  • visual studio code 2019远程连接服务器

    一 安装sftp 二 配置sftp 按住ctrl ship p键 得到以下画面 选择SFTP Config 当右下角出现 意思时需要一个文件夹 点击open folder后 选择或者创建一个文件夹 再回来按住ctrl shif p就会看到一
  • QT 总结(三) 1.Qt 运行 bat 文件 QProcess 2.获取当前文件路径

    1 Qt 运行 bat 文件 QProcess QProcess p p start cmd exe QStringList lt lt c lt lt c WINDOWS upan2 bat if p waitForStarted p w
  • 求助:tp-link wr720n路由器,想刷打印服务器!

    求助 tp link wr720n路由器 想刷打印服务器 求固件和教程 希望大神赐教
  • 智能图像水位识别系统的工作原理

    系统组成 智能水位图像识别系统主要包括前端设备 传输网络 平台软件和显示终端 采用定时抓拍和自主抓拍图像两种形式 定时或根据需要上传水尺图片 前端设备主要包括网络高速摄像机 水尺 4G流量卡 传输网络主要通过4G网络传输至信息中心 在信息中
  • 会话技术Cookie&Session

    1 会话技术 从打开一个浏览器访问某个站点 到关闭这个浏览器的整个过程 成为一次会话 会 话技术就是记录这次会话中客户端态的状与数据的 会话技术分为Cookie和Session Cookie 数据存储在客户端本地 减少服务器端的存储的压力
  • Ubuntu20.04下载安装FFmpeg源码,并且编译FFmpeg

    一 Terminal终端输入 git clone git source ffmpeg org ffmpeg git ffmpeg 二 安装依赖环境 sudo apt get install y autoconf automake build
  • python基础之程序执行原理(科普)

    提示 文章写完后 目录可以自动生成 如何生成可参考右边的帮助文档 文章目录 一 计算机的三大件 二 计算机执行 三 python程序的执行原理 四 程序的作用 一 计算机的三大件 1 cpu 本质上是一块超大规模集成电路 2 内存 存储设备
  • 博客摘录「 【MySQL】事务及其隔离性/隔离级别」2023年8月31日

    一般的数据库在可重复读情况的时候 无法屏蔽其他事务insert的数据 因为隔离性实现是对数据加锁完成的 而insert待插入的数据因为并不存在 那么一般加锁无法屏蔽这类问题 这会造成大部分内容虽然是可重复读的 但是insert的数据在可重复
  • web前端面试题整理(前端和计算机相关知识)

    1 你能描述一下渐进增强和优雅降级之间的不同吗 定义 优雅降级 graceful degradation 一开始就构建站点的完整功能 然后针对浏览器测试和修复 渐进增强 progressive enhancement 一开始只构建站点的最少
  • 面试阿里测开岗失败后,被面试官在朋友圈吐槽了......

    前一阵子有个徒弟向我诉苦 说自己在参加某大厂测试面试的时候被面试官怼得哑口无言 场面让他一度十分尴尬 印象最深的就是下面几个问题 根据你以前的工作经验和学习到的测试技术 说说你对质量保证的理解 非关系型数据库和关系型数据库的区别 谈谈优势比
  • 用tornado 连接mysql进行操作报错sqlalchemy.exc.OperationalEror: (pymysql.err.perationalError)(235,Can‘t comect

    用tornado 连接mysql进行操作报错sqlalchemy exc OperationalEror pymysql err perationalError 235 Can t comect to Wy lL serVer on 192
  • 【创作赢红包】云原生之使用Docker部署YApi接口管理服务平台

    云原生之使用Docker部署YApi接口管理服务平台 一 YApi介绍 1 YApi简介 2 YApi功能 二 检查docker环境 1 检查docker版本 2 检查docker状态 三 安装MongoDB数据库 1 创建MongoDB数
  • 百度文库等类似工具的免费下载工具

    百度文库如何免费下载文献 软件介绍 百度文库如何免费下载文献 冰点文库下载器V3 1 9 亲测 可用 软件介绍 无需积分就可以自由下载百度 豆丁 丁香 MBALib 道客巴巴 Book118等文库文档 无需注册和登录 下载的文档最终生成高清
  • 跳动爱心代码-李峋爱心代码(手把手教学)

    电视剧 点燃我 温暖你 打火机与公主裙 李洵爱心跳动效果 获取完整代码 公众号 ClassmateJie 回复爱心代码 本文分为两种方式讲解如何运行代码 第一种方式比较简单推荐新手 完全不懂编程的 第二种方式需要有一定的编程基础的人 跟着我
  • Pytorch Advanced(三) Neural Style Transfer

    神经风格迁移在之前的博客中已经用keras实现过了 比较复杂 keras版本 这里用pytorch重新实现一次 原理图如下 from future import division from torchvision import models