pytorch를 배우면서 가장 두려웠던 부분이 Dataset을 어떻게 넣어야 할까 였습니다.
그래서 이 부분을 설명하고자 포스팅합니다(솔직히 pytorch 공홈에 보면 다 나와있더군요).
파이토치 공홈에 보면 다음과 같이 나와있습니다.
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
저는 예전에 파이썬을 공부하면서 "__" 가 들어간 기능들을 이해를 잘 못했습니다("__"가들어가는 기능을 magic method라고 하더군요!). 이번기회에 강제적으로 이해가 되더군요(머리로 이해가 안되는부분은 그냥 기능을 많이 써보면 이해가 되더군요.)
1. __init__ 의 경우
R을 해보셧다면 바로 이해가 가실겁니다. R에서 function을 할때, 그 기능에 어떤 변수가 들어갈것인지를 지정 하는데 그것과 같은 기능을 합니다. Class내에 변수에는 저런 친구들이 있다라고 선언해준다고 생각 하시면 될 것 같습니다.
인스턴스화를 시킨다라고 하는데, 그걸 변수로 만든다와 비슷하다고 생각합니다.
위키백과에 보면 인스턴스는 해당 클래스의 구조로 컴퓨터 저장공간에서 할당된 실체를 의미한다. 참 어렵게 써놓은 것 같은데, 제가 이해하기로는 아무것도 없는 세상에 어떠한 물체를 만들고 물체의 스탯을 지정하는것이 인스턴스화다 라고 이해가 되더군요.
__init__ 뒤에 나오는 변수들은 다른 def에서도 사용이 가능합니다.
2. __len__의 경우
__len__의 경우 데이터들의 길이를 내보내줍니다. 우리가 어떠한 class를 통해 변수를 지정하고 지정한 변수의 크기를 알고 싶을때, len이라는 것을 사용하는데, custom class에서는 __len__을 설정해주지 않으면 이 부분이 실행이 되지 않습니다.
input되는 데이터 개수의 길이를 알아야하는 이유는 모델을 제작할때, 정확한 train set, test set의 개수를 맞춰줄수있기 때문입니다.
3. __getitem__의 경우
__get item__의 경우 뒤에 보면 idx라는 부분이 있습니다. 데이터를 indexing해주는 부분으로 __len__에서 데이터의 length를 알아 낼 수가 있고, 데이터 내에서 특정 데이터를 가져오기 위해서는 __getitem__ 설정을 통해 가져 올 수 있게 합니다.
예제를 들어보도록 하겠습니다.
test_data=[1,2,3,4,5,6]
test_data1=[1,2,3,5,2,1]
class test:
def __init__(self,bio,info):
self.stat=[1,2,3,4,5]
self.int=bio
self.str=info
__init__만 넣게되면
bbbb=test(test_data,test_data1)
bbbb.stat
bbbb.int
bbbb.str
위 세가지에 변수에 값들이 들어가게되고, 각 변수들에 대해서는 len 사용이 가능하지만,
bbbb에는 사용이 불가능합니다.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_42784/3707354467.py in <module>
11 bbbb=test(test_data,test_data1)
12 b1=bbbb.stat
---> 13 len(bbbb)
TypeError: object of type 'test' has no len()
__len__을 추가해보도록 하겠습니다.
test_data=[1,2,3,4,5,6]
test_data1=[1,2,3,5,2,1]
class test:
def __init__(self,bio,info):
self.stat=[1,2,3,4,5]
self.int=bio
self.str=info
def __len__(self):
return len(self.stat)
이렇게 제작한 후에 len을 했을때는 동작이 됩니다. 이때 결과는 self.stat의 개수를 리턴하게 됩니다.
test_data,test_data1의 개수가 아닙니다.
__getitem__을 사용해보도록 하겠습니다.
test_data=[1,2,3,4,5,6]
test_data1=[1,2,3,5,2,1]
class test:
def __init__(self,bio,info):
self.stat=[1,2,3,4,5]
self.int=bio
self.str=info
def __len__(self):
return len(self.stat)
def __getitem__(self,idx):
test_d=self.int[idx]
test_d1=self.str[idx]
return test_d, test_d1
bbbb=test(test_data,test_data1)
bbbb.__getitem__(0)
# output
(1,1)
__getitem__을 통해서 데이터와 라벨을 한꺼번에 가지고 올 수가 있습니다.
__getitem__에서 다른 변수를 사용해서도 이용이 가능하나, 아마 제가 모르는 부분이 있지 않을까 싶습니다.
데이터를 외부에서 읽어와서 진행한다던지 등(아시는분 댓글좀.. 혹은 추후에 업데이트 하겠습니다.)
여기까지 최근 추가적으로 공부하고 알아낸 부분을 설명드렸습니다.
감사합니다.
'딥러닝\머신러닝' 카테고리의 다른 글
[pytorch, 이미지분석] CustomDataset 제작시 주의해야할 점. (0) | 2022.08.05 |
---|---|
[pytorch, 딥러닝] 모델 저장하는 방법 (0) | 2022.07.28 |
[딥러닝공부] Model의 기본 - pytorch (0) | 2022.05.17 |
[딥러닝공부] 모델의 기본 - tensorflow (0) | 2022.05.16 |
[tensorflow]python에서 이미지 읽는 방법 (0) | 2022.03.12 |
댓글