ruvector_sparse_inference/integration/
ruvector.rs

1//! Ruvector EmbeddingProvider integration
2//!
3//! This module provides a sparse inference-based embedding provider that
4//! integrates with the Ruvector vector database ecosystem.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use ruvector_sparse_inference::integration::SparseEmbeddingProvider;
10//!
11//! let provider = SparseEmbeddingProvider::from_gguf("model.gguf")?;
12//! let embedding = provider.embed("Hello, world!")?;
13//! ```
14
15use crate::{
16    config::{ActivationType, SparsityConfig},
17    error::{Result, SparseInferenceError},
18    model::{GgufParser, InferenceConfig},
19    predictor::{LowRankPredictor, Predictor},
20    sparse::SparseFfn,
21    SparsityStats,
22};
23
24/// Sparse embedding provider for Ruvector integration
25///
26/// Implements the EmbeddingProvider interface using PowerInfer-style
27/// sparse inference for efficient embedding generation.
28pub struct SparseEmbeddingProvider {
29    /// Sparse FFN for inference
30    ffn: SparseFfn,
31    /// Activation predictor
32    predictor: LowRankPredictor,
33    /// Inference configuration
34    config: InferenceConfig,
35    /// Embedding dimension
36    embed_dim: usize,
37    /// Sparsity statistics
38    stats: SparsityStats,
39}
40
41impl SparseEmbeddingProvider {
42    /// Create a new sparse embedding provider with specified dimensions
43    pub fn new(
44        input_dim: usize,
45        hidden_dim: usize,
46        embed_dim: usize,
47        sparsity_ratio: f32,
48    ) -> Result<Self> {
49        // Use top-K selection based on sparsity ratio for reliable activation
50        // This ensures we always have some active neurons regardless of random init
51        let target_active = ((1.0 - sparsity_ratio) * hidden_dim as f32).max(1.0) as usize;
52        let sparsity_config = SparsityConfig {
53            threshold: None,
54            top_k: Some(target_active),
55            target_sparsity: Some(sparsity_ratio),
56            adaptive_threshold: false,
57        };
58
59        let predictor = LowRankPredictor::new(
60            input_dim,
61            hidden_dim,
62            hidden_dim / 32, // rank = hidden_dim / 32
63            sparsity_config,
64        )?;
65
66        let ffn = SparseFfn::new(
67            input_dim,
68            hidden_dim,
69            embed_dim,
70            ActivationType::Gelu,
71        )?;
72
73        Ok(Self {
74            ffn,
75            predictor,
76            config: InferenceConfig::default(),
77            embed_dim,
78            stats: SparsityStats {
79                average_active_ratio: 0.3,
80                min_active: 0,
81                max_active: hidden_dim,
82            },
83        })
84    }
85
86    /// Create from a GGUF model file
87    #[cfg(not(target_arch = "wasm32"))]
88    pub fn from_gguf(path: &std::path::Path) -> Result<Self> {
89        use std::fs;
90
91        let data = fs::read(path).map_err(|e| {
92            SparseInferenceError::Model(crate::error::ModelError::LoadFailed(e.to_string()))
93        })?;
94
95        Self::from_gguf_bytes(&data)
96    }
97
98    /// Create from GGUF model bytes
99    pub fn from_gguf_bytes(data: &[u8]) -> Result<Self> {
100        let gguf = GgufParser::parse(data)?;
101
102        // Extract dimensions from model metadata
103        let hidden_dim = gguf.metadata.get("llama.embedding_length")
104            .and_then(|v| v.as_u32())
105            .unwrap_or(4096) as usize;
106
107        let intermediate_dim = gguf.metadata.get("llama.feed_forward_length")
108            .and_then(|v| v.as_u32())
109            .unwrap_or((hidden_dim * 4) as u32) as usize;
110
111        Self::new(hidden_dim, intermediate_dim, hidden_dim, 0.1)
112    }
113
114    /// Generate embedding for input tokens
115    pub fn embed(&self, input: &[f32]) -> Result<Vec<f32>> {
116        // Predict active neurons
117        let active_neurons = self.predictor.predict(input)?;
118
119        // Compute sparse forward pass
120        let embedding = self.ffn.forward_sparse(input, &active_neurons)?;
121
122        // Normalize embedding (L2 normalization)
123        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
124        let normalized: Vec<f32> = if norm > 1e-8 {
125            embedding.iter().map(|x| x / norm).collect()
126        } else {
127            embedding
128        };
129
130        Ok(normalized)
131    }
132
133    /// Batch embed multiple inputs
134    pub fn embed_batch(&self, inputs: &[Vec<f32>]) -> Result<Vec<Vec<f32>>> {
135        inputs.iter()
136            .map(|input| self.embed(input))
137            .collect()
138    }
139
140    /// Get embedding dimension
141    pub fn embedding_dim(&self) -> usize {
142        self.embed_dim
143    }
144
145    /// Get sparsity statistics
146    pub fn sparsity_stats(&self) -> &SparsityStats {
147        &self.stats
148    }
149
150    /// Set sparsity threshold
151    pub fn set_sparsity_threshold(&mut self, threshold: f32) {
152        self.config.sparsity_threshold = threshold;
153    }
154
155    /// Calibrate the predictor with sample data
156    pub fn calibrate(&mut self, samples: &[Vec<f32>]) -> Result<()> {
157        // Generate activations for calibration
158        let activations: Vec<Vec<f32>> = samples.iter()
159            .map(|s| self.ffn.forward_dense(s))
160            .collect::<Result<Vec<_>>>()?;
161
162        // Calibrate predictor
163        self.predictor.calibrate(samples, &activations)?;
164
165        Ok(())
166    }
167}
168
169/// Trait for embedding providers (matches Ruvector interface)
170pub trait EmbeddingProvider: Send + Sync {
171    /// Generate embedding for text (requires tokenization)
172    fn embed_text(&self, text: &str) -> Result<Vec<f32>>;
173
174    /// Generate embedding for token ids
175    fn embed_tokens(&self, tokens: &[u32]) -> Result<Vec<f32>>;
176
177    /// Get embedding dimension
178    fn dimension(&self) -> usize;
179
180    /// Provider name
181    fn name(&self) -> &str;
182}
183
184impl EmbeddingProvider for SparseEmbeddingProvider {
185    fn embed_text(&self, _text: &str) -> Result<Vec<f32>> {
186        // Note: This requires a tokenizer - return placeholder for now
187        // In production, integrate with a tokenizer (e.g., tiktoken, sentencepiece)
188        Err(SparseInferenceError::Inference(
189            crate::error::InferenceError::InvalidInput(
190                "Text embedding requires tokenizer integration".to_string()
191            )
192        ))
193    }
194
195    fn embed_tokens(&self, tokens: &[u32]) -> Result<Vec<f32>> {
196        // Convert tokens to embeddings (simplified - real implementation needs token embedding lookup)
197        let input: Vec<f32> = tokens.iter()
198            .map(|&t| (t as f32) / 50000.0) // Normalize token ids
199            .collect();
200
201        // Pad or truncate to expected input dimension
202        let padded: Vec<f32> = if input.len() >= self.embed_dim {
203            input[..self.embed_dim].to_vec()
204        } else {
205            let mut padded = input;
206            padded.resize(self.embed_dim, 0.0);
207            padded
208        };
209
210        self.embed(&padded)
211    }
212
213    fn dimension(&self) -> usize {
214        self.embed_dim
215    }
216
217    fn name(&self) -> &str {
218        "sparse-inference"
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_provider_creation() {
228        let provider = SparseEmbeddingProvider::new(512, 2048, 512, 0.1);
229        assert!(provider.is_ok());
230
231        let provider = provider.unwrap();
232        assert_eq!(provider.embedding_dim(), 512);
233    }
234
235    #[test]
236    fn test_embed() {
237        // Use lower sparsity threshold to ensure enough neurons are active
238        let provider = SparseEmbeddingProvider::new(64, 256, 64, 0.001).unwrap();
239        // Use varied input to get more neuron activations
240        let input: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 64.0).collect();
241
242        let embedding = provider.embed(&input);
243        assert!(embedding.is_ok(), "Embedding failed: {:?}", embedding.err());
244
245        let embedding = embedding.unwrap();
246        assert_eq!(embedding.len(), 64);
247
248        // Check L2 normalization
249        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
250        assert!((norm - 1.0).abs() < 0.01, "Norm is {}", norm);
251    }
252
253    #[test]
254    fn test_batch_embed() {
255        // Use lower sparsity threshold to ensure enough neurons are active
256        let provider = SparseEmbeddingProvider::new(64, 256, 64, 0.001).unwrap();
257        let inputs = vec![
258            (0..64).map(|i| i as f32 / 64.0).collect(),
259            (0..64).map(|i| (i as f32).sin()).collect(),
260            (0..64).map(|i| (i as f32).cos()).collect(),
261        ];
262
263        let embeddings = provider.embed_batch(&inputs);
264        assert!(embeddings.is_ok(), "Batch embed failed: {:?}", embeddings.err());
265
266        let embeddings = embeddings.unwrap();
267        assert_eq!(embeddings.len(), 3);
268    }
269}