Tech

GEMM (General Matrix to Matrix Multiplication)과 GPU 아키텍처의 이해

임로켓 2025. 3. 1. 13:06
728x90

1. GEMM이란?

GEMM(General Matrix to Matrix Multiplication)은 일반적인 행렬 곱셈을 의미합니다.

특히 딥 러닝 연산에서 핵심적인 연산으로, 아래처럼 표현됩니다. 

 

여기에서 A, B는 곱셈의 대상이 되는 행렬이며, C는 결과 행렬, ɑ, β는 스칼라 값입니다.  즉 A X B  행렬 곱 결과에 스칼라를 곱하고, 기존 C 행렬에 스칼라를 곱한 값을 더하는 연산입니다. 신경망의 대부분 레이어는 사실상 큰 규모의 행렬 곱 연산을 수행합니다. GEMM 최적화는 곧 딥러닝 성능 향상과 직결되는 것이죠. 그래서 딥러닝 프레임워크 (PyTorch, Tensor Flow)와 GPU (CUDA, cuBLAS)는 GEMM 연산을 최대한 빠르게 실행하려고 최적화를 계속해 나가는 것이고요. 

 

이러한 GEMM 연산을 최적화하는 기법에는 여러 방법들이 있는데요, 하나씩 간단히 살펴보겠습니다.

먼저 행렬의 데이터가 CPU 또는 GPU의 빠른 메모리(Cache)에 존재하도록 데이터 접근 패턴을 최적화하고, 메모리 접근 시간을 최소화하여 성능을 극대화하는 것이 필요합니다. 이것을 Cache Locality 최적화라고 합니다. 

다음으로는 Tiling, Blocking이라고 하는 방법입니다. 큰 행렬을 작은 블록(tile)으로 나누어 연산하는 것으로, 이를 통해서 데이터 재사용성을 높이고 캐시 hit ratio를 높일 수 있습니다. 

그리고 GPU의 SM(Streaming multiprocessor)가 병렬로 연산할 수 있도록 CUDA 커널을 활용하여 처리 속도를 증가시키고요. 

마찬가지로 SIMD(Vectorization)이라고 하는, CPU의 SIMD(Single Instrcution, Multiple Data) 명령어 및 GPU의 warp 단위 연산을 활용하여 한 번의 연산으로 여러 데이터를 처리하기도 합니다. 그 외에 Loop unrolling 등의 Instruction Level Parallelism을 증가시키는 방법도 있습니다. 

 

일반적으로 기본적인 행렬 연산은 BLAS(Basic Linear Algebra Subprograms)라고 하는 라이브러리를 통해 제공되는데, NVIDIA가 제공하는 GPU 기반의 고성능 BLAS 라이브러리인 cuBLAS는 GEMM 최적화가 뛰어납니다.

OpenBLAS CPU 코드와 GPU cuBLAS API 함수 DGEMM 계산 속도 비교(19.2배 향상) - [출처] https://developer.nvidia.com/ko-kr/blog/nvidia-math-라이브러리를-통한-gpu-애플리케이션-가속/

 

2. Attention과 GEMM의 관계

Attention은 transformer 모델의 핵심이고, GEMM을 주요 연산으로 사용합니다. 

[AI 상식] LLM은 어떻게 동작할까 - Attention 편에서 알아본 것처럼, attention은 query(Q), key(K), value(V)를 통해 계산합니다. 

[AI 상식] LL은 어떻게 동작할까 - Attention

 

 

attention 연산의 수식은 위와 같습니다. 이 연산에서 GEMM은 다음 단계에서 사용되죠.

- Query, Key, Value 생성 과정에서의 행렬 곱
- Attention 점수(QK^T) 계산
- Attention 점수와 Value(V) 곱셈

따라서 GEMM 최적화는 attention 성능 향상의 핵심이 되는 것입니다. 

실제 구현에서는 많은 Attention Head와 배치가 동시에 처리되므로, Batch GEMM (torch.bmm)을 사용하여 GPU 병렬성을 높이고요, 큰 사이즈의 행렬을 처리하기 때문에 tiling을 통해 캐시 효율을 높이는 것도 매우 중요합니다. 

 

3. Transformer의 MLP 연산

MLP(Multi-Layer Perceptron)는 transformer 내부의 Feed-Forward Network를 의미하며, 다음과 같은 수식으로 표현할 수 있습니다.


이 과정에서 보면, 두 번의 선형 변환(Linear projection)이 GEMM 연산으로 수행되죠.
- 첫 번째: 입력 데이터(x)와 W₁의 곱
- 두 번째: 활성화 함수 이후 결과와 W₂의 곱

MLP에서는 입력 텐서 X [batch_size, seq_len, hidden_dim]와 W [hidden_dim, intermediate_dim]을 곱하면서, 입력 텐서의 차원을 확장(=모델의 표현력을 확장)하고, activation 이후 다시 W2 [intermediate_dim, hidden_dim]과 곱하는 연산을 통해 확장된 차원을 다시 원복 하는 동작을 합니다. 여기서 hidden_dim 보다 intermediate_dim 자체가 크기 때문에 매우 큰 규모의 GEMM 연산이 되겠죠. 따라서 MLP에서도 GEMM 연산의 비중이 매우 크고, 성능의 병목이 될 수 있습니다. 

 

4. GPU 아키텍처와 GEMM 연산에서의 Tiling

Tiling은 큰 행렬을 작은 블록으로 나눠서 처리하는 최적화 기법입니다. 큰 행렬을 작은 블록(타일)으로 나누어 연산하고, 각 타일을 독립적으로 계산해서 결과를 얻는 간단한 아이디어이죠. 이 아이디어를 바탕으로 GEMM 연산의 최적화를 하려면 GPU 메모리 구조와 밀접하게 연관되어 있는데요. GPU 메모리 구조는 크게 Registers -> Shared Memory -> Global Memory로 나뉩니다. Global memory가 가장 접근이 느리고, Shared Memory는 더 빠르지만 용량에 제한이 있습니다. 따라서 GEMM 같은 대규모 연산을 최적화하기 위해서는 연산 시 Global Memory 접근 횟수를 줄이고, Shared Memory를 효율적으로 사용하여 데이터 접근 속도를 향상합니다. 특히 warp 내 thread들이 연속된 메모리를 read 할 경우 GPU에서 이를 한 번에 read 할 수 있게 되기 때문에 매우 효율적으로 동작할 수 있게 됩니다. 

 

이러한 tiling의 실제 구현을 간단하게 알아보면, 이런 형태를 갖게 됩니다. 

1. 타일 크기 결정 (보통 16x16 또는 32x32)
2. 데이터를 Global Memory에서 Shared Memory로 복사
3. Shared Memory에서 Thread 단위로 빠르게 연산
4. 결과를 다시 Global Memory에 저장

 

5. GPU의 Thread와 Warp 개념

Thread는 GPU 연산의 가장 작은 단위이고, Warp는 32개의 Thread가 모여 동시에 실행되는 GPU 연산의 최소 단위로, SIMT(Single Instruction, Multiple Thread) 방식으로 동작합니다. 즉, GPU는 개별 thread 단위가 아니라 warp 단위로 명령을 처리하는 것이죠. 따라서 warp 내 모든 thread가 동일한 명령을 실행하는 것이 효율적이고, 만약 다른 경로로 나뉘는 코드가 있으면 branch divergence (분기 발산)이 발생하여 성능이 저하됩니다. 

 

GPU에서의 thread는 CPU thread 대비 가벼운 것으로 알려져 있는데요, 이는 다음과 같은 아키텍처의 차이 때문입니다. 

GPU는 한 번에 수천~수만 개의 Thread를 병렬적으로 처리하는 SIMT 방식을 사용하는데, 여러 개의 Thread를 같은 명령어로 묶어서 실행하므로, Thread 관리가 상대적으로 간단합니다. CPU는 일반적으로 서로 다른 명령어를 개별적으로 처리하여, 하나의 Thread 관리가 복잡한 반면, GPU는 동일 명령을 처리하기 때문에 개별 Thread에 대한 관리 부담이 적습니다. 또한 GPU의 Thread는 CPU Thread와 달리 독립적인 Program Counter (PC), Register File, Stack 같은 무거운 리소스를 많이 소모하지 않습니다. 아주 제한된 레지스터와 간소한 상태 정보만 가지고 있습니다. 이에 따라 Thread를 전환할 때 CPU처럼 복잡한 상태 정보를 저장하고 복구하는 과정이 필요 없습니다

 

Warp는 warp scheduler에 의해서 실행되는데, SM(Sreaming multiprocessor) 내에 일반적으로 2~4개의 Warp Scheduler가 존재하며, 각 scheduler는 warp를 동시에 실행합니다. Warp Scheduler는 warp가 메모리 접근 등으로 멈추면 즉시 다른 warp를 실행하여 GPU 자원을 최대한 활용 방식을 사용하여 latency를 hiding 합니다. 

 

 

이상으로 GEMM 연산과, Attention, MLP에서의 GEMM 연산의 중요성, 그리고 GPU의 아키텍처를 간단하게 살펴보고 GEMM 최적화를 어떻게 하는지 알아봤습니다. GEMM 연산 최적화는 GPU 아키텍처의 특성(메모리 구조, Warp와 Thread의 실행 방식)을 잘 이해하고 활용해야 하고요, tiling, warp 최적화 등의 전략을 적절히 사용하여야 큰 성능 향상을 이룰 수 있게 되는 것입니다. 

 

728x90