본문 바로가기
딥러닝\머신러닝/이미지분석

[pytorch] transforms.Compose 사용법

by 인포메틱스 2024. 1. 15.
반응형

 

1. Intro

 요즘 이미지 분석에 대해서 흥미가 생겨서 열심히 공부하고 있습니다(사실 관련된 일을 맡게 되었습니다. 먹고살려고 빡시게 공부중입니다.).

 

tensorflow를 공부하려다가 pytorch가 사용하기 편하다는 이야기를 듣고 바로 pytorch로 마음을 돌렸습니다. 

 

바로 본론으로 들어가도록 하겠습니다.

 

딥러닝에서 이미지관련 모델을 제작할 때, 힘든 부분 중 하나가 바로 데이터의 양이지 않을까 싶습니다. 

 

제한된 이미지에서 좋은 모델을 제작하기 위해서 사람들이 생각을 해낸 것은 Data augmentation입니다. 

 

Data augmentation은 생각보다 쉽습니다. 각도나, 명암을 변경하던지, 부분을 자르던지 다양한 방법으로 통해 augmentation 을 할 수 있습니다.

 

Data augmentation을 코드로 구현하려면 torchvision에서 transforms이나, albumentations module을 사용하면 됩니다.

 

주의할 점모델 제작에는 torchvision의 transforms을 사용하는데, 모델을 돌릴 때에는 albumentations을 사용하면 안됩니다. 결과에서 차이가 날 수 있다는 것을 인지하셔야 합니다(1년 동안 경험한 이야기).

 

이번 포스팅에서는 torchvision에서 transforms 기능에 대한 이야기를 하고자 합니다.

 

2. transforms.Compose

  딥러닝 공부를 하면서 transforms.Compose 라는 코드를 많이 보셨을 겁니다. transforms.Compose는 다양한 Data augmentation을 한꺼번에 손쉽게 해주는 기능입니다. transforms에 있는 기능들 대부분 범위로 들어가게 되고, 랜덤으로 변경되어 적용됩니다.

 

아래 예시는 Data augmentation에 자주 사용되는 RandomRotation 을 이용하여 기본적인 사용예시를 들었습니다.

 

import torchvision.transforms as transforms
transforms.RandomRotation(90)
## output 
RandomRotation(degrees=[-90.0, 90.0], interpolation=nearest, expand=False, fill=0)

transforms.RandomRotation(190)
## output 
RandomRotation(degrees=[-190.0, 190.0], interpolation=nearest, expand=False, fill=0)

transforms.RandomRotation([30,100])
## output
RandomRotation(degrees=[30.0, 100.0], interpolation=nearest, expand=False, fill=0)

 

 transforms.Compose에는 Augmentation이외에도 모델을 제작할 때 사용하는 기본적인 기능도 있습니다.

 

 모델을 제작해야 할 때, Tensor 형태로 들어가야 하는데, 이를 transforms.ToTensor를 이용하면 변경시켜서 나옵니다.

 

 참고로 ToTensor는 albumentations에서도 비슷한 기능이 있는데, ToTensor(old버전만있음), ToTensorV2라는 기능입니다. torchvision.Compose에서의 ToTensor는 이미지 값을 0~1사이, 혹은 -1~1사이로 Scale로 만들어주고, HWC -> CHW로 변경하고 Tensor형태로 변경해준다는 것입니다. 반면 albumentation의 ToTensorV2의 경우 단지 torch.Tensor로 변경, HWC ->CHW변경만 해줍니다.

 

그렇기 때문에 사용할 때에는 ToTensor의 경우 Normalize 전에 추가해야되고, ToTensorV2의 경우 ToTensorV2이후에 Normalize를 진행해야 합니다.

 

 그렇다고 transforms.Compose([ ])에다가 albumentations의 ToTensorV2를 사용하는 우를 범하지 마시고, module마다 각자 적용을 해야합니다(어떤 모듈이든 간에 비슷한 기능을 할지라도 작동되는 방법이 다르기 때문에 각자 적용해야 하는 것을 추천합니다.).

 

from torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


transform = A.Compose([
    A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    A.pytorch.ToTensorV2()
])

 

 

 normalize기능에 대해서 잠깐 설명하자면 코드를 보면 대부분 mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]의 값을 확인할 수가 있습니다.

 

이 값들은 도대체 어디서 나왔는가?  ImageNet 데이터 기반으로 제작된 기준들로서 많이 사용되는 이유대부분 모델을 제작할때, 적용해보는 데이터 셋이 ImageNet이기도 하고, transfer learning할때 사용하는 대부분의 모델도 해당 데이터 기준으로 모델을 제작했기 때문에 특별하게 측정하지 않는 이상 해당 normalize를 사용합니다(이 부분에 대해서도 다양한 의견이 있긴합니다만 지금은 패스).

 

3. custom function

transforms.Compose에 들어갈 기능들은 제한적인데, custom function을 추가하고 싶을 경우에는 어떻게 해야하냐?

 

custom function을 제작하기 위해서는 주의 사항이 있습니다. transforms.Compose에 들어갈 input이 어떤 타입이 들어가야하는지를 알아야 합니다.

 

 예를 들어 custom function에 input, output이 numpy.array인데, 다른 transforms.Compose function들과 같이 사용할 경우 아래 같은 에러가 뜰 수 있습니다. 혹은 그냥 이미지를 읽고 나서 적용할때 다음과 같은 에러가 발생할 수 있습니다.

TypeError: Unexpected type <class 'numpy.ndarray'>

 

 위와 같은 에러가 뜨는 이유는 transforms.Compose에는 PIL의 Image를 이용해야하고, Image의 차원이 3차원(3개의 채널)이 되어야 합니다. 그렇기 때문에 다음과 같이 적용을 하면 돌아가게 됩니다.

 

image = Image.open(file_path)
image = np.array(image)[:,:,:3]
image = Image.fromarray(image)

 

 다음과 같은 custom function을 기반으로 제작하면 됩니다.

from PIL import Image
def Custom_fun(img):
	img = np.array(img)[:,:,:3]
    if img.max()<1:
    	img*=255
	## 관련 후처리 기능 추가 ##
    img = img.astype(np.uint8)
    return Image.fromarray(img)
    
 trn = transforms.Compose([
 Custom_fun,
 transforms.ToTensor(),
 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
 ])
 
# PIL.Image 방법
Image = Image.open(file_path)
Image = np.array(Image)[:,:,:3]
Image = Image.fromarray(Image)
Image_trn=trn(Image)
 
# matplotlib.pyplot as plt 방법
Image = plt.imread(file_path)
image = np.array(image)[:,:,:3]
image *=255
image=image.astype(np.uint8)
image = Image.fromarray(image)
Image_trn=trn(Image)

 

 

 

728x90
반응형

댓글