Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Welcome to Mini-YAIE

Mini-YAIE (Yet Another Inference Engine) is an educational project designed to demystify modern Large Language Model (LLM) inference engines.

Driven by the need for efficiency, modern engines like SGLang, vLLM, and TensorRT-LLM use sophisticated techniques to maximize GPU throughput and minimize latency. Mini-YAIE provides a simplified, clean implementation of these concepts, focusing on:

  • Continuous Batching
  • Paged KV Caching
  • Radix Attention (Prefix Sharing)

How to use this guide

This documentation is structured to take you from high-level concepts to low-level implementation.

  1. Core Concepts: Start here to understand the why and what of inference optimization.
  2. Architecture: Understand how the system components fit together.
  3. Implementation Guides: Step-by-step guides to implementing the missing “kernels” in Python and CUDA.

Your Mission

The codebase contains placeholders (NotImplementedError) for critical components. Your goal is to implement these components following this guide, turning Mini-YAIE from a skeleton into a fully functional inference engine.

Prerequisites

To successfully implement the kernels in Mini-YAIE, you should be familiar with:

Programming Languages

  • Python (Intermediate): Understanding of classes, inheritance, type hinting, and PyTorch tensors.
  • C++ (Basic): For reading and writing the CUDA kernels (though much of the boilerplate is provided).
  • CUDA (Basic): Understanding of the GPU execution model (blocks, threads, shared memory).

Machine Learning Concepts

  • Transformer Architecture: Queries, Keys, Values, Attention mechanism.
  • Tensors: Shapes, dimensions, matrix multiplication.

Tools

  • Git: For version control.
  • Linux/Unix Shell: For running commands.

Environment Setup

1. Clone the Repository

git clone https://github.com/yourusername/Mini-YAIE.git
cd Mini-YAIE

2. Python Environment

It is highly recommended to use a virtual environment.

python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
pip install -e .

3. CUDA Requirements (Optional)

To build and run the CUDA kernels, you need:

  • NVIDIA GPU (Compute Capability 7.0+)
  • CUDA Toolkit 11.8+
  • PyTorch with CUDA support

If you do not have a GPU, you can still implement the Python logic and the CPU fallback kernels.

4. Documentation Setup

To serve this documentation locally:

  1. Install mdbook:

    # If you have Rust/Cargo installed:
    cargo install mdbook
    
    # Or download the binary from their GitHub releases.
    
  2. Serve the docs:

    mdbook serve docs
    

    Navigate to http://localhost:3000 in your browser.

Model Loading

YAIE supports loading models from HuggingFace Hub with automatic caching and local model support.

ModelLoader Class

The ModelLoader class in src/models/loader.py handles all model and tokenizer loading operations.

Initialization

from src.models.loader import ModelLoader

# Load from HuggingFace Hub
loader = ModelLoader("microsoft/DialoGPT-medium")

# Load from local path
loader = ModelLoader("/path/to/local/model")

Loading Models

# Load the model
model = loader.load_model()

# Load the tokenizer
tokenizer = loader.load_tokenizer()

Supported Model Sources

HuggingFace Hub Models

YAIE can load any compatible model from HuggingFace Hub:

# Popular conversational models
loader = ModelLoader("microsoft/DialoGPT-medium")
loader = ModelLoader("microsoft/DialoGPT-large")

# Code generation models
loader = ModelLoader("Salesforce/codegen-350M-mono")

# General purpose models
loader = ModelLoader("gpt2")
loader = ModelLoader("gpt2-medium")

Local Models

You can also load models from local directories:

# Load from local path
loader = ModelLoader("./models/my-custom-model")

Caching Behavior

Automatic Caching

Models are automatically cached in the standard HuggingFace cache directory:

  • Linux/macOS: ~/.cache/huggingface/
  • Windows: C:\Users\<username>\.cache\huggingface\

Cache Structure

~/.cache/huggingface/
├── hub/
│   └── models--microsoft--DialoGPT-medium/
│       ├── blobs/
│       ├── refs/
│       └── snapshots/
│           └── abc123.../
│               ├── config.json
│               ├── pytorch_model.bin
│               └── tokenizer.json

Cache Management

The loader automatically:

  1. Checks for existing models in cache
  2. Downloads missing models from HuggingFace Hub
  3. Uses cached models for subsequent loads

Model Configuration

Data Types

Models are loaded with optimized data types:

# Models are loaded with float16 by default for efficiency
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,  # Half precision
    device_map="auto"          # Automatic device placement
)

Device Placement

  • Single GPU: Model is loaded directly to GPU
  • Multi-GPU: Automatically distributed across available GPUs
  • CPU: Falls back to CPU if no GPU available

Tokenizer Configuration

Automatic Pad Token

The loader ensures tokenizers have a pad token:

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

This is important for batch processing where sequences need to be padded to the same length.

Memory Optimization

Lazy Loading

Models are loaded on-demand, not at import time:

# Model is not loaded here
loader = ModelLoader("gpt2")

# Model is loaded here when requested
model = loader.load_model()

Memory Mapping

Large models use memory mapping to reduce RAM usage during loading.

Error Handling

Network Issues

If downloading fails, the loader will retry and provide clear error messages.

Incompatible Models

Models must be compatible with AutoModelForCausalLM. Incompatible models will raise clear errors.

Disk Space

Large models require significant disk space. The loader shows download progress and estimated sizes.

Performance Tips

Pre-download Models

For production deployments, pre-download models:

# This will cache the model
python -c "from src.models.loader import ModelLoader; loader = ModelLoader('microsoft/DialoGPT-medium'); loader.load_model()"

Cache Location

You can customize the cache location by setting the HF_HOME environment variable:

export HF_HOME=/path/to/custom/cache

Model Selection

Choose appropriate model sizes for your hardware:

  • Small models (< 1GB): gpt2, DialoGPT-small
  • Medium models (1-5GB): gpt2-medium, DialoGPT-medium
  • Large models (> 5GB): gpt2-large, DialoGPT-large

Troubleshooting

Common Issues

“Model not found” errors:

  • Check model name spelling
  • Verify model exists on HuggingFace Hub
  • Ensure internet connection for downloads

Out of memory errors:

  • Try smaller models
  • Reduce batch sizes in configuration
  • Use CPU-only mode if GPU memory is insufficient

Tokenizer issues:

  • Some models may require special token handling
  • Check the model’s documentation on HuggingFace Hub docs/src/intro/model_loading.md

LLM Inference: The Basics

Large Language Model (LLM) inference is the process of generating text from a trained model. It consists of two distinct phases.

1. Prefill Phase (The “Prompt”)

  • Input: The user’s prompt (e.g., “Write a poem about cats”).
  • Operation: The model processes all input tokens in parallel.
  • Output: The KV (Key-Value) cache for the prompt and the first generated token.
  • Characteristic: Compute-bound. We maximize parallelism here.

The Process Visualized

sequenceDiagram
    participant U as User
    participant E as Engine
    participant M as Model

    rect rgb(200, 220, 255)
    note right of U: Prefill Phase (Parallel)
    U->>E: Prompt: "A B C"
    E->>M: Forward(["A", "B", "C"])
    M-->>E: KV Cache + Logits(C)
    end

    rect rgb(220, 255, 200)
    note right of U: Decode Phase (Serial)
    loop Until EOS
        E->>M: Forward([Last Token])
        M-->>E: Update KV + Logits
        E->>E: Sample Next Token
    end
    end
    E->>U: Response

2. Decode Phase (The “Generation”)

  • Input: The previously generated token.
  • Operation: The model generates one token at a time, autoregressively.
  • Output: The next token and an updated KV cache.
  • Characteristic: Memory-bound. We are limited by how fast we can move weights and KV cache from memory to the compute units.

The KV Cache

State management is crucial. Instead of re-computing the attention for all previous tokens at every step, we cache the Key and Value vectors for every token in the sequence. This is the KV Cache. Managing this cache efficiently is the main challenge of high-performance inference engines.

Continuous Batching

The Problem: Static Batching

In traditional deep learning (like training), we use static batches: all sequences in a batch must have the same length (padded to the max length).

  • Waste: Padding wastes computation and memory.
  • Latency: We must wait for the longest sequence to finish generating before finishing the batch.

Visualizing the Difference

gantt
    title Static Batching (Inefficient)
    dateFormat YYYY-MM-DD
    axisFormat %H:%M

    section Batch 1
    Req A (Short) :done, a1, 2024-01-01, 2d
    Padding       :crit, 2024-01-03, 2d
    Req B (Long)  :active, b1, 2024-01-01, 4d

    section Batch 2
    Req C :c1, 2024-01-05, 2d
gantt
    title Continuous Batching (Efficient)
    dateFormat YYYY-MM-DD
    axisFormat %H:%M

    section GPU Stream
    Req A (Short) :done, a1, 2024-01-01, 2d
    Req C (New!)  :active, c1, 2024-01-03, 2d

    section GPU Stream 2
    Req B (Long)  :active, b1, 2024-01-01, 4d

The Solution: Continuous Batching (Orca)

Introduced by the Orca paper, Continuous Batching (or Iteration-level Batching) decouples the implementation of a batch from the user’s view.

  1. Iteration Level: The engine runs one iteration (one forward pass) at a time.
  2. Dynamic Insertion: As soon as a request finishes, it enters the “Completed” state. A new request from the queue can immediately take its place in the next iteration.
  3. No Padding: We process only the valid tokens for each request.

This significantly improves throughput (requests per second) without hurting latency (time per token) for individual requests.

Radix Attention (SGLang)

Radix Attention is the core innovation of SGLang. It optimizes the Prefill Phase by reusing computation from previous requests.

The Intuition

If two users ask:

  1. “Write a Python script to scrape a website.”
  2. “Write a Python script to sort a list.”

They share the prefix “Write a Python script to “. In a standard engine, we would compute the KV cache for this prefix twice.

graph TD
    classDef shared fill:#aaffaa,stroke:#333,stroke-width:2px;
    classDef unique fill:#ffaaaa,stroke:#333,stroke-width:2px;

    Root((Root)) --> Node1["Write a Python script to"]:::shared
    Node1 --> Node2["scrape a website"]:::unique
    Node1 --> Node3["sort a list"]:::unique

    style Node1 fill:#aaffaa

The Radix Tree

SGLang maintains a Radix Tree (Trie) of all token sequences currently in the KV cache.

  • Nodes: Sequences of tokens.
  • Edges: Transitions to new tokens.

When a new request arrives, we map its prompt to the longest matching path in the Radix Tree.

  • Hit: We reuse the KV Cache for the matched part. The prefill only needs to compute the new suffix.
  • Miss: We compute from scratch.

Benefits

  • Reduced Latency: “Time To First Token” (TTFT) is nearly zero for cached prefixes.
  • Higher Throughput: Less computation required per request.
  • Complex Workflows: Enables efficient multi-turn chat, few-shot learning, and tree-of-thought prompting.

Paged Attention (vLLM)

Paged Attention is the core innovation of vLLM. It optimizes the Decode Phase by managing memory like an Operating System.

The Problem: Memory Fragmentation

Before vLLM, engines allocated contiguous memory for the maximum possible length of a request.

  • Internal Fragmentation: If a request was shorter than max length, memory was wasted.
  • External Fragmentation: We couldn’t fit a new request even if total free memory was sufficient, because no single contiguous block was large enough.

The Solution: Paging

Inspired by virtual memory in OS:

  1. Blocks: Divide KV Cache into fixed-size blocks (e.g., 16 tokens per block).
  2. Non-Contiguous: Blocks can be stored anywhere in physical GPU memory.
  3. Mapping: A “Block Table” maps logical token positions to physical block addresses.
graph LR
    subgraph Logical[Logical Sequence Request]
        L0[Block 0: "Hello"]
        L1[Block 1: "World"]
        L2[Block 2: "!"]
    end

    subgraph Table[Page Table]
        T0[0 -> 7]
        T1[1 -> 2]
        T2[2 -> 9]
    end

    subgraph Physical[GPU Memory Physical Blocks]
        B0[Block 0]
        B1[Block 1]
        B2[Block 2: "World"]:::used
        B3...
        B7[Block 7: "Hello"]:::used
        B8...
        B9[Block 9: "!"]:::used
    end

    L0 --> T0 --> B7
    L1 --> T1 --> B2
    L2 --> T2 --> B9

    classDef used fill:#aaffaa;

The Kernel

The Paged Attention kernel allows the Attention mechanism to read keys and values from these non-contiguous blocks on the fly, enabling near-zero memory waste.

System Architecture Overview

Introduction

Mini-YAIE (Yet Another Inference Engine) is an educational implementation of modern LLM inference techniques, specifically designed to demonstrate concepts from state-of-the-art systems like SGLang, vLLM, and TensorRT-LLM. The architecture focuses on three core optimizations:

  1. Continuous Batching: Dynamically batching incoming requests to maximize GPU utilization
  2. Radix Attention: Efficient attention mechanism with prefix sharing for similar requests
  3. Paged KV-Cache: Memory-efficient key-value cache management

High-Level Architecture

┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐
│   API Layer     │    │  Engine Core    │    │  Model/Kernels  │
│  (FastAPI)      │◄──►│  (Scheduler,   │◄──►│  (PyTorch/     │
│                 │    │  Attention)    │    │  CUDA)         │
└─────────────────┘    └─────────────────┘    └─────────────────┘
         ▲                       ▲                       ▲
         │                       │                       │
┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐
│   CLI Layer     │    │  Model Loading  │    │  Memory Mgmt    │
│  (yaie serve/   │    │  (HuggingFace  │    │  (Paged Cache)  │
│   yaie chat)    │    │  Integration)   │    │                 │
└─────────────────┘    └─────────────────┘    └─────────────────┘

Core Components

1. Main Inference Engine (engine.py)

The main inference engine orchestrates all components and provides the high-level API for inference. It implements SGLang-style continuous batching with radix attention and prefix sharing.

Key Responsibilities:

  • Request orchestration and management
  • Integration between scheduler, attention mechanisms, and memory management
  • API layer communication
  • Model loading and tokenizer management

2. SGLang Scheduler (core/sglang_scheduler.py)

The SGLang-style scheduler implements advanced request scheduling with:

  • Prefix-based request grouping: Groups requests with common prefixes for computation sharing
  • Separate prefill and decode scheduling: Optimizes for the different computational patterns
  • Memory-aware batch sizing: Considers available KV-cache memory when scheduling
  • Continuous batching optimization: Maintains high GPU utilization

3. Radix Attention System (kernels/radix_attention.py)

Implements the radial attention mechanism with:

  • Prefix sharing: Reduces redundant computation for requests with common prefixes
  • Paged KV-cache integration: Efficient memory management for variable-length requests
  • RoPE (Rotary Position Embeddings): Supports position-aware attention

4. Paged KV-Cache Management (kernels/kv_cache.py)

Efficient memory management using page-based allocation:

  • Fixed-size blocks: Reduces memory fragmentation
  • Request-to-block mapping: Tracks which blocks belong to which requests
  • Dynamic allocation/deallocation: Manages memory based on request lifecycle

5. Radix Tree System (kernels/radix_tree.py)

Enables efficient prefix matching and computation sharing:

  • Trie-based structure: Organizes token sequences hierarchically
  • Request grouping: Identifies requests with shared prefixes
  • Computation optimization: Provides information for scheduler optimization

6. Sampling Kernel (kernels/sampling.py)

Implements core sampling algorithms:

  • Temperature scaling: Controls randomness in generation
  • Top-K sampling: Limits selection to top K most probable tokens
  • Top-P (Nucleus) sampling: Limits selection to tokens that sum to probability P

7. API Server (server/api.py)

Provides OpenAI-compatible API endpoints:

  • RESTful design: Follows OpenAI’s API specification
  • Streaming support: Real-time token streaming
  • Health monitoring: Server status endpoints

Data Flow

The system processes requests in the following sequence:

  1. Request Arrival: Client sends a request through the API layer
  2. Request Scheduling: SGLang scheduler groups requests with common prefixes
  3. Prefill Phase: Process full prompt sequences using radial attention
  4. Decode Phase: Generate tokens one-by-one with shared computation
  5. KV-Cache Management: Efficient memory allocation and sharing
  6. Response Generation: Return results via API layer

Key Design Principles

Modularity

Each component is designed to be independent, allowing for focused learning and experimentation.

Educational Focus

Clean, well-documented code with comprehensive explanations of key concepts.

SGLang-Style Optimization

Focus on prefix sharing and radix trees to maximize computational efficiency.

Memory Efficiency

Paged cache management to reduce memory fragmentation and maximize utilization.

Architecture Benefits

  1. High Throughput: Continuous batching and prefix sharing maximize GPU utilization
  2. Memory Efficiency: Paged KV-cache reduces fragmentation and enables larger batch sizes
  3. Scalability: Modular design allows for optimization of individual components
  4. Educational Value: Clean implementation of state-of-the-art techniques

Integration Points

The system integrates components through well-defined interfaces:

  • Engine connects to scheduler for request management
  • Scheduler connects to memory manager for KV-cache coordination
  • Attention mechanisms access KV-cache through the memory manager
  • Sampler provides token selection for generation
  • API layer communicates with the engine for request processing

Configuration Management

Overview

Mini-YAIE uses a flexible configuration system that allows users to customize various aspects of the inference engine without modifying the code. The configuration system provides settings for memory management, scheduling, model loading, and performance optimization.

Configuration Structure

The configuration system is built around the SGLangConfig dataclass in src/config.py. The system supports:

  • Dataclass-based configuration with type hints
  • Default values for all parameters
  • Dictionary-based overrides
  • Component-specific configuration sections

Key Configuration Parameters

Scheduler Configuration

# Maximum batch size for processing requests
max_batch_size: int = 8

# Maximum batch size for prefill operations
max_prefill_batch_size: int = 16

# Maximum batch size for decode operations
max_decode_batch_size: int = 256

# Maximum sequence length allowed
max_seq_len: int = 2048

KV Cache Configuration

# Number of GPU memory blocks for KV-cache
num_gpu_blocks: int = 2000

# Number of CPU memory blocks for swapping
num_cpu_blocks: int = 1000

# Size of each memory block (in tokens)
block_size: int = 16

Model Configuration

# Data type for model weights and KV-cache
dtype: str = "float16"  # Options: "float16", "float32", "bfloat16"

# Tensor parallelism size
tensor_parallel_size: int = 1

# GPU memory utilization fraction
gpu_memory_utilization: float = 0.9

# CPU swap space in GB
swap_space: int = 4

Generation Configuration

# Default maximum tokens to generate per request
default_max_tokens: int = 1024

# Default sampling temperature
default_temperature: float = 1.0

# Default top-p value
default_top_p: float = 1.0

SGLang-Specific Features

# Enable radix attention cache for prefix sharing
enable_radix_cache: bool = True

# Enable chunked prefill for long prompts
enable_chunked_prefill: bool = True

# Scheduling policy: "fcfs" (first-come-first-served)
schedule_policy: str = "fcfs"

# Enable prefix caching
enable_prefix_caching: bool = True

# Maximum scheduling steps before preemption
max_num_schedule_steps: int = 1000

Configuration Loading

Default Configuration

When no explicit configuration is provided, the system uses sensible defaults that work well for most educational purposes:

  • Conservative memory usage to work on most GPUs
  • Balanced performance settings
  • Safe batch sizes that avoid out-of-memory errors

Custom Configuration

Users can customize configurations by:

  1. Direct parameter passing to constructors
  2. Environment variables for deployment scenarios
  3. Configuration files (when implemented)

Configuration Best Practices

Performance Tuning

For production use, consider these configuration adjustments:

  • Increase batch sizes based on available GPU memory
  • Adjust block size for optimal cache utilization
  • Tune memory pool size based on request patterns

Memory Management

Configure memory settings based on your hardware:

# For high-end GPUs (24GB+ VRAM)
num_blocks = 4000
max_batch_size = 32

# For mid-range GPUs (8-16GB VRAM)
num_blocks = 1000
max_batch_size = 8

# For entry-level GPUs (4-8GB VRAM)
num_blocks = 500
max_batch_size = 4

Integration with Components

Engine Integration

The main engine uses the SGLangConfig for initialization:

from src.config import SGLangConfig, get_sglang_config

# Use default config
config = get_sglang_config()

# Or override specific parameters
config = get_sglang_config(
    max_batch_size=16,
    num_gpu_blocks=4000
)

# Initialize components with config values
scheduler = SGLangScheduler(
    max_batch_size=config.max_batch_size,
    max_prefill_batch_size=config.max_prefill_batch_size,
    max_decode_batch_size=config.max_decode_batch_size
)

Scheduler Configuration

The SGLang scheduler uses configuration for scheduling policies:

  • Batch size limits
  • Prefill/decode phase sizing
  • Memory-aware scheduling decisions

Memory Manager Configuration

The KV-cache manager uses configuration for:

  • Total memory pool size
  • Block allocation strategies
  • Memory optimization policies

Environment-Specific Configuration

Development Configuration

For development and learning:

  • Conservative memory limits
  • Detailed logging
  • Debug information enabled

Production Configuration

For production deployment:

  • Optimized batch sizes
  • Performance-focused settings
  • Minimal logging overhead

Configuration Validation

The system validates configuration parameters to prevent:

  • Memory allocation failures
  • Invalid parameter combinations
  • Performance-degrading settings

Future Extensions

The configuration system is designed to accommodate:

  • Model-specific optimizations
  • Hardware-aware tuning
  • Runtime configuration updates
  • Performance auto-tuning

Configuration Examples

Basic Configuration

# Minimal configuration for learning
config = {
    "max_batch_size": 4,
    "num_blocks": 1000,
    "block_size": 16
}

Performance Configuration

# Optimized for throughput
config = {
    "max_batch_size": 32,
    "max_decode_batch_size": 512,
    "num_blocks": 4000,
    "block_size": 32
}

Troubleshooting Configuration Issues

Memory Issues

If experiencing out-of-memory errors:

  1. Reduce num_blocks in KV-cache
  2. Lower batch sizes
  3. Check available GPU memory

Performance Issues

If experiencing low throughput:

  1. Increase batch sizes
  2. Optimize block size for your model
  3. Verify CUDA availability and compatibility

Engine Orchestration

Overview

The Inference Engine serves as the main orchestrator of the Mini-YAIE system, coordinating between various components to provide a unified interface for LLM inference. The engine implements SGLang-style continuous batching with radial attention and prefix sharing to maximize efficiency and throughput.

Engine Architecture

The main engine is implemented in src/engine.py and follows a modular design pattern where each component is responsible for specific aspects of request processing:

┌─────────────────┐
│   API Layer     │  ← Requests enter here
├─────────────────┤
│ Engine Orchestration │  ← Coordination happens here
├─────────────────┤
│   Scheduler     │  ← Request scheduling
├─────────────────┤
│  Memory Manager │  ← KV-cache management
├─────────────────┤
│  Attention Core │  ← Radial attention computation
├─────────────────┤
│  Model/Kernel   │  ← Forward pass execution
└─────────────────┘

Core Components

1. Model Loading Integration

The engine handles model and tokenizer loading through the ModelLoader component:

def __init__(self, model_name: str):
    self.tokenizer: PreTrainedTokenizer = self._load_tokenizer()
    self.model = self._load_model()

This ensures that models are properly loaded from HuggingFace or local cache with appropriate configuration.

2. SGLang-Style Scheduler

The engine integrates with the SGLangScheduler for advanced request scheduling:

self.scheduler = SGLangScheduler(
    max_batch_size=8, 
    max_prefill_batch_size=16, 
    max_decode_batch_size=256
)

The scheduler implements prefix grouping and multi-step processing for computation sharing.

3. Radial Attention System

The engine connects to the radial attention mechanism:

self.radix_attention = RadixAttentionWithPagedKVCache(
    num_layers=self.model.config.num_hidden_layers,
    num_heads=self.model.config.num_attention_heads,
    head_dim=self.model.config.hidden_size // self.model.config.num_attention_heads,
)

4. KV-Cache Management

The engine manages memory through the KVCacheManager:

self.kv_cache_manager = KVCacheManager(
    num_blocks=2000,
    block_size=16,
    num_heads=self.model.config.num_attention_heads,
    head_dim=self.model.config.hidden_size // self.model.config.num_attention_heads,
    dtype=torch.float16,
)

Request Processing Flow

1. Request Addition

def generate(self, prompts: List[str], **kwargs) -> List[str]:
    # Add requests to scheduler
    request_ids = []
    for prompt in prompts:
        req_id = self.scheduler.add_request(prompt, **kwargs)
        request_ids.append(req_id)

2. Generation Loop

The engine runs a main generation loop that processes requests:

def _run_generation_loop(self, request_ids: List[str]) -> List[str]:
    # Process requests in batches
    # Handle prefill and decode phases
    # Manage KV-cache efficiently

3. Response Generation

The engine generates responses with proper tokenization and formatting:

# Generate response using the existing generate method
responses = self.generate([formatted_prompt], **kwargs)
generated_text = responses[0] if responses else ""

SGLang-Style Optimization Features

1. Continuous Batching

The engine supports continuous batching where requests at different stages can be processed together:

  • Prefill requests (processing full prompts)
  • Decode requests (generating single tokens)
  • Mixed batches combining both types

2. Prefix Sharing

The engine enables computation sharing for requests with common prefixes:

  • Radix tree identifies shared prefixes
  • Common computations are performed once
  • Results are shared among multiple requests

3. Memory Efficiency

The engine optimizes memory usage through:

  • Paged KV-cache management
  • Block allocation strategies
  • Memory reclamation for completed requests

API Integration

1. Chat Completion API

The engine provides OpenAI-compatible chat completion:

def chat_completion(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
    # Format messages using chat template
    # Process through generation pipeline
    # Return in OpenAI format

2. Streaming Support

The engine supports streaming responses for real-time applications:

def chat_completion_stream(self, messages: List[Dict[str, str]], **kwargs):
    # Generate tokens one by one
    # Yield chunks immediately
    # Maintain OpenAI stream format

Performance Optimization

1. Batch Size Management

The engine dynamically adjusts batch sizes based on available memory and request characteristics:

  • Prefill batches: Optimized for prompt processing
  • Decode batches: Optimized for token generation
  • Mixed batches: Balanced between both phases

2. Memory Management

The engine coordinates memory usage across components:

# Connect scheduler to memory manager for optimization
self.scheduler.connect_memory_manager(self.kv_cache_manager)

3. Computation Sharing

The engine maximizes computation sharing through radix attention:

  • Shared prefix processing
  • Common token computations
  • Reduced redundant calculations

Error Handling and Resilience

1. Request Validation

The engine validates requests before processing:

  • Input format validation
  • Parameter range checking
  • Resource availability verification

2. Graceful Degradation

When resources are constrained, the engine gracefully degrades:

  • Reduced batch sizes
  • Fallback mechanisms
  • Proper error reporting

3. Resource Management

The engine manages system resources effectively:

  • GPU memory monitoring
  • Request queue management
  • Memory cleanup for completed requests

Integration Points

1. Model Interface

The engine interfaces with any HuggingFace-compatible model:

outputs = self.model(current_ids)  # Standard HuggingFace model interface

2. Sampling Integration

The engine uses the sampling kernel for token generation:

sampling_kernel = SamplingKernel()
next_token_id = sampling_kernel.sample(
    next_token_logits,
    temperature=request.temperature,
    top_p=request.top_p
)

3. Scheduler Integration

The engine coordinates closely with the scheduler:

# Add requests to scheduler
req_id = self.scheduler.add_request(prompt, **kwargs)
# Process in generation loop
responses = self._run_generation_loop(request_ids)

Engine Configuration

The engine supports various configuration options:

  • Model selection and loading
  • Batch size limits
  • Memory allocation settings
  • Performance optimization parameters

Future Extensions

The engine design supports:

  • Additional optimization techniques
  • New attention mechanisms
  • Enhanced scheduling algorithms
  • Advanced memory management strategies

Scheduler Logic: SGLang-Style Request Management

Overview

The SGLang-style scheduler (core/sglang_scheduler.py) implements advanced request scheduling with prefix grouping and computation sharing capabilities. Unlike traditional schedulers, this implementation focuses on maximizing computational efficiency by identifying and leveraging shared prefixes among different requests.

Key Concepts

Request States

The scheduler manages requests through several states:

  • PENDING: New requests awaiting initial processing
  • SCHEDULED_PREFILL: Requests ready for prefill phase
  • RUNNING_PREFILL: Currently processing full prompts
  • SCHEDULED_DECODE: Requests ready for token generation
  • RUNNING_DECODE: Currently generating tokens
  • COMPLETED: Finished requests
  • CANCELLED: Cancelled requests

Prefix-Based Grouping

The scheduler uses prefix hashing to group requests with common prefixes:

def _calculate_prefix_hash(self, prompt: str) -> Optional[str]:
    # Calculate hash to identify common prefixes
    return hashlib.sha256(prompt.encode("utf-8")).hexdigest()

Architecture

Core Data Structures

Request Management

# Separate queues for different processing phases
self.pending_requests: List[Request] = []
self.prefill_requests: List[Request] = []
self.running_prefill: List[Request] = []
self.decode_requests: List[Request] = []
self.running_decode: List[Request] = []
self.completed_requests: List[Request] = []

Prefix Grouping

# Group requests by common prefixes for shared computation
self.prefix_groups: Dict[str, List[Request]] = defaultdict(list)
self.request_lookup: Dict[str, Request] = {}

Scheduling Strategy

The scheduler implements a SGLang-inspired strategy:

  1. Prioritize Decode Requests: Minimize token-to-token latency
  2. Maximize Prefill Efficiency: Process new requests efficiently
  3. Leverage Prefix Sharing: Share computation for similar requests
  4. Memory-Aware Scheduling: Respect KV-cache limitations

Detailed Implementation

Request Lifecycle

1. Request Addition

def add_request(self, prompt: str, max_tokens: int = 128, ...) -> str:
    # Calculate prefix hash for grouping
    prefix_hash = self._calculate_prefix_hash(prompt)
    # Add to prefix group if applicable
    if prefix_hash:
        self.prefix_groups[prefix_hash].append(request)
        request.request_group = prefix_hash

2. Scheduling Step

def schedule_step(self) -> Tuple[List[Request], List[Request]]:
    # First, prioritize decode requests
    decode_batch = []
    prefill_batch = []
    
    # Calculate remaining capacity after decode allocation
    remaining_capacity = self.max_batch_size - len(decode_batch)
    
    # Fill remaining capacity with prefill requests
    if remaining_capacity > 0:
        num_prefills = min(len(prefill_candidates), remaining_capacity, self.max_prefill_batch_size)
        prefill_batch = prefill_candidates[:num_prefills]

Batch Selection Policy

The scheduler implements a multi-level priority system:

  1. Decode Priority: Continue existing generation to minimize latency
  2. Prefill Efficiency: Process new requests in efficient batches
  3. Memory Management: Ensure sufficient KV-cache for all requests
  4. Prefix Sharing: Group similar requests for computation sharing

Prefill Processing

Prefill requests undergo full prompt processing:

def process_prefill_batch(self, requests: List[Request]) -> List[Request]:
    for req in requests:
        # Process full prompt in one forward pass
        req.status = RequestStatus.SCHEDULED_DECODE
        # Initialize output sequence
        if req.output_ids is None:
            req.output_ids = []

Decode Processing

Decode requests generate tokens one-by-one:

def process_decode_batch(self, requests: List[Request]) -> List[Request]:
    for req in requests:
        # Get logits from model (simplified)
        dummy_logits = torch.randn(1, 1000)
        
        # Sample next token using kernel
        next_token_tensor = self.sampling_kernel.sample(
            dummy_logits,
            temperature=req.temperature,
            top_p=req.top_p
        )
        
        # Update position and check termination
        req.current_position += 1
        if req.current_position >= req.max_tokens:
            req.status = RequestStatus.COMPLETED

SGLang-Style Optimizations

1. Computation Sharing

The scheduler identifies requests with shared prefixes:

def find_shared_prefixes(self, token_ids: List[int]) -> Tuple[List[str], List[int]]:
    # Traverse radix tree to find matching prefixes
    # Return requests that can share computation

2. Memory-Aware Scheduling

The scheduler connects to the memory manager for KV-cache coordination:

def connect_memory_manager(self, memory_manager):
    self.memory_manager = memory_manager

3. Continuous Batching

The scheduler maintains continuous processing by balancing prefill and decode requests:

  • Decode requests have higher priority (latency-sensitive)
  • Prefill requests fill remaining batch capacity
  • Memory requirements are considered during scheduling

Performance Considerations

Batch Size Optimization

The scheduler uses different batch size limits:

  • max_prefill_batch_size: Limits prefill batch size for memory efficiency
  • max_decode_batch_size: Larger limit for decode due to smaller memory footprint
  • max_batch_size: Overall system limit

Memory Management

The scheduler coordinates with the KV-cache manager to:

  • Allocate blocks for new requests
  • Track memory usage during processing
  • Ensure sufficient memory for scheduled requests

Integration with Other Components

Memory Manager Integration

def process_prefill_batch(self, requests: List[Request]) -> List[Request]:
    if self.memory_manager:
        # Allocate KV cache blocks for requests
        pass

Sampling Kernel Integration

def process_decode_batch(self, requests: List[Request]) -> List[Request]:
    # Use sampling kernel for token selection
    next_token_tensor = self.sampling_kernel.sample(...)

Request Status Monitoring

Queue Status

The scheduler provides detailed status information:

def get_queue_status(self) -> Dict[str, int]:
    return {
        "pending": len(self.pending_requests),
        "prefill_queue": len(self.prefill_requests),
        "running_prefill": len(self.running_prefill),
        "decode_queue": len(self.decode_requests),
        "running_decode": len(self.running_decode),
        "completed": len(self.completed_requests),
        "total_active": self.get_active_request_count(),
    }

Request Result Access

def get_request_result(self, req_id: str) -> Optional[Dict[str, Any]]:
    # Check completed requests for results

Implementation Challenges

1. Prefix Hashing

For educational purposes, the implementation uses simple string hashing. In production:

  • Use token ID sequences for more accurate prefix matching
  • Implement more sophisticated similarity measures
  • Consider semantic similarity for better grouping

2. Memory Allocation

The current implementation shows integration points for memory management. A full implementation would:

  • Calculate precise memory requirements
  • Implement cache eviction policies
  • Handle memory fragmentation

3. Computation Sharing

The radix tree integration points exist but require full implementation of:

  • Efficient tree traversal
  • Shared computation tracking
  • Result distribution to multiple requests

Scheduling Algorithm Details

Step-by-Step Process

  1. Decode Prioritization: Schedule as many decode requests as possible
  2. Capacity Calculation: Determine remaining batch capacity
  3. Prefill Scheduling: Fill remaining capacity with prefill requests
  4. Memory Verification: Confirm sufficient KV-cache availability
  5. Batch Execution: Process scheduled requests

Optimization Strategies

The scheduler implements several optimization strategies:

  1. Temporal Multiplexing: Interleave prefill and decode for efficiency
  2. Spatial Multiplexing: Group similar requests for shared computation
  3. Memory Multiplexing: Optimize KV-cache usage across requests

Future Extensions

The scheduler design supports:

  • Advanced prefix matching algorithms
  • Dynamic batch size adjustment
  • Request preemption and rescheduling
  • Multi-GPU coordination
  • Custom scheduling policies

Memory Management: Paged KV-Cache System

Overview

The memory management system in Mini-YAIE implements a paged KV-cache mechanism inspired by systems like vLLM and SGLang. This approach addresses the memory fragmentation challenges in LLM inference by using fixed-size memory blocks (pages) that can be allocated and deallocated independently for each request.

Core Concepts

Paged Memory Architecture

Traditional KV-cache management allocates contiguous memory blocks for each request, leading to fragmentation when requests have varying lengths. Paged KV-cache solves this by:

  • Dividing the KV-cache into fixed-size blocks (pages)
  • Allowing requests to use non-contiguous memory blocks
  • Enabling efficient memory reuse and sharing

Key Benefits

  1. Reduced Fragmentation: Fixed-size blocks prevent memory fragmentation
  2. Efficient Memory Utilization: Unused blocks can be allocated to other requests
  3. Scalability: Supports variable-length requests without memory waste
  4. Computation Sharing: Enables shared prefixes to use the same memory blocks

Architecture

KVCacheBlock Class

Each KVCacheBlock represents a fixed-size memory block:

class KVCacheBlock:
    def __init__(self, block_id: int, size: int, num_heads: int, head_dim: int, ...):
        self.block_id = block_id  # Unique identifier for the block
        self.size = size          # Number of tokens this block can hold
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.keys = None          # [size, num_heads, head_dim] tensor
        self.values = None        # [size, num_heads, head_dim] tensor

KVCacheManager Class

The main manager orchestrates all memory operations:

class KVCacheManager:
    def __init__(self, num_blocks: int, block_size: int, num_heads: int, head_dim: int, ...):
        self.num_blocks = num_blocks      # Total number of blocks in the pool
        self.block_size = block_size      # Size of each block in tokens
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.blocks: List[KVCacheBlock] = []  # Pool of all blocks
        self.free_block_list: List[int] = []  # Available blocks for allocation
        self.block_tables: dict = {}      # Maps request_id to list of block_ids

Memory Management Operations

1. Block Allocation

When a request needs KV-cache memory, the manager allocates the required number of blocks:

def allocate_blocks(self, request_id: str, num_tokens: int) -> List[int]:
    # Calculate required blocks
    num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
    
    # Check availability
    if len(self.free_block_list) < num_blocks_needed:
        raise RuntimeError("Not enough free blocks")
    
    # Allocate from free list
    allocated_block_ids = []
    for _ in range(num_blocks_needed):
        block_id = self.free_block_list.pop(0)  # Remove from free list
        allocated_block_ids.append(block_id)
        self.blocks[block_id].allocate()  # Allocate GPU memory
    
    # Track allocation
    self.block_tables[request_id] = allocated_block_ids
    return allocated_block_ids

Key Aspects:

  • Calculates minimum blocks needed based on token count and block size
  • Ensures sufficient free blocks before allocation
  • Updates free block list and request tracking
  • Actually allocates GPU memory for the blocks

2. Block Deallocation

When requests complete, their blocks are returned to the free pool:

def free_blocks(self, request_id: str):
    if request_id in self.block_tables:
        block_ids = self.block_tables[request_id]
        self.free_block_list.extend(block_ids)  # Return to free pool
        self.free_block_list.sort()  # Maintain sorted order
        del self.block_tables[request_id]  # Remove tracking entry
        
        # Optionally clear tensors to free GPU memory
        for block_id in block_ids:
            self.blocks[block_id].keys = None
            self.blocks[block_id].values = None

Key Aspects:

  • Returns blocks to the free list for reuse
  • Maintains sorted order for allocation efficiency
  • Removes request tracking information
  • Optionally clears GPU tensors to free memory

3. Block Copying

For advanced operations like request preemption or memory defragmentation:

def copy_blocks(self, src_block_ids: List[int], dst_block_ids: List[int]):
    if len(src_block_ids) != len(dst_block_ids):
        raise ValueError("Source and destination lists must have same length")
    
    for src_id, dst_id in zip(src_block_ids, dst_block_ids):
        src_block = self.blocks[src_id]
        dst_block = self.blocks[dst_id]
        
        # Allocate destination if needed
        if dst_block.keys is None or dst_block.values is None:
            dst_block.allocate()
        
        # Copy data
        with torch.no_grad():
            dst_block.keys.copy_(src_block.keys)
            dst_block.values.copy_(src_block.values)

Memory Layout and Access

Block Organization

The memory is organized as a collection of fixed-size blocks:

Global Memory Pool:
┌─────────┬─────────┬─────────┬─────────┬─────────┬─────────┐
│ Block 0 │ Block 1 │ Block 2 │ Block 3 │ Block 4 │ Block 5 │
│ [16xHxD]│ [16xHxD]│ [16xHxD]│ [16xHxD]│ [16xHxD]│ [16xHxD]│
└─────────┴─────────┴─────────┴─────────┴─────────┴─────────┘

Request A: Uses [Block 0, Block 2] for non-contiguous sequence storage
Request B: Uses [Block 1, Block 4, Block 5] for its sequence

Where H = num_heads and D = head_dim

Block Tables

Each request has an associated block table that maps logical token positions to physical blocks:

Request A Block Table:
Logical: [0-15][16-31][32-47][48-63][64-79]
Physical: Block0 Block2  -     -     Block4

Integration with SGLang Features

Computation Sharing

The paged system enables computation sharing by allowing requests with shared prefixes to reference the same memory blocks:

  • Requests with common prefixes can share the same KV-cache blocks
  • Multiple requests can reference the same physical memory location
  • Reduces redundant computation and memory usage

Memory Efficiency

By using fixed-size blocks:

  • Memory fragmentation is eliminated
  • Block reuse is maximized
  • Memory utilization approaches optimal levels

Performance Considerations

Block Size Selection

The block size parameter is critical for performance:

  • Smaller blocks: Less internal fragmentation, more overhead for block management
  • Larger blocks: More internal fragmentation, less management overhead
  • Typical values: 8-32 tokens per block work well in practice

Memory Allocation Strategy

The system uses a simple first-fit strategy:

# Allocate from beginning of free list
block_id = self.free_block_list.pop(0)

For production systems, more sophisticated strategies might include:

  • Best-fit to minimize fragmentation
  • Coalescing strategies to combine blocks
  • Preallocation to reduce allocation overhead

Memory Safety and Management

GPU Memory Management

The system ensures proper GPU memory allocation:

def allocate(self):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.keys = torch.zeros(
        self.size, self.num_heads, self.head_dim, 
        dtype=self.dtype, device=device
    )
    self.values = torch.zeros(
        self.size, self.num_heads, self.head_dim, 
        dtype=self.dtype, device=device
    )

Memory Cleanup

Proper cleanup prevents memory leaks:

  • Free blocks when requests complete
  • Clear GPU tensors to release memory
  • Maintain consistent state in block tables

Advanced Features

Dynamic Block Resizing

For requests that need to extend beyond their initial allocation:

  • Allocate additional blocks as needed
  • Maintain logical sequence continuity
  • Update block tables accordingly

Memory Pool Management

Advanced implementations might include:

  • Block migration to reduce fragmentation
  • Eviction policies for memory-constrained scenarios
  • Prefetching strategies for better performance

Error Handling

Out of Memory Conditions

The system handles memory exhaustion gracefully:

if len(self.free_block_list) < num_blocks_needed:
    raise RuntimeError(f"Not enough free blocks. Need {num_blocks_needed}, have {len(self.free_block_list)}")

Block Validation

Before operations, the system validates block states:

  • Verify blocks are allocated before accessing
  • Check for proper tensor dimensions
  • Validate request associations

Future Enhancements

Memory Optimization

Potential improvements include:

  • Compressed KV-cache storage
  • Offloading to CPU memory when possible
  • Cache eviction policies for long-running requests

Performance Optimization

Advanced techniques might include:

  • Block prefetching for better cache performance
  • Heterogeneous memory management (different memory types)
  • Asynchronous memory operations

Implementation Variations

SGLang-Style Memory Management

For SGLang-specific optimizations:

  • Prefix sharing memory management
  • Radix tree integration for shared computation
  • Advanced scheduling based on memory access patterns

Integration Points

The memory manager connects with other components:

  • Scheduler: Provides memory availability information
  • Attention modules: Access KV-cache through block tables
  • Model execution: Uses paged cache for efficient attention computation

Python Kernels Guide

Overview

The Python kernels in Mini-YAIE implement the core computational components that enable SGLang-style inference optimization. These kernels provide the foundational functionality for attention mechanisms, memory management, and token sampling that make efficient LLM inference possible.

Kernel Architecture

Core Components

The kernel system consists of several interconnected modules:

  1. Radix Tree: Implements prefix matching for shared computation
  2. KV Cache Manager: Manages paged key-value storage
  3. Radix Attention Module: Implements radial attention with shared computation
  4. Sampling Module: Provides token selection algorithms

SGLang-Style Optimization

The kernels are designed to support SGLang’s key optimization strategies:

  • Prefix Sharing: Share computation for requests with common prefixes
  • Continuous Batching: Dynamically batch requests at different processing stages
  • Paged Memory Management: Efficiently manage KV-cache memory using fixed-size blocks
  • Radial Attention: Optimize attention computation for shared prefixes

Python Kernel Implementation

Design Philosophy

The Python kernels follow these design principles:

1. Educational Focus

  • Clean, well-documented code
  • Clear algorithm implementation
  • Comprehensive comments explaining concepts

2. SGLang Compatibility

  • Implement SGLang-style optimization techniques
  • Support for radial attention and prefix sharing
  • Continuous batching integration

3. Modularity

  • Independent components that can be tested individually
  • Clean interfaces between components
  • Easy to extend and modify

4. Performance Considerations

  • Efficient data structures
  • Proper memory management
  • Optimized algorithm implementations

Implementation Structure

Each kernel follows a similar pattern:

class KernelName:
    def __init__(self, parameters):
        # Initialize kernel with configuration
        pass
    
    def process(self, input_data):
        # Core processing logic
        pass
    
    def update_state(self, new_data):
        # State management for ongoing requests
        pass

Integration with System Components

Engine Integration

The kernels integrate seamlessly with the main inference engine:

# Engine uses kernels for computation
self.radix_attention = RadixAttentionWithPagedKVCache(...)
self.kv_cache_manager = KVCacheManager(...)
self.sampling_kernel = SamplingKernel()

Scheduler Coordination

Kernels work with the SGLang scheduler:

  • Provide computation sharing opportunities
  • Manage memory allocation and deallocation
  • Coordinate with scheduling policies

Memory Management

Kernels connect with the paged memory system:

  • Request memory allocation through the manager
  • Manage KV-cache blocks efficiently
  • Support for shared memory blocks

Performance Characteristics

Computational Efficiency

The Python kernels provide:

  • Efficient attention computation
  • Optimized memory access patterns
  • Shared computation for common prefixes

Memory Usage

Optimized memory management includes:

  • Paged cache allocation
  • Block-level memory sharing
  • Efficient reuse of allocated blocks

Scalability

The kernel design supports:

  • Variable batch sizes
  • Multiple concurrent requests
  • Scaled performance with more requests

Advanced Features

Computation Sharing

The radix tree and attention modules enable:

  • Shared prefix identification
  • Computation reuse across requests
  • Efficient memory utilization

Adaptive Processing

Kernels adapt to:

  • Different request patterns
  • Variable sequence lengths
  • Changing memory requirements

Testing and Validation

Unit Testing

Each kernel includes:

  • Comprehensive unit tests
  • Edge case validation
  • Performance benchmarking

Integration Testing

Kernels are tested as part of:

  • Full inference pipeline
  • SGLang-style optimization scenarios
  • Memory management validation

Extensibility

Adding New Kernels

The system supports:

  • Easy addition of new kernel types
  • Pluggable architecture for kernel replacement
  • Backwards compatibility

Customization

Kernels can be customized for:

  • Specific model architectures
  • Hardware optimization
  • Performance tuning

This Python kernel system forms the computational backbone of Mini-YAIE, implementing SGLang-style optimization techniques in an educational and accessible way.

Radix Tree (kernels/radix_tree.py)

1. Concept: The Prefix Tree

A Radix Tree (or Compressed Trie) is a data structure that succinctly stores sequences of tokens. Unlike a standard trie where each edge is a single character (or token), a Radix Tree allows edges to be sequences of tokens.

Optimization Goal

When two requests start with "The quick brown fox", we want to store that sequence once.

  • Request A: "The quick brown fox jumps"
  • Request B: "The quick brown fox sleeps"

In our tree, we should have a shared node for "The quick brown fox", which then branches into "jumps" and "sleeps".

flowchart TD
    Root((Root)) --> Shared["The quick brown fox"]
    Shared --> Branch1["jumps"]
    Shared --> Branch2["sleeps"]

    style Shared fill:#aaffaa

2. Implementation Guide

Open src/kernels/radix_tree.py. You will implement the RadixTree class step-by-step.

Step 1: Define the Tree Node

First, we need a node structure. Unlike a binary tree, a Radix Node can have many children.

class RadixTreeNode:
    def __init__(self, prefix: List[int]):
        self.prefix = prefix               # The sequence of tokens on this edge
        self.children: Dict[int, RadixTreeNode] = {} # Map: first_token -> Child Node
        self.request_id: Optional[str] = None # If a request ends here, store its ID
        self.lock_count = 0                # Reference counting (how many requests use this?)

Task: Locate the RadixTreeNode class and ensure it has these fields.


Step 2: Implement match_prefix

Before inserting, we need a way to see how much of a new prompt already exists in the tree.

Algorithm:

  1. Start at self.root.
  2. Compare the input token_ids with the edges in the tree.
  3. Traverse down as long as the tokens match exactly.
  4. Return the last matching Node and the number of matching tokens.

Your Turn: Implement find_shared_prefixes(token_ids) in RadixTree.

Hint: Use a while loop. At each node, look at node.children[token_ids[current_idx]]. If it exists, check if the full edge child.prefix matches the next chunk of your input.


Step 3: Implement insert_request (The Hard Part)

Now, inserting a new request. This involves splitting nodes if a partial match is found.

Scenario:

  • Tree has edge [1, 2, 3, 4].
  • You insert [1, 2, 5].

Algorithm:

  1. Trace the path like in Step 2.
  2. If you differ in the middle of an edge (e.g., matched 1, 2 but tree has 3, you have 5):
    • Split: Create a new parent node for [1, 2].
    • Make the old node [3, 4] a child of this new parent.
    • Create your new node [5] as another child.
flowchart TD
    Start([Insert 1, 2, 5]) --> Match{Match 1, 2?}
    Match -->|Yes| Diverge{Next is 3 vs 5}
    Diverge --> Split[Split Edge]
    Split --> Old[Child: 3, 4]
    Split --> New[Child: 5]

Your Turn: Implement insert_request(request_id, token_ids).

  • Use your match_prefix logic helper.
  • Handle the 3 cases: Exact match, New Branch, or Split Edge.

Step 4: Verify

Create a test script tests/test_radix_manual.py:

tree = RadixTree()
tree.insert_request("req1", [1, 2, 3])
match, count = tree.find_shared_prefixes([1, 2, 3, 4])
print(f"Matched {count} tokens") # Should be 3!

KV Cache Manager (kernels/kv_cache.py)

1. Concept: Paged Attention

In a standard implementation, KV Cache is a huge contiguous tensor [MAX_SEQ_LEN, HEADS, DIM]. This wastes memory because most prompts are shorter than MAX_SEQ_LEN.

Paged Attention breaks this tensor into small fixed-size blocks (e.g., size 16).

  • Physical Memory: A big pool of blocks [NUM_BLOCKS, 16, HEADS, DIM].
  • Logical Memory: For each request, we just keep a list of block indices [0, 5, 12].

Your job is to write the Allocator (like malloc in C).


2. Implementation Guide

Open src/kernels/kv_cache.py.

Step 1: Initialization

We need to track which blocks are free and which are used.

Task: In __init__:

  1. Create a list self.free_blocks. Initially, it should contain all integers from 0 to num_blocks - 1.
  2. Create a dictionary self.block_tables. This will map request_id -> List[int] (the list of blocks owned by that request).
# Hint
self.free_blocks = list(range(num_blocks))

Step 2: The allocate_blocks Method

When a request comes in (or generates new tokens), it needs memory.

Signature:

def allocate_blocks(self, request_id: str, num_tokens: int) -> List[int]:

Algorithm:

  1. Calculate how many blocks are needed.
    • $N_{blocks} = \lceil num_tokens / block_size \rceil$
  2. Check if we have enough free_blocks.
    • If len(free_blocks) < needed, raise an Error (or handle OOM).
  3. Pop the blocks from free_blocks.
  4. Assign them to self.block_tables[request_id].
  5. Return the list of allocated block indices.

Your Turn: Implement this logic. Watch out for integer division!


Step 3: The free_blocks Method

When a request finishes, we must reclaim memory.

Algorithm:

  1. Look up the blocks for request_id.
  2. Append them back to self.free_blocks.
  3. Delete the entry from self.block_tables.

Critical: Do not double-free! (Though Python sets make this easier, a list is faster for standard stacks).


Step 4: Connecting to the Engine

The get_kv_tensors method is checking if you can translate the “Logical” view to the “Physical” view.

Task: Implement get_kv_tensors.

  • It should presumably return the specific GPU tensors for the blocks.
  • Note: In this Python simulation, just returning the indices is often enough for the Scheduler to know mapping. The actual Tensor access happens in the CUDA kernel.

Step 5: Verify

Create tests/test_kv_manual.py:

manager = KVCacheManager(num_blocks=10, block_size=16, ...)
# Alloc 20 tokens -> needs 2 blocks (indices 0, 1)
blocks = manager.allocate_blocks("req1", 20)
print(blocks)
# Free
manager.free_blocks("req1")
print(len(manager.free_blocks)) # Should be 10 again

Radix Attention Module (kernels/radix_attention.py)

1. Concept: Connecting the Dots

We have a Radix Tree (prefix matching) and a Paged KV Cache (memory management). The RadixAttentionWithPagedKVCache class is the glue that runs on the CPU (Python side) to manage these resources before we launch the GPU kernels.

It doesn’t run the attention math (that’s the CUDA kernel’s job). Instead, it manages the metadata:

  • “Request A needs to append ‘cat’ to its sequence.”
  • “Does ‘cat’ already exist in the Radix Tree?”
  • “If yes, reuse the block.”
  • “If no, allocate a new block.”

2. Implementation Guide

Open src/kernels/radix_attention.py.

Step 1: Initialization

You need to initialize the two sub-components we built earlier.

class RadixAttentionWithPagedKVCache:
    def __init__(self, ...):
        # ...
        self.radix_tree = RadixTree()
        self.kv_cache_manager = KVCacheManager(...)

Step 2: append_slot (The Critical Logic)

This method is called when we want to add a new token (or tokens) to a request.

Signature:

def append_slot(self, key: torch.Tensor, value: torch.Tensor, request_id: str):
  • key/value: The computed K/V tensors for the new token(s).

Algorithm:

  1. Check Tree: Use self.radix_tree to see if this (request_id + new_token) path already exists?
    • Note: In a real system, we check before computing K/V. Here, we might just be managing the cache storage.
  2. Allocate: If we need new space, call self.kv_cache_manager.allocate_blocks().
  3. Store: We need to perform the copy.
    • Ideally, we just return the indices of where to write, and the GPU kernel does the writing.
    • For this Python simulation, you might simulate the copy or just track the metadata.

Step 3: get_kv_cache

The scheduler asks: “I am about to run requests [R1, R2]. Where is their data?”

Algorithm:

  1. Loop through request_ids.
  2. For each, ask self.kv_cache_manager for its block table (list of integers).
  3. Pack these lists into a single Tensor block_tables.
  4. Return block_tables to the Engine.

Step 4: free_request

When a request is done:

  1. self.radix_tree.remove_request(request_id) (Decrement ref counts).
  2. self.kv_cache_manager.free_blocks(request_id) (Reclaim memory).

3. The RadixAttentionBlock (Model Layer)

The class RadixAttentionBlock is the PyTorch module that sits in the model.

Task: In forward():

  1. Compute Q, K, V projections.
  2. Compute RoPE (Rotary Embeddings).
  3. If Prefill: Use Flash Attention (or a standard attention) on the new tokens.
  4. If Decode:
    • Call append_slot to save the new K/V.
    • Call paged_attention_kernel (the CUDA op) to attend to the entire history using the block tables.

Exercise: Since we don’t have the full model weight loading for this specific block, focus on the logic flow in the comments.

Sampling (kernels/sampling.py)

1. Concept: From Logits to Tokens

The model outputs logits: a vector of size [VOCAB_SIZE] (e.g., 50,000) containing raw scores for the next token. We need to pick one token ID.

  • Greedy: Just pick argmax(). Boring, repetitive.
  • Sampling: Pick randomly based on probability. Creative!

We control the randomness with Temperature, Top-P (Nucleus), and Top-K.


2. Implementation Guide

Open src/kernels/sampling.py.

Step 1: Temperature Scaling

Temperature ($T$) controls confidence.

  • $T < 1$: Makes peakier (more confident).
  • $T > 1$: Makes flatter (more random).

Algorithm:

logits = logits / temperature
  • Watch out: If $T$ is very close to 0, just do argmax to avoid division by zero!

Step 2: Softmax

Convert logits to probabilities (0.0 to 1.0).

probs = torch.softmax(logits, dim=-1)

Step 3: Top-K Filtering

Keep only the $K$ most likely tokens. Zero out the rest.

Algorithm:

  1. Find the value of the $K$-th highest score.
  2. Mask (set to $-\infty$) anything below that value in logits (or 0 in probs).
  3. Re-normalize probabilities.

Step 4: Top-P (Nucleus) Filtering (The Tricky One)

Keep the smallest set of tokens whose cumulative probability adds up to $P$ (e.g., 0.9). This dynamically truncates the long tail of “nonsense” words.

Algorithm:

  1. Sort probabilities in descending order: sorted_probs, sorted_indices = torch.sort(probs, descending=True).
  2. Compute cumulative sum: cumulative_probs = torch.cumsum(sorted_probs, dim=-1).
  3. Find cut-off: Mask where cumulative_probs > top_p.
    • Tip: You want to include the first token that crosses the threshold. So shift the mask right by one.
  4. Scatter the mask back to the original ordering.
  5. Re-normalize.

Step 5: The Final Selection

Once you have your clean probability distribution:

next_token = torch.multinomial(probs, num_samples=1)

Your Turn: Implement sample in SamplingKernel. Start simple (just Temperature) and verify, then add Top-P.

Build System and CUDA Kernels

Overview

The Mini-YAIE project includes a comprehensive build system for compiling custom CUDA kernels that provide optimized performance for SGLang-style inference operations. The build system is designed to be both educational and production-ready, allowing users to learn CUDA kernel development while achieving high performance.

Build System Architecture

Build Scripts

The project provides multiple ways to build kernels:

Shell Script: build_kernels.sh

#!/bin/bash
# Comprehensive build script for CUDA kernels
./build_kernels.sh

Makefile Integration

# Build kernels using make
make build-kernels

Direct Python Build

# Direct build using Python setup
python setup_kernels.py build_ext --inplace

Build Dependencies

The build system requires:

  • CUDA Toolkit: Version 11.0 or higher
  • PyTorch with CUDA Support: For CUDA extensions
  • NVIDIA GPU: With compute capability >= 6.0
  • System Compiler: GCC/Clang with C++14 support
  • Python Development Headers: For Python C API

CUDA Kernel Design

SGLang-Style Optimization Focus

The CUDA kernels are specifically designed to optimize SGLang-style inference operations:

  1. Radix Tree Operations: Efficient prefix matching on GPU
  2. Paged Attention: Optimized attention for paged KV-cache
  3. Memory Operations: High-performance memory management
  4. Radix Attention: GPU-optimized radial attention with prefix sharing

Performance Goals

The kernels target SGLang-specific optimizations:

  • Memory Bandwidth Optimization: Minimize memory access overhead
  • Computation Sharing: Implement efficient prefix sharing on GPU
  • Batch Processing: Optimize for variable-length batch processing
  • Cache Efficiency: Optimize for GPU cache hierarchies

CUDA Kernel Components

1. Memory Operations Kernels

Paged KV-Cache Operations

// Efficient operations on paged key-value cache
__global__ void copy_blocks_kernel(
    float* dst_keys, float* dst_values,
    float* src_keys, float* src_values,
    int* block_mapping, int num_blocks
);

Block Management

  • Block allocation and deallocation
  • Memory copying between blocks
  • Block state management

2. Flash Attention Kernels

Optimized Attention Computation

// Optimized attention for both prefill and decode phases
template<typename T>
__global__ void flash_attention_kernel(
    const T* q, const T* k, const T* v,
    T* output, float* lse,  // logsumexp for numerical stability
    int num_heads, int head_dim, int seq_len
);

Features

  • Memory-efficient attention computation
  • Numerical stability with logsumexp
  • Support for variable sequence lengths
  • Optimized memory access patterns

3. Paged Attention Kernels

Paged Memory Access

// Attention with paged key-value cache support
__global__ void paged_attention_kernel(
    float* output,
    const float* query,
    const float* key_cache,
    const float* value_cache,
    const int* block_tables,
    const int* context_lens,
    int num_kv_heads, int head_dim, int block_size
);

Features

  • Direct paged cache access patterns
  • Efficient block index computation
  • Memory coalescing optimization
  • Support for shared prefix computation

4. Radix Operations Kernels

Prefix Matching Operations

// GPU-accelerated prefix matching for SGLang-style sharing
__global__ void radix_tree_lookup_kernel(
    int* token_ids, int* request_ids,
    int* prefix_matches, int batch_size
);

Features

  • Parallel prefix matching
  • Efficient tree traversal on GPU
  • Shared computation identification
  • Batch processing optimization

Build Process Details

Setup Configuration

The build system uses PyTorch’s setup.py for CUDA extension compilation:

# setup_kernels.py
from torch.utils.cpp_extension import setup, CUDAExtension

setup(
    name="yaie_kernels",
    ext_modules=[
        CUDAExtension(
            name="yaie_kernels",
            sources=[
                "kernels/cuda/radix_attention.cu",
                "kernels/cuda/paged_attention.cu",
                "kernels/cuda/memory_ops.cu",
                "kernels/cuda/radix_ops.cu"
            ],
            extra_compile_args={
                "cxx": ["-O3"],
                "nvcc": ["-O3", "--use_fast_math"]
            }
        )
    ],
    cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}
)

Compilation Flags

The build system uses optimization flags for performance:

  • -O3: Maximum optimization
  • --use_fast_math: Fast math operations
  • -arch=sm_60: Target specific GPU architectures
  • -lineinfo: Include debug line information

Architecture Targeting

The system supports multiple GPU architectures:

# Specify target architecture during build
python setup_kernels.py build_ext --inplace --arch=75  # Turing GPUs
python setup_kernels.py build_ext --inplace --arch=80  # Ampere GPUs

CUDA Kernel Implementation Guidelines

Memory Management

Unified Memory vs Regular Memory

// Use unified memory for easier management (if available)
cudaMallocManaged(&ptr, size);

// Or regular device memory for better performance
cudaMalloc(&ptr, size);

Memory Pooling

  • Implement memory pooling for frequently allocated objects
  • Reuse memory blocks across operations
  • Batch memory operations when possible

Thread Organization

Block and Grid Sizing

// Optimize for your specific algorithm
dim3 blockSize(256);
dim3 gridSize((N + blockSize.x - 1) / blockSize.x);

Warp-Level Primitives

  • Use warp-level operations for better efficiency
  • Align memory accesses with warp boundaries
  • Minimize warp divergence

Synchronization

Cooperative Groups

#include <cooperative_groups.h>
using namespace cooperative_groups;

// Use cooperative groups for complex synchronization
thread_block block = this_thread_block();

Memory Barriers

  • Use appropriate memory barriers for consistency
  • Minimize unnecessary synchronization overhead

Performance Optimization Strategies

Memory Bandwidth Optimization

Coalesced Access

// Ensure memory accesses are coalesced
int tid = blockIdx.x * blockDim.x + threadIdx.x;
// Access data[tid] by threads in order for coalescing

Shared Memory Usage

  • Use shared memory for frequently accessed data
  • Implement tiling strategies for large operations
  • Minimize global memory access

Computation Optimization

Warp-Level Operations

  • Leverage warp-level primitives when possible
  • Use vectorized operations (float4, int4)
  • Minimize thread divergence within warps

Kernel Fusion

Combined Operations

  • Fuse multiple operations into single kernels
  • Reduce kernel launch overhead
  • Improve memory locality

Integration with Python

PyTorch Extensions

The CUDA kernels integrate with PyTorch using extensions:

import torch
import yaie_kernels  # Compiled extension

# Use kernel from Python
result = yaie_kernels.radix_attention_forward(
    query, key, value, 
    radix_tree_info, 
    attention_mask
)

Automatic GPU Management

The integration handles:

  • GPU memory allocation
  • Device synchronization
  • Error propagation
  • Backpropagation support

Error Handling and Debugging

Build-Time Errors

Common build issues and solutions:

# CUDA toolkit not found
export CUDA_HOME=/usr/local/cuda

# Architecture mismatch
# Check GPU compute capability and adjust flag

# Missing PyTorch
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Runtime Error Checking

Kernels should include error checking:

#define CUDA_CHECK(call) \
    do { \
        cudaError_t err = call; \
        if (err != cudaSuccess) { \
            fprintf(stderr, "CUDA error at %s:%d - %s\n", __FILE__, __LINE__, \
                    cudaGetErrorString(err)); \
            exit(1); \
        } \
    } while(0)

Testing and Validation

Kernel Testing

Test kernels with:

  • Unit tests for individual functions
  • Integration tests with Python interface
  • Performance benchmarks
  • Memory correctness validation

SGLang-Specific Tests

Test SGLang optimization features:

  • Prefix sharing correctness
  • Memory management validation
  • Performance gain verification
  • Edge case handling

Development Workflow

Iterative Development

The development process includes:

  1. Kernel Design: Design algorithm for GPU execution
  2. Implementation: Write CUDA kernel code
  3. Building: Compile with build system
  4. Testing: Validate correctness and performance
  5. Optimization: Profile and optimize based on results

Profiling and Optimization

Use NVIDIA tools for optimization:

  • Nsight Systems: Overall system profiling
  • Nsight Compute: Detailed kernel analysis
  • nvprof: Legacy profiling tool

Future Extensions

Advanced Features

Potential kernel enhancements:

  • Quantized attention operations
  • Sparse attention kernels
  • Custom activation functions
  • Advanced memory management

Hardware Support

Expand support for:

  • Different GPU architectures
  • Multi-GPU operations
  • Heterogeneous computing
  • Tensor Core optimizations

This build system and CUDA kernel architecture enables Mini-YAIE to achieve SGLang-style performance optimizations while maintaining educational value and extensibility.

Memory Operations (kernels/cuda/memory_ops.cu)

Concept

Moving data between different GPU memory locations is a frequent operation in Paged Attention.

Implementation Goal

Implement copy_blocks_kernel:

Signature

void copy_blocks_kernel(
    torch::Tensor key_cache,      // [num_blocks, block_size, head_dim]
    torch::Tensor value_cache,    // [num_blocks, block_size, head_dim]
    torch::Tensor block_mapping,  // [num_mappings, 2] (src, dst)
    int num_mappings
);

Logic

  1. Parallelism: Launch one thread per token to copy.
  2. Indexing:
    • mapping_idx = blockIdx.x
    • src_block = block_mapping[mapping_idx][0]
    • dst_block = block_mapping[mapping_idx][1]
  3. Copy:
    • Read key/value from src_block at threadIdx offset.
    • Write to dst_block.

Flash Attention (kernels/cuda/flash_attention.cu)

1. Concept: Memory Bandwidth

The main bottleneck in Attention is reading the huge $N \times N$ matrix from memory. Flash Attention breaks the problem into small “tiles” that fit into the GPU’s fast SRAM (Shared Memory). We compute everything for that tile without going back to slow Global Memory.

graph TB
    subgraph GlobalMemory[Global Memory HBM]
        Q[Matrix Q]
        K[Matrix K]
        V[Matrix V]
    end

    subgraph SRAM[Shared Memory SRAM]
        TileQ[Tile Q]
        TileK[Tile K]
        TileV[Tile V]
        Comp(("Compute QK^T * V"))
    end

    Q --> TileQ
    K --> TileK
    V --> TileV

    TileQ --> Comp
    TileK --> Comp
    TileV --> Comp

2. Implementation Guide

We will implement a simplified version. Doing full FlashAttention v2 is extremely complex. We aim for “Tiled Attention”.

Step 0: The Setup

Open src/kernels/cuda/flash_attention.cu. Identify the flash_attention_forward function.

You have pointers to:

  • query (Q), key (K), value (V) residing in Global Memory.

Step 1: Define Thread Layout

We want to process tiles.

  • Grid: One block per query chunk.
  • Block: Threads within the block handle individual heads or elements.
// Example
dim3 grid(num_batches, num_heads);
dim3 block(128); // 128 threads work together on one head

Step 2: Load Tiles to Shared Memory

You need __shared__ memory arrays.

__shared__ float s_Q[TILE_SIZE][HEAD_DIM];
__shared__ float s_K[TILE_SIZE][HEAD_DIM];

Use threadIdx.x to cooperatively load data from Global Q to Shared s_Q. Remember: call __syncthreads() after loading!

Step 3: Compute $QK^T$ (Scores)

Iterate over your shared Q and K. Calculate the dot product. Store in a register (local variable).

Step 4: Softmax (The “Online” Trick)

In standard softmax, you need the max of the entire row. Here we only see a tile! Trick: Keep a running max ($m$) and running sum ($l$). Update them as you see new tiles.

  • $m_{new} = \max(m_{old}, \max(current_tile))$
  • Adjust previous sums by multiplying by $e^{m_{old} - m_{new}}$.

Step 5: Compute Score $\times$ V

Once you have the probabilities for the tile, multiply by s_V (which you also loaded). Accumulate into output.


3. Hints

  • Start with a Naive kernel first! Forget shared memory. Just loops.
    • Thread per query token.
    • Loop over all key tokens.
    • Compute.
    • This is $O(N^2)$ memory reads but verifies your logic is correct.
  • Only optimize to Shared Memory once logic works.

Paged Attention (kernels/cuda/paged_attention.cu)

1. Concept: Indirection

Paged Attention is just standard attention, but K and V are not contiguous. We have to “gather” them using a Page Table.

graph LR
    Thread -->|1. Get Logical idx| Logic[Token #42]
    Logic -->|2. Lookup Table| Table[Block 2, Offset 10]
    Table -->|3. Get Physical Addr| Phys[0xA000...]
    Phys -->|4. Read| Data[Value]

2. Implementation Guide

Step 1: Understand the Block Table

You are passed block_tables tensor of shape [num_seqs, max_blocks].

  • It holds integer indices of physical blocks.
  • block_tables[req_id][0] is the first block of that request.

Step 2: Calculate Physical Address

Inside your kernel, you want the Key vector for token t of request r.

int block_idx = t / BLOCK_SIZE;
int block_offset = t % BLOCK_SIZE;
int physical_block_number = block_tables[r][block_idx];

// Pointer arithmetic
float* k_ptr = key_cache_base
             + physical_block_number * (BLOCK_SIZE * HEAD_DIM * NUM_HEADS)
             + ... // navigate to specific head and offset

Step 3: Load Data

Using the pointer k_ptr, load the vector into registers or shared memory.

Step 4: Compute Attention

Once loaded, the math is identical to standard Attention or Flash Attention. $Q \cdot K^T$, Softmax, $\cdot V$.


3. Your Task

Implement paged_attention_kernel in src/kernels/cuda/paged_attention.cu.

  1. Focus on the address calculation logic. That is the only difference!
  2. Use the copy_blocks kernel (Memory Ops) to help set up test data if needed.

Radix Operations (kernels/cuda/radix_ops.cu)

Concept

If we have a Radix Tree, we can optimize attention even further by knowing exactly which tokens are shared.

Implementation Goal

This is an advanced extension.

Logic

  1. Tree Traversal on GPU: Mapping the Radix Tree structure to a GPU-friendly format (e.g., flattened arrays).
  2. Prefix Matching: A kernel that takes a batch of prompts and quickly identifies the longest common prefix node ID for each.

Note: In the simplified version, this logic is often kept in CPU (Python) and only the KV indices are passed to the GPU.

Kernels Implementation Guide

Overview

This guide provides a comprehensive walkthrough for implementing the core kernels in Mini-YAIE that enable SGLang-style inference optimization. The implementation focuses on three key areas:

  1. Python Implementations: Educational implementations of core algorithms
  2. CUDA Kernels: Performance-optimized GPU implementations
  3. Integration: Connecting kernels with the main inference engine

Implementation Roadmap

Phase 1: Core Python Kernels

Implement the educational Python versions first:

  1. Radix tree for prefix matching
  2. Basic attention mechanisms
  3. KV-cache management
  4. Sampling algorithms

Phase 2: CUDA Kernel Development

Develop optimized GPU versions:

  1. Memory operations kernels
  2. Paged attention implementation
  3. Flash attention optimization
  4. Radix operations acceleration

Phase 3: Integration and Optimization

Connect kernels to the main system:

  1. Engine integration
  2. Performance validation
  3. Correctness verification

Python Kernel Implementation

1. Radix Tree Implementation

Start with the radix tree that enables prefix sharing:

File: src/kernels/radix_tree.py

class RadixTreeNode:
    def __init__(self, token_id: Optional[int] = None):
        self.token_id = token_id
        self.children: Dict[int, "RadixTreeNode"] = {}
        self.request_ids: List[str] = []
        self.kv_cache_refs: List[str] = []
        self.is_terminal = False

class RadixTree:
    def __init__(self):
        self.root = RadixTreeNode()
        self.request_to_path: Dict[str, List[int]] = {}
        self.path_to_node: Dict[str, RadixTreeNode] = {}
    
    def insert_request(self, request_id: str, token_ids: List[int]):
        """Insert a request into the radix tree based on its token sequence"""
        current = self.root
        for token_id in token_ids:
            if token_id not in current.children:
                current.children[token_id] = RadixTreeNode(token_id)
            current = current.children[token_id]
            if request_id not in current.request_ids:
                current.request_ids.append(request_id)
        current.is_terminal = True
        self.request_to_path[request_id] = token_ids
        path_str = self._path_to_string(token_ids)
        self.path_to_node[path_str] = current
    
    def find_shared_prefixes(self, token_ids: List[int]) -> Tuple[List[str], int]:
        """Find requests that share prefixes with the given token sequence"""
        current = self.root
        matched_requests = []
        prefix_length = 0
        
        for i, token_id in enumerate(token_ids):
            if token_id in current.children:
                current = current.children[token_id]
                matched_requests.extend(current.request_ids)
                prefix_length = i + 1
            else:
                break
        return list(set(matched_requests)), prefix_length

2. KV-Cache Management

Implement the paged KV-cache system:

File: src/kernels/kv_cache.py

class KVCacheBlock:
    def __init__(self, block_id: int, size: int, num_heads: int, head_dim: int, dtype=torch.float16):
        self.block_id = block_id
        self.size = size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.dtype = dtype
        self.keys = None
        self.values = None
    
    def allocate(self):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.keys = torch.zeros(self.size, self.num_heads, self.head_dim, dtype=self.dtype, device=device)
        self.values = torch.zeros(self.size, self.num_heads, self.head_dim, dtype=self.dtype, device=device)

class KVCacheManager:
    def __init__(self, num_blocks: int, block_size: int, num_heads: int, head_dim: int, dtype=torch.float16):
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.dtype = dtype
        
        self.blocks: List[KVCacheBlock] = []
        for i in range(num_blocks):
            block = KVCacheBlock(i, block_size, num_heads, head_dim, dtype)
            self.blocks.append(block)
        
        self.free_block_list: List[int] = list(range(num_blocks))
        self.block_tables: dict = {}
    
    def allocate_blocks(self, request_id: str, num_tokens: int) -> List[int]:
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
        if len(self.free_block_list) < num_blocks_needed:
            raise RuntimeError(f"Not enough free blocks. Need {num_blocks_needed}, have {len(self.free_block_list)}")
        
        allocated_block_ids = []
        for _ in range(num_blocks_needed):
            block_id = self.free_block_list.pop(0)
            allocated_block_ids.append(block_id)
            self.blocks[block_id].allocate()
        
        self.block_tables[request_id] = allocated_block_ids
        return allocated_block_ids

3. Radix Attention Implementation

Implement the radial attention mechanism:

File: src/kernels/radix_attention.py

import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
    cos = cos.unsqueeze(0).unsqueeze(0)
    sin = sin.unsqueeze(0).unsqueeze(0)
    
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class RadixAttentionBlock(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, head_dim: int, max_position_embeddings: int = 2048):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_position_embeddings = max_position_embeddings
        
        total_hidden_dim = num_heads * head_dim
        
        self.q_proj = nn.Linear(hidden_size, total_hidden_dim, bias=False)
        self.k_proj = nn.Linear(hidden_size, total_hidden_dim, bias=False)
        self.v_proj = nn.Linear(hidden_size, total_hidden_dim, bias=False)
        self.o_proj = nn.Linear(total_hidden_dim, hidden_size, bias=False)
        
        self.register_buffer(
            "cos_cached",
            torch.ones((max_position_embeddings, head_dim), dtype=torch.float32),
            persistent=False,
        )
        self.register_buffer(
            "sin_cached",
            torch.ones((max_position_embeddings, head_dim), dtype=torch.float32),
            persistent=False,
        )
        
        self._setup_rope_embeddings()
    
    def _setup_rope_embeddings(self):
        position_ids = torch.arange(self.max_position_embeddings, dtype=torch.float32)
        inv_freq = 1.0 / (10000.0 ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim))
        
        freqs = torch.einsum("i,j->ij", position_ids, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        
        self.cos_cached = emb.cos().to(dtype=torch.float16)
        self.sin_cached = emb.sin().to(dtype=torch.float16)
    
    def forward(self, hidden_states: torch.Tensor, position_ids: Optional[torch.Tensor] = None, 
                past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape
        
        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)
        
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        cos_to_use = self.cos_cached[:seq_len].to(query.dtype)
        sin_to_use = self.sin_cached[:seq_len].to(query.dtype)
        
        query, key = apply_rotary_pos_emb(query, key, cos_to_use, sin_to_use, position_ids)
        
        if past_key_value is not None:
            key = torch.cat([past_key_value[0], key], dim=2)
            value = torch.cat([past_key_value[1], value], dim=2)
        
        attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
        
        attn_output = torch.matmul(attn_weights, value)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_len, self.num_heads * self.head_dim)
        
        output = self.o_proj(attn_output)
        return output

4. Sampling Kernel Implementation

Implement the token sampling system:

File: src/kernels/sampling.py

import torch

class SamplingKernel:
    def sample(self, logits: torch.Tensor, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1) -> torch.Tensor:
        if temperature != 1.0:
            logits = logits / temperature
        
        probs = torch.softmax(logits, dim=-1)
        batch_size, vocab_size = probs.shape
        
        if top_k > 0:
            top_k = min(top_k, vocab_size)
            top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
            
            new_probs = torch.zeros_like(probs)
            new_probs.scatter_(1, top_k_indices, top_k_probs)
            new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True)
            probs = new_probs
        
        if 0 < top_p < 1.0:
            sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
            
            mask = cumulative_probs <= top_p
            if mask.shape[-1] > 0:
                mask[..., 0] = True
            
            filtered_probs = torch.zeros_like(probs)
            filtered_probs.scatter_(1, sorted_indices, mask.float() * sorted_probs)
            filtered_probs = filtered_probs / filtered_probs.sum(dim=-1, keepdim=True)
            probs = filtered_probs
        
        sampled_ids = torch.multinomial(probs, num_samples=1).squeeze(-1)
        return sampled_ids

CUDA Kernel Implementation

1. Memory Operations Kernels

File: src/kernels/cuda/memory_ops.cu

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void copy_blocks_kernel(
    float* key_cache, float* value_cache,
    float* new_key_cache, float* new_value_cache,
    int* block_mapping,  // [src_block_id, dst_block_id] pairs
    int num_heads, int head_dim, int block_size,
    int num_mappings
) {
    int mapping_idx = blockIdx.x;
    if (mapping_idx >= num_mappings) return;
    
    int src_block_id = block_mapping[mapping_idx * 2];
    int dst_block_id = block_mapping[mapping_idx * 2 + 1];
    
    int total_elements_per_block = block_size * num_heads * head_dim;
    
    int tid = threadIdx.x;
    if (tid < total_elements_per_block) {
        int src_idx = src_block_id * total_elements_per_block + tid;
        int dst_idx = dst_block_id * total_elements_per_block + tid;
        
        new_key_cache[dst_idx] = key_cache[src_idx];
        new_value_cache[dst_idx] = value_cache[src_idx];
    }
}

torch::Tensor copy_blocks_cuda(
    torch::Tensor key_cache, torch::Tensor value_cache,
    torch::Tensor block_mapping,
    int num_heads, int head_dim, int block_size
) {
    int num_mappings = block_mapping.size(0);
    
    auto options = key_cache.options();
    auto new_key_cache = torch::zeros_like(key_cache);
    auto new_value_cache = torch::zeros_like(value_cache);
    
    dim3 grid(num_mappings);
    dim3 block(256);  // Use 256 threads per block
    
    copy_blocks_kernel<<<grid, block>>>(
        key_cache.data_ptr<float>(),
        value_cache.data_ptr<float>(),
        new_key_cache.data_ptr<float>(),
        new_value_cache.data_ptr<float>(),
        block_mapping.data_ptr<int>(),
        num_heads, head_dim, block_size,
        num_mappings
    );
    
    cudaDeviceSynchronize();
    return std::make_tuple(new_key_cache, new_value_cache);
}

2. Paged Attention Kernels

File: src/kernels/cuda/paged_attention.cu

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void paged_attention_kernel(
    float* output,           // [num_seqs, seq_len, num_heads, head_dim]
    const float* query,      // [num_seqs, seq_len, num_heads, head_dim]
    const float* key_cache,  // [num_blocks, block_size, num_kv_heads, head_dim]
    const float* value_cache,// [num_blocks, block_size, num_kv_heads, head_dim]
    const int* block_tables, // [num_seqs, max_blocks_per_seq]
    const int* context_lens, // [num_seqs]
    const int num_kv_heads,
    const int num_queries_per_kv,
    const int head_dim,
    const int block_size,
    const int max_num_blocks_per_seq
) {
    int seq_idx = blockIdx.x;
    int q_head_idx = blockIdx.y;
    int token_idx = blockIdx.z * blockDim.x + threadIdx.x;
    
    if (seq_idx >= gridDim.x || q_head_idx >= gridDim.y || token_idx >= context_lens[seq_idx]) {
        return;
    }
    
    // Get corresponding KV head index
    int kv_head_idx = q_head_idx / num_queries_per_kv;
    
    // Get query vector
    int query_idx = seq_idx * context_lens[seq_idx] * gridDim.y * head_dim +
                    token_idx * gridDim.y * head_dim +
                    q_head_idx * head_dim;
    
    // Shared memory for the current query
    extern __shared__ float shared_mem[];
    float* query_vec = shared_mem;
    
    // Load query vector to shared memory
    for (int d = 0; d < head_dim; d++) {
        query_vec[d] = query[query_idx + d];
    }
    
    // Calculate which block and offset for this token
    int block_idx = token_idx / block_size;
    int block_offset = token_idx % block_size;
    
    // Get physical block number from block table
    int physical_block = block_tables[seq_idx * max_num_blocks_per_seq + block_idx];
    
    // Calculate the actual index in the cache
    int cache_idx = physical_block * block_size * num_kv_heads * head_dim +
                    block_offset * num_kv_heads * head_dim +
                    kv_head_idx * head_dim;
    
    // Perform attention computation
    float sum = 0.0f;
    for (int d = 0; d < head_dim; d++) {
        sum += query_vec[d] * key_cache[cache_idx + d];
    }
    
    // Apply softmax and multiply with value
    float attention_weight = __expf(sum);  // Simplified (real softmax needs normalization)
    for (int d = 0; d < head_dim; d++) {
        int output_idx = seq_idx * context_lens[seq_idx] * gridDim.y * head_dim +
                         token_idx * gridDim.y * head_dim +
                         q_head_idx * head_dim + d;
        output[output_idx] += attention_weight * value_cache[cache_idx + d];
    }
}

torch::Tensor paged_attention_cuda(
    torch::Tensor query, torch::Tensor key_cache, torch::Tensor value_cache,
    torch::Tensor block_tables, torch::Tensor context_lens,
    int num_kv_heads, int num_queries_per_kv
) {
    int num_seqs = query.size(0);
    int seq_len = query.size(1);
    int num_heads = query.size(2);
    int head_dim = query.size(3);
    int block_size = key_cache.size(1);
    int max_blocks_per_seq = block_tables.size(1);
    
    auto output = torch::zeros_like(query);
    
    dim3 grid(num_seqs, num_heads, (seq_len + 255) / 256);  // 256 threads per block
    dim3 block(256);
    
    // Allocate shared memory for query vector
    int shared_mem_size = head_dim * sizeof(float);
    
    paged_attention_kernel<<<grid, block, shared_mem_size>>>(
        output.data_ptr<float>(),
        query.data_ptr<float>(),
        key_cache.data_ptr<float>(),
        value_cache.data_ptr<float>(),
        block_tables.data_ptr<int>(),
        context_lens.data_ptr<int>(),
        num_kv_heads,
        num_queries_per_kv,
        head_dim,
        block_size,
        max_blocks_per_seq
    );
    
    cudaDeviceSynchronize();
    return output;
}

3. Flash Attention Kernels (Simplified)

File: src/kernels/cuda/flash_attention.cu

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

#define TILE_SIZE 32  // Small tile for educational purposes

__global__ void flash_attention_kernel(
    float* output,
    const float* query,
    const float* key,
    const float* value,
    const int* seq_lens,
    const int num_seqs,
    const int num_heads,
    const int head_dim,
    const int max_seq_len
) {
    int seq_idx = blockIdx.x;
    int head_idx = blockIdx.y;
    
    if (seq_idx >= num_seqs || head_idx >= num_heads) return;
    
    int current_seq_len = seq_lens[seq_idx];
    
    // Shared memory for tiles
    extern __shared__ float shared_mem[];
    float* s_Q = shared_mem;
    float* s_K = s_Q + TILE_SIZE * head_dim;
    float* s_V = s_K + TILE_SIZE * head_dim;
    float* s_scores = s_V + TILE_SIZE * head_dim;
    
    // Process the sequence in tiles
    for (int q_tile_start = 0; q_tile_start < current_seq_len; q_tile_start += TILE_SIZE) {
        // Load Q tile to shared memory
        for (int i = threadIdx.x; i < TILE_SIZE * head_dim; i += blockDim.x) {
            int q_row = q_tile_start + i / head_dim;
            int q_col = i % head_dim;
            
            if (q_row < current_seq_len) {
                int q_idx = seq_idx * max_seq_len * num_heads * head_dim +
                           q_row * num_heads * head_dim +
                           head_idx * head_dim + q_col;
                s_Q[i] = query[q_idx];
            } else {
                s_Q[i] = 0.0f;
            }
        }
        
        __syncthreads();
        
        // For each K/V tile
        for (int k_tile_start = 0; k_tile_start < current_seq_len; k_tile_start += TILE_SIZE) {
            // Load K and V tiles
            for (int i = threadIdx.x; i < TILE_SIZE * head_dim; i += blockDim.x) {
                int k_row = k_tile_start + i / head_dim;
                int k_col = i % head_dim;
                
                if (k_row < current_seq_len) {
                    int k_idx = seq_idx * max_seq_len * num_heads * head_dim +
                               k_row * num_heads * head_dim +
                               head_idx * head_dim + k_col;
                    s_K[i] = key[k_idx];
                    s_V[i] = value[k_idx];
                } else {
                    s_K[i] = 0.0f;
                    s_V[i] = 0.0f;
                }
            }
            
            __syncthreads();
            
            // Compute attention scores for this tile
            for (int q_local = threadIdx.x; q_local < TILE_SIZE; q_local += blockDim.x) {
                int q_global = q_tile_start + q_local;
                if (q_global >= current_seq_len) continue;
                
                float score_sum = 0.0f;
                float max_score = -INFINITY;
                
                // Compute scores against K tile
                for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
                    int k_global = k_tile_start + k_local;
                    if (k_global >= current_seq_len) continue;
                    
                    float score = 0.0f;
                    for (int d = 0; d < head_dim; d++) {
                        int q_offset = q_local * head_dim + d;
                        int k_offset = k_local * head_dim + d;
                        score += s_Q[q_offset] * s_K[k_offset];
                    }
                    
                    // Apply causal mask
                    if (k_global > q_global) score = -INFINITY;
                    
                    // Update max for numerical stability
                    if (score > max_score) max_score = score;
                    
                    s_scores[q_local * TILE_SIZE + k_local] = score;
                }
                
                // Apply softmax with numerical stability
                float exp_sum = 0.0f;
                for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
                    int k_global = k_tile_start + k_local;
                    if (k_global >= current_seq_len || k_global > q_global) {
                        s_scores[q_local * TILE_SIZE + k_local] = 0.0f;
                    } else {
                        float score = s_scores[q_local * TILE_SIZE + k_local];
                        float exp_score = __expf(score - max_score);
                        s_scores[q_local * TILE_SIZE + k_local] = exp_score;
                        exp_sum += exp_score;
                    }
                }
                
                // Normalize scores
                if (exp_sum > 0.0f) {
                    for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
                        s_scores[q_local * TILE_SIZE + k_local] /= exp_sum;
                    }
                }
                
                // Compute output = scores @ V
                for (int d = 0; d < head_dim; d++) {
                    float output_val = 0.0f;
                    for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
                        int v_offset = k_local * head_dim + d;
                        output_val += s_scores[q_local * TILE_SIZE + k_local] * s_V[v_offset];
                    }
                    
                    int out_idx = seq_idx * max_seq_len * num_heads * head_dim +
                                 q_global * num_heads * head_dim +
                                 head_idx * head_dim + d;
                    output[out_idx] = output_val;
                }
            }
            
            __syncthreads();
        }
    }
}

torch::Tensor flash_attention_cuda(
    torch::Tensor query, torch::Tensor key, torch::Tensor value,
    torch::Tensor seq_lens
) {
    int num_seqs = query.size(0);
    int num_heads = query.size(2);
    int head_dim = query.size(3);
    int max_seq_len = query.size(1);
    
    auto output = torch::zeros_like(query);
    
    dim3 grid(num_seqs, num_heads);
    dim3 block(256);
    
    // Shared memory for 3 tiles + scores matrix
    int shared_mem_size = 3 * TILE_SIZE * head_dim * sizeof(float) +
                          TILE_SIZE * TILE_SIZE * sizeof(float);
    
    flash_attention_kernel<<<grid, block, shared_mem_size>>>(
        output.data_ptr<float>(),
        query.data_ptr<float>(),
        key.data_ptr<float>(),
        value.data_ptr<float>(),
        seq_lens.data_ptr<int>(),
        num_seqs,
        num_heads,
        head_dim,
        max_seq_len
    );
    
    cudaDeviceSynchronize();
    return output;
}

Integration and Testing

1. Python-CUDA Interface

Create the Python interface for CUDA kernels:

File: src/kernels/api.py

"""
API for accessing both Python and CUDA implementations of kernels
"""

import torch
from typing import Optional

# Try to import CUDA extensions
try:
    from . import yaie_kernels  # This will be built from CUDA sources
    CUDA_AVAILABLE = True
except ImportError:
    print("CUDA extensions not available. Using Python implementations.")
    CUDA_AVAILABLE = False

def attention_forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                     use_cuda: bool = True, **kwargs):
    """Unified attention interface that can use CUDA or Python implementation"""
    if CUDA_AVAILABLE and use_cuda and query.is_cuda:
        return yaie_kernels.flash_attention_cuda(query, key, value, 
                                               kwargs.get('seq_lens', None))
    else:
        # Fallback to Python implementation
        from .radix_attention import RadixAttentionBlock
        attention_block = RadixAttentionBlock(
            hidden_size=query.shape[-1],
            num_heads=query.shape[-2],
            head_dim=query.shape[-1] // query.shape[-2]
        )
        return attention_block(query)

def paged_attention_forward(query: torch.Tensor, key_cache: torch.Tensor, 
                           value_cache: torch.Tensor, block_tables: torch.Tensor,
                           context_lens: torch.Tensor, use_cuda: bool = True, **kwargs):
    """Paged attention interface"""
    if CUDA_AVAILABLE and use_cuda and query.is_cuda:
        return yaie_kernels.paged_attention_cuda(
            query, key_cache, value_cache, block_tables, context_lens,
            kwargs.get('num_kv_heads', 1),
            kwargs.get('num_queries_per_kv', 1)
        )
    else:
        # Python fallback would go here
        raise NotImplementedError("Paged attention Python fallback not implemented")

def copy_blocks(key_cache: torch.Tensor, value_cache: torch.Tensor, 
                block_mapping: torch.Tensor, use_cuda: bool = True, **kwargs):
    """Memory copy interface"""
    if CUDA_AVAILABLE and use_cuda and key_cache.is_cuda:
        return yaie_kernels.copy_blocks_cuda(
            key_cache, value_cache, block_mapping,
            kwargs.get('num_heads', 1),
            kwargs.get('head_dim', 64),
            kwargs.get('block_size', 16)
        )
    else:
        # Python fallback would go here
        raise NotImplementedError("Copy blocks Python fallback not implemented")

2. Testing Framework

Create comprehensive tests:

File: tests/test_kernels.py

import pytest
import torch
import numpy as np

from src.kernels.radix_tree import RadixTree
from src.kernels.kv_cache import KVCacheManager
from src.kernels.radix_attention import RadixAttentionBlock
from src.kernels.sampling import SamplingKernel
from src.kernels.api import attention_forward, paged_attention_forward, copy_blocks

class TestRadixTree:
    def test_basic_insertion_and_search(self):
        tree = RadixTree()
        
        # Insert requests
        tree.insert_request("req1", [1, 2, 3])
        tree.insert_request("req2", [1, 2, 4])  # Shares prefix [1, 2]
        tree.insert_request("req3", [5, 6, 7])  # No shared prefix
        
        # Test shared prefixes
        shared, length = tree.find_shared_prefixes([1, 2, 5])
        assert "req1" in shared
        assert "req2" in shared
        assert length == 2  # Common prefix [1, 2]
    
    def test_prefix_sharing_graph(self):
        tree = RadixTree()
        tree.insert_request("req1", [1, 2, 3])
        tree.insert_request("req2", [1, 2, 4])
        tree.insert_request("req3", [1, 5, 6])
        
        graph = tree.get_shared_computation_graph()
        # Should show shared computation at token [1]
        assert graph["request_count"] == 3  # All requests start with root
        
class TestKVCacheManager:
    def test_basic_allocation(self):
        cache_manager = KVCacheManager(
            num_blocks=100,
            block_size=16,
            num_heads=8,
            head_dim=64,
            dtype=torch.float16
        )
        
        # Allocate blocks for a request
        blocks = cache_manager.allocate_blocks("req1", 20)  # Need 2 blocks (20/16 = 2)
        assert len(blocks) == 2
        
        # Verify the blocks exist and have proper tensors
        for block_id in blocks:
            block = cache_manager.blocks[block_id]
            assert block.keys is not None
            assert block.values is not None
            assert block.keys.shape == (16, 8, 64)  # block_size, num_heads, head_dim
    
    def test_block_reuse(self):
        cache_manager = KVCacheManager(
            num_blocks=10,
            block_size=16,
            num_heads=8,
            head_dim=64
        )
        
        # Allocate all blocks
        req_ids = [f"req{i}" for i in range(10)]
        for req_id in req_ids:
            cache_manager.allocate_blocks(req_id, 10)
        
        assert len(cache_manager.free_block_list) == 0
        
        # Free some blocks
        cache_manager.free_blocks("req0")
        cache_manager.free_blocks("req1")
        
        assert len(cache_manager.free_block_list) == 2
        assert 0 in cache_manager.free_block_list
        assert 1 in cache_manager.free_block_list

class TestRadixAttention:
    def test_basic_attention_forward(self):
        hidden_size = 512
        num_heads = 8
        head_dim = hidden_size // num_heads
        
        attention = RadixAttentionBlock(
            hidden_size=hidden_size,
            num_heads=num_heads,
            head_dim=head_dim,
            max_position_embeddings=256
        )
        
        batch_size = 2
        seq_len = 10
        x = torch.randn(batch_size, seq_len, hidden_size, dtype=torch.float16)
        
        output = attention(x)
        assert output.shape == (batch_size, seq_len, hidden_size)
        assert not torch.isnan(output).any()
    
    def test_attention_with_past_key_value(self):
        hidden_size = 256
        num_heads = 4
        head_dim = hidden_size // num_heads
        
        attention = RadixAttentionBlock(
            hidden_size=hidden_size,
            num_heads=num_heads,
            head_dim=head_dim
        )
        
        batch_size = 1
        seq_len = 5
        x = torch.randn(batch_size, seq_len, hidden_size)
        
        # First forward pass
        output1, _, past_kv = attention(x, use_cache=True)
        
        # Second forward pass with past key-value
        next_token = torch.randn(batch_size, 1, hidden_size)
        output2, _, _ = attention(next_token, past_key_value=past_kv)
        
        assert output2.shape == (batch_size, 1, hidden_size)

class TestSamplingKernel:
    def test_temperature_sampling(self):
        sampling = SamplingKernel()
        
        # Create logits with one clear winner
        logits = torch.tensor([[10.0, 1.0, 1.0, 1.0]])  # First token is dominant
        
        # High temperature should allow other tokens
        sampled_high_temp = sampling.sample(logits, temperature=2.0)
        assert sampled_high_temp.shape == (1,)
        
        # Low temperature should favor dominant token
        sampled_low_temp = sampling.sample(logits, temperature=0.1)
        assert sampled_low_temp[0] == 0  # Should pick the dominant token
    
    def test_top_p_nucleus_sampling(self):
        sampling = SamplingKernel()
        
        # Create logits where first 3 tokens account for ~90% of probability
        logits = torch.tensor([[2.0, 1.5, 1.0, -10.0, -10.0]])
        
        # Top-p = 0.8 should exclude the last two tokens
        sampled = sampling.sample(logits, top_p=0.8)
        # Should be one of the first 3 tokens
        assert sampled[0] in [0, 1, 2]
    
    def test_top_k_sampling(self):
        sampling = SamplingKernel()
        
        # Create logits with clear ordering
        logits = torch.tensor([[5.0, 4.0, 3.0, 2.0, 1.0]])
        
        # Top-k = 2 should only consider first 2 tokens
        sampled = sampling.sample(logits, top_k=2)
        assert sampled[0] in [0, 1]  # Should be one of top 2 tokens

class TestIntegration:
    def test_full_inference_pipeline(self):
        """Test integration of all kernels in a simple pipeline"""
        # This test would simulate a full inference step
        batch_size = 2
        seq_len = 10
        hidden_size = 256
        num_heads = 4
        head_dim = hidden_size // num_heads
        
        # Create input
        x = torch.randn(batch_size, seq_len, hidden_size)
        
        # Apply attention
        attention = RadixAttentionBlock(
            hidden_size=hidden_size,
            num_heads=num_heads,
            head_dim=head_dim
        )
        attn_output = attention(x)
        assert attn_output.shape == (batch_size, seq_len, hidden_size)
        
        # Apply sampling (on logits that would come from LM head)
        logits = torch.randn(batch_size, 1000)  # vocab_size = 1000
        sampling = SamplingKernel()
        sampled_tokens = sampling.sample(logits, temperature=0.7)
        assert sampled_tokens.shape == (batch_size,)

if __name__ == "__main__":
    pytest.main([__file__])

Building and Running

Setup Configuration

File: setup_kernels.py

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os

# Check if CUDA is available
def get_extensions():
    extensions = []
    
    # Check if CUDA is available
    try:
        import torch
        if torch.cuda.is_available():
            extensions.append(
                CUDAExtension(
                    name='yaie_kernels',
                    sources=[
                        'src/kernels/cuda/memory_ops.cu',
                        'src/kernels/cuda/paged_attention.cu', 
                        'src/kernels/cuda/flash_attention.cu',
                        'src/kernels/cuda/radix_ops.cu',
                        'src/kernels/cuda/pybind.cpp',  # Python bindings
                    ],
                    extra_compile_args={
                        'cxx': ['-O3'],
                        'nvcc': ['-O3', '--use_fast_math', '-arch=sm_70']
                    }
                )
            )
    except:
        print("CUDA not available, building without CUDA extensions")
    
    return extensions

setup(
    name='yaie_kernels',
    ext_modules=get_extensions(),
    cmdclass={'build_ext': BuildExtension},
    zip_safe=False,
)

Performance Optimization Guidelines

CUDA Optimization Tips

  1. Memory Coalescing: Ensure threads in a warp access consecutive memory
  2. Shared Memory: Use for frequently accessed data
  3. Occupancy: Maximize number of active warps
  4. Reduction Operations: Use efficient parallel reduction algorithms

Profiling and Benchmarking

Create benchmarking tools:

import torch
import time
from torch.profiler import profile, record_function, ProfilerActivity

def benchmark_kernel(kernel_func, *args, **kwargs):
    """Benchmark a kernel function"""
    # Warmup
    for _ in range(3):
        result = kernel_func(*args, **kwargs)
    
    # Actual timing
    torch.cuda.synchronize()
    start_time = time.time()
    
    for _ in range(10):  # Run multiple times for average
        result = kernel_func(*args, **kwargs)
    
    torch.cuda.synchronize()
    end_time = time.time()
    
    avg_time = (end_time - start_time) / 10
    return avg_time, result

def profile_kernel(kernel_func, *args, **kwargs):
    """Profile a kernel function"""
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
        result = kernel_func(*args, **kwargs)
    
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    return result

This comprehensive implementation guide provides everything needed to implement the core kernels for Mini-YAIE, following SGLang-style optimization principles while maintaining educational value.

API & Serving: OpenAI-Compatible Server

Overview

The API server in Mini-YAIE implements an OpenAI-compatible interface, allowing the engine to be used with existing applications and tools designed for OpenAI’s API. The server uses FastAPI to provide RESTful endpoints with proper request/response handling, streaming support, and health monitoring.

API Design Philosophy

The server follows OpenAI’s API specification to ensure compatibility with existing tools and applications while leveraging the advanced features of the SGLang-style inference engine. This approach allows users to:

  • Use existing OpenAI clients without modification
  • Take advantage of Mini-YAIE’s performance optimizations
  • Integrate with tools built for OpenAI’s API format
  • Maintain familiar request/response patterns

Core Architecture

FastAPI Application

The server is built using FastAPI for high-performance web serving:

def create_app(model_name: str) -> FastAPI:
    app = FastAPI(title="YAIE API", version="0.1.0")
    engine = InferenceEngine(model_name)
    return app

Main Components

  1. Inference Engine Integration: Connects API endpoints to the core inference engine
  2. Request Validation: Pydantic models for request/response validation
  3. Streaming Support: Server-sent events for real-time token streaming
  4. Error Handling: Proper HTTP error codes and message formatting
  5. Health Monitoring: Endpoints for system status and availability

API Endpoints

1. Chat Completions Endpoint

The main endpoint follows OpenAI’s /v1/chat/completions specification:

@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
    # Implementation handles both streaming and non-streaming responses

Request Schema

class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 1.0
    max_tokens: Optional[int] = None
    stream: Optional[bool] = False

Response Schema

class ChatCompletionResponse(BaseModel):
    id: str
    object: str = "chat.completion"
    created: int
    model: str
    choices: List[Choice]
    usage: Dict[str, int]

Supported Parameters

  • model: Model identifier (passed during server startup)
  • messages: List of message objects with role and content
  • temperature: Sampling temperature (0.0-2.0 recommended)
  • top_p: Nucleus sampling threshold (0.0-1.0)
  • max_tokens: Maximum tokens to generate
  • stream: Whether to stream responses (true/false)

2. Model Listing Endpoint

Lists the available model:

@app.get("/v1/models")
async def list_models():
    return {
        "object": "list",
        "data": [
            {
                "id": model_name,
                "object": "model",
                "owned_by": "user",
                "created": int(time.time()),
            }
        ],
    }

3. Health Check Endpoint

Simple health monitoring:

@app.get("/health")
async def health_check():
    return {"status": "healthy", "model": model_name}

Streaming Implementation

Streaming vs Non-Streaming

The server supports both response formats using the same endpoint:

if request.stream:
    # Return streaming response
    return StreamingResponse(generate_stream(), media_type="text/event-stream")
else:
    # Return non-streaming response
    response = engine.chat_completion(messages_dicts, **kwargs)
    return response

Streaming Response Format

The streaming implementation generates Server-Sent Events (SSE):

def generate_stream():
    for chunk in engine.chat_completion_stream(messages_dicts, **kwargs):
        yield f"data: {json.dumps(chunk)}\n\n"
    yield "data: [DONE]\n\n"

Each chunk follows OpenAI’s streaming format:

{
  "id": "chatcmpl-...",
  "object": "chat.completion.chunk",
  "created": 1234567890,
  "model": "model-name",
  "choices": [{
    "index": 0,
    "delta": {"content": "token"},
    "finish_reason": null
  }]
}

Integration with Inference Engine

Request Processing Flow

  1. API Request: Received through FastAPI endpoints
  2. Validation: Pydantic models validate request format
  3. Parameter Extraction: Convert API parameters to engine format
  4. Engine Processing: Call appropriate engine methods
  5. Response Formatting: Convert engine output to API format
  6. API Response: Return properly formatted responses

Parameter Mapping

API parameters are mapped to engine capabilities:

kwargs = {}
if request.max_tokens is not None:
    kwargs["max_tokens"] = request.max_tokens
if request.temperature is not None:
    kwargs["temperature"] = request.temperature
if request.top_p is not None:
    kwargs["top_p"] = request.top_p

Message Formatting

The server handles OpenAI-style messages by converting them to a format the engine understands:

messages_dicts = [
    {"role": msg.role, "content": msg.content}
    for msg in request.messages
]

# Apply chat template if available
if hasattr(self.tokenizer, 'apply_chat_template'):
    formatted_prompt = self.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
else:
    # Fallback formatting
    formatted_prompt = ""
    for message in messages:
        formatted_prompt += f"{message['role'].capitalize()}: {message['content']}\n"
    formatted_prompt += "\nAssistant:"

Error Handling

HTTP Error Codes

The server properly handles various error conditions:

  • 400 Bad Request: Invalid request parameters
  • 429 Too Many Requests: Rate limiting (not implemented in basic version)
  • 500 Internal Server Error: Server-side errors during processing

Error Response Format

Standard OpenAI-compatible error format:

{
  "error": {
    "message": "Error description",
    "type": "server_error",
    "param": null,
    "code": null
  }
}

Exception Handling

The server wraps all processing in try-catch blocks:

try:
    # Process request
    response = engine.chat_completion(messages_dicts, **kwargs)
    return response
except Exception as e:
    traceback.print_exc()
    raise HTTPException(status_code=500, detail=str(e))

Performance Considerations

Request Batching

The API integrates with the engine’s batching system:

  • Multiple API requests can be batched together in the engine
  • Continuous batching maintains high throughput
  • Batch size limited by engine configuration

Memory Management

The server shares memory with the inference engine:

  • KV-cache is shared across API requests
  • Efficient memory reuse through paged cache system
  • Memory limits enforced by engine configuration

Concurrency

FastAPI provides automatic concurrency handling:

  • Async request processing
  • Connection pooling
  • Efficient handling of multiple simultaneous requests

Security Considerations

Input Validation

  • Pydantic models validate all request parameters
  • Type checking prevents injection attacks
  • Length limits prevent excessive resource consumption

Rate Limiting

While not implemented in the basic version, can be added:

  • Per-IP rate limiting
  • Request quota management
  • Usage monitoring

Deployment Configuration

Server Startup

The server can be started with a specific model:

uvicorn server.api:app --host 0.0.0.0 --port 8000

Environment Configuration

The server supports environment-based configuration:

  • Model name via environment variables
  • Port and host configuration
  • Resource limits

SGLang-Style Features Integration

Continuous Batching

The API benefits from the engine’s continuous batching:

  • Requests are automatically batched
  • High throughput maintained
  • Low latency for individual requests

Prefix Sharing

API requests with similar prefixes benefit from:

  • Shared computation in radial attention
  • Reduced memory usage
  • Improved efficiency

Multi-Step Processing

The API leverages the engine’s multi-step capabilities:

  • Efficient prefill and decode phases
  • Optimized request scheduling
  • Memory-aware processing

Usage Examples

Basic Chat Request

curl -X POST "http://localhost:8000/v1/chat/completions" \
  -H "Content-Type: application/json" \
  -d '{
    "model": "model-name",
    "messages": [
      {"role": "user", "content": "Hello, how are you?"}
    ],
    "temperature": 0.7
  }'

Streaming Request

curl -X POST "http://localhost:8000/v1/chat/completions" \
  -H "Content-Type: application/json" \
  -d '{
    "model": "model-name",
    "messages": [
      {"role": "user", "content": "Write a short story"}
    ],
    "stream": true
  }'

Programmatic Usage

import openai

# Configure client to use local server
client = openai.OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="dummy"  # Required but ignored
)

# Use standard OpenAI SDK
response = client.chat.completions.create(
    model="model-name",
    messages=[{"role": "user", "content": "Hello"}],
    temperature=0.7
)

Monitoring and Logging

Health Endpoints

The health endpoint can be used for monitoring:

curl http://localhost:8000/health

Performance Metrics

The server can be extended with:

  • Response time tracking
  • Request volume monitoring
  • Error rate monitoring
  • Resource utilization metrics

Advanced Features

Model Loading

The server handles model loading automatically:

  • Lazy loading when first accessed
  • Caching for subsequent requests
  • HuggingFace model integration

Response Caching

The system supports response caching for:

  • Repeated identical requests
  • Common prompt patterns
  • Improved response times for cached content

Logging and Debugging

Comprehensive logging can be added:

  • Request/response logging
  • Performance metrics
  • Error tracing
  • Usage analytics

This OpenAI-compatible API server enables Mini-YAIE to be integrated into existing ecosystems while providing the performance benefits of SGLang-style inference optimization.

CLI Usage: Interactive and Server Modes

Overview

Mini-YAIE provides a comprehensive command-line interface (CLI) that serves as the primary entry point for users. The CLI supports both interactive chat mode and server mode, making it suitable for both direct interaction and production deployment scenarios.

CLI Architecture

Entry Point Structure

The CLI is organized around different command verbs:

yaie <command> [options] [arguments]

Commands:
- serve: Start an OpenAI-compatible API server
- chat: Start an interactive chat session

Core Components

  1. Argument Parsing: Uses argparse for command-line option handling
  2. Model Integration: Connects CLI commands to the inference engine
  3. Interactive Interface: Provides user-friendly chat experience
  4. Server Integration: Launches API server with proper configuration

Server Mode

Basic Server Usage

Start the API server with a specific model:

yaie serve microsoft/DialoGPT-medium --host localhost --port 8000

Server Options

Model Selection

--model MODEL_NAME          # Specify the model to use (required)

Network Configuration

--host HOST                 # Server host (default: localhost)
--port PORT                 # Server port (default: 8000)
--workers WORKERS           # Number of server workers

Performance Options

--max-batch-size N          # Maximum batch size
--max-prefill-batch-size N  # Maximum prefill batch size
--max-decode-batch-size N   # Maximum decode batch size
--num-blocks N              # Number of KV-cache blocks
--block-size N              # Size of each cache block

Server Startup Process

  1. Model Loading: Download and load model from HuggingFace if not cached
  2. Engine Initialization: Create inference engine with specified parameters
  3. API Server Creation: Initialize FastAPI application with engine
  4. Server Launch: Start the web server on specified host/port

Example Server Commands

Basic Server

yaie serve microsoft/DialoGPT-medium

Production Server

yaie serve microsoft/DialoGPT-medium --host 0.0.0.0 --port 8000 --max-batch-size 16

Resource-Constrained Server

yaie serve microsoft/DialoGPT-medium --num-blocks 1000 --max-batch-size 4

Chat Mode

Basic Chat Usage

Start an interactive chat session:

yaie chat microsoft/DialoGPT-medium

Chat Options

Generation Parameters

--temperature TEMP          # Sampling temperature (default: 1.0)
--top-p TOP_P               # Nucleus sampling threshold (default: 1.0)
--max-tokens N              # Maximum tokens to generate (default: 512)
--stream                    # Stream responses in real-time (default: true)

Model Configuration

--model MODEL_NAME          # Specify the model to use (required)

Interactive Chat Experience

Session Flow

  1. Model Loading: Model is loaded if not cached
  2. Chat Initialization: Engine and tokenizer are set up
  3. Conversation Loop: User inputs are processed and responses generated
  4. Session Termination: Exit with Ctrl+C or quit command

User Interaction

The chat interface provides a conversational experience:

$ yaie chat microsoft/DialoGPT-medium
Model loaded successfully!
Starting chat session (press Ctrl+C to exit)...

User: Hello, how are you?
AI: I'm doing well, thank you for asking!

User: What can you help me with?
AI: I can have conversations, answer questions, and assist with various tasks.

Example Chat Commands

Basic Chat

yaie chat microsoft/DialoGPT-medium

Creative Chat

yaie chat microsoft/DialoGPT-medium --temperature 1.2 --top-p 0.9

Focused Chat

yaie chat microsoft/DialoGPT-medium --temperature 0.7 --max-tokens 128

Model Selection

Supported Model Formats

The CLI supports any HuggingFace-compatible model:

Pre-trained Models

yaie serve microsoft/DialoGPT-medium
yaie serve gpt2
yaie serve facebook/opt-1.3b

Local Models

yaie serve /path/to/local/model
yaie serve ./models/my-custom-model

Model Caching

Models are automatically downloaded and cached:

  • First run: Download from HuggingFace Hub
  • Subsequent runs: Use cached version
  • Cache location: Standard HuggingFace cache directory

Performance Tuning

Memory Configuration

Adjust memory settings based on available GPU memory:

# For 24GB+ GPU
yaie serve model --num-blocks 4000 --max-batch-size 32

# For 8-16GB GPU  
yaie serve model --num-blocks 1500 --max-batch-size 8

# For 4-8GB GPU
yaie serve model --num-blocks 800 --max-batch-size 4

Batch Size Optimization

Tune batch sizes for optimal throughput:

# High throughput (more memory)
yaie serve model --max-batch-size 32 --max-prefill-batch-size 64

# Memory efficient (lower batch sizes)
yaie serve model --max-batch-size 4 --max-prefill-batch-size 8

Error Handling and Troubleshooting

Common Errors

Model Loading Errors

# If model name is invalid
Error: Model not found on HuggingFace Hub

# If network is unavailable during first load
Error: Failed to download model

Memory Errors

# If not enough GPU memory
CUDA out of memory error

# If KV-cache is too large
Memory allocation failed

Debugging Options

Verbose Output

yaie serve model --verbose  # Show detailed startup information

Configuration Validation

yaie serve model --debug    # Enable debugging features

Advanced CLI Features

Configuration Files

The CLI supports configuration files for complex setups:

yaie serve --config config.yaml model

Environment Variables

Several environment variables can customize behavior:

# Set default host
export YAIE_HOST=0.0.0.0

# Set default port  
export YAIE_PORT=9000

# Set memory limits
export YAIE_MAX_BLOCKS=2000

Logging Configuration

Control logging verbosity and output:

# Enable detailed logging
yaie serve model --log-level DEBUG

# Log to file
yaie serve model --log-file server.log

Integration with SGLang Features

Batching Optimization

The CLI exposes SGLang batching parameters:

yaie serve model \
  --max-prefill-batch-size 16 \
  --max-decode-batch-size 256

Prefix Sharing Control

Parameters that affect prefix sharing efficiency:

yaie serve model \
  --max-seq-len 2048 \
  --block-size 16

Production Deployment

Server Management

Process Control

# Start server in background
nohup yaie serve model > server.log 2>&1 &

# Kill server process
pkill -f "yaie serve"

Process Monitoring

# Monitor server with systemd
systemctl start yaie-server

# Monitor with supervisor
supervisorctl start yaie-server

Health Checks

The server provides health status:

# Check server status
curl http://localhost:8000/health

# Integrate with monitoring tools
# Health check interval and thresholds

Examples and Use Cases

Development Usage

# Quick test with small model
yaie chat gpt2

# Interactive development with verbose output
yaie serve gpt2 --port 8000 --verbose

Production Usage

# High-performance server for production
yaie serve microsoft/DialoGPT-medium \
  --host 0.0.0.0 \
  --port 8000 \
  --max-batch-size 16 \
  --num-blocks 2000

# Low-resource server for edge deployment
yaie serve gpt2 \
  --max-batch-size 2 \
  --num-blocks 500

Testing and Evaluation

# Test with various parameters
yaie chat model --temperature 0.5 --top-p 0.9

# Evaluate different models
yaie serve model1 --port 8001 &
yaie serve model2 --port 8002 &

The CLI provides a comprehensive interface to access all of Mini-YAIE’s features, from simple interactive chat to high-performance API serving with SGLang-style optimizations.

Production Deployment

While Mini-YAIE is primarily educational, understanding production considerations helps bridge the gap between learning and real-world deployment.

Deployment Architecture

graph TD
    LoadBalancer -->|HTTP| API1
    LoadBalancer -->|HTTP| API2
    LoadBalancer -->|HTTP| API3
    API1 -->|gRPC| Inference1
    API2 -->|gRPC| Inference2
    API3 -->|gRPC| Inference3
    Inference1 -->|GPU| Model1
    Inference2 -->|GPU| Model2
    Inference3 -->|GPU| Model3

1. Performance Optimization

Batching & Latency Management

  • Request Timeouts: Implement configurable timeouts for requests waiting in queue
  • Priority Queues: Support different priority levels for requests
  • Preemption: Pause low-priority requests and swap their KV-cache to CPU when high-priority requests arrive
  • Adaptive Batching: Dynamically adjust batch sizes based on current load and latency requirements

Continuous Batching Strategies

# Advanced scheduling policies
class SchedulingPolicy(Enum):
    FCFS = "fcfs"  # First-come-first-served
    PRIORITY = "priority"  # Priority-based
    FAIR = "fair"  # Fair sharing
    LATENCY_OPTIMIZED = "latency"  # Minimize latency

2. Distributed Inference

Tensor Parallelism

# Split model weights across multiple GPUs
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map={
        "layer.0": "cuda:0",
        "layer.1": "cuda:1",
        "layer.2": "cuda:2",
        "layer.3": "cuda:3"
    }
)

Pipeline Parallelism

# Split layers across different stages
# Stage 1: Layers 0-10 on GPU 0
# Stage 2: Layers 11-20 on GPU 1
# Stage 3: Layers 21-30 on GPU 2

Model Parallelism

  • Expert Parallelism: For mixture-of-experts models
  • Sequence Parallelism: For very long sequences
  • Hybrid Parallelism: Combining multiple strategies

3. Memory Optimization

Quantization

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,  # 8-bit quantization
    device_map="auto"
)

# Or use 4-bit quantization
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    device_map="auto"
)

Memory Management Strategies

  • Paged Attention: Efficient memory management for KV-cache
  • Memory Pooling: Reuse memory blocks across requests
  • Swapping: Move inactive KV-cache blocks to CPU memory
  • Compression: Compress KV-cache values with minimal quality loss

4. Scalability & Reliability

Horizontal Scaling

# Multiple inference workers behind load balancer
workers = [
    InferenceWorker(model_name, gpu_id=0),
    InferenceWorker(model_name, gpu_id=1),
    InferenceWorker(model_name, gpu_id=2),
    InferenceWorker(model_name, gpu_id=3)
]

Load Balancing

  • Round Robin: Simple distribution across workers
  • Least Connections: Send to least busy worker
  • Latency-based: Send to worker with lowest current latency
  • Content-based: Route based on request characteristics

Health Monitoring

# Health check endpoints
@app.get("/health")
def health_check():
    return {
        "status": "healthy",
        "gpu_memory": get_gpu_memory_usage(),
        "active_requests": scheduler.get_active_request_count(),
        "queue_length": scheduler.get_queue_status()
    }

5. Observability

Metrics Collection

# Prometheus metrics
REQUEST_LATENCY = Histogram(
    'request_latency_seconds',
    'Request latency in seconds',
    buckets=[0.1, 0.5, 1.0, 2.5, 5.0, 10.0]
)

TOKENS_PER_SECOND = Counter(
    'tokens_per_second',
    'Tokens generated per second'
)

Logging

# Structured logging
logger.info(
    "Request completed",
    request_id=request.id,
    latency=latency_seconds,
    tokens_generated=token_count,
    model=model_name
)

Tracing

# Distributed tracing
with tracer.start_span("generate_response") as span:
    span.set_attribute("model", model_name)
    span.set_attribute("request_id", request.id)
    response = engine.generate(prompt)

6. Security

Authentication

# API key authentication
@app.post("/v1/chat/completions")
async def chat_completions(
    request: ChatCompletionRequest,
    api_key: str = Header(None)
):
    if not validate_api_key(api_key):
        raise HTTPException(status_code=401, detail="Unauthorized")

Rate Limiting

# Rate limiting per API key
limiter = RateLimiter(
    requests_per_minute=1000,
    burst_capacity=100
)

@app.post("/v1/chat/completions")
@limiter.limit("1000/minute")
async def chat_completions(request: ChatCompletionRequest):
    # Process request

Input Validation

# Validate and sanitize inputs
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
    # Validate prompt length
    if len(request.prompt) > MAX_PROMPT_LENGTH:
        raise HTTPException(status_code=400, detail="Prompt too long")
    
    # Sanitize inputs
    sanitized_prompt = sanitize_input(request.prompt)

7. Configuration Management

Environment Variables

# Configure through environment variables
export MODEL_NAME="gpt2"
export MAX_BATCH_SIZE="16"
export GPU_MEMORY_UTIL="0.9"
export ENABLE_RADIX_CACHE="true"

Configuration Files

# YAML configuration
model:
  name: "gpt2"
  dtype: "float16"
  tensor_parallel_size: 1

scheduler:
  max_batch_size: 16
  max_prefill_batch_size: 32
  max_decode_batch_size: 256

memory:
  gpu_blocks: 2000
  cpu_blocks: 1000
  block_size: 16

8. Deployment Strategies

Containerization

# Dockerfile for Mini-YAIE
FROM python:3.9-slim

WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt

COPY . .
RUN pip install -e .

CMD ["yaie", "serve", "gpt2", "--host", "0.0.0.0", "--port", "8000"]

Kubernetes Deployment

# Kubernetes deployment
apiVersion: apps/v1
kind: Deployment
metadata:
  name: yaie-inference
spec:
  replicas: 3
  selector:
    matchLabels:
      app: yaie-inference
  template:
    metadata:
      labels:
        app: yaie-inference
    spec:
      containers:
      - name: yaie
        image: yaie:latest
        ports:
        - containerPort: 8000
        resources:
          limits:
            nvidia.com/gpu: 1

9. Performance Monitoring

Key Metrics to Track

  • Latency: Time from request to first token (TTFT) and time per token
  • Throughput: Tokens generated per second
  • GPU Utilization: Percentage of GPU time spent on computation
  • Memory Usage: GPU and CPU memory consumption
  • Queue Length: Number of requests waiting
  • Error Rates: Percentage of failed requests

Alerting

# Set up alerts for critical conditions
if gpu_memory_usage > 0.95:
    alert("High GPU memory usage")

if request_latency > 5.0:  # 5 seconds
    alert("High request latency")

if error_rate > 0.01:  # 1%
    alert("High error rate")

10. Cost Optimization

Resource Management

  • Auto-scaling: Scale workers based on demand
  • Spot Instances: Use cheaper spot instances for non-critical workloads
  • Right-sizing: Choose appropriate instance types for workload
  • Batch Processing: Process offline requests in batches during low-traffic periods

Model Selection

# Choose appropriate model size for use case
small_models = ["gpt2", "DialoGPT-small"]  # Fast, low cost
medium_models = ["gpt2-medium", "DialoGPT-medium"]  # Balanced
large_models = ["gpt2-large", "DialoGPT-large"]  # High quality, expensive

Educational Focus

Understanding production considerations helps you:

  1. Bridge the gap between educational implementations and real-world systems
  2. Appreciate the complexity of production-grade inference engines
  3. Make informed decisions about trade-offs in your implementations
  4. Design for scalability from the beginning

From Mini-YAIE to Production

Mini-YAIE provides the foundation for understanding key concepts:

  • Continuous Batching: The core of efficient inference
  • Memory Management: Critical for handling multiple requests
  • Prefix Sharing: Advanced optimization for similar requests
  • API Design: Standard interfaces for integration

Production systems build on these foundations with:

  • Scalability: Handling thousands of concurrent requests
  • Reliability: High availability and fault tolerance
  • Observability: Comprehensive monitoring and logging
  • Security: Authentication, authorization, and input validation

By mastering the concepts in Mini-YAIE, you’ll be well-prepared to understand and contribute to production-grade inference systems like vLLM, SGLang, and TensorRT-LLM.

References

  1. SGLang: Efficient Execution of Structured Language Model Programs. Link
  2. vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention. Link
  3. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Link

Troubleshooting

“CUDA kernel not found”

  • Ensure you ran pip install -e ..
  • Check if nvcc is in your path: nvcc --version.

“OutOfMemoryError”

  • Decrease max_batch_size.
  • Decrease kv_cache_manager block count.

“ImportError: attempted relative import…”

  • Ensure you are running the yaie command, or running python as a module python -m src.cli.main.
  • Do not run scripts directly like python src/engine.py.