温馨提示×

温馨提示×

您好,登录后才能下订单哦!

密码登录×
登录注册×
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》

Pytorch中Dataset数据处理的示例分析

发布时间:2021-12-27 10:08:16 来源:亿速云 阅读:133 作者:小新 栏目:开发技术

这篇文章给大家分享的是有关Pytorch中Dataset数据处理的示例分析的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。

    Pytorch系列是了解与使用Pytorch编程来实现卷积神经网络。

    学习如何对卷积神经网络编程;首先,需要了解Pytorch对数据的使用(也是在我们模型流程中对数据的预处理部分),其中有两个包Dataset,DataLoaderDatasetPytorch对于单个数据的处理类似于给一堆数据进行编号,(在有标签的图像处理中)对其有序地提取图像与标签,
    DataLoader则是一坨一坨的数据进行批次的处理。

    此实验运用的数据是北邮邓伟洪老师的人脸表情包的数据集,

    当然大家也可以自己手动做个二分类数据集之类的就将一幅幅的图片放图标签命名的文件夹中即可。

    将邓伟洪老师的RAF-DB简单来刨析,假设其只有Image,没有真正的Annotation等,
    则其根路径(整个data的大体位置)设为 root_dir = "D:\data\basic"
    (由于以下考虑了Annotation,"Image"放入label)标签路径(data下的label位置)设为label_dir="Image\aligned(original)"

    可参考下图理解:

    Pytorch中Dataset数据处理的示例分析

    假设alignedoriginal是标签,但是它是真正的图片的路径

    Pytorch中Dataset数据处理的示例分析

    Pytorch中Dataset数据处理的示例分析

     现在开始编程:

    因为使用Dataset,即让新的类(MyData)来继承Dataset需要改写 def __getitem__(self,item):def __len__(self):
    其中, def __getitem__ (self,item):输入一系列图像的path与图像的index(组合为一张图像的详细地址),输出图像与标签,代码中默认item为序列号,但是为了方便将item改写为idx;
    def __len__(self):输入一系列图像的路径,输出这些图像的个数。
    其他的函数就可以创新加载自己定义的类里。

    from torch.utils.data import Dataset #Dataset的包
    import os #路径需要这个
    import cv2 # 需要读取图片,最好用opencv-python,当然也可以用PIL只是我不顺手
    
    
    class MyData(Dataset): #我定义的这个类
        def __init__(self, root_dir, label_dir):
         #下面需要使用的变量,在__init__定义好,
            self.root_dir = root_dir # 根路径 data在电脑或者服务器大致的位置
            self.label_dir = label_dir # label的位置(这里假设Image的名字就是label的位置)
            self.path = os.path.join(self.root_dir, self.label_dir)# 将这个两个合在一起就能找到整体图片的大致路径
            self.img_path = os.listdir(self.path) #得到整体图片的路径(可取其中的一张一张的图像的名字)
    
        def __getitem__(self, idx): 
        # 改写__getitem__(self,item)函数,最后得到图像,标签
          #获取具体的一幅图像的名字
            img_name = self.img_path[idx]
            #获取一幅图像的详细地址
            img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
            #用opencv来读取图像
            img = cv2.imread(img_item_path)
            #获取标签(这里简单写了aligned与original)
            label = self.label_dir
            return img, label
    
        def __len__(self):
        #改写整体图像的大小
            return len(self.img_path)
    
    
    root_dir = "D://data//basic"
    img_dir = "Image"
    aligned_label_dir = "aligned"
    # aligned_label_dir = "Image//aligned"
    aligned_label_dir = os.path.join(img_dir, aligned_label_dir)
    
    original_label_dir = "original"
    #original_label_dir = "Image//original"
    original_label_dir = os.path.join(img_dir, original_label_dir)
    
    #aligned_data = "D://data//basic//Image//aligned"
    aligned_data = MyData(root_dir, aligned_label_dir)
    #original_data = "D://data//basic//Image//original"
    original_data = MyData(root_dir, original_label_dir)
    data = aligned_data + original_data
    # 15339
    print(len(aligned_data))
    # 15339
    print(len(original_data))
    # 30678
    print(len(data))
    img_1, label_1 = data[15338]
    img_2, label_2 = data[15339]
    print(label_1) # Image\aligned
    print(label_2) # Image\original

    感谢各位的阅读!关于“Pytorch中Dataset数据处理的示例分析”这篇文章就分享到这里了,希望以上内容可以对大家有一定的帮助,让大家可以学到更多知识,如果觉得文章不错,可以把它分享出去让更多的人看到吧!

    向AI问一下细节

    免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

    AI