pytorch分割网络数据输入接口

2023-11-20

 pytorch的自定义接口是真的方便, 记录一下自己分割数据输入的脚本:

# -*- coding: utf-8 -*-
# @Time    : 2019/10/31 21:36
# @Author  : Yunyun Xu
# @Contact : 1443563995@qq.com
# @File    : MyDatasetReader.py
# @Software: Pycharm
# @Blog    : https://me.csdn.net/xuyunyunaixuexi

import os
import numpy as np
import scipy.misc as m
from PIL import Image
from torch.utils import data
from mypath import Path
from torchvision import transforms
import custom_transforms as tr

class MyEggSegmentation(data.Dataset):
    #NUM_CLASSES = 19

    def __init__(self, args, root = Path.db_root_dir("MyEggs"), split = "train"):

        self.root = root
        self.split = split
        self.args = args
        self.image_files = {}
        self.label_files = {}
        #files = {train:[]}

        self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
        self.annotations_base = os.path.join(self.root, 'gtFine_trainvaltest', 'gtFine', self.split)
        self.image_files[split] = self.recursive_glob(rootdir=self.images_base, suffix=".png")
        self.label_files[split] = self.recursive_glob(rootdir = self.annotations_base,suffix = ".png")

        if not self.image_files[split]:
            raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))

        print("Found %d %s images" % (len(self.files[split]), split))

    def __len__(self):
        return len(self.image_files[self.split])

    def __getitem__(self, index):
        img_path = self.image_files[self.split][index].rstrip()
        lbl_path = self.label_files[self.split][index].rstrip()
        #将RGBA转为RGB三通道
        _img = Image.open(img_path).convert("RGB")
        #读取索引图
        _target = Image.open(lbl_path)

        sample = {"images":_img, "label":_target}

        if self.split == "train":
            return self.transform_tr(sample)
        if self.split == "val":
            return self.transform_tr(sample)
        if self.split == "test":
            return self.transform_tr(sample)

    def recursive_glob(self, rootdir = '.', suffix = " "):
        return [os.path.join(looproot, filename) for looproot, _, filenames in os.walk(rootdir)
                for filename in filenames if filename.endswith(suffix)]

    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255),
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

    def transform_val(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixScaleCrop(crop_size=self.args.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

    def transform_ts(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=self.args.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

测试了一下,是可以遍历的,证明自定义数据集接口(继承data.Dataset)是正确的:

但是本人也有一个问题, 就是分割网络如果根据自己的数据集大小,去确定crop_size, 还是只能不断的尝试去效果??希望得到大家的解答.........

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch分割网络数据输入接口 的相关文章

  • 循环神经网络RNN以及几种经典模型

    RNN简介 现实世界中 很多元素都是相互连接的 比如室外的温度是随着气候的变化而周期性的变化的 我们的语言也需要通过上下文的关系来确认所表达的含义 但是机器要做到这一步就相当得难了 因此 就有了现在的循环神经网络 他的本质是 拥有记忆的能力
  • Pytorch学习笔记(六)

    简单的LeNet网络模型 torchvision datasets torchvision是pytorch的一个图形库 它服务于PyTorch深度学习框架的 主要用来构建计算机视觉模型 以下是torchvision的构成 torchvisi
  • (pytorch进阶之路)Masked AutoEncoder论文及实现

    文章目录 1 导读 2 论文地址 3 代码实现思路 3 1 预处理阶段 3 2 Encoder 3 3 Decoder 3 4 fine tuning 3 5 linear probing 3 6 evaluation 4 代码地址 5 如
  • 使用GPU进行神经网络计算详解

    Pytorch学习笔记 六 使用GPU的简单LeNet网络模型中也提到了如何实现GPU上的运算 虽然不详细 但是也足够 总结 如果对于总结知识已经比较熟悉 那么下面的详解可以不用看 默认CPU进行计算 CPU上变量或模型不能与GPU上变量或
  • 如何使用tensorboard及打开tensorboard生成文件

    一 使用tensorboard tensorboard中常用函数 1 writer add scalar def add scalar self tag scalar value global step None walltime None
  • PyTorch——解决报错“RuntimeError: running_mean should contain *** elements not ***”

    1 问题描述 在使用PyTorch编程的时候 经常遇到一种报错就是 RuntimeError running mean should contain elements not 这次我具体的报错信息是 File home software p
  • 【transformers】tokenizer用法(encode、encode_plus、batch_encode_plus等等)

    tranformers中的模型在使用之前需要进行分词和编码 每个模型都会自带分词器 tokenizer 熟悉分词器的使用将会提高模型构建的效率 string tokens ids 三者转换 string tokens tokenize te
  • torch.nn.LocalResponseNorm(局部响应归一化)详解(附源码解析)

    torch nn LocalResponseNorm 局部响应归一化的理解 局部归一化的动机 在神经生物学有一个概念叫做侧抑制 lateral inhibitio 指的是被激活的神经元抑制相邻神经元 归一化的目的是 抑制 局部响应归一化就是
  • Pytorch学习(6) —— 加载模型部分参数的用法

    上一节 我们给出了模型加载和保存的简要示例 但是 我们有时候会用别人的参数 他们的层参数名和我们的名称很容易不同 因此这里将会对源码进入深入剖析 分析参数提取和保存是如何实现的 我们使用pytorch的VGG16预训练模型 加载 返回其类型
  • Mask掩码

    Python中Mask的用法 引例 Numpy的MaskedArray模块 小于 或小于等于 给定数值 大于 或大于等于 给定数值 在给定范围内 超出给定范围 在算术运算期间忽略NaN和 或infinite值 All men are scu
  • 将tensor张量转换成图片格式并保存

    这是一个工具包 功能 反向操作transforms Normalize和transforms ToTensor函数 将tensor格式的图片转换成 jpg png格式的图片 注 这里是我原始的写法 但是是存在着一些改进空间的 如评论区所言
  • 树莓派4B安装Pytorch, torchvision(附已编译安装包)

    树莓派4B安装Pytorch torchvision Install Pytorch Raspberry Pi 4B Linux raspberrypi 4 19 75 v7l 1270 SMP Tue Sep 24 18 51 41 BS
  • python遍历目录的方法

    简单暴力法 递归 假设在 E 盘中 有个名为 Python 的文件夹 该文件夹中也有两个文件夹 分别是 A 和 B 另外 在 A 文件夹中还有一个 results txt 的文本文件 因此 Python 文件夹的文件结构如下 Python
  • pytorch图像检索评价指标MAP

    map是图像检索模型的一个评价指标 以图片中第一个计算AP值为例 P的分别是 1 2 3 3 6 4 9 5 10 R值分别是 1 5 2 5 3 5 4 5 1 AP计算结果 1 2 3 3 6 4 9 5 10 5 https blog
  • KORNIA与torch 版本存在依赖关系

    KORNIA 0 58对应torch 1 7 以上对应1 8 可以多下载几个多次安装 直到支持
  • PyTorch实战使用Resnet迁移学习

    PyTorch实战使用Resnet迁移学习 项目结构 项目任务 项目代码 网络模型测试 项目结构 数据集存放在flower data文件夹 cat to name json是makejson文件运行生成的 TorchVision文件主要存放
  • (pytorch进阶之路)DDPM扩散概率模型

    文章目录 概述 前置知识 diffusion图示 扩散过程 逆扩散过程 后验的扩散条件概率 似然函数 算法 代码实现 概述 扩散概率模型 deep unsupervised learning using nonequilibrium the
  • Pytorch并行训练方法-单机多卡

    简单方便的 nn DataParallel DataParallel 可以帮助我们 使用单进程控 将模型和数据加载到多个 GPU 中 控制数据在 GPU 之间的流动 协同不同 GPU 上的模型进行并行训练 细粒度的方法有 scatter g
  • [pytorch]关于cross_entropy函数

    loss F cross entropy output labels output 网络的全连接层的输出 值可能是有正有负的 例如 1 56 2 43 等 labels 正常标签 例如一共5个类别 值就是0 4
  • 【PyTorch学习】(三)自定义Datasets

    torchvision datasets源码地址 https github com pytorch vision blob master torchvision datasets 前两篇从搭建经典的ResNet DenseNet入手简单的了

随机推荐