배움 기록/Deep Learning

[PyTorch, MONAI] UNETR 모델 생성 및 Forward

Spezi 2023. 1. 12. 03:24
반응형

UNETR이 무엇인지는 

2023.01.11 - [Programming] - [MONAI] UNETR 이란? (feat. Vision Transformers)

 

[MONAI] UNETR 이란? (feat. Vision Transformers)

이 설명은 당장 UNETR를 써야하는데 빨리 뭔지 대충 알고 싶은 경우만 살짝 도움이 될뿐 자세한 내용은 아래의 논문에서 확인가능 https://arxiv.org/abs/2103.10504 Background 한 줄 정리 FCNN(Fully Convolutional N

jedemanfangwohnteinzauberinne.tistory.com

UNETR 모델에 쓰기위한 데이터의 전처리는

2023.01.11 - [Programming] - [MONAI, PyTorch] UNETR 에 넣을 데이터 전처리 하기

 

[MONAI, PyTorch] UNETR 에 넣을 데이터 전처리 하기

이번에 인턴십 중에 MONAI를 이용한 프로젝트를 하게 되었다. 내용을 공부할겸 여기에 정리를 해본다. MONAI MONAI(Medical Open Network for AI) 는 엔비디아가 만든 헬스케어용 오픈소스 프레임워크(파이토

jedemanfangwohnteinzauberinne.tistory.com

 


UNETR 모델 생성

from monai.networks.nets import UNETR
import torch
import torchio as tio

#create model
model = UNETR(
            in_channels=1, #dimension of input channels.
            out_channels=8, #dimension of output channels.
            img_size=(96, 96, 96), #dimension of input image.
            feature_size=16, #dimension of network feature size.
            hidden_size=768, #dimension of hidden layer.
            mlp_dim=3072, #dimension of feedforward layer.
            num_heads=12, #number of attention heads.
            pos_embed="perceptron", #position embedding layer type.
            norm_name="instance", #feature normalization type and arguments.
            res_block=True, #bool argument to determine if residual block is used.
            conv_block=True, #bool argument to determine if convolutional block is used.
            dropout_rate=0.0, #faction of the input units to drop.
        )

UNETR 모델을 위와 같이 생성 해준다. (참고: https://docs.monai.io/en/stable/_modules/monai/networks/nets/unetr.html)

Saving and loading models

state_dict = torch.load("PATH")["state_dict"] #save
state_dict = {k.partition("model.")[2]: state_dict[k] for k in state_dict.keys()}
model.load_state_dict(state_dict) #load

learning parameters(weights, bias)를 불러와서 state_dict에 저장한다. (이미 train 한 결과의 weight 값을 가져와서 현 모델에 저장하는것)

load_state_dict() 함수는 dictionary 를 object 로 취하기 때문에 state_dict을 먼저 만들어주는 것이다. 성공적으로 load 하면 <All keys matched successfully> 를 확인할 수 있다.

 

Forward

def forward(self, x_in):
        x, hidden_states_out = self.vit(x_in) # visual transformer
        enc1 = self.encoder1(x_in)
        x2 = hidden_states_out[3]
        enc2 = self.encoder2(self.proj_feat(x2))
        x3 = hidden_states_out[6]
        enc3 = self.encoder3(self.proj_feat(x3))
        x4 = hidden_states_out[9]
        enc4 = self.encoder4(self.proj_feat(x4))
        dec4 = self.proj_feat(x)
        dec3 = self.decoder5(dec4, enc4)
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        out = self.decoder2(dec1, enc1)
        return self.out(out)

 

 

 

 


배움을 기록하기 위한 공간입니다. 

수정이 필요한 내용이나 공유하고 싶은 것이 있다면 언제든 댓글로 남겨주시면 환영입니다 :D

반응형