LLaMA의 Stable Diffusion Moment, 찾아오다
2023년 4월 10일 초안 작성
개요
시작은 LLaMA부터였다.
학술적인 연구 목적으로만 액세스를 허용하겠다는 페이스북의 발표와 달리, LLaMA는 발표 첫날부터 토렌트를 통해 만천하에 공개됐다. 그게 페이스북이 의도한 바든 아니든, 덕분에 언어 생성 모델에도 이미지 생성 모델이 작년에 그랬던 것처럼 Stable Diffusion Moment가 찾아왔다. 공교롭게도 기반이 되는 ChatGPT와 DALL-E 모두 오픈AI의 작품이고, ‘오픈’하겠다는 회사명과 달리 둘 다 공개하지 않다 보니 오히려 전 세계의 오픈소스 진영이 더욱 들끓고 있다.
LLaMA
LLaMA 자체는 Transformer 모델에서 크게 변한게 없다. 그간 딥러닝의 발전속도를 보건데 지금쯤이면 충분히 새로운 모델이 나올법한데 2017년에 공개된 모델이 여전히 최고라는 점은 무척 인상적이다. 물론 Transformer의 decoder라는 점까지는 동일해도 LLaMA에 이르러 약간의 차이는 있다. 논문에 따르면 차이점은 다음과 같다.
-
Pre-normalization [GPT3]
학습 안정성을 높이기 위해 Transformer의 sub-layer output 대신 input(Megatron-LM 적용)에 RMSNorm(T5 적용)(Zhang and Sennrich, 2019)으로 normalize했다. -
SwiGLU activation function [PaLM]
Activation을 ReLU에서 SwiGLU(Shazeer, 2020)로 교체했다. PaLM에서 먼저 도입했으나 성능을 높이기 위해 4d 대신 2/3 4d dimension을 사용했다. 참고로 SwiGLU는 Swish(Google, 2017) + GLU(Microsoft, 2016)의 조합이다. -
Rotary Embeddings [GPTNeo]
이제 Positional Encoding을 적용한 논문은 더 이상 찾을 수가 없다. 개인적으로 Positional Encoding만으로 가능할까 의심했는데, 훨씬 결과가 좋았던 Positional Embeddings를 거쳐 지금은 Rotary Positional Embeddings(RoPE)가 대세가 됐다. LLaMA도 RoPE(Su et al., 2021)를 사용했으며, 2048 토큰보다 긴 입력도 처리할 수 있다.
이외에도,
- Optimizer는 AdamW(Loshchilov and Hutter, 2017)를 적용했다.
- 학습 속도를 높이기 위한 여러 최적화:
- 메모리 사용량을 줄이기 위해 causal multi-head attention을 구현한 xformers 라이브러리를 사용했다.
- 체크포인팅 backward pass시 재계산되는 activations를 줄였다. 계산 비용이 많이 드는 activations를 저장해두었으며, 이를 위해 PyTorch의 autograd를 사용하지 않고 Transformer layers에 대한 backward 함수를 수동으로 구현했다.
학습 데이터
1.4T Tokens로 학습했다. GPT-3는 500B Tokens였으니 약 3배 더 크다. OPT-175B는 180B Tokens라 학습이 부족했다는 것이 중론이라고.1 CommonCrawl 데이터는 위키백과의 인용 출처로 쓰일 수 있는가를 판별하는 분류기를 학습시켜 이를 통과한 데이터만 학습에 투입했다.1
참고로 BloombergGPT는 700B Tokens로 학습했고, 이 중 363B는 블룸버그의 자체 재무 데이터에서 추출했다. 이렇게 학습한 5B 모델이 Finance Tasks에서는 175B 모델의 성능을 능가한다.
학습 시간
65B 모델 학습 시 80GB A100 2,048장으로 380 Tokens/1 GPU/1s를 처리했으며, 1.4T 전체 토큰 학습에 21일이 소요됐다. 전체 처리 속도만 보면 778K Tokens/s이며, TPU v3-256 Pod을 이용한 6B GPT-J(151K Tokens/s), 2.7B GPT-Neo(148K Tokens/s) 보다 훨씬 더 빠르다. 일반적으로 Megatron-LM으로는 24일 정도 소요되는데, 이보다 빨랐다고.1
현재 LLaMA의 라이센스는 Non-commercial bespoke license로 상업적으로 활용할 수 없다. 따라서 상업적 용도를 위해 좀 더 작은 모델을 원할하게 학습하려면 적어도 80GB A100 GPU 500장 이상은 확보해야 할 것으로 보인다.
학습 코드
LLaMA는 학습 코드를 공개하지 않았고, 공개할 계획이 없으나, Lightning AI에서 학습 코드를 공개2했고, Hugging Face에서도 LoRA 학습을 자체 peft 라이브러리에 추가했고, 이를 이용해 학습하는 방식3을 소개하고 있다.
SFT
LLaMA에서 SFT까지 진행한 모델은 다음과 같다.
-
Alpaca: A Strong, Replicable Instruction-Following Model
스탠포드 CRFM(Percy Liang이 디렉터로 있는 Foundation Model 연구소)에서 공개. LLaMA 7B 모델에 사람이 생성한 175개 instruction-output 쌍을 seed로 한 다음, text-davinci-003에 52K examples 생성 후 SFT를 진행했다. 데이터의 다양성을 확보하기 위해 많은 노력을 기울였다. A100 GPU 4장에 PyTorch FSDP로 진행했고, 당시에는 아직 Merge 되지 않은 LLaMA PR이 적용된 Hugging Face 버전을 사용해 학습했다. 데모에는 유해 컨텐츠를 필터링하기 위해 오픈AI의 Moderations API를 적용했고, 모든 모델 출력에 워터마크를 표시했다. -
Vicuna: An Open Chatbot Impressing GPT-4
ShareGPT에서 수집한 대화 70K개를 활용해 SFT를 진행했다. 첫 버전은 LLaMA 13B를 SFT 진행한 Vicuna-13B부터 공개. 지금은 7B도 공개됐다. 별도의 FastChat이라는 파이썬 패키지를 만들고, 여기에 Gradio를 이용한 데모 서비스까지 덧붙여 편리하게 활용할 수 있게 했다. LLaMA를 Hugging Face 포맷으로 변환한 다음, 다시 FastChat에서 제공하는 apply_delta를 적용하면 Vicuna weights가 된다. LLaMA 라이센스 준수를 위해 이 같은 방식을 택함. 실험 결과 GPU 메모리 30GB 정도면 한 장에도 모델이 충분히 올라갔다. 학습은 A100 GPU 8장에 PyTorch FSDP로 하루 동안 했다. GPT-4로 Evaluation했을 때 ChatGPT를 14% 능가했다. 하지만 아쉽게도 다른 모델과 마찬가지로 상당한 Hallucination 문제가 존재한다.
Model Name | LLaMA | Alpaca | Vicuna | Bard/ChatGPT |
---|---|---|---|---|
Dataset | Publicly available datasets(1T token) | Self-instruct from davinci-003 API(52K samples) | User-shared conversations(70K samples) | N/A |
Training code | N/A | Available | Available | N/A |
Evaluation metrics | Academic benchmark | Author evaluation | GPT-4 assessment | Mixed |
Training cost(7B) | 82K GPU-hours | $500(data) + $100(training) | $140(training) | N/A |
Training cost(13B) | 135K GPU-hours | N/A | $300(training) | N/A |
- Koala: A Dialogue Model for Academic Research
UC버클리의 BAIR에서 공개한 모델. Vicuna와 마찬가지로 ShareGPT 데이터 60K를 이용했다. 데이터 품질을 유지하기 위해 중복과 영어가 아닌 대화는 제거했고, 이렇게 해서 30K를 남겼다. 평가는 100명의 사람이 진행했다. - Baize: An Open-Source Chat Model with Parameter-Efficient Tuning on Self-Chat Data
LoRA 방식으로 학습한 모델, 마찬가지로 100K 다이얼로그는 ChatGPT를 이용해 생성했다. - GPT4All
gpt-3.5-turbo를 이용해 800K 샘플을 구축하고 SFT를 진행했다.
RLHF
앞서 SFT만으로도 좋은 결과를 보였으나 InstructGPT 논문과 동일하게 RLHF까지 진행한 모델은 다음과 같다.
-
StackLLaMA
초기 학습은 StackExchange 데이터셋을 사용했으며, upvotes를 이용해 Reward 점수를 자동으로 구축했다. 메모리 사용량을 줄이기 위해 LoRA 학습을 자체 peft 라이브러리에 추가했고, 이를 이용했다. 8비트(파라미터 당 1바이트 메모리만 사용)로 모델을 로드했으며, LLaMA 7B의 경우 7GB 메모리만 차지했다. 이 덕분에 SFT의 경우 Google Colab에서도 가능하다. RLHF는 자체 개발 중인 Transformer Reinforcement Learning 라이브러리인 TRL을 사용했다. -
ColossalChat
Colossal-AI 프로젝트의 일환으로 LLM에 RLHF를 적용한 ColossalChat 프로젝트를 진행 중이다. -
Open Assistant
“Stable Diffusion이 세상을 새로운 방식으로 예술과 이미지를 만드는 데 도움을 준 것과 마찬가지로 놀라운 대화형 AI를 제공하여 세상을 개선하고자 합니다.”라고 자신들을 소개하고 있다. 비전쪽에서는 LAION-400M 데이터셋으로 유명한 LAION AI에서 진행하는 프로젝트다. 이미 Pythia 12B를 기반으로 SFT를 진행한 모델은 공개4한 상태다.
기타
- trlX
EleutherAI에서 분사한 CarperAI에서 만든 Transformer Reinforcement Learning eXtended 프로젝트로, TRL을 기반으로 확장한 프로젝트다. 파라미터 20B 까지는 Hugging Face의 Accelerate를 지원하고, 그 이상은 NeMo도 지원한다. NeMo의 PPO 버전은 아직 개발중인 단계. 그러나 WandB에서 trlX를 이용한 Summarize Task RLHF를 공개5하는 등 사실상 ChatGPT RLHF를 위한 프레임워크는 거의 완성된 단계다.
-
llama.cpp
LLaMA를 A100 GPU가 아니라 로컬 맥북에서도 구동할 수 있도록 하는 것을 목표로 하는 프로젝트. C++ 최적화, 4비트 Quantization, CPU 지원을 주요 목표로 하고 있다. 예전에 프랑스의 천재적인 해커 Fabrice Bellard를 연상케 하는(요즘은 이 아저씨도 GPT와 LLM에 올인 중이다) 불가리아의 Georgi Gerganov가 진행하는 개인 프로젝트. 갑자기 등장한 프로젝트는 아니고 오래전부터 GPT를 CPU로 Inference하는 C++ 프로젝트를 진행해왔고 이의 연장선 상이다. 상당히 높은 완성도를 보이며 이 정도면 Production에서도 CPU 서버만으로 충분히 서비스가 가능할 것으로 보인다. -
LLaMA 4bit ChatBot Guide v2
LLaMA를 4비트로 실행하는 가이드를 제공한다. 여러 리소스를 제공하며, 이를 통해 CPU Production으로 운영 비용을 획기적으로 절감할 수 있을 것으로 보인다. -
LLM Worksheet
각종 LLM의 성능을 비교한 시트. LLaMA의 성능이 매우 뛰어남을 확인할 수 있다.
References
-
https://devocean.sk.com/blog/techBoardDetail.do?ID=164601&boardType=techBlog ↩ ↩2 ↩3
-
https://lightning.ai/pages/community/tutorial/accelerating-llama-with-fabric-a-comprehensive-guide-to-training-and-fine-tuning-llama/ ↩
-
https://huggingface.co/docs/trl/main/en/using_llama_models ↩
-
https://huggingface.co/OpenAssistant/oasst-sft-1-pythia-12b ↩
-
https://wandb.ai/carperai/summarize_RLHF/reports/Implementing-RLHF-Learning-to-Summarize-with-trlX--VmlldzozMzAwODM2 ↩