往期回顾:
在上一篇网络复现中提到:
- 输入图像变成: 4 * 155 * 240 * 240,155层,每一层4个channel,每一个channel是240*240大小
在在接下来构建数据流时候,输出的图像部分,也就按照这个大小进行构建即可。
构建数据流
在数据篇中,了解到4个模态,每一个模态存储的都是一个3维数据。而上文网络输入是一个(batch_size, channel=4, width, height)
,没有了Z轴的维度信息。所以,在构建数据的时候,需要按Z轴,将每一层单独拿出来,组成新的数据形式。
- 读取数据,区分图像和标签
- 处理成图像(channel, width, height)和标签(width, height)形式
在pytorch
的训练数据处理中,以下这段内容,可以作为嵌套的依据:
class Brats15DataLoader(Dataset):
def __init__(self, data_dir, train=True):
self.data = []
if train:
self.data ···
else:
self.data ···
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# ********** get file dir **********
image, label = self.data[index] # get whole data for one subject
# ********** change data type from numpy to torch.Tensor **********
image = torch.from_numpy(image).float()
label = torch.from_numpy(label).float()
return image, label
简单点就是干这么3件事情:
- 告诉我数据在哪里,我好去读取
- 有了数据,处理成模型需要feed的数据形式,按index一个一个的返回
- 返回数据长度,好知道一次循环结束了
插入一点内容:创建Brats15DataLoader
类中__len__
和__getitem__
方法,可以类比创建列表的时候,是实现了一个列表对象的实例化a_list
。
- len(a_list)函数实际上是调用列表类中的私有方法
__len__()
- 用列表索引的时候a_list[1],实际上是调用了
__getitem__()
方法,传入的index=1
直接将代码放到这里,注释部分加入了自己的学习和理解,如存在问题,欢迎评论区交流。
# coding:utf-8
from torch.utils.data import Dataset
from src.utils import *
modals = ['flair', 't1', 't1c', 't2']
class Brats15DataLoader(Dataset):
def __init__(self, data_dir, conf='../config/train15.conf', train=True):
img_lists = []
train_config = open(conf).readlines()
for data in train_config:
img_lists.append(os.path.join(data_dir, data.strip('\n'))) # 获取图像列表
print('\n' + '~' * 50)
print('******** Loading data from disk ********')
self.data = []
self.freq = np.zeros(5) # 频率 array([0., 0., 0., 0., 0.])
self.zero_vol = np.zeros((4, 240, 240)) # 初始化 volume 大小
count = 0
for subject in img_lists: # 逐文件获取全部数据
count += 1
if count % 10 == 0:
print('loading subject %d' % count)
volume, label = Brats15DataLoader.get_subject(subject) # 4 * 155 * 240 * 240, 155 * 240 * 240 ******重点*****
volume = norm_vol(volume) # 归一化
self.freq += self.get_freq(label)
if train is True:
length = volume.shape[1] # length = 155, 4 * 155 * 240 * 240
for i in range(length): # 沿Z轴逐层扫描
name = subject + '=slice' + str(i)
# 如果当前层的内容为空,则跳过
# all() 函数用于判断给定的可迭代参数 iterable 中的所有元素是否都为 TRUE,如果是,返回 True;否,返回 False。
if (volume[:, i, :, :] == self.zero_vol).all(): # when training, ignore zero data
continue
else:
self.data.append([volume[:, i, :, :], label[i, :, :], name]) # self.data.append([4,240,240], [240, 240], name)
else:
volume = np.transpose(volume, (1, 0, 2, 3))
self.data.append([volume, label, subject])
self.freq = self.freq / np.sum(self.freq)
self.weight = np.median(self.freq) / self.freq # np.median中位数
print('******** Finish loading data ********')
print('******** Weight for all classes ********')
print(self.freq)
print(self.weight)
if train is True:
print('******** Total number of 2D images is ' + str(len(self.data)) + ' **********')
else:
print('******** Total number of subject is ' + str(len(self.data)) + ' **********')
print('~' * 50)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# ********** get file dir **********
[image, label, name] = self.data[index] # get whole data for one subject
# ********** change data type from numpy to torch.Tensor **********
image = torch.from_numpy(image).float() # Float Tensor 4, 240, 240
label = torch.from_numpy(label).float() # Float Tensor 240, 240
return image, label, name
"""
一般来说,要使用某个类的方法,需要先实例化一个对象再调用方法。
而使用@staticmethod或@classmethod,就可以不需要实例化,直接类名.方法名()来调用。
这有利于组织代码,把某些应该属于某个类的函数给放到那个类里去,同时有利于命名空间的整洁。
"""
# 单个文件处理
@staticmethod
def get_subject(subject):
"""
:param subject: absolute dir
:return:
volume 4D numpy 4 * 155 * 240 * 240
label 4D numpy 155 * 240 * 240
"""
# **************** get file ****************
files = os.listdir(subject) # [XXX.Flair, XXX.T1, XXX.T1c, XXX.T2, XXX.OT] T2加权液体衰减反转恢复(FLAIR)、T1加权(T1)、T1加权对比增强(T1c)、T2加权(T2)、OT-label
multi_mode_dir = [] # 图像文件名
label_dir = "" # 标签文件名
for f in files:
if f == '.DS_Store':
continue
if 'Flair' in f or 'T1' in f or 'T2' in f: # if is data, (.Flair, .T1, .T1c, .T2)
multi_mode_dir.append(f)
elif 'OT.' in f: # if is label
label_dir = f
# ********** load 4 mode images **********
multi_mode_imgs = [] # list size :4 item size: 155 * 240 * 240
for mod_dir in multi_mode_dir:
path = os.path.join(subject, mod_dir) # absolute directory
img = load_mha_as_array(path)
multi_mode_imgs.append(img)
# ********** get label **********
label_dir = os.path.join(subject, label_dir)
label = load_mha_as_array(label_dir)
volume = np.asarray(multi_mode_imgs)
return volume, label
def get_freq(self, label):
"""
:param label: numpy 155 * 240 * 240 val: 0,1,2,3,4
:return:
"""
class_count = np.zeros((5)) # array([0., 0., 0., 0., 0.])
for i in range(5):
a = (label == i) + 0 # label维度的 0 or 1 数组
class_count[i] = np.sum(a) # 有就对应类别位置+1
return class_count
# test case
if __name__ == "__main__":
data_dir = '../data_sample/'
conf = '../config/sample15.conf'
# test data loader for training data
brats15 = Brats15DataLoader(data_dir=data_dir, conf=conf, train=True)
print(len(brats15))
image2d, label2d, im_name = brats15[70]
print('image size ......')
print(image2d.shape) # (4, 240, 240)
print('label size ......')
print(label2d.shape) # (240, 240)
print(im_name)
name = im_name.split('/')[-1]
save_one_image_label(image2d, label2d, 'img5/img_label_%s.jpg' % name)
# test data loader for testing data
brats15_test = Brats15DataLoader(data_dir=data_dir, conf=conf, train=False)
print(len(brats15_test))
image_volume, label_volume, subject = brats15_test[0]
print(image_volume.shape)
print(label_volume.shape)
print(subject)
打印结果如下:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
******** Loading data from disk ********
******** Finish loading data ********
******** Weight for all classes ********
[9.87486111e-01 2.95374104e-03 5.94198029e-03 9.16218638e-05
3.52654570e-03]
[3.57123575e-03 1.19392515e+00 5.93496701e-01 3.84902200e+01
1.00000000e+00]
******** Total number of 2D images is 132 **********
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
132
image size ......
torch.Size([4, 240, 240])
label size ......
torch.Size([240, 240])
../data_sample/HGG/brats_2013_pat0001_1=slice80
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
******** Loading data from disk ********
******** Finish loading data ********
******** Weight for all classes ********
[9.87486111e-01 2.95374104e-03 5.94198029e-03 9.16218638e-05
3.52654570e-03]
[3.57123575e-03 1.19392515e+00 5.93496701e-01 3.84902200e+01
1.00000000e+00]
******** Total number of subject is 1 **********
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1
torch.Size([155, 4, 240, 240])
torch.Size([155, 240, 240])
../data_sample/HGG/brats_2013_pat0001_1
其中保存一个index的图像内容如下:
总结
本文代码部分参考自GitHub,完整代码可以点击下方链接直达。在上文部分,着重对学习过程的主要代码进行了描述,其他辅助的功能函数,还需要你自行学习。
参考GitHub:stm_multi_modal_UNet
下一章博客是对训练的主函数做个简单的介绍,还有训练和测试过程做个描述。本文的数据处理和网络复现部分最为重要,所以篇幅较多。
作者提供了解决一个问题的多个思路和角度,这种对问题的分析方式,值得我好好学习和思考。于是对论文和代码部分进行学习、训练之余,分享给大家进行参考。
也希望大家对其中不对的地方多多在评论区批评指正,谢谢。