Skip to main content

ruvector_gnn_node/
lib.rs

1//! Node.js bindings for Ruvector GNN via NAPI-RS
2//!
3//! This module provides JavaScript bindings for the Ruvector GNN library,
4//! enabling graph neural network operations, tensor compression, and
5//! differentiable search in Node.js applications.
6
7#![deny(clippy::all)]
8
9use napi::bindgen_prelude::*;
10use napi_derive::napi;
11use ruvector_gnn::{
12    compress::{
13        CompressedTensor as RustCompressedTensor, CompressionLevel as RustCompressionLevel,
14        TensorCompress as RustTensorCompress,
15    },
16    layer::RuvectorLayer as RustRuvectorLayer,
17    search::{
18        differentiable_search as rust_differentiable_search,
19        hierarchical_forward as rust_hierarchical_forward,
20    },
21};
22
23// ==================== RuvectorLayer Bindings ====================
24
25/// Graph Neural Network layer for HNSW topology
26#[napi]
27pub struct RuvectorLayer {
28    inner: RustRuvectorLayer,
29}
30
31#[napi]
32impl RuvectorLayer {
33    /// Create a new Ruvector GNN layer
34    ///
35    /// # Arguments
36    /// * `input_dim` - Dimension of input node embeddings
37    /// * `hidden_dim` - Dimension of hidden representations
38    /// * `heads` - Number of attention heads
39    /// * `dropout` - Dropout rate (0.0 to 1.0)
40    ///
41    /// # Example
42    /// ```javascript
43    /// const layer = new RuvectorLayer(128, 256, 4, 0.1);
44    /// ```
45    #[napi(constructor)]
46    pub fn new(input_dim: u32, hidden_dim: u32, heads: u32, dropout: f64) -> Result<Self> {
47        let inner = RustRuvectorLayer::new(
48            input_dim as usize,
49            hidden_dim as usize,
50            heads as usize,
51            dropout as f32,
52        )
53        .map_err(|e| Error::new(Status::InvalidArg, e.to_string()))?;
54
55        Ok(Self { inner })
56    }
57
58    /// Forward pass through the GNN layer
59    ///
60    /// # Arguments
61    /// * `node_embedding` - Current node's embedding (Float32Array)
62    /// * `neighbor_embeddings` - Embeddings of neighbor nodes (Array of Float32Array)
63    /// * `edge_weights` - Weights of edges to neighbors (Float32Array)
64    ///
65    /// # Returns
66    /// Updated node embedding as Float32Array
67    ///
68    /// # Example
69    /// ```javascript
70    /// const node = new Float32Array([1.0, 2.0, 3.0, 4.0]);
71    /// const neighbors = [new Float32Array([0.5, 1.0, 1.5, 2.0]), new Float32Array([2.0, 3.0, 4.0, 5.0])];
72    /// const weights = new Float32Array([0.3, 0.7]);
73    /// const output = layer.forward(node, neighbors, weights);
74    /// ```
75    #[napi]
76    pub fn forward(
77        &self,
78        node_embedding: Float32Array,
79        neighbor_embeddings: Vec<Float32Array>,
80        edge_weights: Float32Array,
81    ) -> Result<Float32Array> {
82        let node_slice = node_embedding.as_ref();
83        let neighbors_vec: Vec<Vec<f32>> = neighbor_embeddings
84            .into_iter()
85            .map(|arr| arr.to_vec())
86            .collect();
87        let weights_slice = edge_weights.as_ref();
88
89        let result = self
90            .inner
91            .forward(node_slice, &neighbors_vec, weights_slice);
92
93        Ok(Float32Array::new(result))
94    }
95
96    /// Serialize the layer to JSON
97    #[napi]
98    pub fn to_json(&self) -> Result<String> {
99        serde_json::to_string(&self.inner).map_err(|e| {
100            Error::new(
101                Status::GenericFailure,
102                format!("Serialization error: {}", e),
103            )
104        })
105    }
106
107    /// Deserialize the layer from JSON
108    #[napi(factory)]
109    pub fn from_json(json: String) -> Result<Self> {
110        let inner: RustRuvectorLayer = serde_json::from_str(&json).map_err(|e| {
111            Error::new(
112                Status::GenericFailure,
113                format!("Deserialization error: {}", e),
114            )
115        })?;
116        Ok(Self { inner })
117    }
118}
119
120// ==================== TensorCompress Bindings ====================
121
122/// Compression level for tensor compression
123#[napi(object)]
124pub struct CompressionLevelConfig {
125    /// Type of compression: "none", "half", "pq8", "pq4", "binary"
126    pub level_type: String,
127    /// Scale factor (for "half" compression)
128    pub scale: Option<f64>,
129    /// Number of subvectors (for PQ compression)
130    pub subvectors: Option<u32>,
131    /// Number of centroids (for PQ8)
132    pub centroids: Option<u32>,
133    /// Outlier threshold (for PQ4)
134    pub outlier_threshold: Option<f64>,
135    /// Binary threshold (for binary compression)
136    pub threshold: Option<f64>,
137}
138
139impl CompressionLevelConfig {
140    fn to_rust(&self) -> Result<RustCompressionLevel> {
141        match self.level_type.as_str() {
142            "none" => Ok(RustCompressionLevel::None),
143            "half" => Ok(RustCompressionLevel::Half {
144                scale: self.scale.unwrap_or(1.0) as f32,
145            }),
146            "pq8" => Ok(RustCompressionLevel::PQ8 {
147                subvectors: self.subvectors.unwrap_or(8) as u8,
148                centroids: self.centroids.unwrap_or(16) as u8,
149            }),
150            "pq4" => Ok(RustCompressionLevel::PQ4 {
151                subvectors: self.subvectors.unwrap_or(8) as u8,
152                outlier_threshold: self.outlier_threshold.unwrap_or(3.0) as f32,
153            }),
154            "binary" => Ok(RustCompressionLevel::Binary {
155                threshold: self.threshold.unwrap_or(0.0) as f32,
156            }),
157            _ => Err(Error::new(
158                Status::InvalidArg,
159                format!("Invalid compression level: {}", self.level_type),
160            )),
161        }
162    }
163}
164
165/// Tensor compressor with adaptive level selection
166#[napi]
167pub struct TensorCompress {
168    inner: RustTensorCompress,
169}
170
171#[napi]
172impl TensorCompress {
173    /// Create a new tensor compressor
174    ///
175    /// # Example
176    /// ```javascript
177    /// const compressor = new TensorCompress();
178    /// ```
179    #[napi(constructor)]
180    pub fn new() -> Self {
181        Self {
182            inner: RustTensorCompress::new(),
183        }
184    }
185
186    /// Compress an embedding based on access frequency
187    ///
188    /// # Arguments
189    /// * `embedding` - The input embedding vector (Float32Array)
190    /// * `access_freq` - Access frequency in range [0.0, 1.0]
191    ///
192    /// # Returns
193    /// Compressed tensor as JSON string
194    ///
195    /// # Example
196    /// ```javascript
197    /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
198    /// const compressed = compressor.compress(embedding, 0.5);
199    /// ```
200    #[napi]
201    pub fn compress(&self, embedding: Float32Array, access_freq: f64) -> Result<String> {
202        let embedding_slice = embedding.as_ref();
203
204        let compressed = self
205            .inner
206            .compress(embedding_slice, access_freq as f32)
207            .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
208
209        serde_json::to_string(&compressed).map_err(|e| {
210            Error::new(
211                Status::GenericFailure,
212                format!("Serialization error: {}", e),
213            )
214        })
215    }
216
217    /// Compress with explicit compression level
218    ///
219    /// # Arguments
220    /// * `embedding` - The input embedding vector (Float32Array)
221    /// * `level` - Compression level configuration
222    ///
223    /// # Returns
224    /// Compressed tensor as JSON string
225    ///
226    /// # Example
227    /// ```javascript
228    /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
229    /// const level = { level_type: "half", scale: 1.0 };
230    /// const compressed = compressor.compressWithLevel(embedding, level);
231    /// ```
232    #[napi]
233    pub fn compress_with_level(
234        &self,
235        embedding: Float32Array,
236        level: CompressionLevelConfig,
237    ) -> Result<String> {
238        let embedding_slice = embedding.as_ref();
239        let rust_level = level.to_rust()?;
240
241        let compressed = self
242            .inner
243            .compress_with_level(embedding_slice, &rust_level)
244            .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
245
246        serde_json::to_string(&compressed).map_err(|e| {
247            Error::new(
248                Status::GenericFailure,
249                format!("Serialization error: {}", e),
250            )
251        })
252    }
253
254    /// Decompress a compressed tensor
255    ///
256    /// # Arguments
257    /// * `compressed_json` - Compressed tensor as JSON string
258    ///
259    /// # Returns
260    /// Decompressed embedding vector as Float32Array
261    ///
262    /// # Example
263    /// ```javascript
264    /// const decompressed = compressor.decompress(compressed);
265    /// ```
266    #[napi]
267    pub fn decompress(&self, compressed_json: String) -> Result<Float32Array> {
268        let compressed: RustCompressedTensor =
269            serde_json::from_str(&compressed_json).map_err(|e| {
270                Error::new(
271                    Status::GenericFailure,
272                    format!("Deserialization error: {}", e),
273                )
274            })?;
275
276        let result = self.inner.decompress(&compressed).map_err(|e| {
277            Error::new(
278                Status::GenericFailure,
279                format!("Decompression error: {}", e),
280            )
281        })?;
282
283        Ok(Float32Array::new(result))
284    }
285}
286
287// ==================== Search Functions ====================
288
289/// Result from differentiable search
290#[napi(object)]
291pub struct SearchResult {
292    /// Indices of top-k candidates
293    pub indices: Vec<u32>,
294    /// Soft weights for top-k candidates
295    pub weights: Vec<f64>,
296}
297
298/// Differentiable search using soft attention mechanism
299///
300/// # Arguments
301/// * `query` - The query vector (Float32Array)
302/// * `candidate_embeddings` - List of candidate embedding vectors (Array of Float32Array)
303/// * `k` - Number of top results to return
304/// * `temperature` - Temperature for softmax (lower = sharper, higher = smoother)
305///
306/// # Returns
307/// Search result with indices and soft weights
308///
309/// # Example
310/// ```javascript
311/// const query = new Float32Array([1.0, 0.0, 0.0]);
312/// const candidates = [new Float32Array([1.0, 0.0, 0.0]), new Float32Array([0.9, 0.1, 0.0]), new Float32Array([0.0, 1.0, 0.0])];
313/// const result = differentiableSearch(query, candidates, 2, 1.0);
314/// console.log(result.indices); // [0, 1]
315/// console.log(result.weights); // [0.x, 0.y]
316/// ```
317#[napi]
318pub fn differentiable_search(
319    query: Float32Array,
320    candidate_embeddings: Vec<Float32Array>,
321    k: u32,
322    temperature: f64,
323) -> Result<SearchResult> {
324    let query_slice = query.as_ref();
325    let candidates_vec: Vec<Vec<f32>> = candidate_embeddings
326        .into_iter()
327        .map(|arr| arr.to_vec())
328        .collect();
329
330    let (indices, weights) =
331        rust_differentiable_search(query_slice, &candidates_vec, k as usize, temperature as f32);
332
333    Ok(SearchResult {
334        indices: indices.iter().map(|&i| i as u32).collect(),
335        weights: weights.iter().map(|&w| w as f64).collect(),
336    })
337}
338
339/// Hierarchical forward pass through GNN layers
340///
341/// # Arguments
342/// * `query` - The query vector (Float32Array)
343/// * `layer_embeddings` - Embeddings organized by layer (Array of Array of Float32Array)
344/// * `gnn_layers_json` - JSON array of serialized GNN layers
345///
346/// # Returns
347/// Final embedding after hierarchical processing as Float32Array
348///
349/// # Example
350/// ```javascript
351/// const query = new Float32Array([1.0, 0.0]);
352/// const layerEmbeddings = [[new Float32Array([1.0, 0.0]), new Float32Array([0.0, 1.0])]];
353/// const layer1 = new RuvectorLayer(2, 2, 1, 0.0);
354/// const layers = [layer1.toJson()];
355/// const result = hierarchicalForward(query, layerEmbeddings, layers);
356/// ```
357#[napi]
358pub fn hierarchical_forward(
359    query: Float32Array,
360    layer_embeddings: Vec<Vec<Float32Array>>,
361    gnn_layers_json: Vec<String>,
362) -> Result<Float32Array> {
363    let query_slice = query.as_ref();
364
365    let embeddings_f32: Vec<Vec<Vec<f32>>> = layer_embeddings
366        .into_iter()
367        .map(|layer| layer.into_iter().map(|arr| arr.to_vec()).collect())
368        .collect();
369
370    let gnn_layers: Vec<RustRuvectorLayer> = gnn_layers_json
371        .iter()
372        .map(|json| {
373            serde_json::from_str(json).map_err(|e| {
374                Error::new(
375                    Status::GenericFailure,
376                    format!("Layer deserialization error: {}", e),
377                )
378            })
379        })
380        .collect::<Result<Vec<_>>>()?;
381
382    let result = rust_hierarchical_forward(query_slice, &embeddings_f32, &gnn_layers);
383
384    Ok(Float32Array::new(result))
385}
386
387// ==================== Helper Functions ====================
388
389/// Get the compression level that would be selected for a given access frequency
390///
391/// # Arguments
392/// * `access_freq` - Access frequency in range [0.0, 1.0]
393///
394/// # Returns
395/// String describing the compression level: "none", "half", "pq8", "pq4", or "binary"
396///
397/// # Example
398/// ```javascript
399/// const level = getCompressionLevel(0.9); // "none" (hot data)
400/// const level2 = getCompressionLevel(0.5); // "half" (warm data)
401/// ```
402#[napi]
403pub fn get_compression_level(access_freq: f64) -> String {
404    if access_freq > 0.8 {
405        "none".to_string()
406    } else if access_freq > 0.4 {
407        "half".to_string()
408    } else if access_freq > 0.1 {
409        "pq8".to_string()
410    } else if access_freq > 0.01 {
411        "pq4".to_string()
412    } else {
413        "binary".to_string()
414    }
415}
416
417/// Module initialization
418#[napi]
419pub fn init() -> String {
420    "Ruvector GNN Node.js bindings initialized".to_string()
421}