ruvector_gnn_wasm/
lib.rs

1//! WebAssembly bindings for RuVector GNN
2//!
3//! This module provides high-performance browser bindings for Graph Neural Network
4//! operations on HNSW topology, including:
5//! - GNN layer forward passes
6//! - Tensor compression with adaptive level selection
7//! - Differentiable search with soft attention
8//! - Hierarchical forward propagation
9
10use ruvector_gnn::{
11    differentiable_search as core_differentiable_search,
12    hierarchical_forward as core_hierarchical_forward, CompressedTensor, CompressionLevel,
13    RuvectorLayer, TensorCompress,
14};
15use serde::{Deserialize, Serialize};
16use wasm_bindgen::prelude::*;
17
18/// Initialize panic hook for better error messages
19#[wasm_bindgen(start)]
20pub fn init() {
21    #[cfg(feature = "console_error_panic_hook")]
22    console_error_panic_hook::set_once();
23}
24
25// ============================================================================
26// Type Definitions for WASM
27// ============================================================================
28
29/// Query configuration for differentiable search
30#[derive(Debug, Clone, Serialize, Deserialize)]
31#[wasm_bindgen]
32pub struct SearchConfig {
33    /// Number of top results to return
34    pub k: usize,
35    /// Temperature for softmax (lower = sharper, higher = smoother)
36    pub temperature: f32,
37}
38
39#[wasm_bindgen]
40impl SearchConfig {
41    /// Create a new search configuration
42    #[wasm_bindgen(constructor)]
43    pub fn new(k: usize, temperature: f32) -> Self {
44        Self { k, temperature }
45    }
46}
47
48/// Search results with indices and weights (internal)
49#[derive(Debug, Clone, Serialize, Deserialize)]
50struct SearchResultInternal {
51    /// Indices of top-k candidates
52    indices: Vec<usize>,
53    /// Soft weights for each result
54    weights: Vec<f32>,
55}
56
57// ============================================================================
58// JsRuvectorLayer - GNN Layer Wrapper
59// ============================================================================
60
61/// Graph Neural Network layer for HNSW topology
62#[wasm_bindgen]
63pub struct JsRuvectorLayer {
64    inner: RuvectorLayer,
65    hidden_dim: usize,
66}
67
68#[wasm_bindgen]
69impl JsRuvectorLayer {
70    /// Create a new GNN layer
71    ///
72    /// # Arguments
73    /// * `input_dim` - Dimension of input node embeddings
74    /// * `hidden_dim` - Dimension of hidden representations
75    /// * `heads` - Number of attention heads
76    /// * `dropout` - Dropout rate (0.0 to 1.0)
77    #[wasm_bindgen(constructor)]
78    pub fn new(
79        input_dim: usize,
80        hidden_dim: usize,
81        heads: usize,
82        dropout: f32,
83    ) -> Result<JsRuvectorLayer, JsValue> {
84        if dropout < 0.0 || dropout > 1.0 {
85            return Err(JsValue::from_str("Dropout must be between 0.0 and 1.0"));
86        }
87
88        Ok(JsRuvectorLayer {
89            inner: RuvectorLayer::new(input_dim, hidden_dim, heads, dropout),
90            hidden_dim,
91        })
92    }
93
94    /// Forward pass through the GNN layer
95    ///
96    /// # Arguments
97    /// * `node_embedding` - Current node's embedding (Float32Array)
98    /// * `neighbor_embeddings` - Embeddings of neighbor nodes (array of Float32Arrays)
99    /// * `edge_weights` - Weights of edges to neighbors (Float32Array)
100    ///
101    /// # Returns
102    /// Updated node embedding (Float32Array)
103    #[wasm_bindgen]
104    pub fn forward(
105        &self,
106        node_embedding: Vec<f32>,
107        neighbor_embeddings: JsValue,
108        edge_weights: Vec<f32>,
109    ) -> Result<Vec<f32>, JsValue> {
110        // Convert neighbor embeddings from JS value
111        let neighbors: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(neighbor_embeddings)
112            .map_err(|e| {
113                JsValue::from_str(&format!("Failed to parse neighbor embeddings: {}", e))
114            })?;
115
116        // Validate inputs
117        if neighbors.len() != edge_weights.len() {
118            return Err(JsValue::from_str(&format!(
119                "Number of neighbors ({}) must match number of edge weights ({})",
120                neighbors.len(),
121                edge_weights.len()
122            )));
123        }
124
125        // Call core forward
126        let result = self
127            .inner
128            .forward(&node_embedding, &neighbors, &edge_weights);
129
130        Ok(result)
131    }
132
133    /// Get the output dimension of this layer
134    #[wasm_bindgen(getter, js_name = outputDim)]
135    pub fn output_dim(&self) -> usize {
136        self.hidden_dim
137    }
138}
139
140// ============================================================================
141// JsTensorCompress - Tensor Compression Wrapper
142// ============================================================================
143
144/// Tensor compressor with adaptive level selection
145#[wasm_bindgen]
146pub struct JsTensorCompress {
147    inner: TensorCompress,
148}
149
150#[wasm_bindgen]
151impl JsTensorCompress {
152    /// Create a new tensor compressor
153    #[wasm_bindgen(constructor)]
154    pub fn new() -> Self {
155        Self {
156            inner: TensorCompress::new(),
157        }
158    }
159
160    /// Compress an embedding based on access frequency
161    ///
162    /// # Arguments
163    /// * `embedding` - The input embedding vector (Float32Array)
164    /// * `access_freq` - Access frequency in range [0.0, 1.0]
165    ///   - f > 0.8: Full precision (hot data)
166    ///   - f > 0.4: Half precision (warm data)
167    ///   - f > 0.1: 8-bit PQ (cool data)
168    ///   - f > 0.01: 4-bit PQ (cold data)
169    ///   - f <= 0.01: Binary (archive)
170    ///
171    /// # Returns
172    /// Compressed tensor as JsValue
173    #[wasm_bindgen]
174    pub fn compress(&self, embedding: Vec<f32>, access_freq: f32) -> Result<JsValue, JsValue> {
175        let compressed = self
176            .inner
177            .compress(&embedding, access_freq)
178            .map_err(|e| JsValue::from_str(&format!("Compression failed: {}", e)))?;
179
180        // Serialize using serde_wasm_bindgen
181        serde_wasm_bindgen::to_value(&compressed)
182            .map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
183    }
184
185    /// Compress with explicit compression level
186    ///
187    /// # Arguments
188    /// * `embedding` - The input embedding vector
189    /// * `level` - Compression level ("none", "half", "pq8", "pq4", "binary")
190    ///
191    /// # Returns
192    /// Compressed tensor as JsValue
193    #[wasm_bindgen(js_name = compressWithLevel)]
194    pub fn compress_with_level(
195        &self,
196        embedding: Vec<f32>,
197        level: &str,
198    ) -> Result<JsValue, JsValue> {
199        let compression_level = match level {
200            "none" => CompressionLevel::None,
201            "half" => CompressionLevel::Half { scale: 1.0 },
202            "pq8" => CompressionLevel::PQ8 {
203                subvectors: 8,
204                centroids: 16,
205            },
206            "pq4" => CompressionLevel::PQ4 {
207                subvectors: 8,
208                outlier_threshold: 3.0,
209            },
210            "binary" => CompressionLevel::Binary { threshold: 0.0 },
211            _ => {
212                return Err(JsValue::from_str(&format!(
213                    "Unknown compression level: {}",
214                    level
215                )))
216            }
217        };
218
219        let compressed = self
220            .inner
221            .compress_with_level(&embedding, &compression_level)
222            .map_err(|e| JsValue::from_str(&format!("Compression failed: {}", e)))?;
223
224        // Serialize using serde_wasm_bindgen
225        serde_wasm_bindgen::to_value(&compressed)
226            .map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
227    }
228
229    /// Decompress a compressed tensor
230    ///
231    /// # Arguments
232    /// * `compressed` - Serialized compressed tensor (JsValue)
233    ///
234    /// # Returns
235    /// Decompressed embedding vector (Float32Array)
236    #[wasm_bindgen]
237    pub fn decompress(&self, compressed: JsValue) -> Result<Vec<f32>, JsValue> {
238        let compressed_tensor: CompressedTensor = serde_wasm_bindgen::from_value(compressed)
239            .map_err(|e| JsValue::from_str(&format!("Deserialization failed: {}", e)))?;
240
241        let decompressed = self
242            .inner
243            .decompress(&compressed_tensor)
244            .map_err(|e| JsValue::from_str(&format!("Decompression failed: {}", e)))?;
245
246        Ok(decompressed)
247    }
248
249    /// Get compression ratio estimate for a given access frequency
250    ///
251    /// # Arguments
252    /// * `access_freq` - Access frequency in range [0.0, 1.0]
253    ///
254    /// # Returns
255    /// Estimated compression ratio (original_size / compressed_size)
256    #[wasm_bindgen(js_name = getCompressionRatio)]
257    pub fn get_compression_ratio(&self, access_freq: f32) -> f32 {
258        if access_freq > 0.8 {
259            1.0 // No compression
260        } else if access_freq > 0.4 {
261            2.0 // Half precision
262        } else if access_freq > 0.1 {
263            4.0 // 8-bit PQ
264        } else if access_freq > 0.01 {
265            8.0 // 4-bit PQ
266        } else {
267            32.0 // Binary
268        }
269    }
270}
271
272// ============================================================================
273// Standalone Functions
274// ============================================================================
275
276/// Differentiable search using soft attention mechanism
277///
278/// # Arguments
279/// * `query` - The query vector (Float32Array)
280/// * `candidate_embeddings` - List of candidate embedding vectors (array of Float32Arrays)
281/// * `config` - Search configuration (k and temperature)
282///
283/// # Returns
284/// Object with indices and weights for top-k candidates
285#[wasm_bindgen(js_name = differentiableSearch)]
286pub fn differentiable_search(
287    query: Vec<f32>,
288    candidate_embeddings: JsValue,
289    config: &SearchConfig,
290) -> Result<JsValue, JsValue> {
291    // Convert candidate embeddings from JS value
292    let candidates: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(candidate_embeddings)
293        .map_err(|e| JsValue::from_str(&format!("Failed to parse candidate embeddings: {}", e)))?;
294
295    // Call core search function
296    let (indices, weights) =
297        core_differentiable_search(&query, &candidates, config.k, config.temperature);
298
299    let result = SearchResultInternal { indices, weights };
300    serde_wasm_bindgen::to_value(&result)
301        .map_err(|e| JsValue::from_str(&format!("Failed to serialize result: {}", e)))
302}
303
304/// Hierarchical forward pass through multiple GNN layers
305///
306/// # Arguments
307/// * `query` - The query vector (Float32Array)
308/// * `layer_embeddings` - Embeddings organized by layer (array of arrays of Float32Arrays)
309/// * `gnn_layers` - Array of GNN layers to process through
310///
311/// # Returns
312/// Final embedding after hierarchical processing (Float32Array)
313#[wasm_bindgen(js_name = hierarchicalForward)]
314pub fn hierarchical_forward(
315    query: Vec<f32>,
316    layer_embeddings: JsValue,
317    gnn_layers: Vec<JsRuvectorLayer>,
318) -> Result<Vec<f32>, JsValue> {
319    // Convert layer embeddings from JS value
320    let embeddings: Vec<Vec<Vec<f32>>> = serde_wasm_bindgen::from_value(layer_embeddings)
321        .map_err(|e| JsValue::from_str(&format!("Failed to parse layer embeddings: {}", e)))?;
322
323    // Extract inner layers
324    let core_layers: Vec<RuvectorLayer> = gnn_layers.iter().map(|l| l.inner.clone()).collect();
325
326    // Call core function
327    let result = core_hierarchical_forward(&query, &embeddings, &core_layers);
328
329    Ok(result)
330}
331
332// ============================================================================
333// Utility Functions
334// ============================================================================
335
336/// Get version information
337#[wasm_bindgen]
338pub fn version() -> String {
339    env!("CARGO_PKG_VERSION").to_string()
340}
341
342/// Compute cosine similarity between two vectors
343///
344/// # Arguments
345/// * `a` - First vector (Float32Array)
346/// * `b` - Second vector (Float32Array)
347///
348/// # Returns
349/// Cosine similarity score [-1.0, 1.0]
350#[wasm_bindgen(js_name = cosineSimilarity)]
351pub fn cosine_similarity(a: Vec<f32>, b: Vec<f32>) -> Result<f32, JsValue> {
352    if a.len() != b.len() {
353        return Err(JsValue::from_str(&format!(
354            "Vector dimensions must match: {} vs {}",
355            a.len(),
356            b.len()
357        )));
358    }
359
360    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
361    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
362    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
363
364    if norm_a == 0.0 || norm_b == 0.0 {
365        Ok(0.0)
366    } else {
367        Ok(dot_product / (norm_a * norm_b))
368    }
369}
370
371// ============================================================================
372// Tests
373// ============================================================================
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use wasm_bindgen_test::*;
379
380    wasm_bindgen_test_configure!(run_in_browser);
381
382    #[wasm_bindgen_test]
383    fn test_version() {
384        assert!(!version().is_empty());
385    }
386
387    #[wasm_bindgen_test]
388    fn test_ruvector_layer_creation() {
389        let layer = JsRuvectorLayer::new(4, 8, 2, 0.1);
390        assert!(layer.is_ok());
391    }
392
393    #[wasm_bindgen_test]
394    fn test_tensor_compress_creation() {
395        let compressor = JsTensorCompress::new();
396        assert_eq!(compressor.get_compression_ratio(1.0), 1.0);
397        assert_eq!(compressor.get_compression_ratio(0.5), 2.0);
398    }
399
400    #[wasm_bindgen_test]
401    fn test_cosine_similarity() {
402        let a = vec![1.0, 0.0, 0.0];
403        let b = vec![1.0, 0.0, 0.0];
404        let sim = cosine_similarity(a, b).unwrap();
405        assert!((sim - 1.0).abs() < 1e-6);
406    }
407
408    #[wasm_bindgen_test]
409    fn test_search_config() {
410        let config = SearchConfig::new(5, 1.0);
411        assert_eq!(config.k, 5);
412        assert_eq!(config.temperature, 1.0);
413    }
414}