Pytorch读取npy数据格式,编写dataset模块,可配合Dataloader进行使用

        在训练模型前,最重要的部分就是制作好数据集,有些情况下,由于图片数据过多,然后存储很不方便,我们就需要将数据制作成npy类型的数据格式。npy数据格式是一个四维的数组[N,H,W, C],其中N代表数据集的总数,H, W,C分别代表每一张图片对应的长、宽、以及通道数。

数据制作好之后,就是如何加载数据问题,TF中加载数据相对比较容易,但是Pytorch中,我们一般都是将数据制作成dataset,再传入Dataloader进行加载,因此就需要继承Dataset的类,然后编写读取npy的数据格式。Dataset中,我们需要定义三个函数。

一、__init__(self,data) 函数

主要是用来加载npy数据的,也可以加载数据预处理的函数,比如将数据转化为tensor之类的操作

 def __init__(self, data):
        self.data = np.load(data) #加载npy数据
        self.transforms = transform #转为tensor形式

二、__len__(self)函数

这个函数就是用来返回数据的总个数

 def __len__(self):
        return self.data.shape[0] #返回数据的总个数

三、 __getitem__(self,index)函数

这个是最要的函数,类似一个for循环,从头开始,每次读取一个保存在npy里面的数据,然后进行处理后,可以同时返回训练数据,以及对应的标签

    def __getitem__(self, index):
        hdct= self.data[index, :, :, :]  # 读取每一个npy的数据
        hdct = np.squeeze(hdct)  # 删掉一维的数据,就是把通道数这个维度删除
        ldct = 2.5 * skimage.util.random_noise(hdct * (0.4 / 255), mode='poisson', seed=None) * 255 #加poisson噪声
        hdct=Image.fromarray(np.uint8(hdct)) #转成image的形式
        ldct=Image.fromarray(np.uint8(ldct)) #转成image的形式
        hdct= self.transforms(hdct)  #转为tensor形式
        ldct= self.transforms(ldct)  #转为tensor形式
        return ldct,hdct #返回数据还有标签

完整的代码如下:

import torch
import numpy as np
import skimage
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
torch.manual_seed(1)  # reproducible

transform = transforms.Compose([
    transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1]
])
'''NPY数据格式'''
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = np.load(data) #加载npy数据
        self.transforms = transform #转为tensor形式
    def __getitem__(self, index):
        hdct= self.data[index, :, :, :]  # 读取每一个npy的数据
        hdct = np.squeeze(hdct)  # 删掉一维的数据,就是把通道数这个维度删除
        ldct = 2.5 * skimage.util.random_noise(hdct * (0.4 / 255), mode='poisson', seed=None) * 255 #加poisson噪声
        hdct=Image.fromarray(np.uint8(hdct)) #转成image的形式
        ldct=Image.fromarray(np.uint8(ldct)) #转成image的形式
        hdct= self.transforms(hdct)  #转为tensor形式
        ldct= self.transforms(ldct)  #转为tensor形式
        return ldct,hdct #返回数据还有标签
    def __len__(self):
        return self.data.shape[0] #返回数据的总个数

def main():
    dataset=MyDataset('.\data_npy\img_covid_poisson_glay_clean_BATCH_64_PATS_100.npy')
    data= DataLoader(dataset, batch_size=64, shuffle=True, pin_memory=True)

if __name__ == '__main__':
	main()

 

更多相关推荐

pytorch读取自己的数据集_给pyto...

在用tensorflow的时候,可以将数据转化成tfrecord的数据格式,增加数据读取效率。这时候你看nv...

继续阅读

pytorch读取自己的数据集_给训练...

需求最近在训练coco数据集,训练集就有11万张,训练一个epoch就要将近100分钟,训练100个epoch...

继续阅读

PyTorch学习—6.PyTorch数据读取...

文章目录一、PyTorch数据读取机制Dataloader一、PyTorch数据读取机制Dataloader  PyTorch数据...

继续阅读

pytorch dataset_有关如何在PyTo...

本篇文章包括:torch.utils.data.Datasettorch.utils.data.TensorDataset拆分我们的数据集:ra...

继续阅读

fashionmnist数据集_简单的使用P...

概述本文的目的是为那些想要使用PyTorch和FashionMNIST进行简单深度学习图像分类网络的人提供...

继续阅读

【Tensorflow】训练keras模型+te...

1.数据集jpg图像数据格式的MNIST数据集:(放在database1文件夹下面) 2.利用tensorflow-V2的...

继续阅读

img标签读取本地图片_PyTorch 学...

加入极市专业CV交流群,与6000+来自腾讯,华为,百度,北大,清华,中科院等名企名校视觉开发...

继续阅读

使用pytorch 的torch.utils.Data...

目录导入必要的库用pandas读入数据定义一个显示图片和landmarks的函数定义一个Dataset类,继承t...

继续阅读

使用PyTorch进行数据处理

  在深度学习中,数据的处理对于神经网络的训练来说十分重要,良好的数据(包括图像、文本、语...

继续阅读

(第一篇)pytorch数据预处理三...

前言:在深度学习中,数据的预处理是第一步,pytorch提供了非常规范的处理接口,本文将针对处...

继续阅读