FastChat의 train/test split 버그

FastChat으로 학습 후 모델에서 일부 학습 데이터의 결과가 유독 두드러져 보인다는 보고가 있어 이를 실험을 통해 직접 확인해봤다.

2023년 8월 12일 초안 작성

개요

FastChat으로 학습 후 모델에서 일부 학습 데이터의 결과가 유독 두드러져 보인다는 보고가 있어 이를 실험을 통해 직접 확인해봤다.

실험 환경

계산하기 쉽도록 데이터는 1,000건으로 구성했고, 이 중 10%를 eval로 처리했다.

  • train 900건
  • eval 100건

학습 옵션은 다음과 같다.

  • num_train_epochs=5
  • gradient_accumulation_steps=16
  • per_device_train_batch_size=2

iterations가 35가 나오며, 따라서 전체 학습 횟수는 다음과 같다.

35 iterations * 16 steps * 2 batch * 4 gpus = 4,480

참고로 900건에 5 epochs는 4,500이다.

실험

20개 문서를 샘플링하여 이 문서가 몇 번이나 학습에 참여하는지 횟수를 확인해봤다. 각 문서가 5번씩 학습에 참여해야 정상일 것이다.

문서 GPU 0 GPU 1 GPU 2 GPU 3 합계
DOC #1 0 1 0 2 3
DOC #2 1 1 3 3 8
DOC #3 0 1 2 2 5
DOC #4 1 3 1 2 7
DOC #5 0 2 0 2 4
DOC #6 3 1 1 1 6
DOC #7 0 1 1 2 4
DOC #8 1 1 2 0 4
DOC #9 3 1 2 2 8
DOC #10 1 0 0 1 2
DOC #11 2 1 2 0 5
DOC #12 0 2 0 0 2
DOC #13 1 2 1 2 6
DOC #14 0 1 1 1 3
DOC #15 0 1 2 2 5
DOC #16 2 2 0 2 6
DOC #17 2 0 1 1 4
DOC #18 1 1 1 0 3
DOC #19 1 0 3 0 4
DOC #20 1 1 2 0 4

이처럼 2 ~ 8회까지 불규칙한 분포를 보이며, 각각의 GPU가 1,125회 학습에 참여해 전체를 계산해보면 4,500회 정상적으로 학습은 마쳤으나 실제로 학습에 참여한 문서는 GPU 당 675 ~ 694건으로 각각의 GPU는 전체 문서를 보지 못했다. 각 GPU는 900건의 문서 중 GPU의 갯수만큼으로 나눈(\(\frac{1}{4}\)) 225건 까지는 전체를 정확히 1:1 샘플링하지만 그 이상은 랜덤하게 샘플링하고 1,125회를 추출해도 700건을 채 보지 못한다.

그렇다면 GPU를 1개만 동작했을 때는 정상적으로 5번만 살펴볼까.

문서 GPU 0
DOC #1 5
DOC #2 5
DOC #3 5
DOC #4 5
DOC #5 0
DOC #6 5
DOC #7 5
DOC #8 5
DOC #9 0
DOC #10 0
DOC #11 5
DOC #12 5
DOC #13 5
DOC #14 5
DOC #15 5
DOC #16 5
DOC #17 5
DOC #18 5
DOC #19 5
DOC #20 5

eval로 분리된 10%(여기서는 약간 수치를 넘어선 3개)를 제외하고 정확히 5회씩 학습에 참여하는 것을 확인할 수 있다. 이런 형태가 GPU 2개 일때, 4개 일때도 동일하게 나타나야 한다. 그러나 4개일 때는 앞서 본 것처럼 2 ~ 8회로 불규칙한 모습을 보인다. 그렇다면 GPU를 2개만 사용하면 어떻게 될까.

문서 GPU 0 GPU 1 합계
DOC #1 2 0 2
DOC #2 3 2 5
DOC #3 1 2 3
DOC #4 3 2 5
DOC #5 3 2 5
DOC #6 4 2 6
DOC #7 3 3 6
DOC #8 3 0 3
DOC #9 1 2 3
DOC #10 4 3 7
DOC #11 3 4 7
DOC #12 0 3 3
DOC #13 0 1 1
DOC #14 2 2 4
DOC #15 3 3 6
DOC #16 1 4 5
DOC #17 4 1 5
DOC #18 2 2 4
DOC #19 2 2 4
DOC #20 0 0 0

마찬가지로 0 ~ 7회로 불규칙하며 이렇게 해서는 제대로 학습이 되지 않는다. 또한 데이터도 정확히 GPU 갯수만큼 나눈 450건 까지는 1:1 샘플링하지만 그 이상이 되면 랜덤 샘플링을 하며 마찬가지로 868 ~ 873건 수준에서 더 이상 보지는 못한다. 각 GPU는 4개 일때에 비해 2배 더 많은 2,250회를 추출하지만 여전히 전체를 다 보지는 못하는 것이다.

코드 분석

예전부터 의심했던 부분은 make_supervised_data_module()에서 train/test를 split하는 부분이다. 코드는 다음과 같다.

perm = np.random.permutation(len(raw_data))
split = int(len(perm) * 0.9)
train_indices = perm[:split]
eval_indices = perm[split:]

이처럼 랜덤 순열을 생성한 다음 train/eval로 split하는데, 문제는 seed가 없다 보니 각 프로세스가 저마다 다르게 split한다는 점이다. eval 조차도 프로세스 별로 다르게 갖게 되니 어떤 프로세스에서는 학습 문서인 것이 어떤 프로세스에서는 검증 문서가 된다. 각 문서는 5 또는 0으로 균등한 횟수를 보여야 하지만 이처럼 각 프로세스가 다른 순서로 데이터를 갖고 있으면 문서마다 학습에 참여하는 횟수가 상이하므로 문제가 있다.

그나마 2→4 GPU는 loss 변화에 차이가 있지만 4→8은 거의 차이가 없고, 8→16은 아예 비슷하게 떨어진다. 데이터를 뒤섞어 놓고 GPU가 많을수록 더 적은 수를 샘플링하니 오히려 더 불규칙하게 추출되는 것이다. 이 때문에 GPU 8장을 초과하는 멀티 노드에서는 제대로 학습이 되지 않는다.

이 문제는 다음과 같이 seed를 부여함으로써 해결할 수 있다.

np.random.seed(42)
perm = np.random.permutation(len(raw_data))

이제 랜덤하게 추출해도 seed가 고정이므로 각 프로세스마다 동일한 순서로 데이터를 갖게 된다. GPU 4개로 다시 한번 실험한 결과는 다음과 같다.

문서 GPU 0 GPU 1 GPU 2 GPU 3 합계
DOC #1 0 1 4 0 5
DOC #2 2 1 2 0 5
DOC #3 1 1 2 1 5
DOC #4 0 1 2 2 5
DOC #5 0 2 2 1 5
DOC #6 2 0 2 2 6
DOC #7 3 0 0 2 5
DOC #8 0 0 0 0 0
DOC #9 3 1 0 1 5
DOC #10 1 1 2 1 5
DOC #11 0 0 0 0 0
DOC #12 0 0 0 0 0
DOC #13 2 0 3 0 5
DOC #14 2 1 1 1 5
DOC #15 1 0 0 4 5
DOC #16 2 0 2 1 5
DOC #17 1 1 1 2 5
DOC #18 2 2 1 0 5
DOC #19 0 3 2 0 5
DOC #20 2 0 2 1 5

거의 모든 문서가(1개 제외) 5회씩 학습에 참여한다. eval로 split되어 0회 참여한 문서도 각 프로세스마다 동일하다. 이제 멀티 GPU에서 문제 없이 정상적으로 학습이 된다.

결론

seed 고정으로 문제를 해결한 후 FastChat 공식 레포를 뒤적여 보니 이미 5월 29일자로 패치가 적용된 것을 확인할 수 있었다. 이전 버전을 쓰다보니 몰랐던 사항이고, 지금은 이 코드 또한 다른 파일로 리팩토링 된 상태다.

이제 멀티 GPU 뿐만 아니라 멀티 노드에서도 문제 없이 학습이 잘 될 것이다.

wandb에서 X축을 Relative Time (Wall), Y축을 eval/loss로 설정해서 보면 기존에는 멀티 노드에서도 loss가 동일하게 떨어졌지만 이제 6 노드에서는 1 노드에 비해 훨씬 더 큰 폭으로 loss가 떨어지는 것을 확인할 수 있다.

is a collection of Papers I have written.
© 2000 - Sang Park Except where otherwise noted, content on this site is licensed under a CC BY 4.0.
This site design was brought from Distill.