《图像分割Unet网络分析及其Pytorch版本代码实现》

2023-11-19

  最近两个月在做学习图像分割方面的学习,踩了无数的坑,也学到了很多的东西,想了想还是趁着国庆节有时间来做个总结,以后有这方面需要可以来看看。

  神经网络被大规模的应用到计算机视觉中的分类任务中,说到神经网络的分类任务这里不得不提到CNN(卷积神经网络), 在我的认识中,CNN的分类是对整个训练图像对应的标签进行分类,而图像分割网络Unet是对图像的各个像素进行分类,在图像分类时像素 “0” 一般都代表着背景,其他的像素代表你自己需要训练分割的类别,比如在你进行第一个类别图像标注时,你可以用像素 “1” 代表你的第一类,用像素 “2” 代表第二个类,以此类推。当然,你也可以用其他的任意的像素表示自己的类别或者背景,这不重要,只是用像素 “0”、“1”、“2”等代表你的背景像素或者类别比较方便,训练起来消耗的时间也比较用其他像素短。

接下来,我们开始分析网络结构以及Pytorch版本的图像分割Unet。

1、Unet网络结构

图1-1 Unet网络结构

            

1.1 Unet网络结构

Unet网络可以分成两个结构:

(1)图像特征提取层:该层由卷积(Conv)、下采样(Pooling)构成,如图 1-1 左半部分,输入大小为 572x572x1(w,h,c) 的图像数据image到网络后先进行两次卷积得到C1(568x568x64),再进行下采样得到D1(284x284x64),继续对D1层进行进行两次卷积得到C2(280x280x128),对C2进行下采样得到D2(140x140x128),以此类推,后面分别计算出C3(136x136x256)、D3(68x68x256)、C4(64x64x512),D4(32x32x512)。至此,特征提取层结束。

(2)图像特征融合层:该层由卷积(Conv)、上采样(使用转置卷积或线性采样)、图像数据的拼接构成,首先一样的使用C4进行两次卷积得到C5(28x28x1024),再进行装置卷积或者线性采样得到U1(56x56x1024),此时再与C4进行拼接得到O1(56x56x1024),O1再进行两次卷积、上采样等操作,以此类推最后得到输出图像output(388x388x2)。至此整个Unet网络完成。

2、Pytorch版本代码实现

  这里使用的是大佬的图像分割网络Unet进行学习的,bilibili链接:https://www.bilibili.com/video/BV11341127iK/?spm_id_from=333.999.0.0&vd_source=35b62865b997e4f1a87b1ab816f5296b

2.1 图像标注

  这里使用开源图像标注工具labelme,命令行cmd使用命令 pip install labelme 进行安装,安装完成后在命令行中输入 labelme 打开工具进行标注。

图2-1 labelme标注工具

   标注完成保存之后会生成 .json 文件,在该标签图像路径输入 labelme_json_to_dataset + 你 .json文件名就可以生成标签文件,如图2-2 至 图2-3显示,则成功标注该图像。其中img为标注原图、label为标签图像、label_names.txt文本文件里面是标注的类别以及背景类、label_viz为标注原图与标签图像融合之后得到的图像。

图2-2 labelme命令

图2-3 labelme生成的标注图像

图2-4 标注原图

图2-5 标签图像

图2-6 label_names.txt

图2-7 标注原图与标签图像融合

 当标注的图像比较多时,使用labelme工具自带的解析器一个一个标签图像的生成会浪费很多时间,因此我自己写了一个代码来自动使用labelme的解析器,以下代码能够批量的生成标签图像。其中image路径为.json和标注原图的路径,JPEGImages为生成的训练图像路径,SegmentationClass为生成的标签图像路径。

  json_to_dataset.py

from __future__ import print_function
import argparse
import glob
import math
import json
import os
import os.path as osp
import shutil
import numpy as np
import PIL.Image
import PIL.ImageDraw
import cv2
import time



def json_to_dataset(json_path, image_path, label_path):
    if osp.isdir(label_path):
        shutil.rmtree(label_path)
        #print(label_path)
    os.makedirs(label_path)
    image_path_list = []
    json_path_list = []
    for file_path in os.listdir(json_path):  # 0.png - 10.png
        #print(file_path)
        if file_path.endswith(".png"):  
            image_name = file_path.split(".")[0]   # 0 - 10
            json_name = os.path.join(json_path, image_name + ".json")   # C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data/image\0.json
            
            image_path_list.append(os.path.join(json_path , file_path))  # ['C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data/image\\0.png']
            json_path_list.append(json_name)   # ['C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data/image\\0.json']
    
    # print(image_path_list)
    # print(json_path_list)
    for i in range(len(image_path_list)):
        # 读取原图像
        image = cv2.imread(image_path_list[i])
        h, w = image.shape[:2]
        # 生成与原图像大小的一样的标签图像
        mask = np.zeros([h, w, 1], np.uint8)

        # 打开json文件
        with open(json_path_list[i], "r") as f:
            label = json.load(f)
        # 提取json文件中的 shapes
        label = label["shapes"]
        for label in label:
            category = label["label"]   #  标签
            points = label["points"]  # 标记的点
            #print(category, points_array)
            points_array = np.array(points, dtype=np.int32)

            # 填充
            mask = cv2.fillPoly(mask, [points_array], category_types.index(category))
            
            # 保存原图像至 JPEGImages
            cv2.imwrite(os.path.join(image_path, image_path_list[i].split("\\")[-1]), image)
            # 保存标签图像至 SegmentationClass
            cv2.imwrite(os.path.join(label_path, image_path_list[i].split("\\")[-1]), mask)
    
    print("Pictures has been saved!")



if __name__=='__main__':
    json_path = "C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data/image"
    image_path = "C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data/JPEGImages"
    label_path = "C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data/SegmentationClass"

    # 标签种类
    category_types = ["Background", "1", "2"]

    json_to_dataset(json_path, image_path, label_path)

至此,图像的标注完成。

2.2 图像预处理加载图像数据

  这里大佬先找到图像的最长边,然后用黑色像素来填充另外一边形成的高和宽相等的图像来进行训练,生成大小长宽相等的图像之后再把图像大小重置为256x256进行训练,比如标注的图像大小为640x480,则找到图像的最长边640,另外一边长为480的边则用黑色像素填充为640,最后得到的标注图像大小为640x640,再把图像大小重置为256x256,标签图像同理。图像预处理代码如下:

  utils.py

from PIL import Image


def keep_image_size_open(path, size=(256, 256)):
    img = Image.open(path)
    temp = max(img.size)
    mask = Image.new('P', (temp, temp))
    
    mask.paste(img, (0, 0))
    ;
    mask = mask.resize(size)
    #mask.save(path)
    return mask
def keep_image_size_open_rgb(path, size=(256, 256)):
    img = Image.open(path)
    temp = max(img.size)
    mask = Image.new('RGB', (temp, temp))
    mask.paste(img, (0, 0))
    mask = mask.resize(size)
    #mask.save(path)
    return mask

if __name__ == '__main__':

    image_path = "C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data/JPEGImages/0.png"
    label_path = "C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data/SegmentationClass/0.png"
    
    image1 = keep_image_size_open_rgb(image_path)
    print(image1.mode)
    print(image1.size)
    image1.show('test1')

    image2 = keep_image_size_open(label_path)
    print(image2.mode)
    print(image2.size)
    #image2.save('C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data/JPEGImages/PASS2022_04_29_11_16_49_924.jpg')
    image2.show('test2')
图2-8 生成的标注图像

                 

  若是训练的图像是类似与下图中图像,背景占比较大,而我们需要分类的像素是图像的某个特征点,则我们可以使用opencv的查找轮廓函数进行图像的特征提取,如不进行提取的话,背景多余的干扰影响会很大,导致要训练更多的次数才能把图像中的类别给分割出来,并且效果很一般,这时可以使用opencv中的特征查找、特征提取函数进行提取特征。提取图像特征代码如下:

  image_corp.py

import cv2
from PIL import Image
import os

def get_picture_path(file_path):

    image_path_list = []
    for i in os.listdir(file_path):
        image_path = i.split(".")
       
        if image_path[-1] != "json":
            image_crop(os.path.join(file_path, i), i)
        
    print(image_path_list)
    

def image_crop(image_path, image_name):
    
    img = Image.open(image_path)

    image = cv2.imread(image_path)
    image_copy = image.copy()

    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    ret, thresh = cv2.threshold(image, 200, 255,cv2.THRESH_BINARY)
    contours,hierarchy=cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
    for i in range(len(contours)):
        area = cv2.contourArea(contours[i])
        #print(area)
        if 111208.0 <area< 1308417.0:
            print(area)
            #image_copy=cv2.drawContours(image_copy,contours[i],-1,(0,255,0),2)  # img为三通道才能显示轮廓 cv2.FILLED
            x, y, w, h = cv2.boundingRect(contours[i])   
            cv2.rectangle(image_copy, (x-20,y-20), (x+w+20,y+h+20), (255,0,0), 2) 
            img = img.crop((x, y, x+w, y+h))

        else:
            continue

    img.save(os.path.join("C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data/image", image_name))
    #cv2.imwrite("C:/Users/yiyaxin/Desktop/PASS2022_04_29_11_08_20_749_test.jpg", image_copy)

    #cv2.imshow("thresh", image_copy)


if __name__ =='__main__':

    file_path = 'C:/Users/yiyaxin/Desktop/bls/20220513/8X2.5-BIAOMIAN-ng'
    get_picture_path(file_path)


    cv2.waitKey(0)

图2-9 未使用opencv提取的标注图像

图2-10 使用opencv提取后的标注图像

  

  众所周知,在Pytorch中加载自己的训练图像时重写Dataset中类中的初始化函数(init)、长度函数(len)和加载图像函数(getitem),在大佬的代码里,初始化函数是找到图像的路径,长度函数则是返回图像数据的数量,getitem函数里面则是先把图像处理成长宽相等的图像,再重置大小为256x256。接着再使用pytorch中的transforms把图像数据和标签数据转换成向量的形式,传入网络训练。加载图像代码如下:

  data.py

import os

import numpy as np
import torch
from torch.utils.data import Dataset
from utils import *
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor()
])


class MyDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.name = os.listdir(os.path.join(path, 'SegmentationClass'))

    def __len__(self):
        #print("len(name)", len(self.name))
        return len(self.name)

    def __getitem__(self, index):
        segment_name = self.name[index]  # xx.png
        segment_path = os.path.join(self.path, 'SegmentationClass', segment_name)
        image_path = os.path.join(self.path, 'JPEGImages', segment_name)
        #print("segment_name: ", segment_name)
        #print("image_path: ", image_path)
        segment_image = keep_image_size_open(segment_path)
        image = keep_image_size_open_rgb(image_path)

        # print(image.size)
        segment_image = np.array(segment_image)
        # print(image.shape[1])
        for i in range(segment_image.shape[1]):
            print(np.array(segment_image[i]))
            # for j in range(segment_image.shape[0]):
            #     print(np.array(segment_image[i][j]))
        
        return transform(image), torch.Tensor(np.array(segment_image))


if __name__ == '__main__':
    #from torch.nn.functional import one_hot
    data = MyDataset('C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data')
    print("image: ", data[0][0].shape)
    print("label:", data[0][1].shape)
    #out=one_hot(data[0][1].long())
    #print("one_hot:", out.shape)

2.3 Unet网络搭建

  在这一步中,分别构建卷积层类、下采样类、上采样类、Unet网络类。其中卷积层包括两次卷积函数,使用Pytorch中的BatchNorm2d函数进行数据归一化、Dropout2d函数进行数据的随机丢弃,目的是防止数据过大而产生过拟合,使用的激活函数为LeakyRelu()函数。下采样使用卷积函数Conv2d和BatchNorm2d及LeakyRelu()函数,上采样使用卷积函数Conv2d使图像通道变为原来的一半,接着使用转置卷积函数ConvTranspose2d或者使用线性采样函数interpolate进行上采样,最后再使用cat函数进行图像的拼接。接着再按照论文中的Unet网络结构进行搭建Unet类,这里可以用不用sigmoid或softmax等激活都无所谓,其中num_classes为图像预测的类别,完成网络的搭建后可以测试以下,比如输入(1,3,256,256)大小的图像数据,若经过网络计算后输出的图像数据依然是(1,3,256,256)大小的图像数据,则搭建的网络没有问题。

至此,Unet网络的搭建完成。搭建Unet网络代码如下:

  net.py

import torch
from torch import nn
from torch.nn import functional as F

class Conv_Block(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(Conv_Block, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)


class DownSample(nn.Module):
    def __init__(self,channel):
        super(DownSample, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)


class UpSample(nn.Module):
    def __init__(self,channel):
        super(UpSample, self).__init__()
        self.layer=nn.Conv2d(channel,channel//2,1,1)
        self.up =torch.nn.ConvTranspose2d(channel,channel,2,2)
    def forward(self,x,feature_map):
        out=self.layer(self.up(x))
        return torch.cat((out,feature_map),dim=1)


class UNet(nn.Module):
    def __init__(self,num_classes):
        super(UNet, self).__init__()
        self.c1=Conv_Block(3,64)
        self.d1=DownSample(64)
        self.c2=Conv_Block(64,128)
        self.d2=DownSample(128)
        self.c3=Conv_Block(128,256)
        self.d3=DownSample(256)
        self.c4=Conv_Block(256,512)
        self.d4=DownSample(512)
        self.c5=Conv_Block(512,1024)
        self.u1=UpSample(1024)
        self.c6=Conv_Block(1024,512)
        self.u2 = UpSample(512)
        self.c7 = Conv_Block(512, 256)
        self.u3 = UpSample(256)
        self.c8 = Conv_Block(256, 128)
        self.u4 = UpSample(128)
        self.c9 = Conv_Block(128, 64)
        self.out=nn.Conv2d(64,num_classes,3,1,1)
        print("num_classes: ", num_classes)

    def forward(self,x):
        R1=self.c1(x)
        #print(R1.size())
        R2=self.c2(self.d1(R1))
        R3 = self.c3(self.d2(R2))
        R4 = self.c4(self.d3(R3))
        R5 = self.c5(self.d4(R4))
        #print(R5.size())
        O1=self.c6(self.u1(R5,R4))
        O2 = self.c7(self.u2(O1, R3))
        O3 = self.c8(self.u3(O2, R2))
        O4 = self.c9(self.u4(O3, R1))

        return self.out(O4)
        #return F.log_softmax(self.out(O4),dim=1)

if __name__ == '__main__':
    x=torch.randn(1,3,256,256)
    net=UNet(5)
    print("shape: ", net(x).shape)

  2.4 图像训练与预测

  搭建好Unet网络之后就可以开始训练了,这里最需要注意的是背景也为一类,也就是说如果你的标注图像标注的是2个类别,那么进行训练时的类别是3类,传入的num_classes参数应该为3,首先把图像加载到网络里进行训练,再进行反向传播就可以了,这里使用的优化器是自适应Adam,使用的损失函数为多分类损失函数交叉熵损失函数CrossEntropyloss函数,如果只想进行二分类也可以只用BCE损失函数,不过网络之后的激活函数要换成sigmoid激活函数。训练函数代码如下:

train.py

import os
 
import tqdm
from torch import nn, optim
import torch
from torch.utils.data import DataLoader#数据集加载器
from data import *
from net import *
from torchvision.utils import save_image

import os



os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

 
device = torch.device('cuda')
weight_path = 'C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/params/unet.pth'#权重地址
data_path = r'C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data'#数据集地址
save_path = 'C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/train_image'#训练时保存的图像地址
if __name__ == '__main__':
    num_classes = 2+ 1  # +1是背景也为一类
    data_loader = DataLoader(MyDataset(data_path), batch_size=2, shuffle=True)#加载数据集,batch_size批次,根据自身电脑的情况进行修改
    net = UNet(num_classes).to(device)#实例化Unet网路
    if os.path.exists(weight_path):#判断权重是否存在
        net.load_state_dict(torch.load(weight_path))
        print('successful load weight!')
    else:
        print('not successful load weight')
    
    

    opt = optim.Adam(net.parameters())
    #opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    loss_fun = nn.CrossEntropyLoss()  # nn.BCELoss()
 
    epoch = 1
    while epoch < 100:
        for i, (image, segment_image) in enumerate(tqdm.tqdm(data_loader)):

            #print("标签种类:", segment_image.max(), segment_image.min())
            # print(image.size)
            # print(segment_image.size)
            image, segment_image = image.to(device), segment_image.to(device)
            if hasattr(torch.cuda, 'empty_cache'):
	            torch.cuda.empty_cache()

            out_image = net(image)
            
            train_loss = loss_fun(out_image, segment_image.long())
            #print("train_loss:", train_loss)
           
            opt.zero_grad()
            try:
                train_loss.backward()
            except RuntimeError as e:
                print("异常:", e)
            opt.step()
 
            if i % 5 == 0:
                print(f'\t{epoch}-{i}-train_loss===>>{train_loss.item()}')
 
            _image = image[0]
            _segment_image = torch.unsqueeze(segment_image[0], 0) * 255
            _out_image = torch.argmax(out_image[0], dim=0).unsqueeze(0) * 255
 
            img = torch.stack([_segment_image, _out_image], dim=0)
            save_image(img, f'{save_path}/{i}.png')
        if epoch % 20 == 0:#每20次保存一次权重
            torch.save(net.state_dict(), weight_path)
            #torch.save(net, weight_path)
            print('save successfully!')
        epoch += 1

  测试代码就是使用图像预处理时的函数加载图像到训练好的网络里进行分类,由于网络输出的是(1,3,256,256)大小的图像数据,所有只需要把这个图像数据进行降维成(1,256,256)大小的图像数据,接着把图像数据的像素转换成没有重复数据的矩阵就可以知道它预测出来的类别了,想要使用opencv查看的话需要把图像数据转换成(256,256,1)的形式进行保存或者显示,预测分类图像的代码如下:

  test.py

import os

import cv2
import numpy as np
import torch

from net import *
from utils import *
from data import *
from torchvision.utils import save_image
from PIL import Image

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
net=UNet(3).cuda()

weights='C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/params/unet.pth'
if os.path.exists(weights):
    net.load_state_dict(torch.load(weights))
    #net.load(weights)
    print('successfully')
else:
    print('no loading')

#_input='C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/test/PASS2022_04_29_11_16_49_924.jpg'
_input = 'C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/test/0.jpg'

img=keep_image_size_open_rgb(_input)

#img = Image.open(_input)
#img = img.convert("RGB")

img_data=transform(img).cuda()   # (3, 256, 256)
img_data=torch.unsqueeze(img_data,dim=0)  # (1, 3, 256, 256)
print("img_data.size: ", img_data.shape)
net.eval()
out=net(img_data) # 网络输出 (1, 2, 256, 256)
out=torch.argmax(out,dim=1)    # (1, 256, 256)
out=torch.squeeze(out,dim=0)    # (256, 256)
out=out.unsqueeze(dim=0)        # (1, 256, 256)
print(set((out).reshape(-1).tolist()))
out=(out).permute((1,2,0)).cpu().detach().numpy()   # (256, 256, 1)[
cv2.imwrite('C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/result/result.png',out)
cv2.imshow('out',out*255.0)
cv2.waitKey(0)

  上面的代码显示的只是像素为0或255(黑或白)的图像,若想看的它的类别的话可以使用以下的代码进行显示:

label_path = 'C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/result/ret.jpg'
#label_path = 'C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/data/SegmentationClass'

label = np.asarray(Image.open(label_path), dtype=np.float32)
np.save("test.npy",label)

img3 = np.load("test.npy")
print("img3.shape", img3.shape)
print(set((img3).reshape(-1).tolist()))
plt.imshow(img3)
plt.show()

  若是想使用opencv中的查找轮廓函数显示分割结果,可以使用下面的代码进行显示:

image1 = cv2.imread('C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/result/ret.jpg')

print(set((image1).reshape(-1).tolist()))

image2 = cv2.imread('C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/test/0.jpg')
#image2 = keep_image_size_open_rgb('C:/Users/yiyaxin/Desktop/C++/pytorch-UNet-master/test/PASS2022_04_29_11_08_46_933.png')
# image2 = cv2.cvtColor(np.array(image2), cv2.COLOR_RGB2BGR)

image2 = cv2.resize(image2, (256, 256))

image3 = image2.copy()
# print(image3.shape)
image1=cv2.cvtColor(image1,cv2.COLOR_BGR2GRAY)
# #print(image1.shape)
ret,thresh=cv2.threshold(image1,0,255,0)

#cv2.imshow('imageshow',thresh)  # 显示返回值image,其实与输入参数的thresh原图没啥区别
 
contours,hierarchy=cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
#print(contours)
for i in range(len(contours)):
    area = cv2.contourArea(contours[i])
    #print(area)
    if area < 10:
        continue
    print(area)
    image1=cv2.drawContours(image2,contours[i],-1,(0,255,0),2)  # img为三通道才能显示轮廓 cv2.FILLED


# # 核心拼接代码
image = np.concatenate([image3, image2], axis=1)

cv2.imshow('drawimg',image)
cv2.waitKey(0)
cv2.destroyAllWindows()

图2-9 预测图像1

  

图2-10 网络输出结果

图2-11 预测结果

图2-12 可以查看类别的图像

图2-13 预测图像2

图2-14 预测结果2

图2-15 预测图像3

图2-16 预测结果3

图2-17 预测图像4

图2-18 预测图像4

图2-19 预测图像5

图2-20 预测图像5

   至此,Unet分割网络项目完成。

3、项目总结

  这次的Unet分割网络主要可以分为四个步骤:

 一、图像预处理:安装labelme工具进行标注,进行把图像预处理为等高的256x256x3的图像数据,重写Pytorch中Dataset中的加载图像数据函数。

二、搭建Unet网络:首先构建卷积类、下采样类、上采样类,其中num_classes为自己标注的类别加一,因为背景像素也是一个类别,如果只进行二分类,则使用的激活函数为sigmoid,损失函数为BCE损失函数,若进行多分类则可以使用激活函数为softmax,损失函数为交叉熵损失函数CrossEntorpyLoss函数。上采样可以使用转置卷积函数或者线性采样函数,在上采样的最后要进行图像数据的拼接。

三、训练和预测:按照官方标准的训练测试函数构建。

  最后,分析一下这个网络的优缺点,在我学习中看来,Unet网络对于大图像特征的分割还是比较不错的,可以使用较少的训练图像和较少的训练次数就能够得到很好的分类结果,网络搭建起来也是比较简单的,特别是熟悉Pytorch的话搭建起来超级方便。最大的缺点我觉得是对于图像的特征提取不够好,这个或许是跟本身的网络结构有问题,由于它对图像的特征提取并没有那么好,因此在训练背景像素干扰比较大,图像也比较大,想要分类的图像比较细致的话结果并没有那么理想,对于这种图像需要训练的次数还是比较多的,而且分割出来的图像特征干扰还是比较多的。第二个就是对于类别特别多的图像有时候根本分割不出全部的图像类别,会损失掉一两个的图像特征,对于这点以我目前的知识还没想到是什么原因导致的。

  下次我会给大家带来C++版的基于libtorch的Unet图像分割网络分析和代码,以及在实现过程中我所踩过的坑。

  欢迎大家对此项目提出您最宝贵的建议,并在此处留言,指正我在文章内出现的错误或者与我交流您对于Unet分割网络的宝贵见解。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

《图像分割Unet网络分析及其Pytorch版本代码实现》 的相关文章

随机推荐

  • 用Python画一个生日蛋糕并写上生日祝福对象及生日祝福语

    用Python画一个生日蛋糕并写上生日祝福对象及生日祝福语 画一个双层蛋糕并点上蜡烛 代码运行时间较长 请静待惊喜出现 代码运行截图 完整程序代码 干货主要有 200 多本 Python 电子书 和经典的书籍 应该有 Python标准库资料
  • 掌握 Ajax,第 7 部分: 在请求和响应中使用 XML

    偶尔使用 Ajax 的开发人员也会注意到 Ajax 中的 x 并意识到它代表 XML XML 是编程中最常用的数据格式之一 对于异步应用程序中的服务器响应能够带来切实的好处 在本文中 您将看到服务器如何在请求响应中发送 XML 现在如果不使
  • Java设计模式:装饰者模式(Decorator Pattern)

    装饰者模式 涉及的重要设计原则 类应该对扩展开放 对修改关闭 装饰者模式定义 装饰者模式动态地将责任附加到对象上 若要扩展功能 装饰者提供了比继承更有弹性的替代方案 UML类图 装饰者模式事例 咖啡店 咖啡种类 1 深焙咖啡 DarkRoa
  • Python中如何使用boolean类型的数据

    在写代码的过程中 遇到了定义boolean类型变量的问题 之前一直试图用java或者c定义布尔变量的方法 一直不奏效 经过一旦学习之后才明白 和java竟然只是大小写的问题 在python中将java中的true携程True 将false携
  • Educational Codeforces Round 149 (Rated for Div. 2)A~D

    Grasshopper on a Line 题意 给出n和k 求从0到n最少走几步 以及步长 要求步长不能整除k 思路 从n往下找到 k不等于0的数 输出该数和n 该数即可 如果n k 0 那就只需要一步 代码 gt File Name a
  • 探索Java8——默认方法

    文章目录 什么是默认方法 不断演进的API 初始版本API 第二版API 概述默认方法 什么是默认方法 在传统的Java程序中 实现接口的方式是通过Implements把接口中的每一个方法提供一个实现 或者从父类继承他的实现 然而 在实际开
  • redis搜索 - KEYS命令

    文章目录 KEYS命令 使用 使用场景 KEYS命令 KEYS命令用于搜索匹配某个模式的所有key 例如常见的keys 命令 会返回所有的键 Time complexity O N 使用 KEYS命令支持以下正则匹配模式 h llo mat
  • [STM32]详解单片机GPIO输入模式配置-上拉下拉与浮空

    前面说到单片机的GPIO主要输出模式主要有推挽模式和开漏模式 除了连接到片内外设的模拟输入模式和复用输入功能以外 这里再说一下通用输入模式配置 STM32单片机的通用输入模式主要有输入浮空 输入上拉与输入下拉 当配置成上拉模式 即GPIO
  • python rsa加密之后byte类型存储到数据库中_python3 rsa加密

    遇到了跟你一样的问题 此js封装的源码 如下 希望看到的大神解决了的话帮我一下 RSA a suite of routines for performing RSA public key computations in JavaScript
  • c语言字符串相关函数的分析

    c语言中 常见的字符串相关函数主要分为两类 1 与字符串长度无关的函数 如strcpy strcat strcmp 2 与字符串长度有关的函数 如strlen strncpy strncat strncmp strlen 用于求字符串的长度
  • 1130:找第一个只出现一次的字符(C C++)

    题目描述 给定一个只包含小写字母的字符串 请你找到第一个仅出现一次的字符 如果没有 输出no 输入 一个字符串 长度小于100000 输出 输出第一个仅出现一次的字符 若没有则输出no 输入样例 abcabd 输出样例 c 代码 inclu
  • 【Unity】Delegate, Event, UnityEvent, Action, UnityAction, Func 傻傻分不清

    Unity Delegate Event UnityEvent Action UnityAction Func 傻傻分不清 Delegate 委托 函数指针 一个简单的例子 一对一依赖 一个简单的例子 一对多依赖 所以话说 委托有啥用呢 事
  • LDAP简介及其使用

    LDAP简介 LDAP Lightweight Directory Access Protocol 的意思是 轻量级目录访问协议 是一个用于访问 目录服务器 Directory Servers 的协议 这里所谓的 目录 是指一种按照树状结构
  • java button中加入背景图片不显示

    emmmm 写一下关于在button中添加图片作为背景的经历 就 先记录下错误的地方 JLabel stat new JLabel new ImageIcon img left png 这里再left png的路径的开头少了个点 就一直都不
  • Centos7安装Nessus教程

    本文为学习笔记 仅限学习交流 不得利用 从事危害国家或人民安全 荣誉和利益等活动 请参阅 中华人民共和国网络安全法 Nessus安装包 链接 https pan baidu com s 1FJMu8WMZPSjoqQpes GCng 提取码
  • C++中#ifndef, #define, #endif的作用和使用的注意事项

    在C 语言编程中 我们经常会接触到头文件 比如说声明类 或者声明命名空间等 而每次在编写xxx h的头文件时 编程书上都会让我们在代码的前后加上如下的三句代码 ifndef XXX H define XXX H endif 其中 代表中间具
  • DDP入门

    DDP 即动态动态规划 可以用于解决一类带修改的DP问题 我们从一个比较简单的东西入手 最大子段和 带修改的最大子段和其实是常规问题了 经典的解决方法是用线段树维护从左 右开始的最大子段和和区间最大子段和 然后进行合并 现在我们换一种方法来
  • 软件测试人员必备的60个测试工具清单,果断收藏了!

    据统计 中国软件外包市场的潜力和机会已远远超过软件王国印度 不过由于软件人才的严重不足致使我国软件发展遭遇 瓶颈 国家为了大力培养软件人才 不断采取积极有效的措施 我国对软件测试人才的需求数量还将持续增加 因此软件测试工程师也就成为了IT职
  • golang ---JSON-ITERATOR 使用

    jsoniter json iterator 是一款快且灵活的 JSON 解析器 Jsoniter 是最快的 JSON 解析器 它最多能比普通的解析器快 10 倍之多 独特的 iterator api 能够直接遍历 JSON 极致性能 0
  • 《图像分割Unet网络分析及其Pytorch版本代码实现》

    最近两个月在做学习图像分割方面的学习 踩了无数的坑 也学到了很多的东西 想了想还是趁着国庆节有时间来做个总结 以后有这方面需要可以来看看 神经网络被大规模的应用到计算机视觉中的分类任务中 说到神经网络的分类任务这里不得不提到CNN 卷积神经