MyDataset
作者:
晓博
,
2021-10-05 16:27:47
,
所有人可见
,
阅读 248
import os
import PIL.Image as Image
import torch.utils.data as data
class MyDataset(data.Dataset):
def __init__(self, root_dir, transform=None, target_transform=None):
imgs = []
n = len(os.listdir(root_dir)) // 2
for i in range(n):
img = os.path.join(root_dir, "%03d.png" % i)
mask = os.path.join(root_dir, "%03d_mask.png" % i)
imgs.append([img, mask]) # 数据和标签的路径
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
x_path, y_path = self.imgs[index]
img_x = Image.open(x_path) # Image格式
img_y = Image.open(y_path)
if self.transform is not None:
img_x = self.transform(img_x)
if self.target_transform is not None:
img_y = self.target_transform(img_y)
return img_x, img_y
def __len__(self):
return len(self.imgs)