华为开源自研AI框架昇思MindSpore应用案例:Colorization自动着色

2023-05-16

目录

  • 一、环境准备
    • 1.进入ModelArts官网
    • 2.使用ModelArts体验实例
  • 二、数据处理
    • 数据准备
    • 训练集可视化
    • 构建网络
    • 损失函数
  • 三、模型实现
    • 算法流程
    • 模型训练
    • 模型推理
    • 总结

自动着色算法之Colorization
当桃乐丝在1939年的电影《绿野仙踪》中走进奥兹国时,从黑白到鲜艳的色彩的转变使它成为电影史上最令人叹为观止的时刻之一。毫无疑问,颜色是一种有效的表达工具,但它们通常是有代价的。在制作现代动画电影和漫画时,图像着色是最费力和昂贵的阶段之一。自动着色过程可以帮助减少制作漫画或动画电影所需的成本和时间

模型简介
Colorization算法是来自加里福利亚大学的一项研究,采用的是CNN的结构。该算法可以实现灰度图像的自动着色,由Richard
Zhang等人在论文Colorful Image
Colorization中提出,并发表在2016年的ECCV会议中。该模型由8个conv层组成,每个conv层由2个或3个重复的卷积层和ReLU层组成,后面跟着一个BatchNorm层。网络中不包含池化层。

网络特点

  1. 设计了一个合适的损失函数来处理着色问题中的多模不确定性,维持了颜色的多样性。
  2. 将图像着色任务转化为一个自监督表达学习的任务。
  3. 在一些基准模型上获得了最好的效果。

完整的样例代码:Colorization.ipynb

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区

在这里插入图片描述

在这里插入图片描述

一、环境准备

1.进入ModelArts官网

云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpore2.0.0-alpha版本,可以在昇思教程中进入ModelArts官网

在这里插入图片描述

选择下方CodeLab立即体验

在这里插入图片描述

等待环境搭建完成

在这里插入图片描述

2.使用ModelArts体验实例

进入昇思MindSpore官网,点击上方的安装

在这里插入图片描述

获取安装命令

在这里插入图片描述

在ModelArts中切换规格

在这里插入图片描述

打开一个Terminal,输入安装命令

conda install mindspore=2.0.0a0 -c mindspore -c conda-forge

在这里插入图片描述

再点击侧边栏中的Clone a Repository,输入

https://github.com/mindspore-courses/applications.git

在这里插入图片描述

二、数据处理

开始实验之前,请确保本地已经安装了Python环境并安装了MindSpore Vision套件。

数据准备

本案例使用ImageNet数据集作为训练集和测试集。请在官网下载。训练集中包含1000个类别,总计大约120万张图片,测试集中包含5万图片。

解压后的数据集目录结构如下:

.dataset/
├── ILSVRC2012_devkit_t12.tar.gz
├── train/
└── val/

训练集可视化

import os
import argparse
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import mindspore
from src.process_datasets.data_generator import ColorizationDataset


#加载参数
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, default='./dataset/train', help='path to dataset')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_parallel_workers', type=int, default=1)
parser.add_argument('--shuffle', type=bool, default=True)
args = parser.parse_args(args=[])
plt.figure()

#加载数据集
dataset = ColorizationDataset(args.image_dir, args.batch_size, args.shuffle, args.num_parallel_workers)
data = dataset.run()
show_data = next(data.create_tuple_iterator())
show_images_original, _ = show_data
show_images_original = show_images_original.asnumpy()
#循环处理
for i in range(1, 5):
    plt.subplot(1, 4, i)
    temp = show_images_original[i-1]
    temp = np.clip(temp, 0, 1)
    plt.imshow(temp)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)

在这里插入图片描述

构建网络

处理完数据后进行网络的搭建,Colorization的网络结构较为简单,采用CNN的网络结构。具体结构如下图所示
在这里插入图片描述
网络的详细配置为:
在这里插入图片描述
其中X输出的空间分辨率,C输出的通道数;S计算步幅,大于1表示卷积后下采样,小于1表示卷积前上采样;D内核扩张;Sa在所有前一层的累积步数(积于前一层的所有步数);相对于输入的层的有效膨胀(层膨胀乘以累积步幅);BN层后是否使用BatchNorm层;L表示是否施加了1x1的卷积和交叉熵损失层。

损失函数

在这里插入图片描述
分类再平衡
在这里插入图片描述

分类概率到点估计

在这里插入图片描述

class NetLoss(nn.Cell):
    """连接网络和损失"""
    def __init__(self, net):
        super(NetLoss, self).__init__(auto_prefix=True)
        self.net = net
        self.loss = nn.CrossEntropyLoss(reduction='none')

    def construct(self, images, targets, boost, mask):
        """ build network """
        outputs = self.net(images)
        boost_nongray = boost * mask
        squeeze = mindspore.ops.Squeeze(1)
        boost_nongray = squeeze(boost_nongray)
        result = self.loss(outputs, targets)
        result_loss = (result * boost_nongray).mean()
        return result_loss

在这里插入图片描述

三、模型实现

MindSpore要求将损失函数、优化器等操作也看做nn.Cell的子类,所以我们可以自定义Color类,将网络和loss连接起来。

class ColorModel(nn.Cell):
    """定义Colorization网络"""

    def __init__(self, my_train_one_step_cell_for_net):
        super(ColorModel, self).__init__(auto_prefix=True)
        self.my_train_one_step_cell_for_net = my_train_one_step_cell_for_net

    def construct(self, result, targets, boost, mask):
        loss = self.my_train_one_step_cell_for_net(result, targets, boost,
                                                   mask)
        return loss

在这里插入图片描述

算法流程

在这里插入图片描述

模型训练

实例化损失函数,优化器,使用Model接口编译网络,开始训练。

import argparse
import os
from tqdm import tqdm

import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore import ops
import numpy as np
import matplotlib.pyplot as plt
from src.utils.utils import PriorBoostLayer, NNEncLayer, NonGrayMaskLayer, decode

from src.model.model import ColorizationModel
from src.model.colormodel import ColorModel
from src.process_datasets.data_generator import ColorizationDataset
from src.losses.loss import NetLoss
import warnings

warnings.filterwarnings('ignore')
#加载参数

parser = argparse.ArgumentParser()
parser.add_argument('--device_target',
                    default='GPU',
                    choices=['CPU', 'GPU', 'Ascend'],
                    type=str)
parser.add_argument('--device_id', default=1, type=int)
parser.add_argument('--image_dir',
                    type=str,
                    default='./dataset/train',
                    help='path to dataset')
parser.add_argument('--checkpoint_dir',
                    type=str,
                    default='./checkpoints',
                    help='path for saving trained model')
parser.add_argument('--test_dirs',
                    type=str,
                    default='./images',
                    help='path for saving trained model')
parser.add_argument('--resource', type=str, default='./src/resources/')
parser.add_argument('--shuffle', type=bool, default=True)
parser.add_argument('--num_epochs', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_parallel_workers', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=0.5e-4)
parser.add_argument('--save_step',
                    type=int,
                    default=200,
                    help='step size for saving trained models')
args = parser.parse_args(args=[])

if context.get_context('device_id') != args.device_id:
    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)

encode_layer = NNEncLayer(args)
boost_layer = PriorBoostLayer(args)
non_gray_mask = NonGrayMaskLayer()

#网络实例化
net = ColorizationModel()

#设置优化器
net_args = nn.Adam(net.trainable_params(), learning_rate=args.learning_rate)

#实例化NetLoss
net_with_criterion = NetLoss(net)

#实例化TrainOneStepWithLossScaleCell
scale_sense = nn.FixedLossScaleUpdateCell(1)
myTrainOneStepCellForNet = nn.TrainOneStepWithLossScaleCell(
    net_with_criterion, net_args, scale_sense=scale_sense)
colormodel = ColorModel(myTrainOneStepCellForNet)
colormodel.set_train()

#加载数据集
dataset = ColorizationDataset(args.image_dir, args.batch_size, args.shuffle,
                              args.num_parallel_workers)
data = dataset.run().create_tuple_iterator()

for epoch in range(args.num_epochs):
    iters = 0

    #为每轮训练读入数据
    for images, img_ab in tqdm(data):
        images = ops.expand_dims(images, 1)
        encode, max_encode = encode_layer.forward(img_ab)
        targets = mindspore.Tensor(max_encode, dtype=mindspore.int32)
        boost = mindspore.Tensor(boost_layer.forward(encode),
                                 dtype=mindspore.float32)
        mask = mindspore.Tensor(non_gray_mask.forward(img_ab),
                                dtype=mindspore.float32)
        net_loss = colormodel(images, targets, boost, mask)
        #输出训练数据
        print('[%d/%d]\tLoss_net:: %.4f' % (epoch + 1, args.num_epochs, net_loss[0]))
        #中间保存训练结果
        if iters % args.save_step == 0:
            if not os.path.exists(args.checkpoint_dir):
                os.makedirs(args.checkpoint_dir)
            mindspore.save_checkpoint(
                net,
                os.path.join(args.checkpoint_dir, 'net' + str(epoch + 1) + '_' +
                             str(iters) + '.ckpt'))
            img_ab_313 = net(images)
            out_max = np.argmax(img_ab_313[0].asnumpy(), axis=0)
            color_img = decode(images, img_ab_313, args.resource)
            if not os.path.exists(args.test_dirs):
                os.makedirs(args.test_dirs)
            plt.imsave(
                args.test_dirs + '/' + str(epoch + 1) + '_' + str(iters) +
                '%s_infer.png', color_img)
        iters = iters + 1

在这里插入图片描述

在这里插入图片描述

模型推理

运行下面代码,将一张灰度图像输入到网络中,即可生成具有合理色彩的图像。

import argparse
import os

import matplotlib.pyplot as plt
import mindspore
import numpy as np
from mindspore import (context, load_checkpoint, load_param_into_net, ops)
from mindspore.train.model import Model
from tqdm import tqdm

from src.model.model import ColorizationModel
from src.process_datasets.data_generator import ColorizationDataset
from src.utils.utils import decode


parser = argparse.ArgumentParser()
parser.add_argument('--img_path', type=str, default='./dataset/val')
parser.add_argument('--ckpt_path', type=str, default='./checkpoints/net44_1600.ckpt')
parser.add_argument('--resource', type=str, default='./src/resources/')
parser.add_argument('--device_target', default='GPU', choices=['CPU', 'GPU', 'Ascend'], type=str)
parser.add_argument('--device_id', default=1, type=int)
parser.add_argument('--infer_dirs', default='./dataset/output', type=str)
args = parser.parse_args(args=[])


mindspore.context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)

#实例化网络
net = ColorizationModel()

#加载参数
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(net, param_dict)
colorizer = Model(net)
dataset = ColorizationDataset(args.img_path, 1, prob=0)
data = dataset.run().create_tuple_iterator()
iters = 0

if not os.path.exists(args.infer_dirs):
    os.makedirs(args.infer_dirs)

#循环处理图像
for images, img_ab in tqdm(data):
    images = ops.expand_dims(images, 1)
    img_ab_313 = colorizer.predict(images)
    out_max = np.argmax(img_ab_313[0].asnumpy(), axis=0)
    color_img = decode(images, img_ab_313, args.resource)
    plt.imsave(args.infer_dirs+'/'+str(iters)+'_infer.png', color_img)
    iters = iters + 1

在这里插入图片描述
在这里插入图片描述

总结

本案例对Colorful Image
Colorization文中提出的模型进行了详细的解释,向读者完整地展现了该算法的流程,分析了Colorization在着色方面的优势和存在的不足。如需查看详细代码,可参考MindSpore
Vision套件。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

华为开源自研AI框架昇思MindSpore应用案例:Colorization自动着色 的相关文章

  • 13:SpringBoot跨域解决方案-Java Spring

    目录 13 1 CorsFilter13 2 64 CrossOrigin13 3 WebMvcConfigurer 13 1 CorsFilter SpringBoot设置CORS的的本质都是通过设置响应头信息来告诉前端该请求是否支持跨域
  • 14:Servlet并发机制-Java Spring

    目录 14 1 并发14 2 Servlet并发机制14 3 Tomcat并发特点14 4 Tomcat线程模型 14 1 并发 并发 xff08 Concurrent xff09 是指多个任务交替执行的现象 xff0c 把CPU运行时间划
  • 手写字体识别实验-Python课程设计

    安装python 打开手写识别文件夹中的安装包文件夹 xff0c 双击python3 7 1可执行文件 xff0c 进行安装 弹出窗口 第一步 xff0c 勾选第二个复选框 Add Python 3 7 to PATH xff0c 然后点击
  • 生产企业原材料订购与运输的研究-数据处理课程设计

    目录 摘要1 引言2 规划问题说明3 问题重述3 1 问题分析3 2 数据说明3 3 模型假设3 4 符号说明 4 实验及分析4 1 问题一模型的建立与求解4 2 问题二模型的建立与求解 5 总结5 1 模型的优点5 2 模型的缺点 参考文
  • 信号发生器-电路与电子技术课程设计

    目录 1 设计任务与要求1 1 设计任务1 2 设计要求 2 方案设计与论证2 1 方案设计2 2 论证 3 信号发生器设计与计算3 1 信号发生器设计3 2 方波振荡电路图3 3 三角波振荡电路图3 4 参数计算 4 总原理图及元器件清单
  • 增益可控放大电路-电路与电子技术课程设计

    目录 1 设计任务与要求1 1 设计任务1 2 设计要求 2 方案设计与论证2 1 方案设计2 2 论证 3 放大电路设计与计算3 1 放大电路设计3 2 电子开关切换电路设计3 3 六档控制电路3 4 参数计算 4 总原理图及元器件清单4
  • 超声波测距实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习超声波测距传感器的使用方法 xff0c 了解超声波测距传感器的原理和电路及实际应用 xff0c 了解超声波测距传感器的基本操作
  • 光敏传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习光敏传感器的使用方法 xff0c 了解光敏传感器的基本实验原理和实际应用 xff0c 熟练掌握光敏传感器实验的操作步骤 xff
  • 红外反射传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习红外反射传感器的使用方式 xff0c 了解红外反射传感器的实验原理和实际应用 xff0c 学习并理解Modbus数据格式所代表
  • 酒精传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习酒精传感器MQ 3的使用方法 xff0c 了解酒精传感器的实验原理和实际应用 xff0c 了解酒精传感器的基本操作模式 xff
  • hdoj 1575 Tr A (矩阵快速幂)

    Tr A Time Limit 1000 1000 MS Java Others Memory Limit 32768 32768 K Java Others Total Submission s 4549 Accepted Submiss
  • MapReduce排序过程

    排序是MapReduce框架中最重要的操作之一 MapTask和ReduceTask均会对数据按照key 进行排序 该操作属于Hadoop 的默认行为 xff0c 任何应用程序中的数据均会被排序 xff0c 而不管逻辑上是否需要 默认排序是
  • 温湿度传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习温湿度传感器的使用方法 xff0c 了解温湿度传感器的基本实验原理和实际应用 xff0c 熟练掌握温湿度传感器的基本步骤 xf
  • 烟雾检测传感器实验-传感器原理及应用实验

    目录 一 实验实训主要内容二 实验实训方法 过程步骤三 实验实训结果与分析四 讨论小结 一 实验实训主要内容 学习烟雾检测传感器的原理及检测方式 xff0c 了解烟雾检测传感器的实验原理和技术指标 xff0c 熟练掌握烟雾检测传感器的工作步
  • 4:Servlet-Java Web

    目录 4 1 Servlet简介4 2 HTTP协议4 3 Servlet与JSP4 4 Servlet处理的基本流程4 5 Servlet 容器4 6 Servlet程序实现 4 1 Servlet简介 Servlet是用Java语言编写
  • 5:Servlet程序-Java Web

    目录 5 1 Servlet要求5 2 创建Servlet5 3 第一个Servlet5 4 Servlet编译5 5 Servlet配置 5 1 Servlet要求 如果要开发一个可以处理HTTP请求的Servlet程序 xff0c 首先
  • 6:部署Servlet-Java Web

    目录 6 1 部署Servlet6 2 请求Servlet6 3 找不到servlet包6 4 Servlet映射的细节 6 1 部署Servlet 部署就是把Servlet的字节码文件放在适当的地方 为了在浏览器上访问Servlet xf
  • 7:Servlet表单-Java Web

    目录 7 1 Servlet响应7 2 Servlet获取客户端参数7 3 Servlet接受表单数据 7 1 Servlet响应 通过response对象对用户进行响应 创建输出流对象 PrintWriter out 61 respons
  • 8:Servlet生命周期-Java Web

    目录 8 1 Servlet生命周期8 2 Servlet生命周期对应的方法8 3 Servlet的多线程机制 8 1 Servlet生命周期 Servlet程序是运行在服务器端的一段Java程序 xff0c 其生命周期将受到Web容器的控
  • 9:中文乱码处理-Java Web

    目录 9 1 常见字符集9 2 乱码原因9 3 解决乱码 9 1 常见字符集 ASCII 最原始的一套编码 xff0c 所有编码都是由一个字节的二进制数对应 xff0c 尽管包含8位 xff0c 但是第一位始终是0 xff0c 也就是128

随机推荐