ruvector_gnn_node/
lib.rs

1//! Node.js bindings for Ruvector GNN via NAPI-RS
2//!
3//! This module provides JavaScript bindings for the Ruvector GNN library,
4//! enabling graph neural network operations, tensor compression, and
5//! differentiable search in Node.js applications.
6
7#![deny(clippy::all)]
8
9use napi::bindgen_prelude::*;
10use napi_derive::napi;
11use ruvector_gnn::{
12    compress::{
13        CompressedTensor as RustCompressedTensor, CompressionLevel as RustCompressionLevel,
14        TensorCompress as RustTensorCompress,
15    },
16    layer::RuvectorLayer as RustRuvectorLayer,
17    search::{
18        differentiable_search as rust_differentiable_search,
19        hierarchical_forward as rust_hierarchical_forward,
20    },
21};
22
23// ==================== RuvectorLayer Bindings ====================
24
25/// Graph Neural Network layer for HNSW topology
26#[napi]
27pub struct RuvectorLayer {
28    inner: RustRuvectorLayer,
29}
30
31#[napi]
32impl RuvectorLayer {
33    /// Create a new Ruvector GNN layer
34    ///
35    /// # Arguments
36    /// * `input_dim` - Dimension of input node embeddings
37    /// * `hidden_dim` - Dimension of hidden representations
38    /// * `heads` - Number of attention heads
39    /// * `dropout` - Dropout rate (0.0 to 1.0)
40    ///
41    /// # Example
42    /// ```javascript
43    /// const layer = new RuvectorLayer(128, 256, 4, 0.1);
44    /// ```
45    #[napi(constructor)]
46    pub fn new(input_dim: u32, hidden_dim: u32, heads: u32, dropout: f64) -> Result<Self> {
47        if dropout < 0.0 || dropout > 1.0 {
48            return Err(Error::new(
49                Status::InvalidArg,
50                "Dropout must be between 0.0 and 1.0".to_string(),
51            ));
52        }
53
54        Ok(Self {
55            inner: RustRuvectorLayer::new(
56                input_dim as usize,
57                hidden_dim as usize,
58                heads as usize,
59                dropout as f32,
60            ),
61        })
62    }
63
64    /// Forward pass through the GNN layer
65    ///
66    /// # Arguments
67    /// * `node_embedding` - Current node's embedding
68    /// * `neighbor_embeddings` - Embeddings of neighbor nodes
69    /// * `edge_weights` - Weights of edges to neighbors
70    ///
71    /// # Returns
72    /// Updated node embedding
73    ///
74    /// # Example
75    /// ```javascript
76    /// const node = [1.0, 2.0, 3.0, 4.0];
77    /// const neighbors = [[0.5, 1.0, 1.5, 2.0], [2.0, 3.0, 4.0, 5.0]];
78    /// const weights = [0.3, 0.7];
79    /// const output = layer.forward(node, neighbors, weights);
80    /// ```
81    #[napi]
82    pub fn forward(
83        &self,
84        node_embedding: Vec<f64>,
85        neighbor_embeddings: Vec<Vec<f64>>,
86        edge_weights: Vec<f64>,
87    ) -> Result<Vec<f64>> {
88        // Convert f64 to f32
89        let node_f32: Vec<f32> = node_embedding.iter().map(|&x| x as f32).collect();
90        let neighbors_f32: Vec<Vec<f32>> = neighbor_embeddings
91            .iter()
92            .map(|v| v.iter().map(|&x| x as f32).collect())
93            .collect();
94        let weights_f32: Vec<f32> = edge_weights.iter().map(|&x| x as f32).collect();
95
96        let result = self.inner.forward(&node_f32, &neighbors_f32, &weights_f32);
97
98        // Convert back to f64
99        Ok(result.iter().map(|&x| x as f64).collect())
100    }
101
102    /// Serialize the layer to JSON
103    #[napi]
104    pub fn to_json(&self) -> Result<String> {
105        serde_json::to_string(&self.inner).map_err(|e| {
106            Error::new(
107                Status::GenericFailure,
108                format!("Serialization error: {}", e),
109            )
110        })
111    }
112
113    /// Deserialize the layer from JSON
114    #[napi(factory)]
115    pub fn from_json(json: String) -> Result<Self> {
116        let inner: RustRuvectorLayer = serde_json::from_str(&json).map_err(|e| {
117            Error::new(
118                Status::GenericFailure,
119                format!("Deserialization error: {}", e),
120            )
121        })?;
122        Ok(Self { inner })
123    }
124}
125
126// ==================== TensorCompress Bindings ====================
127
128/// Compression level for tensor compression
129#[napi(object)]
130pub struct CompressionLevelConfig {
131    /// Type of compression: "none", "half", "pq8", "pq4", "binary"
132    pub level_type: String,
133    /// Scale factor (for "half" compression)
134    pub scale: Option<f64>,
135    /// Number of subvectors (for PQ compression)
136    pub subvectors: Option<u32>,
137    /// Number of centroids (for PQ8)
138    pub centroids: Option<u32>,
139    /// Outlier threshold (for PQ4)
140    pub outlier_threshold: Option<f64>,
141    /// Binary threshold (for binary compression)
142    pub threshold: Option<f64>,
143}
144
145impl CompressionLevelConfig {
146    fn to_rust(&self) -> Result<RustCompressionLevel> {
147        match self.level_type.as_str() {
148            "none" => Ok(RustCompressionLevel::None),
149            "half" => Ok(RustCompressionLevel::Half {
150                scale: self.scale.unwrap_or(1.0) as f32,
151            }),
152            "pq8" => Ok(RustCompressionLevel::PQ8 {
153                subvectors: self.subvectors.unwrap_or(8) as u8,
154                centroids: self.centroids.unwrap_or(16) as u8,
155            }),
156            "pq4" => Ok(RustCompressionLevel::PQ4 {
157                subvectors: self.subvectors.unwrap_or(8) as u8,
158                outlier_threshold: self.outlier_threshold.unwrap_or(3.0) as f32,
159            }),
160            "binary" => Ok(RustCompressionLevel::Binary {
161                threshold: self.threshold.unwrap_or(0.0) as f32,
162            }),
163            _ => Err(Error::new(
164                Status::InvalidArg,
165                format!("Invalid compression level: {}", self.level_type),
166            )),
167        }
168    }
169}
170
171/// Tensor compressor with adaptive level selection
172#[napi]
173pub struct TensorCompress {
174    inner: RustTensorCompress,
175}
176
177#[napi]
178impl TensorCompress {
179    /// Create a new tensor compressor
180    ///
181    /// # Example
182    /// ```javascript
183    /// const compressor = new TensorCompress();
184    /// ```
185    #[napi(constructor)]
186    pub fn new() -> Self {
187        Self {
188            inner: RustTensorCompress::new(),
189        }
190    }
191
192    /// Compress an embedding based on access frequency
193    ///
194    /// # Arguments
195    /// * `embedding` - The input embedding vector
196    /// * `access_freq` - Access frequency in range [0.0, 1.0]
197    ///
198    /// # Returns
199    /// Compressed tensor as JSON string
200    ///
201    /// # Example
202    /// ```javascript
203    /// const embedding = [1.0, 2.0, 3.0, 4.0];
204    /// const compressed = compressor.compress(embedding, 0.5);
205    /// ```
206    #[napi]
207    pub fn compress(&self, embedding: Vec<f64>, access_freq: f64) -> Result<String> {
208        let embedding_f32: Vec<f32> = embedding.iter().map(|&x| x as f32).collect();
209
210        let compressed = self
211            .inner
212            .compress(&embedding_f32, access_freq as f32)
213            .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
214
215        serde_json::to_string(&compressed).map_err(|e| {
216            Error::new(
217                Status::GenericFailure,
218                format!("Serialization error: {}", e),
219            )
220        })
221    }
222
223    /// Compress with explicit compression level
224    ///
225    /// # Arguments
226    /// * `embedding` - The input embedding vector
227    /// * `level` - Compression level configuration
228    ///
229    /// # Returns
230    /// Compressed tensor as JSON string
231    ///
232    /// # Example
233    /// ```javascript
234    /// const embedding = [1.0, 2.0, 3.0, 4.0];
235    /// const level = { level_type: "half", scale: 1.0 };
236    /// const compressed = compressor.compressWithLevel(embedding, level);
237    /// ```
238    #[napi]
239    pub fn compress_with_level(
240        &self,
241        embedding: Vec<f64>,
242        level: CompressionLevelConfig,
243    ) -> Result<String> {
244        let embedding_f32: Vec<f32> = embedding.iter().map(|&x| x as f32).collect();
245        let rust_level = level.to_rust()?;
246
247        let compressed = self
248            .inner
249            .compress_with_level(&embedding_f32, &rust_level)
250            .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
251
252        serde_json::to_string(&compressed).map_err(|e| {
253            Error::new(
254                Status::GenericFailure,
255                format!("Serialization error: {}", e),
256            )
257        })
258    }
259
260    /// Decompress a compressed tensor
261    ///
262    /// # Arguments
263    /// * `compressed_json` - Compressed tensor as JSON string
264    ///
265    /// # Returns
266    /// Decompressed embedding vector
267    ///
268    /// # Example
269    /// ```javascript
270    /// const decompressed = compressor.decompress(compressed);
271    /// ```
272    #[napi]
273    pub fn decompress(&self, compressed_json: String) -> Result<Vec<f64>> {
274        let compressed: RustCompressedTensor =
275            serde_json::from_str(&compressed_json).map_err(|e| {
276                Error::new(
277                    Status::GenericFailure,
278                    format!("Deserialization error: {}", e),
279                )
280            })?;
281
282        let result = self.inner.decompress(&compressed).map_err(|e| {
283            Error::new(
284                Status::GenericFailure,
285                format!("Decompression error: {}", e),
286            )
287        })?;
288
289        Ok(result.iter().map(|&x| x as f64).collect())
290    }
291}
292
293// ==================== Search Functions ====================
294
295/// Result from differentiable search
296#[napi(object)]
297pub struct SearchResult {
298    /// Indices of top-k candidates
299    pub indices: Vec<u32>,
300    /// Soft weights for top-k candidates
301    pub weights: Vec<f64>,
302}
303
304/// Differentiable search using soft attention mechanism
305///
306/// # Arguments
307/// * `query` - The query vector
308/// * `candidate_embeddings` - List of candidate embedding vectors
309/// * `k` - Number of top results to return
310/// * `temperature` - Temperature for softmax (lower = sharper, higher = smoother)
311///
312/// # Returns
313/// Search result with indices and soft weights
314///
315/// # Example
316/// ```javascript
317/// const query = [1.0, 0.0, 0.0];
318/// const candidates = [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.0, 1.0, 0.0]];
319/// const result = differentiableSearch(query, candidates, 2, 1.0);
320/// console.log(result.indices); // [0, 1]
321/// console.log(result.weights); // [0.x, 0.y]
322/// ```
323#[napi]
324pub fn differentiable_search(
325    query: Vec<f64>,
326    candidate_embeddings: Vec<Vec<f64>>,
327    k: u32,
328    temperature: f64,
329) -> Result<SearchResult> {
330    let query_f32: Vec<f32> = query.iter().map(|&x| x as f32).collect();
331    let candidates_f32: Vec<Vec<f32>> = candidate_embeddings
332        .iter()
333        .map(|v| v.iter().map(|&x| x as f32).collect())
334        .collect();
335
336    let (indices, weights) =
337        rust_differentiable_search(&query_f32, &candidates_f32, k as usize, temperature as f32);
338
339    Ok(SearchResult {
340        indices: indices.iter().map(|&i| i as u32).collect(),
341        weights: weights.iter().map(|&w| w as f64).collect(),
342    })
343}
344
345/// Hierarchical forward pass through GNN layers
346///
347/// # Arguments
348/// * `query` - The query vector
349/// * `layer_embeddings` - Embeddings organized by layer
350/// * `gnn_layers_json` - JSON array of serialized GNN layers
351///
352/// # Returns
353/// Final embedding after hierarchical processing
354///
355/// # Example
356/// ```javascript
357/// const query = [1.0, 0.0];
358/// const layerEmbeddings = [[[1.0, 0.0], [0.0, 1.0]]];
359/// const layer1 = new RuvectorLayer(2, 2, 1, 0.0);
360/// const layers = [layer1.toJson()];
361/// const result = hierarchicalForward(query, layerEmbeddings, layers);
362/// ```
363#[napi]
364pub fn hierarchical_forward(
365    query: Vec<f64>,
366    layer_embeddings: Vec<Vec<Vec<f64>>>,
367    gnn_layers_json: Vec<String>,
368) -> Result<Vec<f64>> {
369    let query_f32: Vec<f32> = query.iter().map(|&x| x as f32).collect();
370
371    let embeddings_f32: Vec<Vec<Vec<f32>>> = layer_embeddings
372        .iter()
373        .map(|layer| {
374            layer
375                .iter()
376                .map(|v| v.iter().map(|&x| x as f32).collect())
377                .collect()
378        })
379        .collect();
380
381    let gnn_layers: Vec<RustRuvectorLayer> = gnn_layers_json
382        .iter()
383        .map(|json| {
384            serde_json::from_str(json).map_err(|e| {
385                Error::new(
386                    Status::GenericFailure,
387                    format!("Layer deserialization error: {}", e),
388                )
389            })
390        })
391        .collect::<Result<Vec<_>>>()?;
392
393    let result = rust_hierarchical_forward(&query_f32, &embeddings_f32, &gnn_layers);
394
395    Ok(result.iter().map(|&x| x as f64).collect())
396}
397
398// ==================== Helper Functions ====================
399
400/// Get the compression level that would be selected for a given access frequency
401///
402/// # Arguments
403/// * `access_freq` - Access frequency in range [0.0, 1.0]
404///
405/// # Returns
406/// String describing the compression level: "none", "half", "pq8", "pq4", or "binary"
407///
408/// # Example
409/// ```javascript
410/// const level = getCompressionLevel(0.9); // "none" (hot data)
411/// const level2 = getCompressionLevel(0.5); // "half" (warm data)
412/// ```
413#[napi]
414pub fn get_compression_level(access_freq: f64) -> String {
415    if access_freq > 0.8 {
416        "none".to_string()
417    } else if access_freq > 0.4 {
418        "half".to_string()
419    } else if access_freq > 0.1 {
420        "pq8".to_string()
421    } else if access_freq > 0.01 {
422        "pq4".to_string()
423    } else {
424        "binary".to_string()
425    }
426}
427
428/// Module initialization
429#[napi]
430pub fn init() -> String {
431    "Ruvector GNN Node.js bindings initialized".to_string()
432}