ruvector_sparse_inference/integration/
ruvllm.rs

1//! RuvLLM InferenceBackend integration
2//!
3//! This module provides a sparse inference backend that integrates with
4//! the RuvLLM language model framework for efficient text generation.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use ruvector_sparse_inference::integration::SparseInferenceBackend;
10//!
11//! let backend = SparseInferenceBackend::from_gguf("llama-7b.gguf")?;
12//! let output = backend.generate(&[1, 2, 3], 100)?;
13//! ```
14
15use crate::{
16    config::{ActivationType, SparsityConfig, CacheConfig},
17    error::{Result, SparseInferenceError},
18    model::{GgufParser, GgufModel, InferenceConfig, ModelMetadata, ModelRunner},
19    memory::NeuronCache,
20    predictor::{LowRankPredictor, Predictor},
21    sparse::SparseFfn,
22};
23
24/// KV Cache for autoregressive generation
25#[derive(Debug)]
26pub struct KVCache {
27    /// Key cache per layer
28    keys: Vec<Vec<Vec<f32>>>,
29    /// Value cache per layer
30    values: Vec<Vec<Vec<f32>>>,
31    /// Maximum sequence length
32    max_length: usize,
33    /// Current sequence length
34    current_length: usize,
35}
36
37impl KVCache {
38    /// Create a new KV cache
39    pub fn new(num_layers: usize, max_length: usize, head_dim: usize) -> Self {
40        Self {
41            keys: vec![Vec::new(); num_layers],
42            values: vec![Vec::new(); num_layers],
43            max_length,
44            current_length: 0,
45        }
46    }
47
48    /// Clear the cache
49    pub fn clear(&mut self) {
50        for layer_keys in &mut self.keys {
51            layer_keys.clear();
52        }
53        for layer_values in &mut self.values {
54            layer_values.clear();
55        }
56        self.current_length = 0;
57    }
58
59    /// Get current sequence length
60    pub fn len(&self) -> usize {
61        self.current_length
62    }
63
64    /// Check if cache is empty
65    pub fn is_empty(&self) -> bool {
66        self.current_length == 0
67    }
68
69    /// Append key-value pair for a layer
70    pub fn append(&mut self, layer: usize, key: Vec<f32>, value: Vec<f32>) {
71        if layer < self.keys.len() {
72            self.keys[layer].push(key);
73            self.values[layer].push(value);
74            if layer == 0 {
75                self.current_length += 1;
76            }
77        }
78    }
79}
80
81/// Generation configuration
82#[derive(Debug, Clone)]
83pub struct GenerationConfig {
84    /// Maximum new tokens to generate
85    pub max_new_tokens: usize,
86    /// Temperature for sampling
87    pub temperature: f32,
88    /// Top-K sampling
89    pub top_k: usize,
90    /// Top-P (nucleus) sampling
91    pub top_p: f32,
92    /// Repetition penalty
93    pub repetition_penalty: f32,
94    /// Stop tokens
95    pub stop_tokens: Vec<u32>,
96}
97
98impl Default for GenerationConfig {
99    fn default() -> Self {
100        Self {
101            max_new_tokens: 100,
102            temperature: 0.7,
103            top_k: 50,
104            top_p: 0.9,
105            repetition_penalty: 1.1,
106            stop_tokens: vec![2], // Default EOS token
107        }
108    }
109}
110
111/// Generation statistics
112#[derive(Debug, Clone, Default)]
113pub struct GenerationStats {
114    /// Total tokens generated
115    pub tokens_generated: usize,
116    /// Average inference time per token (ms)
117    pub avg_token_time_ms: f64,
118    /// Average sparsity ratio
119    pub avg_sparsity: f64,
120    /// Total inference time (ms)
121    pub total_time_ms: f64,
122}
123
124/// Sparse inference backend for RuvLLM integration
125pub struct SparseInferenceBackend {
126    /// Model metadata
127    metadata: ModelMetadata,
128    /// Layer predictors (one per layer)
129    predictors: Vec<LowRankPredictor>,
130    /// Layer FFNs (one per layer)
131    ffns: Vec<SparseFfn>,
132    /// Neuron cache for hot neurons
133    neuron_cache: NeuronCache,
134    /// Inference configuration
135    config: InferenceConfig,
136    /// Generation statistics
137    stats: GenerationStats,
138    /// Vocabulary size
139    vocab_size: usize,
140}
141
142impl SparseInferenceBackend {
143    /// Create a new sparse inference backend
144    pub fn new(
145        num_layers: usize,
146        hidden_dim: usize,
147        intermediate_dim: usize,
148        vocab_size: usize,
149        sparsity_ratio: f32,
150    ) -> Result<Self> {
151        // Use top-K selection based on sparsity ratio for reliable activation
152        let target_active = ((1.0 - sparsity_ratio) * intermediate_dim as f32).max(1.0) as usize;
153        let sparsity_config = SparsityConfig {
154            threshold: None,
155            top_k: Some(target_active),
156            target_sparsity: Some(sparsity_ratio),
157            adaptive_threshold: false,
158        };
159
160        let cache_config = CacheConfig {
161            hot_neuron_fraction: 0.2, // 20% hot neurons
162            max_cold_cache_size: 1000,
163            cache_strategy: crate::config::CacheStrategy::Lru,
164            hot_neuron_count: (intermediate_dim as f32 * 0.2) as usize,
165            lru_cache_size: 4096,
166            use_mmap: false,
167            hot_threshold: 0.5,
168        };
169
170        // Create predictors and FFNs for each layer
171        let mut predictors = Vec::with_capacity(num_layers);
172        let mut ffns = Vec::with_capacity(num_layers);
173
174        for _ in 0..num_layers {
175            let predictor = LowRankPredictor::new(
176                hidden_dim,
177                intermediate_dim,
178                intermediate_dim / 32,
179                sparsity_config.clone(),
180            )?;
181            predictors.push(predictor);
182
183            let ffn = SparseFfn::new(
184                hidden_dim,
185                intermediate_dim,
186                hidden_dim,
187                ActivationType::Silu, // Llama uses SiLU
188            )?;
189            ffns.push(ffn);
190        }
191
192        let neuron_cache = NeuronCache::new(intermediate_dim, cache_config);
193
194        let metadata = ModelMetadata {
195            hidden_size: hidden_dim,
196            intermediate_size: intermediate_dim,
197            num_layers,
198            num_heads: hidden_dim / 64, // Assuming head_dim = 64
199            num_key_value_heads: None,
200            vocab_size,
201            max_position_embeddings: 4096,
202            architecture: crate::model::ModelArchitecture::Llama,
203            quantization: None,
204            rope_theta: Some(10000.0),
205            rope_scaling: None,
206        };
207
208        Ok(Self {
209            metadata,
210            predictors,
211            ffns,
212            neuron_cache,
213            config: InferenceConfig::default(),
214            stats: GenerationStats::default(),
215            vocab_size,
216        })
217    }
218
219    /// Create from a GGUF model file
220    #[cfg(not(target_arch = "wasm32"))]
221    pub fn from_gguf(path: &std::path::Path) -> Result<Self> {
222        use std::fs;
223
224        let data = fs::read(path).map_err(|e| {
225            SparseInferenceError::Model(crate::error::ModelError::LoadFailed(e.to_string()))
226        })?;
227
228        Self::from_gguf_bytes(&data)
229    }
230
231    /// Create from GGUF model bytes
232    pub fn from_gguf_bytes(data: &[u8]) -> Result<Self> {
233        let gguf = GgufParser::parse(data)?;
234
235        // Extract model configuration from GGUF metadata
236        let hidden_dim = gguf.metadata.get("llama.embedding_length")
237            .and_then(|v| v.as_u32())
238            .unwrap_or(4096) as usize;
239
240        let intermediate_dim = gguf.metadata.get("llama.feed_forward_length")
241            .and_then(|v| v.as_u32())
242            .unwrap_or((hidden_dim * 4) as u32) as usize;
243
244        let num_layers = gguf.metadata.get("llama.block_count")
245            .and_then(|v| v.as_u32())
246            .unwrap_or(32) as usize;
247
248        let vocab_size = gguf.metadata.get("llama.vocab_size")
249            .and_then(|v| v.as_u32())
250            .unwrap_or(32000) as usize;
251
252        Self::new(num_layers, hidden_dim, intermediate_dim, vocab_size, 0.1)
253    }
254
255    /// Generate next token
256    pub fn next_token(&mut self, input_ids: &[u32], kv_cache: &mut KVCache) -> Result<u32> {
257        // Simplified next token prediction
258        // In production, this would:
259        // 1. Look up token embeddings
260        // 2. Apply rotary position embeddings
261        // 3. Run through transformer layers with sparse FFN
262        // 4. Compute logits and sample
263
264        let hidden_dim = self.metadata.hidden_size;
265
266        // Create mock hidden state from input
267        let mut hidden: Vec<f32> = input_ids.iter()
268            .map(|&t| (t as f32) / (self.vocab_size as f32))
269            .collect();
270        hidden.resize(hidden_dim, 0.0);
271
272        // Process through sparse FFN layers
273        for (layer_idx, (predictor, ffn)) in self.predictors.iter().zip(self.ffns.iter()).enumerate() {
274            // Predict active neurons
275            let active = predictor.predict(&hidden)?;
276
277            // Sparse FFN forward
278            hidden = ffn.forward_sparse(&hidden, &active)?;
279
280            // Update cache stats
281            self.neuron_cache.record_activations(&active);
282        }
283
284        // Compute logits (simplified - use output projection)
285        let logit_sum: f32 = hidden.iter().sum();
286        let next_token = ((logit_sum.abs() * 1000.0) as u32) % (self.vocab_size as u32);
287
288        self.stats.tokens_generated += 1;
289
290        Ok(next_token)
291    }
292
293    /// Generate multiple tokens
294    pub fn generate(
295        &mut self,
296        input_ids: &[u32],
297        config: &GenerationConfig,
298    ) -> Result<Vec<u32>> {
299        let mut output_ids = input_ids.to_vec();
300        let mut kv_cache = KVCache::new(
301            self.metadata.num_layers,
302            config.max_new_tokens + input_ids.len(),
303            self.metadata.hidden_size / self.metadata.num_heads,
304        );
305
306        let start_time = std::time::Instant::now();
307
308        for _ in 0..config.max_new_tokens {
309            let next_token = self.next_token(&output_ids, &mut kv_cache)?;
310
311            // Check for stop token
312            if config.stop_tokens.contains(&next_token) {
313                break;
314            }
315
316            output_ids.push(next_token);
317        }
318
319        let elapsed = start_time.elapsed();
320        self.stats.total_time_ms = elapsed.as_secs_f64() * 1000.0;
321        self.stats.avg_token_time_ms = self.stats.total_time_ms / self.stats.tokens_generated as f64;
322
323        Ok(output_ids)
324    }
325
326    /// Get model metadata
327    pub fn metadata(&self) -> &ModelMetadata {
328        &self.metadata
329    }
330
331    /// Get generation statistics
332    pub fn generation_stats(&self) -> &GenerationStats {
333        &self.stats
334    }
335
336    /// Set sparsity threshold
337    pub fn set_sparsity(&mut self, threshold: f32) {
338        self.config.sparsity_threshold = threshold;
339    }
340
341    /// Calibrate predictors with sample data
342    pub fn calibrate(&mut self, samples: &[Vec<f32>]) -> Result<()> {
343        for (predictor, ffn) in self.predictors.iter_mut().zip(self.ffns.iter()) {
344            // Generate activations for each sample
345            let activations: Vec<Vec<f32>> = samples.iter()
346                .map(|s| ffn.forward_dense(s))
347                .collect::<Result<Vec<_>>>()?;
348
349            predictor.calibrate(samples, &activations)?;
350        }
351        Ok(())
352    }
353
354    /// Reset KV cache (for new conversation)
355    pub fn reset(&mut self) {
356        self.stats = GenerationStats::default();
357        self.neuron_cache.clear();
358    }
359}
360
361/// Trait for inference backends (matches RuvLLM interface)
362pub trait InferenceBackend: Send + Sync {
363    /// Generate next token probabilities
364    fn forward(&mut self, input_ids: &[u32]) -> Result<Vec<f32>>;
365
366    /// Generate tokens
367    fn generate(&mut self, input_ids: &[u32], max_new_tokens: usize) -> Result<Vec<u32>>;
368
369    /// Get vocabulary size
370    fn vocab_size(&self) -> usize;
371
372    /// Backend name
373    fn name(&self) -> &str;
374}
375
376impl InferenceBackend for SparseInferenceBackend {
377    fn forward(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
378        // Return logits (simplified)
379        let hidden_dim = self.metadata.hidden_size;
380        let mut hidden: Vec<f32> = input_ids.iter()
381            .map(|&t| (t as f32) / (self.vocab_size as f32))
382            .collect();
383        hidden.resize(hidden_dim, 0.0);
384
385        for (predictor, ffn) in self.predictors.iter().zip(self.ffns.iter()) {
386            let active = predictor.predict(&hidden)?;
387            hidden = ffn.forward_sparse(&hidden, &active)?;
388        }
389
390        Ok(hidden)
391    }
392
393    fn generate(&mut self, input_ids: &[u32], max_new_tokens: usize) -> Result<Vec<u32>> {
394        let config = GenerationConfig {
395            max_new_tokens,
396            ..Default::default()
397        };
398        self.generate(input_ids, &config)
399    }
400
401    fn vocab_size(&self) -> usize {
402        self.vocab_size
403    }
404
405    fn name(&self) -> &str {
406        "sparse-inference"
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_backend_creation() {
416        let backend = SparseInferenceBackend::new(4, 256, 1024, 32000, 0.1);
417        assert!(backend.is_ok());
418
419        let backend = backend.unwrap();
420        assert_eq!(backend.metadata.num_layers, 4);
421        assert_eq!(backend.vocab_size(), 32000);
422    }
423
424    #[test]
425    fn test_next_token() {
426        // Use lower sparsity threshold to ensure enough neurons are active
427        let mut backend = SparseInferenceBackend::new(2, 64, 256, 1000, 0.001).unwrap();
428        let mut kv_cache = KVCache::new(2, 100, 64);
429
430        let result = backend.next_token(&[1, 2, 3], &mut kv_cache);
431        assert!(result.is_ok(), "next_token failed: {:?}", result.err());
432
433        let token = result.unwrap();
434        assert!(token < 1000);
435    }
436
437    #[test]
438    fn test_generate() {
439        // Use lower sparsity threshold to ensure enough neurons are active
440        let mut backend = SparseInferenceBackend::new(2, 64, 256, 1000, 0.001).unwrap();
441        let config = GenerationConfig {
442            max_new_tokens: 10,
443            ..Default::default()
444        };
445
446        let result = backend.generate(&[1, 2, 3], &config);
447        assert!(result.is_ok(), "generate failed: {:?}", result.err());
448
449        let output = result.unwrap();
450        assert!(output.len() >= 3); // At least input tokens
451        assert!(output.len() <= 13); // At most input + max_new_tokens
452    }
453
454    #[test]
455    fn test_kv_cache() {
456        let mut cache = KVCache::new(4, 100, 64);
457        assert!(cache.is_empty());
458
459        cache.append(0, vec![1.0; 64], vec![2.0; 64]);
460        assert_eq!(cache.len(), 1);
461
462        cache.clear();
463        assert!(cache.is_empty());
464    }
465}