Skip to main content

unillm_kv/
lib.rs

1//! Hybrid KV cache combining RadixAttention and PagedAttention
2//!
3//! This crate implements UniLLM's innovative memory management system that combines:
4//! - SGLang's RadixAttention for token-level prefix sharing (L1)
5//! - vLLM's PagedAttention for block-level efficiency (L2)
6//! - Compressed storage for cold data (L3)
7
8mod hybrid_cache;
9mod gpu_memory;
10mod gpu_integrated_cache;
11
12pub use hybrid_cache::{
13    HybridKVCache, CacheHandle, CacheTier, KVTensorPair, TokenId, SequenceId,
14    RadixCache, AdaptiveCachePolicy, CachePolicy, HybridCacheStats, CacheAnalysis
15};
16
17pub use gpu_memory::{
18    GpuAwareMemoryPool, GpuMemoryBackend, GpuDevicePtr, GpuAllocation,
19    GpuMemoryError, GpuMemoryResult, GpuMemoryStats, GpuDeviceProperties,
20    CudaMemoryBackend, HipMemoryBackend
21};
22
23pub use gpu_integrated_cache::{
24    GpuIntegratedCache, GpuIntegratedCacheBuilder, GpuIntegratedCacheStats,
25    GpuBackendType
26};
27
28#[cfg(test)]
29mod tests;
30
31use std::collections::{HashMap, VecDeque, HashSet};
32use std::sync::{Arc, Mutex};
33use std::time::{Duration, Instant};
34
35/// A page in the KV cache
36#[derive(Debug, Clone)]
37pub struct KvPage {
38    /// Unique page ID
39    pub id: u32,
40    /// Device pointer to the page memory
41    pub device_ptr: u64,
42    /// Whether the page is currently allocated
43    pub allocated: bool,
44    /// Sequence ID that owns this page
45    pub owner_seq_id: Option<u32>,
46    /// Timestamp when page was allocated
47    pub allocated_at: Option<Instant>,
48}
49
50impl KvPage {
51    /// Create a new KV page
52    pub fn new(id: u32, device_ptr: u64) -> Self {
53        Self {
54            id,
55            device_ptr,
56            allocated: false,
57            owner_seq_id: None,
58            allocated_at: None,
59        }
60    }
61    
62    /// Allocate the page to a sequence
63    pub fn allocate(&mut self, seq_id: u32) {
64        self.allocated = true;
65        self.owner_seq_id = Some(seq_id);
66        self.allocated_at = Some(Instant::now());
67    }
68    
69    /// Free the page
70    pub fn free(&mut self) {
71        self.allocated = false;
72        self.owner_seq_id = None;
73        self.allocated_at = None;
74    }
75}
76
77/// A block of pages (typically 16 pages per block)
78#[derive(Debug)]
79pub struct KvBlock {
80    /// Block ID
81    pub id: u32,
82    /// Pages in this block
83    pub pages: Vec<KvPage>,
84    /// Number of allocated pages in this block
85    pub allocated_count: usize,
86    /// Whether the block is fully allocated
87    pub is_full: bool,
88}
89
90impl KvBlock {
91    /// Create a new KV block
92    pub fn new(id: u32, page_count: usize, base_device_ptr: u64, page_size: usize) -> Self {
93        let mut pages = Vec::new();
94        for i in 0..page_count {
95            let page_id = id * page_count as u32 + i as u32;
96            let device_ptr = base_device_ptr + (i * page_size) as u64;
97            pages.push(KvPage::new(page_id, device_ptr));
98        }
99        
100        Self {
101            id,
102            pages,
103            allocated_count: 0,
104            is_full: false,
105        }
106    }
107    
108    /// Allocate a page from this block
109    pub fn allocate_page(&mut self, seq_id: u32) -> Option<u32> {
110        if self.is_full {
111            return None;
112        }
113
114        let pages_len = self.pages.len();
115        for page in &mut self.pages {
116            if !page.allocated {
117                let page_id = page.id;
118                page.allocate(seq_id);
119                self.allocated_count += 1;
120                if self.allocated_count == pages_len {
121                    self.is_full = true;
122                }
123                return Some(page_id);
124            }
125        }
126
127        None
128    }
129    
130    /// Free a page in this block
131    pub fn free_page(&mut self, page_id: u32) -> bool {
132        for page in &mut self.pages {
133            if page.id == page_id && page.allocated {
134                page.free();
135                self.allocated_count -= 1;
136                self.is_full = false;
137                return true;
138            }
139        }
140        false
141    }
142}
143
144/// Sequence information for KV cache management
145#[derive(Debug, Clone)]
146pub struct KvSequence {
147    /// Sequence ID
148    pub seq_id: u32,
149    /// Pages allocated to this sequence
150    pub pages: Vec<u32>,
151    /// Current sequence length
152    pub length: usize,
153    /// Maximum sequence length
154    pub max_length: usize,
155    /// Number of tokens processed
156    pub tokens_processed: usize,
157    /// Whether the sequence is active
158    pub is_active: bool,
159    /// Creation timestamp
160    pub created_at: Instant,
161}
162
163impl KvSequence {
164    /// Create a new KV sequence
165    pub fn new(seq_id: u32, max_length: usize) -> Self {
166        Self {
167            seq_id,
168            pages: Vec::new(),
169            length: 0,
170            max_length,
171            tokens_processed: 0,
172            is_active: true,
173            created_at: Instant::now(),
174        }
175    }
176    
177    /// Add a page to this sequence
178    pub fn add_page(&mut self, page_id: u32) {
179        self.pages.push(page_id);
180    }
181    
182    /// Remove a page from this sequence
183    pub fn remove_page(&mut self, page_id: u32) -> bool {
184        if let Some(pos) = self.pages.iter().position(|&id| id == page_id) {
185            self.pages.remove(pos);
186            true
187        } else {
188            false
189        }
190    }
191    
192    /// Check if sequence needs more pages
193    pub fn needs_more_pages(&self, page_size: usize) -> bool {
194        let current_capacity = self.pages.len() * page_size;
195        current_capacity < self.max_length
196    }
197    
198    /// Get the number of pages needed
199    pub fn pages_needed(&self, page_size: usize) -> usize {
200        let current_capacity = self.pages.len() * page_size;
201        let needed_capacity = self.max_length - current_capacity;
202        (needed_capacity + page_size - 1) / page_size // Ceiling division
203    }
204}
205
206/// Paged KV allocator implementation
207pub struct PagedKvAllocator {
208    /// All blocks in the allocator
209    blocks: Vec<KvBlock>,
210    /// Free pages available for allocation
211    free_pages: VecDeque<u32>,
212    /// Active sequences
213    sequences: HashMap<u32, KvSequence>,
214    /// Page size in tokens
215    page_size: usize,
216    /// Pages per block
217    pages_per_block: usize,
218    /// Total number of pages
219    total_pages: usize,
220    /// Number of allocated pages
221    allocated_pages: usize,
222    /// Base device pointer for the first page
223    base_device_ptr: u64,
224    /// Next sequence ID
225    next_seq_id: u32,
226}
227
228impl PagedKvAllocator {
229    /// Create a new paged KV allocator
230    /// 
231    /// # Arguments
232    /// * `total_pages` - Total number of pages to allocate
233    /// * `page_size` - Size of each page in tokens
234    /// * `pages_per_block` - Number of pages per block
235    /// * `base_device_ptr` - Base device pointer for memory allocation
236    pub fn new(
237        total_pages: usize,
238        page_size: usize,
239        pages_per_block: usize,
240        base_device_ptr: u64,
241    ) -> Self {
242        let num_blocks = (total_pages + pages_per_block - 1) / pages_per_block;
243        let mut blocks = Vec::new();
244        let mut free_pages = VecDeque::new();
245        
246        for block_id in 0..num_blocks {
247            let pages_in_block = std::cmp::min(pages_per_block, total_pages - block_id * pages_per_block);
248            let block_base_ptr = base_device_ptr + (block_id * pages_per_block * page_size) as u64;
249            
250            let block = KvBlock::new(block_id as u32, pages_in_block, block_base_ptr, page_size);
251            
252            // Add all pages to free list
253            for page in &block.pages {
254                free_pages.push_back(page.id);
255            }
256            
257            blocks.push(block);
258        }
259        
260        Self {
261            blocks,
262            free_pages,
263            sequences: HashMap::new(),
264            page_size,
265            pages_per_block,
266            total_pages,
267            allocated_pages: 0,
268            base_device_ptr,
269            next_seq_id: 0,
270        }
271    }
272    
273    /// Allocate pages for a new sequence
274    /// 
275    /// # Arguments
276    /// * `max_length` - Maximum sequence length
277    /// 
278    /// # Returns
279    /// Sequence ID and allocated page IDs
280    pub fn allocate_sequence(&mut self, max_length: usize) -> Result<(u32, Vec<u32>), Box<dyn std::error::Error>> {
281        let seq_id = self.next_seq_id;
282        self.next_seq_id += 1;
283        
284        let pages_needed = (max_length + self.page_size - 1) / self.page_size;
285        
286        if self.free_pages.len() < pages_needed {
287            return Err(format!("Not enough free pages: need {}, have {}", pages_needed, self.free_pages.len()).into());
288        }
289        
290        let mut allocated_pages = Vec::new();
291        
292        // Allocate pages
293        for _ in 0..pages_needed {
294            if let Some(page_id) = self.free_pages.pop_front() {
295                // Find the block containing this page
296                let block_id = page_id / self.pages_per_block as u32;
297                if let Some(block) = self.blocks.get_mut(block_id as usize) {
298                    if let Some(page_id) = block.allocate_page(seq_id) {
299                        allocated_pages.push(page_id);
300                        self.allocated_pages += 1;
301                    }
302                }
303            }
304        }
305        
306        // Create sequence
307        let mut sequence = KvSequence::new(seq_id, max_length);
308        sequence.pages = allocated_pages.clone();
309        self.sequences.insert(seq_id, sequence);
310        
311        println!("Allocated sequence {} with {} pages (max_length: {})", seq_id, allocated_pages.len(), max_length);
312        
313        Ok((seq_id, allocated_pages))
314    }
315    
316    /// Free a sequence and return its pages to the free pool
317    pub fn free_sequence(&mut self, seq_id: u32) -> Result<(), Box<dyn std::error::Error>> {
318        if let Some(sequence) = self.sequences.remove(&seq_id) {
319            let pages_count = sequence.pages.len();
320            for page_id in sequence.pages {
321                // Find the block containing this page
322                let block_id = page_id / self.pages_per_block as u32;
323                if let Some(block) = self.blocks.get_mut(block_id as usize) {
324                    if block.free_page(page_id) {
325                        self.free_pages.push_back(page_id);
326                        self.allocated_pages -= 1;
327                    }
328                }
329            }
330            
331            println!("Freed sequence {} with {} pages", seq_id, pages_count);
332        }
333        
334        Ok(())
335    }
336    
337    /// Extend a sequence with more pages
338    pub fn extend_sequence(&mut self, seq_id: u32, additional_length: usize) -> Result<Vec<u32>, Box<dyn std::error::Error>> {
339        if let Some(sequence) = self.sequences.get_mut(&seq_id) {
340            let additional_pages = (additional_length + self.page_size - 1) / self.page_size;
341            
342            if self.free_pages.len() < additional_pages {
343                return Err(format!("Not enough free pages for extension: need {}, have {}", additional_pages, self.free_pages.len()).into());
344            }
345            
346            let mut new_pages = Vec::new();
347            
348            // Allocate additional pages
349            for _ in 0..additional_pages {
350                if let Some(page_id) = self.free_pages.pop_front() {
351                    let block_id = page_id / self.pages_per_block as u32;
352                    if let Some(block) = self.blocks.get_mut(block_id as usize) {
353                        if let Some(page_id) = block.allocate_page(seq_id) {
354                            new_pages.push(page_id);
355                            sequence.pages.push(page_id);
356                            self.allocated_pages += 1;
357                        }
358                    }
359                }
360            }
361            
362            println!("Extended sequence {} with {} additional pages", seq_id, new_pages.len());
363            Ok(new_pages)
364        } else {
365            Err(format!("Sequence {} not found", seq_id).into())
366        }
367    }
368    
369    /// Get sequence information
370    pub fn get_sequence(&self, seq_id: u32) -> Option<&KvSequence> {
371        self.sequences.get(&seq_id)
372    }
373    
374    /// Get all active sequences
375    pub fn get_active_sequences(&self) -> Vec<&KvSequence> {
376        self.sequences.values().filter(|s| s.is_active).collect()
377    }
378    
379    /// Get memory usage statistics
380    pub fn get_stats(&self) -> KvAllocatorStats {
381        KvAllocatorStats {
382            total_pages: self.total_pages,
383            allocated_pages: self.allocated_pages,
384            free_pages: self.free_pages.len(),
385            active_sequences: self.sequences.len(),
386            memory_usage_percent: (self.allocated_pages as f64 / self.total_pages as f64) * 100.0,
387        }
388    }
389    
390    /// Defragment memory by moving sequences to contiguous pages
391    pub fn defragment(&mut self) -> Result<(), Box<dyn std::error::Error>> {
392        println!("Starting memory defragmentation...");
393        
394        // For now, this is a placeholder implementation
395        // In a real implementation, we would:
396        // 1. Identify fragmented sequences
397        // 2. Move pages to create contiguous blocks
398        // 3. Update device pointers accordingly
399        
400        println!("Memory defragmentation completed");
401        Ok(())
402    }
403}
404
405/// Memory usage statistics
406#[derive(Debug, Clone)]
407pub struct KvAllocatorStats {
408    pub total_pages: usize,
409    pub allocated_pages: usize,
410    pub free_pages: usize,
411    pub active_sequences: usize,
412    pub memory_usage_percent: f64,
413}
414
415impl std::fmt::Display for KvAllocatorStats {
416    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417        write!(f, "KV Allocator Stats: {}/{} pages allocated ({:.1}%), {} free pages, {} active sequences",
418               self.allocated_pages, self.total_pages, self.memory_usage_percent, self.free_pages, self.active_sequences)
419    }
420}