Skip to main content

scirs2_neural/inference/
mod.rs

1//! LLM inference utilities: paged KV cache, block manager, prefix caching.
2//!
3//! This module implements a production-grade paged KV cache for LLM inference,
4//! inspired by vLLM's PagedAttention architecture. Key features:
5//!
6//! - **Non-contiguous memory**: Keys/values stored in fixed-size pages, allowing
7//!   flexible memory allocation without fragmentation.
8//! - **Block manager**: Tracks page chains per sequence, handles allocation and eviction.
9//! - **Prefix sharing**: Shared prefix cache enables KV reuse across requests with
10//!   common prefixes (e.g., system prompts).
11//! - **Paged attention**: Attention computation over non-contiguous page chains.
12//!
13//! ## Architecture
14//!
15//! ```text
16//! ┌─────────────────────────────────────────────────────┐
17//! │                  KvPagePool                         │
18//! │  ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐      │
19//! │  │ Page 0 │ │ Page 1 │ │ Page 2 │ │ Page 3 │ ...  │
20//! │  │[bs,H,D]│ │[bs,H,D]│ │[bs,H,D]│ │[bs,H,D]│      │
21//! │  └────────┘ └────────┘ └────────┘ └────────┘      │
22//! └─────────────────────────────────────────────────────┘
23//!            ↑
24//!    BlockManager maps SeqId → [PageId, PageId, ...]
25//! ```
26//!
27//! ## Example
28//!
29//! ```rust
30//! use scirs2_neural::inference::{
31//!     KvPageConfig, KvPagePool, BlockManagerConfig, BlockManager,
32//!     PagedAttentionConfig, PagedAttentionForward,
33//! };
34//!
35//! // Configure pages: block_size=16 tokens, 8 heads, head_dim=64
36//! let page_cfg = KvPageConfig {
37//!     block_size: 16,
38//!     num_heads: 8,
39//!     head_dim: 64,
40//!     dtype_bytes: 4,
41//! };
42//! let pool = KvPagePool::<f32>::new(128, page_cfg);
43//! let bm_cfg = BlockManagerConfig {
44//!     max_sequences: 32,
45//!     max_pages_per_seq: 64,
46//! };
47//! let _manager = BlockManager::<f32>::new(pool, bm_cfg);
48//! ```
49
50pub mod block_manager;
51pub mod kv_page;
52pub mod paged_attention;
53pub mod speculative;
54
55pub use block_manager::{BlockManager, BlockManagerConfig, SharedPrefixCache};
56pub use kv_page::{KvPage, KvPageConfig, KvPagePool, PageId};
57pub use paged_attention::{PagedAttentionConfig, PagedAttentionForward};
58
59use crate::NeuralError;
60
61/// Errors specific to inference / paged KV cache operations.
62#[non_exhaustive]
63#[derive(Debug, thiserror::Error)]
64pub enum InferenceError {
65    /// No free pages remaining in the pool.
66    #[error("Out of memory: no free pages in pool")]
67    Oom,
68
69    /// The requested sequence ID is not tracked by the block manager.
70    #[error("Sequence not found: {0}")]
71    SequenceNotFound(u64),
72
73    /// A page was freed more than once.
74    #[error("Page {0} already freed")]
75    DoubleFree(PageId),
76
77    /// Page ID is out of bounds for the pool.
78    #[error("Page index {0} out of bounds (pool size {1})")]
79    PageOutOfBounds(PageId, usize),
80
81    /// Slot position exceeds page capacity.
82    #[error("Slot {slot} out of range for page capacity {capacity}")]
83    SlotOutOfRange {
84        /// Requested slot index.
85        slot: usize,
86        /// Page capacity (block_size).
87        capacity: usize,
88    },
89
90    /// Shape mismatch when writing key/value tensors.
91    #[error(
92        "Shape mismatch: expected [{expected_heads}, {expected_dim}], got [{got_heads}, {got_dim}]"
93    )]
94    KvShapeMismatch {
95        /// Expected number of heads.
96        expected_heads: usize,
97        /// Expected head dimension.
98        expected_dim: usize,
99        /// Actual number of heads.
100        got_heads: usize,
101        /// Actual head dimension.
102        got_dim: usize,
103    },
104
105    /// The sequence already has the maximum number of pages allocated.
106    #[error("Sequence {0} has reached max pages per sequence")]
107    MaxPagesExceeded(u64),
108
109    /// A neural layer error propagated from elsewhere in the crate.
110    #[error("Neural error: {0}")]
111    Neural(#[from] NeuralError),
112}
113
114/// Convenience alias for `Result<T, InferenceError>`.
115pub type InferenceResult<T> = Result<T, InferenceError>;