Docker기반의 RESTful api로 pytorch model 배포하기[2-5]
- Flask RESTful을 이용하여, HTTP로 호출할 수 있는 RESTful api 만들기
- 학습한 모델의 weight와 graph를 하나의 파일로 만들기(torchscript)
- 보안을 위해, 2번에서 만든 torchscript파일을 encryption하기
- 보안을 위해, cython을 이용하여 python script를 library형태로 만들기
- Docker image 만들고 배포하기
지난 글에서는 flask restful을 이용해서 restful api를 만들었었다.
이제 web상에서 HTTP method를 이용하여 restful api를 호출할 수 있게 되었다.
다음으로 할 것은, POST method를 구현하기 위해 post 함수 내부에서 image를 매개변수로 받고 model로 prediction하는 일이다.
pytorch를 이용한 model의 save&load는 아래의 orderdict을 활용한 예제일 것이다.
import torch
''' save '''
state = {'net': net.state_dict()}
torch.save(state, 'net.pt')
''' load '''
state = torch.load('net.pt')
net.load_state_dict(state['net'])
이러한 방식은 network의 구조가 미리 선언되어 있어야 하는 조건이 있다. 또한 network의 parameter를 dictionary형태로 저장하므로, network의 이름 또한 저장한 이름과 동일해야 한다.
본인이 생각하기에는, 위의 방식을 사용할 경우 network 구조에 대한 선언이 web server에 포함되어야 하므로 좋지 않은 구조라고 판단했다. network를 업데이트할 때마다 web server에 있는 code도 변경되어야 한다면 중간에 실수할 가능성도 많아지니깐.
그러므로, 여기서는 위의 dictionary 기반의 방식을 사용하지 않고 network 구조와 weight를 하나의 파일로 저장할 수 있는 torchscript방식을 사용한다.
rand_input = torch.rand(1, 3, img_size[0], img_size[1])
traced_net = torch.jit.trace(net, rand_input)
pytorch에서 제공하는 torchscript는 script 방식과 trace 방식이 존재하는데, 만약 선언한 network의 forward에서 조건에 따른 분기가 존재하지 않는다면 trace 방식을 선택하면 된다.
간단하게 설명하면, trace는 말그대로 주어진 input에 따른 실행 흐름을 추적하여 IR로 표현하고 script는 pytorch에서 제공하는 script compiler를 이용하여 decorator로 명시된 부분에 대해 IR로 표현한다.
위의 코드를 실행할 경우, torchscript로 표현된 network가 생성되며 이를 아래와 같이 저장하고 불러오기를 하면 된다.
''' save '''
traced_net.save(traced_path)
''' load '''
net = torch.jit.load(traced_path)