Logical Scribbles
[Project] VPT를 이용하여 Segmentation 모델 만들기 본문
코드는 아래 깃허브에서 확인할 수 있습니다.
Preliminary
이번 프로젝트에서는 Visual Prompt Tuning 논문을 응용하여 segmentation 모델을 만들어 보았다. 기존 VPT 논문에서는 classification task를 주로 수행하고 있지만, 나는 classification head를 삭제하고 segmentation을 위한 head를 새로 추가하였다. (기존 VPT 논문에 따르면 segmentation도 실험을 진행한 것 같지만, 구체적으로 서술되어 있지 않다.)
TransUNet에서 소개된 Upsampling CNN 구조를 이용하여 task specific head를 만들어보았으며, 데이터셋으로는 PASCAL VOC2012를 사용하였다.
그 결과 기존 모델( ViT_base_patch16_224)의 파라미터 중 약 0.78%를 사용하여 segmentation이 가능했다.
• TOP1 Accuracy : mIoU 58.7%
• ViT_base_patch16_224, deep, batch = 32, prompt_ token = 5
• Max lr = 0.001, 에포크 = 32, weight_ decay = 0.0001
• For training : Cross Entropy, Adam W, OneCycleLR
아래는 다양한 이미지에 대한 segmetation 이미지 예시다. 비록 결과는 SOTA에 크게 못 미쳤지만 VPT의 구조를 좀 더 면밀하게 이해하여 공부할 수 있었고, 헤드를 변경하고 데이터셋을 전처리하는 과정에서 많은 것을 배웠다!
'Projects' 카테고리의 다른 글
[Project] 자연어 처리와 Graph 이론을 이용한 Twilight 인물 Network 분석 (0) | 2023.12.30 |
---|---|
[Project] zero-Shot Photo Frame Recommendation Using Clustering Algorithms (0) | 2023.12.30 |