Deepseek v3 code review - model
자체 개발 중인 inference engine에 MoE 기능을 넣기 위해, Deepseek v3와 Llama MoE를 분석해보고자 합니다.
여기에서는 먼저 요즘 핫한 Deepseek을 먼저 살펴 봅니다.
github repo: https://github.com/deepseek-ai/DeepSeek-V3/tree/main/inference
글이 길어져서 결론을 먼저 위에 씁니다.
결론
Llama model과 비교를 해봤을 때, Deepseek의 구조는 동일한 transformer architecture 기반으로 전체적인 골격은 유사합니다.
가장 큰 차이는 당연하게 MoE 부분인 FFN의 설계입니다. Llama는 dense MLP 구조를 사용하므로, 모든 토큰이 동일한 MLP
path를 거치게 되고요. Deepseek은 라우팅 로직을 통해, 각 토큰이 어느 expert를 사용할지 결정합니다. 즉, gata + expert + shared expert가 결합되어 MoE 구조를 형성하고, 토큰마다 실제로 활성화되는 Expert 수만큼 계산하므로, 대규모 모델에서 계산 효율과 모델 용량을 모두 확보할 수 있는 것입니다.
parallelism 관련해서는 llama에서도 모두 충분히 지원가능하고, 이미 llama model을 서빙한다면 적용했을 부분이라서 다른 점은 없다고 생각합니다. 다만 임베딩 레이어에도 parallelization을 한 부분은 약간 다른점이라고 볼 수 있을 것 같네요.
그래서 개인적으로는 Llama architecture에 MoE를 적용하면 Deepseek MoE와 다른 부분이 있을까? 하는 짧은 생각이 드네요.
High-Level Overview
deepseek v3는 고성능 LLM 프레임워크로서, 일반적인 Decoder Transformer에 다음과 같은 특장점을 결합했습니다.
- FP8 Quantization: 기본 FP32/BF16보다 훨씬 작은 비트폭으로 연산/저장을 수행해 메모리 절감과 추론 속도 향상을 노립니다.
- Tensor Parallelization: Embedding과 Linear 계층을 모델 차원으로 분할(Column/Row)하여, 여러 GPU에서 병렬 수행.
- MoE(Mixture-of-Experts): 일부 블록에서 Gate를 통해 입력 토큰마다 Top-k Expert만 활성화함으로써 모델 용량을 효과적으로 확장.
- RoPE(Rotary Position Embedding) 확장: 기존 시퀀스 길이(original_seq_len)보다 긴 입력에 대해서도 부작용 없이 로터리 임베딩을 보정해 처리.
- LoRA: Attention의 Q/K/V projection에서 lora_rank 기법을 사용해 파라미터 효율을 높임.
- 분산 추론: NCCL(또는 다른 백엔드)을 통해 분산 환경을 초기화하고, rank별로 모델 파라미터 일부를 담당.
이는 특히 긴 시퀀스, 대규모 파라미터, 저렴한 메모리 사용량이 필요한 환경에서 유리한 구조입니다.
Deepseek의 지향점이 어디인지 명확하게 알 수 있는 부분이죠.
디렉터리 및 파일 구성
여기에서는 전체 코드 중 inference directory 내의 코드만 살펴봅니다.
크게 다음과 같은 코드들로 이루어져있고, 그 외 몇 개의 코드가 있는데 생략하겠습니다.
- kernel.py: Triton 기반의 커스텀 커널 - FP8 GEMM(fp8_gemm_kernel), Activation 양자화(act_quant_kernel), Weight dequantize (weight_dequant_kernel) 등.
- model.py: Transformer 전체 구성 요소로, ModelArgs, MoE(Gate/Expert), MLA(Attention), Block/Transformer 등 구현.
- generate.py: 추론용 CLI 스크립트. 분산 환경 초기화, 모델/토크나이저 로드, 대화형(interactive) 혹은 배치(input_file) 모드에서 텍스트 생성 수행.
model.py (Transformer 및 MoE 구현)
ModelArgs class
@dataclass
class ModelArgs:
...
max_batch_size: int = 8
max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16"
vocab_size: int = 102400
dim: int = 2048
inter_dim: int = 10944
moe_inter_dim: int = 1408
n_layers: int = 27
n_dense_layers: int = 1
n_heads: int = 16
# moe
n_routed_experts: int = 64
n_shared_experts: int = 2
n_activated_experts: int = 6
n_expert_groups: int = 1
n_limited_groups: int = 1
score_func: Literal["softmax", "sigmoid"] = "softmax"
route_scale: float = 1.
# mla
q_lora_rank: int = 0
kv_lora_rank: int = 512
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128
# yarn
original_seq_len: int = 4096
rope_theta: float = 10000.0
rope_factor: float = 40
beta_fast: int = 32
beta_slow: int = 1
mscale: float = 1.
모델의 주요 하이퍼파라미터/설정을 한 곳에 정의하고 있습니다. 뒤에 kerne.py 분석 시 좀 더 살펴보겠지만, dtype이 fp8이면 triton kernel이 사용됩니다.
ParalleleEmbedding class
class ParallelEmbedding(nn.Module):
def __init__(self, vocab_size, dim):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
assert vocab_size % world_size == 0
self.part_vocab_size = (vocab_size // world_size)
self.vocab_start_idx = rank * self.part_vocab_size
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
def forward(self, x):
# x: (B, S), 토큰 ID
# 1) 범위 밖 토큰 mask, offset 조정
# 2) F.embedding
# 3) dist.all_reduce
if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx
x[mask] = 0
y = F.embedding(x, self.weight)
if world_size > 1:
y[mask] = 0
dist.all_reduce(y)
return y
return y
여기서는 먼저, vocab_size를 world_size만큼 나누어 rank별로 일부분만 보관합니다.
- self.part_vocab_size = (vocab_size // world_size)
- self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
이렇게 vocab_size를 world_size만큼 나누어 rank별로 일부분만 보관하는 가장 큰 목적은 대규모 vocab 임베딩의 메모리 사용량을 분산하기 위함입니다. 보통 LLM에서는 vocab_size가 매우 커서(수십만~수백만 단어 이상) 그에 따른 임베딩 매트릭스 역시 상당히 큰 메모리를 차지합니다. 임베딩 매트릭스가 vocab size * 모델의 차원 dim이 되기 때문에 단일 GPU에 올리면 메모리 부담이 되니까, 1/n을 하면 GPU 메모리 부담을 낮출 수 있겠죠. vocal_start_idx와 vocab_end_idx로 각 rank 별로 담당할 구간도 지정해 줬습니다.
이 방식은 tensor parrellism으로, 자주 활용되는 기법으로, 많은 분들이 익숙하실거라 생각합니다.
Forward()에서는 입력 토큰이 현재 rank 범위에 속하면 해당 rank의 로컬 파라미터로 임베딩을 계산하고, 범위에 속하지 않으면 mask를 통해 임시로 0으로 대체한 뒤 계산하고 그 결과를 all_reduce로 합쳐 최종 임베딩 벡터를 얻고 있습니다. 각각의 rank에서 자신의 파라미터로만 값을 계산하고 자신의 구간에 속하지 않는 값은 0으로 대체되고, 각 계산된 값을 all_reduce 통해 합치는 것이죠.
linear()
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
...
if weight.element_size() > 1:
return F.linear(x, weight, bias)
elif gemm_impl == "bf16":
weight = weight_dequant(weight, weight.scale)
return F.linear(x, weight, bias)
else:
x, scale = act_quant(x, block_size)
y = fp8_gemm(x, scale, weight, weight.scale)
if bias is not None:
y += bias
return y
이 linear()은 주석에 적혀 있는 것처럼, y = xA^T + b 의 선형 변환을 수행하는 간단한 함수입니다.
입력 텐서가 주로 (batch, in_features) 이고 weight가 (out_features, in_features) 형태로 사용되기 때문에, weight에 대해서 transpose 해서 output (batch, out_features)로 변환되는 것입니다.
이 함수에는 quantization 여부 및 data type에 따라 내부 동작이 분기되는 특징만 있습니다. 코드는 간단해서 깊게 설명할 부분은 없을 것 같습니다. weight.element_size()는 pytorch tensor에서 각 element가 차지하는 bytet 수를 return 하는데요. 이 값이 1 (1byte, fp8) 이상, 즉 bf16 (2byte), fpt32 (4byte) 이면, 단순하게 pytorch의 F.linear()를 사용하고 있습니다.
weight.element_size()가 1 (fp8 weight)인 경우에, gemm_impl이 bf16이라는 것은 연산은 bf16 모드로 하고 싶다는 것이므로, fp8 weight를 다시 bf16으로 복원(dequantize)하고, F.linear()를 하고 있네요. 마지막 else에서는 weight와 연산 모두 fp8로 동작하는 구간입니다.
Linear class
class Linear(nn.Module):
...
dtype = torch.bfloat16
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size - 1) // block_size
scale_in_features = (in_features + block_size - 1) // block_size
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
else:
self.register_parameter("scale", None)
if bias:
self.bias = nn.Parameter(torch.empty(self.part_out_features))
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
...
return linear(x, self.weight, self.bias)
이 클래스는 FP8 quantization 및 scale parameter를 함께 처리할 수 있도록 설계되었습니다.
이 코드에서 보면, self.weight = nn.Parameter(torch.empty(out_features, in_features, ...) 로 가중치 텐서를 생성하는 것을 볼 수 있네요. 이 코드 역시 새로운 부분은 없는데, fp8 quantization 사용할 경우에는 블록화된 scale parameter를 만들어 등록하는 부분이 있습니다. (블록 단위 scale을 사용)
- if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size - 1) // block_size
scale_in_features = (in_features + block_size - 1) // block_size
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
fp8은 bf16, fp32 대비 부동소수 표현에 훨씬 더 적은 bit을 사용하기 때문에, 넓은 범위의 수를 담기 어렵고 오차가 쉽게 발생합니다. 이를 보완하기 위해 scale을 함께 관리해서 실제 계산시 fp8 x scale을 수행하여 올바른 실수값을 얻는 것입니다. (dequantization)
forward()에서는 이전에 살펴 봤던 linear()를 그대로 호출하고 있습니다.
ColumnParallelLinear class
class ColumnParallelLinear(Linear):
...
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert out_features % world_size == 0
self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
...
y = linear(x, self.weight, self.bias)
return y
이 클래스는 model parallelism를 구현하기 위한 linear 레이어입니다. 주로 대규모 모델을 여러 GPU로 나누어 학습할 때, 하나의 linear를 여러 rank가 각자 일부(Output features의 일부 column, column-wise)를 맡아서 병렬 계산하는 구조입니다. 코드는 간단합니다. 각 rank(process)가 담당해야 할 output(출력차원, column 수)을 계산합니다. 단순히 전체 GPU 개수 (wolrd_size)로 나눠줍니다.
- self.par_out_features = out_feature // world_size
그리고 super().__init__()을 통해 Linear(in_features, part_out_features, ...) 형태로 초기화하기 때문에, 결과적으로 각 rank가 자기만의 (in_features, part_out_features) 형태의 가중치(self.weight)를 갖게 됩니다. 즉 (in_features, out_features)를 여러 rank가 나누어 가지고 있게 되는 것입니다. 그리고 각 rank에서는 part_out_features만을 계산하게 되는 것이죠. (output[batch, part_out_features]의 형태)
따라서 최종적으로는 concat 또는 all_gather를 통해 여러 rank가 구한 결과를 이어 붙이거나, 교환해서 output을 완성해야 합니다. 여기에는 없고, 아마 caller에 해당 루틴이 있을 것 같네요.
RowParallelLinear class
class RowParallelLinear(Linear):
...
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert in_features % world_size == 0
self.part_in_features = in_features // world_size
super().__init__(self.part_in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
...
y = linear(x, self.weight)
if world_size > 1:
dist.all_reduce(y)
if self.bias is not None:
y += self.bias
return y
ColumnParallelLinear class와 개념적으로 동일한데, 반대(?) 입니다. RowParallelLinear class는 입력 차원(in_features)을 여러 프로세스(또는 GPU)로 나누어 처리하는 Row-Parallel 방식의 선형 레이어 구현입니다. 그래서 보면 __init__() 내부의 구현이 in_feature로 바뀌고 동일한 것을 알 수 있습니다.
다만 forward()는 조금 달라집니다. 단일 GPU가 아닌 여러 대의 GPU에서 수행되는 경우에는 row-parallel, 즉 전체 입력의 일부만을 각 프로세스가 보고 있는 상황이므로, 서로 다른 프로세스(rank)에서 수행된 결과를 합쳐야 최종 output이 되는 것입니다. (선형 연산의 결과 = 모든 입력 차원의 곱의 합). All reduce() 후에는 모든 rank에 동일한 최종 결과 값이 존재하게 됩니다.
다시 정리해보면, ColumnParallelLinear 연산은 전체 입력을 동일하게 각 GPU가 받아서 output의 부분을 각자 계산한 후, all gather 나 concat을 통해 최종 output을 만드는 형태이고, RowParallelLinear 연산은 각 GPU가 부분 입력을 받아, 각자 계산한 후, 이 부분 결과를 all reduce를 통해 합산 하는 것입니다.
RMSNorm class
class RMSNorm(nn.Module):
...
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor):
...
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
RMSNorm(Root Mean Square Layer Normalization)은 layernorm에 비해, 계산량이 살짝 줄고 sequence model 등에서는 성능(layernorm 대비 평균 계산이 없음)이 나아서 transformer 모델에서 주로 쓰이고 있습니다.
dim은 입력 텐서의 차원 크기를 나타내고, eps는 계산 과정에서 분모가 0이 되는 불상사(?)를 막기 위해 사용하는 보정 값입니다. (epsilon, 1e-6). RMSNorm에서도 LayerNorm과 마찬가지로, 학습 가능한 스케일 파라미터(길이가 dim인 벡터)를 둡니다.
forward()에서는 F.rms_norm()을 그대로 사용하고 있네요. Pytorch 2.0부터 지원합니다.
RMSNorm의 연산을 살짝 살펴 보면, 다음과 같습니다.
⍺는 scale parameter이고, RMS(x)를 보면, 아래와 같은데요.
각 원소의 제곱 합에 1/n을 곱함으로써, 원소 제곱의 평균을 구하고 여기에 root를 씌운 값이 RMS입니다. x가 전반적으로 얼마나 큰지를 나타내는 이 값을 사용해 x를 길이가 1에 가까운 벡터로 만든 다음 (x/RMS(x)), 다시 ⍺를 곱해 원하는 크기로 스케일링 하는 것입니다.
precomput_freqs_cis()
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
...
dim = args.qk_rope_head_dim
seqlen = args.max_seq_len
beta_fast = args.beta_fast
beta_slow = args.beta_slow
base = args.rope_theta
factor = args.rope_factor
def find_correction_dim(num_rotations, dim, base, max_seq_len):
...
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
...
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
return max(low, 0), min(high, dim-1)
def linear_ramp_factor(min, max, dim):
...
if min == max:
max += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if seqlen > args.original_seq_len:
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
freqs = freqs / factor * (1 - smooth) + freqs * smooth
t = torch.arange(seqlen)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
이 precompute_freqs_cis()는 RoPE(Rotary Positional Embedding)의 frequency를 미리 계산하고, 이를 복소수 형태로 반환합니다. sequence 길이가 original_seq_len보다 클 때, 확장 보정 로직이 추가 되어 있고, beta_fast, beta_slow, rope_factor 등을 사용해 주파수를 부드럽게 보정(smoothing)하고 있습니다.
apply_rotary_emb()
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
...
dtype = x.dtype
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
y = torch.view_as_real(x * freqs_cis).flatten(3)
return y.to(dtype)
입력 텐서 (x)를 복소수 형태로 해석한 뒤 (torch.view_as_complex), 미리 계산된 복소 지수 (freqs_cis)와 곱해 회전(로테이션)을 적용하고, 다시 실수 형태로 복원(torch.view_as_real)하는 과정을 거칩니다. 이를 통해 좌표축상의 회전 변환이 일어나며, 모델이 위치 정보를 학습할 수 있게 됩니다. 중간에 x.float으로 변환하여 연산 정확도를 높인 후, 마지막에는 to(dtype)을 통해 원래 데이터 타입으로 반환하는 것을 알 수 있습니다. (fp/bf16 -> fp32 -> fp/bf16)
MLA(Multi-Headed Attention Layer) class
class MLA(nn.Module):
...
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.n_heads = args.n_heads
self.n_local_heads = args.n_heads // world_size
self.q_lora_rank = args.q_lora_rank
self.kv_lora_rank = args.kv_lora_rank
self.qk_nope_head_dim = args.qk_nope_head_dim
self.qk_rope_head_dim = args.qk_rope_head_dim
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
self.v_head_dim = args.v_head_dim
if self.q_lora_rank == 0:
self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
else:
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
self.softmax_scale = self.qk_head_dim ** -0.5
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale
if attn_impl == "naive":
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
else:
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
...
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
if self.q_lora_rank == 0:
q = self.wq(x)
else:
q = self.wq_b(self.q_norm(self.wq_a(x)))
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
if attn_impl == "naive":
q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(self.kv_norm(kv))
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
self.k_cache[:bsz, start_pos:end_pos] = k
self.v_cache[:bsz, start_pos:end_pos] = v
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
else:
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
if mask is not None:
scores += mask.unsqueeze(1)
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
if attn_impl == "naive":
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
else:
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = self.wo(x.flatten(2))
return x
여기서부터는 코드가 꽤 길어 집니다. 그 만큼 중요한 부분이라는 거겠죠? Deepseek V3 코드 중에 이 클래스 코드가 가장 깁니다.
일반적은 MLA를 확장한 형태로, 크게 특징을 살펴 보면 다음과 같습니다.
• LoRA로 Query, Key, Value를 부분적으로 분해하여 학습.
• RoPE를 위해 qk_rope_head_dim 차원에만 복소수 회전(cos, sin 변환) 적용.
• ColumnParallelLinear, RowParallelLinear를 이용해 GPU 분산. (Tensor parallel)
• Attn Impl이 "naive" vs "absorb"로 나누어져, Cache를 저장 및 갱신하는 방법이 다름.
먼저 __init__() 부터 살펴 보겠습니다.
if self.q_lora_rank == 0:
self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
else:
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
lora rank가 0이면, Q = wq(x)로 그대로 계산하고, 0이 아닌 경우에는 Q를 low rank 분해 (wq_a) 후 RMSNorm을 거쳐 다시 full rank (wq_b)로 확장하고 있습니다. 이 부분이 일반적인 attention 구현과 다른 부분으로, 일반 attention에서는 Q = xWq 처럼 직접 선형 변환을 하는 대신에, 여기 구현에서는 위에서 보는 것처럼 LoRA (Low-Rank Adaption) 기법을 사용하고 있습니다.
LoRA는 LLM에서 finetuning 시, 메모리 사용량 및 학습 비용을 크게 줄이고자 고안된 방법으로, pretrained 된 모델의 모든 파라미터를 직접 업데이트 하는 대신, 특정 가중치 (Q, K, V)에 대해만 low rank 행렬을 추가 학습하여, 파라미터 효율적으로 finetuning하는 기법입니다. 여기서 말하는 rank는 특정 행렬이 표현할 수 있는 선형 조합의 최대 개수로, W = (d x k)라는 행렬이 있을 때, rank = r 이라고 하고 W를 두 행렬 A, B로 나누게 되면 A = (d x r), B(r x k) 가 되어, 𝛥W = A x B의 실제 표현 범위가 r 차원 부분에 국한 된다는 것이죠.
그래서 W (768, 768) 인 경우에 최대 rank는 768이고, LoRA에서 rank 4 처럼 작은 값을 사용할 경우, 𝛥W 가 rank(𝛥W) <= 4 인 부분 공간에 속하게 되니까 rank가 낮을 수록 파라미터는 감소하지만, 표현력 또한 제한 된다는 것입니다. (weight 행렬 전체를 다 표현하지 못하고, 그 중 몇가지 방향성에 대해서만 학습한다는 의미). 그럼에도 불구하고 LLM에서는 LoRA를 사용해서 도메인 특화 finetuning 하면 의외로 좋은 결과들이 나와서 많은 연구를 하고 있다고 하네요.
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
kv weight 관련해서는 입력 (self.dim)을 kv_lora_rank + qk_rope_head_dim으로 투영합니다. (LoRA와 RoPE를 동시에 처리)
- K를 LoRA로 학습할 랭크 r + K가 회전을 적용할 차원을 묶어서 한 덩어리로 뽑는 과정
- 즉, wkv_a는 key, value 공통 전처리 역할을 하는 것입니다. (K의 일부 + v의 일부를 구하기 위한 중간 텐서)
위에서 살펴보았던 ColumnParallelLinear()를 통해 colume-parallel (각 GPU별로 출력 분할)을 통해 K/Q_nope + V 차원을 만듭니다. 마지막으로 RowParallelLinear()를 통해 n_heads * v_head_dim -> dim 으로 다시 축소합니다. 즉 멀티 헤드 결과를 다시 합친다는 것입니다.
MLP Class
class MLP(nn.Module):
...
def __init__(self, dim: int, inter_dim: int):
...
super().__init__()
self.w1 = ColumnParallelLinear(dim, inter_dim)
self.w2 = RowParallelLinear(inter_dim, dim)
self.w3 = ColumnParallelLinear(dim, inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
...
return self.w2(F.silu(self.w1(x)) * self.w3(x))
일반적인 MLP 구현은 다음과 같습니다. 2개의 선형 함수 레이어와 비 선형 활성 함수를 갖는 것이죠. 비 선형 함수를 사용함으로써, 선형 함수들로는 달성할 수 없는 복잡한 매핑을 표현할 수 있는 것입니다.
일반적으로 transfomers FFN에서는, 입력 차원을을 중간 차원 (4배 정도)으로 확장해서, 다양한 조합/연산을 수행한 뒤, 비 선형 함수를 거치게 됩니다. 그리고 마지막에는 모델 전체가 같은 차원을 주고 받을 수 있도록 차원을 다시 줄여서 정렬합니다.
이 class에서는 일반적인 FFN과 달리 다음과 같은 특징을 갖고 있습니다.
- 세 개의 선형 레이어: w1, w2, w3
- SiLU(Sigmoid Linear Unit)
- 입력 z가 음수일 때는 시그모이드값(0 < σ < 1)이 작아서 결과가 작아지거나 거의 0에 가까워짐
- 입력 z가 양수일 때는 z에 σ(z) ≈ 1 이 곱해져 크게 유지
- 출력 계산 시 SiLU(w1(x)) x w3(x) 라는 gating 형태의 곱셉 연산
- 게이트(Gate) 개념: w3(x) 또는SiLU(w1(x)) 등의 활성 함수를 쓴 부분과 별도의 선형 변환 결과를 곱하여, 불필요한 정보는 축소 혹은 차단하고 필요한 정보만 남기는 효과를 냅니다.
- 값이 0에 가까우면(닫힘), 해당 위치(채널)의 정보가 거의 차단되고,
- 값이 1에 가까우면(열림), 해당 정보가 손실 없이 그대로 반영됩니다.
w1은 ColumnParallelLinear()를 통해 입력 차원 dim에서 중간 차원 inter_dim으로 맵핑합니다.
w2은 RowParallelLinear()를 통해 확장된 inter_dim을 다시 dim으로 줄여주고요.
w3은 ColumnParallelLinear()을 사용하고, w1과 동일합니다. SiLU연산 시 w1과 w3를 사용합니다.
forward()에서는 간단하게 아래 코드를 사용하고 있습니다.
- F.silu(self.w1(x)) * self.w3(x))
self.w1()을 통해 x를 inter_dim으로 확장하고, 여기에 활성함수 SiLU()를 적용합니다.
이렇게 함으로써 일반적인 FFN 대비, 모델이 더 중요한 정보를 골라내는 능력이 커져 궁금적으로 성능 향상과 학습 안정성을 얻게 되는 것입니다.
Gate class
class Gate(nn.Module):
...
def __init__(self, args: ModelArgs):
...
super().__init__()
self.dim = args.dim
self.topk = args.n_activated_experts
self.n_groups = args.n_expert_groups
self.topk_groups = args.n_limited_groups
self.score_func = args.score_func
self.route_scale = args.route_scale
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
...
scores = linear(x, self.weight)
if self.score_func == "softmax":
scores = scores.softmax(dim=-1, dtype=torch.float32)
else:
scores = scores.sigmoid()
original_scores = scores
if self.bias is not None:
scores = scores + self.bias
if self.n_groups > 1:
scores = scores.view(x.size(0), self.n_groups, -1)
if self.bias is None:
group_scores = scores.amax(dim=-1)
else:
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
scores = (scores * mask.unsqueeze(-1)).flatten(1)
indices = torch.topk(scores, self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
if self.score_func == "sigmoid":
weights /= weights.sum(dim=-1, keepdim=True)
weights *= self.route_scale
return weights.type_as(x), indices
MoE(Mixture-of-Experts)에서 입력을 여러 expert 중 일부(Top-k)에게만 보내기 위해, 각 입력마다 “어떤 expert를 사용하면 좋을지”를 스코어링해야 합니다. 여기의 Gate 클래스는 그 스코어를 학습 가능한 weight, bias로 계산한 뒤, 원하는 방식으로(softmax or sigmoid) 정규화하고, Top-k Expert를 골라낼 뿐만 아니라, 해당 라우팅 점수(weights)도 반환합니다.
먼저 __init__()을 살펴 보겠습니다.
self.top_k = args.n_activated_experts는, 한 입력 당 선택(활성화)할 experts의 개수 입니다.
self.n_groups = args.n_expert_groups는, MoE에서의 라우팅 그룹 개수로, 1이면 모든 입력을 한 그룹으로 처리하고, 그 이상이면 특정 형태의 라우팅을 하는 것입니다.
self.topk_groups = args.n_limited_groups는 그룹 라우팅 중, 각 그룹에서 Top-k(여기선 topk_groups) 그룹만 실제 라우팅할 때 사용합니다.
self.weight는 (args.n_routed_experts, args.dim)의 형태의 gate weight 입니다.
forward()에서는, 먼저 라우팅 스코어 계산을 수행합니다.
- scores = linear(x, self.weight)
x(batch, dim) 이고 self.weight(n_routed_experts, dim) 이기 때문에 linear()를 수행하면 scores(batch, n_routed_experts)가 됩니다.
즉, 각 배치당 각 expert에 대한 score가 생기는 것입니다.
그리고 그 후 score_func 설정에 따라 softmax 또는 sigmoid로 모든 expert간 확률처럼 정규화를 수행합니다.
아래 코드는 Multi-group MoE를 다루는 케이스 입니다.
if self.n_groups > 1:
scores = scores.view(x.size(0), self.n_groups, -1)
if self.bias is None:
group_scores = scores.amax(dim=-1)
else:
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
scores = (scores * mask.unsqueeze(-1)).flatten(1)
MoE에서 n_groups를 설정하게 되면, n_routed_expers를 해당 그룹 수 만큼 묶어서 취급합니다. score.view(x.size(0), self.n_groups, -1)을 수행하게 되면, 기존 score(batch, n_routed_groups) -> (batch, n_groups, n_routed_expers/n_groups)가 됩니다.
그리고 각 그룹 내 최대값 or 상위 2개 합 등을 사용해 그룹별 대표 점수를 산출합니다.
- group_scores = scores.amax(dim=-1)
- 이 path에서는 그룹 내 점수가 가장 큰 expert 점수를 그 group의 대표 점수로 사용하게 되고요.
- group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
- 이 path에서는 그룹 내 점수가 가장 높은 상위 2개 expert를 찾고, 그 점수의 합을 대표 점수로 사용하는 것입니다.
indices = group_scores.topk(self.topk_groups, dim=-1)[1] 에서는 배치별로 topk_groups를 선택하고요. 그 이후 마스크를 이용해 선택되지 않은 그룹의 점수는 모두 0으로 만듭니다. 즉 여러 그룹 중에 topk groups만 활성화하는 것입니다.
indices = torch.topk(scores, self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
if self.score_func == "sigmoid":
weights /= weights.sum(dim=-1, keepdim=True)
weights *= self.route_scale
return weights.type_as(x), indices
다음으로 각 입력 별로 상위 top-k expert index를 뽑고요.
- indices = torch.topk(scores, self.topk, dim=-1)[1]
- torch.topk()는 결과를 (values, indices) 형태의 튜플로 반환하는데, values = 상위 k개 값, indices = 그 값들이 원래 텐서의 어느 위치에 있었는지를 나타내기 때문에, 어떤 expert 또는 어느 그룹이 선택 되었는지를 알 수 있게 되는 것입니다.
게이트 weigths도 original_scores에서 해당 indices만 추출합니다. 이후 sigmoid 및 route scale을 곱해주고요.
최종적으로 각 배치에 top-k rouing score + top-k expert index 를 리턴해줍니다. 따라서 이 (weights, indices)는 다음 단계에서 “해당 입력을 어떤 Expert(s)에 얼마만큼 보내고, 어떻게 합산할지”를 결정하는 필수 정보가 됩니다.
Expert class
class Expert(nn.Module):
...
def __init__(self, dim: int, inter_dim: int):
...
super().__init__()
self.w1 = Linear(dim, inter_dim)
self.w2 = Linear(inter_dim, dim)
self.w3 = Linear(dim, inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
...
return self.w2(F.silu(self.w1(x)) * self.w3(x))
이 클래스 에서는 개별 expert의 레이어를 정의하고 있습니다. 굉장히 간단합니다.
내부적으로 gate + MLP 형태로 입력을 처리하고 있습니다. MLP와 동일한 형태로 세 개의 선형 레이어가 있습니다.
다 아시겠지만, MLP와 다르게 Expert는 단일 GPU 내에서 수행하는 것으로 되어 있네요. MoE에서는 여러 GPU들을 분산 시켜 놓고 routing을 하고 있기 때문에, tensor parallelism 까지 적용할 경우 복잡도 증가 및 통신 오버헤드 부담이 있어 사용하지 않는 것 같습니다.
한 번 실험해 보고 싶은데? 라는 생각을 하긴 했는데, 가만 생각해보니 좀 더 분할이 필요하고 GPU 리소스가 충분하다면 expert를 더 만들어서 분할하면 될 것 같다는 생각이 들긴 합니다. 굳이 tensor parallelism을 할 필요가 없지 않나? 싶네요.
MOE class
class MoE(nn.Module):
...
def __init__(self, args: ModelArgs):
...
super().__init__()
self.dim = args.dim
assert args.n_routed_experts % world_size == 0
self.n_routed_experts = args.n_routed_experts
self.n_local_experts = args.n_routed_experts // world_size
self.n_activated_experts = args.n_activated_experts
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(args)
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)])
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
...
shape = x.size()
x = x.view(-1, self.dim)
weights, indices = self.gate(x)
y = torch.zeros_like(x)
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx]) * weights[idx, top, None]
z = self.shared_experts(x)
if world_size > 1:
dist.all_reduce(y)
return (y + z).view(shape)
드디어 대망의 MOE class 입니다.
__init__()을 먼저 살펴 보면, 이전에 설명 했던 내용들은 skip하고, 아래 내용부터 보겠습니다.
self.n_local_experts = args.n_routed_experts // world_size
self.n_activated_experts = args.n_activated_experts
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(args)
self.experts = nn.ModuleList([
Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx
else None
for i in range(self.n_routed_experts)
])
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
n_local_experts는 각 rank에서 처리할 expert의 수가 되겠고요. 아래에는 이 expert들의 index를 설정해주고 있네요.
Gate() 모듈을 설정해주고, 전체 expert 목록을 업데이트 합니다. 이 rank에서 맡지 않는 expert는 None으로 처리 해주고요.
마지막에는 shared_experts를 설정해주는데, 모든 입력에 대해 동일하게 적용되는 공유 expert를 말합니다. 이건 MLP로 설정해주네요.
def forward(self, x: torch.Tensor) -> torch.Tensor:
...
shape = x.size()
x = x.view(-1, self.dim)
weights, indices = self.gate(x)
y = torch.zeros_like(x)
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx]) * weights[idx, top, None]
z = self.shared_experts(x)
if world_size > 1:
dist.all_reduce(y)
return (y + z).view(shape)
forward()에서는 먼저 self.gate(x)를 통해서 각 토큰이 어떤 expert에 배정되었는지, 그리고 가중치는 얼마나 되는지 획득합니다.
- weights, indices = self.gate(x)
그리고 각 expert에 라우팅된 토큰이 몇 개인지 확인합니다.
- counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
- indices를 1D로 펴서(flatten()) 각 expert id가 몇 번 나오는지 셉니다.
- counts[i]가 0이면, expert i에 할당된 토큰이 없다는 의미
이 rank가 담당하는 expert에 대해서만 연산을 수행하고요. (self.experts_start_idx, self.experts_end_idx)
idx, top = torch.where(indices == i) 에서 idx, top은 배치 내 토큰 index, 몇 번째 activated_expert인지를 가르킵니다. 이 정보를 바탕으로 expert 연산을 수행하고, gate weight 만큼 곱해서 y에 더해주네요. y는 로컬 expert들의 출력을 이렇게 누적하는 텐서입니다.
- y[idx] += expert(x[idx]) * weights[idx, top, None]
그 후 shared_experts()는 모든 토큰에, 모든 rank에서 동일 연산을 수행하고요.
- z = self.shared_experts(x)
마지막으로 multi-gpu 인 경우 각 로컬 expert들의 partial 출력 y를 all_reduce()하여 모든 rank가 동일한 합산 결과를 공유하도록 합니다. 그리고 로컬 expert들의 결과 + shared expert 결과를 합쳐서 원래 shape로 복원해서 리턴하고 끝입니다.
- if world_size > 1: dist.all_reduce(y)
- return (y + z).view(shape)
다시 요약해보면, 이 MoE class에서는 MLP가 2번 일어나는 것입니다. 전용(Expert) + 공유(Shared) 구조로 동작하는 것이죠. Expert에서는 일부가 선택되서 전문적 처리를 수행하고, Shared에서모든 토큰에 대해 공통적으로 필요한 변환이나, 기본처리를 담당합니다.
Block class
class Block(nn.Module):
...
def __init__(self, layer_id: int, args: ModelArgs):
...
super().__init__()
self.attn = MLA(args)
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
self.attn_norm = RMSNorm(args.dim)
self.ffn_norm = RMSNorm(args.dim)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
...
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
x = x + self.ffn(self.ffn_norm(x))
return x
이 block 클래스는 Transformer 블록을 정의한 예시로, self-Attention((attn)) 레이어와 Feed-Forward Network((ffn)) 레이어를 차례로 적용하고, 각각을 위한 Layer Normalization((attn_norm, ffn_norm))이 포함된 구조입니다. layer_id가 n_dense_layers보다 작으면 일반 MLP, 그 외에는 MoE로 동작하게 되네요. 그리고 MLA를 사용하는 것을 알 수 있습니다. 이제 이 블록을 여러개 쌓아서 Transfomer를 구성하면 코드는 끝입니다.
Transformer class
class Transformer(nn.Module):
...
def __init__(self, args: ModelArgs):
...
global world_size, rank
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
super().__init__()
self.max_seq_len = args.max_seq_len
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args))
self.norm = RMSNorm(args.dim)
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
...
seqlen = tokens.size(1)
h = self.embed(tokens)
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)[:, -1]
logits = self.head(h)
if world_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
dist.all_gather(all_logits, logits)
logits = torch.cat(all_logits, dim=-1)
return logits
마지막으로, Transformer class에서는 토큰 임베딩 + 여러 개의 Transformer Block + 최종 Normalization 및 Linear Projection을 수행하고, 특히 Rotary Position Embedding, 분산 병렬화 등 다양한 요소가 함께 들어 있습니다.
forward()는 다른 transformer model들과 크게 다르지 않습니다.
임베딩을 거쳐, 마스크를 생성하고, RMSNorm -> Attention -> RMSNorm -> MLP or MoE -> Residual 연결을 레이어 만큼 수행합니다. 마지막으로 normalization을 하고, 종 시퀀스 중에서 [-1] 위치(마지막 토큰)에 대해서만 로짓을 계산하고 output projection하고, multi-gpu 환경에서는 all-gather로 logit을 모두 합쳐 리턴합니다.
kernel.py와 generate.py는 2탄에서 알아 보겠습니다. 글이 너무 길어져서..