ruvector_gnn_node/
lib.rs

1//! Node.js bindings for Ruvector GNN via NAPI-RS
2//!
3//! This module provides JavaScript bindings for the Ruvector GNN library,
4//! enabling graph neural network operations, tensor compression, and
5//! differentiable search in Node.js applications.
6
7#![deny(clippy::all)]
8
9use napi::bindgen_prelude::*;
10use napi_derive::napi;
11use ruvector_gnn::{
12    compress::{
13        CompressedTensor as RustCompressedTensor, CompressionLevel as RustCompressionLevel,
14        TensorCompress as RustTensorCompress,
15    },
16    layer::RuvectorLayer as RustRuvectorLayer,
17    search::{
18        differentiable_search as rust_differentiable_search,
19        hierarchical_forward as rust_hierarchical_forward,
20    },
21};
22
23// ==================== RuvectorLayer Bindings ====================
24
25/// Graph Neural Network layer for HNSW topology
26#[napi]
27pub struct RuvectorLayer {
28    inner: RustRuvectorLayer,
29}
30
31#[napi]
32impl RuvectorLayer {
33    /// Create a new Ruvector GNN layer
34    ///
35    /// # Arguments
36    /// * `input_dim` - Dimension of input node embeddings
37    /// * `hidden_dim` - Dimension of hidden representations
38    /// * `heads` - Number of attention heads
39    /// * `dropout` - Dropout rate (0.0 to 1.0)
40    ///
41    /// # Example
42    /// ```javascript
43    /// const layer = new RuvectorLayer(128, 256, 4, 0.1);
44    /// ```
45    #[napi(constructor)]
46    pub fn new(input_dim: u32, hidden_dim: u32, heads: u32, dropout: f64) -> Result<Self> {
47        if dropout < 0.0 || dropout > 1.0 {
48            return Err(Error::new(
49                Status::InvalidArg,
50                "Dropout must be between 0.0 and 1.0".to_string(),
51            ));
52        }
53
54        Ok(Self {
55            inner: RustRuvectorLayer::new(
56                input_dim as usize,
57                hidden_dim as usize,
58                heads as usize,
59                dropout as f32,
60            ),
61        })
62    }
63
64    /// Forward pass through the GNN layer
65    ///
66    /// # Arguments
67    /// * `node_embedding` - Current node's embedding (Float32Array)
68    /// * `neighbor_embeddings` - Embeddings of neighbor nodes (Array of Float32Array)
69    /// * `edge_weights` - Weights of edges to neighbors (Float32Array)
70    ///
71    /// # Returns
72    /// Updated node embedding as Float32Array
73    ///
74    /// # Example
75    /// ```javascript
76    /// const node = new Float32Array([1.0, 2.0, 3.0, 4.0]);
77    /// const neighbors = [new Float32Array([0.5, 1.0, 1.5, 2.0]), new Float32Array([2.0, 3.0, 4.0, 5.0])];
78    /// const weights = new Float32Array([0.3, 0.7]);
79    /// const output = layer.forward(node, neighbors, weights);
80    /// ```
81    #[napi]
82    pub fn forward(
83        &self,
84        node_embedding: Float32Array,
85        neighbor_embeddings: Vec<Float32Array>,
86        edge_weights: Float32Array,
87    ) -> Result<Float32Array> {
88        let node_slice = node_embedding.as_ref();
89        let neighbors_vec: Vec<Vec<f32>> = neighbor_embeddings
90            .into_iter()
91            .map(|arr| arr.to_vec())
92            .collect();
93        let weights_slice = edge_weights.as_ref();
94
95        let result = self
96            .inner
97            .forward(node_slice, &neighbors_vec, weights_slice);
98
99        Ok(Float32Array::new(result))
100    }
101
102    /// Serialize the layer to JSON
103    #[napi]
104    pub fn to_json(&self) -> Result<String> {
105        serde_json::to_string(&self.inner).map_err(|e| {
106            Error::new(
107                Status::GenericFailure,
108                format!("Serialization error: {}", e),
109            )
110        })
111    }
112
113    /// Deserialize the layer from JSON
114    #[napi(factory)]
115    pub fn from_json(json: String) -> Result<Self> {
116        let inner: RustRuvectorLayer = serde_json::from_str(&json).map_err(|e| {
117            Error::new(
118                Status::GenericFailure,
119                format!("Deserialization error: {}", e),
120            )
121        })?;
122        Ok(Self { inner })
123    }
124}
125
126// ==================== TensorCompress Bindings ====================
127
128/// Compression level for tensor compression
129#[napi(object)]
130pub struct CompressionLevelConfig {
131    /// Type of compression: "none", "half", "pq8", "pq4", "binary"
132    pub level_type: String,
133    /// Scale factor (for "half" compression)
134    pub scale: Option<f64>,
135    /// Number of subvectors (for PQ compression)
136    pub subvectors: Option<u32>,
137    /// Number of centroids (for PQ8)
138    pub centroids: Option<u32>,
139    /// Outlier threshold (for PQ4)
140    pub outlier_threshold: Option<f64>,
141    /// Binary threshold (for binary compression)
142    pub threshold: Option<f64>,
143}
144
145impl CompressionLevelConfig {
146    fn to_rust(&self) -> Result<RustCompressionLevel> {
147        match self.level_type.as_str() {
148            "none" => Ok(RustCompressionLevel::None),
149            "half" => Ok(RustCompressionLevel::Half {
150                scale: self.scale.unwrap_or(1.0) as f32,
151            }),
152            "pq8" => Ok(RustCompressionLevel::PQ8 {
153                subvectors: self.subvectors.unwrap_or(8) as u8,
154                centroids: self.centroids.unwrap_or(16) as u8,
155            }),
156            "pq4" => Ok(RustCompressionLevel::PQ4 {
157                subvectors: self.subvectors.unwrap_or(8) as u8,
158                outlier_threshold: self.outlier_threshold.unwrap_or(3.0) as f32,
159            }),
160            "binary" => Ok(RustCompressionLevel::Binary {
161                threshold: self.threshold.unwrap_or(0.0) as f32,
162            }),
163            _ => Err(Error::new(
164                Status::InvalidArg,
165                format!("Invalid compression level: {}", self.level_type),
166            )),
167        }
168    }
169}
170
171/// Tensor compressor with adaptive level selection
172#[napi]
173pub struct TensorCompress {
174    inner: RustTensorCompress,
175}
176
177#[napi]
178impl TensorCompress {
179    /// Create a new tensor compressor
180    ///
181    /// # Example
182    /// ```javascript
183    /// const compressor = new TensorCompress();
184    /// ```
185    #[napi(constructor)]
186    pub fn new() -> Self {
187        Self {
188            inner: RustTensorCompress::new(),
189        }
190    }
191
192    /// Compress an embedding based on access frequency
193    ///
194    /// # Arguments
195    /// * `embedding` - The input embedding vector (Float32Array)
196    /// * `access_freq` - Access frequency in range [0.0, 1.0]
197    ///
198    /// # Returns
199    /// Compressed tensor as JSON string
200    ///
201    /// # Example
202    /// ```javascript
203    /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
204    /// const compressed = compressor.compress(embedding, 0.5);
205    /// ```
206    #[napi]
207    pub fn compress(&self, embedding: Float32Array, access_freq: f64) -> Result<String> {
208        let embedding_slice = embedding.as_ref();
209
210        let compressed = self
211            .inner
212            .compress(embedding_slice, access_freq as f32)
213            .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
214
215        serde_json::to_string(&compressed).map_err(|e| {
216            Error::new(
217                Status::GenericFailure,
218                format!("Serialization error: {}", e),
219            )
220        })
221    }
222
223    /// Compress with explicit compression level
224    ///
225    /// # Arguments
226    /// * `embedding` - The input embedding vector (Float32Array)
227    /// * `level` - Compression level configuration
228    ///
229    /// # Returns
230    /// Compressed tensor as JSON string
231    ///
232    /// # Example
233    /// ```javascript
234    /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
235    /// const level = { level_type: "half", scale: 1.0 };
236    /// const compressed = compressor.compressWithLevel(embedding, level);
237    /// ```
238    #[napi]
239    pub fn compress_with_level(
240        &self,
241        embedding: Float32Array,
242        level: CompressionLevelConfig,
243    ) -> Result<String> {
244        let embedding_slice = embedding.as_ref();
245        let rust_level = level.to_rust()?;
246
247        let compressed = self
248            .inner
249            .compress_with_level(embedding_slice, &rust_level)
250            .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
251
252        serde_json::to_string(&compressed).map_err(|e| {
253            Error::new(
254                Status::GenericFailure,
255                format!("Serialization error: {}", e),
256            )
257        })
258    }
259
260    /// Decompress a compressed tensor
261    ///
262    /// # Arguments
263    /// * `compressed_json` - Compressed tensor as JSON string
264    ///
265    /// # Returns
266    /// Decompressed embedding vector as Float32Array
267    ///
268    /// # Example
269    /// ```javascript
270    /// const decompressed = compressor.decompress(compressed);
271    /// ```
272    #[napi]
273    pub fn decompress(&self, compressed_json: String) -> Result<Float32Array> {
274        let compressed: RustCompressedTensor =
275            serde_json::from_str(&compressed_json).map_err(|e| {
276                Error::new(
277                    Status::GenericFailure,
278                    format!("Deserialization error: {}", e),
279                )
280            })?;
281
282        let result = self.inner.decompress(&compressed).map_err(|e| {
283            Error::new(
284                Status::GenericFailure,
285                format!("Decompression error: {}", e),
286            )
287        })?;
288
289        Ok(Float32Array::new(result))
290    }
291}
292
293// ==================== Search Functions ====================
294
295/// Result from differentiable search
296#[napi(object)]
297pub struct SearchResult {
298    /// Indices of top-k candidates
299    pub indices: Vec<u32>,
300    /// Soft weights for top-k candidates
301    pub weights: Vec<f64>,
302}
303
304/// Differentiable search using soft attention mechanism
305///
306/// # Arguments
307/// * `query` - The query vector (Float32Array)
308/// * `candidate_embeddings` - List of candidate embedding vectors (Array of Float32Array)
309/// * `k` - Number of top results to return
310/// * `temperature` - Temperature for softmax (lower = sharper, higher = smoother)
311///
312/// # Returns
313/// Search result with indices and soft weights
314///
315/// # Example
316/// ```javascript
317/// const query = new Float32Array([1.0, 0.0, 0.0]);
318/// const candidates = [new Float32Array([1.0, 0.0, 0.0]), new Float32Array([0.9, 0.1, 0.0]), new Float32Array([0.0, 1.0, 0.0])];
319/// const result = differentiableSearch(query, candidates, 2, 1.0);
320/// console.log(result.indices); // [0, 1]
321/// console.log(result.weights); // [0.x, 0.y]
322/// ```
323#[napi]
324pub fn differentiable_search(
325    query: Float32Array,
326    candidate_embeddings: Vec<Float32Array>,
327    k: u32,
328    temperature: f64,
329) -> Result<SearchResult> {
330    let query_slice = query.as_ref();
331    let candidates_vec: Vec<Vec<f32>> = candidate_embeddings
332        .into_iter()
333        .map(|arr| arr.to_vec())
334        .collect();
335
336    let (indices, weights) =
337        rust_differentiable_search(query_slice, &candidates_vec, k as usize, temperature as f32);
338
339    Ok(SearchResult {
340        indices: indices.iter().map(|&i| i as u32).collect(),
341        weights: weights.iter().map(|&w| w as f64).collect(),
342    })
343}
344
345/// Hierarchical forward pass through GNN layers
346///
347/// # Arguments
348/// * `query` - The query vector (Float32Array)
349/// * `layer_embeddings` - Embeddings organized by layer (Array of Array of Float32Array)
350/// * `gnn_layers_json` - JSON array of serialized GNN layers
351///
352/// # Returns
353/// Final embedding after hierarchical processing as Float32Array
354///
355/// # Example
356/// ```javascript
357/// const query = new Float32Array([1.0, 0.0]);
358/// const layerEmbeddings = [[new Float32Array([1.0, 0.0]), new Float32Array([0.0, 1.0])]];
359/// const layer1 = new RuvectorLayer(2, 2, 1, 0.0);
360/// const layers = [layer1.toJson()];
361/// const result = hierarchicalForward(query, layerEmbeddings, layers);
362/// ```
363#[napi]
364pub fn hierarchical_forward(
365    query: Float32Array,
366    layer_embeddings: Vec<Vec<Float32Array>>,
367    gnn_layers_json: Vec<String>,
368) -> Result<Float32Array> {
369    let query_slice = query.as_ref();
370
371    let embeddings_f32: Vec<Vec<Vec<f32>>> = layer_embeddings
372        .into_iter()
373        .map(|layer| layer.into_iter().map(|arr| arr.to_vec()).collect())
374        .collect();
375
376    let gnn_layers: Vec<RustRuvectorLayer> = gnn_layers_json
377        .iter()
378        .map(|json| {
379            serde_json::from_str(json).map_err(|e| {
380                Error::new(
381                    Status::GenericFailure,
382                    format!("Layer deserialization error: {}", e),
383                )
384            })
385        })
386        .collect::<Result<Vec<_>>>()?;
387
388    let result = rust_hierarchical_forward(query_slice, &embeddings_f32, &gnn_layers);
389
390    Ok(Float32Array::new(result))
391}
392
393// ==================== Helper Functions ====================
394
395/// Get the compression level that would be selected for a given access frequency
396///
397/// # Arguments
398/// * `access_freq` - Access frequency in range [0.0, 1.0]
399///
400/// # Returns
401/// String describing the compression level: "none", "half", "pq8", "pq4", or "binary"
402///
403/// # Example
404/// ```javascript
405/// const level = getCompressionLevel(0.9); // "none" (hot data)
406/// const level2 = getCompressionLevel(0.5); // "half" (warm data)
407/// ```
408#[napi]
409pub fn get_compression_level(access_freq: f64) -> String {
410    if access_freq > 0.8 {
411        "none".to_string()
412    } else if access_freq > 0.4 {
413        "half".to_string()
414    } else if access_freq > 0.1 {
415        "pq8".to_string()
416    } else if access_freq > 0.01 {
417        "pq4".to_string()
418    } else {
419        "binary".to_string()
420    }
421}
422
423/// Module initialization
424#[napi]
425pub fn init() -> String {
426    "Ruvector GNN Node.js bindings initialized".to_string()
427}