scirs2_neural/inference/mod.rs
1//! LLM inference utilities: paged KV cache, block manager, prefix caching.
2//!
3//! This module implements a production-grade paged KV cache for LLM inference,
4//! inspired by vLLM's PagedAttention architecture. Key features:
5//!
6//! - **Non-contiguous memory**: Keys/values stored in fixed-size pages, allowing
7//! flexible memory allocation without fragmentation.
8//! - **Block manager**: Tracks page chains per sequence, handles allocation and eviction.
9//! - **Prefix sharing**: Shared prefix cache enables KV reuse across requests with
10//! common prefixes (e.g., system prompts).
11//! - **Paged attention**: Attention computation over non-contiguous page chains.
12//!
13//! ## Architecture
14//!
15//! ```text
16//! ┌─────────────────────────────────────────────────────┐
17//! │ KvPagePool │
18//! │ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │
19//! │ │ Page 0 │ │ Page 1 │ │ Page 2 │ │ Page 3 │ ... │
20//! │ │[bs,H,D]│ │[bs,H,D]│ │[bs,H,D]│ │[bs,H,D]│ │
21//! │ └────────┘ └────────┘ └────────┘ └────────┘ │
22//! └─────────────────────────────────────────────────────┘
23//! ↑
24//! BlockManager maps SeqId → [PageId, PageId, ...]
25//! ```
26//!
27//! ## Example
28//!
29//! ```rust
30//! use scirs2_neural::inference::{
31//! KvPageConfig, KvPagePool, BlockManagerConfig, BlockManager,
32//! PagedAttentionConfig, PagedAttentionForward,
33//! };
34//!
35//! // Configure pages: block_size=16 tokens, 8 heads, head_dim=64
36//! let page_cfg = KvPageConfig {
37//! block_size: 16,
38//! num_heads: 8,
39//! head_dim: 64,
40//! dtype_bytes: 4,
41//! };
42//! let pool = KvPagePool::<f32>::new(128, page_cfg);
43//! let bm_cfg = BlockManagerConfig {
44//! max_sequences: 32,
45//! max_pages_per_seq: 64,
46//! };
47//! let _manager = BlockManager::<f32>::new(pool, bm_cfg);
48//! ```
49
50pub mod block_manager;
51pub mod kv_page;
52pub mod paged_attention;
53pub mod speculative;
54
55pub use block_manager::{BlockManager, BlockManagerConfig, SharedPrefixCache};
56pub use kv_page::{KvPage, KvPageConfig, KvPagePool, PageId};
57pub use paged_attention::{PagedAttentionConfig, PagedAttentionForward};
58
59use crate::NeuralError;
60
61/// Errors specific to inference / paged KV cache operations.
62#[non_exhaustive]
63#[derive(Debug, thiserror::Error)]
64pub enum InferenceError {
65 /// No free pages remaining in the pool.
66 #[error("Out of memory: no free pages in pool")]
67 Oom,
68
69 /// The requested sequence ID is not tracked by the block manager.
70 #[error("Sequence not found: {0}")]
71 SequenceNotFound(u64),
72
73 /// A page was freed more than once.
74 #[error("Page {0} already freed")]
75 DoubleFree(PageId),
76
77 /// Page ID is out of bounds for the pool.
78 #[error("Page index {0} out of bounds (pool size {1})")]
79 PageOutOfBounds(PageId, usize),
80
81 /// Slot position exceeds page capacity.
82 #[error("Slot {slot} out of range for page capacity {capacity}")]
83 SlotOutOfRange {
84 /// Requested slot index.
85 slot: usize,
86 /// Page capacity (block_size).
87 capacity: usize,
88 },
89
90 /// Shape mismatch when writing key/value tensors.
91 #[error(
92 "Shape mismatch: expected [{expected_heads}, {expected_dim}], got [{got_heads}, {got_dim}]"
93 )]
94 KvShapeMismatch {
95 /// Expected number of heads.
96 expected_heads: usize,
97 /// Expected head dimension.
98 expected_dim: usize,
99 /// Actual number of heads.
100 got_heads: usize,
101 /// Actual head dimension.
102 got_dim: usize,
103 },
104
105 /// The sequence already has the maximum number of pages allocated.
106 #[error("Sequence {0} has reached max pages per sequence")]
107 MaxPagesExceeded(u64),
108
109 /// A neural layer error propagated from elsewhere in the crate.
110 #[error("Neural error: {0}")]
111 Neural(#[from] NeuralError),
112}
113
114/// Convenience alias for `Result<T, InferenceError>`.
115pub type InferenceResult<T> = Result<T, InferenceError>;