DistributedDataParallel의 동작원리

torch.nn.parallel.DistributedDataParallel를 사용하면 기존 model을 랩핑하는 코드로 간단히 데이터 패러럴 처리를 할 수 있다. 그렇다면 DDP는 내부적으로 어떻게 동작을 하고 어떤 과정을 통해 해당 동작을 수행하는지, 내부 코드를 살펴보면서 확인해보도록 한다.

2022년 6월 14일 초안 작성

내용

torch.nn.parallel.DistributedDataParallel(이하 DDP)를 사용하면 기존에 model을 랩핑하는 코드로 간단히 데이터 패러럴 처리를 진행할 수 있다.

ddp_model = torch.nn.parallel.DistributedDataParallel(model)

그렇다면 DDP는 어떤 작업을 진행할까? 공식 문서1에는 다음과 같이 기술되어 있다.

This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension. The module is replicated on each machine and each device, and each such replica handles a portion of the input. During the backwards pass, gradients from each node are averaged.

입력을 분할하여 각 노드에서 학습을 진행하고, 백워드(backwards)시 각 노드는 그래디언트(gradients)의 평균을 갖는다고 한다.

DDP 역할

2

즉, 위 그림처럼 각 노드는 동일한 모델을 복제하고 있다가 입력 데이터를 나눠 받고, 이를 각자 학습하여 백워드시 계산한 그래디언트를 동기화(sync grads)한다. 동기화 후에는 평균을 구하고 이 값으로 옵티마이저가 모델 파라미터를 업데이트 한다.

모델 파라미터 전체 브로드캐스팅

참고로, 위 그림에는 나와 있지 않지만 각각의 노드는 모델 뿐만 아니라 모델 파라미터도 항상 동일한 값을 갖는다. 그 이유는 DDP를 초기화하는 시점에 rank 0의 모델 파라미터를 전체 rank로 브로드캐스팅하는 작업이 진행되기 때문이다. 이로 인해 처음부터 모두 동일한 값을 갖고 시작하며 중간에 그래디언트도 동기화하면서 계속해서 같은 값을 반영한다. 따라서 모델 파라미터는 항상 동일한 값을 유지한다.

실제로 DDP로 초기화할 때 PyTorch의 코드를 ditributed.py에서 살펴보면, 다음과 같이 rank 0의 모델 파라미터를 브로드캐스팅하는 작업이 있다.

# distributed.py
# Sync params and buffers. Ensures all DDP models start off at the same value.
_sync_module_states(
    module=self.module,
    process_group=self.process_group,
    broadcast_bucket_size=self.broadcast_bucket_size,
    src=0,
    params_and_buffers_to_ignore=self.parameters_to_ignore,
)

3

그래디언트 동기화

그렇다면 그래디언트 동기화는 어떤식으로 진행될까?

만약 DDP를 사용하지 않는다고 가정하면 다음과 같이 학습 과정 중에 그래디언트를 동기화 하고 평균을 구하는 부분을 직접 구현할 수 있다.4 실제로 이 코드는 PyTorch 공식 가이드에서 DDP의 원리를 설명할 때 직접 구현한 코드다.

""" Distributed Synchronous SGD Example """
def run(rank, size):
    torch.manual_seed(1234)
    train_set, bsz = partition_dataset()
    model = Net()
    optimizer = optim.SGD(model.parameters(),
                          lr=0.01, momentum=0.5)

    num_batches = ceil(len(train_set.dataset) / float(bsz))
    for epoch in range(10):
        epoch_loss = 0.0
        for data, target in train_set:
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            epoch_loss += loss.item()
            loss.backward()
            average_gradients(model)
            optimizer.step()
        print('Rank ', dist.get_rank(), ', epoch ',
              epoch, ': ', epoch_loss / num_batches)

위 코드는 학습을 진행하는 코드이며, 여기서 그래디언트를 동기화하고 평균을 구하는 average_gradients() 함수는 다음과 같다.

""" Gradient averaging. """
def average_gradients(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
        param.grad.data /= size

이 함수는 모든 모델 파라미터의 그래디언트를 all-reduce 하고, 이 값을 world_size로 나눈다. 이렇게 간단한 작업을 통해 그래디언트의 평균을 구하며, all-reduce 상태에서 평균은 모든 노드가 동일하므로 각각의 노드는 항상 동일한 모델 파라미터 값을 유지하게 된다. 물론 이렇게 직접 그래디언트 평균을 구현해도 되지만 실제로는 DDP를 사용하는 것이 훨씬 더 효율적이다. 그 이유는 바로 아래 설명한다.

DDP 버킷 최적화

DDP는 그래디언트 동기화 시 버킷 최적화가 추가로 진행된다.

5

먼저 바닐라 방식은 백워드가 진행되면 계산이 모두 끝난 후에 그래디언트를 동기화하는 작업이 시작된다. 통신 방식을 NCCL이라고 가정하면 이때부터 NCCL 통신이 시작되고 NCCL로 all-reduce가 완료되면 이제 그래디언트의 평균을 구하고 모델을 업데이트하게 된다. 그러나 DDP(Horovod도 마찬가지로)는 모델 파라미터를 일정 기준 버킷 단위로 할당하고 백워드가 뒤에서부터 진행되면서 버킷 만큼의 그래디언트 계산이 끝나면 바로 NCCL 통신을 시작한다. 이렇게 하면 나머지 백워드를 진행하면서 NCCL 통신을 동시에 진행할 수 있게 된다. 실제로 NCCL은 내부적으로 별도의 쓰레드를 생성하고 쓰레드간 직접 소켓 통신을 진행하기 때문에 계산이 진행되는 쓰레드와는 병목없이 동시에 수행이 가능하다. 물론 파이썬의 경우에는 GIL로 멀티 쓰레드 작업이 어렵지만 이 작업은 모두 C++로 구현된 모듈이므로 GIL의 제약 없이 동시 진행이 가능하다.

# distributed.py
# Can remove this branching once #73732 is landed.
if static_graph is True or self.find_unused_parameters is False:
    bucket_size_limits = [sys.maxsize]
else:
    bucket_size_limits = [dist._DEFAULT_FIRST_BUCKET_BYTES, self.bucket_bytes_cap]
bucket_indices, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
    parameters,
    bucket_size_limits,
    expect_sparse_gradient,
)

# Note: reverse list of buckets because we want to approximate the
# order in which their gradients are produced, and assume they
# are used in the forward pass in the order they are defined.
self.reducer = dist.Reducer(
    parameters,
    list(reversed(bucket_indices)),
    list(reversed(per_bucket_size_limits)),
    self.process_group,
    expect_sparse_gradient,
    # The bucket size limit is specified in the constructor.
    # Additionally, we allow for a single small bucket for parameters
    # that are defined first, such that their gradients don't spill into
    # a much larger bucket, adding unnecessary latency after gradient
    # computation finishes. Experiments showed 1MB is a reasonable value.
    self.bucket_bytes_cap,
    self.find_unused_parameters,
    self.gradient_as_bucket_view,
    param_to_name_mapping,
    # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first
    # bucket.
    dist._DEFAULT_FIRST_BUCKET_BYTES
)

6

PyTorch의 distributed.py를 보면 백워드 진행시 all-reduce가 진행될 수 있도록 self.reducer를 설정하는 부분이 있다. 또한 그래디언트는 뒤에서부터 계산된다고 가정하므로 버킷 리스트를 뒤집어서 보낸다.

// reducer.cpp
// Hook to execute after the gradient accumulator has executed.
hooks_.emplace_back(
  grad_accumulator->add_post_hook(
      torch::make_unique<torch::autograd::utils::LambdaPostHook>(
          [=](const torch::autograd::variable_list& outputs,
              const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32
            this->rpc_context_.set(
                ThreadLocalDistAutogradContext::getContextPtr());
#endif
            this->autograd_hook(variable_index);
            return outputs;
          })),
  grad_accumulator);

7

self.reducer를 설정할 때 dist.Reducer() 초기화 함수를 실행하는데, 이 부분의 구현은 reducer.cpp에 C++로 구현되어 있다. 여러가지 초기화 작업과 함께 특히 autograd_hook()은 나중에 그래디언트 계산이 완료되면 autograd 엔진이 호출하게 되는 부분인데, 이를 통해 그래디언트 계산과 동시에 해당 버킷의 NCCL 통신이 진행될 수 있게 한다.

버킷의 갯수는 앞서 distributed.py의 self.reducer 설정 바로 앞 부분에서 dist._compute_bucket_assignment_by_size()를 호출하여 계산하도록 되어 있는데, 이는 마찬가지로 reducer.cpp에 C++로 구현되어 있다. std::tuple<std::vector<std::vector<size_t>>, std::vector<size_t>> compute_bucket_assignment_by_size()를 호출하며 이 부분의 코드를 확인해보면 텐서의 크기 등을 이용해 버킷 사이즈를 계산하도록 구현되어 있다.

// reducer.cpp
for (const auto i : c10::irange(tensors.size())) {
  ...

  // when tensor_indices is empty, the index of tensors[i] assigned to
  // bucket is i, otherwise the tensor index is tensor_indices[i].
  auto tensor_index = i;
  if (!tensor_indices.empty()) {
    tensor_index = tensor_indices[i];
  }
  // If we expect a sparse gradient to be produced for this tensor, it cannot
  // be grouped together with other gradients and gets its own bucket.
  if (!expect_sparse_gradient.empty() &&
      expect_sparse_gradient[tensor_index]) {
        result.emplace_back(std::vector<size_t>({tensor_index}), kNoSizeLimit);
        continue;
  }

  auto key = BucketKey(tensor.scalar_type(), tensor.device());
  auto& bucket = buckets[key];
  bucket.indices.push_back(tensor_index);
  bucket.size += tensor.numel() * tensor.element_size();

  ...
}

8

이 과정을 통해 DDP는 백워드 진행시 그래디언트를 계산해 나가다 버킷 단위로 계산이 완료되면 바로 NCCL 통신을 진행하면서 계산과 동시에 통신을 진행하여 보다 빨리 전체 그래디언트 작업을 완료할 수 있게된다.

해당 사항은 DDP를 직접 개발한 Shen Li가 PyTorch 포럼에서 동일하게 확인해준 내용9이기도 하며,

Usually, there are 4 steps in distributed data parallel training:

local forward to compute losslocal backward to compute local gradientsallreduce (communication) to compute global gradients. This would be allreduce with SUM + divide by world size to calculate averageoptimizer step to use global gradients to update parameters

그의 설명에 따르면 DDP는 그래디언트를 버킷으로 구성하여 2번과 3번 과정을 오버랩(overlap)하여 병렬(parallel)로 고속 처리가 가능하다고 설명한다.

관련 연구

DDP는 PyTorch C++ Extentions를 이용하면 별도의 독립된 모듈 형태로도 개발할 수 있으며, 이를 이용해 다른 회사에서 구현한 구현체도 많이 나와 있다.

  • NVIDIA Apex는 torch.nn.parallel.DistributedDataParallel과 호환되는 apex.parallel.DistributedDataParallel 모듈을 제공한다. 내부 구현은 파이썬 코드로 거의 유사하게 구현되어 있다.

  • Microsoft DeepSpeed는 DeepSpeed 엔진을 이용해 DDP 대신 분산 학습을 구현할 수 있다. 그래디언트 평균을 구하는 것 외에도 Loss Scaling, Learning Rate Scheduler 등의 기능도 함께 제공한다.

정리

torch.nn.parallel.DistributedDataParallel를 사용하면 기존 model을 랩핑하는 코드로 간단히 데이터 패러럴 처리를 할 수 있다. DDP는 내부적으로 모델 파라미터를 브로드캐스팅하면서 초기화 하고 그래디언트를 동기화 하면서 각각의 노드가 항상 동일한 모델 파라미터 값을 유지하도록 한다. 또한 백워드 진행시 그래디언트를 계산해 나가다 버킷 단위로 계산이 완료되면 바로 NCCL 통신을 진행하면서 계산과 동시에 통신을 진행하여 보다 빨리 전체 그래디언트 작업을 완료할 수 있도록 최적화 되어 있다.

References

is a collection of Papers I have written.
© 2000 - Sang-Kil Park Except where otherwise noted, content on this site is licensed under a CC BY 4.0.
This site design was brought from Distill.