Radix Attention Module (kernels/radix_attention.py)
1. Concept: Connecting the Dots
We have a Radix Tree (prefix matching) and a Paged KV Cache (memory management). The RadixAttentionWithPagedKVCache class is the glue that runs on the CPU (Python side) to manage these resources before we launch the GPU kernels.
It doesn’t run the attention math (that’s the CUDA kernel’s job). Instead, it manages the metadata:
- “Request A needs to append ‘cat’ to its sequence.”
- “Does ‘cat’ already exist in the Radix Tree?”
- “If yes, reuse the block.”
- “If no, allocate a new block.”
2. Implementation Guide
Open src/kernels/radix_attention.py.
Step 1: Initialization
You need to initialize the two sub-components we built earlier.
class RadixAttentionWithPagedKVCache:
def __init__(self, ...):
# ...
self.radix_tree = RadixTree()
self.kv_cache_manager = KVCacheManager(...)
Step 2: append_slot (The Critical Logic)
This method is called when we want to add a new token (or tokens) to a request.
Signature:
def append_slot(self, key: torch.Tensor, value: torch.Tensor, request_id: str):
key/value: The computed K/V tensors for the new token(s).
Algorithm:
- Check Tree: Use
self.radix_treeto see if this(request_id + new_token)path already exists?- Note: In a real system, we check before computing K/V. Here, we might just be managing the cache storage.
- Allocate: If we need new space, call
self.kv_cache_manager.allocate_blocks(). - Store: We need to perform the copy.
- Ideally, we just return the indices of where to write, and the GPU kernel does the writing.
- For this Python simulation, you might simulate the copy or just track the metadata.
Step 3: get_kv_cache
The scheduler asks: “I am about to run requests [R1, R2]. Where is their data?”
Algorithm:
- Loop through
request_ids. - For each, ask
self.kv_cache_managerfor its block table (list of integers). - Pack these lists into a single Tensor
block_tables. - Return
block_tablesto the Engine.
Step 4: free_request
When a request is done:
self.radix_tree.remove_request(request_id)(Decrement ref counts).self.kv_cache_manager.free_blocks(request_id)(Reclaim memory).
3. The RadixAttentionBlock (Model Layer)
The class RadixAttentionBlock is the PyTorch module that sits in the model.
Task:
In forward():
- Compute Q, K, V projections.
- Compute RoPE (Rotary Embeddings).
- If Prefill: Use Flash Attention (or a standard attention) on the new tokens.
- If Decode:
- Call
append_slotto save the new K/V. - Call
paged_attention_kernel(the CUDA op) to attend to the entire history using the block tables.
- Call
Exercise: Since we don’t have the full model weight loading for this specific block, focus on the logic flow in the comments.