【BraTS】Brain Tumor Segmentation 脑部肿瘤分割3--构建数据流

2023-11-17

往期回顾:

在上一篇网络复现中提到:

  • 输入图像变成: 4 * 155 * 240 * 240,155层,每一层4个channel,每一个channel是240*240大小

在在接下来构建数据流时候,输出的图像部分,也就按照这个大小进行构建即可。

构建数据流

在数据篇中,了解到4个模态,每一个模态存储的都是一个3维数据。而上文网络输入是一个(batch_size, channel=4, width, height),没有了Z轴的维度信息。所以,在构建数据的时候,需要按Z轴,将每一层单独拿出来,组成新的数据形式。

  1. 读取数据,区分图像和标签
  2. 处理成图像(channel, width, height)和标签(width, height)形式

pytorch的训练数据处理中,以下这段内容,可以作为嵌套的依据:

class Brats15DataLoader(Dataset):
    def __init__(self, data_dir, train=True):
        self.data = []

        if train:
        	self.data  ···
        else:
        	self.data  ···

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # ********** get file dir **********
        image, label = self.data[index]  # get whole data for one subject
        # ********** change data type from numpy to torch.Tensor **********
        image = torch.from_numpy(image).float()  
        label = torch.from_numpy(label).float()  
        return image, label

简单点就是干这么3件事情:

  1. 告诉我数据在哪里,我好去读取
  2. 有了数据,处理成模型需要feed的数据形式,按index一个一个的返回
  3. 返回数据长度,好知道一次循环结束了

插入一点内容:创建Brats15DataLoader类中__len____getitem__方法,可以类比创建列表的时候,是实现了一个列表对象的实例化a_list

  • len(a_list)函数实际上是调用列表类中的私有方法__len__()
  • 用列表索引的时候a_list[1],实际上是调用了__getitem__()方法,传入的index=1

直接将代码放到这里,注释部分加入了自己的学习和理解,如存在问题,欢迎评论区交流。

# coding:utf-8
from torch.utils.data import Dataset
from src.utils import *

modals = ['flair', 't1', 't1c', 't2']

class Brats15DataLoader(Dataset):
    def __init__(self, data_dir, conf='../config/train15.conf', train=True):
        img_lists = []
        train_config = open(conf).readlines()
        for data in train_config:
            img_lists.append(os.path.join(data_dir, data.strip('\n')))  # 获取图像列表

        print('\n' + '~' * 50)
        print('******** Loading data from disk ********')
        self.data = []
        self.freq = np.zeros(5)     # 频率  array([0., 0., 0., 0., 0.])
        self.zero_vol = np.zeros((4, 240, 240))  # 初始化 volume 大小
        count = 0
        for subject in img_lists:   # 逐文件获取全部数据
            count += 1
            if count % 10 == 0:
                print('loading subject %d' % count)
            volume, label = Brats15DataLoader.get_subject(subject)   # 4 * 155 * 240 * 240,  155 * 240 * 240   ******重点*****
            volume = norm_vol(volume)   # 归一化

            self.freq += self.get_freq(label)
            if train is True:
                length = volume.shape[1]    # length = 155, 4 * 155 * 240 * 240
                for i in range(length):     # 沿Z轴逐层扫描
                    name = subject + '=slice' + str(i)
                    # 如果当前层的内容为空,则跳过
                    # all() 函数用于判断给定的可迭代参数 iterable 中的所有元素是否都为 TRUE,如果是,返回 True;否,返回 False。
                    if (volume[:, i, :, :] == self.zero_vol).all():  # when training, ignore zero data
                        continue
                    else:
                        self.data.append([volume[:, i, :, :], label[i, :, :], name])    # self.data.append([4,240,240], [240, 240], name)
            else:
                volume = np.transpose(volume, (1, 0, 2, 3))
                self.data.append([volume, label, subject])

        self.freq = self.freq / np.sum(self.freq)
        self.weight = np.median(self.freq) / self.freq      # np.median中位数
        print('********  Finish loading data  ********')
        print('********  Weight for all classes  ********')
        print(self.freq)
        print(self.weight)
        if train is True:
            print('********  Total number of 2D images is ' + str(len(self.data)) + ' **********')
        else:
            print('********  Total number of subject is ' + str(len(self.data)) + ' **********')

        print('~' * 50)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # ********** get file dir **********
        [image, label, name] = self.data[index]  # get whole data for one subject
        # ********** change data type from numpy to torch.Tensor **********
        image = torch.from_numpy(image).float()  # Float Tensor 4, 240, 240
        label = torch.from_numpy(label).float()    # Float Tensor 240, 240
        return image, label, name

    """
    一般来说,要使用某个类的方法,需要先实例化一个对象再调用方法。
    而使用@staticmethod或@classmethod,就可以不需要实例化,直接类名.方法名()来调用。
    这有利于组织代码,把某些应该属于某个类的函数给放到那个类里去,同时有利于命名空间的整洁。
    """
    # 单个文件处理
    @staticmethod
    def get_subject(subject):
        """
        :param subject: absolute dir
        :return:
        volume  4D numpy    4 * 155 * 240 * 240
        label   4D numpy    155 * 240 * 240
        """
        # **************** get file ****************
        files = os.listdir(subject)  # [XXX.Flair, XXX.T1, XXX.T1c, XXX.T2, XXX.OT]    T2加权液体衰减反转恢复(FLAIR)、T1加权(T1)、T1加权对比增强(T1c)、T2加权(T2)、OT-label
        multi_mode_dir = []     # 图像文件名
        label_dir = ""          # 标签文件名
        for f in files:
            if f == '.DS_Store':
                continue
            if 'Flair' in f or 'T1' in f or 'T2' in f:    # if is data, (.Flair, .T1, .T1c, .T2)
                multi_mode_dir.append(f)
            elif 'OT.' in f:                              # if is label
                label_dir = f

        # ********** load 4 mode images **********
        multi_mode_imgs = []  # list size :4      item size: 155 * 240 * 240
        for mod_dir in multi_mode_dir:
            path = os.path.join(subject, mod_dir)  # absolute directory
            img = load_mha_as_array(path)
            multi_mode_imgs.append(img)

        # ********** get label **********
        label_dir = os.path.join(subject, label_dir)
        label = load_mha_as_array(label_dir)

        volume = np.asarray(multi_mode_imgs)
        return volume, label

    def get_freq(self, label):
        """
        :param label: numpy 155 * 240 * 240     val: 0,1,2,3,4
        :return:
        """
        class_count = np.zeros((5))     #  array([0., 0., 0., 0., 0.])
        for i in range(5):
            a = (label == i) + 0        # label维度的 0 or 1 数组
            class_count[i] = np.sum(a)  # 有就对应类别位置+1
        return class_count

# test case
if __name__ == "__main__":
    data_dir = '../data_sample/'
    conf = '../config/sample15.conf'
    # test data loader for training data
    brats15 = Brats15DataLoader(data_dir=data_dir, conf=conf, train=True)
    print(len(brats15))
    image2d, label2d, im_name = brats15[70]

    print('image size ......')
    print(image2d.shape)             # (4,  240, 240)

    print('label size ......')
    print(label2d.shape)             # (240, 240)
    print(im_name)
    name = im_name.split('/')[-1]
    save_one_image_label(image2d, label2d, 'img5/img_label_%s.jpg' % name)

    # test data loader for testing data
    brats15_test = Brats15DataLoader(data_dir=data_dir, conf=conf, train=False)
    print(len(brats15_test))
    image_volume, label_volume, subject = brats15_test[0]
    print(image_volume.shape)
    print(label_volume.shape)
    print(subject)

打印结果如下:

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
******** Loading data from disk ********
********  Finish loading data  ********
********  Weight for all classes  ********
[9.87486111e-01 2.95374104e-03 5.94198029e-03 9.16218638e-05
 3.52654570e-03]
[3.57123575e-03 1.19392515e+00 5.93496701e-01 3.84902200e+01
 1.00000000e+00]
********  Total number of 2D images is 132 **********
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
132
image size ......
torch.Size([4, 240, 240])
label size ......
torch.Size([240, 240])
../data_sample/HGG/brats_2013_pat0001_1=slice80

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
******** Loading data from disk ********
********  Finish loading data  ********
********  Weight for all classes  ********
[9.87486111e-01 2.95374104e-03 5.94198029e-03 9.16218638e-05
 3.52654570e-03]
[3.57123575e-03 1.19392515e+00 5.93496701e-01 3.84902200e+01
 1.00000000e+00]
********  Total number of subject is 1 **********
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1
torch.Size([155, 4, 240, 240])
torch.Size([155, 240, 240])
../data_sample/HGG/brats_2013_pat0001_1

其中保存一个index的图像内容如下:

1

总结

本文代码部分参考自GitHub,完整代码可以点击下方链接直达。在上文部分,着重对学习过程的主要代码进行了描述,其他辅助的功能函数,还需要你自行学习。

参考GitHub:stm_multi_modal_UNet

下一章博客是对训练的主函数做个简单的介绍,还有训练和测试过程做个描述。本文的数据处理和网络复现部分最为重要,所以篇幅较多。

作者提供了解决一个问题的多个思路和角度,这种对问题的分析方式,值得我好好学习和思考。于是对论文和代码部分进行学习、训练之余,分享给大家进行参考。

也希望大家对其中不对的地方多多在评论区批评指正,谢谢。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【BraTS】Brain Tumor Segmentation 脑部肿瘤分割3--构建数据流 的相关文章

随机推荐

  • SpringBoot + Spring Security多种登录方式:账号+微信网页授权登录

    大家好 我是宝哥 一 概述 实现账号用户名 微信网页授权登录集成在Spring Security的思路 最重要的一点是要实现微信登录通过Spring Security安全框架时 不需要验证账号 密码 二 准备工作 要实现该功能 首先需要掌握
  • win10台式机rtl8188eu(FW 150 UM V2.0)无线网卡无法连接wifi(无法连接到这个网络)

    同一个网卡 同一个WiFi 在笔记本上能用 能连接wifi 但是在台式机上就不能连接wifi 提示 无法连接到这个网络 如下图 win10版本都是1903 尝试换各种驱动都没解决 最后更新主板bios 然后从微星主板客服得知可以问京东自营的
  • 高校评优评奖管理系统

    这是一个高校评优评奖管理系统 供大家参考学习 不懂的地方可以联系本人 1 管理员登陆 学生申请 管理员后台 评优记录 数据维护 信息统计 系统设置 学生申报 微信 17777665965 QQ 1161724197
  • 纯 CSS 开关切换按钮

  • brk和sbrk及内存分配函数相关

    brk和sbrk主要的工作是实现虚拟内存到内存的映射 在GNUC中 内存分配是这样的 每个进程可访问的虚拟内存空间为3G 但在程序编译时 不可能也没必要为程序分配这么大的空间 只分配并不大的数据段空间 程序中动态分配的空间就是从这 一块分配
  • 【Spark系列2】reduceByKey和groupByKey区别与用法

    在spark中 我们知道一切的操作都是基于RDD的 在使用中 RDD有一种非常特殊也是非常实用的format pair RDD 即RDD的每一行是 key value 的格式 这种格式很像Python的字典类型 便于针对key进行一些处理
  • 微前端--qiankun原理概述

    demo放最后了 一 微前端 一 微前端概述 微前端概念是从微服务概念扩展而来的 摒弃大型单体方式 将前端整体分解为小而简单的块 这些块可以独立开发 测试和部署 同时仍然聚合为一个产品出现在客户面前 可以理解微前端是一种将多个可独立交付的小
  • Android 热补丁动态修复框架小结

    转载 http blog csdn net xdgaozhan article details 51848570 一 概述 最新github上开源了很多热补丁动态修复框架 大致有 https github com dodola HotFix
  • Python与OpenCV(三)——基于光流法的运动目标检测程序分析

    光流的概念是指在连续的两帧图像当中 由于图像中的物体移动或者摄像头的移动而使得图像中的目标形成的矢量运动轨迹叫做光流 本质上光流是个向量场 表示了一个像素点从第一帧过渡到第二帧的运动过程 体现该像素点在成像平面上的瞬时速度 而当我们对图像当
  • oracle 游标 上限,ORA-01000: 超出打开游标的最大数

    语言 java 数据库 oracle 开发中通过jdbc做批量删除对象时 出现了如下异常 java sql SQLException ORA 01000 超出打开游标的最大数 at oracle jdbc driver T4CTTIoer
  • UE5_创建C++项目报错

    UE官方VS安装推荐 https docs unrealengine com 4 26 en US ProductionPipelines DevelopmentSetup VisualStudioSetup UE5报错 A fatal e
  • 下拉框怎么用ajax实现添加功能,ajax实现动态下拉框示例

    许多页面上都涉及有下拉框 即select标签 对于简单的下拉框 被选择的数据是不需要改变的 我们可以用写死 这样下拉框的数据永远都是那几条 示例 信息一 信息二 信息三 信息四 但是有些项目或者工程是需要将数据库中的数据呈现出来并提供选择的
  • CH3-栈和队列

    文章目录 3 1栈和队列的定义和特点 栈的应用 队列的应用 3 1 1栈的定义和特点 3 1 2队列的定义和特点 3 2案例引入 案例3 1 进制转换 案例3 2 括号匹配的检验 案例3 3 表达式求值 案例3 4 舞伴问题 3 3栈的表示
  • 网络传输方式

    1 单播 1 1 定义 单播是指一种向单个目标地址传送数据的方式 即单独的一对一通讯方式 1 2 可使用协议 UDP TCP等协议 1 3 常见的场景 发送电子邮件 传输文件 2 广播 2 1 定义 一种向本地网络中所有设备发送数据的方式
  • FISCO BCOS 三、多群组部署以及新节点加入群组

    本章主要以星形组网和并行多组组网拓扑为例 指导您了解如下内容 了解如何使用build chain sh创建多群组区块链安装包 了解build chain sh创建的多群组区块链安装包目录组织形式 学习如何启动该区块链节点 并通过日志查看各群
  • 正确加载 Javascript 和 CSS 到 WordPress

    原文 http technerdia com 1789 include jquery css html 正确加载 jQuery Javascript 和 CSS 到你的WordPress网站也许是一件比较痛苦的事情 本文将讲解如何使用Wor
  • 使用python处理selenium中的获取元素属性

    获取我的订单元素class属性值 get class name driver find element by link text 我的订单 get attribute class 判断class属性值是否为active self asser
  • 深度学习半自动化视频标注工具——VATIC使用教程

    Vatic简介 Vatic是一个带有目标跟踪的半自动化视频标注工具 适合目标检测任务的标注工作 输入一段视频 支持自动抽取成粒度合适的标注任务并在流程上支持接入亚马逊的众包平台Mechanical Turk 当然也可以自己在本地标注 最大的
  • 使用markdown写大论文

    目的 用自己中意的 Markdown 编辑器来写论文初稿 使用 Zotero 来管理大量参考文献 然后论文转换成 Office Word 文档让老师们查看 当 Markdown 内容并转换成 Word 格式后 所有引用都需要被 Zotero
  • 【BraTS】Brain Tumor Segmentation 脑部肿瘤分割3--构建数据流

    往期回顾 BraTS Brain Tumor Segmentation 脑部肿瘤分割1 数据篇 BraTS Brain Tumor Segmentation 脑部肿瘤分割2 UNet的复现 在上一篇网络复现中提到 输入图像变成 4 155