Unet网络搭建(Pytorch)

2023-11-09

Unet是一个经典的语义分割网络,常常被用于医学影像的分割。在Unet的网络结构中,可以分为卷积模块,下采样模块以及上采样模块,详见下面的网络结构图:
在这里插入图片描述 在网络的搭建过程中,也是依照分为三大块这种思路进行搭建。话不多说,直接上代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

class conv_block(nn.Module):
    def __init__(self,in_c,out_c):
        super(conv_block,self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_c,out_c,kernel_size=(3,3),stride=1,padding=1,padding_mode='reflect'),
            nn.BatchNorm2d(out_c),
            nn.Dropout(0.3),
            nn.ReLU(inplace=True),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=(3, 3), stride=1, padding=1, padding_mode='reflect',bias = False),
            nn.BatchNorm2d(out_c),
            nn.Dropout(0.3),
            nn.ReLU(inplace=True),
        )

    def forward(self,x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

class Downsample(nn.Module):
    def __init__(self,channel):
        super(Downsample, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(channel, channel, kernel_size=(3, 3), stride=2, padding=1,  bias=False),
            nn.BatchNorm2d(channel),
            nn.ReLU()
        )

    def forward(self,x):
        return self.layer(x)


class Upsample(nn.Module):
    def __init__(self,channel):
        super(Upsample, self).__init__()
        self.conv1 = nn.Conv2d(channel,channel//2,kernel_size=(1,1),stride=1)

    def forward(self,x,featuremap):
        x = F.interpolate(x,scale_factor=2,mode='nearest')
        x = self.conv1(x)
        x = torch.cat((x,featuremap),dim=1)
        return x

class UNET(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(UNET, self).__init__()
        self.layer1 = conv_block(in_channel,out_channel)
        self.layer2 = Downsample(out_channel)
        self.layer3 = conv_block(out_channel,out_channel*2)
        self.layer4 = Downsample(out_channel*2)
        self.layer5 = conv_block(out_channel*2,out_channel*4)
        self.layer6 = Downsample(out_channel*4)
        self.layer7 = conv_block(out_channel*4,out_channel*8)
        self.layer8 = Downsample(out_channel*8)
        self.layer9 = conv_block(out_channel*8,out_channel*16)
        self.layer10 = Upsample(out_channel*16)
        self.layer11 = conv_block(out_channel*16,out_channel*8)
        self.layer12 = Upsample(out_channel*8)
        self.layer13 = conv_block(out_channel*8,out_channel*4)
        self.layer14 = Upsample(out_channel*4)
        self.layer15 = conv_block(out_channel*4,out_channel*2)
        self.layer16 = Upsample(out_channel*2)
        self.layer17 = conv_block(out_channel*2,out_channel)
        self.layer18 = nn.Conv2d(out_channel,3,kernel_size=(1,1),stride=1)
        self.act = nn.Sigmoid()

    def forward(self,x):
        x = self.layer1(x)
        f1 = x
        x = self.layer2(x)
        x = self.layer3(x)
        f2 = x
        x = self.layer4(x)
        x = self.layer5(x)
        f3 = x
        x = self.layer6(x)
        x = self.layer7(x)
        f4 = x
        x = self.layer8(x)
        x = self.layer9(x)
        x = self.layer10(x,f4)
        x = self.layer11(x)
        x = self.layer12(x,f3)
        x = self.layer13(x)
        x = self.layer14(x,f2)
        x = self.layer15(x)
        x = self.layer16(x,f1)
        x = self.layer17(x)
        x = self.layer18(x)
        return self.act(x)


if __name__ == '__main__':
    #device = 'cuda' if torch.cuda.is_available() else 'cpu'
    x = torch.randn(10,3,256,256)
    model = UNET(3,64)
    #if hasattr(torch.cuda, 'empty_cache'):
        #torch.cuda.empty_cache()

    x = model(x)
    print(x.size())

    wiriter = SummaryWriter('log1')
    wiriter.add_graph(model,x)

最后,我们可以使用tensorboard查看网络结构:
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Unet网络搭建(Pytorch) 的相关文章

随机推荐

  • java 异常 错误_有关JAVA异常和错误(ERROR)的处理

    异常的处理主要包括捕获异常 程序流程的跳转和异常处理语句块的定义等 当一个异常被抛出时 应该有专门的语句来捕获这个被抛出的异常对象 这个过程被称为捕获异常 当一个异常类的对象被捕获后 用户程序就会发生流程的跳转 系统中止当前的流程而跳转至专
  • 【毕设选题】最新51单片机毕业设计项目集合 - 500例

    文章目录 1前言 2 STM32 毕设课题 3 如何选题 3 1 不要给自己挖坑 3 2 难度把控 3 3 如何命名题目 4 最后 1前言 更新单片机嵌入式选题后 不少学弟学妹催学长更新STM32和C51选题系列 感谢大家的认可 来啦 以下
  • Java对文件的基本操作(查找、读取)

    1 读取目录下的所有文件 隐藏的文件一网打尽 快速定位要找的那个文件 查询路径下的所有文件 param path 路径 private static void find String path File file new File path
  • 使用DatagramSocket发送、接收数据(1)

    Java使用DatagramSocket代表UDP协议的Socket DatagramSocket本身只是码头 不维护状态 不能产生IO流 它的唯一作用就是接收和发送数据报 Java使用DatagramPacket来代表数据报 Datagr
  • 给 Typora 改个背景颜色

    因为白色 在多云天气的时候 看上去有的扎眼 所以就想修改一下颜色 但本地的主题 不好看 所以就想简简单单换一个颜色 网上有很多 自定义主题的文章 我懒 只想改背景颜色 不想弄那么多的操作 换成这种颜色 就是好看 哈哈哈 操作 在 typor
  • 清华大学LightGrad-TTS,且流式实现

    论文链接 https arxiv org abs 2308 16569 代码地址 https github com thuhcsi LightGrad 数据支持 针对BZNSYP和LJSpeech提供训练脚本 针对Grad TTS提出两个问
  • stm32外部中断

    目录 1 STM32的外部中断线 2 NVIC嵌套向量中断器 3 外部中断 事件控制器 EXTI 4 STM32CubeMX配置外部中断 1 外部中断是什么 想象一个场景 你在家里玩游戏 这时候突然来电话了 这时你会停止玩游戏去接电话 电话
  • 开源|携程机票 App KMM 跨端 KV 存储库 MMKV-Kotlin

    作者简介 禹昂 携程移动端资深工程师 专注于 Kotlin 移动端跨平台领域 Kotlin 中文社区核心成员 图书 Kotlin 编程实践 译者 一 背景 携程机票移动端研发团队自 2021 年始就一直在移动端实践 Kotlin Multi
  • 关于二进制的练习

    前言 一 二题为牛客网练习 都有题目链接 文章目录 一 两个整数二进制位不同个数 二 输入一个整数 n 输出该数32位二进制表示中1的个数 其中负数用补码表示 三 获取一个整数二进制序列中所有的偶数位和奇数位 分别打印出二进制序列 一 两个
  • 马尔可夫过程

    马尔可夫过程的定义 平稳过程的平稳性保证了未来可以通过过去来预知 而马尔科夫是这样的一类过程 即未来只与现在有关 与过去无关 就是你的过去是什么样子不重要 未来只与自己当下的努力有关 我们只需要知道当前的信息就够了 举一个实际例子比如说卖电
  • 静态路由协议的默认管理距离是_距离矢量路由选择协议

    上一节我们主要讲述了影响路由选择协议的四个因素 路径决策 度量 收敛 负载均衡 也提了一下大多数路由选择协议的分类有距离矢量和链路状态 本节我们主要讲述一下距离矢量路由选择协议 首先说一下 该路由选择协议的由来 由于该路由选择协议通告的方式
  • https网络编程——如何做web的访问控制机制(ACL)

    参考 如何做web的访问控制机制 ACL 地址 https qingmu blog csdn net article details 108286660 spm 1001 2014 3001 5502 目录 ACL含义 例子 具体实现 AC
  • Linux相关的小知识点

    Linux 中每个 TCP 连接最少占用多少内存 详细解释 Linux 内核到底长啥样详细解释
  • GPS模块启动模式

    文章目录 GPS启动模式 1 冷启动 2 热启动 3 温启动 GPS模块举例 GPS启动模式 有3种启动模式 冷启动 温启动 热启动 启动时间 冷启动 gt 温启动 gt 热启动 启动时间越长定位越慢 用户使用体验越差 1 冷启动 冷启动是
  • Segmentation简记1-The Liver Tumor Segmentation Benchmark (LiTS)

    创新点 最主要的创新是建立了一个肝脏CT图像分割数据库 总结 类似于综述加上数据库的介绍 没有细看 医学方面时候会用到
  • 并发编程系列文章-Java线程的创建方式

    文章目录 继承Thread类 实现Runnable接口 使用Callable和Future创建有返回值的线程 使用Executor框架创建线程池 几个关键类的关系图 实战例子 常见的Java线程的4中方式包括 继承Thread类 实现Run
  • 用docker命令时报错,提示:Cannot connect to the Docker daemon at unix:///var/run/docker.sock.

    报错现象 root node02 docker ps Cannot connect to the Docker daemon at unix var run docker sock Is the docker daemon running
  • 工作中报错故障集合

    OOM常见报错排查之堆外内存溢出 报错 ExecutorLostFailure executor xxx exited caused by one of the running tasks Reason Container killed b
  • numpy和torch的一些操作

    1 如何把数据从1维扩充成2维 np expand dims x1 axis 1 或者x1 x1 None 从 2 33075 换成两个 1 33075 x1 x1 None 2 numpy trace array 返回数组沿对角线元素的和
  • Unet网络搭建(Pytorch)

    Unet是一个经典的语义分割网络 常常被用于医学影像的分割 在Unet的网络结构中 可以分为卷积模块 下采样模块以及上采样模块 详见下面的网络结构图 在网络的搭建过程中 也是依照分为三大块这种思路进行搭建 话不多说 直接上代码 import