ruvector_sparse_inference/integration/
ruvector.rs1use crate::{
16 config::{ActivationType, SparsityConfig},
17 error::{Result, SparseInferenceError},
18 model::{GgufParser, InferenceConfig},
19 predictor::{LowRankPredictor, Predictor},
20 sparse::SparseFfn,
21 SparsityStats,
22};
23
24pub struct SparseEmbeddingProvider {
29 ffn: SparseFfn,
31 predictor: LowRankPredictor,
33 config: InferenceConfig,
35 embed_dim: usize,
37 stats: SparsityStats,
39}
40
41impl SparseEmbeddingProvider {
42 pub fn new(
44 input_dim: usize,
45 hidden_dim: usize,
46 embed_dim: usize,
47 sparsity_ratio: f32,
48 ) -> Result<Self> {
49 let target_active = ((1.0 - sparsity_ratio) * hidden_dim as f32).max(1.0) as usize;
52 let sparsity_config = SparsityConfig {
53 threshold: None,
54 top_k: Some(target_active),
55 target_sparsity: Some(sparsity_ratio),
56 adaptive_threshold: false,
57 };
58
59 let predictor = LowRankPredictor::new(
60 input_dim,
61 hidden_dim,
62 hidden_dim / 32, sparsity_config,
64 )?;
65
66 let ffn = SparseFfn::new(
67 input_dim,
68 hidden_dim,
69 embed_dim,
70 ActivationType::Gelu,
71 )?;
72
73 Ok(Self {
74 ffn,
75 predictor,
76 config: InferenceConfig::default(),
77 embed_dim,
78 stats: SparsityStats {
79 average_active_ratio: 0.3,
80 min_active: 0,
81 max_active: hidden_dim,
82 },
83 })
84 }
85
86 #[cfg(not(target_arch = "wasm32"))]
88 pub fn from_gguf(path: &std::path::Path) -> Result<Self> {
89 use std::fs;
90
91 let data = fs::read(path).map_err(|e| {
92 SparseInferenceError::Model(crate::error::ModelError::LoadFailed(e.to_string()))
93 })?;
94
95 Self::from_gguf_bytes(&data)
96 }
97
98 pub fn from_gguf_bytes(data: &[u8]) -> Result<Self> {
100 let gguf = GgufParser::parse(data)?;
101
102 let hidden_dim = gguf.metadata.get("llama.embedding_length")
104 .and_then(|v| v.as_u32())
105 .unwrap_or(4096) as usize;
106
107 let intermediate_dim = gguf.metadata.get("llama.feed_forward_length")
108 .and_then(|v| v.as_u32())
109 .unwrap_or((hidden_dim * 4) as u32) as usize;
110
111 Self::new(hidden_dim, intermediate_dim, hidden_dim, 0.1)
112 }
113
114 pub fn embed(&self, input: &[f32]) -> Result<Vec<f32>> {
116 let active_neurons = self.predictor.predict(input)?;
118
119 let embedding = self.ffn.forward_sparse(input, &active_neurons)?;
121
122 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
124 let normalized: Vec<f32> = if norm > 1e-8 {
125 embedding.iter().map(|x| x / norm).collect()
126 } else {
127 embedding
128 };
129
130 Ok(normalized)
131 }
132
133 pub fn embed_batch(&self, inputs: &[Vec<f32>]) -> Result<Vec<Vec<f32>>> {
135 inputs.iter()
136 .map(|input| self.embed(input))
137 .collect()
138 }
139
140 pub fn embedding_dim(&self) -> usize {
142 self.embed_dim
143 }
144
145 pub fn sparsity_stats(&self) -> &SparsityStats {
147 &self.stats
148 }
149
150 pub fn set_sparsity_threshold(&mut self, threshold: f32) {
152 self.config.sparsity_threshold = threshold;
153 }
154
155 pub fn calibrate(&mut self, samples: &[Vec<f32>]) -> Result<()> {
157 let activations: Vec<Vec<f32>> = samples.iter()
159 .map(|s| self.ffn.forward_dense(s))
160 .collect::<Result<Vec<_>>>()?;
161
162 self.predictor.calibrate(samples, &activations)?;
164
165 Ok(())
166 }
167}
168
169pub trait EmbeddingProvider: Send + Sync {
171 fn embed_text(&self, text: &str) -> Result<Vec<f32>>;
173
174 fn embed_tokens(&self, tokens: &[u32]) -> Result<Vec<f32>>;
176
177 fn dimension(&self) -> usize;
179
180 fn name(&self) -> &str;
182}
183
184impl EmbeddingProvider for SparseEmbeddingProvider {
185 fn embed_text(&self, _text: &str) -> Result<Vec<f32>> {
186 Err(SparseInferenceError::Inference(
189 crate::error::InferenceError::InvalidInput(
190 "Text embedding requires tokenizer integration".to_string()
191 )
192 ))
193 }
194
195 fn embed_tokens(&self, tokens: &[u32]) -> Result<Vec<f32>> {
196 let input: Vec<f32> = tokens.iter()
198 .map(|&t| (t as f32) / 50000.0) .collect();
200
201 let padded: Vec<f32> = if input.len() >= self.embed_dim {
203 input[..self.embed_dim].to_vec()
204 } else {
205 let mut padded = input;
206 padded.resize(self.embed_dim, 0.0);
207 padded
208 };
209
210 self.embed(&padded)
211 }
212
213 fn dimension(&self) -> usize {
214 self.embed_dim
215 }
216
217 fn name(&self) -> &str {
218 "sparse-inference"
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_provider_creation() {
228 let provider = SparseEmbeddingProvider::new(512, 2048, 512, 0.1);
229 assert!(provider.is_ok());
230
231 let provider = provider.unwrap();
232 assert_eq!(provider.embedding_dim(), 512);
233 }
234
235 #[test]
236 fn test_embed() {
237 let provider = SparseEmbeddingProvider::new(64, 256, 64, 0.001).unwrap();
239 let input: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 64.0).collect();
241
242 let embedding = provider.embed(&input);
243 assert!(embedding.is_ok(), "Embedding failed: {:?}", embedding.err());
244
245 let embedding = embedding.unwrap();
246 assert_eq!(embedding.len(), 64);
247
248 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
250 assert!((norm - 1.0).abs() < 0.01, "Norm is {}", norm);
251 }
252
253 #[test]
254 fn test_batch_embed() {
255 let provider = SparseEmbeddingProvider::new(64, 256, 64, 0.001).unwrap();
257 let inputs = vec![
258 (0..64).map(|i| i as f32 / 64.0).collect(),
259 (0..64).map(|i| (i as f32).sin()).collect(),
260 (0..64).map(|i| (i as f32).cos()).collect(),
261 ];
262
263 let embeddings = provider.embed_batch(&inputs);
264 assert!(embeddings.is_ok(), "Batch embed failed: {:?}", embeddings.err());
265
266 let embeddings = embeddings.unwrap();
267 assert_eq!(embeddings.len(), 3);
268 }
269}