Tech

Sarathi-Serve 상세 기술 분석

임로켓 2025. 4. 5. 17:11
728x90

전체 아키텍처 개요 및 모듈 구조

Sarathi-Serve는 대규모 언어 모델(LLM)의 온라인 추론을 위한 고성능 서빙 엔진으로서, 낮은 지연 시간과 높은 처리량을 동시에 확보하기 위해 특화된 구조로 설계되었습니다. 전체 시스템은 엔진 프로세스와 워커 프로세스로 구분되어 운영되며, 주요 구성 요소로는 스케줄러, 시퀀스 관리자, 블록 메모리 관리자, 모델 실행기, 요청 처리기(API 서버) 등이 있습니다. 각 구성 요소의 역할은 다음과 같습니다.

엔진: 중앙 제어 모듈로서, 새로운 요청을 접수하고 스케줄러를 통해 요청들의 배치 및 실행 방법을 결정합니다. 엔진은 워커들과의 통신을 담당하며, 결과를 취합하여 응답을 반환합니다.

스케줄러: 실행 대기 중인 모든 시퀀스들을 관리하며, 각 반복 단계에서 특정 요청에 대해 프리필 또는 디코드를 실행할지를 결정합니다. Sarathi-Serve의 스케줄러에는 Stall-Free Batching 알고리즘 등이 구현되어 있습니다.

시퀀스 관리자: 현재 생성 중인 시퀀스들을 추적하며, 각 시퀀스의 토큰 길이, 상태(프리필 또는 디코드 단계), 완료 여부 등의 메타데이터를 관리합니다. 스케줄러의 결정에 따라 배치 실행 시 필요한 입력 텐서 준비 또는 시퀀스 종료 처리 등을 지원합니다.

블록 메모리 관리자: KV 캐시를 GPU 메모리 상에 효율적으로 관리하는 컴포넌트로서, 여러 요청의 KV 캐시를 하나의 큰 메모리 풀에서 고정 크기 블록 단위로 할당 및 해제합니다. vLLM의 메모리 관리 기법을 계승하여, 연속적인 메모리 블록을 활용함으로써 다수 요청 처리 시 메모리 단편화를 감소시키고 접근 효율성을 향상시킵니다.

모델 실행기: 실제 Transformer 모델을 호출하여 주어진 입력에 대한 출력을 계산하는 모듈입니다. 내부에 모델 러너가 존재하여 Hugging Face Transformers 기반 모델을 로드하고, 전처리된 입력 텐서와 KV 캐시를 사용하여 forward 연산을 수행합니다. 필요에 따라 다중 GPU 분산(예: 파이프라인 병렬)도 지원합니다.

워커: 각 GPU별로 구동되는 프로세스로, BaseWorker 클래스를 기반으로 동작합니다. 각 워커는 하나의 GPU와 연결되며, 엔진으로부터 전달받은 배치 작업을 해당 GPU에서 실행하고 결과를 엔진에 반환합니다. KV 캐시 메모리를 실질적으로 관리하고 보유하는 주체 또한 워커입니다. Sarathi-Serve에서는 ZeroMQ 통신을 활용하여 엔진과 워커 간에 명령 및 데이터를 교환합니다.

Sarathi-Serve의 전체 작동 방식은 엔진, 스케줄러, 워커의 순차적인 흐름으로 구성됩니다. 요청이 접수되면 엔진은 시퀀스 객체를 생성하여 스케줄러에 등록하며, 스케줄러는 주기적인 배치 반복을 통해 현재 배치에 포함할 요청을 결정합니다. 결정된 배치 작업은 엔진을 거쳐 워커들에게 전달되고(PUB/SUB 소켓 통신 활용), 워커는 GPU를 사용하여 모델을 실행한 후 결과 토큰들을 엔진에 전송합니다. 엔진은 수신된 결과를 시퀀스 관리자와 스케줄러를 통해 반영하고, 필요한 경우 다음 배치를 연속적으로 실행합니다. 이러한 과정은 각 요청이 완수될 때까지 반복되며, 완료된 요청의 최종 생성 텍스트를 응답으로 반환합니다..

 

스케줄러: Stall-Free 스케줄링 및 배치 구성

스케줄러는 Sarathi-Serve의 주요 알고리즘을 구현하는 모듈로서, 요청들을 배치로 구성하는 정책토큰 생성의 우선순위를 결정합니다. Sarathi-Serve에는 다양한 스케줄러 구현이 존재하며, sarathi/core/scheduler 디렉터리에서 확인할 수 있습니다. 기본 클래스인 BaseScheduler를 포함하여 vLLM 호환 스케줄러(VLLMScheduler), 단순 청크 프리필 스케줄러(SimpleChunkingScheduler) 및 Sarathi-Serve의 독자적인 Stall-Free 스케줄러(SarathiScheduler) 등이 구현되어 있습니다. 구성 파일을 통해 사용자가 원하는 스케줄러를 선택할 수 있으며, 기본적으로는 논문에서 제안된 Stall-Free Scheduling 기법이 적용된 SarathiScheduler를 사용하도록 설정되어 있습니다.

 

스케줄러 작동 메커니즘

스케줄러는 요청 시퀀스들의 상태(프리필 진행 여부, 잔여 토큰 길이 등)와 시스템 자원 현황을 기반으로, 현재 반복(iteration) 단계에서 실행할 배치를 결정합니다. 엔진은 내부적으로 주기적인 루프를 통해 scheduler.schedule() 메서드를 호출하며, 이 메서드에서 이번 GPU 실행을 위한 배치 작업(SchedulerOutputs)을 생성합니다. SchedulerOutputs에는 이번 배치에 포함될 각 시퀀스와 해당 시퀀스에서 이번에 처리할 토큰 수에 대한 정보가 포함됩니다.

Sarathi-Serve의 스케줄러는 iteration 단위 배치 스케줄링을 수행합니다. 즉, 한 번의 모델 forward 실행 시 여러 요청을 동시에 처리하고, 다음 iteration에서는 잔여 작업 또는 신규 요청을 다시 배치로 구성하는 반복적 방식을 채택합니다. 각 iteration에서 스케줄러는 다음과 같은 절차를 따릅니다.

  1. 신규 요청 추가 확인: 엔진에 새로 도착하여 아직 배치에 포함되지 않은 신규 시퀀스를 파악합니다. (엔진은 add_request() 호출을 통해 스케줄러에 시퀀스를 추가합니다.)
  2. 배치 대상 선정: 활성화된 시퀀스 중 이번 iteration에서 실행할 시퀀스를 선택합니다. 시퀀스는 프리필 단계(프롬프트 입력 처리 중) 또는 디코드 단계(출력 토큰 생성 중)의 두 가지 상태를 가질 수 있습니다.
  3. 청크 프리필 적용: 프리필이 진행 중인 시퀀스의 경우, 이번 iteration에서 처리할 프롬프트 토큰 수를 결정합니다. Sarathi-Serve의 청크드 프리필(chunked-prefill) 알고리즘은 프리필 요청을 거의 동일한 연산량을 갖는 청크로 분할하여 여러 iteration에 나누어 수행합니다. 이를 위해 스케줄러는 각 시퀀스별로 _get_seq_next_num_prefill_tokens() 등을 통해 이번 배치에 투입할 프리필 토큰 수를 계산합니다. 청크 크기는 사전 설정된 목표 토큰 수 또는 GPU 부하 기준 계산에 따라 결정되며, 모든 프리필 시퀀스의 청크 합이 GPU 메모리 및 연산 예산을 초과하지 않도록 조정합니다.
  4. 디코드 시퀀스 포함: 프리필을 완료하고 디코드 단계에 있는 시퀀스는 이번 iteration에서 다음 하나의 토큰을 생성하도록 배치에 포함됩니다. Sarathi-Serve 스케줄러는 가능한 한 많은 디코드 시퀀스를 배치에 함께 묶어 GPU 활용도를 극대화하고자 합니다. (이는 Decode-maximal batching 개념과 유사합니다.)
  5. 배치 완료 및 반환: 선정된 시퀀스에 대해 이번 iteration에서 처리할 토큰 수 정보를 취합하여 SchedulerOutputs를 구성합니다. 예를 들어, “시퀀스 A: 프리필 20 토큰”, “시퀀스 B: 디코드 1 토큰”, “시퀀스 C: 디코드 1 토큰” 등의 정보가 포함될 수 있습니다. 이렇게 작성된 배치를 엔진에 반환합니다. 실행할 시퀀스가 없는 경우 SchedulerOutputs.is_empty()로 표시하여 엔진에 빈 배치임을 알립니다.

스케줄러가 배치를 결정하면, 엔진은 이 정보를 바탕으로 워커에 실제 모델 실행을 요청합니다. 스케줄러는 iteration이 완료된 후 scheduler.on_step_completed() 호출을 통해 내부 상태를 갱신합니다. 이는 이번 iteration에서 처리된 토큰 수를 각 시퀀스에서 차감하고, 완료된 시퀀스를 제거하며, 필요한 통계를 집계하는 역할을 수행합니다.

 

Stall-Free Scheduling

Stall-Free 스케줄링은 진행 중인 토큰 생성 작업을 멈추지 않고도 새로운 요청을 배치에 편입시킬 수 있는 스케줄링 방법을 의미합니다. 즉, 디코드(token 생성)를 수행 중인 배치가 있어도, 적절한 시점에 새 프리필 작업을 끼워 넣어 GPU 자원 사용의 공백 없이 추가 요청을 처리합니다 . 이를 통해 디코드 토큰 생성 동안 GPU가 놀지 않고도 새로운 프롬프트 처리를 병렬로 진행할 수 있어, 지연시간 증가 없이 처리량을 높이는 효과를 얻습니다.

기본적으로, Stall-Free 배치는 현재 디코드 중인 시퀀스들의 배치신규 프리필 청크합쳐 하나의 큰 배치로 실행하는 방식으로 구현됩니다. 예를 들어, 이전 iteration에서 디코드 중이던 시퀀스들이 이번 iteration에도 각자 1토큰씩 생성해야 하는 상황이라면, 여기에 새로 도착한 요청의 프리필 청크(예: 16토큰)를 함께 묶어 GPU에 한 번에 입력합니다. 이렇게 하면 새로운 요청의 프리필도 바로 처리되면서, 원래 디코드 배치는 중단 없이 이어지고, GPU는 더 큰 배치를 처리하여 효율이 올라갑니다. 

Sarathi-Serve의 스케줄러는 배치를 구성할 때 항상 디코드 토큰들을 최대한 포함시키고 남는 계산 여력을 프리필 청크로 채워 배치 크기를 일정 수준으로 유지합니다. 이러한 방식을 논문에서는 decode-maximal batching이라고도 부르며, Sarathi-Serve의 Stall-Free 스케줄링은 decode-maximal batching을 실현한 것이라 볼 수 있습니다.

3개의 디코드 요청이 진행 중인 상황에서 새로운 프리필 요청이 1개 추가되었다고 가정해 보겠습니다. 기존 스케줄러는 새로운 요청의 프리필 처리를 위해 기존 디코드 요청을 일시적으로 중단하거나 별도의 배치로 처리해야 했습니다. 그러나 Stall-Free 스케줄러는 신규 프리필 요청의 프롬프트를 여러 개의 청크로 분할하여 디코드 배치에 순차적으로 통합합니다. 즉, 첫 번째 반복에서는 3개의 디코드와 1개의 프리필 청크를 처리하고, 다음 반복에서는 3개의 디코드와 다른 프리필 청크를 처리하는 방식으로 진행됩니다. 이러한 방식을 통해 각 반복마다 디코드 토큰 생성을 빠짐없이 수행하면서도, 여러 번의 반복에 걸쳐 프리필을 완료할 수 있습니다. 결과적으로 어느 한쪽의 처리로 인해 다른 쪽의 처리가 중단되는 현상을 방지할 수 있습니다.

 

Sarathi-Serve의 SarathiScheduler 클래스는 이러한 로직을 구현합니다. 코드 구현을 살펴보면, 프리필이 완료되지 않은 시퀀스와 디코드 대기 시퀀스를 분리하여 관리하며, 각 반복마다 우선적으로 디코드 토큰을 모두 포함한 후 남은 여유 공간에 가능한 만큼 프리필 청크를 배치하는 형태로 구성됩니다. 또한, 프리필 청크는 항상 배치의 앞부분에서 처리되도록 배치 순서를 정렬하여 모델 내부 구현(특히 가변 시퀀스 길이에 따른 마스킹 처리)에 문제가 발생하지 않도록 주의를 기울입니다. 이와 같은 Stall-Free Scheduling을 통해 Sarathi-Serve는 프리필 우선 정책에서 발생하던 생성 중단 현상을 해결하는 동시에 디코드 위주 정책의 낮은 처리량 문제 또한 극복할 수 있습니다.

 

Chunked Prefill

Chunked-prefill은 Stall-Free 스케줄링을 가능하게 해주는 또 하나의 핵심 기법입니다. Chunked-prefill이란 프리필 요청을 동등한 연산량을 갖는 청크들로 분할하여 여러 iteration에 걸쳐 처리하는 것을 의미합니다. 예를 들어 1000토큰 길이의 프롬프트가 있다면, 이를 한 번에 처리하는 대신 200토큰씩 5번의 iteration으로 나누어 처리할 수 있습니다 (청크 크기는 모델과 GPU 성능에 따라 설정).

이러한 분할의 이점은 두 가지입니다. 첫째, 새 요청의 프리필 작업이 오래 걸려 다른 디코드 작업을 지연시키는 것을 막을 수 있습니다. 프리필을 작은 청크로 쪼개면, 각 iteration마다 일부만 처리하고 바로 디코드 작업과 번갈아가며 실행되므로 특정 요청 때문에 몇 초씩 디코드가 멈추는 현상을 예방합니다 . 둘째, 프리필 청크당 GPU 활용도를 균일하게 유지할 수 있습니다. 각 청크는 거의 일정한 계산 소요를 갖도록 크기를 조절하기 때문에, 매 iteration의 배치가 편차 없이 유사한 실행 시간을 갖습니다. 이는 특히 파이프라인 병렬(Pipeline Parallelism) 환경에서 마이크로배치 간 불균형을 완화하여 파이프라인 버블을 줄여주는 효과가 있습니다 .

Sarathi-Serve 구현에서 청크드 프리필은 주로 스케줄러와 블록 메모리 관리자 측면에서 지원됩니다. 스케줄러는 앞서 설명한 바와 같이 _get_seq_next_num_prefill_tokens() 함수를 통해 현재 시퀀스에 남아있는 프롬프트 토큰 중 이번에 처리할 양(prompt_chunk_len)을 결정합니다. 한편, 블록 메모리 관리자는 청크드 프리필에 적합하게 메모리 할당을 유연하게 지원합니다. 예를 들어, 최초 프리필 청크 실행 시 필요한 KV 캐시 블록들을 할당하고, 추가 청크 처리 시 다음 블록들을 연속적으로 사용하는 방식을 채택합니다. Sarathi-Serve에서는 vLLM의 메모리 관리 기법을 기반으로 블록 단위 할당을 수행하므로, 프리필을 청크 단위로 분할하여도 KV 캐시 메모리가 연속적인 공간에 저장됩니다. 따라서 청크 분할로 인한 메모리 단편화는 발생하지 않으며, 각 청크 처리 완료 후 이미 확보된 KV 캐시를 활용하여 후속 토큰 생성을 신속하게 진행할 수 있습니다.

 

모델 실행 흐름 및 배치 처리

Sarathi-Serve에서 모델 실행은 주로 워커 프로세스에서 이루어지지만, 엔진-스케줄러-워커가 협업하여 배치 처리 파이프라인을 완성합니다. 여기에서는 요청이 들어와 응답이 나가기까지의 전체 흐름을 단계별로 설명하며, 특히 프리필/디코드 배치 처리KV 캐시 활용 부분에 초점을 맞춥니다.

요청 수신과 시퀀스 추가

클라이언트의 요청(예: HTTP API 호출)이 접수되면, 엔진 프로세스의 API 서버(요청 처리기)가 해당 요청을 처리합니다. Sarathi-Serve는 OpenAI 호환 API 엔드포인트를 제공하도록 구현되어 있으며, sarathi/entrypoints/openai/serving_chat.py 등의 모듈에서 REST API 요청을 파싱한 후 엔진에 전달합니다. API 서버는 프롬프트 및 생성 파라미터(예: 최대 토큰 수, 온도 등)를 추출하여, 엔진 인스턴스의 add_request(prompt, sampling_params, ...) 메소드를 호출합니다.

 

엔진(BaseLLMEngine)의 add_request 구현은 아래와 같습니다(pseudo code).

def add_request(self, prompt: str, sampling_params: SamplingParams, ...):
    # 1. 프롬프트 문자열을 토크나이즈하여 token ID 리스트 얻기
    if prompt_token_ids is None:
        prompt_token_ids = self.tokenizer.tokenize(prompt)
    # 2. 새로운 시퀀스 객체 생성 (시퀀스 ID, 토큰 목록, 파라미터 등 포함)
    seq = Sequence(seq_id=..., tokens=prompt_token_ids, sampling_params=sampling_params, ...)
    # 3. 스케줄러에 시퀀스 추가
    self.scheduler.add_seq(seq)
    # 4. 메트릭 저장소에 도착 기록
    self.metrics_store.on_request_arrival(seq)
    # (실제 응답은 engine.step() 호출들을 통해 생성 완료 후 반환)

 

새로운 시퀀스 객체에는 요청 식별을 위한 seq_id, 토큰화된 프롬프트, 생성 완료 조건(예: 최대 생성 토큰 수) 등이 포함됩니다. scheduler.add_seq(seq)를 호출하면 스케줄러 내부 큐에 시퀀스를 등록하여 차기 배치 스케줄링 시 고려 대상에 포함합니다. 요청이 추가됨에 따라, 엔진은 해당 시퀀스 처리를 위한 배치 실행 루프를 개시합니다. Sarathi-Serve에서는 엔진이 별도 스레드 또는 비동기 방식으로 engine.step()을 반복 호출하여, 더 이상 처리할 시퀀스가 없을 때까지 배치를 지속적으로 수행합니다. (또는 일정 간격으로 step을 호출하여 요청 풀을 비우는 방식으로 작동합니다.)

 

배치 스케줄링 및 워커 실행

 

앞서 기술된 바와 같이, 엔진의 step() 메소드는 단일 반복 실행 단위를 수행합니다. 해당 메소드의 주요 작동 절차를 pseudo code 형식으로 나타내면 다음과 같습니다.

def step(self) -> List[RequestOutput]:
    # 1. 스케줄러에게 다음 배치 구성 지시
    scheduler_outputs = self.scheduler.schedule()   
    if scheduler_outputs.is_empty():
        return []  # 처리할 작업 없음
   
    # 2. 시퀀스 관리자에 배치 정보 전달 (메모리 예약 등)
    ignored_seqs, seq_metadata_list = self.seq_manager.on_schedule(scheduler_outputs)
   
    # 3. 워커로 배치 실행 요청 전송 (ZMQ PUB 소켓)
    step_inputs = StepInputs(scheduler_outputs, new_seqs=self._get_new_seqs())
    self.enqueue_socket.send_pyobj(step_inputs)     
   
    # 4. 워커로부터 결과 수신 (ZMQ PULL 소켓, blocking 대기)
    sampler_outputs = self.output_socket.recv_pyobj()
   
    # 5. 완료된 iteration 처리 (시퀀스/스케줄러 상태 갱신 등)
    result_outputs = self._on_step_completed(scheduler_outputs, ignored_seqs, seq_metadata_list, sampler_outputs, start_time)
   
    return result_outputs  # 새 토큰이 추가된 시퀀스들의 출력 결과 리스트

 

  1. `scheduler.schedule()` 호출을 통해 이번 반복 실행에 필요한 배치(scheduler_outputs)를 획득합니다. 예를 들어, scheduler_outputs에는 시퀀스 A(프리필 16개 토큰), 시퀀스 B(디코드 1개 토큰), 시퀀스 C(디코드 1개 토큰) 등의 정보가 포함될 수 있습니다.
  2. 획득한 정보를 `seq_manager.on_schedule`에 전달하면, 시퀀스 관리자는 해당 시퀀스들의 메타데이터를 갱신하고, 이번 배치에서 제외해야 할 시퀀스(ignored_seqs, 예를 들어 응답이 완료되어 결과만 수집하고 제외해야 하는 시퀀스)를 처리합니다. 일반적으로 대부분의 시퀀스는 무시되지 않으며, seq_metadata_list를 통해 메타 정보가 반환됩니다.
  3. 다음으로, StepInputs 객체를 생성하여 워커에게 전송합니다. StepInputs에는 배치 실행 정보(scheduler_outputs)와 신규 시퀀스 목록(new_seqs)이 포함됩니다. 여기서 신규 시퀀스란 이번 반복 실행에서 처음 프리필을 시작하는 시퀀스들을 의미합니다. 이 정보에는 해당 시퀀스들의 프롬프트 토큰 ID 배열이 함께 포함되어 워커가 즉시 첫 번째 프리필 청크를 처리할 수 있도록 합니다. 엔진은 ZeroMQ의 PUB 소켓(enqueue_socket)을 통해 이 객체를 전송하며, 모든 워커 프로세스는 각자의 SUB 소켓으로 이 메시지를 수신합니다. Sarathi-Serve에서는 모든 워커가 동일한 모델의 일부를 담당하거나, 단일 GPU 환경에서는 하나의 워커만 존재합니다. 여기서는 단일 워커 시나리오를 중심으로 설명합니다.
  4. 작업자는 전달받은 StepInputs를 처리하여 GPU에서 모델 실행을 수행합니다. BaseWorker의 _execution_loop 스레드는 SUB 소켓 메시지를 수신하면, 메시지 내의 scheduler_outputs를 분석하여 처리해야 할 시퀀스와 토큰 수를 확인합니다. 이후 ModelRunner를 통해 실제 모델 forward를 수행하며, 이때 토큰 ID 입력, 과거 KV 캐시, 그리고 생성될 출력 개수 등을 인자로 모델을 호출합니다. 작업자의 ModelRunner는 HuggingFace 모델을 래핑하여 vLLM 스타일의 KV 캐시 관리 기법을 적용한 forward 함수를 사용합니다. 예를 들어, 프리필 단계에서는 입력 토큰들을 일괄 처리하여 KV 캐시를 채우고, 디코드 단계에서는 과거 캐시로부터 다음 토큰의 로짓을 계산합니다. 이때 작업자는 WorkerSequenceManager와 BlockSpaceManager를 활용하여 필요한 KV 캐시 메모리 블록을 할당/참조하고, 입력 시퀀스들의 현재 토큰 포인터 등을 관리합니다.

모델 연산이 완료되면 작업자는 각 시퀀스에 대한 생성 결과를 정리합니다. 디코드 단계 시퀀스들은 각각 1개의 신규 생성 토큰(또는 배치에 따라 여러 토큰)을 얻게 되며, 해당 토큰 ID와 토큰 문자열 등을 SamplerOutputs에 담아 엔진으로 반환합니다. 이때 작업자는 ZeroMQ PUSH 소켓(output_socket)을 이용하여 sampler_outputs 객체를 엔진의 PULL 소켓으로 전송합니다. sampler_outputs에는 배치에 포함된 각 시퀀스에 대해 새로 생성된 토큰들과 해당 확률(또는 로짓) 등의 정보가 포함됩니다.엔진은 output_socket.recv_pyobj()를 통해 블로킹 대기 상태를 유지하다가 작업자로부터 결과 객체를 수신합니다.

5. 엔진은 _on_step_completed(...)를 호출하여 이번 iteration을 마무리합니다. 이 함수에서 수행되는 작업은 다음과 같습니다.

• seq_manager.on_step_completed(scheduler_outputs, sampler_outputs)를 호출하여 각 시퀀스의 상태를 업데이트합니다. 예를 들어, 디코드 단계 시퀀스들은 방금 생성된 토큰을 시퀀스 토큰 목록에 추가하고 현재 길이를 증가시킵니다. 프리필 단계 시퀀스 중 이번에 프롬프트 일부를 처리한 경우 프리필 진행률을 업데이트하고, 프롬프트를 전부 처리했다면 해당 시퀀스를 디코드 단계로 전환시킵니다. 또한 시퀀스가 종료 조건(예: 토큰 최대 길이 도달 또는 스트림 종료 토큰 등)에 도달했는지 확인하여 완료 플래그를 설정할 수 있습니다.

 

• self.scheduler.on_step_completed()를 호출하여 스케줄러 내부 상태를 갱신합니다. Sarathi-Serve의 스케줄러 구현에서는 iteration 종료 시 필요한 정리 작업(예: 이번에 프리필 청크를 처리한 시퀀스들의 처리량 업데이트, 현재 배치의 청크 처리량에 따라 다음 청크 크기 조정 등)을 수행합니다. 스케줄러는 완료된 시퀀스를 내부 큐에서 제거하고, 다음 iteration에 고려할 새 시퀀스들을 반영합니다.

 

• 메트릭 수집: metrics_store.on_batch_end(...)를 호출하여 이번 배치의 처리 시간, 포함된 토큰 수 등의 통계를 기록합니다. Sarathi-Serve는 각 iteration별 토큰 처리량, 대기 시간 등을 측정하여 추후 성능 분석 또는 SLA 확인에 활용합니다.

 

• seq_manager.generate_request_outputs(ignored_seqs, seq_metadata_list)를 호출하여 최종 응답 생성 결과를 취합합니다. 이 함수는 시퀀스 중 응답이 완료된 시퀀스들을 찾아 RequestOutput 형태로 변환합니다. 여기에는 완성된 텍스트, 완료 상태(완료/중단/길이 초과 등), 사용된 토큰 수 등의 정보가 포함됩니다. 완료된 시퀀스는 엔진의 응답 리스트에 추가되고, 더 이상 스케줄러에서 처리하지 않도록 조치됩니다.

engine.step()은 이러한 결과 RequestOutput 리스트를 반환합니다. 일반적으로 한 iteration에서 대부분의 시퀀스는 완료되지 않고 진행 중이므로 이 리스트는 비어 있는 경우가 많습니다. 하지만 특정 시퀀스가 마지막 토큰을 생성하여 종료됐다면 이 리스트에 해당 요청의 최종 결과가 포함됩니다.

엔진은 내부적으로 지속적으로 step()을 반복 호출하기 때문에, 곧 다음 iteration의 scheduler.schedule()이 실행되고 새로운 배치가 구성되어 처리됩니다. 이 루프는 스케줄러에 더 이상 unfinished sequence가 없을 때까지, 즉 모든 요청이 완료될 때까지 돌아갑니다. 

응답 반환 및 스트리밍

  1. 요청 접수: API 서버는 엔진에 add_request 함수를 호출하여 요청을 추가합니다.
  2. 배치 결정: 엔진은 스케줄러의 schedule() 함수를 호출하여 SchedulerOutputs을 생성하고 배치를 결정합니다.
  3. 배치 전송: 엔진은 워커에게 StepInputs을 PUB 방식으로 전송하여 배치를 전달합니다.
  4. 모델 실행: 워커는 GPU를 사용하여 ModelRunner를 통해 프리필/디코드 작업을 수행하고, KV 캐시를 활용하여 모델을 실행합니다.
  5. 결과 수신: 워커는 SamplerOutputs을 PUSH 방식으로 수신하여 엔진에 결과를 전달합니다.
  6. 상태 갱신: 엔진은 SequenceManager와 Scheduler를 업데이트하여 상태를 갱신합니다.
  7. 반복 수행: 미완료된 시퀀스가 존재할 경우, 2번 단계부터 step() 함수를 반복 수행합니다.
  8. 응답 처리: 완료된 요청에 대해서는 API 서버를 통해 클라이언트에게 응답을 반환합니다.

이러한 과정을 통해 Sarathi-Serve는 실시간으로 다수의 요청을 병렬 처리하여 GPU 활용률을 높이고, 처리량을 극대화합니다. Sarathi-Serve는 토큰 단위로 결과를 제공할 수 있는 구조를 갖추고 있으며, 스트리밍 응답 또한 지원 가능합니다. ZeroMQ PUB/SUB 패턴을 사용하여 엔진의 명령을 모든 워커가 수신하며, 싱글 모델 싱글 워커 또는 동기 병렬 처리를 염두에 두고 설계되었습니다.

 

KV 캐시 및 메모리 관리 기법

KV 캐시는 Transformer 모델의 각 Self-Attention 레이어에서 이전 토큰들의 Key/Value 행렬을 저장하며, 디코딩 시 새로운 토큰 생성 시에 이를 재활용함으로써 반복적인 연산을 최소화합니다. Sarathi-Serve는 vLLM과 유사하게, 다수의 요청이 존재하는 환경에서 KV 캐시를 효율적으로 관리하기 위해 특화된 메모리 관리 전략을 활용합니다.

 

블록 단위 메모리 관리

Sarathi-Serve는 GPU 메모리에 KV 캐시를 일괄 확보하고, 이를 세분화된 블록 단위로 나누어 각 시퀀스에 할당하는 방식을 채택합니다. 기본적인 동작 방식은 vLLM과 동일합니다. 

 

• 초기 메모리 풀 할당: 엔진 시작 시 모델 크기, GPU 메모리 용량, 설정된 gpu_memory_utilization 비율 등을 기반으로 KV 캐시에 할당될 메모리 풀 크기를 결정합니다. 예를 들어, GPU 메모리의 30%가 KV 캐시에 사용되도록 설정되었다면, 해당 범위 내에서 최대한 많은 토큰을 저장할 수 있는 크기로 풀을 생성합니다. Sarathi-Serve 엔진 초기화 과정 중 _init_cache() 단계에서 프로파일링을 통해 이 값을 산출합니다. 워커가 모델 하나를 순방향 실행하여 토큰당 필요한 메모리 등을 계산하고, 이를 바탕으로 num_gpu_blocks를 설정합니다.

 

• 블록 구성: 전체 KV 캐시 메모리 풀은 동일한 크기의 블록들로 분할됩니다. 각 블록은 모델의 한 레이어에 대한 한 토큰 분량의 KV를 저장할 수 있는 크기를 가집니다. 예를 들어, GPT-3 계열 모델의 경우 한 블록은 모든 헤드의 key/value 벡터를 담을 수 있는 크기입니다. 이러한 블록 크기 정의에 따라, 1000개의 블록이 존재한다면 이는 해당 GPU에서 최대 1000 토큰 분량의 KV를 저장할 수 있음을 의미합니다.

 

• 할당 정책: 새로운 시퀀스가 입력되면, 해당 시퀀스가 생성할 최대 토큰 수만큼 블록을 예약합니다. vLLM에서는 지연 할당 방식으로 필요 시 할당하지만, Sarathi-Serve에서는 prefill 전에 일부 블록을 확보하고 시작합니다. EngineSequenceManager는 스케줄러의 배치 결정 시 on_schedule()에서 필요한 블록을 미리 확보하도록 워커에 지시합니다. 워커의 init_cache_engine 호출 시 각 워커의 gpu_cache(실제 CUDA 메모리)와 cache_engine(관리자)을 초기화합니다.

 

• 접근 방식: 모델 순방향 실행 시, 내부 Attention 연산은 기존의 (batch, seq_len, hidden) 형태의 KV를 참조하는 대신, Sarathi-Serve 커스텀 Attention 백엔드를 통해 블록 단위로 분산 저장된 KV를 취합하여 사용합니다. Sarathi-Serve는 vLLM과 마찬가지로 HuggingFace 모델의 Attention 부분을 오버라이드하여, 토큰 아이디 대신 블록 인덱스 리스트를 통해 KV 값을 조회하는 방식을 사용합니다. 예를 들어, FlashInferAttentionWrapper 등이 이러한 역할을 수행합니다. 따라서 다수의 시퀀스 KV가 단일 거대한 텐서에 존재하더라도 필요한 부분만 인덱싱하여 접근할 수 있습니다.

• 해제 및 재사용: 한 시퀀스의 생성이 완료되거나 중단되면, 해당 시퀀스가 사용하던 블록들은 즉시 해제되어 풀에 반환됩니다. Sarathi-Serve는 vLLM과 동일하게 프리 리스트 방식을 통해 빈 블록들을 관리합니다. 이후 다른 시퀀스가 입력되면 빈 블록들 중 연속된 영역을 할당받아 사용합니다. 이러한 방식을 통해 메모리 재사용 효율성을 높이고, GPU 메모리 할당/해제 호출 빈도를 줄일 수 있습니다.

 

메모리 관리와 Stall-Free/Chunked Prefill의 관계

청크 분할 사전 입력(Chunked Prefill) 및 Stall-Free 스케줄링은 메모리 관리 측면에서 상당한 이점을 제공합니다. 사전 입력을 여러 청크로 분할함으로써, 초기 청크 처리 시에만 대량의 블록 할당이 발생하며, 후속 청크들은 이미 확보된 블록의 연속된 영역을 활용하게 됩니다. 예를 들어, 1000 토큰의 프롬프트를 5회에 걸쳐 200 토큰씩 처리하는 경우, 최초 200 토큰 처리 시 200개의 블록이 할당되고 키-값(KV) 데이터가 채워집니다. 이후 200 토큰 청크는 이미 할당된 200번부터 399번째 블록 영역에 순차적으로 채워집니다. 즉, 단일 시퀀스의 사전 입력 전체에 걸쳐 연속적인 블록 영역을 확보하여 사용하는 것입니다. Sarathi-Serve의 BlockSpaceManager는 이러한 연속 할당을 고려하여 작동하므로, 청크 분할 사전 입력으로 인한 추가적인 메모리 오버헤드는 발생하지 않습니다.

 

또한, stall free 방식으로 디코딩과 사전 입력이 혼합되어 실행되더라도, 메모리 관리자는 다수의 시퀀스의 키-값 데이터를 동시에 수용할 수 있도록 충분히 큰 풀을 확보해 뒀음으로, 병행 처리에 문제가 발생하지 않습니다. 다만, 극단적인 경우, 과도한 요청이 동시에 유입되어 예약된 블록 수를 초과하게 되면, 새로운 시퀀스를 수용하지 못하거나 (스케줄러가 대기시킴) 메모리 부족(OOM) 위험이 발생할 수 있습니다. Sarathi-Serve는 이를 방지하기 위해 scheduler.schedule() 단계에서 가용 블록 수를 고려하여 사전 입력 시퀀스를 제한합니다. 예를 들어, 현재 가용 블록이 500개인 상황에서 새로운 사전 입력 청크 600 토큰을 한 번에 처리하려는 경우, 스케줄러는 _get_seq_next_num_prefill_tokens() 함수에서 값을 500 이하로 조정하거나, 일부 디코딩 시퀀스를 이번 배치에서 제외하여 블록을 확보하는 방식으로 조절할 수 있습니다.

즉, Sarathi-Serve의 Key-Value 캐시 관리 메커니즘은 다음과 같습니다.

  1. 대규모 연속 메모리 풀을 확보하고, 이를 균일한 크기의 블록으로 분할하여 관리하는 vLLM 방식을 채택합니다. 
  2. Key-Value 캐시는 각 시퀀스에 연속된 블록을 할당하는 방식으로 관리되며, 사용이 완료된 블록은 동적으로 해제 및 재활용됩니다. 
  3. Chunked Prefill 방식에도 적합하도록 연속적인 블록 사용을 보장합니다. 
  4. 여러 시퀀스가 동시에 실행되는 상황에서도 블록 풀 내에서 효율적인 관리가 이루어집니다. 

이러한 메모리 관리 메커니즘을 통해 Sarathi-Serve는 수백 건의 동시 생성 요청에 대해서도 효율적인 캐시 운영이 가능하며, GPU 메모리의 활용도를 극대화함으로써 탁월한 성능을 발휘합니다.

 

주요 클래스 및 구현 세부 내용

Sarathi-serv 코드에서 주목할 만한 주요 클래스와 함수들을 간략히 소개합니다. 이들은 앞서 설명된 기능들을 실제로 구현하는 핵심 요소들입니다.

LLMEngine (sarathi/engine/llm_engine.py): 엔진의 진입점 클래스로, LLMEngine.from_system_config()를 통해 설정에 따라 BaseLLMEngine 또는 PipelineParallelLLMEngine 인스턴스를 생성합니다. 통상적으로 BaseLLMEngine을 사용하며, 이 클래스의 step() 및 add_request() 등은 주요 메소드입니다. 또한 엔진 초기화 시 _init_zmq_sockets()로 ZMQ 통신 채널을 설정하고, _init_workers_ray()로 Ray를 이용한 워커 프로세스 생성을 수행합니다.

BaseLLMEngine (sarathi/engine/base\_llm\_engine.py): 엔진 로직을 구현한 클래스입니다. self.scheduler, self.seq\manager, self.tokenizer, self.workers 등을 멤버로 포함하고 있습니다. __init_에서 EngineSequenceManager와 SchedulerRegistry.get(...)을 호출하여 시퀀스 관리자와 스케줄러 인스턴스를 초기화합니다. 또한 Ray를 통해 RayWorker(워커 프로세스)를 시작하고 준비 상태를 확인합니다. 주요 메소드는 다음과 같습니다.

• add_request(prompt, sampling_params, ...): 새로운 요청을 시퀀스로 추가하는 함수로, 앞서 의사 코드로 설명된 내용에 해당하며, scheduler.add_seq()를 호출하여 스케줄러에 등록합니다.

• step(): 한 번의 반복(iteration)을 실행하는 함수로, 스케줄링 -> 전송 -> 수신 -> 완료 처리 흐름을 구현합니다. 이 함수는 사라티-서브 동작의 핵심이며, 앞서 상세히 설명되었습니다.

• _on_step_completed(...): 반복 종료 처리를 담당하는 내부 함수로, 시퀀스 매니저/스케줄러 업데이트 및 RequestOutput 수집을 수행합니다.

• _run_workers(method, ...): Ray로 관리되는 멀티 워커들에게 특정 메소드를 원격 호출하는 헬퍼 함수입니다. 예를 들어, _run_workers("init_cache_engine", ...)으로 모든 워커의 init_cache_engine을 실행하거나, _run_workers("get_model_parallel_ranks")로 각 워커의 모델 병렬 랭크 정보를 가져옵니다. Ray의 execute_method.remote를 이용하여 호출하고 결과를 수집합니다.

 

EngineSequenceManager / WorkerSequenceManager (sarathi/core/sequence_manager): 시퀀스 관리자의 엔진 측과 워커 측 구현입니다. 엔진 측은 주로 시퀀스 메타데이터를 관리하고, 워커 측은 실제 시퀀스의 토큰 및 캐시 상태를 관리합니다. EngineSequenceManager.on_schedule(scheduler_outputs)는 엔진에서 배치 결정 후 호출되어 이번 배치에서 처리될 시퀀스들의 메타데이터 리스트를 반환하고 완료된 시퀀스(ignored)를 식별합니다. 한편 워커에서는 WorkerSequenceManager가 BaseWorker 내에서 사용되어 워커가 새 토큰을 생성할 때 해당 시퀀스의 캐시 인덱스, 현재 토큰 위치 등을 관리합니다.

Scheduler 구현들 (sarathi/core/scheduler): 다양한 스케줄러 클래스가 존재합니다.

  • VLLMScheduler: vLLM의 기본 continuous batching 스케줄링을 구현합니다.
  • SimpleChunkingScheduler: 프리필을 chunk로 분할하는 단순 정책을 채택합니다.
  • SarathiScheduler: Stall-Free 및 Chunked Prefill을 완전하게 구현하며, 내부 함수인 _schedule_prefill_and_decode() 등에서 running_prefills와 running_decodes 리스트를 관리하고 배치 토큰 수를 조절합니다.
  • 각 스케줄러 클래스는 BaseScheduler를 상속하며, add_seq(), schedule(), on_step_completed() 등의 공통 인터페이스를 오버라이드합니다.
  • SchedulerRegistry: 스케줄러 종류 문자열(예: “sarathi”, “vllm”)을 클래스에 매핑하는 팩토리로, 엔진 초기화 시 활용됩니다.

BlockSpaceManager 구현들 (sarathi/core/block_space_manager):

  • VLLMBlockSpaceManager: 내부에 거대한 torch.Tensor를 할당하고 블록 단위로 관리합니다. allocate_blocks(num_blocks)로 연속 블록 영역 할당, free_blocks(start, count)로 해제 등의 메소드를 제공합니다.
  • SarathiBlockSpaceManager / SimpleChunkingBlockSpaceManager: 현재 코드상 내용은 pass로, 별도의 추가 동작 없이 기본 VLLM 동작을 그대로 사용합니다. 이는 Sarathi-Serve의 메모리 관리가 vLLM과 동일함을 의미합니다.
  • OrcaBlockSpaceManager, FasterTransformerBlockSpaceManager: 실험적 구현으로, Orca 또는 NVIDIA FasterTransformer 통합 시의 메모리 관리 기법(논문에 언급됨)을 포함하나, 핵심 아이디어는 VLLM 방식과 유사합니다.

ModelRunner (sarathi/model_executor/model_runner.py): 모델 실행을 추상화한 클래스입니다. 초기화 시 ModelLoader를 통해 HuggingFace 모델을 로드하고, 해당 모델을 Sarathi-Serve의 커스텀 모듈(특히 Attention 부분)로 패치합니다. ModelRunner.execute_step(step_inputs) 등의 메소드는 StepInputs를 받아 실제 모델 forward를 호출하고 SamplerOutputs를 반환합니다. 내부적으로 prefill인지 decode인지에 따라 다른 함수를 호출하며, KV 캐시 텐서의 인덱싱을 처리합니다. 

  • 프리필: model(..., use_cache=True)를 호출하여 입력 토큰 전체를 통과시키고, 출력으로 나온 past_key_values를 Sarathi의 KV 캐시 공간에 복사 저장합니다.
  • 디코드: 과거 캐시를 불러와 model(input_ids=last_token, past_key_values=..., use_cache=True)를 호출하고, 새로운 past_key_values 한 스텝분을 업데이트 및 로짓을 반환합니다.
  • 이러한 동작을 위해 transformers_utils의 configure_model_for_sarathi 류의 함수가 호출되어 모델의 forward signature를 변경하거나, past_key_values 대신 자체 cache 구조를 사용하도록 설정합니다.

RayWorker & BaseWorker (sarathi/engine/ray_utils.py, sarathi/worker/base_worker.py):

  • RayWorker는 Ray actor를 사용하여 BaseWorker를 원격으로 제어합니다. BaseWorker는 실제 작업을 실행하며, 초기화 후 _execution_loop를 통해 StepInputs를 받아 ModelRunner를 호출하고 결과를 전송합니다. 또한, 캐시 엔진 초기화, GPU 병렬 랭크 정보 제공, 모델 가중치 로드 및 동기화 등 엔진과 통신하기 위한 다양한 메소드를 포함합니다.

끝.




728x90