Paper review

PyTorch FSDP: Experiences on Scaling Fully Shared Data Parallel #2

임로켓 2025. 3. 3. 09:30
728x90

https://arxiv.org/pdf/2304.11277

https://unnamed-underdogs.tistory.com/39

 

PyTorch FSDP: Experiences on Scaling Fully Shared Data Parallel #1

https://arxiv.org/pdf/2304.11277 이 논문에서는 대규모 모델 학습을 위한 PyTorch Fully Sharded Data Parallel (FSDP)에 대해 소개합니다. FSDP는 Tensor 구현, 디스패처 시스템, CUDA memory caching allocator등과 밀접하게

unnamed-underdogs.tistory.com

 

리뷰 2탄에서는 Implementation, Evaluation 까지 살펴보고 마치겠습니다. 

FSDP 사용 방식

FSDP를 사용하는 방법은 크게 두 가지로 나눌 수 있습니다:

  • FullyShardedDataParallel 모델 래퍼(wrapper): 전체 모델을 래핑하여 각 하위 모듈(sub-module)을 FSDP 유닛으로 자동 변환합니다.
  • fully_shard 모듈 어노테이터(annotator): 모델 구조를 유지하면서 nn.Module의 forward 및 backward 훅을 통해 FSDP 로직을 추가하는 방식입니다.
  •  

FSDP 유닛 초기화 전략

1. GPU에서 비분할(unsharded)된 모델 초기화

모델 전체를 단일 GPU에 초기화할 수 있을 만큼 메모리가 충분하다면 이 방식을 사용할 수 있습니다. 하지만 이 경우에도 최적화(optimizer) 단계에서 추가적인 메모리 소모가 발생할 수 있기 때문에 FSDP를 사용하여 초기화 하는 것이 좋다고 합니다. 이 방식은 모델이 반드시 GPU 메모리 내에 들어 가야만 사용할 수 있기 때문에 대규모 모델에서는 할 수 없는 방법입니다. 

2. CPU에서 비분할(unsharded)된 모델 초기화

모델이 GPU 메모리를 초과할 만큼 매우 큰 경우, CPU에서 모델 전체를 초기화한 뒤, 모델을 작은 단위(unit)로 나누어 GPU로 순차적으로 이동시키고 각 유닛을 즉시 분할합니다. 이 "스트리밍(streaming)" 방식은 다음 유닛을 처리하기 전에 메모리 오버헤드를 감소시켜 대규모 모델도 효과적으로 초기화할 수 있도록 하는 것입니다. 당연히 GPU 메모리 내에 한 번에 들어가는 모델 대비 규모가 큰 모델을 이렇게 하기 때문에 초기화 속도가 1번에 대비해서 느릴 수 밖에 없겠죠. 거기에 CPU의 제한된 메모리 대역폭 및 병렬 처리 능력 때문에 좀 더 느려질 수 있습니다. 

 

PyTorch FSDP의 Flat Parameters 이해하기

1탄에서 잠깐 알아본 FlatParameter의 개념과 관리하는 방법에 대해서 좀 더 자세히 살펴 보겠습니다. 

FlatParameter의 정의와 역할

  • FlatParameter 클래스는 PyTorch의 nn.Parameter를 상속하여, 일반적인 nn.Parameter와 동일하게 동작합니다.
  • FSDP는 FlatParameter를 관리하기 위한 별도의 클래스인 FlatParamHandle을 제공하며, 사용자는 FullyShardedDataParallel 또는 fully_shard를 통해 FlatParamHandle을 통해서만 FlatParameter와 상호작용합니다.

FlatParameter의 구조 및 중요성

하나의 FlatParameter는 하나의 FSDP 유닛(unit) 내의 모든 파라미터 텐서를 저장하는 공간으로 작동합니다. 유닛의 경계를 어디에 두느냐에 따라 각 유닛 내 파라미터들이 함께 모아지거나 분산되는 시점이 결정되기 때문에 이 경계를 모델의 실제 실행 순서와 잘 맞추면, 통신과 계산이 더 효과적으로 중첩(overlap)되어 GPU 자원 활용률과 전체적인 학습 속도가 크게 향상됩니다.

 

FSDP에서 통신 시점이 유닛의 경계에 따라 결정되는 이유는, 통신이 각 FSDP 유닛 내에 포함된 파라미터들을 단위로 진행되기 때문입니다.

  • AllGather: 유닛 내의 분할된(sharded) 파라미터들을 하나로 모으는 작업이므로 유닛의 경계가 설정되면, 이 경계 내부의 파라미터들을 동시에 모으게 됩니다.
  • ReduceScatter: 유닛 내에서 계산된 그래디언트들을 프로세스 간에 통합한 후 다시 분할하여 각 GPU로 나눠주므로, 유닛의 경계에 따라 어떤 파라미터들을 통합할지 결정됩니다.

이러한 이유로, 유닛의 경계가 잘못 설정되면 통신과 연산이 비효율적으로 동기화될 수 있으며, 반대로 경계를 잘 설정하면 연산과 통신이 자연스럽게 겹쳐져(overlap) 자원 활용률과 처리량이 최적화됩니다.

 

Runtime에서의 통신 관리

FSDP는 로컬 모델 인스턴스에 그래디언트 감소(reduce)와 파라미터 수집(gather)을 위한 통신 연산을 추가하여 확장합니다. 통신 연산을 정확하고 효율적으로 수행하려면, 타이밍이 매우 중요합니다.

  • 너무 일찍 통신 연산을 시작하면, 아직 업데이트되지 않은 파라미터나 그래디언트를 사용하게 되어 정확성을 해칩니다.
  • 너무 늦게 통신을 시작하면 네트워크 대역폭이 낭비되며 다음 연산이 지연됩니다.

Forward 단계

  • FullyShardedDataParallel은 PyTorch의 기본 nn.Module의 forward() 메소드를 오버라이드하여 forward 전후에 필요한 로직을 추가합니다.
  • 함수형 방식인 fully_shard register_forward_pre_hook()  register_forward_hook()을 사용하여 이러한 로직을 삽입합니다.

Backward 단계

PyTorch에서 제공하는  다양한 hook을 사용하여, 이 작업을 정밀하게 관리합니다.

  • Tensor Hook (register_hook()):
    • Tensor의 그래디언트가 생성될 때 커스텀 로직을 실행합니다.
    • FSDP는 각 유닛의 forward 출력 텐서에 이 후크를 걸어, backward가 해당 유닛에 진입하기 전에 미리 통신 작업을 시작합니다.
  • Backward Hook (queue_callback()):
    • 전체 backward가 끝나기 직전에 실행됩니다.
    • FSDP는 이를 통해 아직 끝나지 않은 통신 작업이 모두 완료될 때까지 기다려서, 옵티마이저가 그래디언트를 너무 빨리 사용하지 않도록 합니다.
  • AccumulateGrad Hook:
    • 파라미터 그래디언트 축적이 완료될 때 즉시 실행됩니다.
    • 각 FlatParameter의 AccumulateGrad에 이 hook을 걸어, 그래디언트 준비 직후 바로 ReduceScatter를 실행하여 지연을 최소화합니다.

 

Native Mixed Precision 관리

  • 파라미터는 저정밀도(low precision)와 고정밀도(full precision) 두 가지로 관리됩니다.
  • forward와 backward 계산은 저정밀도로, 옵티마이저 단계는 고정밀도로 수행합니다.
  • 사용자는 파라미터, 그래디언트 축소 및 버퍼를 독립적으로 원하는 정밀도로 지정할 수 있습니다.

기존 혼합 정밀도 접근법은 메모리 사용량을 증가시킵니다. 그러나 FSDP는 각 로컬 sharded FlatParameter만 GPU 메모리에 유지하고, 비분할된(unsharded) FlatParameter는 필요할 때만 동적으로 할당하여 메모리 사용량을 감소시킵니다.

FSDP의 혼합 정밀도 방식은 operator 수준에서 즉각적인 변환을 수행하는 torch.amp.autocast와 달리, 각 FlatParameter별로 미리 변환 작업을 최소화하고, 모든 통신 작업을 낮은 정밀도로 실행하여 통신 비용도 절감합니다.

일반적으로 사용자는 FP16 또는 BF16을 낮은 정밀도로, FP32를 높은 정밀도로 선택합니다. FP16은 FP32보다 동적 범위가 작아 언더플로우(underflow)와 오버플로우(overflow)의 위험이 있으며, 이를 방지하기 위해 그래디언트 스케일링(gradient scaler)을 사용합니다. 그러나 FSDP는 그래디언트를 분할 관리하므로 일반적인 그래디언트 스케일링 방식이 맞지 않으며, 대신 자체적으로 설계한 분할 그래디언트 스케일러(sharded gradient scaler)를 제공합니다.

 

 

 

이 그림은 FSDP의 성능과 효율성을 다양한 관점에서 실험하여 비교한 결과입니다. 

(a) Model Scale

  • 모델 크기(611M, 2.28B, 11.3B 파라미터 수)에 따라 네 가지 방식(Full Sharding, Hybrid Sharding, Full Replication, DDP)의 성능을 비교한 그래프입니다.
  • 수직축은 GPU당 처리 가능한 성능(TFLOPS/GPU)을 나타냅니다.
  • 가장 작은 모델(611M)에서는 네 가지 방법 간 성능 차이가 거의 없지만, 모델 크기가 커질수록(11.3B 모델) Full Sharding 및 Hybrid Sharding의 성능이 DDP보다 크게 향상됨을 보여줍니다. 즉, FSDP의 샤딩 방식은 모델이 클수록 더 큰 성능 우위를 나타냅니다.

(b) GPT-175B Backward Prefetch

  • GPT-175B 모델에서 GPU 개수(128, 256, 512)에 따른 성능을 비교하며, 특히 Backward Prefetching 기능이 있는 경우와 없는 경우를 비교합니다.
  • Backward Prefetching을 사용할 때 GPU당 성능이 더 높으며, GPU 수가 증가해도 성능이 상대적으로 덜 저하됩니다. 이는 Backward Prefetching이 GPU 개수가 많은 환경에서 특히 성능 향상에 효과적임을 보여줍니다.

(c) Rate Limiter (Ms = Machines)

  • 다양한 모델(RegNet, T5, DeepViT)을 서로 다른 머신 수에서 Rate Limiter의 유무에 따라 배치당 지연 시간(Latency)을 비교한 그래프입니다.
  • Rate Limiter를 사용할 경우(Limit, 파란색 막대) 일반적으로 배치당 지연 시간이 더 짧아지며, Rate Limiter를 사용하지 않을 경우(No Limit, 주황색 막대) 보다 높은 지연이 발생합니다.
  • 이는 Rate Limiter가 적절한 시점에 통신을 제한하여 메모리 블록의 재사용을 높이고, 불필요한 메모리 할당/해제를 줄임으로써 성능을 향상시킴을 나타냅니다.

 

 

이 그림은 FSDP의 훈련 처리량(Training Throughput)을 다양한 상황과 모델에서 분석한 것입니다. 각각의 그래프는 특정 조건에서 FSDP의 성능이 어떻게 변화하는지 보여주고 있네요.

(a) DHEN QPS (Queries Per Second)

  • X축은 사용된 GPU(A100 80GB)의 수를 나타내며, Y축은 QPS(초당 처리된 쿼리 수)의 90번째 백분위수 성능을 나타냅니다.
  • 네 가지 설정을 비교합니다:
    • Full Sharding (RAF/NRAF)
      RAF는 "reshard after forward"로 forward 이후 파라미터를 다시 나누는 방식,
      NRAF는 forward 이후 재분할을 하지 않는 방식입니다.
    • Hybrid Sharding (RAF/NRAF)
      일부 파라미터만 sharding 하는 혼합 방식입니다.
  • GPU 수가 증가할수록 전반적으로 성능이 감소하는 경향을 보이며, 특히 128 GPU 이상에서 급격한 성능 감소가 나타납니다.
    • 특히 Full Sharding 방식에서 성능 하락 폭이 더 큽니다. Full Sharding 방식은 모든 파라미터를 각 GPU에 완전히 나누어 저장하기 때문에 forward와 backward 단계마다 파라미터를 모으고(AllGather), 그래디언트를 나누는(ReduceScatter) 통신 작업을 반복적으로 수행해야 합니다. 따라서 GPU 개수가 많아질수록 각 GPU가 보유한 파라미터가 매우 작아지며, 통신 횟수는 그대로 유지되거나 오히려 늘어나기 때문에, 통신으로 인한 오버헤드가 급격히 커지기 때문에 성능 하락폭이 커지게 됩니다. 

(b) GPT-175B TFLOPS

  • GPT-175B 모델을 이용하여 GPU 수를 증가시켰을 때의 성능 변화를 보여줍니다.
  • 두 가지 배치 크기(B=1, B=2)를 비교하며, 배치 크기가 클수록 GPU당 처리 성능이 더 좋게 유지됩니다.
  • GPU 수가 증가해도 성능이 비교적 일정하게 유지되며, 큰 배치 크기(B=2)에서 더 높은 성능을 나타냅니다.

(c) T5-11B TFLOPS

  • T5-11B 모델을 이용하여 GPU 수에 따른 성능 변화를 나타냅니다.
  • 여기서도 배치 크기(B=8, B=16)를 비교하며, 큰 배치 크기(B=16)일 때 성능이 더 높습니다.
  • GPU 수가 256을 넘어서면서 성능이 급격히 떨어지는 현상이 나타납니다. 특히 작은 배치 크기(B=8)에서는 256개 이상의 GPU에서 성능이 크게 저하됩니다.

이 결과들을 종합해서 봤을 때, GPU의 개수를 늘릴 때, 성능이 무조건 비례적으로 증가하지는 않으며, 특정 시점 이후 급격한 성능 저하가 발생할 수 있고, 배치 크기를 크게 하면 성능을 상대적으로 유지할 수 있지만, 이 역시도 GPU 개수가 매우 많아지면 한계에 도달하게 되는 것을 알 수 있습니다. Sharding 방식의 선택(Full vs Hybrid)과 forward 이후의 재분할 방식(RAF vs NRAF)이 성능에 영향을 주는데 Hybrid sharding과 NRAF 방식이 상대적으로 성능을 더 잘 유지하는 경향이 있는 것도 알 수 있죠.

 

 

이 그림은 FSDP의 메모리 사용량(Memory Footprint)을 분석한 결과로, 모델의 크기와 설정 방식에 따라 GPU에서의 최대 메모리 사용량(Peak Memory)이 어떻게 변화하는지를 나타냅니다. 

(a) DHEN 모델의 메모리 사용량

  • X축은 사용된 GPU(A100 80GB)의 수를 나타내며, Y축은 각 GPU당 최대 메모리 사용량(GB)을 나타냅니다.
  • 네 가지 Sharding 방식(Full Sharding, Hybrid Sharding)과 forward 후 재분할 여부(RAF/NRAF)를 비교했습니다.
  • GPU 수가 증가할수록 각 GPU가 가지는 파라미터 크기가 줄어들기 때문에, GPU당 최대 메모리 사용량이 전반적으로 감소합니다.
  • 특히 Full Sharding 방식이 Hybrid Sharding 방식보다 GPU 메모리 사용량이 조금 더 낮은 경향이 있습니다. 이는 Full Sharding이 모든 파라미터를 분할하여 더 효율적으로 메모리를 사용하기 때문입니다.

(b) GPT-175B 모델의 메모리 사용량

  • GPU 수(128~512개)에 따른 최대 메모리 사용량을 분석한 결과로, 각 GPU에서 실제로 할당된(Alloc), 활성 상태인(Active), 예약된(Reserved) 메모리의 사용량을 나타냅니다.
  • 두 가지 배치 크기(B=1, B=2)에 따라 메모리 사용량을 비교했습니다.
  • GPU 수가 증가할수록 각 GPU에서 실제로 필요한 메모리 사용량(Alloc)은 감소하지만, Reserved 메모리는 여전히 높게 유지됩니다. 즉, PyTorch의 캐싱 메모리 할당자가 GPU에 여유 공간을 미리 확보하기 때문입니다.
  • 더 큰 배치 크기(B=2)는 각 GPU가 처리하는 데이터가 많아져 메모리 사용량이 전반적으로 증가하지만, GPU가 많아질수록 여전히 감소하는 경향이 나타납니다.

(c) T5-11B 모델의 메모리 사용량

  • T5-11B 모델에서의 GPU 수에 따른 메모리 사용량을 배치 크기(B=8, B=16)에 따라 비교한 그래프입니다.
  • 이 경우에도 GPU 수가 증가하면 GPU당 메모리 사용량은 감소합니다. 그러나 GPU 수가 매우 많아지면(512개) 메모리 사용량 감소가 포화 상태에 도달하며 거의 변화하지 않습니다.
  • 배치 크기가 클수록(B=16) 메모리 사용량은 높지만, 증가 폭이 상대적으로 일정하게 유지됩니다. 작은 배치 크기(B=8)에서는 GPU 수가 많아질수록 메모리 사용 효율성이 떨어져 Reserved 메모리 양이 상대적으로 많이 남습니다.

따라서, GPU 자원을 효율적으로 사용하기 위해 GPU 개수와 배치 크기, Sharding 방식을 신중히 결정해야 한다는 것이죠. FSDP는 GPU 수가 많아질수록 GPU당 메모리 사용량을 효과적으로 감소시키고, 특히 Full Sharding 방식이 가장 효과적으로 메모리 절약을 할 수 있는 것을 알 수 있습니다. 하지만 Reserved 메모리(실제 할당되지 않은 메모리)는 GPU 수가 많아도 일정 수준에서 유지되므로, 메모리 절약 효과는 특정 GPU 개수 이상에서는 포화 상태에 도달할 수 있습니다. 그리고 배치 크기가 클수록 각 GPU의 메모리 사용량은 높아지지만 GPU가 많아지면 그 증가량은 완화됩니다. 

 

이 논문을 리뷰하기 전에 DeepSpeed의 ZeRO (https://arxiv.org/pdf/1910.02054)를 먼저 읽고 굉장히 유사한 접근이라고 생각했는데, 이 논문에서도 ZeRO를 여러 번 언급합니다. 

 

FSDP는 기존의 ZeRO와 Cross-Replica Sharding 접근법에서 영감을 받아 설계되었다고 하면서도, 근본적으로 다른 특징을 갖고 있다고 선을 긋고 있죠. 😀 기존 접근법인 ZeRO 및 Cross-Replica Sharding은 모델의 파라미터 텐서를 GPU 간에 나누기 위해 모델 파티셔닝(model partitioning) 또는 파라미터별로 나누는 방식(per-parameter sharding)을 사용합니다. 또한, Broadcast 및 Gather와 같은 통신 연산(collective communication primitives)을 이용하여 각 GPU 간 값을 동기화합니다. 이러한 접근은 동일한 기능을 제공할 수 있지만, 다음과 같은 문제점들을 동반할 수 있습니다. 

  1. 부하 분산의 불균형
    • 파라미터별 또는 모델 파티셔닝 방식은 각 GPU가 처리하는 작업량이 고르지 않을 수 있습니다. 이러한 불균형은 분산된 GPU 간의 성능을 떨어뜨리고, 동기화된 분산 학습의 효율성을 저하시킵니다.
  2. 프레임워크 내부 변경에 따른 호환성 문제
    • 이 방식은 머신러닝 프레임워크의 내부를 직접적으로 수정하는 방식입니다(텐서 저장 및 메모리 관리 등). 따라서 프레임워크가 업데이트되거나 새로운 기능이 추가될 경우 내부 구조가 변하면서 기존 접근법이 제대로 동작하지 않을 가능성이 있습니다.

FSDP는 프레임워크의 핵심 요소와 함께 설계된(native) 방식으로 프레임워크의 내부 구조와 밀접하게 결합되어 있기 때문에 내부 구현이 변경되더라도 쉽게 호환성을 유지할 수 있습니다. 개인적으로 1번도 효율을 추구하기 위해서 매우 중요한 부분이지만, 유지 보수성과 재사용성을 고려해봤을 때 2번도 무시 못할 부분이라고 생각합니다. 오픈소스를 가져와 사용할 때 항상 고민되는 지점이기도 하고요.

 

아뭏든 FSDP 리뷰는 여기서 마칩니다. 

끝. 

728x90