본문 바로가기
딥러닝\머신러닝

[pytorch, 딥러닝] 모델 저장하는 방법

by 인포메틱스 2022. 7. 28.
반응형

 

 최근 논문을 보면서 github에 공개된 코드를 해석하는 중에 공개된 모델을 사용해보고자 하였습니다.

 

 처음에는 제가 배운 모델 로드 방법이 먹히질 않아서(kaggle에서 얻게된 모델 로드방법...) 정말 오랫동안 알아 보는 도중에 알아내게 되었습니다.

 

 그걸 포스팅 해보고자 합니다. 머리카락 한움쿰정도 뽑아진것 같네요.

 

 제가 kaggle에서 배운 모델 저장 방법은 checkpoint를 저장하는 방법이고 모델을 제작할때용된 정보들을 이용하여 dictionary를 제작 후에 저장하는 방법입니다.

 

checkpoint = {
            'epoch': epoch + 1,
            'loss': epoch_val_loss,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        
torch.save(checkpoint)

 

위와 같은 경우에는 model에다 각 지정을 해줘야합니다.

 

model = 제작할때와 같은 모델(가중치 없는 빈깡통)
checkpoint = torch.load(checkpoint_fpath)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
valid_loss_min = checkpoint['loss']

 

이걸 kaggle에서 배웠는데, 좀 더 쉽고 간단하게 저장하는 방법이 있습니다.

 

model # <-  학습후의 모델이라 가정
# 1.
torch.save(model.state_dict(),'저장할 위치1') # 가중치만 저장
# 2.
torch.save(model,'저장할 위치2') # 모델 전체 저장


# 사용시
model = 빈깡통의 모델로 지정
# 1.
model.load_stat_dict(torch.load('저장한위치1')) 
# 2.
model=torch.load('저장한위치2')

 

위와 같이 진행하면 됩니다.

 

 만약에 정체모를 모델을 받았을때, 그냥 torch.load해서 읽어오시고 보면 차이가 납니다. dictionary로 되어있느냐, 혹은 바로 OrderedDict로 시작하냐 입니다. 혹은 모델 구조가 나오는 경우는 모델을 저장한 경우겠죠.

 

 

 

 

 

728x90
반응형

댓글