Welcome to Mini-YAIE
Mini-YAIE (Yet Another Inference Engine) is an educational project designed to demystify modern Large Language Model (LLM) inference engines.
Driven by the need for efficiency, modern engines like SGLang, vLLM, and TensorRT-LLM use sophisticated techniques to maximize GPU throughput and minimize latency. Mini-YAIE provides a simplified, clean implementation of these concepts, focusing on:
- Continuous Batching
- Paged KV Caching
- Radix Attention (Prefix Sharing)
How to use this guide
This documentation is structured to take you from high-level concepts to low-level implementation.
- Core Concepts: Start here to understand the why and what of inference optimization.
- Architecture: Understand how the system components fit together.
- Implementation Guides: Step-by-step guides to implementing the missing “kernels” in Python and CUDA.
Your Mission
The codebase contains placeholders (NotImplementedError) for critical components. Your goal is to implement these components following this guide, turning Mini-YAIE from a skeleton into a fully functional inference engine.
Prerequisites
To successfully implement the kernels in Mini-YAIE, you should be familiar with:
Programming Languages
- Python (Intermediate): Understanding of classes, inheritance, type hinting, and PyTorch tensors.
- C++ (Basic): For reading and writing the CUDA kernels (though much of the boilerplate is provided).
- CUDA (Basic): Understanding of the GPU execution model (blocks, threads, shared memory).
Machine Learning Concepts
- Transformer Architecture: Queries, Keys, Values, Attention mechanism.
- Tensors: Shapes, dimensions, matrix multiplication.
Tools
- Git: For version control.
- Linux/Unix Shell: For running commands.
Environment Setup
1. Clone the Repository
git clone https://github.com/yourusername/Mini-YAIE.git
cd Mini-YAIE
2. Python Environment
It is highly recommended to use a virtual environment.
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
pip install -e .
3. CUDA Requirements (Optional)
To build and run the CUDA kernels, you need:
- NVIDIA GPU (Compute Capability 7.0+)
- CUDA Toolkit 11.8+
- PyTorch with CUDA support
If you do not have a GPU, you can still implement the Python logic and the CPU fallback kernels.
4. Documentation Setup
To serve this documentation locally:
-
Install mdbook:
# If you have Rust/Cargo installed: cargo install mdbook # Or download the binary from their GitHub releases. -
Serve the docs:
mdbook serve docsNavigate to
http://localhost:3000in your browser.
Model Loading
YAIE supports loading models from HuggingFace Hub with automatic caching and local model support.
ModelLoader Class
The ModelLoader class in src/models/loader.py handles all model and tokenizer loading operations.
Initialization
from src.models.loader import ModelLoader
# Load from HuggingFace Hub
loader = ModelLoader("microsoft/DialoGPT-medium")
# Load from local path
loader = ModelLoader("/path/to/local/model")
Loading Models
# Load the model
model = loader.load_model()
# Load the tokenizer
tokenizer = loader.load_tokenizer()
Supported Model Sources
HuggingFace Hub Models
YAIE can load any compatible model from HuggingFace Hub:
# Popular conversational models
loader = ModelLoader("microsoft/DialoGPT-medium")
loader = ModelLoader("microsoft/DialoGPT-large")
# Code generation models
loader = ModelLoader("Salesforce/codegen-350M-mono")
# General purpose models
loader = ModelLoader("gpt2")
loader = ModelLoader("gpt2-medium")
Local Models
You can also load models from local directories:
# Load from local path
loader = ModelLoader("./models/my-custom-model")
Caching Behavior
Automatic Caching
Models are automatically cached in the standard HuggingFace cache directory:
- Linux/macOS:
~/.cache/huggingface/ - Windows:
C:\Users\<username>\.cache\huggingface\
Cache Structure
~/.cache/huggingface/
├── hub/
│ └── models--microsoft--DialoGPT-medium/
│ ├── blobs/
│ ├── refs/
│ └── snapshots/
│ └── abc123.../
│ ├── config.json
│ ├── pytorch_model.bin
│ └── tokenizer.json
Cache Management
The loader automatically:
- Checks for existing models in cache
- Downloads missing models from HuggingFace Hub
- Uses cached models for subsequent loads
Model Configuration
Data Types
Models are loaded with optimized data types:
# Models are loaded with float16 by default for efficiency
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16, # Half precision
device_map="auto" # Automatic device placement
)
Device Placement
- Single GPU: Model is loaded directly to GPU
- Multi-GPU: Automatically distributed across available GPUs
- CPU: Falls back to CPU if no GPU available
Tokenizer Configuration
Automatic Pad Token
The loader ensures tokenizers have a pad token:
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
This is important for batch processing where sequences need to be padded to the same length.
Memory Optimization
Lazy Loading
Models are loaded on-demand, not at import time:
# Model is not loaded here
loader = ModelLoader("gpt2")
# Model is loaded here when requested
model = loader.load_model()
Memory Mapping
Large models use memory mapping to reduce RAM usage during loading.
Error Handling
Network Issues
If downloading fails, the loader will retry and provide clear error messages.
Incompatible Models
Models must be compatible with AutoModelForCausalLM. Incompatible models will raise clear errors.
Disk Space
Large models require significant disk space. The loader shows download progress and estimated sizes.
Performance Tips
Pre-download Models
For production deployments, pre-download models:
# This will cache the model
python -c "from src.models.loader import ModelLoader; loader = ModelLoader('microsoft/DialoGPT-medium'); loader.load_model()"
Cache Location
You can customize the cache location by setting the HF_HOME environment variable:
export HF_HOME=/path/to/custom/cache
Model Selection
Choose appropriate model sizes for your hardware:
- Small models (< 1GB):
gpt2,DialoGPT-small - Medium models (1-5GB):
gpt2-medium,DialoGPT-medium - Large models (> 5GB):
gpt2-large,DialoGPT-large
Troubleshooting
Common Issues
“Model not found” errors:
- Check model name spelling
- Verify model exists on HuggingFace Hub
- Ensure internet connection for downloads
Out of memory errors:
- Try smaller models
- Reduce batch sizes in configuration
- Use CPU-only mode if GPU memory is insufficient
Tokenizer issues:
- Some models may require special token handling
- Check the model’s documentation on HuggingFace Hub
docs/src/intro/model_loading.md
LLM Inference: The Basics
Large Language Model (LLM) inference is the process of generating text from a trained model. It consists of two distinct phases.
1. Prefill Phase (The “Prompt”)
- Input: The user’s prompt (e.g., “Write a poem about cats”).
- Operation: The model processes all input tokens in parallel.
- Output: The KV (Key-Value) cache for the prompt and the first generated token.
- Characteristic: Compute-bound. We maximize parallelism here.
The Process Visualized
sequenceDiagram
participant U as User
participant E as Engine
participant M as Model
rect rgb(200, 220, 255)
note right of U: Prefill Phase (Parallel)
U->>E: Prompt: "A B C"
E->>M: Forward(["A", "B", "C"])
M-->>E: KV Cache + Logits(C)
end
rect rgb(220, 255, 200)
note right of U: Decode Phase (Serial)
loop Until EOS
E->>M: Forward([Last Token])
M-->>E: Update KV + Logits
E->>E: Sample Next Token
end
end
E->>U: Response
2. Decode Phase (The “Generation”)
- Input: The previously generated token.
- Operation: The model generates one token at a time, autoregressively.
- Output: The next token and an updated KV cache.
- Characteristic: Memory-bound. We are limited by how fast we can move weights and KV cache from memory to the compute units.
The KV Cache
State management is crucial. Instead of re-computing the attention for all previous tokens at every step, we cache the Key and Value vectors for every token in the sequence. This is the KV Cache. Managing this cache efficiently is the main challenge of high-performance inference engines.
Continuous Batching
The Problem: Static Batching
In traditional deep learning (like training), we use static batches: all sequences in a batch must have the same length (padded to the max length).
- Waste: Padding wastes computation and memory.
- Latency: We must wait for the longest sequence to finish generating before finishing the batch.
Visualizing the Difference
gantt
title Static Batching (Inefficient)
dateFormat YYYY-MM-DD
axisFormat %H:%M
section Batch 1
Req A (Short) :done, a1, 2024-01-01, 2d
Padding :crit, 2024-01-03, 2d
Req B (Long) :active, b1, 2024-01-01, 4d
section Batch 2
Req C :c1, 2024-01-05, 2d
gantt
title Continuous Batching (Efficient)
dateFormat YYYY-MM-DD
axisFormat %H:%M
section GPU Stream
Req A (Short) :done, a1, 2024-01-01, 2d
Req C (New!) :active, c1, 2024-01-03, 2d
section GPU Stream 2
Req B (Long) :active, b1, 2024-01-01, 4d
The Solution: Continuous Batching (Orca)
Introduced by the Orca paper, Continuous Batching (or Iteration-level Batching) decouples the implementation of a batch from the user’s view.
- Iteration Level: The engine runs one iteration (one forward pass) at a time.
- Dynamic Insertion: As soon as a request finishes, it enters the “Completed” state. A new request from the queue can immediately take its place in the next iteration.
- No Padding: We process only the valid tokens for each request.
This significantly improves throughput (requests per second) without hurting latency (time per token) for individual requests.
Radix Attention (SGLang)
Radix Attention is the core innovation of SGLang. It optimizes the Prefill Phase by reusing computation from previous requests.
The Intuition
If two users ask:
- “Write a Python script to scrape a website.”
- “Write a Python script to sort a list.”
They share the prefix “Write a Python script to “. In a standard engine, we would compute the KV cache for this prefix twice.
graph TD
classDef shared fill:#aaffaa,stroke:#333,stroke-width:2px;
classDef unique fill:#ffaaaa,stroke:#333,stroke-width:2px;
Root((Root)) --> Node1["Write a Python script to"]:::shared
Node1 --> Node2["scrape a website"]:::unique
Node1 --> Node3["sort a list"]:::unique
style Node1 fill:#aaffaa
The Radix Tree
SGLang maintains a Radix Tree (Trie) of all token sequences currently in the KV cache.
- Nodes: Sequences of tokens.
- Edges: Transitions to new tokens.
When a new request arrives, we map its prompt to the longest matching path in the Radix Tree.
- Hit: We reuse the KV Cache for the matched part. The prefill only needs to compute the new suffix.
- Miss: We compute from scratch.
Benefits
- Reduced Latency: “Time To First Token” (TTFT) is nearly zero for cached prefixes.
- Higher Throughput: Less computation required per request.
- Complex Workflows: Enables efficient multi-turn chat, few-shot learning, and tree-of-thought prompting.
Paged Attention (vLLM)
Paged Attention is the core innovation of vLLM. It optimizes the Decode Phase by managing memory like an Operating System.
The Problem: Memory Fragmentation
Before vLLM, engines allocated contiguous memory for the maximum possible length of a request.
- Internal Fragmentation: If a request was shorter than max length, memory was wasted.
- External Fragmentation: We couldn’t fit a new request even if total free memory was sufficient, because no single contiguous block was large enough.
The Solution: Paging
Inspired by virtual memory in OS:
- Blocks: Divide KV Cache into fixed-size blocks (e.g., 16 tokens per block).
- Non-Contiguous: Blocks can be stored anywhere in physical GPU memory.
- Mapping: A “Block Table” maps logical token positions to physical block addresses.
graph LR
subgraph Logical[Logical Sequence Request]
L0[Block 0: "Hello"]
L1[Block 1: "World"]
L2[Block 2: "!"]
end
subgraph Table[Page Table]
T0[0 -> 7]
T1[1 -> 2]
T2[2 -> 9]
end
subgraph Physical[GPU Memory Physical Blocks]
B0[Block 0]
B1[Block 1]
B2[Block 2: "World"]:::used
B3...
B7[Block 7: "Hello"]:::used
B8...
B9[Block 9: "!"]:::used
end
L0 --> T0 --> B7
L1 --> T1 --> B2
L2 --> T2 --> B9
classDef used fill:#aaffaa;
The Kernel
The Paged Attention kernel allows the Attention mechanism to read keys and values from these non-contiguous blocks on the fly, enabling near-zero memory waste.
System Architecture Overview
Introduction
Mini-YAIE (Yet Another Inference Engine) is an educational implementation of modern LLM inference techniques, specifically designed to demonstrate concepts from state-of-the-art systems like SGLang, vLLM, and TensorRT-LLM. The architecture focuses on three core optimizations:
- Continuous Batching: Dynamically batching incoming requests to maximize GPU utilization
- Radix Attention: Efficient attention mechanism with prefix sharing for similar requests
- Paged KV-Cache: Memory-efficient key-value cache management
High-Level Architecture
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ API Layer │ │ Engine Core │ │ Model/Kernels │
│ (FastAPI) │◄──►│ (Scheduler, │◄──►│ (PyTorch/ │
│ │ │ Attention) │ │ CUDA) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
▲ ▲ ▲
│ │ │
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ CLI Layer │ │ Model Loading │ │ Memory Mgmt │
│ (yaie serve/ │ │ (HuggingFace │ │ (Paged Cache) │
│ yaie chat) │ │ Integration) │ │ │
└─────────────────┘ └─────────────────┘ └─────────────────┘
Core Components
1. Main Inference Engine (engine.py)
The main inference engine orchestrates all components and provides the high-level API for inference. It implements SGLang-style continuous batching with radix attention and prefix sharing.
Key Responsibilities:
- Request orchestration and management
- Integration between scheduler, attention mechanisms, and memory management
- API layer communication
- Model loading and tokenizer management
2. SGLang Scheduler (core/sglang_scheduler.py)
The SGLang-style scheduler implements advanced request scheduling with:
- Prefix-based request grouping: Groups requests with common prefixes for computation sharing
- Separate prefill and decode scheduling: Optimizes for the different computational patterns
- Memory-aware batch sizing: Considers available KV-cache memory when scheduling
- Continuous batching optimization: Maintains high GPU utilization
3. Radix Attention System (kernels/radix_attention.py)
Implements the radial attention mechanism with:
- Prefix sharing: Reduces redundant computation for requests with common prefixes
- Paged KV-cache integration: Efficient memory management for variable-length requests
- RoPE (Rotary Position Embeddings): Supports position-aware attention
4. Paged KV-Cache Management (kernels/kv_cache.py)
Efficient memory management using page-based allocation:
- Fixed-size blocks: Reduces memory fragmentation
- Request-to-block mapping: Tracks which blocks belong to which requests
- Dynamic allocation/deallocation: Manages memory based on request lifecycle
5. Radix Tree System (kernels/radix_tree.py)
Enables efficient prefix matching and computation sharing:
- Trie-based structure: Organizes token sequences hierarchically
- Request grouping: Identifies requests with shared prefixes
- Computation optimization: Provides information for scheduler optimization
6. Sampling Kernel (kernels/sampling.py)
Implements core sampling algorithms:
- Temperature scaling: Controls randomness in generation
- Top-K sampling: Limits selection to top K most probable tokens
- Top-P (Nucleus) sampling: Limits selection to tokens that sum to probability P
7. API Server (server/api.py)
Provides OpenAI-compatible API endpoints:
- RESTful design: Follows OpenAI’s API specification
- Streaming support: Real-time token streaming
- Health monitoring: Server status endpoints
Data Flow
The system processes requests in the following sequence:
- Request Arrival: Client sends a request through the API layer
- Request Scheduling: SGLang scheduler groups requests with common prefixes
- Prefill Phase: Process full prompt sequences using radial attention
- Decode Phase: Generate tokens one-by-one with shared computation
- KV-Cache Management: Efficient memory allocation and sharing
- Response Generation: Return results via API layer
Key Design Principles
Modularity
Each component is designed to be independent, allowing for focused learning and experimentation.
Educational Focus
Clean, well-documented code with comprehensive explanations of key concepts.
SGLang-Style Optimization
Focus on prefix sharing and radix trees to maximize computational efficiency.
Memory Efficiency
Paged cache management to reduce memory fragmentation and maximize utilization.
Architecture Benefits
- High Throughput: Continuous batching and prefix sharing maximize GPU utilization
- Memory Efficiency: Paged KV-cache reduces fragmentation and enables larger batch sizes
- Scalability: Modular design allows for optimization of individual components
- Educational Value: Clean implementation of state-of-the-art techniques
Integration Points
The system integrates components through well-defined interfaces:
- Engine connects to scheduler for request management
- Scheduler connects to memory manager for KV-cache coordination
- Attention mechanisms access KV-cache through the memory manager
- Sampler provides token selection for generation
- API layer communicates with the engine for request processing
Configuration Management
Overview
Mini-YAIE uses a flexible configuration system that allows users to customize various aspects of the inference engine without modifying the code. The configuration system provides settings for memory management, scheduling, model loading, and performance optimization.
Configuration Structure
The configuration system is built around the SGLangConfig dataclass in src/config.py. The system supports:
- Dataclass-based configuration with type hints
- Default values for all parameters
- Dictionary-based overrides
- Component-specific configuration sections
Key Configuration Parameters
Scheduler Configuration
# Maximum batch size for processing requests
max_batch_size: int = 8
# Maximum batch size for prefill operations
max_prefill_batch_size: int = 16
# Maximum batch size for decode operations
max_decode_batch_size: int = 256
# Maximum sequence length allowed
max_seq_len: int = 2048
KV Cache Configuration
# Number of GPU memory blocks for KV-cache
num_gpu_blocks: int = 2000
# Number of CPU memory blocks for swapping
num_cpu_blocks: int = 1000
# Size of each memory block (in tokens)
block_size: int = 16
Model Configuration
# Data type for model weights and KV-cache
dtype: str = "float16" # Options: "float16", "float32", "bfloat16"
# Tensor parallelism size
tensor_parallel_size: int = 1
# GPU memory utilization fraction
gpu_memory_utilization: float = 0.9
# CPU swap space in GB
swap_space: int = 4
Generation Configuration
# Default maximum tokens to generate per request
default_max_tokens: int = 1024
# Default sampling temperature
default_temperature: float = 1.0
# Default top-p value
default_top_p: float = 1.0
SGLang-Specific Features
# Enable radix attention cache for prefix sharing
enable_radix_cache: bool = True
# Enable chunked prefill for long prompts
enable_chunked_prefill: bool = True
# Scheduling policy: "fcfs" (first-come-first-served)
schedule_policy: str = "fcfs"
# Enable prefix caching
enable_prefix_caching: bool = True
# Maximum scheduling steps before preemption
max_num_schedule_steps: int = 1000
Configuration Loading
Default Configuration
When no explicit configuration is provided, the system uses sensible defaults that work well for most educational purposes:
- Conservative memory usage to work on most GPUs
- Balanced performance settings
- Safe batch sizes that avoid out-of-memory errors
Custom Configuration
Users can customize configurations by:
- Direct parameter passing to constructors
- Environment variables for deployment scenarios
- Configuration files (when implemented)
Configuration Best Practices
Performance Tuning
For production use, consider these configuration adjustments:
- Increase batch sizes based on available GPU memory
- Adjust block size for optimal cache utilization
- Tune memory pool size based on request patterns
Memory Management
Configure memory settings based on your hardware:
# For high-end GPUs (24GB+ VRAM)
num_blocks = 4000
max_batch_size = 32
# For mid-range GPUs (8-16GB VRAM)
num_blocks = 1000
max_batch_size = 8
# For entry-level GPUs (4-8GB VRAM)
num_blocks = 500
max_batch_size = 4
Integration with Components
Engine Integration
The main engine uses the SGLangConfig for initialization:
from src.config import SGLangConfig, get_sglang_config
# Use default config
config = get_sglang_config()
# Or override specific parameters
config = get_sglang_config(
max_batch_size=16,
num_gpu_blocks=4000
)
# Initialize components with config values
scheduler = SGLangScheduler(
max_batch_size=config.max_batch_size,
max_prefill_batch_size=config.max_prefill_batch_size,
max_decode_batch_size=config.max_decode_batch_size
)
Scheduler Configuration
The SGLang scheduler uses configuration for scheduling policies:
- Batch size limits
- Prefill/decode phase sizing
- Memory-aware scheduling decisions
Memory Manager Configuration
The KV-cache manager uses configuration for:
- Total memory pool size
- Block allocation strategies
- Memory optimization policies
Environment-Specific Configuration
Development Configuration
For development and learning:
- Conservative memory limits
- Detailed logging
- Debug information enabled
Production Configuration
For production deployment:
- Optimized batch sizes
- Performance-focused settings
- Minimal logging overhead
Configuration Validation
The system validates configuration parameters to prevent:
- Memory allocation failures
- Invalid parameter combinations
- Performance-degrading settings
Future Extensions
The configuration system is designed to accommodate:
- Model-specific optimizations
- Hardware-aware tuning
- Runtime configuration updates
- Performance auto-tuning
Configuration Examples
Basic Configuration
# Minimal configuration for learning
config = {
"max_batch_size": 4,
"num_blocks": 1000,
"block_size": 16
}
Performance Configuration
# Optimized for throughput
config = {
"max_batch_size": 32,
"max_decode_batch_size": 512,
"num_blocks": 4000,
"block_size": 32
}
Troubleshooting Configuration Issues
Memory Issues
If experiencing out-of-memory errors:
- Reduce
num_blocksin KV-cache - Lower batch sizes
- Check available GPU memory
Performance Issues
If experiencing low throughput:
- Increase batch sizes
- Optimize block size for your model
- Verify CUDA availability and compatibility
Engine Orchestration
Overview
The Inference Engine serves as the main orchestrator of the Mini-YAIE system, coordinating between various components to provide a unified interface for LLM inference. The engine implements SGLang-style continuous batching with radial attention and prefix sharing to maximize efficiency and throughput.
Engine Architecture
The main engine is implemented in src/engine.py and follows a modular design pattern where each component is responsible for specific aspects of request processing:
┌─────────────────┐
│ API Layer │ ← Requests enter here
├─────────────────┤
│ Engine Orchestration │ ← Coordination happens here
├─────────────────┤
│ Scheduler │ ← Request scheduling
├─────────────────┤
│ Memory Manager │ ← KV-cache management
├─────────────────┤
│ Attention Core │ ← Radial attention computation
├─────────────────┤
│ Model/Kernel │ ← Forward pass execution
└─────────────────┘
Core Components
1. Model Loading Integration
The engine handles model and tokenizer loading through the ModelLoader component:
def __init__(self, model_name: str):
self.tokenizer: PreTrainedTokenizer = self._load_tokenizer()
self.model = self._load_model()
This ensures that models are properly loaded from HuggingFace or local cache with appropriate configuration.
2. SGLang-Style Scheduler
The engine integrates with the SGLangScheduler for advanced request scheduling:
self.scheduler = SGLangScheduler(
max_batch_size=8,
max_prefill_batch_size=16,
max_decode_batch_size=256
)
The scheduler implements prefix grouping and multi-step processing for computation sharing.
3. Radial Attention System
The engine connects to the radial attention mechanism:
self.radix_attention = RadixAttentionWithPagedKVCache(
num_layers=self.model.config.num_hidden_layers,
num_heads=self.model.config.num_attention_heads,
head_dim=self.model.config.hidden_size // self.model.config.num_attention_heads,
)
4. KV-Cache Management
The engine manages memory through the KVCacheManager:
self.kv_cache_manager = KVCacheManager(
num_blocks=2000,
block_size=16,
num_heads=self.model.config.num_attention_heads,
head_dim=self.model.config.hidden_size // self.model.config.num_attention_heads,
dtype=torch.float16,
)
Request Processing Flow
1. Request Addition
def generate(self, prompts: List[str], **kwargs) -> List[str]:
# Add requests to scheduler
request_ids = []
for prompt in prompts:
req_id = self.scheduler.add_request(prompt, **kwargs)
request_ids.append(req_id)
2. Generation Loop
The engine runs a main generation loop that processes requests:
def _run_generation_loop(self, request_ids: List[str]) -> List[str]:
# Process requests in batches
# Handle prefill and decode phases
# Manage KV-cache efficiently
3. Response Generation
The engine generates responses with proper tokenization and formatting:
# Generate response using the existing generate method
responses = self.generate([formatted_prompt], **kwargs)
generated_text = responses[0] if responses else ""
SGLang-Style Optimization Features
1. Continuous Batching
The engine supports continuous batching where requests at different stages can be processed together:
- Prefill requests (processing full prompts)
- Decode requests (generating single tokens)
- Mixed batches combining both types
2. Prefix Sharing
The engine enables computation sharing for requests with common prefixes:
- Radix tree identifies shared prefixes
- Common computations are performed once
- Results are shared among multiple requests
3. Memory Efficiency
The engine optimizes memory usage through:
- Paged KV-cache management
- Block allocation strategies
- Memory reclamation for completed requests
API Integration
1. Chat Completion API
The engine provides OpenAI-compatible chat completion:
def chat_completion(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
# Format messages using chat template
# Process through generation pipeline
# Return in OpenAI format
2. Streaming Support
The engine supports streaming responses for real-time applications:
def chat_completion_stream(self, messages: List[Dict[str, str]], **kwargs):
# Generate tokens one by one
# Yield chunks immediately
# Maintain OpenAI stream format
Performance Optimization
1. Batch Size Management
The engine dynamically adjusts batch sizes based on available memory and request characteristics:
- Prefill batches: Optimized for prompt processing
- Decode batches: Optimized for token generation
- Mixed batches: Balanced between both phases
2. Memory Management
The engine coordinates memory usage across components:
# Connect scheduler to memory manager for optimization
self.scheduler.connect_memory_manager(self.kv_cache_manager)
3. Computation Sharing
The engine maximizes computation sharing through radix attention:
- Shared prefix processing
- Common token computations
- Reduced redundant calculations
Error Handling and Resilience
1. Request Validation
The engine validates requests before processing:
- Input format validation
- Parameter range checking
- Resource availability verification
2. Graceful Degradation
When resources are constrained, the engine gracefully degrades:
- Reduced batch sizes
- Fallback mechanisms
- Proper error reporting
3. Resource Management
The engine manages system resources effectively:
- GPU memory monitoring
- Request queue management
- Memory cleanup for completed requests
Integration Points
1. Model Interface
The engine interfaces with any HuggingFace-compatible model:
outputs = self.model(current_ids) # Standard HuggingFace model interface
2. Sampling Integration
The engine uses the sampling kernel for token generation:
sampling_kernel = SamplingKernel()
next_token_id = sampling_kernel.sample(
next_token_logits,
temperature=request.temperature,
top_p=request.top_p
)
3. Scheduler Integration
The engine coordinates closely with the scheduler:
# Add requests to scheduler
req_id = self.scheduler.add_request(prompt, **kwargs)
# Process in generation loop
responses = self._run_generation_loop(request_ids)
Engine Configuration
The engine supports various configuration options:
- Model selection and loading
- Batch size limits
- Memory allocation settings
- Performance optimization parameters
Future Extensions
The engine design supports:
- Additional optimization techniques
- New attention mechanisms
- Enhanced scheduling algorithms
- Advanced memory management strategies
Scheduler Logic: SGLang-Style Request Management
Overview
The SGLang-style scheduler (core/sglang_scheduler.py) implements advanced request scheduling with prefix grouping and computation sharing capabilities. Unlike traditional schedulers, this implementation focuses on maximizing computational efficiency by identifying and leveraging shared prefixes among different requests.
Key Concepts
Request States
The scheduler manages requests through several states:
- PENDING: New requests awaiting initial processing
- SCHEDULED_PREFILL: Requests ready for prefill phase
- RUNNING_PREFILL: Currently processing full prompts
- SCHEDULED_DECODE: Requests ready for token generation
- RUNNING_DECODE: Currently generating tokens
- COMPLETED: Finished requests
- CANCELLED: Cancelled requests
Prefix-Based Grouping
The scheduler uses prefix hashing to group requests with common prefixes:
def _calculate_prefix_hash(self, prompt: str) -> Optional[str]:
# Calculate hash to identify common prefixes
return hashlib.sha256(prompt.encode("utf-8")).hexdigest()
Architecture
Core Data Structures
Request Management
# Separate queues for different processing phases
self.pending_requests: List[Request] = []
self.prefill_requests: List[Request] = []
self.running_prefill: List[Request] = []
self.decode_requests: List[Request] = []
self.running_decode: List[Request] = []
self.completed_requests: List[Request] = []
Prefix Grouping
# Group requests by common prefixes for shared computation
self.prefix_groups: Dict[str, List[Request]] = defaultdict(list)
self.request_lookup: Dict[str, Request] = {}
Scheduling Strategy
The scheduler implements a SGLang-inspired strategy:
- Prioritize Decode Requests: Minimize token-to-token latency
- Maximize Prefill Efficiency: Process new requests efficiently
- Leverage Prefix Sharing: Share computation for similar requests
- Memory-Aware Scheduling: Respect KV-cache limitations
Detailed Implementation
Request Lifecycle
1. Request Addition
def add_request(self, prompt: str, max_tokens: int = 128, ...) -> str:
# Calculate prefix hash for grouping
prefix_hash = self._calculate_prefix_hash(prompt)
# Add to prefix group if applicable
if prefix_hash:
self.prefix_groups[prefix_hash].append(request)
request.request_group = prefix_hash
2. Scheduling Step
def schedule_step(self) -> Tuple[List[Request], List[Request]]:
# First, prioritize decode requests
decode_batch = []
prefill_batch = []
# Calculate remaining capacity after decode allocation
remaining_capacity = self.max_batch_size - len(decode_batch)
# Fill remaining capacity with prefill requests
if remaining_capacity > 0:
num_prefills = min(len(prefill_candidates), remaining_capacity, self.max_prefill_batch_size)
prefill_batch = prefill_candidates[:num_prefills]
Batch Selection Policy
The scheduler implements a multi-level priority system:
- Decode Priority: Continue existing generation to minimize latency
- Prefill Efficiency: Process new requests in efficient batches
- Memory Management: Ensure sufficient KV-cache for all requests
- Prefix Sharing: Group similar requests for computation sharing
Prefill Processing
Prefill requests undergo full prompt processing:
def process_prefill_batch(self, requests: List[Request]) -> List[Request]:
for req in requests:
# Process full prompt in one forward pass
req.status = RequestStatus.SCHEDULED_DECODE
# Initialize output sequence
if req.output_ids is None:
req.output_ids = []
Decode Processing
Decode requests generate tokens one-by-one:
def process_decode_batch(self, requests: List[Request]) -> List[Request]:
for req in requests:
# Get logits from model (simplified)
dummy_logits = torch.randn(1, 1000)
# Sample next token using kernel
next_token_tensor = self.sampling_kernel.sample(
dummy_logits,
temperature=req.temperature,
top_p=req.top_p
)
# Update position and check termination
req.current_position += 1
if req.current_position >= req.max_tokens:
req.status = RequestStatus.COMPLETED
SGLang-Style Optimizations
1. Computation Sharing
The scheduler identifies requests with shared prefixes:
def find_shared_prefixes(self, token_ids: List[int]) -> Tuple[List[str], List[int]]:
# Traverse radix tree to find matching prefixes
# Return requests that can share computation
2. Memory-Aware Scheduling
The scheduler connects to the memory manager for KV-cache coordination:
def connect_memory_manager(self, memory_manager):
self.memory_manager = memory_manager
3. Continuous Batching
The scheduler maintains continuous processing by balancing prefill and decode requests:
- Decode requests have higher priority (latency-sensitive)
- Prefill requests fill remaining batch capacity
- Memory requirements are considered during scheduling
Performance Considerations
Batch Size Optimization
The scheduler uses different batch size limits:
max_prefill_batch_size: Limits prefill batch size for memory efficiencymax_decode_batch_size: Larger limit for decode due to smaller memory footprintmax_batch_size: Overall system limit
Memory Management
The scheduler coordinates with the KV-cache manager to:
- Allocate blocks for new requests
- Track memory usage during processing
- Ensure sufficient memory for scheduled requests
Integration with Other Components
Memory Manager Integration
def process_prefill_batch(self, requests: List[Request]) -> List[Request]:
if self.memory_manager:
# Allocate KV cache blocks for requests
pass
Sampling Kernel Integration
def process_decode_batch(self, requests: List[Request]) -> List[Request]:
# Use sampling kernel for token selection
next_token_tensor = self.sampling_kernel.sample(...)
Request Status Monitoring
Queue Status
The scheduler provides detailed status information:
def get_queue_status(self) -> Dict[str, int]:
return {
"pending": len(self.pending_requests),
"prefill_queue": len(self.prefill_requests),
"running_prefill": len(self.running_prefill),
"decode_queue": len(self.decode_requests),
"running_decode": len(self.running_decode),
"completed": len(self.completed_requests),
"total_active": self.get_active_request_count(),
}
Request Result Access
def get_request_result(self, req_id: str) -> Optional[Dict[str, Any]]:
# Check completed requests for results
Implementation Challenges
1. Prefix Hashing
For educational purposes, the implementation uses simple string hashing. In production:
- Use token ID sequences for more accurate prefix matching
- Implement more sophisticated similarity measures
- Consider semantic similarity for better grouping
2. Memory Allocation
The current implementation shows integration points for memory management. A full implementation would:
- Calculate precise memory requirements
- Implement cache eviction policies
- Handle memory fragmentation
3. Computation Sharing
The radix tree integration points exist but require full implementation of:
- Efficient tree traversal
- Shared computation tracking
- Result distribution to multiple requests
Scheduling Algorithm Details
Step-by-Step Process
- Decode Prioritization: Schedule as many decode requests as possible
- Capacity Calculation: Determine remaining batch capacity
- Prefill Scheduling: Fill remaining capacity with prefill requests
- Memory Verification: Confirm sufficient KV-cache availability
- Batch Execution: Process scheduled requests
Optimization Strategies
The scheduler implements several optimization strategies:
- Temporal Multiplexing: Interleave prefill and decode for efficiency
- Spatial Multiplexing: Group similar requests for shared computation
- Memory Multiplexing: Optimize KV-cache usage across requests
Future Extensions
The scheduler design supports:
- Advanced prefix matching algorithms
- Dynamic batch size adjustment
- Request preemption and rescheduling
- Multi-GPU coordination
- Custom scheduling policies
Memory Management: Paged KV-Cache System
Overview
The memory management system in Mini-YAIE implements a paged KV-cache mechanism inspired by systems like vLLM and SGLang. This approach addresses the memory fragmentation challenges in LLM inference by using fixed-size memory blocks (pages) that can be allocated and deallocated independently for each request.
Core Concepts
Paged Memory Architecture
Traditional KV-cache management allocates contiguous memory blocks for each request, leading to fragmentation when requests have varying lengths. Paged KV-cache solves this by:
- Dividing the KV-cache into fixed-size blocks (pages)
- Allowing requests to use non-contiguous memory blocks
- Enabling efficient memory reuse and sharing
Key Benefits
- Reduced Fragmentation: Fixed-size blocks prevent memory fragmentation
- Efficient Memory Utilization: Unused blocks can be allocated to other requests
- Scalability: Supports variable-length requests without memory waste
- Computation Sharing: Enables shared prefixes to use the same memory blocks
Architecture
KVCacheBlock Class
Each KVCacheBlock represents a fixed-size memory block:
class KVCacheBlock:
def __init__(self, block_id: int, size: int, num_heads: int, head_dim: int, ...):
self.block_id = block_id # Unique identifier for the block
self.size = size # Number of tokens this block can hold
self.num_heads = num_heads
self.head_dim = head_dim
self.keys = None # [size, num_heads, head_dim] tensor
self.values = None # [size, num_heads, head_dim] tensor
KVCacheManager Class
The main manager orchestrates all memory operations:
class KVCacheManager:
def __init__(self, num_blocks: int, block_size: int, num_heads: int, head_dim: int, ...):
self.num_blocks = num_blocks # Total number of blocks in the pool
self.block_size = block_size # Size of each block in tokens
self.num_heads = num_heads
self.head_dim = head_dim
self.blocks: List[KVCacheBlock] = [] # Pool of all blocks
self.free_block_list: List[int] = [] # Available blocks for allocation
self.block_tables: dict = {} # Maps request_id to list of block_ids
Memory Management Operations
1. Block Allocation
When a request needs KV-cache memory, the manager allocates the required number of blocks:
def allocate_blocks(self, request_id: str, num_tokens: int) -> List[int]:
# Calculate required blocks
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
# Check availability
if len(self.free_block_list) < num_blocks_needed:
raise RuntimeError("Not enough free blocks")
# Allocate from free list
allocated_block_ids = []
for _ in range(num_blocks_needed):
block_id = self.free_block_list.pop(0) # Remove from free list
allocated_block_ids.append(block_id)
self.blocks[block_id].allocate() # Allocate GPU memory
# Track allocation
self.block_tables[request_id] = allocated_block_ids
return allocated_block_ids
Key Aspects:
- Calculates minimum blocks needed based on token count and block size
- Ensures sufficient free blocks before allocation
- Updates free block list and request tracking
- Actually allocates GPU memory for the blocks
2. Block Deallocation
When requests complete, their blocks are returned to the free pool:
def free_blocks(self, request_id: str):
if request_id in self.block_tables:
block_ids = self.block_tables[request_id]
self.free_block_list.extend(block_ids) # Return to free pool
self.free_block_list.sort() # Maintain sorted order
del self.block_tables[request_id] # Remove tracking entry
# Optionally clear tensors to free GPU memory
for block_id in block_ids:
self.blocks[block_id].keys = None
self.blocks[block_id].values = None
Key Aspects:
- Returns blocks to the free list for reuse
- Maintains sorted order for allocation efficiency
- Removes request tracking information
- Optionally clears GPU tensors to free memory
3. Block Copying
For advanced operations like request preemption or memory defragmentation:
def copy_blocks(self, src_block_ids: List[int], dst_block_ids: List[int]):
if len(src_block_ids) != len(dst_block_ids):
raise ValueError("Source and destination lists must have same length")
for src_id, dst_id in zip(src_block_ids, dst_block_ids):
src_block = self.blocks[src_id]
dst_block = self.blocks[dst_id]
# Allocate destination if needed
if dst_block.keys is None or dst_block.values is None:
dst_block.allocate()
# Copy data
with torch.no_grad():
dst_block.keys.copy_(src_block.keys)
dst_block.values.copy_(src_block.values)
Memory Layout and Access
Block Organization
The memory is organized as a collection of fixed-size blocks:
Global Memory Pool:
┌─────────┬─────────┬─────────┬─────────┬─────────┬─────────┐
│ Block 0 │ Block 1 │ Block 2 │ Block 3 │ Block 4 │ Block 5 │
│ [16xHxD]│ [16xHxD]│ [16xHxD]│ [16xHxD]│ [16xHxD]│ [16xHxD]│
└─────────┴─────────┴─────────┴─────────┴─────────┴─────────┘
Request A: Uses [Block 0, Block 2] for non-contiguous sequence storage
Request B: Uses [Block 1, Block 4, Block 5] for its sequence
Where H = num_heads and D = head_dim
Block Tables
Each request has an associated block table that maps logical token positions to physical blocks:
Request A Block Table:
Logical: [0-15][16-31][32-47][48-63][64-79]
Physical: Block0 Block2 - - Block4
Integration with SGLang Features
Computation Sharing
The paged system enables computation sharing by allowing requests with shared prefixes to reference the same memory blocks:
- Requests with common prefixes can share the same KV-cache blocks
- Multiple requests can reference the same physical memory location
- Reduces redundant computation and memory usage
Memory Efficiency
By using fixed-size blocks:
- Memory fragmentation is eliminated
- Block reuse is maximized
- Memory utilization approaches optimal levels
Performance Considerations
Block Size Selection
The block size parameter is critical for performance:
- Smaller blocks: Less internal fragmentation, more overhead for block management
- Larger blocks: More internal fragmentation, less management overhead
- Typical values: 8-32 tokens per block work well in practice
Memory Allocation Strategy
The system uses a simple first-fit strategy:
# Allocate from beginning of free list
block_id = self.free_block_list.pop(0)
For production systems, more sophisticated strategies might include:
- Best-fit to minimize fragmentation
- Coalescing strategies to combine blocks
- Preallocation to reduce allocation overhead
Memory Safety and Management
GPU Memory Management
The system ensures proper GPU memory allocation:
def allocate(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.keys = torch.zeros(
self.size, self.num_heads, self.head_dim,
dtype=self.dtype, device=device
)
self.values = torch.zeros(
self.size, self.num_heads, self.head_dim,
dtype=self.dtype, device=device
)
Memory Cleanup
Proper cleanup prevents memory leaks:
- Free blocks when requests complete
- Clear GPU tensors to release memory
- Maintain consistent state in block tables
Advanced Features
Dynamic Block Resizing
For requests that need to extend beyond their initial allocation:
- Allocate additional blocks as needed
- Maintain logical sequence continuity
- Update block tables accordingly
Memory Pool Management
Advanced implementations might include:
- Block migration to reduce fragmentation
- Eviction policies for memory-constrained scenarios
- Prefetching strategies for better performance
Error Handling
Out of Memory Conditions
The system handles memory exhaustion gracefully:
if len(self.free_block_list) < num_blocks_needed:
raise RuntimeError(f"Not enough free blocks. Need {num_blocks_needed}, have {len(self.free_block_list)}")
Block Validation
Before operations, the system validates block states:
- Verify blocks are allocated before accessing
- Check for proper tensor dimensions
- Validate request associations
Future Enhancements
Memory Optimization
Potential improvements include:
- Compressed KV-cache storage
- Offloading to CPU memory when possible
- Cache eviction policies for long-running requests
Performance Optimization
Advanced techniques might include:
- Block prefetching for better cache performance
- Heterogeneous memory management (different memory types)
- Asynchronous memory operations
Implementation Variations
SGLang-Style Memory Management
For SGLang-specific optimizations:
- Prefix sharing memory management
- Radix tree integration for shared computation
- Advanced scheduling based on memory access patterns
Integration Points
The memory manager connects with other components:
- Scheduler: Provides memory availability information
- Attention modules: Access KV-cache through block tables
- Model execution: Uses paged cache for efficient attention computation
Python Kernels Guide
Overview
The Python kernels in Mini-YAIE implement the core computational components that enable SGLang-style inference optimization. These kernels provide the foundational functionality for attention mechanisms, memory management, and token sampling that make efficient LLM inference possible.
Kernel Architecture
Core Components
The kernel system consists of several interconnected modules:
- Radix Tree: Implements prefix matching for shared computation
- KV Cache Manager: Manages paged key-value storage
- Radix Attention Module: Implements radial attention with shared computation
- Sampling Module: Provides token selection algorithms
SGLang-Style Optimization
The kernels are designed to support SGLang’s key optimization strategies:
- Prefix Sharing: Share computation for requests with common prefixes
- Continuous Batching: Dynamically batch requests at different processing stages
- Paged Memory Management: Efficiently manage KV-cache memory using fixed-size blocks
- Radial Attention: Optimize attention computation for shared prefixes
Python Kernel Implementation
Design Philosophy
The Python kernels follow these design principles:
1. Educational Focus
- Clean, well-documented code
- Clear algorithm implementation
- Comprehensive comments explaining concepts
2. SGLang Compatibility
- Implement SGLang-style optimization techniques
- Support for radial attention and prefix sharing
- Continuous batching integration
3. Modularity
- Independent components that can be tested individually
- Clean interfaces between components
- Easy to extend and modify
4. Performance Considerations
- Efficient data structures
- Proper memory management
- Optimized algorithm implementations
Implementation Structure
Each kernel follows a similar pattern:
class KernelName:
def __init__(self, parameters):
# Initialize kernel with configuration
pass
def process(self, input_data):
# Core processing logic
pass
def update_state(self, new_data):
# State management for ongoing requests
pass
Integration with System Components
Engine Integration
The kernels integrate seamlessly with the main inference engine:
# Engine uses kernels for computation
self.radix_attention = RadixAttentionWithPagedKVCache(...)
self.kv_cache_manager = KVCacheManager(...)
self.sampling_kernel = SamplingKernel()
Scheduler Coordination
Kernels work with the SGLang scheduler:
- Provide computation sharing opportunities
- Manage memory allocation and deallocation
- Coordinate with scheduling policies
Memory Management
Kernels connect with the paged memory system:
- Request memory allocation through the manager
- Manage KV-cache blocks efficiently
- Support for shared memory blocks
Performance Characteristics
Computational Efficiency
The Python kernels provide:
- Efficient attention computation
- Optimized memory access patterns
- Shared computation for common prefixes
Memory Usage
Optimized memory management includes:
- Paged cache allocation
- Block-level memory sharing
- Efficient reuse of allocated blocks
Scalability
The kernel design supports:
- Variable batch sizes
- Multiple concurrent requests
- Scaled performance with more requests
Advanced Features
Computation Sharing
The radix tree and attention modules enable:
- Shared prefix identification
- Computation reuse across requests
- Efficient memory utilization
Adaptive Processing
Kernels adapt to:
- Different request patterns
- Variable sequence lengths
- Changing memory requirements
Testing and Validation
Unit Testing
Each kernel includes:
- Comprehensive unit tests
- Edge case validation
- Performance benchmarking
Integration Testing
Kernels are tested as part of:
- Full inference pipeline
- SGLang-style optimization scenarios
- Memory management validation
Extensibility
Adding New Kernels
The system supports:
- Easy addition of new kernel types
- Pluggable architecture for kernel replacement
- Backwards compatibility
Customization
Kernels can be customized for:
- Specific model architectures
- Hardware optimization
- Performance tuning
This Python kernel system forms the computational backbone of Mini-YAIE, implementing SGLang-style optimization techniques in an educational and accessible way.
Radix Tree (kernels/radix_tree.py)
1. Concept: The Prefix Tree
A Radix Tree (or Compressed Trie) is a data structure that succinctly stores sequences of tokens. Unlike a standard trie where each edge is a single character (or token), a Radix Tree allows edges to be sequences of tokens.
Optimization Goal
When two requests start with "The quick brown fox", we want to store that sequence once.
- Request A:
"The quick brown fox jumps" - Request B:
"The quick brown fox sleeps"
In our tree, we should have a shared node for "The quick brown fox", which then branches into "jumps" and "sleeps".
flowchart TD
Root((Root)) --> Shared["The quick brown fox"]
Shared --> Branch1["jumps"]
Shared --> Branch2["sleeps"]
style Shared fill:#aaffaa
2. Implementation Guide
Open src/kernels/radix_tree.py. You will implement the RadixTree class step-by-step.
Step 1: Define the Tree Node
First, we need a node structure. Unlike a binary tree, a Radix Node can have many children.
class RadixTreeNode:
def __init__(self, prefix: List[int]):
self.prefix = prefix # The sequence of tokens on this edge
self.children: Dict[int, RadixTreeNode] = {} # Map: first_token -> Child Node
self.request_id: Optional[str] = None # If a request ends here, store its ID
self.lock_count = 0 # Reference counting (how many requests use this?)
Task: Locate the RadixTreeNode class and ensure it has these fields.
Step 2: Implement match_prefix
Before inserting, we need a way to see how much of a new prompt already exists in the tree.
Algorithm:
- Start at
self.root. - Compare the input
token_idswith the edges in the tree. - Traverse down as long as the tokens match exactly.
- Return the last matching Node and the number of matching tokens.
Your Turn:
Implement find_shared_prefixes(token_ids) in RadixTree.
Hint: Use a while loop. At each node, look at
node.children[token_ids[current_idx]]. If it exists, check if the full edgechild.prefixmatches the next chunk of your input.
Step 3: Implement insert_request (The Hard Part)
Now, inserting a new request. This involves splitting nodes if a partial match is found.
Scenario:
- Tree has edge
[1, 2, 3, 4]. - You insert
[1, 2, 5].
Algorithm:
- Trace the path like in Step 2.
- If you differ in the middle of an edge (e.g., matched
1, 2but tree has3, you have5):- Split: Create a new parent node for
[1, 2]. - Make the old node
[3, 4]a child of this new parent. - Create your new node
[5]as another child.
- Split: Create a new parent node for
flowchart TD
Start([Insert 1, 2, 5]) --> Match{Match 1, 2?}
Match -->|Yes| Diverge{Next is 3 vs 5}
Diverge --> Split[Split Edge]
Split --> Old[Child: 3, 4]
Split --> New[Child: 5]
Your Turn:
Implement insert_request(request_id, token_ids).
- Use your
match_prefixlogic helper. - Handle the 3 cases: Exact match, New Branch, or Split Edge.
Step 4: Verify
Create a test script tests/test_radix_manual.py:
tree = RadixTree()
tree.insert_request("req1", [1, 2, 3])
match, count = tree.find_shared_prefixes([1, 2, 3, 4])
print(f"Matched {count} tokens") # Should be 3!
KV Cache Manager (kernels/kv_cache.py)
1. Concept: Paged Attention
In a standard implementation, KV Cache is a huge contiguous tensor [MAX_SEQ_LEN, HEADS, DIM]. This wastes memory because most prompts are shorter than MAX_SEQ_LEN.
Paged Attention breaks this tensor into small fixed-size blocks (e.g., size 16).
- Physical Memory: A big pool of blocks
[NUM_BLOCKS, 16, HEADS, DIM]. - Logical Memory: For each request, we just keep a list of block indices
[0, 5, 12].
Your job is to write the Allocator (like malloc in C).
2. Implementation Guide
Open src/kernels/kv_cache.py.
Step 1: Initialization
We need to track which blocks are free and which are used.
Task: In __init__:
- Create a list
self.free_blocks. Initially, it should contain all integers from0tonum_blocks - 1. - Create a dictionary
self.block_tables. This will maprequest_id -> List[int](the list of blocks owned by that request).
# Hint
self.free_blocks = list(range(num_blocks))
Step 2: The allocate_blocks Method
When a request comes in (or generates new tokens), it needs memory.
Signature:
def allocate_blocks(self, request_id: str, num_tokens: int) -> List[int]:
Algorithm:
- Calculate how many blocks are needed.
- $N_{blocks} = \lceil num_tokens / block_size \rceil$
- Check if we have enough
free_blocks.- If
len(free_blocks) < needed, raise an Error (or handle OOM).
- If
- Pop the blocks from
free_blocks. - Assign them to
self.block_tables[request_id]. - Return the list of allocated block indices.
Your Turn: Implement this logic. Watch out for integer division!
Step 3: The free_blocks Method
When a request finishes, we must reclaim memory.
Algorithm:
- Look up the blocks for
request_id. - Append them back to
self.free_blocks. - Delete the entry from
self.block_tables.
Critical: Do not double-free! (Though Python sets make this easier, a list is faster for standard stacks).
Step 4: Connecting to the Engine
The get_kv_tensors method is checking if you can translate the “Logical” view to the “Physical” view.
Task: Implement get_kv_tensors.
- It should presumably return the specific GPU tensors for the blocks.
- Note: In this Python simulation, just returning the indices is often enough for the Scheduler to know mapping. The actual Tensor access happens in the CUDA kernel.
Step 5: Verify
Create tests/test_kv_manual.py:
manager = KVCacheManager(num_blocks=10, block_size=16, ...)
# Alloc 20 tokens -> needs 2 blocks (indices 0, 1)
blocks = manager.allocate_blocks("req1", 20)
print(blocks)
# Free
manager.free_blocks("req1")
print(len(manager.free_blocks)) # Should be 10 again
Radix Attention Module (kernels/radix_attention.py)
1. Concept: Connecting the Dots
We have a Radix Tree (prefix matching) and a Paged KV Cache (memory management). The RadixAttentionWithPagedKVCache class is the glue that runs on the CPU (Python side) to manage these resources before we launch the GPU kernels.
It doesn’t run the attention math (that’s the CUDA kernel’s job). Instead, it manages the metadata:
- “Request A needs to append ‘cat’ to its sequence.”
- “Does ‘cat’ already exist in the Radix Tree?”
- “If yes, reuse the block.”
- “If no, allocate a new block.”
2. Implementation Guide
Open src/kernels/radix_attention.py.
Step 1: Initialization
You need to initialize the two sub-components we built earlier.
class RadixAttentionWithPagedKVCache:
def __init__(self, ...):
# ...
self.radix_tree = RadixTree()
self.kv_cache_manager = KVCacheManager(...)
Step 2: append_slot (The Critical Logic)
This method is called when we want to add a new token (or tokens) to a request.
Signature:
def append_slot(self, key: torch.Tensor, value: torch.Tensor, request_id: str):
key/value: The computed K/V tensors for the new token(s).
Algorithm:
- Check Tree: Use
self.radix_treeto see if this(request_id + new_token)path already exists?- Note: In a real system, we check before computing K/V. Here, we might just be managing the cache storage.
- Allocate: If we need new space, call
self.kv_cache_manager.allocate_blocks(). - Store: We need to perform the copy.
- Ideally, we just return the indices of where to write, and the GPU kernel does the writing.
- For this Python simulation, you might simulate the copy or just track the metadata.
Step 3: get_kv_cache
The scheduler asks: “I am about to run requests [R1, R2]. Where is their data?”
Algorithm:
- Loop through
request_ids. - For each, ask
self.kv_cache_managerfor its block table (list of integers). - Pack these lists into a single Tensor
block_tables. - Return
block_tablesto the Engine.
Step 4: free_request
When a request is done:
self.radix_tree.remove_request(request_id)(Decrement ref counts).self.kv_cache_manager.free_blocks(request_id)(Reclaim memory).
3. The RadixAttentionBlock (Model Layer)
The class RadixAttentionBlock is the PyTorch module that sits in the model.
Task:
In forward():
- Compute Q, K, V projections.
- Compute RoPE (Rotary Embeddings).
- If Prefill: Use Flash Attention (or a standard attention) on the new tokens.
- If Decode:
- Call
append_slotto save the new K/V. - Call
paged_attention_kernel(the CUDA op) to attend to the entire history using the block tables.
- Call
Exercise: Since we don’t have the full model weight loading for this specific block, focus on the logic flow in the comments.
Sampling (kernels/sampling.py)
1. Concept: From Logits to Tokens
The model outputs logits: a vector of size [VOCAB_SIZE] (e.g., 50,000) containing raw scores for the next token.
We need to pick one token ID.
- Greedy: Just pick
argmax(). Boring, repetitive. - Sampling: Pick randomly based on probability. Creative!
We control the randomness with Temperature, Top-P (Nucleus), and Top-K.
2. Implementation Guide
Open src/kernels/sampling.py.
Step 1: Temperature Scaling
Temperature ($T$) controls confidence.
- $T < 1$: Makes peakier (more confident).
- $T > 1$: Makes flatter (more random).
Algorithm:
logits = logits / temperature
- Watch out: If $T$ is very close to 0, just do
argmaxto avoid division by zero!
Step 2: Softmax
Convert logits to probabilities (0.0 to 1.0).
probs = torch.softmax(logits, dim=-1)
Step 3: Top-K Filtering
Keep only the $K$ most likely tokens. Zero out the rest.
Algorithm:
- Find the value of the $K$-th highest score.
- Mask (set to $-\infty$) anything below that value in
logits(or 0 inprobs). - Re-normalize probabilities.
Step 4: Top-P (Nucleus) Filtering (The Tricky One)
Keep the smallest set of tokens whose cumulative probability adds up to $P$ (e.g., 0.9). This dynamically truncates the long tail of “nonsense” words.
Algorithm:
- Sort probabilities in descending order:
sorted_probs, sorted_indices = torch.sort(probs, descending=True). - Compute cumulative sum:
cumulative_probs = torch.cumsum(sorted_probs, dim=-1). - Find cut-off: Mask where
cumulative_probs > top_p.- Tip: You want to include the first token that crosses the threshold. So shift the mask right by one.
- Scatter the mask back to the original ordering.
- Re-normalize.
Step 5: The Final Selection
Once you have your clean probability distribution:
next_token = torch.multinomial(probs, num_samples=1)
Your Turn: Implement sample in SamplingKernel. Start simple (just Temperature) and verify, then add Top-P.
Build System and CUDA Kernels
Overview
The Mini-YAIE project includes a comprehensive build system for compiling custom CUDA kernels that provide optimized performance for SGLang-style inference operations. The build system is designed to be both educational and production-ready, allowing users to learn CUDA kernel development while achieving high performance.
Build System Architecture
Build Scripts
The project provides multiple ways to build kernels:
Shell Script: build_kernels.sh
#!/bin/bash
# Comprehensive build script for CUDA kernels
./build_kernels.sh
Makefile Integration
# Build kernels using make
make build-kernels
Direct Python Build
# Direct build using Python setup
python setup_kernels.py build_ext --inplace
Build Dependencies
The build system requires:
- CUDA Toolkit: Version 11.0 or higher
- PyTorch with CUDA Support: For CUDA extensions
- NVIDIA GPU: With compute capability >= 6.0
- System Compiler: GCC/Clang with C++14 support
- Python Development Headers: For Python C API
CUDA Kernel Design
SGLang-Style Optimization Focus
The CUDA kernels are specifically designed to optimize SGLang-style inference operations:
- Radix Tree Operations: Efficient prefix matching on GPU
- Paged Attention: Optimized attention for paged KV-cache
- Memory Operations: High-performance memory management
- Radix Attention: GPU-optimized radial attention with prefix sharing
Performance Goals
The kernels target SGLang-specific optimizations:
- Memory Bandwidth Optimization: Minimize memory access overhead
- Computation Sharing: Implement efficient prefix sharing on GPU
- Batch Processing: Optimize for variable-length batch processing
- Cache Efficiency: Optimize for GPU cache hierarchies
CUDA Kernel Components
1. Memory Operations Kernels
Paged KV-Cache Operations
// Efficient operations on paged key-value cache
__global__ void copy_blocks_kernel(
float* dst_keys, float* dst_values,
float* src_keys, float* src_values,
int* block_mapping, int num_blocks
);
Block Management
- Block allocation and deallocation
- Memory copying between blocks
- Block state management
2. Flash Attention Kernels
Optimized Attention Computation
// Optimized attention for both prefill and decode phases
template<typename T>
__global__ void flash_attention_kernel(
const T* q, const T* k, const T* v,
T* output, float* lse, // logsumexp for numerical stability
int num_heads, int head_dim, int seq_len
);
Features
- Memory-efficient attention computation
- Numerical stability with logsumexp
- Support for variable sequence lengths
- Optimized memory access patterns
3. Paged Attention Kernels
Paged Memory Access
// Attention with paged key-value cache support
__global__ void paged_attention_kernel(
float* output,
const float* query,
const float* key_cache,
const float* value_cache,
const int* block_tables,
const int* context_lens,
int num_kv_heads, int head_dim, int block_size
);
Features
- Direct paged cache access patterns
- Efficient block index computation
- Memory coalescing optimization
- Support for shared prefix computation
4. Radix Operations Kernels
Prefix Matching Operations
// GPU-accelerated prefix matching for SGLang-style sharing
__global__ void radix_tree_lookup_kernel(
int* token_ids, int* request_ids,
int* prefix_matches, int batch_size
);
Features
- Parallel prefix matching
- Efficient tree traversal on GPU
- Shared computation identification
- Batch processing optimization
Build Process Details
Setup Configuration
The build system uses PyTorch’s setup.py for CUDA extension compilation:
# setup_kernels.py
from torch.utils.cpp_extension import setup, CUDAExtension
setup(
name="yaie_kernels",
ext_modules=[
CUDAExtension(
name="yaie_kernels",
sources=[
"kernels/cuda/radix_attention.cu",
"kernels/cuda/paged_attention.cu",
"kernels/cuda/memory_ops.cu",
"kernels/cuda/radix_ops.cu"
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": ["-O3", "--use_fast_math"]
}
)
],
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}
)
Compilation Flags
The build system uses optimization flags for performance:
-O3: Maximum optimization--use_fast_math: Fast math operations-arch=sm_60: Target specific GPU architectures-lineinfo: Include debug line information
Architecture Targeting
The system supports multiple GPU architectures:
# Specify target architecture during build
python setup_kernels.py build_ext --inplace --arch=75 # Turing GPUs
python setup_kernels.py build_ext --inplace --arch=80 # Ampere GPUs
CUDA Kernel Implementation Guidelines
Memory Management
Unified Memory vs Regular Memory
// Use unified memory for easier management (if available)
cudaMallocManaged(&ptr, size);
// Or regular device memory for better performance
cudaMalloc(&ptr, size);
Memory Pooling
- Implement memory pooling for frequently allocated objects
- Reuse memory blocks across operations
- Batch memory operations when possible
Thread Organization
Block and Grid Sizing
// Optimize for your specific algorithm
dim3 blockSize(256);
dim3 gridSize((N + blockSize.x - 1) / blockSize.x);
Warp-Level Primitives
- Use warp-level operations for better efficiency
- Align memory accesses with warp boundaries
- Minimize warp divergence
Synchronization
Cooperative Groups
#include <cooperative_groups.h>
using namespace cooperative_groups;
// Use cooperative groups for complex synchronization
thread_block block = this_thread_block();
Memory Barriers
- Use appropriate memory barriers for consistency
- Minimize unnecessary synchronization overhead
Performance Optimization Strategies
Memory Bandwidth Optimization
Coalesced Access
// Ensure memory accesses are coalesced
int tid = blockIdx.x * blockDim.x + threadIdx.x;
// Access data[tid] by threads in order for coalescing
Shared Memory Usage
- Use shared memory for frequently accessed data
- Implement tiling strategies for large operations
- Minimize global memory access
Computation Optimization
Warp-Level Operations
- Leverage warp-level primitives when possible
- Use vectorized operations (float4, int4)
- Minimize thread divergence within warps
Kernel Fusion
Combined Operations
- Fuse multiple operations into single kernels
- Reduce kernel launch overhead
- Improve memory locality
Integration with Python
PyTorch Extensions
The CUDA kernels integrate with PyTorch using extensions:
import torch
import yaie_kernels # Compiled extension
# Use kernel from Python
result = yaie_kernels.radix_attention_forward(
query, key, value,
radix_tree_info,
attention_mask
)
Automatic GPU Management
The integration handles:
- GPU memory allocation
- Device synchronization
- Error propagation
- Backpropagation support
Error Handling and Debugging
Build-Time Errors
Common build issues and solutions:
# CUDA toolkit not found
export CUDA_HOME=/usr/local/cuda
# Architecture mismatch
# Check GPU compute capability and adjust flag
# Missing PyTorch
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
Runtime Error Checking
Kernels should include error checking:
#define CUDA_CHECK(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
fprintf(stderr, "CUDA error at %s:%d - %s\n", __FILE__, __LINE__, \
cudaGetErrorString(err)); \
exit(1); \
} \
} while(0)
Testing and Validation
Kernel Testing
Test kernels with:
- Unit tests for individual functions
- Integration tests with Python interface
- Performance benchmarks
- Memory correctness validation
SGLang-Specific Tests
Test SGLang optimization features:
- Prefix sharing correctness
- Memory management validation
- Performance gain verification
- Edge case handling
Development Workflow
Iterative Development
The development process includes:
- Kernel Design: Design algorithm for GPU execution
- Implementation: Write CUDA kernel code
- Building: Compile with build system
- Testing: Validate correctness and performance
- Optimization: Profile and optimize based on results
Profiling and Optimization
Use NVIDIA tools for optimization:
- Nsight Systems: Overall system profiling
- Nsight Compute: Detailed kernel analysis
- nvprof: Legacy profiling tool
Future Extensions
Advanced Features
Potential kernel enhancements:
- Quantized attention operations
- Sparse attention kernels
- Custom activation functions
- Advanced memory management
Hardware Support
Expand support for:
- Different GPU architectures
- Multi-GPU operations
- Heterogeneous computing
- Tensor Core optimizations
This build system and CUDA kernel architecture enables Mini-YAIE to achieve SGLang-style performance optimizations while maintaining educational value and extensibility.
Memory Operations (kernels/cuda/memory_ops.cu)
Concept
Moving data between different GPU memory locations is a frequent operation in Paged Attention.
Implementation Goal
Implement copy_blocks_kernel:
Signature
void copy_blocks_kernel(
torch::Tensor key_cache, // [num_blocks, block_size, head_dim]
torch::Tensor value_cache, // [num_blocks, block_size, head_dim]
torch::Tensor block_mapping, // [num_mappings, 2] (src, dst)
int num_mappings
);
Logic
- Parallelism: Launch one thread per token to copy.
- Indexing:
mapping_idx = blockIdx.xsrc_block = block_mapping[mapping_idx][0]dst_block = block_mapping[mapping_idx][1]
- Copy:
- Read
key/valuefromsrc_blockatthreadIdxoffset. - Write to
dst_block.
- Read
Flash Attention (kernels/cuda/flash_attention.cu)
1. Concept: Memory Bandwidth
The main bottleneck in Attention is reading the huge $N \times N$ matrix from memory. Flash Attention breaks the problem into small “tiles” that fit into the GPU’s fast SRAM (Shared Memory). We compute everything for that tile without going back to slow Global Memory.
graph TB
subgraph GlobalMemory[Global Memory HBM]
Q[Matrix Q]
K[Matrix K]
V[Matrix V]
end
subgraph SRAM[Shared Memory SRAM]
TileQ[Tile Q]
TileK[Tile K]
TileV[Tile V]
Comp(("Compute QK^T * V"))
end
Q --> TileQ
K --> TileK
V --> TileV
TileQ --> Comp
TileK --> Comp
TileV --> Comp
2. Implementation Guide
We will implement a simplified version. Doing full FlashAttention v2 is extremely complex. We aim for “Tiled Attention”.
Step 0: The Setup
Open src/kernels/cuda/flash_attention.cu.
Identify the flash_attention_forward function.
You have pointers to:
query(Q),key(K),value(V) residing in Global Memory.
Step 1: Define Thread Layout
We want to process tiles.
- Grid: One block per query chunk.
- Block: Threads within the block handle individual heads or elements.
// Example
dim3 grid(num_batches, num_heads);
dim3 block(128); // 128 threads work together on one head
Step 2: Load Tiles to Shared Memory
You need __shared__ memory arrays.
__shared__ float s_Q[TILE_SIZE][HEAD_DIM];
__shared__ float s_K[TILE_SIZE][HEAD_DIM];
Use threadIdx.x to cooperatively load data from Global Q to Shared s_Q.
Remember: call __syncthreads() after loading!
Step 3: Compute $QK^T$ (Scores)
Iterate over your shared Q and K. Calculate the dot product. Store in a register (local variable).
Step 4: Softmax (The “Online” Trick)
In standard softmax, you need the max of the entire row. Here we only see a tile! Trick: Keep a running max ($m$) and running sum ($l$). Update them as you see new tiles.
- $m_{new} = \max(m_{old}, \max(current_tile))$
- Adjust previous sums by multiplying by $e^{m_{old} - m_{new}}$.
Step 5: Compute Score $\times$ V
Once you have the probabilities for the tile, multiply by s_V (which you also loaded).
Accumulate into output.
3. Hints
- Start with a Naive kernel first! Forget shared memory. Just loops.
- Thread per query token.
- Loop over all key tokens.
- Compute.
- This is $O(N^2)$ memory reads but verifies your logic is correct.
- Only optimize to Shared Memory once logic works.
Paged Attention (kernels/cuda/paged_attention.cu)
1. Concept: Indirection
Paged Attention is just standard attention, but K and V are not contiguous.
We have to “gather” them using a Page Table.
graph LR
Thread -->|1. Get Logical idx| Logic[Token #42]
Logic -->|2. Lookup Table| Table[Block 2, Offset 10]
Table -->|3. Get Physical Addr| Phys[0xA000...]
Phys -->|4. Read| Data[Value]
2. Implementation Guide
Step 1: Understand the Block Table
You are passed block_tables tensor of shape [num_seqs, max_blocks].
- It holds integer indices of physical blocks.
block_tables[req_id][0]is the first block of that request.
Step 2: Calculate Physical Address
Inside your kernel, you want the Key vector for token t of request r.
int block_idx = t / BLOCK_SIZE;
int block_offset = t % BLOCK_SIZE;
int physical_block_number = block_tables[r][block_idx];
// Pointer arithmetic
float* k_ptr = key_cache_base
+ physical_block_number * (BLOCK_SIZE * HEAD_DIM * NUM_HEADS)
+ ... // navigate to specific head and offset
Step 3: Load Data
Using the pointer k_ptr, load the vector into registers or shared memory.
Step 4: Compute Attention
Once loaded, the math is identical to standard Attention or Flash Attention. $Q \cdot K^T$, Softmax, $\cdot V$.
3. Your Task
Implement paged_attention_kernel in src/kernels/cuda/paged_attention.cu.
- Focus on the address calculation logic. That is the only difference!
- Use the
copy_blockskernel (Memory Ops) to help set up test data if needed.
Radix Operations (kernels/cuda/radix_ops.cu)
Concept
If we have a Radix Tree, we can optimize attention even further by knowing exactly which tokens are shared.
Implementation Goal
This is an advanced extension.
Logic
- Tree Traversal on GPU: Mapping the Radix Tree structure to a GPU-friendly format (e.g., flattened arrays).
- Prefix Matching: A kernel that takes a batch of prompts and quickly identifies the longest common prefix node ID for each.
Note: In the simplified version, this logic is often kept in CPU (Python) and only the KV indices are passed to the GPU.
Kernels Implementation Guide
Overview
This guide provides a comprehensive walkthrough for implementing the core kernels in Mini-YAIE that enable SGLang-style inference optimization. The implementation focuses on three key areas:
- Python Implementations: Educational implementations of core algorithms
- CUDA Kernels: Performance-optimized GPU implementations
- Integration: Connecting kernels with the main inference engine
Implementation Roadmap
Phase 1: Core Python Kernels
Implement the educational Python versions first:
- Radix tree for prefix matching
- Basic attention mechanisms
- KV-cache management
- Sampling algorithms
Phase 2: CUDA Kernel Development
Develop optimized GPU versions:
- Memory operations kernels
- Paged attention implementation
- Flash attention optimization
- Radix operations acceleration
Phase 3: Integration and Optimization
Connect kernels to the main system:
- Engine integration
- Performance validation
- Correctness verification
Python Kernel Implementation
1. Radix Tree Implementation
Start with the radix tree that enables prefix sharing:
File: src/kernels/radix_tree.py
class RadixTreeNode:
def __init__(self, token_id: Optional[int] = None):
self.token_id = token_id
self.children: Dict[int, "RadixTreeNode"] = {}
self.request_ids: List[str] = []
self.kv_cache_refs: List[str] = []
self.is_terminal = False
class RadixTree:
def __init__(self):
self.root = RadixTreeNode()
self.request_to_path: Dict[str, List[int]] = {}
self.path_to_node: Dict[str, RadixTreeNode] = {}
def insert_request(self, request_id: str, token_ids: List[int]):
"""Insert a request into the radix tree based on its token sequence"""
current = self.root
for token_id in token_ids:
if token_id not in current.children:
current.children[token_id] = RadixTreeNode(token_id)
current = current.children[token_id]
if request_id not in current.request_ids:
current.request_ids.append(request_id)
current.is_terminal = True
self.request_to_path[request_id] = token_ids
path_str = self._path_to_string(token_ids)
self.path_to_node[path_str] = current
def find_shared_prefixes(self, token_ids: List[int]) -> Tuple[List[str], int]:
"""Find requests that share prefixes with the given token sequence"""
current = self.root
matched_requests = []
prefix_length = 0
for i, token_id in enumerate(token_ids):
if token_id in current.children:
current = current.children[token_id]
matched_requests.extend(current.request_ids)
prefix_length = i + 1
else:
break
return list(set(matched_requests)), prefix_length
2. KV-Cache Management
Implement the paged KV-cache system:
File: src/kernels/kv_cache.py
class KVCacheBlock:
def __init__(self, block_id: int, size: int, num_heads: int, head_dim: int, dtype=torch.float16):
self.block_id = block_id
self.size = size
self.num_heads = num_heads
self.head_dim = head_dim
self.dtype = dtype
self.keys = None
self.values = None
def allocate(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.keys = torch.zeros(self.size, self.num_heads, self.head_dim, dtype=self.dtype, device=device)
self.values = torch.zeros(self.size, self.num_heads, self.head_dim, dtype=self.dtype, device=device)
class KVCacheManager:
def __init__(self, num_blocks: int, block_size: int, num_heads: int, head_dim: int, dtype=torch.float16):
self.num_blocks = num_blocks
self.block_size = block_size
self.num_heads = num_heads
self.head_dim = head_dim
self.dtype = dtype
self.blocks: List[KVCacheBlock] = []
for i in range(num_blocks):
block = KVCacheBlock(i, block_size, num_heads, head_dim, dtype)
self.blocks.append(block)
self.free_block_list: List[int] = list(range(num_blocks))
self.block_tables: dict = {}
def allocate_blocks(self, request_id: str, num_tokens: int) -> List[int]:
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
if len(self.free_block_list) < num_blocks_needed:
raise RuntimeError(f"Not enough free blocks. Need {num_blocks_needed}, have {len(self.free_block_list)}")
allocated_block_ids = []
for _ in range(num_blocks_needed):
block_id = self.free_block_list.pop(0)
allocated_block_ids.append(block_id)
self.blocks[block_id].allocate()
self.block_tables[request_id] = allocated_block_ids
return allocated_block_ids
3. Radix Attention Implementation
Implement the radial attention mechanism:
File: src/kernels/radix_attention.py
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class RadixAttentionBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, head_dim: int, max_position_embeddings: int = 2048):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.max_position_embeddings = max_position_embeddings
total_hidden_dim = num_heads * head_dim
self.q_proj = nn.Linear(hidden_size, total_hidden_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, total_hidden_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, total_hidden_dim, bias=False)
self.o_proj = nn.Linear(total_hidden_dim, hidden_size, bias=False)
self.register_buffer(
"cos_cached",
torch.ones((max_position_embeddings, head_dim), dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"sin_cached",
torch.ones((max_position_embeddings, head_dim), dtype=torch.float32),
persistent=False,
)
self._setup_rope_embeddings()
def _setup_rope_embeddings(self):
position_ids = torch.arange(self.max_position_embeddings, dtype=torch.float32)
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim))
freqs = torch.einsum("i,j->ij", position_ids, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos().to(dtype=torch.float16)
self.sin_cached = emb.sin().to(dtype=torch.float16)
def forward(self, hidden_states: torch.Tensor, position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor:
batch_size, seq_len, _ = hidden_states.shape
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
cos_to_use = self.cos_cached[:seq_len].to(query.dtype)
sin_to_use = self.sin_cached[:seq_len].to(query.dtype)
query, key = apply_rotary_pos_emb(query, key, cos_to_use, sin_to_use, position_ids)
if past_key_value is not None:
key = torch.cat([past_key_value[0], key], dim=2)
value = torch.cat([past_key_value[1], value], dim=2)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_len, self.num_heads * self.head_dim)
output = self.o_proj(attn_output)
return output
4. Sampling Kernel Implementation
Implement the token sampling system:
File: src/kernels/sampling.py
import torch
class SamplingKernel:
def sample(self, logits: torch.Tensor, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1) -> torch.Tensor:
if temperature != 1.0:
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
batch_size, vocab_size = probs.shape
if top_k > 0:
top_k = min(top_k, vocab_size)
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
new_probs = torch.zeros_like(probs)
new_probs.scatter_(1, top_k_indices, top_k_probs)
new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True)
probs = new_probs
if 0 < top_p < 1.0:
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cumulative_probs <= top_p
if mask.shape[-1] > 0:
mask[..., 0] = True
filtered_probs = torch.zeros_like(probs)
filtered_probs.scatter_(1, sorted_indices, mask.float() * sorted_probs)
filtered_probs = filtered_probs / filtered_probs.sum(dim=-1, keepdim=True)
probs = filtered_probs
sampled_ids = torch.multinomial(probs, num_samples=1).squeeze(-1)
return sampled_ids
CUDA Kernel Implementation
1. Memory Operations Kernels
File: src/kernels/cuda/memory_ops.cu
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__ void copy_blocks_kernel(
float* key_cache, float* value_cache,
float* new_key_cache, float* new_value_cache,
int* block_mapping, // [src_block_id, dst_block_id] pairs
int num_heads, int head_dim, int block_size,
int num_mappings
) {
int mapping_idx = blockIdx.x;
if (mapping_idx >= num_mappings) return;
int src_block_id = block_mapping[mapping_idx * 2];
int dst_block_id = block_mapping[mapping_idx * 2 + 1];
int total_elements_per_block = block_size * num_heads * head_dim;
int tid = threadIdx.x;
if (tid < total_elements_per_block) {
int src_idx = src_block_id * total_elements_per_block + tid;
int dst_idx = dst_block_id * total_elements_per_block + tid;
new_key_cache[dst_idx] = key_cache[src_idx];
new_value_cache[dst_idx] = value_cache[src_idx];
}
}
torch::Tensor copy_blocks_cuda(
torch::Tensor key_cache, torch::Tensor value_cache,
torch::Tensor block_mapping,
int num_heads, int head_dim, int block_size
) {
int num_mappings = block_mapping.size(0);
auto options = key_cache.options();
auto new_key_cache = torch::zeros_like(key_cache);
auto new_value_cache = torch::zeros_like(value_cache);
dim3 grid(num_mappings);
dim3 block(256); // Use 256 threads per block
copy_blocks_kernel<<<grid, block>>>(
key_cache.data_ptr<float>(),
value_cache.data_ptr<float>(),
new_key_cache.data_ptr<float>(),
new_value_cache.data_ptr<float>(),
block_mapping.data_ptr<int>(),
num_heads, head_dim, block_size,
num_mappings
);
cudaDeviceSynchronize();
return std::make_tuple(new_key_cache, new_value_cache);
}
2. Paged Attention Kernels
File: src/kernels/cuda/paged_attention.cu
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__ void paged_attention_kernel(
float* output, // [num_seqs, seq_len, num_heads, head_dim]
const float* query, // [num_seqs, seq_len, num_heads, head_dim]
const float* key_cache, // [num_blocks, block_size, num_kv_heads, head_dim]
const float* value_cache,// [num_blocks, block_size, num_kv_heads, head_dim]
const int* block_tables, // [num_seqs, max_blocks_per_seq]
const int* context_lens, // [num_seqs]
const int num_kv_heads,
const int num_queries_per_kv,
const int head_dim,
const int block_size,
const int max_num_blocks_per_seq
) {
int seq_idx = blockIdx.x;
int q_head_idx = blockIdx.y;
int token_idx = blockIdx.z * blockDim.x + threadIdx.x;
if (seq_idx >= gridDim.x || q_head_idx >= gridDim.y || token_idx >= context_lens[seq_idx]) {
return;
}
// Get corresponding KV head index
int kv_head_idx = q_head_idx / num_queries_per_kv;
// Get query vector
int query_idx = seq_idx * context_lens[seq_idx] * gridDim.y * head_dim +
token_idx * gridDim.y * head_dim +
q_head_idx * head_dim;
// Shared memory for the current query
extern __shared__ float shared_mem[];
float* query_vec = shared_mem;
// Load query vector to shared memory
for (int d = 0; d < head_dim; d++) {
query_vec[d] = query[query_idx + d];
}
// Calculate which block and offset for this token
int block_idx = token_idx / block_size;
int block_offset = token_idx % block_size;
// Get physical block number from block table
int physical_block = block_tables[seq_idx * max_num_blocks_per_seq + block_idx];
// Calculate the actual index in the cache
int cache_idx = physical_block * block_size * num_kv_heads * head_dim +
block_offset * num_kv_heads * head_dim +
kv_head_idx * head_dim;
// Perform attention computation
float sum = 0.0f;
for (int d = 0; d < head_dim; d++) {
sum += query_vec[d] * key_cache[cache_idx + d];
}
// Apply softmax and multiply with value
float attention_weight = __expf(sum); // Simplified (real softmax needs normalization)
for (int d = 0; d < head_dim; d++) {
int output_idx = seq_idx * context_lens[seq_idx] * gridDim.y * head_dim +
token_idx * gridDim.y * head_dim +
q_head_idx * head_dim + d;
output[output_idx] += attention_weight * value_cache[cache_idx + d];
}
}
torch::Tensor paged_attention_cuda(
torch::Tensor query, torch::Tensor key_cache, torch::Tensor value_cache,
torch::Tensor block_tables, torch::Tensor context_lens,
int num_kv_heads, int num_queries_per_kv
) {
int num_seqs = query.size(0);
int seq_len = query.size(1);
int num_heads = query.size(2);
int head_dim = query.size(3);
int block_size = key_cache.size(1);
int max_blocks_per_seq = block_tables.size(1);
auto output = torch::zeros_like(query);
dim3 grid(num_seqs, num_heads, (seq_len + 255) / 256); // 256 threads per block
dim3 block(256);
// Allocate shared memory for query vector
int shared_mem_size = head_dim * sizeof(float);
paged_attention_kernel<<<grid, block, shared_mem_size>>>(
output.data_ptr<float>(),
query.data_ptr<float>(),
key_cache.data_ptr<float>(),
value_cache.data_ptr<float>(),
block_tables.data_ptr<int>(),
context_lens.data_ptr<int>(),
num_kv_heads,
num_queries_per_kv,
head_dim,
block_size,
max_blocks_per_seq
);
cudaDeviceSynchronize();
return output;
}
3. Flash Attention Kernels (Simplified)
File: src/kernels/cuda/flash_attention.cu
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define TILE_SIZE 32 // Small tile for educational purposes
__global__ void flash_attention_kernel(
float* output,
const float* query,
const float* key,
const float* value,
const int* seq_lens,
const int num_seqs,
const int num_heads,
const int head_dim,
const int max_seq_len
) {
int seq_idx = blockIdx.x;
int head_idx = blockIdx.y;
if (seq_idx >= num_seqs || head_idx >= num_heads) return;
int current_seq_len = seq_lens[seq_idx];
// Shared memory for tiles
extern __shared__ float shared_mem[];
float* s_Q = shared_mem;
float* s_K = s_Q + TILE_SIZE * head_dim;
float* s_V = s_K + TILE_SIZE * head_dim;
float* s_scores = s_V + TILE_SIZE * head_dim;
// Process the sequence in tiles
for (int q_tile_start = 0; q_tile_start < current_seq_len; q_tile_start += TILE_SIZE) {
// Load Q tile to shared memory
for (int i = threadIdx.x; i < TILE_SIZE * head_dim; i += blockDim.x) {
int q_row = q_tile_start + i / head_dim;
int q_col = i % head_dim;
if (q_row < current_seq_len) {
int q_idx = seq_idx * max_seq_len * num_heads * head_dim +
q_row * num_heads * head_dim +
head_idx * head_dim + q_col;
s_Q[i] = query[q_idx];
} else {
s_Q[i] = 0.0f;
}
}
__syncthreads();
// For each K/V tile
for (int k_tile_start = 0; k_tile_start < current_seq_len; k_tile_start += TILE_SIZE) {
// Load K and V tiles
for (int i = threadIdx.x; i < TILE_SIZE * head_dim; i += blockDim.x) {
int k_row = k_tile_start + i / head_dim;
int k_col = i % head_dim;
if (k_row < current_seq_len) {
int k_idx = seq_idx * max_seq_len * num_heads * head_dim +
k_row * num_heads * head_dim +
head_idx * head_dim + k_col;
s_K[i] = key[k_idx];
s_V[i] = value[k_idx];
} else {
s_K[i] = 0.0f;
s_V[i] = 0.0f;
}
}
__syncthreads();
// Compute attention scores for this tile
for (int q_local = threadIdx.x; q_local < TILE_SIZE; q_local += blockDim.x) {
int q_global = q_tile_start + q_local;
if (q_global >= current_seq_len) continue;
float score_sum = 0.0f;
float max_score = -INFINITY;
// Compute scores against K tile
for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
int k_global = k_tile_start + k_local;
if (k_global >= current_seq_len) continue;
float score = 0.0f;
for (int d = 0; d < head_dim; d++) {
int q_offset = q_local * head_dim + d;
int k_offset = k_local * head_dim + d;
score += s_Q[q_offset] * s_K[k_offset];
}
// Apply causal mask
if (k_global > q_global) score = -INFINITY;
// Update max for numerical stability
if (score > max_score) max_score = score;
s_scores[q_local * TILE_SIZE + k_local] = score;
}
// Apply softmax with numerical stability
float exp_sum = 0.0f;
for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
int k_global = k_tile_start + k_local;
if (k_global >= current_seq_len || k_global > q_global) {
s_scores[q_local * TILE_SIZE + k_local] = 0.0f;
} else {
float score = s_scores[q_local * TILE_SIZE + k_local];
float exp_score = __expf(score - max_score);
s_scores[q_local * TILE_SIZE + k_local] = exp_score;
exp_sum += exp_score;
}
}
// Normalize scores
if (exp_sum > 0.0f) {
for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
s_scores[q_local * TILE_SIZE + k_local] /= exp_sum;
}
}
// Compute output = scores @ V
for (int d = 0; d < head_dim; d++) {
float output_val = 0.0f;
for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
int v_offset = k_local * head_dim + d;
output_val += s_scores[q_local * TILE_SIZE + k_local] * s_V[v_offset];
}
int out_idx = seq_idx * max_seq_len * num_heads * head_dim +
q_global * num_heads * head_dim +
head_idx * head_dim + d;
output[out_idx] = output_val;
}
}
__syncthreads();
}
}
}
torch::Tensor flash_attention_cuda(
torch::Tensor query, torch::Tensor key, torch::Tensor value,
torch::Tensor seq_lens
) {
int num_seqs = query.size(0);
int num_heads = query.size(2);
int head_dim = query.size(3);
int max_seq_len = query.size(1);
auto output = torch::zeros_like(query);
dim3 grid(num_seqs, num_heads);
dim3 block(256);
// Shared memory for 3 tiles + scores matrix
int shared_mem_size = 3 * TILE_SIZE * head_dim * sizeof(float) +
TILE_SIZE * TILE_SIZE * sizeof(float);
flash_attention_kernel<<<grid, block, shared_mem_size>>>(
output.data_ptr<float>(),
query.data_ptr<float>(),
key.data_ptr<float>(),
value.data_ptr<float>(),
seq_lens.data_ptr<int>(),
num_seqs,
num_heads,
head_dim,
max_seq_len
);
cudaDeviceSynchronize();
return output;
}
Integration and Testing
1. Python-CUDA Interface
Create the Python interface for CUDA kernels:
File: src/kernels/api.py
"""
API for accessing both Python and CUDA implementations of kernels
"""
import torch
from typing import Optional
# Try to import CUDA extensions
try:
from . import yaie_kernels # This will be built from CUDA sources
CUDA_AVAILABLE = True
except ImportError:
print("CUDA extensions not available. Using Python implementations.")
CUDA_AVAILABLE = False
def attention_forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
use_cuda: bool = True, **kwargs):
"""Unified attention interface that can use CUDA or Python implementation"""
if CUDA_AVAILABLE and use_cuda and query.is_cuda:
return yaie_kernels.flash_attention_cuda(query, key, value,
kwargs.get('seq_lens', None))
else:
# Fallback to Python implementation
from .radix_attention import RadixAttentionBlock
attention_block = RadixAttentionBlock(
hidden_size=query.shape[-1],
num_heads=query.shape[-2],
head_dim=query.shape[-1] // query.shape[-2]
)
return attention_block(query)
def paged_attention_forward(query: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, block_tables: torch.Tensor,
context_lens: torch.Tensor, use_cuda: bool = True, **kwargs):
"""Paged attention interface"""
if CUDA_AVAILABLE and use_cuda and query.is_cuda:
return yaie_kernels.paged_attention_cuda(
query, key_cache, value_cache, block_tables, context_lens,
kwargs.get('num_kv_heads', 1),
kwargs.get('num_queries_per_kv', 1)
)
else:
# Python fallback would go here
raise NotImplementedError("Paged attention Python fallback not implemented")
def copy_blocks(key_cache: torch.Tensor, value_cache: torch.Tensor,
block_mapping: torch.Tensor, use_cuda: bool = True, **kwargs):
"""Memory copy interface"""
if CUDA_AVAILABLE and use_cuda and key_cache.is_cuda:
return yaie_kernels.copy_blocks_cuda(
key_cache, value_cache, block_mapping,
kwargs.get('num_heads', 1),
kwargs.get('head_dim', 64),
kwargs.get('block_size', 16)
)
else:
# Python fallback would go here
raise NotImplementedError("Copy blocks Python fallback not implemented")
2. Testing Framework
Create comprehensive tests:
File: tests/test_kernels.py
import pytest
import torch
import numpy as np
from src.kernels.radix_tree import RadixTree
from src.kernels.kv_cache import KVCacheManager
from src.kernels.radix_attention import RadixAttentionBlock
from src.kernels.sampling import SamplingKernel
from src.kernels.api import attention_forward, paged_attention_forward, copy_blocks
class TestRadixTree:
def test_basic_insertion_and_search(self):
tree = RadixTree()
# Insert requests
tree.insert_request("req1", [1, 2, 3])
tree.insert_request("req2", [1, 2, 4]) # Shares prefix [1, 2]
tree.insert_request("req3", [5, 6, 7]) # No shared prefix
# Test shared prefixes
shared, length = tree.find_shared_prefixes([1, 2, 5])
assert "req1" in shared
assert "req2" in shared
assert length == 2 # Common prefix [1, 2]
def test_prefix_sharing_graph(self):
tree = RadixTree()
tree.insert_request("req1", [1, 2, 3])
tree.insert_request("req2", [1, 2, 4])
tree.insert_request("req3", [1, 5, 6])
graph = tree.get_shared_computation_graph()
# Should show shared computation at token [1]
assert graph["request_count"] == 3 # All requests start with root
class TestKVCacheManager:
def test_basic_allocation(self):
cache_manager = KVCacheManager(
num_blocks=100,
block_size=16,
num_heads=8,
head_dim=64,
dtype=torch.float16
)
# Allocate blocks for a request
blocks = cache_manager.allocate_blocks("req1", 20) # Need 2 blocks (20/16 = 2)
assert len(blocks) == 2
# Verify the blocks exist and have proper tensors
for block_id in blocks:
block = cache_manager.blocks[block_id]
assert block.keys is not None
assert block.values is not None
assert block.keys.shape == (16, 8, 64) # block_size, num_heads, head_dim
def test_block_reuse(self):
cache_manager = KVCacheManager(
num_blocks=10,
block_size=16,
num_heads=8,
head_dim=64
)
# Allocate all blocks
req_ids = [f"req{i}" for i in range(10)]
for req_id in req_ids:
cache_manager.allocate_blocks(req_id, 10)
assert len(cache_manager.free_block_list) == 0
# Free some blocks
cache_manager.free_blocks("req0")
cache_manager.free_blocks("req1")
assert len(cache_manager.free_block_list) == 2
assert 0 in cache_manager.free_block_list
assert 1 in cache_manager.free_block_list
class TestRadixAttention:
def test_basic_attention_forward(self):
hidden_size = 512
num_heads = 8
head_dim = hidden_size // num_heads
attention = RadixAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
head_dim=head_dim,
max_position_embeddings=256
)
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, hidden_size, dtype=torch.float16)
output = attention(x)
assert output.shape == (batch_size, seq_len, hidden_size)
assert not torch.isnan(output).any()
def test_attention_with_past_key_value(self):
hidden_size = 256
num_heads = 4
head_dim = hidden_size // num_heads
attention = RadixAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
head_dim=head_dim
)
batch_size = 1
seq_len = 5
x = torch.randn(batch_size, seq_len, hidden_size)
# First forward pass
output1, _, past_kv = attention(x, use_cache=True)
# Second forward pass with past key-value
next_token = torch.randn(batch_size, 1, hidden_size)
output2, _, _ = attention(next_token, past_key_value=past_kv)
assert output2.shape == (batch_size, 1, hidden_size)
class TestSamplingKernel:
def test_temperature_sampling(self):
sampling = SamplingKernel()
# Create logits with one clear winner
logits = torch.tensor([[10.0, 1.0, 1.0, 1.0]]) # First token is dominant
# High temperature should allow other tokens
sampled_high_temp = sampling.sample(logits, temperature=2.0)
assert sampled_high_temp.shape == (1,)
# Low temperature should favor dominant token
sampled_low_temp = sampling.sample(logits, temperature=0.1)
assert sampled_low_temp[0] == 0 # Should pick the dominant token
def test_top_p_nucleus_sampling(self):
sampling = SamplingKernel()
# Create logits where first 3 tokens account for ~90% of probability
logits = torch.tensor([[2.0, 1.5, 1.0, -10.0, -10.0]])
# Top-p = 0.8 should exclude the last two tokens
sampled = sampling.sample(logits, top_p=0.8)
# Should be one of the first 3 tokens
assert sampled[0] in [0, 1, 2]
def test_top_k_sampling(self):
sampling = SamplingKernel()
# Create logits with clear ordering
logits = torch.tensor([[5.0, 4.0, 3.0, 2.0, 1.0]])
# Top-k = 2 should only consider first 2 tokens
sampled = sampling.sample(logits, top_k=2)
assert sampled[0] in [0, 1] # Should be one of top 2 tokens
class TestIntegration:
def test_full_inference_pipeline(self):
"""Test integration of all kernels in a simple pipeline"""
# This test would simulate a full inference step
batch_size = 2
seq_len = 10
hidden_size = 256
num_heads = 4
head_dim = hidden_size // num_heads
# Create input
x = torch.randn(batch_size, seq_len, hidden_size)
# Apply attention
attention = RadixAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
head_dim=head_dim
)
attn_output = attention(x)
assert attn_output.shape == (batch_size, seq_len, hidden_size)
# Apply sampling (on logits that would come from LM head)
logits = torch.randn(batch_size, 1000) # vocab_size = 1000
sampling = SamplingKernel()
sampled_tokens = sampling.sample(logits, temperature=0.7)
assert sampled_tokens.shape == (batch_size,)
if __name__ == "__main__":
pytest.main([__file__])
Building and Running
Setup Configuration
File: setup_kernels.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os
# Check if CUDA is available
def get_extensions():
extensions = []
# Check if CUDA is available
try:
import torch
if torch.cuda.is_available():
extensions.append(
CUDAExtension(
name='yaie_kernels',
sources=[
'src/kernels/cuda/memory_ops.cu',
'src/kernels/cuda/paged_attention.cu',
'src/kernels/cuda/flash_attention.cu',
'src/kernels/cuda/radix_ops.cu',
'src/kernels/cuda/pybind.cpp', # Python bindings
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '--use_fast_math', '-arch=sm_70']
}
)
)
except:
print("CUDA not available, building without CUDA extensions")
return extensions
setup(
name='yaie_kernels',
ext_modules=get_extensions(),
cmdclass={'build_ext': BuildExtension},
zip_safe=False,
)
Performance Optimization Guidelines
CUDA Optimization Tips
- Memory Coalescing: Ensure threads in a warp access consecutive memory
- Shared Memory: Use for frequently accessed data
- Occupancy: Maximize number of active warps
- Reduction Operations: Use efficient parallel reduction algorithms
Profiling and Benchmarking
Create benchmarking tools:
import torch
import time
from torch.profiler import profile, record_function, ProfilerActivity
def benchmark_kernel(kernel_func, *args, **kwargs):
"""Benchmark a kernel function"""
# Warmup
for _ in range(3):
result = kernel_func(*args, **kwargs)
# Actual timing
torch.cuda.synchronize()
start_time = time.time()
for _ in range(10): # Run multiple times for average
result = kernel_func(*args, **kwargs)
torch.cuda.synchronize()
end_time = time.time()
avg_time = (end_time - start_time) / 10
return avg_time, result
def profile_kernel(kernel_func, *args, **kwargs):
"""Profile a kernel function"""
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
result = kernel_func(*args, **kwargs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
return result
This comprehensive implementation guide provides everything needed to implement the core kernels for Mini-YAIE, following SGLang-style optimization principles while maintaining educational value.
API & Serving: OpenAI-Compatible Server
Overview
The API server in Mini-YAIE implements an OpenAI-compatible interface, allowing the engine to be used with existing applications and tools designed for OpenAI’s API. The server uses FastAPI to provide RESTful endpoints with proper request/response handling, streaming support, and health monitoring.
API Design Philosophy
The server follows OpenAI’s API specification to ensure compatibility with existing tools and applications while leveraging the advanced features of the SGLang-style inference engine. This approach allows users to:
- Use existing OpenAI clients without modification
- Take advantage of Mini-YAIE’s performance optimizations
- Integrate with tools built for OpenAI’s API format
- Maintain familiar request/response patterns
Core Architecture
FastAPI Application
The server is built using FastAPI for high-performance web serving:
def create_app(model_name: str) -> FastAPI:
app = FastAPI(title="YAIE API", version="0.1.0")
engine = InferenceEngine(model_name)
return app
Main Components
- Inference Engine Integration: Connects API endpoints to the core inference engine
- Request Validation: Pydantic models for request/response validation
- Streaming Support: Server-sent events for real-time token streaming
- Error Handling: Proper HTTP error codes and message formatting
- Health Monitoring: Endpoints for system status and availability
API Endpoints
1. Chat Completions Endpoint
The main endpoint follows OpenAI’s /v1/chat/completions specification:
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
# Implementation handles both streaming and non-streaming responses
Request Schema
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
max_tokens: Optional[int] = None
stream: Optional[bool] = False
Response Schema
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[Choice]
usage: Dict[str, int]
Supported Parameters
- model: Model identifier (passed during server startup)
- messages: List of message objects with role and content
- temperature: Sampling temperature (0.0-2.0 recommended)
- top_p: Nucleus sampling threshold (0.0-1.0)
- max_tokens: Maximum tokens to generate
- stream: Whether to stream responses (true/false)
2. Model Listing Endpoint
Lists the available model:
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": model_name,
"object": "model",
"owned_by": "user",
"created": int(time.time()),
}
],
}
3. Health Check Endpoint
Simple health monitoring:
@app.get("/health")
async def health_check():
return {"status": "healthy", "model": model_name}
Streaming Implementation
Streaming vs Non-Streaming
The server supports both response formats using the same endpoint:
if request.stream:
# Return streaming response
return StreamingResponse(generate_stream(), media_type="text/event-stream")
else:
# Return non-streaming response
response = engine.chat_completion(messages_dicts, **kwargs)
return response
Streaming Response Format
The streaming implementation generates Server-Sent Events (SSE):
def generate_stream():
for chunk in engine.chat_completion_stream(messages_dicts, **kwargs):
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
Each chunk follows OpenAI’s streaming format:
{
"id": "chatcmpl-...",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "model-name",
"choices": [{
"index": 0,
"delta": {"content": "token"},
"finish_reason": null
}]
}
Integration with Inference Engine
Request Processing Flow
- API Request: Received through FastAPI endpoints
- Validation: Pydantic models validate request format
- Parameter Extraction: Convert API parameters to engine format
- Engine Processing: Call appropriate engine methods
- Response Formatting: Convert engine output to API format
- API Response: Return properly formatted responses
Parameter Mapping
API parameters are mapped to engine capabilities:
kwargs = {}
if request.max_tokens is not None:
kwargs["max_tokens"] = request.max_tokens
if request.temperature is not None:
kwargs["temperature"] = request.temperature
if request.top_p is not None:
kwargs["top_p"] = request.top_p
Message Formatting
The server handles OpenAI-style messages by converting them to a format the engine understands:
messages_dicts = [
{"role": msg.role, "content": msg.content}
for msg in request.messages
]
# Apply chat template if available
if hasattr(self.tokenizer, 'apply_chat_template'):
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
# Fallback formatting
formatted_prompt = ""
for message in messages:
formatted_prompt += f"{message['role'].capitalize()}: {message['content']}\n"
formatted_prompt += "\nAssistant:"
Error Handling
HTTP Error Codes
The server properly handles various error conditions:
- 400 Bad Request: Invalid request parameters
- 429 Too Many Requests: Rate limiting (not implemented in basic version)
- 500 Internal Server Error: Server-side errors during processing
Error Response Format
Standard OpenAI-compatible error format:
{
"error": {
"message": "Error description",
"type": "server_error",
"param": null,
"code": null
}
}
Exception Handling
The server wraps all processing in try-catch blocks:
try:
# Process request
response = engine.chat_completion(messages_dicts, **kwargs)
return response
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
Performance Considerations
Request Batching
The API integrates with the engine’s batching system:
- Multiple API requests can be batched together in the engine
- Continuous batching maintains high throughput
- Batch size limited by engine configuration
Memory Management
The server shares memory with the inference engine:
- KV-cache is shared across API requests
- Efficient memory reuse through paged cache system
- Memory limits enforced by engine configuration
Concurrency
FastAPI provides automatic concurrency handling:
- Async request processing
- Connection pooling
- Efficient handling of multiple simultaneous requests
Security Considerations
Input Validation
- Pydantic models validate all request parameters
- Type checking prevents injection attacks
- Length limits prevent excessive resource consumption
Rate Limiting
While not implemented in the basic version, can be added:
- Per-IP rate limiting
- Request quota management
- Usage monitoring
Deployment Configuration
Server Startup
The server can be started with a specific model:
uvicorn server.api:app --host 0.0.0.0 --port 8000
Environment Configuration
The server supports environment-based configuration:
- Model name via environment variables
- Port and host configuration
- Resource limits
SGLang-Style Features Integration
Continuous Batching
The API benefits from the engine’s continuous batching:
- Requests are automatically batched
- High throughput maintained
- Low latency for individual requests
Prefix Sharing
API requests with similar prefixes benefit from:
- Shared computation in radial attention
- Reduced memory usage
- Improved efficiency
Multi-Step Processing
The API leverages the engine’s multi-step capabilities:
- Efficient prefill and decode phases
- Optimized request scheduling
- Memory-aware processing
Usage Examples
Basic Chat Request
curl -X POST "http://localhost:8000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"model": "model-name",
"messages": [
{"role": "user", "content": "Hello, how are you?"}
],
"temperature": 0.7
}'
Streaming Request
curl -X POST "http://localhost:8000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"model": "model-name",
"messages": [
{"role": "user", "content": "Write a short story"}
],
"stream": true
}'
Programmatic Usage
import openai
# Configure client to use local server
client = openai.OpenAI(
base_url="http://localhost:8000/v1",
api_key="dummy" # Required but ignored
)
# Use standard OpenAI SDK
response = client.chat.completions.create(
model="model-name",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7
)
Monitoring and Logging
Health Endpoints
The health endpoint can be used for monitoring:
curl http://localhost:8000/health
Performance Metrics
The server can be extended with:
- Response time tracking
- Request volume monitoring
- Error rate monitoring
- Resource utilization metrics
Advanced Features
Model Loading
The server handles model loading automatically:
- Lazy loading when first accessed
- Caching for subsequent requests
- HuggingFace model integration
Response Caching
The system supports response caching for:
- Repeated identical requests
- Common prompt patterns
- Improved response times for cached content
Logging and Debugging
Comprehensive logging can be added:
- Request/response logging
- Performance metrics
- Error tracing
- Usage analytics
This OpenAI-compatible API server enables Mini-YAIE to be integrated into existing ecosystems while providing the performance benefits of SGLang-style inference optimization.
CLI Usage: Interactive and Server Modes
Overview
Mini-YAIE provides a comprehensive command-line interface (CLI) that serves as the primary entry point for users. The CLI supports both interactive chat mode and server mode, making it suitable for both direct interaction and production deployment scenarios.
CLI Architecture
Entry Point Structure
The CLI is organized around different command verbs:
yaie <command> [options] [arguments]
Commands:
- serve: Start an OpenAI-compatible API server
- chat: Start an interactive chat session
Core Components
- Argument Parsing: Uses argparse for command-line option handling
- Model Integration: Connects CLI commands to the inference engine
- Interactive Interface: Provides user-friendly chat experience
- Server Integration: Launches API server with proper configuration
Server Mode
Basic Server Usage
Start the API server with a specific model:
yaie serve microsoft/DialoGPT-medium --host localhost --port 8000
Server Options
Model Selection
--model MODEL_NAME # Specify the model to use (required)
Network Configuration
--host HOST # Server host (default: localhost)
--port PORT # Server port (default: 8000)
--workers WORKERS # Number of server workers
Performance Options
--max-batch-size N # Maximum batch size
--max-prefill-batch-size N # Maximum prefill batch size
--max-decode-batch-size N # Maximum decode batch size
--num-blocks N # Number of KV-cache blocks
--block-size N # Size of each cache block
Server Startup Process
- Model Loading: Download and load model from HuggingFace if not cached
- Engine Initialization: Create inference engine with specified parameters
- API Server Creation: Initialize FastAPI application with engine
- Server Launch: Start the web server on specified host/port
Example Server Commands
Basic Server
yaie serve microsoft/DialoGPT-medium
Production Server
yaie serve microsoft/DialoGPT-medium --host 0.0.0.0 --port 8000 --max-batch-size 16
Resource-Constrained Server
yaie serve microsoft/DialoGPT-medium --num-blocks 1000 --max-batch-size 4
Chat Mode
Basic Chat Usage
Start an interactive chat session:
yaie chat microsoft/DialoGPT-medium
Chat Options
Generation Parameters
--temperature TEMP # Sampling temperature (default: 1.0)
--top-p TOP_P # Nucleus sampling threshold (default: 1.0)
--max-tokens N # Maximum tokens to generate (default: 512)
--stream # Stream responses in real-time (default: true)
Model Configuration
--model MODEL_NAME # Specify the model to use (required)
Interactive Chat Experience
Session Flow
- Model Loading: Model is loaded if not cached
- Chat Initialization: Engine and tokenizer are set up
- Conversation Loop: User inputs are processed and responses generated
- Session Termination: Exit with Ctrl+C or quit command
User Interaction
The chat interface provides a conversational experience:
$ yaie chat microsoft/DialoGPT-medium
Model loaded successfully!
Starting chat session (press Ctrl+C to exit)...
User: Hello, how are you?
AI: I'm doing well, thank you for asking!
User: What can you help me with?
AI: I can have conversations, answer questions, and assist with various tasks.
Example Chat Commands
Basic Chat
yaie chat microsoft/DialoGPT-medium
Creative Chat
yaie chat microsoft/DialoGPT-medium --temperature 1.2 --top-p 0.9
Focused Chat
yaie chat microsoft/DialoGPT-medium --temperature 0.7 --max-tokens 128
Model Selection
Supported Model Formats
The CLI supports any HuggingFace-compatible model:
Pre-trained Models
yaie serve microsoft/DialoGPT-medium
yaie serve gpt2
yaie serve facebook/opt-1.3b
Local Models
yaie serve /path/to/local/model
yaie serve ./models/my-custom-model
Model Caching
Models are automatically downloaded and cached:
- First run: Download from HuggingFace Hub
- Subsequent runs: Use cached version
- Cache location: Standard HuggingFace cache directory
Performance Tuning
Memory Configuration
Adjust memory settings based on available GPU memory:
# For 24GB+ GPU
yaie serve model --num-blocks 4000 --max-batch-size 32
# For 8-16GB GPU
yaie serve model --num-blocks 1500 --max-batch-size 8
# For 4-8GB GPU
yaie serve model --num-blocks 800 --max-batch-size 4
Batch Size Optimization
Tune batch sizes for optimal throughput:
# High throughput (more memory)
yaie serve model --max-batch-size 32 --max-prefill-batch-size 64
# Memory efficient (lower batch sizes)
yaie serve model --max-batch-size 4 --max-prefill-batch-size 8
Error Handling and Troubleshooting
Common Errors
Model Loading Errors
# If model name is invalid
Error: Model not found on HuggingFace Hub
# If network is unavailable during first load
Error: Failed to download model
Memory Errors
# If not enough GPU memory
CUDA out of memory error
# If KV-cache is too large
Memory allocation failed
Debugging Options
Verbose Output
yaie serve model --verbose # Show detailed startup information
Configuration Validation
yaie serve model --debug # Enable debugging features
Advanced CLI Features
Configuration Files
The CLI supports configuration files for complex setups:
yaie serve --config config.yaml model
Environment Variables
Several environment variables can customize behavior:
# Set default host
export YAIE_HOST=0.0.0.0
# Set default port
export YAIE_PORT=9000
# Set memory limits
export YAIE_MAX_BLOCKS=2000
Logging Configuration
Control logging verbosity and output:
# Enable detailed logging
yaie serve model --log-level DEBUG
# Log to file
yaie serve model --log-file server.log
Integration with SGLang Features
Batching Optimization
The CLI exposes SGLang batching parameters:
yaie serve model \
--max-prefill-batch-size 16 \
--max-decode-batch-size 256
Prefix Sharing Control
Parameters that affect prefix sharing efficiency:
yaie serve model \
--max-seq-len 2048 \
--block-size 16
Production Deployment
Server Management
Process Control
# Start server in background
nohup yaie serve model > server.log 2>&1 &
# Kill server process
pkill -f "yaie serve"
Process Monitoring
# Monitor server with systemd
systemctl start yaie-server
# Monitor with supervisor
supervisorctl start yaie-server
Health Checks
The server provides health status:
# Check server status
curl http://localhost:8000/health
# Integrate with monitoring tools
# Health check interval and thresholds
Examples and Use Cases
Development Usage
# Quick test with small model
yaie chat gpt2
# Interactive development with verbose output
yaie serve gpt2 --port 8000 --verbose
Production Usage
# High-performance server for production
yaie serve microsoft/DialoGPT-medium \
--host 0.0.0.0 \
--port 8000 \
--max-batch-size 16 \
--num-blocks 2000
# Low-resource server for edge deployment
yaie serve gpt2 \
--max-batch-size 2 \
--num-blocks 500
Testing and Evaluation
# Test with various parameters
yaie chat model --temperature 0.5 --top-p 0.9
# Evaluate different models
yaie serve model1 --port 8001 &
yaie serve model2 --port 8002 &
The CLI provides a comprehensive interface to access all of Mini-YAIE’s features, from simple interactive chat to high-performance API serving with SGLang-style optimizations.
Production Deployment
While Mini-YAIE is primarily educational, understanding production considerations helps bridge the gap between learning and real-world deployment.
Deployment Architecture
graph TD
LoadBalancer -->|HTTP| API1
LoadBalancer -->|HTTP| API2
LoadBalancer -->|HTTP| API3
API1 -->|gRPC| Inference1
API2 -->|gRPC| Inference2
API3 -->|gRPC| Inference3
Inference1 -->|GPU| Model1
Inference2 -->|GPU| Model2
Inference3 -->|GPU| Model3
1. Performance Optimization
Batching & Latency Management
- Request Timeouts: Implement configurable timeouts for requests waiting in queue
- Priority Queues: Support different priority levels for requests
- Preemption: Pause low-priority requests and swap their KV-cache to CPU when high-priority requests arrive
- Adaptive Batching: Dynamically adjust batch sizes based on current load and latency requirements
Continuous Batching Strategies
# Advanced scheduling policies
class SchedulingPolicy(Enum):
FCFS = "fcfs" # First-come-first-served
PRIORITY = "priority" # Priority-based
FAIR = "fair" # Fair sharing
LATENCY_OPTIMIZED = "latency" # Minimize latency
2. Distributed Inference
Tensor Parallelism
# Split model weights across multiple GPUs
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map={
"layer.0": "cuda:0",
"layer.1": "cuda:1",
"layer.2": "cuda:2",
"layer.3": "cuda:3"
}
)
Pipeline Parallelism
# Split layers across different stages
# Stage 1: Layers 0-10 on GPU 0
# Stage 2: Layers 11-20 on GPU 1
# Stage 3: Layers 21-30 on GPU 2
Model Parallelism
- Expert Parallelism: For mixture-of-experts models
- Sequence Parallelism: For very long sequences
- Hybrid Parallelism: Combining multiple strategies
3. Memory Optimization
Quantization
# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_8bit=True, # 8-bit quantization
device_map="auto"
)
# Or use 4-bit quantization
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_4bit=True,
device_map="auto"
)
Memory Management Strategies
- Paged Attention: Efficient memory management for KV-cache
- Memory Pooling: Reuse memory blocks across requests
- Swapping: Move inactive KV-cache blocks to CPU memory
- Compression: Compress KV-cache values with minimal quality loss
4. Scalability & Reliability
Horizontal Scaling
# Multiple inference workers behind load balancer
workers = [
InferenceWorker(model_name, gpu_id=0),
InferenceWorker(model_name, gpu_id=1),
InferenceWorker(model_name, gpu_id=2),
InferenceWorker(model_name, gpu_id=3)
]
Load Balancing
- Round Robin: Simple distribution across workers
- Least Connections: Send to least busy worker
- Latency-based: Send to worker with lowest current latency
- Content-based: Route based on request characteristics
Health Monitoring
# Health check endpoints
@app.get("/health")
def health_check():
return {
"status": "healthy",
"gpu_memory": get_gpu_memory_usage(),
"active_requests": scheduler.get_active_request_count(),
"queue_length": scheduler.get_queue_status()
}
5. Observability
Metrics Collection
# Prometheus metrics
REQUEST_LATENCY = Histogram(
'request_latency_seconds',
'Request latency in seconds',
buckets=[0.1, 0.5, 1.0, 2.5, 5.0, 10.0]
)
TOKENS_PER_SECOND = Counter(
'tokens_per_second',
'Tokens generated per second'
)
Logging
# Structured logging
logger.info(
"Request completed",
request_id=request.id,
latency=latency_seconds,
tokens_generated=token_count,
model=model_name
)
Tracing
# Distributed tracing
with tracer.start_span("generate_response") as span:
span.set_attribute("model", model_name)
span.set_attribute("request_id", request.id)
response = engine.generate(prompt)
6. Security
Authentication
# API key authentication
@app.post("/v1/chat/completions")
async def chat_completions(
request: ChatCompletionRequest,
api_key: str = Header(None)
):
if not validate_api_key(api_key):
raise HTTPException(status_code=401, detail="Unauthorized")
Rate Limiting
# Rate limiting per API key
limiter = RateLimiter(
requests_per_minute=1000,
burst_capacity=100
)
@app.post("/v1/chat/completions")
@limiter.limit("1000/minute")
async def chat_completions(request: ChatCompletionRequest):
# Process request
Input Validation
# Validate and sanitize inputs
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
# Validate prompt length
if len(request.prompt) > MAX_PROMPT_LENGTH:
raise HTTPException(status_code=400, detail="Prompt too long")
# Sanitize inputs
sanitized_prompt = sanitize_input(request.prompt)
7. Configuration Management
Environment Variables
# Configure through environment variables
export MODEL_NAME="gpt2"
export MAX_BATCH_SIZE="16"
export GPU_MEMORY_UTIL="0.9"
export ENABLE_RADIX_CACHE="true"
Configuration Files
# YAML configuration
model:
name: "gpt2"
dtype: "float16"
tensor_parallel_size: 1
scheduler:
max_batch_size: 16
max_prefill_batch_size: 32
max_decode_batch_size: 256
memory:
gpu_blocks: 2000
cpu_blocks: 1000
block_size: 16
8. Deployment Strategies
Containerization
# Dockerfile for Mini-YAIE
FROM python:3.9-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY . .
RUN pip install -e .
CMD ["yaie", "serve", "gpt2", "--host", "0.0.0.0", "--port", "8000"]
Kubernetes Deployment
# Kubernetes deployment
apiVersion: apps/v1
kind: Deployment
metadata:
name: yaie-inference
spec:
replicas: 3
selector:
matchLabels:
app: yaie-inference
template:
metadata:
labels:
app: yaie-inference
spec:
containers:
- name: yaie
image: yaie:latest
ports:
- containerPort: 8000
resources:
limits:
nvidia.com/gpu: 1
9. Performance Monitoring
Key Metrics to Track
- Latency: Time from request to first token (TTFT) and time per token
- Throughput: Tokens generated per second
- GPU Utilization: Percentage of GPU time spent on computation
- Memory Usage: GPU and CPU memory consumption
- Queue Length: Number of requests waiting
- Error Rates: Percentage of failed requests
Alerting
# Set up alerts for critical conditions
if gpu_memory_usage > 0.95:
alert("High GPU memory usage")
if request_latency > 5.0: # 5 seconds
alert("High request latency")
if error_rate > 0.01: # 1%
alert("High error rate")
10. Cost Optimization
Resource Management
- Auto-scaling: Scale workers based on demand
- Spot Instances: Use cheaper spot instances for non-critical workloads
- Right-sizing: Choose appropriate instance types for workload
- Batch Processing: Process offline requests in batches during low-traffic periods
Model Selection
# Choose appropriate model size for use case
small_models = ["gpt2", "DialoGPT-small"] # Fast, low cost
medium_models = ["gpt2-medium", "DialoGPT-medium"] # Balanced
large_models = ["gpt2-large", "DialoGPT-large"] # High quality, expensive
Educational Focus
Understanding production considerations helps you:
- Bridge the gap between educational implementations and real-world systems
- Appreciate the complexity of production-grade inference engines
- Make informed decisions about trade-offs in your implementations
- Design for scalability from the beginning
From Mini-YAIE to Production
Mini-YAIE provides the foundation for understanding key concepts:
- Continuous Batching: The core of efficient inference
- Memory Management: Critical for handling multiple requests
- Prefix Sharing: Advanced optimization for similar requests
- API Design: Standard interfaces for integration
Production systems build on these foundations with:
- Scalability: Handling thousands of concurrent requests
- Reliability: High availability and fault tolerance
- Observability: Comprehensive monitoring and logging
- Security: Authentication, authorization, and input validation
By mastering the concepts in Mini-YAIE, you’ll be well-prepared to understand and contribute to production-grade inference systems like vLLM, SGLang, and TensorRT-LLM.
References
- SGLang: Efficient Execution of Structured Language Model Programs. Link
- vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention. Link
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Link
Troubleshooting
“CUDA kernel not found”
- Ensure you ran
pip install -e .. - Check if
nvccis in your path:nvcc --version.
“OutOfMemoryError”
- Decrease
max_batch_size. - Decrease
kv_cache_managerblock count.
“ImportError: attempted relative import…”
- Ensure you are running the
yaiecommand, or running python as a modulepython -m src.cli.main. - Do not run scripts directly like
python src/engine.py.