ruvector_gnn_node/
lib.rs

1//! Node.js bindings for Ruvector GNN via NAPI-RS
2//!
3//! This module provides JavaScript bindings for the Ruvector GNN library,
4//! enabling graph neural network operations, tensor compression, and
5//! differentiable search in Node.js applications.
6
7#![deny(clippy::all)]
8
9use napi::bindgen_prelude::*;
10use napi_derive::napi;
11use ruvector_gnn::{
12    compress::{
13        CompressedTensor as RustCompressedTensor, CompressionLevel as RustCompressionLevel,
14        TensorCompress as RustTensorCompress,
15    },
16    layer::RuvectorLayer as RustRuvectorLayer,
17    search::{
18        differentiable_search as rust_differentiable_search,
19        hierarchical_forward as rust_hierarchical_forward,
20    },
21};
22
23// ==================== RuvectorLayer Bindings ====================
24
25/// Graph Neural Network layer for HNSW topology
26#[napi]
27pub struct RuvectorLayer {
28    inner: RustRuvectorLayer,
29}
30
31#[napi]
32impl RuvectorLayer {
33    /// Create a new Ruvector GNN layer
34    ///
35    /// # Arguments
36    /// * `input_dim` - Dimension of input node embeddings
37    /// * `hidden_dim` - Dimension of hidden representations
38    /// * `heads` - Number of attention heads
39    /// * `dropout` - Dropout rate (0.0 to 1.0)
40    ///
41    /// # Example
42    /// ```javascript
43    /// const layer = new RuvectorLayer(128, 256, 4, 0.1);
44    /// ```
45    #[napi(constructor)]
46    pub fn new(input_dim: u32, hidden_dim: u32, heads: u32, dropout: f64) -> Result<Self> {
47        if dropout < 0.0 || dropout > 1.0 {
48            return Err(Error::new(
49                Status::InvalidArg,
50                "Dropout must be between 0.0 and 1.0".to_string(),
51            ));
52        }
53
54        Ok(Self {
55            inner: RustRuvectorLayer::new(
56                input_dim as usize,
57                hidden_dim as usize,
58                heads as usize,
59                dropout as f32,
60            ),
61        })
62    }
63
64    /// Forward pass through the GNN layer
65    ///
66    /// # Arguments
67    /// * `node_embedding` - Current node's embedding (Float32Array)
68    /// * `neighbor_embeddings` - Embeddings of neighbor nodes (Array of Float32Array)
69    /// * `edge_weights` - Weights of edges to neighbors (Float32Array)
70    ///
71    /// # Returns
72    /// Updated node embedding as Float32Array
73    ///
74    /// # Example
75    /// ```javascript
76    /// const node = new Float32Array([1.0, 2.0, 3.0, 4.0]);
77    /// const neighbors = [new Float32Array([0.5, 1.0, 1.5, 2.0]), new Float32Array([2.0, 3.0, 4.0, 5.0])];
78    /// const weights = new Float32Array([0.3, 0.7]);
79    /// const output = layer.forward(node, neighbors, weights);
80    /// ```
81    #[napi]
82    pub fn forward(
83        &self,
84        node_embedding: Float32Array,
85        neighbor_embeddings: Vec<Float32Array>,
86        edge_weights: Float32Array,
87    ) -> Result<Float32Array> {
88        let node_slice = node_embedding.as_ref();
89        let neighbors_vec: Vec<Vec<f32>> = neighbor_embeddings
90            .into_iter()
91            .map(|arr| arr.to_vec())
92            .collect();
93        let weights_slice = edge_weights.as_ref();
94
95        let result = self.inner.forward(node_slice, &neighbors_vec, weights_slice);
96
97        Ok(Float32Array::new(result))
98    }
99
100    /// Serialize the layer to JSON
101    #[napi]
102    pub fn to_json(&self) -> Result<String> {
103        serde_json::to_string(&self.inner).map_err(|e| {
104            Error::new(
105                Status::GenericFailure,
106                format!("Serialization error: {}", e),
107            )
108        })
109    }
110
111    /// Deserialize the layer from JSON
112    #[napi(factory)]
113    pub fn from_json(json: String) -> Result<Self> {
114        let inner: RustRuvectorLayer = serde_json::from_str(&json).map_err(|e| {
115            Error::new(
116                Status::GenericFailure,
117                format!("Deserialization error: {}", e),
118            )
119        })?;
120        Ok(Self { inner })
121    }
122}
123
124// ==================== TensorCompress Bindings ====================
125
126/// Compression level for tensor compression
127#[napi(object)]
128pub struct CompressionLevelConfig {
129    /// Type of compression: "none", "half", "pq8", "pq4", "binary"
130    pub level_type: String,
131    /// Scale factor (for "half" compression)
132    pub scale: Option<f64>,
133    /// Number of subvectors (for PQ compression)
134    pub subvectors: Option<u32>,
135    /// Number of centroids (for PQ8)
136    pub centroids: Option<u32>,
137    /// Outlier threshold (for PQ4)
138    pub outlier_threshold: Option<f64>,
139    /// Binary threshold (for binary compression)
140    pub threshold: Option<f64>,
141}
142
143impl CompressionLevelConfig {
144    fn to_rust(&self) -> Result<RustCompressionLevel> {
145        match self.level_type.as_str() {
146            "none" => Ok(RustCompressionLevel::None),
147            "half" => Ok(RustCompressionLevel::Half {
148                scale: self.scale.unwrap_or(1.0) as f32,
149            }),
150            "pq8" => Ok(RustCompressionLevel::PQ8 {
151                subvectors: self.subvectors.unwrap_or(8) as u8,
152                centroids: self.centroids.unwrap_or(16) as u8,
153            }),
154            "pq4" => Ok(RustCompressionLevel::PQ4 {
155                subvectors: self.subvectors.unwrap_or(8) as u8,
156                outlier_threshold: self.outlier_threshold.unwrap_or(3.0) as f32,
157            }),
158            "binary" => Ok(RustCompressionLevel::Binary {
159                threshold: self.threshold.unwrap_or(0.0) as f32,
160            }),
161            _ => Err(Error::new(
162                Status::InvalidArg,
163                format!("Invalid compression level: {}", self.level_type),
164            )),
165        }
166    }
167}
168
169/// Tensor compressor with adaptive level selection
170#[napi]
171pub struct TensorCompress {
172    inner: RustTensorCompress,
173}
174
175#[napi]
176impl TensorCompress {
177    /// Create a new tensor compressor
178    ///
179    /// # Example
180    /// ```javascript
181    /// const compressor = new TensorCompress();
182    /// ```
183    #[napi(constructor)]
184    pub fn new() -> Self {
185        Self {
186            inner: RustTensorCompress::new(),
187        }
188    }
189
190    /// Compress an embedding based on access frequency
191    ///
192    /// # Arguments
193    /// * `embedding` - The input embedding vector (Float32Array)
194    /// * `access_freq` - Access frequency in range [0.0, 1.0]
195    ///
196    /// # Returns
197    /// Compressed tensor as JSON string
198    ///
199    /// # Example
200    /// ```javascript
201    /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
202    /// const compressed = compressor.compress(embedding, 0.5);
203    /// ```
204    #[napi]
205    pub fn compress(&self, embedding: Float32Array, access_freq: f64) -> Result<String> {
206        let embedding_slice = embedding.as_ref();
207
208        let compressed = self
209            .inner
210            .compress(embedding_slice, access_freq as f32)
211            .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
212
213        serde_json::to_string(&compressed).map_err(|e| {
214            Error::new(
215                Status::GenericFailure,
216                format!("Serialization error: {}", e),
217            )
218        })
219    }
220
221    /// Compress with explicit compression level
222    ///
223    /// # Arguments
224    /// * `embedding` - The input embedding vector (Float32Array)
225    /// * `level` - Compression level configuration
226    ///
227    /// # Returns
228    /// Compressed tensor as JSON string
229    ///
230    /// # Example
231    /// ```javascript
232    /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
233    /// const level = { level_type: "half", scale: 1.0 };
234    /// const compressed = compressor.compressWithLevel(embedding, level);
235    /// ```
236    #[napi]
237    pub fn compress_with_level(
238        &self,
239        embedding: Float32Array,
240        level: CompressionLevelConfig,
241    ) -> Result<String> {
242        let embedding_slice = embedding.as_ref();
243        let rust_level = level.to_rust()?;
244
245        let compressed = self
246            .inner
247            .compress_with_level(embedding_slice, &rust_level)
248            .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
249
250        serde_json::to_string(&compressed).map_err(|e| {
251            Error::new(
252                Status::GenericFailure,
253                format!("Serialization error: {}", e),
254            )
255        })
256    }
257
258    /// Decompress a compressed tensor
259    ///
260    /// # Arguments
261    /// * `compressed_json` - Compressed tensor as JSON string
262    ///
263    /// # Returns
264    /// Decompressed embedding vector as Float32Array
265    ///
266    /// # Example
267    /// ```javascript
268    /// const decompressed = compressor.decompress(compressed);
269    /// ```
270    #[napi]
271    pub fn decompress(&self, compressed_json: String) -> Result<Float32Array> {
272        let compressed: RustCompressedTensor =
273            serde_json::from_str(&compressed_json).map_err(|e| {
274                Error::new(
275                    Status::GenericFailure,
276                    format!("Deserialization error: {}", e),
277                )
278            })?;
279
280        let result = self.inner.decompress(&compressed).map_err(|e| {
281            Error::new(
282                Status::GenericFailure,
283                format!("Decompression error: {}", e),
284            )
285        })?;
286
287        Ok(Float32Array::new(result))
288    }
289}
290
291// ==================== Search Functions ====================
292
293/// Result from differentiable search
294#[napi(object)]
295pub struct SearchResult {
296    /// Indices of top-k candidates
297    pub indices: Vec<u32>,
298    /// Soft weights for top-k candidates
299    pub weights: Vec<f64>,
300}
301
302/// Differentiable search using soft attention mechanism
303///
304/// # Arguments
305/// * `query` - The query vector (Float32Array)
306/// * `candidate_embeddings` - List of candidate embedding vectors (Array of Float32Array)
307/// * `k` - Number of top results to return
308/// * `temperature` - Temperature for softmax (lower = sharper, higher = smoother)
309///
310/// # Returns
311/// Search result with indices and soft weights
312///
313/// # Example
314/// ```javascript
315/// const query = new Float32Array([1.0, 0.0, 0.0]);
316/// const candidates = [new Float32Array([1.0, 0.0, 0.0]), new Float32Array([0.9, 0.1, 0.0]), new Float32Array([0.0, 1.0, 0.0])];
317/// const result = differentiableSearch(query, candidates, 2, 1.0);
318/// console.log(result.indices); // [0, 1]
319/// console.log(result.weights); // [0.x, 0.y]
320/// ```
321#[napi]
322pub fn differentiable_search(
323    query: Float32Array,
324    candidate_embeddings: Vec<Float32Array>,
325    k: u32,
326    temperature: f64,
327) -> Result<SearchResult> {
328    let query_slice = query.as_ref();
329    let candidates_vec: Vec<Vec<f32>> = candidate_embeddings
330        .into_iter()
331        .map(|arr| arr.to_vec())
332        .collect();
333
334    let (indices, weights) =
335        rust_differentiable_search(query_slice, &candidates_vec, k as usize, temperature as f32);
336
337    Ok(SearchResult {
338        indices: indices.iter().map(|&i| i as u32).collect(),
339        weights: weights.iter().map(|&w| w as f64).collect(),
340    })
341}
342
343/// Hierarchical forward pass through GNN layers
344///
345/// # Arguments
346/// * `query` - The query vector (Float32Array)
347/// * `layer_embeddings` - Embeddings organized by layer (Array of Array of Float32Array)
348/// * `gnn_layers_json` - JSON array of serialized GNN layers
349///
350/// # Returns
351/// Final embedding after hierarchical processing as Float32Array
352///
353/// # Example
354/// ```javascript
355/// const query = new Float32Array([1.0, 0.0]);
356/// const layerEmbeddings = [[new Float32Array([1.0, 0.0]), new Float32Array([0.0, 1.0])]];
357/// const layer1 = new RuvectorLayer(2, 2, 1, 0.0);
358/// const layers = [layer1.toJson()];
359/// const result = hierarchicalForward(query, layerEmbeddings, layers);
360/// ```
361#[napi]
362pub fn hierarchical_forward(
363    query: Float32Array,
364    layer_embeddings: Vec<Vec<Float32Array>>,
365    gnn_layers_json: Vec<String>,
366) -> Result<Float32Array> {
367    let query_slice = query.as_ref();
368
369    let embeddings_f32: Vec<Vec<Vec<f32>>> = layer_embeddings
370        .into_iter()
371        .map(|layer| {
372            layer
373                .into_iter()
374                .map(|arr| arr.to_vec())
375                .collect()
376        })
377        .collect();
378
379    let gnn_layers: Vec<RustRuvectorLayer> = gnn_layers_json
380        .iter()
381        .map(|json| {
382            serde_json::from_str(json).map_err(|e| {
383                Error::new(
384                    Status::GenericFailure,
385                    format!("Layer deserialization error: {}", e),
386                )
387            })
388        })
389        .collect::<Result<Vec<_>>>()?;
390
391    let result = rust_hierarchical_forward(query_slice, &embeddings_f32, &gnn_layers);
392
393    Ok(Float32Array::new(result))
394}
395
396// ==================== Helper Functions ====================
397
398/// Get the compression level that would be selected for a given access frequency
399///
400/// # Arguments
401/// * `access_freq` - Access frequency in range [0.0, 1.0]
402///
403/// # Returns
404/// String describing the compression level: "none", "half", "pq8", "pq4", or "binary"
405///
406/// # Example
407/// ```javascript
408/// const level = getCompressionLevel(0.9); // "none" (hot data)
409/// const level2 = getCompressionLevel(0.5); // "half" (warm data)
410/// ```
411#[napi]
412pub fn get_compression_level(access_freq: f64) -> String {
413    if access_freq > 0.8 {
414        "none".to_string()
415    } else if access_freq > 0.4 {
416        "half".to_string()
417    } else if access_freq > 0.1 {
418        "pq8".to_string()
419    } else if access_freq > 0.01 {
420        "pq4".to_string()
421    } else {
422        "binary".to_string()
423    }
424}
425
426/// Module initialization
427#[napi]
428pub fn init() -> String {
429    "Ruvector GNN Node.js bindings initialized".to_string()
430}