pytorch用自己数据集训练Unet

2023-05-16

在图像分割这个问题上,主要有两个流派:Encoder-Decoder和Dialated Conv。本文介绍的是编解码网络中最为经典的U-Net。随着骨干网路的进化,很多相应衍生出来的网络大多都是对于Unet进行了改进但是本质上的思路还是没有太多的变化。比如结合DenseNet 和Unet的FCDenseNet, Unet++

一、Unet网络介绍

论文: https://arxiv.org/abs/1505.04597v1 (2015)

UNet的设计就是应用与医学图像的分割。由于医学影像处理中,数据量较少,本文提出的方法有效提升了使用少量数据集训练检测的效果,提出了处理大尺寸图像的有效方法。

UNet的网络架构继承自FCN,并在此基础上做了些改变。提出了Encoder-Decoder概念,实际上就是FCN那个先卷积再上采样的思想。

上图是Unet的网络结构,从图中可以看出,

结构左边为Encoder,即下采样提取特征的过程。Encoder基本模块为双卷积形式, 即输入经过两个conv\quad3\times3,使用的valid卷积,在代码实现时我们可以增加padding使用same卷积,来适应Skip Architecture。下采样采用的池化层直接缩小2倍。

结构右边是Decoder,即上采样恢复图像尺寸并预测的过程。Decoder一样采用双卷积的形式,其中上采样使用转置卷积实现,每次转置卷积放大2倍。

结构中间copy and crop是一个cat操作,即feature map的通道叠加。

二、VOC训练Unet

2.1 Unet代码实现 

根据上面对于Unet网络结构的介绍,可见其结构非常对称简单,代码Unet.py实现如下:

from turtle import forward
import torch.nn as nn
import torch


class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        # Encoder
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        # Decoder
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.output = nn.Conv2d(64, out_ch, 1)


    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)
        conv5 = self.conv5(pool4)

        up6 = self.up6(conv5)
        meger6 = torch.cat([up6, conv4], dim=1)
        conv6 = self.conv6(meger6)
        up7 = self.up7(conv6)
        meger7 = torch.cat([up7, conv3], dim=1)
        conv7 = self.conv7(meger7)
        up8 = self.up8(conv7)
        meger8 = torch.cat([up8, conv2], dim=1)
        conv8 = self.conv8(meger8)
        up9 = self.up9(conv8)
        meger9 = torch.cat([up9, conv1], dim=1)
        conv9 = self.conv9(meger9)
        out = self.output(conv9)
        
        return out



if __name__=="__main__":
    model = Unet(3, 21)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print(model)

2.2 数据集处理

数据来源于kaggle,下载地址我忘了。包含2个类别,1个车,还有1个背景类,共有5k+的数据,按照比例分为训练集和验证集即可。具体见carnava.py

from PIL import Image
from requests import check_compatibility
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
import numpy as np
import os
import matplotlib.pyplot as plt



class Car(Dataset):
    def __init__(self, root, train=True):
        self.root = root
        self.crop_size = (256, 256)
        self.img_path = os.path.join(root, "train_hq")
        self.label_path = os.path.join(root, "train_masks")
        img_path_list = [os.path.join(self.img_path, im) for im in os.listdir(self.img_path)]
        train_path_list, val_path_list = self._split_data_set(img_path_list)

        if train:
            self.imgs_list = train_path_list
        else:
            self.imgs_list = val_path_list

        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
        self.transforms = T.Compose([
                T.Resize(256),
                T.CenterCrop(256),
                T.ToTensor(),
                normalize
            ])

        self.transforms_val = T.Compose([
            T.Resize(256),
            T.CenterCrop(256)
        ])
        self.color_map = [[0, 0, 0], [255, 255, 255]]
        

    def __getitem__(self, index: int):
        im_path = self.imgs_list[index]
        image = Image.open(im_path).convert("RGB")
        data = self.transforms(image)
        (filepath, filename) = os.path.split(im_path)
        filename = filename.split('.')[0]
        label = Image.open(self.label_path +"/"+filename+"_mask.gif").convert("RGB")
        label = self.transforms_val(label)

        cm2lb=np.zeros(256**3)
        for i,cm in enumerate(self.color_map):
            cm2lb[(cm[0]*256+cm[1])*256+cm[2]]=i
        image=np.array(label,dtype=np.int64)
        idx=(image[:,:,0]*256+image[:,:,1])*256+image[:,:,2]
        label=np.array(cm2lb[idx],dtype=np.int64)
        label=torch.from_numpy(label).long()
        return data, label

    def label2img(self, label):
        cmap = self.color_map
        cmap = np.array(cmap).astype(np.uint8)
        pred = cmap[label]
        return pred

    def __len__(self):
        return len(self.imgs_list)

    def _split_data_set(self, img_path_list):
        val_path_list = img_path_list[::8]
        train_path_list = []
        for item in img_path_list:
            if item not in val_path_list:
                train_path_list.append(item)
        return train_path_list, val_path_list






if __name__=="__main__":
    root = "../dataset/carvana"
    car_train = Car(root,train=True)
    train_dataloader = DataLoader(car_train, batch_size=8, shuffle=True)
    print(len(car_train))
    print(len(train_dataloader))
    # for data, label in car_train:
    #     print(data.shape)
    #     print(label.shape)
    #     break
    
    (data, label) = car_train[190]
    label_np = label.data.numpy()
    label_im = car_train.label2img(label_np)
    plt.figure()
    plt.imshow(label_im)
    plt.show()

2.3 训练过程

 分割其实就是给每个像素分类而已,所以损失函数依旧是交叉熵函数,正确率为分类正确的像素点个数/全部的像素点个数

import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from voc import VOC
from carnava import Car
from unet import Unet
import os
import numpy as np
from torch import optim
import torch.nn as nn
import util


# 计算混淆矩阵
def _fast_hist(label_true, label_pred, n_class):
    mask = (label_true >= 0) & (label_true < n_class)
    hist = np.bincount(
        n_class * label_true[mask].astype(int) +
        label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist


def label_accuracy_score(label_trues, label_preds, n_class):
    """Returns accuracy score evaluation result.
      - overall accuracy
      - mean accuracy
      - mean IU
    """
    hist = np.zeros((n_class, n_class))
    for lt, lp in zip(label_trues, label_preds):
        hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
    acc = np.diag(hist).sum() / hist.sum()
    with np.errstate(divide='ignore', invalid='ignore'):
        acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    with np.errstate(divide='ignore', invalid='ignore'):
        iu = np.diag(hist) / (
            hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
        )
    mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    return acc, acc_cls, mean_iu


out_path = "./out"
if not os.path.exists(out_path):
    os.makedirs(out_path)
log_path = os.path.join(out_path, "result.txt")
if os.path.exists(log_path):
    os.remove(log_path)
model_path = os.path.join(out_path, "best_model.pth")
root = "../dataset/carvana"
epochs = 5
numclasses = 2

train_data = Car(root, train=True)
train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
val_data = Car(root, train=False)
val_dataloader = DataLoader(val_data, batch_size=16, shuffle=True)

net = Unet(3, numclasses)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()


def train_model():
    best_score = 0.0
    for e in range(epochs):
        net.train()
        train_loss = 0.0
        label_true = torch.LongTensor()
        label_pred = torch.LongTensor()
        for batch_id, (data, label) in enumerate(train_dataloader):
            data, label = data.to(device), label.to(device)
            output = net(data)
            loss = criterion(output, label)

            pred = output.argmax(dim=1).squeeze().data.cpu()
            real = label.data.cpu()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss+=loss.cpu().item()
            label_true = torch.cat((label_true,real),dim=0)
            label_pred = torch.cat((label_pred,pred),dim=0)

        train_loss /= len(train_dataloader)
        acc, acc_cls, mean_iu = label_accuracy_score(label_true.numpy(),label_pred.numpy(),numclasses)
        print("\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}".format(
            e+1, train_loss, acc, acc_cls, mean_iu))
        with open(log_path, 'a') as f:
            f.write('\n epoch:{}, train_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
                e+1,train_loss,acc, acc_cls, mean_iu))

        net.eval()
        val_loss = 0.0
        val_label_true = torch.LongTensor()
        val_label_pred = torch.LongTensor()
        with torch.no_grad():
            for batch_id, (data, label) in enumerate(val_dataloader):
                data, label = data.to(device), label.to(device)
                output = net(data)
                loss = criterion(output, label)
                pred = output.argmax(dim=1).squeeze().data.cpu()
                real = label.data.cpu()
                val_loss += loss.cpu().item()
                val_label_true = torch.cat((val_label_true, real), dim=0)
                val_label_pred = torch.cat((val_label_pred, pred), dim=0)

            val_loss/=len(val_dataloader)
            val_acc, val_acc_cls, val_mean_iu = label_accuracy_score(val_label_true.numpy(),
                                                                    val_label_pred.numpy(),numclasses)
        print('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(e+1, val_loss, val_acc, val_acc_cls, val_mean_iu))

        with open(log_path, 'a') as f:
            f.write('\n epoch:{}, val_loss:{:.4f}, acc:{:.4f}, acc_cls:{:.4f}, mean_iu:{:.4f}'.format(
            e+1,val_loss,val_acc, val_acc_cls, val_mean_iu))

        score = (val_acc_cls+val_mean_iu)/2
        if score > best_score:
            best_score = score
            torch.save(net.state_dict(), model_path)



def evaluate():
    import util
    import random
    import matplotlib.pyplot as plt

    net.load_state_dict(torch.load(model_path))
    index = random.randint(0, len(val_data)-1)
    val_image, val_label = val_data[index]
    out = net(val_image.unsqueeze(0).to(device))
    pred = out.argmax(dim=1).squeeze().data.cpu().numpy()
    label = val_label.data.numpy()
    img_pred = val_data.label2img(pred)
    img_label = val_data.label2img(label)

    temp = val_image.numpy()
    temp = (temp-np.min(temp)) / (np.max(temp)-np.min(temp))*255
    fig, ax = plt.subplots(1,3)
    ax[0].imshow(temp.transpose(1,2,0).astype("uint8"))
    ax[1].imshow(img_label)
    ax[2].imshow(img_pred)
    plt.show()


if __name__=="__main__":
    # train_model()
    evaluate()


最终训练结果是:

由于数据比较简单,训练到epoch为5时,mIOU就已经达到0.97了。

最后测试一下效果:

 

从左到右分别是:原图、真实label、预测label 

 备注:

其实最开始使用voc数据集训练的,但效果极差,也没发现哪里有问题。换个数据集效果就好了,可能有两个原因:

1. voc数据我在处理数据时出错了,没检查出来

2. 这个数据集比较简单,容易学习,所以效果差不多。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch用自己数据集训练Unet 的相关文章

  • C++ STL中递归锁与普通锁的区别

    在多线程编程中 xff0c 保护共享资源的访问很重要 xff0c 为了实现这个目标 xff0c C 43 43 标准库 xff08 STL xff09 中提供了多种锁 xff0c 如std mutex和std recursive mutex
  • VS+Qt开发环境

    VS Qt下载 VS下载 xff1a https visualstudio microsoft com zh hans vs Qt下载安装 xff1a https www bilibili com video BV1gx4y1M7cM VS
  • windows下使用ShiftMediaProject编译调试FFmpeg

    为什么要编译FFmpeg xff1f 定制模块调试源码 windows下编译 推荐项目ShiftMediaProject FFmpeg 平时总是看到一些人说windows下编译FFmpeg很麻烦 xff0c 这时候我就都是微微一笑 xff0
  • RTSP分析

    RTSP使用TCP来发送控制命令 xff08 OPTIONS DESCRIBE SETUP PLAY xff09 xff0c 因为TCP提供可靠有序的数据传输 xff0c 而且TCP还提供错误检测和纠正 RTSP的报文格式可以参考HTTP的
  • RTP分析

    参考 RTP xff08 A Transport Protocol for Real Time Applications 实时传输协议 xff0c rfc3550 xff09 RTP Payload Format for H 264 Vid
  • VS链接器工具错误 LNK2019:无法解析的外部符号

    常见的问题 以下是一些导致 LNK2019 的常见问题 xff1a 未链接的对象文件或包含符号定义的库 在 Visual Studio 中 xff0c 验证包含定义源代码文件是生成 xff0c 分别链接为项目的一部分 在命令行中 xff0c
  • FFmpeg合并视频流与音频流

    mux h ifndef MUX H define MUX H ifdef cplusplus extern 34 C 34 endif include 34 common h 34 include 34 encode h 34 typed
  • 解决电脑同时使用有线网上内网,无线网上外网的冲突

    由于内网有网络限制 xff08 限制娱乐等 xff09 xff0c 所以肯定要用外网 xff08 无线网卡 xff09 但是有的网站只能用内网访问 xff0c 比如gitlab xff0c oa等 我电脑刚开始连接了wifi后上不了gitl
  • Python斗鱼直播间自动发弹幕脚本

    工具 xff1a Python xff0c Chrome浏览器 因为不会用短信验证码登录 xff0c 所以使用QQ帐号登录 xff0c 必须要斗鱼帐号绑定QQ号 难点主要是帧的切换 查找元素可以通过chrome浏览器鼠标指向该元素 xff0
  • Qt+FFmpeg录屏录音

    欢迎加QQ群309798848交流C C 43 43 linux Qt 音视频 OpenCV 源码 xff1a Qt 43 FFmpeg录屏录音 NanaRecorder 之前的录屏项目ScreenCapture存在音视频同步问题 xff0
  • Qt源码分析(一)

    欢迎加QQ群309798848交流C C 43 43 linux Qt 音视频 OpenCV 源码面前 xff0c 了无秘密 阅读源码能帮助我们理解实现原理 xff0c 然后更灵活的运用 接下来我用VS2015调试Qt5 9源码 首先提一下
  • python实现批量提取图片中文字的小工具

    要实现批量提取图片中的文字 xff0c 我们可以使用Python的pytesseract和Pillow库 pytesseract是一个OCR xff08 Optical Character Recognition xff0c 光学字符识别
  • 在虚拟机与wsl中编译内核,并启用kasan

    linux内核编译 虚拟机编译自己的内核WSL编译自己的内核测试kasan 想学习一下kasan的相关配置 xff0c 但kasan是内核中的相关配置 xff0c 开启得编译内核 xff0c 本文使用虚拟机与wsl两种方式来编译内核启动ka
  • mac时间机器删除旧备份

    查 span class token function sudo span tmutil listlocalsnapshots 删除 span class token function sudo span tmutil deleteloca
  • Linux的ssh服务

    Linux中真机与虚拟机的连接 1 先修改虚拟机的IP xff0c 在虚拟机中输入命令 xff1a nm connection editor 2 选中已存在的System eth0 xff0c 选择Delete 然后点击Add xff0c
  • 51单片机——控制步进电机加速、减速及反转

    加速 xff1a include lt reg52 h gt define uchar unsigned char define uint unsigned int define MotorData P1 uchar phasecw 4 6
  • 【安装教程】——Linux安装opencv

    安装教程 Linux安装opencv 一 安装相关软件包二 获取Source三 安装OpenCV 未完待续 一 安装相关软件包 安装相关软件包 打开终端 xff0c 安装以下软件包 sudo apt install build span c
  • 请教:linux下的/opt目录是做什么用的?

    请教 linux下的 opt目录是做什么用的 蛋疼YMG 浏览 4934 次 发布于2014 10 24 13 35 荒漠探险 答题闯关 好礼连连 最佳答案 opt 主机额外安装软件所摆放的目录 默认是空的 一般安装软件的时候 xff0c
  • kali里装java

    1 下载最新的JAVA JDK jdk 8u91 linux x64 2 解压缩文件并移动至 opt tar xzvf jdk 8u91 linux x64 tar gz mv jdk1 8 0 91 opt cd opt jdk1 8 0
  • ffmpeg--tcp

    TCP代码分析 xff1a include 34 avformat h 34 include 34 libavutil avassert h 34 include 34 libavutil parseutils h 34 include 3

随机推荐

  • GO调用ffmpeg动态库

    package main cgo CFLAGS I usr local ffmpeg include cgo LDFLAGS L usr local ffmpeg lib lavformat include 34 libavformat a
  • ftp上传文件时出现 550 Permission denied,不是用户权限问题

    查了半天 xff0c 发现是因为服务器已经有同名的文件了 xff0c 所以无法上传 xff0c 上传文件名改成不重复的就可以了
  • Ubuntu16.04解决登录闪退问题

    Ubuntu16 04解决登录闪退问题 夏哈哈 64 64 2020 07 21 08 42 46 637 收藏 1 分类专栏 xff1a 程序开发 Ubuntu装机 文章标签 xff1a 深度学习 版权 问题 xff1a Ubuntu16
  • syntax error at or near “.“

    SQL select t id t name t sex t age t address t phone from user where id 61 1 Cause org postgresql util PSQLException ERR
  • ubuntu编译安装opencv4.5.1(C++)

    1 在https github com opencv opencv上下载opencv源码 包括opencv4 5 1和opencv contrib4 5 1 也可以不用 xff0c 其中opencv是主体 xff0c opencv cont
  • Pytorch自定义CNN网络实现猫狗分类

    数据集下载地址 xff1a https www microsoft com en us download confirmation aspx id 61 54765 Dogs vs Cats 猫狗大战 来源Kaggle上的一个竞赛题 xff
  • yolov5训练自己的数据集

    yolo系列在目标检测领域的地位就不用说了 xff0c github上有pytorch实现的训练yolov5的代码 xff0c 本文将用自己的数据去训练一个yolov5的模型 参考代码地址 https github com ultralyt
  • pytorch用自己数据训练VGG16

    一 VGG16的介绍 VGG16是一个很经典的特征提取网络 xff0c 原来的模型是在1000个类别中的训练出来的 xff0c 所以一般都直接拿来把最后的分类数量改掉 xff0c 只训练最后的分类层去适应自己的任务 xff08 又叫迁移学习
  • mmdetection v2.x模型训练与测试

    因为工作关系 xff0c 接触到了mmdetection 它是商汤科技和香港中文大学开源的一款基于pytorch底层的目标检测工具箱 xff0c 隶属于mmlab项目 mmdetection 2 7文件结构 其中configs文件夹下为网络
  • Pytorch用自己的数据训练ResNet

    一 ResNet算法介绍 残差神经网络 ResNet 是由微软研究院的何恺明等人提出的 ResNet 在2015 年的ILSVRC中取得了冠军 通过实验 xff0c ResNet随着网络层不断的加深 xff0c 模型的准确率先是不断的提高
  • 新装Ubuntu系统基本环境安装配置(conda)

    Ubuntu系统用conda来配置深度学习环境 1 配置显卡驱动 查看显卡驱动 nvidia smi 如果是这个命令没有反应 xff0c 就是没有显卡驱动 xff0c 这个时候要先安装显卡驱动 然后再安装cuda xff0c 在这个图中 x
  • Python源码加密与Pytorch模型加密

    0 前言 深度学习领域 xff0c 常常用python写代码 xff0c 而且是建立在一些开源框架之上 xff0c 如pytorch 在实际的项目部署中 xff0c 也有用conda环境和python代码去部署服务器 xff0c 在这个时候
  • Paddle-Lite终端部署深度学习模型流程

    Paddle Lite是飞桨基于Paddle Mobile全新升级推出的端侧推理引擎 xff0c 在多硬件 多平台以及硬件混合调度的支持上更加完备 xff0c 为包括手机在内的端侧场景的AI应用提供高效轻量的推理能力 xff0c 有效解决手
  • nohup: failed to run command `java': No such file or directory

    问题描述 xff1a 平台研发项目 xff0c ActiveQM做消息队列 xff0c zookeeper做集群 xff0c zkui做可视化服务管理 xff0c skynet是引擎服务 xff0c skynet下面有一个xmanager是
  • 【原理篇】一文读懂Faster RCNN

    0 Faster RCNN概述 论文地址 xff1a https arxiv org pdf 1506 01497 pdf Faster R CNN源自2016年发表在cs CV上的论文 Faster R CNN Towards Real
  • 【原理篇】一文读懂Mask RCNN

    Mask RCNN 何凯明大神的经典论文之一 xff0c 是一个实例分割算法 xff0c 正如文中所说 xff0c Mask RCNN是一个简单 灵活 通用的框架 xff0c 该框架主要作用是实例分割 xff0c 目标检测 xff0c 以及
  • 【原理篇】一文读懂FPN(Feature Pyramid Networks)

    论文 xff1a feature pyramid networks for object detection 论文链接 xff1a https arxiv org abs 1612 03144 这篇论文是CVPR2017年的文章 xff0c
  • pytorch用voc分割数据集训练FCN

    语义分割是对图像中的每一个像素进行分类 xff0c 从而完成图像分割的过程 分割主要用于医学图像领域和无人驾驶领域 和其他算法一样 xff0c 图像分割发展过程也经历了传统算法到深度学习算法的转变 xff0c 传统的分割算法包括阈值分割 分
  • 【学习笔记】语义分割综述

    语义分割就是图像分割 xff0c 是图像像素级的分类 xff0c 即给图像的每一个像素点分类 与之临近的一个概念叫实例分割 xff0c 实例分割就是语义分割 43 目标检测 语义分割只能分割出所有同类的像素 xff0c 目标检测把不同的个体
  • pytorch用自己数据集训练Unet

    在图像分割这个问题上 xff0c 主要有两个流派 xff1a Encoder Decoder和Dialated Conv 本文介绍的是编解码网络中最为经典的U Net 随着骨干网路的进化 xff0c 很多相应衍生出来的网络大多都是对于Une