FastChat의 train/test split 버그
2023년 8월 12일 초안 작성
개요
FastChat으로 학습 후 모델에서 일부 학습 데이터의 결과가 유독 두드러져 보인다는 보고가 있어 이를 실험을 통해 직접 확인해봤다.
실험 환경
계산하기 쉽도록 데이터는 1,000건으로 구성했고, 이 중 10%를 eval로 처리했다.
- train 900건
- eval 100건
학습 옵션은 다음과 같다.
num_train_epochs=5
gradient_accumulation_steps=16
per_device_train_batch_size=2
iterations가 35가 나오며, 따라서 전체 학습 횟수는 다음과 같다.
35 iterations * 16 steps * 2 batch * 4 gpus = 4,480
참고로 900건에 5 epochs는 4,500이다.
실험
20개 문서를 샘플링하여 이 문서가 몇 번이나 학습에 참여하는지 횟수를 확인해봤다. 각 문서가 5번씩 학습에 참여해야 정상일 것이다.
문서 | GPU 0 | GPU 1 | GPU 2 | GPU 3 | 합계 |
---|---|---|---|---|---|
DOC #1 | 0 | 1 | 0 | 2 | 3 |
DOC #2 | 1 | 1 | 3 | 3 | 8 |
DOC #3 | 0 | 1 | 2 | 2 | 5 |
DOC #4 | 1 | 3 | 1 | 2 | 7 |
DOC #5 | 0 | 2 | 0 | 2 | 4 |
DOC #6 | 3 | 1 | 1 | 1 | 6 |
DOC #7 | 0 | 1 | 1 | 2 | 4 |
DOC #8 | 1 | 1 | 2 | 0 | 4 |
DOC #9 | 3 | 1 | 2 | 2 | 8 |
DOC #10 | 1 | 0 | 0 | 1 | 2 |
DOC #11 | 2 | 1 | 2 | 0 | 5 |
DOC #12 | 0 | 2 | 0 | 0 | 2 |
DOC #13 | 1 | 2 | 1 | 2 | 6 |
DOC #14 | 0 | 1 | 1 | 1 | 3 |
DOC #15 | 0 | 1 | 2 | 2 | 5 |
DOC #16 | 2 | 2 | 0 | 2 | 6 |
DOC #17 | 2 | 0 | 1 | 1 | 4 |
DOC #18 | 1 | 1 | 1 | 0 | 3 |
DOC #19 | 1 | 0 | 3 | 0 | 4 |
DOC #20 | 1 | 1 | 2 | 0 | 4 |
이처럼 2 ~ 8회까지 불규칙한 분포를 보이며, 각각의 GPU가 1,125회 학습에 참여해 전체를 계산해보면 4,500회 정상적으로 학습은 마쳤으나 실제로 학습에 참여한 문서는 GPU 당 675 ~ 694건으로 각각의 GPU는 전체 문서를 보지 못했다. 각 GPU는 900건의 문서 중 GPU의 갯수만큼으로 나눈(\(\frac{1}{4}\)) 225건 까지는 전체를 정확히 1:1 샘플링하지만 그 이상은 랜덤하게 샘플링하고 1,125회를 추출해도 700건을 채 보지 못한다.
그렇다면 GPU를 1개만 동작했을 때는 정상적으로 5번만 살펴볼까.
문서 | GPU 0 |
---|---|
DOC #1 | 5 |
DOC #2 | 5 |
DOC #3 | 5 |
DOC #4 | 5 |
DOC #5 | 0 |
DOC #6 | 5 |
DOC #7 | 5 |
DOC #8 | 5 |
DOC #9 | 0 |
DOC #10 | 0 |
DOC #11 | 5 |
DOC #12 | 5 |
DOC #13 | 5 |
DOC #14 | 5 |
DOC #15 | 5 |
DOC #16 | 5 |
DOC #17 | 5 |
DOC #18 | 5 |
DOC #19 | 5 |
DOC #20 | 5 |
eval로 분리된 10%(여기서는 약간 수치를 넘어선 3개)를 제외하고 정확히 5회씩 학습에 참여하는 것을 확인할 수 있다. 이런 형태가 GPU 2개 일때, 4개 일때도 동일하게 나타나야 한다. 그러나 4개일 때는 앞서 본 것처럼 2 ~ 8회로 불규칙한 모습을 보인다. 그렇다면 GPU를 2개만 사용하면 어떻게 될까.
문서 | GPU 0 | GPU 1 | 합계 |
---|---|---|---|
DOC #1 | 2 | 0 | 2 |
DOC #2 | 3 | 2 | 5 |
DOC #3 | 1 | 2 | 3 |
DOC #4 | 3 | 2 | 5 |
DOC #5 | 3 | 2 | 5 |
DOC #6 | 4 | 2 | 6 |
DOC #7 | 3 | 3 | 6 |
DOC #8 | 3 | 0 | 3 |
DOC #9 | 1 | 2 | 3 |
DOC #10 | 4 | 3 | 7 |
DOC #11 | 3 | 4 | 7 |
DOC #12 | 0 | 3 | 3 |
DOC #13 | 0 | 1 | 1 |
DOC #14 | 2 | 2 | 4 |
DOC #15 | 3 | 3 | 6 |
DOC #16 | 1 | 4 | 5 |
DOC #17 | 4 | 1 | 5 |
DOC #18 | 2 | 2 | 4 |
DOC #19 | 2 | 2 | 4 |
DOC #20 | 0 | 0 | 0 |
마찬가지로 0 ~ 7회로 불규칙하며 이렇게 해서는 제대로 학습이 되지 않는다. 또한 데이터도 정확히 GPU 갯수만큼 나눈 450건 까지는 1:1 샘플링하지만 그 이상이 되면 랜덤 샘플링을 하며 마찬가지로 868 ~ 873건 수준에서 더 이상 보지는 못한다. 각 GPU는 4개 일때에 비해 2배 더 많은 2,250회를 추출하지만 여전히 전체를 다 보지는 못하는 것이다.
코드 분석
예전부터 의심했던 부분은 make_supervised_data_module()
에서 train/test를 split하는 부분이다. 코드는 다음과 같다.
perm = np.random.permutation(len(raw_data))
split = int(len(perm) * 0.9)
train_indices = perm[:split]
eval_indices = perm[split:]
이처럼 랜덤 순열을 생성한 다음 train/eval로 split하는데, 문제는 seed가 없다 보니 각 프로세스가 저마다 다르게 split한다는 점이다. eval 조차도 프로세스 별로 다르게 갖게 되니 어떤 프로세스에서는 학습 문서인 것이 어떤 프로세스에서는 검증 문서가 된다. 각 문서는 5 또는 0으로 균등한 횟수를 보여야 하지만 이처럼 각 프로세스가 다른 순서로 데이터를 갖고 있으면 문서마다 학습에 참여하는 횟수가 상이하므로 문제가 있다.
그나마 2→4 GPU는 loss 변화에 차이가 있지만 4→8은 거의 차이가 없고, 8→16은 아예 비슷하게 떨어진다. 데이터를 뒤섞어 놓고 GPU가 많을수록 더 적은 수를 샘플링하니 오히려 더 불규칙하게 추출되는 것이다. 이 때문에 GPU 8장을 초과하는 멀티 노드에서는 제대로 학습이 되지 않는다.
이 문제는 다음과 같이 seed를 부여함으로써 해결할 수 있다.
np.random.seed(42)
perm = np.random.permutation(len(raw_data))
이제 랜덤하게 추출해도 seed가 고정이므로 각 프로세스마다 동일한 순서로 데이터를 갖게 된다. GPU 4개로 다시 한번 실험한 결과는 다음과 같다.
문서 | GPU 0 | GPU 1 | GPU 2 | GPU 3 | 합계 |
---|---|---|---|---|---|
DOC #1 | 0 | 1 | 4 | 0 | 5 |
DOC #2 | 2 | 1 | 2 | 0 | 5 |
DOC #3 | 1 | 1 | 2 | 1 | 5 |
DOC #4 | 0 | 1 | 2 | 2 | 5 |
DOC #5 | 0 | 2 | 2 | 1 | 5 |
DOC #6 | 2 | 0 | 2 | 2 | 6 |
DOC #7 | 3 | 0 | 0 | 2 | 5 |
DOC #8 | 0 | 0 | 0 | 0 | 0 |
DOC #9 | 3 | 1 | 0 | 1 | 5 |
DOC #10 | 1 | 1 | 2 | 1 | 5 |
DOC #11 | 0 | 0 | 0 | 0 | 0 |
DOC #12 | 0 | 0 | 0 | 0 | 0 |
DOC #13 | 2 | 0 | 3 | 0 | 5 |
DOC #14 | 2 | 1 | 1 | 1 | 5 |
DOC #15 | 1 | 0 | 0 | 4 | 5 |
DOC #16 | 2 | 0 | 2 | 1 | 5 |
DOC #17 | 1 | 1 | 1 | 2 | 5 |
DOC #18 | 2 | 2 | 1 | 0 | 5 |
DOC #19 | 0 | 3 | 2 | 0 | 5 |
DOC #20 | 2 | 0 | 2 | 1 | 5 |
거의 모든 문서가(1개 제외) 5회씩 학습에 참여한다. eval로 split되어 0회 참여한 문서도 각 프로세스마다 동일하다. 이제 멀티 GPU에서 문제 없이 정상적으로 학습이 된다.
결론
seed 고정으로 문제를 해결한 후 FastChat 공식 레포를 뒤적여 보니 이미 5월 29일자로 패치가 적용된 것을 확인할 수 있었다. 이전 버전을 쓰다보니 몰랐던 사항이고, 지금은 이 코드 또한 다른 파일로 리팩토링 된 상태다.
이제 멀티 GPU 뿐만 아니라 멀티 노드에서도 문제 없이 학습이 잘 될 것이다.
wandb에서 X축을 Relative Time (Wall), Y축을 eval/loss
로 설정해서 보면 기존에는 멀티 노드에서도 loss가 동일하게 떨어졌지만 이제 6 노드에서는 1 노드에 비해 훨씬 더 큰 폭으로 loss가 떨어지는 것을 확인할 수 있다.