ruvector_sparse_inference/integration/
ruvllm.rs1use crate::{
16 config::{ActivationType, SparsityConfig, CacheConfig},
17 error::{Result, SparseInferenceError},
18 model::{GgufParser, GgufModel, InferenceConfig, ModelMetadata, ModelRunner},
19 memory::NeuronCache,
20 predictor::{LowRankPredictor, Predictor},
21 sparse::SparseFfn,
22};
23
24#[derive(Debug)]
26pub struct KVCache {
27 keys: Vec<Vec<Vec<f32>>>,
29 values: Vec<Vec<Vec<f32>>>,
31 max_length: usize,
33 current_length: usize,
35}
36
37impl KVCache {
38 pub fn new(num_layers: usize, max_length: usize, head_dim: usize) -> Self {
40 Self {
41 keys: vec![Vec::new(); num_layers],
42 values: vec![Vec::new(); num_layers],
43 max_length,
44 current_length: 0,
45 }
46 }
47
48 pub fn clear(&mut self) {
50 for layer_keys in &mut self.keys {
51 layer_keys.clear();
52 }
53 for layer_values in &mut self.values {
54 layer_values.clear();
55 }
56 self.current_length = 0;
57 }
58
59 pub fn len(&self) -> usize {
61 self.current_length
62 }
63
64 pub fn is_empty(&self) -> bool {
66 self.current_length == 0
67 }
68
69 pub fn append(&mut self, layer: usize, key: Vec<f32>, value: Vec<f32>) {
71 if layer < self.keys.len() {
72 self.keys[layer].push(key);
73 self.values[layer].push(value);
74 if layer == 0 {
75 self.current_length += 1;
76 }
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct GenerationConfig {
84 pub max_new_tokens: usize,
86 pub temperature: f32,
88 pub top_k: usize,
90 pub top_p: f32,
92 pub repetition_penalty: f32,
94 pub stop_tokens: Vec<u32>,
96}
97
98impl Default for GenerationConfig {
99 fn default() -> Self {
100 Self {
101 max_new_tokens: 100,
102 temperature: 0.7,
103 top_k: 50,
104 top_p: 0.9,
105 repetition_penalty: 1.1,
106 stop_tokens: vec![2], }
108 }
109}
110
111#[derive(Debug, Clone, Default)]
113pub struct GenerationStats {
114 pub tokens_generated: usize,
116 pub avg_token_time_ms: f64,
118 pub avg_sparsity: f64,
120 pub total_time_ms: f64,
122}
123
124pub struct SparseInferenceBackend {
126 metadata: ModelMetadata,
128 predictors: Vec<LowRankPredictor>,
130 ffns: Vec<SparseFfn>,
132 neuron_cache: NeuronCache,
134 config: InferenceConfig,
136 stats: GenerationStats,
138 vocab_size: usize,
140}
141
142impl SparseInferenceBackend {
143 pub fn new(
145 num_layers: usize,
146 hidden_dim: usize,
147 intermediate_dim: usize,
148 vocab_size: usize,
149 sparsity_ratio: f32,
150 ) -> Result<Self> {
151 let target_active = ((1.0 - sparsity_ratio) * intermediate_dim as f32).max(1.0) as usize;
153 let sparsity_config = SparsityConfig {
154 threshold: None,
155 top_k: Some(target_active),
156 target_sparsity: Some(sparsity_ratio),
157 adaptive_threshold: false,
158 };
159
160 let cache_config = CacheConfig {
161 hot_neuron_fraction: 0.2, max_cold_cache_size: 1000,
163 cache_strategy: crate::config::CacheStrategy::Lru,
164 hot_neuron_count: (intermediate_dim as f32 * 0.2) as usize,
165 lru_cache_size: 4096,
166 use_mmap: false,
167 hot_threshold: 0.5,
168 };
169
170 let mut predictors = Vec::with_capacity(num_layers);
172 let mut ffns = Vec::with_capacity(num_layers);
173
174 for _ in 0..num_layers {
175 let predictor = LowRankPredictor::new(
176 hidden_dim,
177 intermediate_dim,
178 intermediate_dim / 32,
179 sparsity_config.clone(),
180 )?;
181 predictors.push(predictor);
182
183 let ffn = SparseFfn::new(
184 hidden_dim,
185 intermediate_dim,
186 hidden_dim,
187 ActivationType::Silu, )?;
189 ffns.push(ffn);
190 }
191
192 let neuron_cache = NeuronCache::new(intermediate_dim, cache_config);
193
194 let metadata = ModelMetadata {
195 hidden_size: hidden_dim,
196 intermediate_size: intermediate_dim,
197 num_layers,
198 num_heads: hidden_dim / 64, num_key_value_heads: None,
200 vocab_size,
201 max_position_embeddings: 4096,
202 architecture: crate::model::ModelArchitecture::Llama,
203 quantization: None,
204 rope_theta: Some(10000.0),
205 rope_scaling: None,
206 };
207
208 Ok(Self {
209 metadata,
210 predictors,
211 ffns,
212 neuron_cache,
213 config: InferenceConfig::default(),
214 stats: GenerationStats::default(),
215 vocab_size,
216 })
217 }
218
219 #[cfg(not(target_arch = "wasm32"))]
221 pub fn from_gguf(path: &std::path::Path) -> Result<Self> {
222 use std::fs;
223
224 let data = fs::read(path).map_err(|e| {
225 SparseInferenceError::Model(crate::error::ModelError::LoadFailed(e.to_string()))
226 })?;
227
228 Self::from_gguf_bytes(&data)
229 }
230
231 pub fn from_gguf_bytes(data: &[u8]) -> Result<Self> {
233 let gguf = GgufParser::parse(data)?;
234
235 let hidden_dim = gguf.metadata.get("llama.embedding_length")
237 .and_then(|v| v.as_u32())
238 .unwrap_or(4096) as usize;
239
240 let intermediate_dim = gguf.metadata.get("llama.feed_forward_length")
241 .and_then(|v| v.as_u32())
242 .unwrap_or((hidden_dim * 4) as u32) as usize;
243
244 let num_layers = gguf.metadata.get("llama.block_count")
245 .and_then(|v| v.as_u32())
246 .unwrap_or(32) as usize;
247
248 let vocab_size = gguf.metadata.get("llama.vocab_size")
249 .and_then(|v| v.as_u32())
250 .unwrap_or(32000) as usize;
251
252 Self::new(num_layers, hidden_dim, intermediate_dim, vocab_size, 0.1)
253 }
254
255 pub fn next_token(&mut self, input_ids: &[u32], kv_cache: &mut KVCache) -> Result<u32> {
257 let hidden_dim = self.metadata.hidden_size;
265
266 let mut hidden: Vec<f32> = input_ids.iter()
268 .map(|&t| (t as f32) / (self.vocab_size as f32))
269 .collect();
270 hidden.resize(hidden_dim, 0.0);
271
272 for (layer_idx, (predictor, ffn)) in self.predictors.iter().zip(self.ffns.iter()).enumerate() {
274 let active = predictor.predict(&hidden)?;
276
277 hidden = ffn.forward_sparse(&hidden, &active)?;
279
280 self.neuron_cache.record_activations(&active);
282 }
283
284 let logit_sum: f32 = hidden.iter().sum();
286 let next_token = ((logit_sum.abs() * 1000.0) as u32) % (self.vocab_size as u32);
287
288 self.stats.tokens_generated += 1;
289
290 Ok(next_token)
291 }
292
293 pub fn generate(
295 &mut self,
296 input_ids: &[u32],
297 config: &GenerationConfig,
298 ) -> Result<Vec<u32>> {
299 let mut output_ids = input_ids.to_vec();
300 let mut kv_cache = KVCache::new(
301 self.metadata.num_layers,
302 config.max_new_tokens + input_ids.len(),
303 self.metadata.hidden_size / self.metadata.num_heads,
304 );
305
306 let start_time = std::time::Instant::now();
307
308 for _ in 0..config.max_new_tokens {
309 let next_token = self.next_token(&output_ids, &mut kv_cache)?;
310
311 if config.stop_tokens.contains(&next_token) {
313 break;
314 }
315
316 output_ids.push(next_token);
317 }
318
319 let elapsed = start_time.elapsed();
320 self.stats.total_time_ms = elapsed.as_secs_f64() * 1000.0;
321 self.stats.avg_token_time_ms = self.stats.total_time_ms / self.stats.tokens_generated as f64;
322
323 Ok(output_ids)
324 }
325
326 pub fn metadata(&self) -> &ModelMetadata {
328 &self.metadata
329 }
330
331 pub fn generation_stats(&self) -> &GenerationStats {
333 &self.stats
334 }
335
336 pub fn set_sparsity(&mut self, threshold: f32) {
338 self.config.sparsity_threshold = threshold;
339 }
340
341 pub fn calibrate(&mut self, samples: &[Vec<f32>]) -> Result<()> {
343 for (predictor, ffn) in self.predictors.iter_mut().zip(self.ffns.iter()) {
344 let activations: Vec<Vec<f32>> = samples.iter()
346 .map(|s| ffn.forward_dense(s))
347 .collect::<Result<Vec<_>>>()?;
348
349 predictor.calibrate(samples, &activations)?;
350 }
351 Ok(())
352 }
353
354 pub fn reset(&mut self) {
356 self.stats = GenerationStats::default();
357 self.neuron_cache.clear();
358 }
359}
360
361pub trait InferenceBackend: Send + Sync {
363 fn forward(&mut self, input_ids: &[u32]) -> Result<Vec<f32>>;
365
366 fn generate(&mut self, input_ids: &[u32], max_new_tokens: usize) -> Result<Vec<u32>>;
368
369 fn vocab_size(&self) -> usize;
371
372 fn name(&self) -> &str;
374}
375
376impl InferenceBackend for SparseInferenceBackend {
377 fn forward(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
378 let hidden_dim = self.metadata.hidden_size;
380 let mut hidden: Vec<f32> = input_ids.iter()
381 .map(|&t| (t as f32) / (self.vocab_size as f32))
382 .collect();
383 hidden.resize(hidden_dim, 0.0);
384
385 for (predictor, ffn) in self.predictors.iter().zip(self.ffns.iter()) {
386 let active = predictor.predict(&hidden)?;
387 hidden = ffn.forward_sparse(&hidden, &active)?;
388 }
389
390 Ok(hidden)
391 }
392
393 fn generate(&mut self, input_ids: &[u32], max_new_tokens: usize) -> Result<Vec<u32>> {
394 let config = GenerationConfig {
395 max_new_tokens,
396 ..Default::default()
397 };
398 self.generate(input_ids, &config)
399 }
400
401 fn vocab_size(&self) -> usize {
402 self.vocab_size
403 }
404
405 fn name(&self) -> &str {
406 "sparse-inference"
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_backend_creation() {
416 let backend = SparseInferenceBackend::new(4, 256, 1024, 32000, 0.1);
417 assert!(backend.is_ok());
418
419 let backend = backend.unwrap();
420 assert_eq!(backend.metadata.num_layers, 4);
421 assert_eq!(backend.vocab_size(), 32000);
422 }
423
424 #[test]
425 fn test_next_token() {
426 let mut backend = SparseInferenceBackend::new(2, 64, 256, 1000, 0.001).unwrap();
428 let mut kv_cache = KVCache::new(2, 100, 64);
429
430 let result = backend.next_token(&[1, 2, 3], &mut kv_cache);
431 assert!(result.is_ok(), "next_token failed: {:?}", result.err());
432
433 let token = result.unwrap();
434 assert!(token < 1000);
435 }
436
437 #[test]
438 fn test_generate() {
439 let mut backend = SparseInferenceBackend::new(2, 64, 256, 1000, 0.001).unwrap();
441 let config = GenerationConfig {
442 max_new_tokens: 10,
443 ..Default::default()
444 };
445
446 let result = backend.generate(&[1, 2, 3], &config);
447 assert!(result.is_ok(), "generate failed: {:?}", result.err());
448
449 let output = result.unwrap();
450 assert!(output.len() >= 3); assert!(output.len() <= 13); }
453
454 #[test]
455 fn test_kv_cache() {
456 let mut cache = KVCache::new(4, 100, 64);
457 assert!(cache.is_empty());
458
459 cache.append(0, vec![1.0; 64], vec![2.0; 64]);
460 assert_eq!(cache.len(), 1);
461
462 cache.clear();
463 assert!(cache.is_empty());
464 }
465}