Kernels Implementation Guide
Overview
This guide provides a comprehensive walkthrough for implementing the core kernels in Mini-YAIE that enable SGLang-style inference optimization. The implementation focuses on three key areas:
- Python Implementations: Educational implementations of core algorithms
- CUDA Kernels: Performance-optimized GPU implementations
- Integration: Connecting kernels with the main inference engine
Implementation Roadmap
Phase 1: Core Python Kernels
Implement the educational Python versions first:
- Radix tree for prefix matching
- Basic attention mechanisms
- KV-cache management
- Sampling algorithms
Phase 2: CUDA Kernel Development
Develop optimized GPU versions:
- Memory operations kernels
- Paged attention implementation
- Flash attention optimization
- Radix operations acceleration
Phase 3: Integration and Optimization
Connect kernels to the main system:
- Engine integration
- Performance validation
- Correctness verification
Python Kernel Implementation
1. Radix Tree Implementation
Start with the radix tree that enables prefix sharing:
File: src/kernels/radix_tree.py
class RadixTreeNode:
def __init__(self, token_id: Optional[int] = None):
self.token_id = token_id
self.children: Dict[int, "RadixTreeNode"] = {}
self.request_ids: List[str] = []
self.kv_cache_refs: List[str] = []
self.is_terminal = False
class RadixTree:
def __init__(self):
self.root = RadixTreeNode()
self.request_to_path: Dict[str, List[int]] = {}
self.path_to_node: Dict[str, RadixTreeNode] = {}
def insert_request(self, request_id: str, token_ids: List[int]):
"""Insert a request into the radix tree based on its token sequence"""
current = self.root
for token_id in token_ids:
if token_id not in current.children:
current.children[token_id] = RadixTreeNode(token_id)
current = current.children[token_id]
if request_id not in current.request_ids:
current.request_ids.append(request_id)
current.is_terminal = True
self.request_to_path[request_id] = token_ids
path_str = self._path_to_string(token_ids)
self.path_to_node[path_str] = current
def find_shared_prefixes(self, token_ids: List[int]) -> Tuple[List[str], int]:
"""Find requests that share prefixes with the given token sequence"""
current = self.root
matched_requests = []
prefix_length = 0
for i, token_id in enumerate(token_ids):
if token_id in current.children:
current = current.children[token_id]
matched_requests.extend(current.request_ids)
prefix_length = i + 1
else:
break
return list(set(matched_requests)), prefix_length
2. KV-Cache Management
Implement the paged KV-cache system:
File: src/kernels/kv_cache.py
class KVCacheBlock:
def __init__(self, block_id: int, size: int, num_heads: int, head_dim: int, dtype=torch.float16):
self.block_id = block_id
self.size = size
self.num_heads = num_heads
self.head_dim = head_dim
self.dtype = dtype
self.keys = None
self.values = None
def allocate(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.keys = torch.zeros(self.size, self.num_heads, self.head_dim, dtype=self.dtype, device=device)
self.values = torch.zeros(self.size, self.num_heads, self.head_dim, dtype=self.dtype, device=device)
class KVCacheManager:
def __init__(self, num_blocks: int, block_size: int, num_heads: int, head_dim: int, dtype=torch.float16):
self.num_blocks = num_blocks
self.block_size = block_size
self.num_heads = num_heads
self.head_dim = head_dim
self.dtype = dtype
self.blocks: List[KVCacheBlock] = []
for i in range(num_blocks):
block = KVCacheBlock(i, block_size, num_heads, head_dim, dtype)
self.blocks.append(block)
self.free_block_list: List[int] = list(range(num_blocks))
self.block_tables: dict = {}
def allocate_blocks(self, request_id: str, num_tokens: int) -> List[int]:
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
if len(self.free_block_list) < num_blocks_needed:
raise RuntimeError(f"Not enough free blocks. Need {num_blocks_needed}, have {len(self.free_block_list)}")
allocated_block_ids = []
for _ in range(num_blocks_needed):
block_id = self.free_block_list.pop(0)
allocated_block_ids.append(block_id)
self.blocks[block_id].allocate()
self.block_tables[request_id] = allocated_block_ids
return allocated_block_ids
3. Radix Attention Implementation
Implement the radial attention mechanism:
File: src/kernels/radix_attention.py
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class RadixAttentionBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, head_dim: int, max_position_embeddings: int = 2048):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.max_position_embeddings = max_position_embeddings
total_hidden_dim = num_heads * head_dim
self.q_proj = nn.Linear(hidden_size, total_hidden_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, total_hidden_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, total_hidden_dim, bias=False)
self.o_proj = nn.Linear(total_hidden_dim, hidden_size, bias=False)
self.register_buffer(
"cos_cached",
torch.ones((max_position_embeddings, head_dim), dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"sin_cached",
torch.ones((max_position_embeddings, head_dim), dtype=torch.float32),
persistent=False,
)
self._setup_rope_embeddings()
def _setup_rope_embeddings(self):
position_ids = torch.arange(self.max_position_embeddings, dtype=torch.float32)
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim))
freqs = torch.einsum("i,j->ij", position_ids, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos().to(dtype=torch.float16)
self.sin_cached = emb.sin().to(dtype=torch.float16)
def forward(self, hidden_states: torch.Tensor, position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> torch.Tensor:
batch_size, seq_len, _ = hidden_states.shape
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
cos_to_use = self.cos_cached[:seq_len].to(query.dtype)
sin_to_use = self.sin_cached[:seq_len].to(query.dtype)
query, key = apply_rotary_pos_emb(query, key, cos_to_use, sin_to_use, position_ids)
if past_key_value is not None:
key = torch.cat([past_key_value[0], key], dim=2)
value = torch.cat([past_key_value[1], value], dim=2)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_len, self.num_heads * self.head_dim)
output = self.o_proj(attn_output)
return output
4. Sampling Kernel Implementation
Implement the token sampling system:
File: src/kernels/sampling.py
import torch
class SamplingKernel:
def sample(self, logits: torch.Tensor, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1) -> torch.Tensor:
if temperature != 1.0:
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
batch_size, vocab_size = probs.shape
if top_k > 0:
top_k = min(top_k, vocab_size)
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
new_probs = torch.zeros_like(probs)
new_probs.scatter_(1, top_k_indices, top_k_probs)
new_probs = new_probs / new_probs.sum(dim=-1, keepdim=True)
probs = new_probs
if 0 < top_p < 1.0:
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cumulative_probs <= top_p
if mask.shape[-1] > 0:
mask[..., 0] = True
filtered_probs = torch.zeros_like(probs)
filtered_probs.scatter_(1, sorted_indices, mask.float() * sorted_probs)
filtered_probs = filtered_probs / filtered_probs.sum(dim=-1, keepdim=True)
probs = filtered_probs
sampled_ids = torch.multinomial(probs, num_samples=1).squeeze(-1)
return sampled_ids
CUDA Kernel Implementation
1. Memory Operations Kernels
File: src/kernels/cuda/memory_ops.cu
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__ void copy_blocks_kernel(
float* key_cache, float* value_cache,
float* new_key_cache, float* new_value_cache,
int* block_mapping, // [src_block_id, dst_block_id] pairs
int num_heads, int head_dim, int block_size,
int num_mappings
) {
int mapping_idx = blockIdx.x;
if (mapping_idx >= num_mappings) return;
int src_block_id = block_mapping[mapping_idx * 2];
int dst_block_id = block_mapping[mapping_idx * 2 + 1];
int total_elements_per_block = block_size * num_heads * head_dim;
int tid = threadIdx.x;
if (tid < total_elements_per_block) {
int src_idx = src_block_id * total_elements_per_block + tid;
int dst_idx = dst_block_id * total_elements_per_block + tid;
new_key_cache[dst_idx] = key_cache[src_idx];
new_value_cache[dst_idx] = value_cache[src_idx];
}
}
torch::Tensor copy_blocks_cuda(
torch::Tensor key_cache, torch::Tensor value_cache,
torch::Tensor block_mapping,
int num_heads, int head_dim, int block_size
) {
int num_mappings = block_mapping.size(0);
auto options = key_cache.options();
auto new_key_cache = torch::zeros_like(key_cache);
auto new_value_cache = torch::zeros_like(value_cache);
dim3 grid(num_mappings);
dim3 block(256); // Use 256 threads per block
copy_blocks_kernel<<<grid, block>>>(
key_cache.data_ptr<float>(),
value_cache.data_ptr<float>(),
new_key_cache.data_ptr<float>(),
new_value_cache.data_ptr<float>(),
block_mapping.data_ptr<int>(),
num_heads, head_dim, block_size,
num_mappings
);
cudaDeviceSynchronize();
return std::make_tuple(new_key_cache, new_value_cache);
}
2. Paged Attention Kernels
File: src/kernels/cuda/paged_attention.cu
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__ void paged_attention_kernel(
float* output, // [num_seqs, seq_len, num_heads, head_dim]
const float* query, // [num_seqs, seq_len, num_heads, head_dim]
const float* key_cache, // [num_blocks, block_size, num_kv_heads, head_dim]
const float* value_cache,// [num_blocks, block_size, num_kv_heads, head_dim]
const int* block_tables, // [num_seqs, max_blocks_per_seq]
const int* context_lens, // [num_seqs]
const int num_kv_heads,
const int num_queries_per_kv,
const int head_dim,
const int block_size,
const int max_num_blocks_per_seq
) {
int seq_idx = blockIdx.x;
int q_head_idx = blockIdx.y;
int token_idx = blockIdx.z * blockDim.x + threadIdx.x;
if (seq_idx >= gridDim.x || q_head_idx >= gridDim.y || token_idx >= context_lens[seq_idx]) {
return;
}
// Get corresponding KV head index
int kv_head_idx = q_head_idx / num_queries_per_kv;
// Get query vector
int query_idx = seq_idx * context_lens[seq_idx] * gridDim.y * head_dim +
token_idx * gridDim.y * head_dim +
q_head_idx * head_dim;
// Shared memory for the current query
extern __shared__ float shared_mem[];
float* query_vec = shared_mem;
// Load query vector to shared memory
for (int d = 0; d < head_dim; d++) {
query_vec[d] = query[query_idx + d];
}
// Calculate which block and offset for this token
int block_idx = token_idx / block_size;
int block_offset = token_idx % block_size;
// Get physical block number from block table
int physical_block = block_tables[seq_idx * max_num_blocks_per_seq + block_idx];
// Calculate the actual index in the cache
int cache_idx = physical_block * block_size * num_kv_heads * head_dim +
block_offset * num_kv_heads * head_dim +
kv_head_idx * head_dim;
// Perform attention computation
float sum = 0.0f;
for (int d = 0; d < head_dim; d++) {
sum += query_vec[d] * key_cache[cache_idx + d];
}
// Apply softmax and multiply with value
float attention_weight = __expf(sum); // Simplified (real softmax needs normalization)
for (int d = 0; d < head_dim; d++) {
int output_idx = seq_idx * context_lens[seq_idx] * gridDim.y * head_dim +
token_idx * gridDim.y * head_dim +
q_head_idx * head_dim + d;
output[output_idx] += attention_weight * value_cache[cache_idx + d];
}
}
torch::Tensor paged_attention_cuda(
torch::Tensor query, torch::Tensor key_cache, torch::Tensor value_cache,
torch::Tensor block_tables, torch::Tensor context_lens,
int num_kv_heads, int num_queries_per_kv
) {
int num_seqs = query.size(0);
int seq_len = query.size(1);
int num_heads = query.size(2);
int head_dim = query.size(3);
int block_size = key_cache.size(1);
int max_blocks_per_seq = block_tables.size(1);
auto output = torch::zeros_like(query);
dim3 grid(num_seqs, num_heads, (seq_len + 255) / 256); // 256 threads per block
dim3 block(256);
// Allocate shared memory for query vector
int shared_mem_size = head_dim * sizeof(float);
paged_attention_kernel<<<grid, block, shared_mem_size>>>(
output.data_ptr<float>(),
query.data_ptr<float>(),
key_cache.data_ptr<float>(),
value_cache.data_ptr<float>(),
block_tables.data_ptr<int>(),
context_lens.data_ptr<int>(),
num_kv_heads,
num_queries_per_kv,
head_dim,
block_size,
max_blocks_per_seq
);
cudaDeviceSynchronize();
return output;
}
3. Flash Attention Kernels (Simplified)
File: src/kernels/cuda/flash_attention.cu
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define TILE_SIZE 32 // Small tile for educational purposes
__global__ void flash_attention_kernel(
float* output,
const float* query,
const float* key,
const float* value,
const int* seq_lens,
const int num_seqs,
const int num_heads,
const int head_dim,
const int max_seq_len
) {
int seq_idx = blockIdx.x;
int head_idx = blockIdx.y;
if (seq_idx >= num_seqs || head_idx >= num_heads) return;
int current_seq_len = seq_lens[seq_idx];
// Shared memory for tiles
extern __shared__ float shared_mem[];
float* s_Q = shared_mem;
float* s_K = s_Q + TILE_SIZE * head_dim;
float* s_V = s_K + TILE_SIZE * head_dim;
float* s_scores = s_V + TILE_SIZE * head_dim;
// Process the sequence in tiles
for (int q_tile_start = 0; q_tile_start < current_seq_len; q_tile_start += TILE_SIZE) {
// Load Q tile to shared memory
for (int i = threadIdx.x; i < TILE_SIZE * head_dim; i += blockDim.x) {
int q_row = q_tile_start + i / head_dim;
int q_col = i % head_dim;
if (q_row < current_seq_len) {
int q_idx = seq_idx * max_seq_len * num_heads * head_dim +
q_row * num_heads * head_dim +
head_idx * head_dim + q_col;
s_Q[i] = query[q_idx];
} else {
s_Q[i] = 0.0f;
}
}
__syncthreads();
// For each K/V tile
for (int k_tile_start = 0; k_tile_start < current_seq_len; k_tile_start += TILE_SIZE) {
// Load K and V tiles
for (int i = threadIdx.x; i < TILE_SIZE * head_dim; i += blockDim.x) {
int k_row = k_tile_start + i / head_dim;
int k_col = i % head_dim;
if (k_row < current_seq_len) {
int k_idx = seq_idx * max_seq_len * num_heads * head_dim +
k_row * num_heads * head_dim +
head_idx * head_dim + k_col;
s_K[i] = key[k_idx];
s_V[i] = value[k_idx];
} else {
s_K[i] = 0.0f;
s_V[i] = 0.0f;
}
}
__syncthreads();
// Compute attention scores for this tile
for (int q_local = threadIdx.x; q_local < TILE_SIZE; q_local += blockDim.x) {
int q_global = q_tile_start + q_local;
if (q_global >= current_seq_len) continue;
float score_sum = 0.0f;
float max_score = -INFINITY;
// Compute scores against K tile
for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
int k_global = k_tile_start + k_local;
if (k_global >= current_seq_len) continue;
float score = 0.0f;
for (int d = 0; d < head_dim; d++) {
int q_offset = q_local * head_dim + d;
int k_offset = k_local * head_dim + d;
score += s_Q[q_offset] * s_K[k_offset];
}
// Apply causal mask
if (k_global > q_global) score = -INFINITY;
// Update max for numerical stability
if (score > max_score) max_score = score;
s_scores[q_local * TILE_SIZE + k_local] = score;
}
// Apply softmax with numerical stability
float exp_sum = 0.0f;
for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
int k_global = k_tile_start + k_local;
if (k_global >= current_seq_len || k_global > q_global) {
s_scores[q_local * TILE_SIZE + k_local] = 0.0f;
} else {
float score = s_scores[q_local * TILE_SIZE + k_local];
float exp_score = __expf(score - max_score);
s_scores[q_local * TILE_SIZE + k_local] = exp_score;
exp_sum += exp_score;
}
}
// Normalize scores
if (exp_sum > 0.0f) {
for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
s_scores[q_local * TILE_SIZE + k_local] /= exp_sum;
}
}
// Compute output = scores @ V
for (int d = 0; d < head_dim; d++) {
float output_val = 0.0f;
for (int k_local = 0; k_local < TILE_SIZE; k_local++) {
int v_offset = k_local * head_dim + d;
output_val += s_scores[q_local * TILE_SIZE + k_local] * s_V[v_offset];
}
int out_idx = seq_idx * max_seq_len * num_heads * head_dim +
q_global * num_heads * head_dim +
head_idx * head_dim + d;
output[out_idx] = output_val;
}
}
__syncthreads();
}
}
}
torch::Tensor flash_attention_cuda(
torch::Tensor query, torch::Tensor key, torch::Tensor value,
torch::Tensor seq_lens
) {
int num_seqs = query.size(0);
int num_heads = query.size(2);
int head_dim = query.size(3);
int max_seq_len = query.size(1);
auto output = torch::zeros_like(query);
dim3 grid(num_seqs, num_heads);
dim3 block(256);
// Shared memory for 3 tiles + scores matrix
int shared_mem_size = 3 * TILE_SIZE * head_dim * sizeof(float) +
TILE_SIZE * TILE_SIZE * sizeof(float);
flash_attention_kernel<<<grid, block, shared_mem_size>>>(
output.data_ptr<float>(),
query.data_ptr<float>(),
key.data_ptr<float>(),
value.data_ptr<float>(),
seq_lens.data_ptr<int>(),
num_seqs,
num_heads,
head_dim,
max_seq_len
);
cudaDeviceSynchronize();
return output;
}
Integration and Testing
1. Python-CUDA Interface
Create the Python interface for CUDA kernels:
File: src/kernels/api.py
"""
API for accessing both Python and CUDA implementations of kernels
"""
import torch
from typing import Optional
# Try to import CUDA extensions
try:
from . import yaie_kernels # This will be built from CUDA sources
CUDA_AVAILABLE = True
except ImportError:
print("CUDA extensions not available. Using Python implementations.")
CUDA_AVAILABLE = False
def attention_forward(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
use_cuda: bool = True, **kwargs):
"""Unified attention interface that can use CUDA or Python implementation"""
if CUDA_AVAILABLE and use_cuda and query.is_cuda:
return yaie_kernels.flash_attention_cuda(query, key, value,
kwargs.get('seq_lens', None))
else:
# Fallback to Python implementation
from .radix_attention import RadixAttentionBlock
attention_block = RadixAttentionBlock(
hidden_size=query.shape[-1],
num_heads=query.shape[-2],
head_dim=query.shape[-1] // query.shape[-2]
)
return attention_block(query)
def paged_attention_forward(query: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, block_tables: torch.Tensor,
context_lens: torch.Tensor, use_cuda: bool = True, **kwargs):
"""Paged attention interface"""
if CUDA_AVAILABLE and use_cuda and query.is_cuda:
return yaie_kernels.paged_attention_cuda(
query, key_cache, value_cache, block_tables, context_lens,
kwargs.get('num_kv_heads', 1),
kwargs.get('num_queries_per_kv', 1)
)
else:
# Python fallback would go here
raise NotImplementedError("Paged attention Python fallback not implemented")
def copy_blocks(key_cache: torch.Tensor, value_cache: torch.Tensor,
block_mapping: torch.Tensor, use_cuda: bool = True, **kwargs):
"""Memory copy interface"""
if CUDA_AVAILABLE and use_cuda and key_cache.is_cuda:
return yaie_kernels.copy_blocks_cuda(
key_cache, value_cache, block_mapping,
kwargs.get('num_heads', 1),
kwargs.get('head_dim', 64),
kwargs.get('block_size', 16)
)
else:
# Python fallback would go here
raise NotImplementedError("Copy blocks Python fallback not implemented")
2. Testing Framework
Create comprehensive tests:
File: tests/test_kernels.py
import pytest
import torch
import numpy as np
from src.kernels.radix_tree import RadixTree
from src.kernels.kv_cache import KVCacheManager
from src.kernels.radix_attention import RadixAttentionBlock
from src.kernels.sampling import SamplingKernel
from src.kernels.api import attention_forward, paged_attention_forward, copy_blocks
class TestRadixTree:
def test_basic_insertion_and_search(self):
tree = RadixTree()
# Insert requests
tree.insert_request("req1", [1, 2, 3])
tree.insert_request("req2", [1, 2, 4]) # Shares prefix [1, 2]
tree.insert_request("req3", [5, 6, 7]) # No shared prefix
# Test shared prefixes
shared, length = tree.find_shared_prefixes([1, 2, 5])
assert "req1" in shared
assert "req2" in shared
assert length == 2 # Common prefix [1, 2]
def test_prefix_sharing_graph(self):
tree = RadixTree()
tree.insert_request("req1", [1, 2, 3])
tree.insert_request("req2", [1, 2, 4])
tree.insert_request("req3", [1, 5, 6])
graph = tree.get_shared_computation_graph()
# Should show shared computation at token [1]
assert graph["request_count"] == 3 # All requests start with root
class TestKVCacheManager:
def test_basic_allocation(self):
cache_manager = KVCacheManager(
num_blocks=100,
block_size=16,
num_heads=8,
head_dim=64,
dtype=torch.float16
)
# Allocate blocks for a request
blocks = cache_manager.allocate_blocks("req1", 20) # Need 2 blocks (20/16 = 2)
assert len(blocks) == 2
# Verify the blocks exist and have proper tensors
for block_id in blocks:
block = cache_manager.blocks[block_id]
assert block.keys is not None
assert block.values is not None
assert block.keys.shape == (16, 8, 64) # block_size, num_heads, head_dim
def test_block_reuse(self):
cache_manager = KVCacheManager(
num_blocks=10,
block_size=16,
num_heads=8,
head_dim=64
)
# Allocate all blocks
req_ids = [f"req{i}" for i in range(10)]
for req_id in req_ids:
cache_manager.allocate_blocks(req_id, 10)
assert len(cache_manager.free_block_list) == 0
# Free some blocks
cache_manager.free_blocks("req0")
cache_manager.free_blocks("req1")
assert len(cache_manager.free_block_list) == 2
assert 0 in cache_manager.free_block_list
assert 1 in cache_manager.free_block_list
class TestRadixAttention:
def test_basic_attention_forward(self):
hidden_size = 512
num_heads = 8
head_dim = hidden_size // num_heads
attention = RadixAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
head_dim=head_dim,
max_position_embeddings=256
)
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, hidden_size, dtype=torch.float16)
output = attention(x)
assert output.shape == (batch_size, seq_len, hidden_size)
assert not torch.isnan(output).any()
def test_attention_with_past_key_value(self):
hidden_size = 256
num_heads = 4
head_dim = hidden_size // num_heads
attention = RadixAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
head_dim=head_dim
)
batch_size = 1
seq_len = 5
x = torch.randn(batch_size, seq_len, hidden_size)
# First forward pass
output1, _, past_kv = attention(x, use_cache=True)
# Second forward pass with past key-value
next_token = torch.randn(batch_size, 1, hidden_size)
output2, _, _ = attention(next_token, past_key_value=past_kv)
assert output2.shape == (batch_size, 1, hidden_size)
class TestSamplingKernel:
def test_temperature_sampling(self):
sampling = SamplingKernel()
# Create logits with one clear winner
logits = torch.tensor([[10.0, 1.0, 1.0, 1.0]]) # First token is dominant
# High temperature should allow other tokens
sampled_high_temp = sampling.sample(logits, temperature=2.0)
assert sampled_high_temp.shape == (1,)
# Low temperature should favor dominant token
sampled_low_temp = sampling.sample(logits, temperature=0.1)
assert sampled_low_temp[0] == 0 # Should pick the dominant token
def test_top_p_nucleus_sampling(self):
sampling = SamplingKernel()
# Create logits where first 3 tokens account for ~90% of probability
logits = torch.tensor([[2.0, 1.5, 1.0, -10.0, -10.0]])
# Top-p = 0.8 should exclude the last two tokens
sampled = sampling.sample(logits, top_p=0.8)
# Should be one of the first 3 tokens
assert sampled[0] in [0, 1, 2]
def test_top_k_sampling(self):
sampling = SamplingKernel()
# Create logits with clear ordering
logits = torch.tensor([[5.0, 4.0, 3.0, 2.0, 1.0]])
# Top-k = 2 should only consider first 2 tokens
sampled = sampling.sample(logits, top_k=2)
assert sampled[0] in [0, 1] # Should be one of top 2 tokens
class TestIntegration:
def test_full_inference_pipeline(self):
"""Test integration of all kernels in a simple pipeline"""
# This test would simulate a full inference step
batch_size = 2
seq_len = 10
hidden_size = 256
num_heads = 4
head_dim = hidden_size // num_heads
# Create input
x = torch.randn(batch_size, seq_len, hidden_size)
# Apply attention
attention = RadixAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
head_dim=head_dim
)
attn_output = attention(x)
assert attn_output.shape == (batch_size, seq_len, hidden_size)
# Apply sampling (on logits that would come from LM head)
logits = torch.randn(batch_size, 1000) # vocab_size = 1000
sampling = SamplingKernel()
sampled_tokens = sampling.sample(logits, temperature=0.7)
assert sampled_tokens.shape == (batch_size,)
if __name__ == "__main__":
pytest.main([__file__])
Building and Running
Setup Configuration
File: setup_kernels.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os
# Check if CUDA is available
def get_extensions():
extensions = []
# Check if CUDA is available
try:
import torch
if torch.cuda.is_available():
extensions.append(
CUDAExtension(
name='yaie_kernels',
sources=[
'src/kernels/cuda/memory_ops.cu',
'src/kernels/cuda/paged_attention.cu',
'src/kernels/cuda/flash_attention.cu',
'src/kernels/cuda/radix_ops.cu',
'src/kernels/cuda/pybind.cpp', # Python bindings
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '--use_fast_math', '-arch=sm_70']
}
)
)
except:
print("CUDA not available, building without CUDA extensions")
return extensions
setup(
name='yaie_kernels',
ext_modules=get_extensions(),
cmdclass={'build_ext': BuildExtension},
zip_safe=False,
)
Performance Optimization Guidelines
CUDA Optimization Tips
- Memory Coalescing: Ensure threads in a warp access consecutive memory
- Shared Memory: Use for frequently accessed data
- Occupancy: Maximize number of active warps
- Reduction Operations: Use efficient parallel reduction algorithms
Profiling and Benchmarking
Create benchmarking tools:
import torch
import time
from torch.profiler import profile, record_function, ProfilerActivity
def benchmark_kernel(kernel_func, *args, **kwargs):
"""Benchmark a kernel function"""
# Warmup
for _ in range(3):
result = kernel_func(*args, **kwargs)
# Actual timing
torch.cuda.synchronize()
start_time = time.time()
for _ in range(10): # Run multiple times for average
result = kernel_func(*args, **kwargs)
torch.cuda.synchronize()
end_time = time.time()
avg_time = (end_time - start_time) / 10
return avg_time, result
def profile_kernel(kernel_func, *args, **kwargs):
"""Profile a kernel function"""
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
result = kernel_func(*args, **kwargs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
return result
This comprehensive implementation guide provides everything needed to implement the core kernels for Mini-YAIE, following SGLang-style optimization principles while maintaining educational value.