Paper review

PyTorch FSDP: Experiences on Scaling Fully Shared Data Parallel #1

임로켓 2025. 3. 3. 08:32
728x90

https://arxiv.org/pdf/2304.11277

 

이 논문에서는 대규모 모델 학습을 위한 PyTorch Fully Sharded Data Parallel (FSDP)에 대해 소개합니다. FSDP는 Tensor 구현, 디스패처 시스템, CUDA memory caching allocator등과 밀접하게 동작하도록 공동 설계되어 있습니다. 이를 통해 다양한 하드웨어 구성에서 자원 활용을 최적화하는 여러 기법과 설정을 자연스럽게 통합하고 있습니다. 이 논문의 실험에서 FSDP는 Districuted Data Parallel과 유사한 성능을 달성하고 있으며, 큰 모델을 지원하면서 TFLOPS 기준으로 "거의" 선형적으로 확장 가능합니다. 

 

모델의 규모가 매우 빠르게 증가하면서, 이러한 모델의 훈련을 간소화하고 높은 효율성을 가진 도구의 필요성 또한 증대 되고 있습니다. 

이러한 도구/방법론에는 다음과 같은 것들이 있습니다. 

 

Pipelie parallelism은 모델 인스턴스를 여러 stage로 분할하고, 여러 장치에 이 stage들을 분배하여, activation과 gradient를 stage 경계를 넘어서 통신하는 것입니다.

 

Tensor parallelism은 모델의 매개변수를 분할하고 각 장치에서 부분 계산을 수행하며 필요한 layer 경계에서 activation을 통신합니다.

 

Zero-Redundancy parallelism은 tensor parallelsim과 마찬가지로 매개변수를 분할하지만, 필요한 경우 매개변수를 통신하여 다시 비 분할 형태로 북구하고, 모든 장치에 이것이 복제되어 있는 것처럼 모델을 실행하는 것입니다.

 

이러한 방법들에 대한 문제점으로는 다음과 같은 것들이 있습니다. 

1) 이러한 방법의 일부는 특정 모델 아키텍처와 밀접하게 통합되어 있어, 일반적인 솔루션으로 사용하기 어렵습니다. 

2) 일부 기술은 기반이 되는 머신 러닝 프레임워크의 내부 인터페이스 위에 구축이 되어 있는데, 이 머신 러닝 프레임워크는 급변하고 있어 프레임 워크 구현의 변화에 매우 취약합니다. 

 

이 논문에서 제안하는 FSDP는 DeepSpeed의 ZeroRedundancy Optimizer에서 영감을 받았으나, PyTorch 기반으로 수정된 설계 및 구현을 갖고 있습니다. FSDP는 모델 인스턴스를 더 작은 단위로 분해하고, 각 단위 내의 모든 매개변수를 평탄화하고 분할합니다. 분할된 매개변수는 계산 전에 필요에 따라 통신 및 복구 되고, 이후 즉시 폐기합니다. 이러한 접근 방식은 FSDP가 한 번에 하나의 단위에서만 매개변수를 생성하면 되므로 피크 메모리 소모를 상당히 줄여줍니다. 

 

FSDP의 구현과 설계 측면에서 고려해야 할 점은 다음과 같습니다. 

 

User Experience 측면에서, 분산 훈련의 사용자 경험을 로컬 훈련과 동일하게 하는 것입니다. DistributeDataParallel(DDP)를 사용할 경우 모든 장치에서 모델을 복제해야 하는데, FSDP는 DDP의 API를 채택할 수는 있지만, 대형 모델의 경우 하나의 GPU에 올릴 수 없기 때문에 효율적으로 초기화 하기가 어렵습니다. 또한 이종의 하드웨어로 구성된 GPU 클러스터 내에서 최적화해야 한다는 것이죠. 그리고 분산 훈련 중 GPU 장치가 완전하게 활용될 수 있도록 하는 것도 중요합니다. 이를 보장하려면 비 계산 작업으로 인한 다운타임을 최소화하는 것이 필수적입니다. Memory Planning 역시 중요한 역할을 합니다. PyTorch는 캐싱을 통해 GPU 메모리 블록 할당을 효율적이고 관리할 수 있도록 만들어 줍니다. Memory fragmentation 으로 인한 잦은 조각 모음은 훈련 속도를 매우 느리게 만들고, 대규모 모델에서 특히 두드러 지기 때문에 가능한 많은 GPU 메모리를 잘 활용할 수 있도록 해야 합니다. 

 

FSDP는 이 문제들을 다양한 방법으로 접근해서 해결하고 있습니다. 

 

Deferred Initialization

FSDP는 지연 초기화를 도입하여, 사용자가 더미 장치에서 모델 인스턴스를 생성하고 초기화 중 호출된 작업을 기록할 수 있게 하고, 기록된 작업을 실제 GPU 장치에서 재생하여 모델을 유닛 단위로 초기화하고 샤딩할 수 있도록 합니다. 이를 통해 로컬 훈련과 유사한 사용자 경험을 제공하고 대규모 모델 훈련의 효과적인 스케일링이 가능하도록 한 것입니다. 

 

Configurable sharding strategies

하드웨어 이질성을 처리하기 위해, 클러스터의 물리적 상호 연결 토폴로지에 맞게 사용자 정의 하여 구성 가능한 샤딩 전략을 제공합니다. 

 

Parameter sharding design

매개변수 샤딩 설계는 필연적으로 통신을 삽입하여 계산이 차단되고 버블이 생길 수 있는데, FSDP에서는 작업 재정렬 및 매개 변수 미리 가져오기를 통해 계산과 통신이 겹치도록하여 버블을 제거합니다. 

 

이러한 방법들을 통해서 FSDP의 성능을 평가해보면, 소형 모델에서는 DSSP와 유사한 성능을 달성할 수 있고, TFLOPS 측면에서는 거의 선형적으로 활장 가능한 훨씬 더 큰 모델도 지원할 수 있음을 알 수 있습니다. 

 

 

FSDP를 이해하기 전에, PyTorch의 분산 훈련 방법들에 대해 잠깐 살펴 보겠습니다. 

 

Model Replication

이 방식은 고용량 데이터 셋을 처리하기 위해 여러 장치에 걸쳐 계산을 확장하고 분산하도록 설계된 것으로, DistributedDataParallel(DDP)는 각 장치에 모델 복제본을 유지하고, 역전파 동안 집합적인 AllReduce 연산을 통해 기울기를 동기화하여 훈련 중 복제본 간의 모델 일관성을 보장합니다. 훈련을 가속화하기 위해 DDP는 기울기 통신을 역방향 계산과 겹치게 하여 다양한 리소스에서 동시에 워크로드를 실행할 수 있도록 하고 있습니다. 하지만 DDP는 모든 모델 파라미터, 기울기 및 최적화 상태가 하나의 GPU 장치의 메모리에 맞아야 하는 단점이 있습니다. 이 문제 때문에 대형 모델을 지원하는데 적합하지 않죠. (Out of Memory)

 

Model Partitioning

단일 GPU 장치에 모델을 적재할 수 없는 경우 모델을 더 작은 구성 요소로 분할하고, 여러 장치에 분산시켜야만 합니다. 파이프라인 병렬 처리 및 텐서 원격 프로시저 호출이 바로 이것입니다. 두 기술 모두 대규모 모델을 여러 장치에 걸쳐 확장할 수 있는데 사용할 수 있지만, 모델을 sequence of stage로 제한하거나, 원격 계산을 삽입하기 위해 모델 작성 코드를 수정해야 하는 요구가 있어 사용자 채택에 장애가 될 수 있습니다. 

 

Model Sharding

모델의 파라미터 샤딩은 메모리 사용량을 줄이고,  단일 GPU 장치의 메모리 용량을 넘어서는 모델을 지원할 수 있습니다. 모델을 샤딩하면 각 랭크는 모델 파라미터의 일부만 보유하게 되기 때문에 로컬 훈련과 동일한 계산을 수행할 수는 없습니다. 정확성을 보장하기 위해서는 훈련 과정에서 파라미터 샤드로 계산을 수행하고 activation을 그에 맞게 통신해야 합니다. 

 

 

아래 그림은 FSDP의 Algorithm Overview 입니다. 간단한 여섯 개의 레이어 모델을 사용하여 전체 워크 플로우를 보여주고 있습니다. FSDP가 모델을 [Layer0, Layer3], [Layer1, Layer2], [Layer4, Layer5]의 세부분으로 분해하고 이 세 부분 각각을 하나의 FSDP 단위로 래핑하고 매개 변수를 그에 맞게 분할합니다. 

 

 

포워드 계산이 레이어 1에 들어가기 전에 FSDP는 다른 랭크로부터의 샤드를 모아 레이어 1과 레이어 2의 비샤딩 데이터를 수집합니다. 비샤딩 파라미터를 사용하여 FSDP는 해당 레이어의 로컬 계산을 수행하고, memory footprint를 줄이기 위해 방금 수집한 peer 샤드를 해제합니다. 그렇기 때문에 전체 포워드 패스 동안 FSDP는 한 번에 하나의 유닛만 완전히 materialize 해야 하고, 다른 모든 유닛은 샤딩 상태를 유지할 수 있습니다. 

 

FSDP에서 사용한 deferred initialization이라는 메커니즘은 위에서 잠깐 이야기 했었는데, 모델 매개변수 텐서를 시뮬레이션 된 가짜 장치에 할당하여, 이 과정에서 텐서에 수행된 모든 초기화 작업을 기록하고, 가짜 장치에서 GPU 장치로 텐서가 디오하면 모든 기록된 작업을 자동으로 재생하도록 합니다. 이 기술을 채택함으로써 사용자는 GPU 메모리 블록을 할당하지 않고도 어떤 third party 라이브러리에서도 모델 인스턴스를 생성할 수 있으며, 여전히 매개변수 초기화 구현을 정확하게 할 수 있습니다. 

 

FSDP의 샤딩 전략은 메모리 사용량과 통신 오버헤드를 결정하는데 중요한 역할을 하는 요소로, 여러 전략을 제공 합니다. 샤딩 요소를 1로 설정하면 FSDP는 모델을 완전히 복제하고 AllReduce를 사용하여 그래디언트 감소를 수행하는 일반적인 데이터 병렬 처리로 단순화됩니다. 반대로 샤딩 요소를 장치의 수와 같게 설정하면 FSDP는 모델을 완전히 샤딩하여 각 장치가 모델을 1/N 만큼 보유하게 되죠. 하이브리드 샤딩은 샤딩 요소가 1과 W(WoldSize) 내에 있을 때를 말합니다. 샤딩 요소를 1로 하여 전체 복제를 하면  DDP와 거의 유사합니다. 

 

Full sharding 전략은 당연하게도 가장 낮은 메모리 사용량을 초래하지만, 반대로 가장 많은 통신 오버헤드를 발생시킵니다. 예를 들어 전체 조각화를 하게 되면 대역폭 최적 링 알고리즘을 사용할 경우 DDP 대비 1.5배의 통신 오버헤드와 볼륨을 갖고 있습니다. 따라서 효율성 극대화를 하기 위해서는 이 통신을 신중하게 최적화 해야 겠죠. 

 

논문에서는 집합 통신 효율성에 대한 입력 크기의 영향성을 이해하기 위한 두 가지 실험 세트를 수행했습니다. 

이 그래프 왼쪽에서는 입력 크기의 불균등성 때문에 All-gather 연산 성능의 저하 및 latency의 변동성이 나타나는 것을 볼 수 있습니다. 그리고 오른쪽에서는 입력 크기가 작아지면 오히려 GPU 간의 통신 효율이 급격히 나빠지는 것을 알 수 있습니다. 

 

 

여기서 FlatParameter는 모델의 파라미터를 1차원 텐서로 변환한 것입니다. 효율적인 통신을 제공하기 위해서 FSDP는 하나의 FSDP 단위 내 모든 매개변수를 이렇게 큰 FlatParameter로 구성하는 것입니다(concat). 모델의 파라미터를 여러 장치에 분산하여 저장하고 있는 것을 알 수 있습니다. 총 16개의 GPU에 파라미터를 분산하고 있습니다. 파라미터의 수가 장치의 수에 완벽하게 나누어 떨어지지 않을 경우에는 Padding을 하게 되는데 사이즈가 각각인 매개변수를 이렇게 concat하여 1차원 텐서로 만든 다음 각각의 rank에 동일한 사이즈로 분산할 경우 이 때 나누는 최소 단위를 F라고 하면 패딩은 F-1 만큼만 생기게 됩니다. 기존처럼 여러 사이즈의 매개변수를 GPU 개수로 나누는 것 대비 패딩이 적게 생성되는 장점이 있는 것입니다. 

 

샤딩이 1보다 크고, 전체 월드 사이즈 (W)보다 작게 사용된 경우 이를 하이브리드 샤딩이라고 합니다. 샤딩과 복제를 결합해서 하이브리드라고 부르죠. 매개변수는 각 그룹 내에서 샤딩되고, replication 그룹 내에서 복제 됩니다. 그레디언트 reduction을 위해 모든 랭크에 대한 단일 reduce-scatter는 각각의 샤딩된 그룹 내에서 reduce-scatter가 되고, 각 복제된 그룹 내에서 all-reduce가 이뤄집니다. 

 

각 모델이 여러 GPU에 나뉘어 저장되어 있기 때문에 각 GPU는 자신이 담당한 부분의 그래디언트를 계산하고, 각 GPU가 가진 그레디언트를 모두 합산(Reduce) 한 다음, 다시 GPU 별로 나누어 분배(Scatter)합니다. 이를 통해 각 GPU는 모델의 특정 부분에 대한 그레디언트 만을 정확하게 다시 나누어 받게 되는데, 중복 저장과 통신량이 줄어 들게 되죠. 그리고 모델이 replication되어 있기 때문에 이 replication 그룹 내 GPU 끼리는 추가적으로 통신하여 모두 동일하고 완전한 그레디언트를 갖도록 하는 것입니다. 이 단계에서는 각 replication 그룹 내에서 All-reduce 연산을 통해 GPU들끼리 가진 데이터를 합산한 뒤, 모든 GPU가 동일한 데이터를 가지도록 합니다. 

 

이렇게 2 단계를 거치는 이유는 효율성입니다. 모든 GPU가 항상 전체 그레디언트를 갖게 되면 메모리 사용량이 커지고, 통신량도 많아지게 되서 효율성이 떨어집니다. 따라서 처음에는 필요한 부분만 나누도록하고, 이후 data parallelism 방식으로 같은 데이터를 갖는 GPU끼리는 최종 그레디언트를 일치시키기 위해 all-reduce를 사용하는 것이죠. 그리고 데이터센터에서는 GPU간 통신 속도가 성능에 큰 영향을 주기 때문에 데이터 센터의 물리적 구조를 잘 활용하는 것이 중요합니다. 논문에서 착안한 것은 데이터 센터가 oversubscribed fat-tree topology(최상위 계층으로 갈수록 대역폭이 제한)를 일반적으로 사용하기 때문에, 서버 내부(host 내부)의 GPU간 통신 속도가 서버 간 통신 속도보다 훨씬 빠르다는 점입니다. 따라서 호스트 내부에서의 통신(locality)를 우선적으로 사용하는 방식을 사용할 수 있도록 device mesh(각 장치가 서로 어떻게 연결될지를 정의)를 특정한 배열로 구성하는 것입니다. 

 

그래서 위 그림을 다시 보면, 같은 샤딩 그룹 내부에 있는 GPU들끼리는 전체 파라미터를 나누어서 가지고 있으니까 forward 및 backward 계산 시 반드시 다른 GPU들과 데이터를 주고받는 통신(All-Gather와 Reduce-Scatter)이 필요하게 되므로, 샤딩 그룹은 최대한 빠른 통신이 가능한 GPU끼리 묶습니다. 같은 서버내 NVLink를 사용할 수 있도록 하면 성능이 가장 좋겠죠? 반대로 replication 그룹 간 데이터 교환은 빈번하지 않고 최소화되어 있습니다. 그래서 상대적으로 더 먼 거리(서로 다른 서버)에 있는 GPU끼리 묶어도 됩니다. 이런 구성을 통해 상대적으로 느린 서버 간 네트워크 대역폭을 최소한으로 사용하도록 하여 통신 비용을 절감합니다.

 

논문에서는 이러한 장점 외에도, 중간 크기 모델의 요구 사항 역시 하이브리드 샤딩 설계의 중요한 부분이라고 말하고 있습니다. 중간 크기 모델에 전체 샤딩을 사용할 경우 가속기 메모리를 완전히 활용할 만큼 크지 않은 경우 런타임 오버헤드 및 메모리 낭비를 초래한다는 것입니다. 하이브리드 샤딩에서는 샤딩 계수 F를 단순히 조절함으로써, 훨씬 더 풍부한 메모리-throughput 트레이드 오프를 조절할 수 있게 됩니다. 

 

PyTorch에서 분산 학습을 위한 핵심 모듈인 c10d 라이브러리는 프로세스 그룹(ProcessGroup)이라는 개념을 제공하여, 여러 프로세스가 함께 통신 작업(collective operations)을 수행할 수 있도록 합니다. 특히 GPU를 활용한 NCCL 백엔드를 사용할 때, ProcessGroupNCCL 클래스는 GPU마다 별도의 내부 CUDA 스트림을 생성합니다. 이는 기본적인 연산을 수행하는 디폴트 스트림과 비동기적으로 작동하며, 서로 병렬적으로 실행되어 연산 효율성을 높이는 데 도움을 줍니다.

 

이 비동기 통신 작업들은 Work 객체를 반환하고, 이를 통해 작업 완료 시점을 추적할 수 있습니다. Work.wait() 메소드를 호출하면 통신 작업이 끝날 때까지 CPU가 기다리게 됩니다. 정확한 연산 수행을 위해 ProcessGroupNCCL은 내부 스트림을 기본 스트림과 동기화한 후 collective 연산을 시작합니다.DistributedDataParallel(DDP)은 이 특징을 활용하여 역전파(Backward) 과정 중 그래디언트 AllReduce 연산을 비동기적으로 미리 실행해 두고, 역전파 연산과 병렬로 수행하여 GPU 효율성을 극대화합니다.

 

그러나 Fully Sharded Data Parallel(FSDP)의 경우, eager 실행 환경에서는 다음 단계에서 필요한 FlatParameter가 미리 정해지지 않으므로, 연산이 끝난 뒤에야 AllGather를 수행할 수 있습니다. 즉, DDP와 달리 연산 후에 collective 연산이 실행되기 때문에, DDP와 같은 방식으로 비동기적 collective 방식을 그대로 적용할 수 없습니다.

 

ProcessGroupNCCL은 기본 스트림과 항상 동기화해야 하므로, AllGather가 연산이 완료되기 전에는 시작되지 않습니다. 이를 해결하기 위해 FSDP는 기본 연산 스트림과 분리된 별도의 CUDA 스트림에서 AllGather를 실행하여, 이전 연산과의 불필요한 의존성을 끊고 두 작업을 병렬로 실행할 수 있게 합니다. 결과적으로 FSDP의 통신 작업들은 단순히 Work 객체를 기다리는 것이 아니라 스트림 간의 동기화 방식으로 실행됩니다.

 

추가로, FSDP는 성능 최적화를 위해 역전파 과정에서 가장 외부 FSDP 단위의 파라미터는 메모리에 유지합니다. 이를 통해 forward가 끝난 후 파라미터를 해제하고 backward 시작 시 재차 AllGather하는 비효율적인 작업을 방지합니다.

 

FSDP는 각 프로세스(rank)가 하나의 CUDA 디바이스만 사용하며, 모든 통신 작업(AllGather 및 ReduceScatter)을 하나의 프로세스 그룹에서 수행합니다. 이 때문에 FSDP의 통신 작업들은 프로세스 그룹 내부의 NCCL 스트림에서 순차적으로 실행됩니다.

특히 역전파(Backward) 과정에서 FSDP는 현재 FlatParameter에 대한 ReduceScatter를 실행한 후 다음 FlatParameter를 위한 AllGather를 수행합니다. 이렇게 하나의 NCCL 스트림을 사용하다 보니, ReduceScatter가 다음 AllGather를 차단하게 되고, 결과적으로 다음 그래디언트 연산까지 차단되는 현상이 발생하여 성능 저하의 원인이 될 수 있습니다.

 

이러한 문제를 해결하고 연속된 두 개의 통신 작업이 성능 저하를 일으키지 않도록 하기 위해, FSDP는 "Backward Prefetching"이라는 방식을 사용합니다. 이 방식은 현재의 ReduceScatter가 실행되기 전에 미리 다음 AllGather 작업을 시작함으로써 통신 작업 간의 불필요한 대기 시간을 줄이는 것입니다.그러나 eager 실행 환경에서는 다음에 어떤 FlatParameter를 AllGather 해야 하는지 미리 결정하는 것이 어렵다는 문제가 있어서, FSDP는 이를 해결하기 위해 forward 단계에서 모듈이 실행된 순서를 역방향으로 기록하여, 이를 backward 실행 순서의 지표로 사용합니다. 

 

CPU 연산이 비교적 느린 작업 부하의 경우, CPU 쓰레드가 다음 forward의 AllGather 작업을 충분히 빠르게 시작하지 못해 NCCL 스트림의 효율적인 활용이 어려울 수 있습니다. 이때 모델의 연산 그래프가 반복(iteration) 간에 일정하게 유지되는(static) 경우, FSDP는 이전 iteration의 모듈 실행 순서를 가정하고 미리 forward 단계에서 다음 AllGather를 명시적으로 실행하여 이러한 병목 현상을 방지합니다. 이를 통해 forward 연산 전에 다음 AllGather를 시작하는 "Forward Prefetching" 전략이 적용됩니다.

 

메모리 매니지먼트 측면에서도 고려한 부분이 있습니다. FSDP의 메모리 관리 PyTorch는 GPU 메모리 할당과 해제를 효율적으로 관리하기 위해 CUDA 캐싱 할당자(caching allocator)를 사용하는데, caching allocator는 빈번한 cudaMalloc 및 cudaFree 호출을 피함으로써 GPU 동기화 비용을 줄입니다. 그러나 CPU 쓰레드가 GPU 실행 속도보다 지나치게 앞설 경우, 캐싱 할당자가 미리 할당된 메모리 블록을 재사용하지 못하고 새로 할당해야 하는 상황이 발생하여 성능이 저하될 수 있습니다.

FSDP는 이러한 문제를 해결하기 위해 "Rate Limiter"를 도입하여 CPU 쓰레드가 너무 빠르게 AllGather 작업을 시작하지 않도록 제한합니다. 이를 통해 메모리 블록이 효율적으로 재사용되고, 불필요한 cudaMalloc과 cudaFree 호출을 방지하여 시스템이 안정적으로 작동하도록 돕습니다. Rate Limiter는 동시에 최대 두 개의 AllGather 작업만 실행될 수 있도록 하여 통신과 연산 간 효율적인 중첩(overlap)을 유지합니다.

 

글이 생각보다 길어져서, 나머지 부분은 2탄에서 이어 하겠습니다. 

 

 

 

728x90