파이토치에서는 데이터셋을 좀 더 쉽게 다룰 수 있도록 유용한 도구로서 torch.utils.data.Dataset과 torch.utils.data.DataLoader를 제공한다.
이를 사용하면 미니 배치 학습, 데이터 셔플(shuffle), GPU 병렬 처리까지 간단히 수행할 수 있다고 한다.
데이터셋마다 load하는 방법과 전처리가 다르므로 그에 맞게 Dataset 클래스를 바꾸어주기위해 사용한다고 보면 된다.
딥러닝 구현 코드를 보던 중 Dataset을 부모 클래스로 두는 경우가 있어서 찾아보니 부모 class로 둔 Dataset이 torch.utills.data.Dataset 이라는걸 알게 되어 이를 정리해보고자 한다.
기본적인 사용 방법은 Dataset을 정의하고, 이를 DataLoader에 전달하는 것이다.
앞서 말한 부모 클래스로 torch.utils.data.Dataset을 두는 경우는 orch.utils.data.Dataset을 상속받아 직접 커스텀 데이터셋(Custom Dataset)을 만드는 경우다.
torch.utils.data.Dataset은 파이토치에서 데이터셋을 제공하는 추상 클래스이다.
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
데이터셋의 전처리를 해주는 부분
def __len__(self):
데이터셋의 길이. 즉, 총 샘플의 수를 적어주는 부분
len(dataset)을 했을 때 데이터셋의 크기를 리턴할 len
def __getitem__(self, idx):
데이터셋에서 특정 1개의 샘플을 가져오는 함수
dataset[i]을 했을 때 i번째 샘플을 가져오도록 하는 인덱싱을 위한 get_item
데이터셋을 커스텀 할 때 반드시 필요한 3개의 뼈대이다. init, len, get_item
class FashionMnistDataset(Dataset):
def __init__(self, root, train=True, transform=transforms.ToTensor(),
minor_class_num=8, ratio=0.025):
super(FashionMnistDataset, self).__init__()
self.transform = transform
processed_folder = os.path.join(root, 'FashionMNIST/processed')
...
def __getitem__(self, index):
img, target = self.data[index], int(self.labels[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
img = self.transform(img)
return img, target
def __len__(self):
return len(self.data)
#참고하고 있는 코드인데, 여기서도 init, get_item, len 3개의 뼈대로 구성되어있음을 알 수 있다.
__init__에 보면 transform이 있는데 이건 데이터를 어느 형태로 바꿀건지를 설정하기 위함이다.
##이 부분은 추후 공부를 더 한 후 추가할 예정이다
'정리' 카테고리의 다른 글
Conv와 ConvTranspose (0) | 2023.03.18 |
---|---|
class 와 super정리 (0) | 2022.11.16 |
Dense Layer 정리 (0) | 2022.11.12 |