인공지능을 좋아하는 곧미남

pytorch의 def forward(self, x) 본문

code_study/pytorch

pytorch의 def forward(self, x)

곧미남 2022. 1. 18. 11:12

오늘은 간단히 pytorch에서 제공하는 대표적인 모델 구현에 사용하는 nn.Module이나 기타 다른 모델에서 input data를 손 쉽게 전달할 수 있게 만든 def forward()에 관해 알아보겠습니다.

 

이건 앞전에 제가 python 문법인 "__call__" 설명 게시물을 한번 보고 오시면 도움이 되실겁니다.

 

오늘 내용은 아래 목차와 같습니다.

 

< INDEX >

 

1. "def forward()"의 역할

 

2. 어떻게 pytorch에서 "def forward()"를 구현했는가?


1. "def forward()"의 역할

우선 def forward(self, x)의 메커니즘을 알기위해서는 python의 "__call__"과 클래스 상속에 관한 내용을 숙지해야합니다.

 

제 블로그 게시물에 "__call__"과 클래스 상속 내용이 있으니 한번 살펴보고 오시면 좋습니다.

 

"__call__"기능만 활용해서 충분히 forward() 함수를 직접 커스텀 코드로 구현한 내용을 확인할 수 있으니 꼭 참고하시길 바랍니다.

 

- "__call__": https://soonhandsomeguy.tistory.com/17?category=976687

 

- 클래스상속: https://soonhandsomeguy.tistory.com/24?category=976687

 

그럼, 여기선 그 내용을 이해했다는 가정하에 설명하도록 하겠습니다.

 

자, forward(self, x)를 사용해서 정말 간단하게 input data를 전달 받아 conv2d layer에 입력시킵니다. 이런 역할을 하는 forward 함수는 어떻게 구현이 되었을까요? 아래에서 살펴보시죠.

2. 어떻게 pytorch에서 "def forward()"를 구현했는가?

보통은 Layer 구성을 할때 커스텀하고 싶은 모델 클래스를 선언하고 nn.Module을 부모 클래스로 상속 받습니다.

 

이때, 상속 받은 nn.Module에서 정의되어 있는 forward 인스턴스를 사용하는 것이 핵심포인트입니다.

 

아래 코드를 보시면, Model 클래스에 nn.Module 클래스를 부모 클래스로 상속 받았습니다. 그럼 부모 클래스의 객체를 사용할 수 있는데요. 

여기서 nn.Module을 타고 들어가보면, forward: Callable[..., Any] = _forward_unimplemented 라고 forward변수가 선언되어 있습니다.

 

이 코드를 분석해보자면, 

 

- forward: Callable[..., Any] -> forward 변수가 __call__ 객체인것을 힌트를 주는 코드  

- forward = _forward_unimplemented 로 "_forward_unimplemented " 메서드를 실제값으로 가지는 forward 변수를 선언한 것입니다.

- "_forward_unimplemented " 메서드는 무슨 역할을 하겠습니까?

 

- 아래 코드를 보면 Should be overridden by all subclasses. 라고 주석이 달려있습니다. 

 

- 즉, forward는 자식클래스에서 override되어 사용된다는 말인데요.

 

- 그래서 맨 처음 사진의 코드와 같이 커스텀한 클래스에서 nn.Module을 상속받아 forward를 override해서 사용할 수 있는 것입니다.

3. 내용 고찰

파이썬 문법을 아주 잘 활용하고 전체적인 디자인이 멋진 pytorch의 코드를 분석하는 것은 여러모로 코딩을 잘하기위한 인사이트를 얻을 수 있습니다. 독자분들도 시간이 많을 때나 코딩 실력을 높이고 싶을 때 이렇게 API 코드를 뜯어보고 분석하는 것을 추천드립니다. 저도 지속적으로 열심히 코드 분석하여 업로드 하겠습니다.

 

읽어 주셔서 갑사합니다.

Comments