在PyTorch中加载图片数据集一般有两种方法:
使用torchvision.datasets中的ImageFolder 方法描述
ImageFolder函数读取指定路径下的图片数据集,并将其组织为类别文件夹结构。
使用torchvision.transforms对图像进行预处理,如调整尺寸和归一化。
示例代码:
```python
from torchvision.datasets import ImageFolder
from torchvision import transforms
import matplotlib.pyplot as plt
定义预处理变换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
加载数据集
train_dir = "../data/hotdog/train"
test_dir = "../data/hotdog/test"
train_dataset = ImageFolder(root=train_dir, transform=transform)
test_dataset = ImageFolder(root=test_dir, transform=transform)
使用DataLoader加载数据集
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
显示图片
for i, (images, labels) in enumerate(train_loader):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images.permute(1, 2, 0))
plt.title(f"Label: {labels}")
plt.axis("off")
if i == 8:
break
plt.show()
```
自定义Dataset类 方法描述
继承torch.utils.data.Dataset类,实现自定义的数据读取和预处理逻辑。
这种方法更加灵活,适用于复杂的数据集结构。
示例代码: