본문 바로가기
딥러닝\머신러닝

[pytorch, 이미지분석] CustomDataset 제작시 주의해야할 점.

by 인포메틱스 2022. 8. 5.
반응형

 

CustomDataset을 제작하는데 있어서 익숙한 포멧은 다음과 같습니다.

 

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

 

 위와 같이 제작을 하더라도 가끔 에러가 발생이 되는데, 이번 포스팅을 하게 된 이유가 데이터폴더 안에 이미지 파일이 아닌 폴더가 들어가 있는 경우입니다.

 

 이럴 때는 화내지말고(전 화냄), 물론 차근차근 배워온 사람이거나 컴공과는 쉽게 해결이 가능하겠지만 필자와 같이 무대포로 분석하는 경우 다음과 같이 해결하면 됩니다.

 

1. 데이터 폴더 내에 폴더가 있는 경우 아래와 같이 에러가 나온다.

from skimage import io, transform
import os
class coad_dataset(torch.utils.data.Dataset):
    
    def __init__(self,path):
        self.path = path
        
    def __len__(self):
        return len(os.listdir(self.img_path))
    
    def __getitem__(self,idx):
    	image_path=self.img_path
        image_path = os.path.join(image_path,os.listdir(image_path)[idx])
        img = plt.imread(image_path)[:,:,:3].astype('float32')
        img = np.array(img)
        img = img.transpose(2,0,1)
        return img,os.listdir(image_folder)[idx]
PermissionError: [Errno 13] Permission denied: [폴더이름]

2. 처음에는 __getitem__에서 이미지 경로를 가지고 os.path.isfile 을 이용하여 파일인경우 return하도록 제작을 했다.

-> 이럴 경우도 에러가 뜬다. 

from skimage import io, transform
import os
class coad_dataset(torch.utils.data.Dataset):
    def __init__(self,path):
        self.path = path
      
    def __len__(self):
        return len(os.listdir(os.path.join(self.path)))
    
    def __getitem__(self,idx):
        image_folder = os.path.join(self.path)
        img_list=os.listdir(image_folder)
        image_path = os.path.join(image_folder,img_list[idx])
        if os.path.isfile(image_path):
            img = plt.imread(image_path)[:,:,:3].astype('float32')
            img = np.array(img)
            img = img.transpose(2,0,1)
            return img,os.listdir(image_folder)[idx]

 

TypeError: object of type 'NoneType' has no len()

 

이럴경우 원인 찾기가 힘듭니다.

 

 그래서 해결책을 찾아보니 __init__이 부분에서 파일리스트를 걸러내고, __len__에도 적용시키고, __getitem__에도 적용시키면 문제가 해결됩니다.

from skimage import io, transform
import os
class coad_dataset(torch.utils.data.Dataset):
    
    def __init__(self,path):
        self.path = path
        self.img_path = os.path.join(self.path)
        # 미리 파일 리스트를 제작
        self.img_list = [i for i in os.listdir(self.img_path) if i.endswith('png')]
        
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self,idx):
        image_path = os.path.join(self.img_path,self.img_list[idx])
        img = plt.imread(image_path)[:,:,:3].astype('float32')
        img = np.array(img)
        img = img.transpose(2,0,1)
        return img,os.listdir(image_folder)[idx]

 

728x90
반응형

댓글