欢迎来到代码驿站!

Python代码

当前位置:首页 > 软件编程 > Python代码

Pytorch使用技巧之Dataloader中的collate_fn参数详析

时间:2023-02-03 07:39:24|栏目:Python代码|点击:

以MNIST为例

from torchvision import datasets
mnist = datasets.MNIST(root='./data/', train=True, download=True)
print(mnist[0])

结果

(<PIL.Image.Image image mode=L size=28x28 at 0x196E3F1D898>, 5)

MINIST数据集的dataset是由一张图片和一个label组成的元组

dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:x)
for each in dataloader:
    print(each)
    break

结果

[(<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105630>, 0), (<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105668>, 2)]

collate_fn为lamda x:x时表示对传入进来的数据不做处理

下面自定义collate_fn看看什么效果

def collate(data):
    img = []
    label = []
    for each in data:
        img.append(each[0])
        label.append(each[1])
    return img,label
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:collate(x))
for each in dataloader:
    print(each)
    break

结果

([<PIL.Image.Image image mode=L size=28x28 at 0x241433A36D8>, <PIL.Image.Image image mode=L size=28x28 at 0x241433A3710>], [9, 3])

说明:若不设置collate_fn参数则会使用默认处理函数

但必须保证传进来的数据都是tensor格式否则会报错

附:DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=<function default_collate>,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

总结

上一篇:Python+Matplotlib绘制3D图像的示例详解

栏    目:Python代码

下一篇:Python爬虫获取基金变动信息

本文标题:Pytorch使用技巧之Dataloader中的collate_fn参数详析

本文地址:http://www.codeinn.net/misctech/225034.html

推荐教程

广告投放 | 联系我们 | 版权申明

重要申明:本站所有的文章、图片、评论等,均由网友发表或上传并维护或收集自网络,属个人行为,与本站立场无关。

如果侵犯了您的权利,请与我们联系,我们将在24小时内进行处理、任何非本站因素导致的法律后果,本站均不负任何责任。

联系QQ:914707363 | 邮箱:codeinn#126.com(#换成@)

Copyright © 2020 代码驿站 版权所有