1use ruvector_gnn::{
11 differentiable_search as core_differentiable_search,
12 hierarchical_forward as core_hierarchical_forward, CompressedTensor, CompressionLevel,
13 RuvectorLayer, TensorCompress,
14};
15use serde::{Deserialize, Serialize};
16use wasm_bindgen::prelude::*;
17
18#[wasm_bindgen(start)]
20pub fn init() {
21 #[cfg(feature = "console_error_panic_hook")]
22 console_error_panic_hook::set_once();
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
31#[wasm_bindgen]
32pub struct SearchConfig {
33 pub k: usize,
35 pub temperature: f32,
37}
38
39#[wasm_bindgen]
40impl SearchConfig {
41 #[wasm_bindgen(constructor)]
43 pub fn new(k: usize, temperature: f32) -> Self {
44 Self { k, temperature }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50struct SearchResultInternal {
51 indices: Vec<usize>,
53 weights: Vec<f32>,
55}
56
57#[wasm_bindgen]
63pub struct JsRuvectorLayer {
64 inner: RuvectorLayer,
65 hidden_dim: usize,
66}
67
68#[wasm_bindgen]
69impl JsRuvectorLayer {
70 #[wasm_bindgen(constructor)]
78 pub fn new(
79 input_dim: usize,
80 hidden_dim: usize,
81 heads: usize,
82 dropout: f32,
83 ) -> Result<JsRuvectorLayer, JsValue> {
84 if dropout < 0.0 || dropout > 1.0 {
85 return Err(JsValue::from_str("Dropout must be between 0.0 and 1.0"));
86 }
87
88 Ok(JsRuvectorLayer {
89 inner: RuvectorLayer::new(input_dim, hidden_dim, heads, dropout),
90 hidden_dim,
91 })
92 }
93
94 #[wasm_bindgen]
104 pub fn forward(
105 &self,
106 node_embedding: Vec<f32>,
107 neighbor_embeddings: JsValue,
108 edge_weights: Vec<f32>,
109 ) -> Result<Vec<f32>, JsValue> {
110 let neighbors: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(neighbor_embeddings)
112 .map_err(|e| {
113 JsValue::from_str(&format!("Failed to parse neighbor embeddings: {}", e))
114 })?;
115
116 if neighbors.len() != edge_weights.len() {
118 return Err(JsValue::from_str(&format!(
119 "Number of neighbors ({}) must match number of edge weights ({})",
120 neighbors.len(),
121 edge_weights.len()
122 )));
123 }
124
125 let result = self
127 .inner
128 .forward(&node_embedding, &neighbors, &edge_weights);
129
130 Ok(result)
131 }
132
133 #[wasm_bindgen(getter, js_name = outputDim)]
135 pub fn output_dim(&self) -> usize {
136 self.hidden_dim
137 }
138}
139
140#[wasm_bindgen]
146pub struct JsTensorCompress {
147 inner: TensorCompress,
148}
149
150#[wasm_bindgen]
151impl JsTensorCompress {
152 #[wasm_bindgen(constructor)]
154 pub fn new() -> Self {
155 Self {
156 inner: TensorCompress::new(),
157 }
158 }
159
160 #[wasm_bindgen]
174 pub fn compress(&self, embedding: Vec<f32>, access_freq: f32) -> Result<JsValue, JsValue> {
175 let compressed = self
176 .inner
177 .compress(&embedding, access_freq)
178 .map_err(|e| JsValue::from_str(&format!("Compression failed: {}", e)))?;
179
180 serde_wasm_bindgen::to_value(&compressed)
182 .map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
183 }
184
185 #[wasm_bindgen(js_name = compressWithLevel)]
194 pub fn compress_with_level(
195 &self,
196 embedding: Vec<f32>,
197 level: &str,
198 ) -> Result<JsValue, JsValue> {
199 let compression_level = match level {
200 "none" => CompressionLevel::None,
201 "half" => CompressionLevel::Half { scale: 1.0 },
202 "pq8" => CompressionLevel::PQ8 {
203 subvectors: 8,
204 centroids: 16,
205 },
206 "pq4" => CompressionLevel::PQ4 {
207 subvectors: 8,
208 outlier_threshold: 3.0,
209 },
210 "binary" => CompressionLevel::Binary { threshold: 0.0 },
211 _ => {
212 return Err(JsValue::from_str(&format!(
213 "Unknown compression level: {}",
214 level
215 )))
216 }
217 };
218
219 let compressed = self
220 .inner
221 .compress_with_level(&embedding, &compression_level)
222 .map_err(|e| JsValue::from_str(&format!("Compression failed: {}", e)))?;
223
224 serde_wasm_bindgen::to_value(&compressed)
226 .map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
227 }
228
229 #[wasm_bindgen]
237 pub fn decompress(&self, compressed: JsValue) -> Result<Vec<f32>, JsValue> {
238 let compressed_tensor: CompressedTensor = serde_wasm_bindgen::from_value(compressed)
239 .map_err(|e| JsValue::from_str(&format!("Deserialization failed: {}", e)))?;
240
241 let decompressed = self
242 .inner
243 .decompress(&compressed_tensor)
244 .map_err(|e| JsValue::from_str(&format!("Decompression failed: {}", e)))?;
245
246 Ok(decompressed)
247 }
248
249 #[wasm_bindgen(js_name = getCompressionRatio)]
257 pub fn get_compression_ratio(&self, access_freq: f32) -> f32 {
258 if access_freq > 0.8 {
259 1.0 } else if access_freq > 0.4 {
261 2.0 } else if access_freq > 0.1 {
263 4.0 } else if access_freq > 0.01 {
265 8.0 } else {
267 32.0 }
269 }
270}
271
272#[wasm_bindgen(js_name = differentiableSearch)]
286pub fn differentiable_search(
287 query: Vec<f32>,
288 candidate_embeddings: JsValue,
289 config: &SearchConfig,
290) -> Result<JsValue, JsValue> {
291 let candidates: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(candidate_embeddings)
293 .map_err(|e| JsValue::from_str(&format!("Failed to parse candidate embeddings: {}", e)))?;
294
295 let (indices, weights) =
297 core_differentiable_search(&query, &candidates, config.k, config.temperature);
298
299 let result = SearchResultInternal { indices, weights };
300 serde_wasm_bindgen::to_value(&result)
301 .map_err(|e| JsValue::from_str(&format!("Failed to serialize result: {}", e)))
302}
303
304#[wasm_bindgen(js_name = hierarchicalForward)]
314pub fn hierarchical_forward(
315 query: Vec<f32>,
316 layer_embeddings: JsValue,
317 gnn_layers: Vec<JsRuvectorLayer>,
318) -> Result<Vec<f32>, JsValue> {
319 let embeddings: Vec<Vec<Vec<f32>>> = serde_wasm_bindgen::from_value(layer_embeddings)
321 .map_err(|e| JsValue::from_str(&format!("Failed to parse layer embeddings: {}", e)))?;
322
323 let core_layers: Vec<RuvectorLayer> = gnn_layers.iter().map(|l| l.inner.clone()).collect();
325
326 let result = core_hierarchical_forward(&query, &embeddings, &core_layers);
328
329 Ok(result)
330}
331
332#[wasm_bindgen]
338pub fn version() -> String {
339 env!("CARGO_PKG_VERSION").to_string()
340}
341
342#[wasm_bindgen(js_name = cosineSimilarity)]
351pub fn cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> Result<f32, JsValue> {
352 if a.len() != b.len() {
353 return Err(JsValue::from_str(&format!(
354 "Vector dimensions must match: {} vs {}",
355 a.len(),
356 b.len()
357 )));
358 }
359
360 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
361 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
362 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
363
364 if norm_a == 0.0 || norm_b == 0.0 {
365 Ok(0.0)
366 } else {
367 Ok(dot_product / (norm_a * norm_b))
368 }
369}
370
371#[cfg(test)]
376mod tests {
377 use super::*;
378 use wasm_bindgen_test::*;
379
380 wasm_bindgen_test_configure!(run_in_browser);
381
382 #[wasm_bindgen_test]
383 fn test_version() {
384 assert!(!version().is_empty());
385 }
386
387 #[wasm_bindgen_test]
388 fn test_ruvector_layer_creation() {
389 let layer = JsRuvectorLayer::new(4, 8, 2, 0.1);
390 assert!(layer.is_ok());
391 }
392
393 #[wasm_bindgen_test]
394 fn test_tensor_compress_creation() {
395 let compressor = JsTensorCompress::new();
396 assert_eq!(compressor.get_compression_ratio(1.0), 1.0);
397 assert_eq!(compressor.get_compression_ratio(0.5), 2.0);
398 }
399
400 #[wasm_bindgen_test]
401 fn test_cosine_similarity() {
402 let a = vec![1.0, 0.0, 0.0];
403 let b = vec![1.0, 0.0, 0.0];
404 let sim = cosine_similarity(a, b).unwrap();
405 assert!((sim - 1.0).abs() < 1e-6);
406 }
407
408 #[wasm_bindgen_test]
409 fn test_search_config() {
410 let config = SearchConfig::new(5, 1.0);
411 assert_eq!(config.k, 5);
412 assert_eq!(config.temperature, 1.0);
413 }
414}