ruvector_gnn_node/lib.rs
1//! Node.js bindings for Ruvector GNN via NAPI-RS
2//!
3//! This module provides JavaScript bindings for the Ruvector GNN library,
4//! enabling graph neural network operations, tensor compression, and
5//! differentiable search in Node.js applications.
6
7#![deny(clippy::all)]
8
9use napi::bindgen_prelude::*;
10use napi_derive::napi;
11use ruvector_gnn::{
12 compress::{
13 CompressedTensor as RustCompressedTensor, CompressionLevel as RustCompressionLevel,
14 TensorCompress as RustTensorCompress,
15 },
16 layer::RuvectorLayer as RustRuvectorLayer,
17 search::{
18 differentiable_search as rust_differentiable_search,
19 hierarchical_forward as rust_hierarchical_forward,
20 },
21};
22
23// ==================== RuvectorLayer Bindings ====================
24
25/// Graph Neural Network layer for HNSW topology
26#[napi]
27pub struct RuvectorLayer {
28 inner: RustRuvectorLayer,
29}
30
31#[napi]
32impl RuvectorLayer {
33 /// Create a new Ruvector GNN layer
34 ///
35 /// # Arguments
36 /// * `input_dim` - Dimension of input node embeddings
37 /// * `hidden_dim` - Dimension of hidden representations
38 /// * `heads` - Number of attention heads
39 /// * `dropout` - Dropout rate (0.0 to 1.0)
40 ///
41 /// # Example
42 /// ```javascript
43 /// const layer = new RuvectorLayer(128, 256, 4, 0.1);
44 /// ```
45 #[napi(constructor)]
46 pub fn new(input_dim: u32, hidden_dim: u32, heads: u32, dropout: f64) -> Result<Self> {
47 if dropout < 0.0 || dropout > 1.0 {
48 return Err(Error::new(
49 Status::InvalidArg,
50 "Dropout must be between 0.0 and 1.0".to_string(),
51 ));
52 }
53
54 Ok(Self {
55 inner: RustRuvectorLayer::new(
56 input_dim as usize,
57 hidden_dim as usize,
58 heads as usize,
59 dropout as f32,
60 ),
61 })
62 }
63
64 /// Forward pass through the GNN layer
65 ///
66 /// # Arguments
67 /// * `node_embedding` - Current node's embedding (Float32Array)
68 /// * `neighbor_embeddings` - Embeddings of neighbor nodes (Array of Float32Array)
69 /// * `edge_weights` - Weights of edges to neighbors (Float32Array)
70 ///
71 /// # Returns
72 /// Updated node embedding as Float32Array
73 ///
74 /// # Example
75 /// ```javascript
76 /// const node = new Float32Array([1.0, 2.0, 3.0, 4.0]);
77 /// const neighbors = [new Float32Array([0.5, 1.0, 1.5, 2.0]), new Float32Array([2.0, 3.0, 4.0, 5.0])];
78 /// const weights = new Float32Array([0.3, 0.7]);
79 /// const output = layer.forward(node, neighbors, weights);
80 /// ```
81 #[napi]
82 pub fn forward(
83 &self,
84 node_embedding: Float32Array,
85 neighbor_embeddings: Vec<Float32Array>,
86 edge_weights: Float32Array,
87 ) -> Result<Float32Array> {
88 let node_slice = node_embedding.as_ref();
89 let neighbors_vec: Vec<Vec<f32>> = neighbor_embeddings
90 .into_iter()
91 .map(|arr| arr.to_vec())
92 .collect();
93 let weights_slice = edge_weights.as_ref();
94
95 let result = self
96 .inner
97 .forward(node_slice, &neighbors_vec, weights_slice);
98
99 Ok(Float32Array::new(result))
100 }
101
102 /// Serialize the layer to JSON
103 #[napi]
104 pub fn to_json(&self) -> Result<String> {
105 serde_json::to_string(&self.inner).map_err(|e| {
106 Error::new(
107 Status::GenericFailure,
108 format!("Serialization error: {}", e),
109 )
110 })
111 }
112
113 /// Deserialize the layer from JSON
114 #[napi(factory)]
115 pub fn from_json(json: String) -> Result<Self> {
116 let inner: RustRuvectorLayer = serde_json::from_str(&json).map_err(|e| {
117 Error::new(
118 Status::GenericFailure,
119 format!("Deserialization error: {}", e),
120 )
121 })?;
122 Ok(Self { inner })
123 }
124}
125
126// ==================== TensorCompress Bindings ====================
127
128/// Compression level for tensor compression
129#[napi(object)]
130pub struct CompressionLevelConfig {
131 /// Type of compression: "none", "half", "pq8", "pq4", "binary"
132 pub level_type: String,
133 /// Scale factor (for "half" compression)
134 pub scale: Option<f64>,
135 /// Number of subvectors (for PQ compression)
136 pub subvectors: Option<u32>,
137 /// Number of centroids (for PQ8)
138 pub centroids: Option<u32>,
139 /// Outlier threshold (for PQ4)
140 pub outlier_threshold: Option<f64>,
141 /// Binary threshold (for binary compression)
142 pub threshold: Option<f64>,
143}
144
145impl CompressionLevelConfig {
146 fn to_rust(&self) -> Result<RustCompressionLevel> {
147 match self.level_type.as_str() {
148 "none" => Ok(RustCompressionLevel::None),
149 "half" => Ok(RustCompressionLevel::Half {
150 scale: self.scale.unwrap_or(1.0) as f32,
151 }),
152 "pq8" => Ok(RustCompressionLevel::PQ8 {
153 subvectors: self.subvectors.unwrap_or(8) as u8,
154 centroids: self.centroids.unwrap_or(16) as u8,
155 }),
156 "pq4" => Ok(RustCompressionLevel::PQ4 {
157 subvectors: self.subvectors.unwrap_or(8) as u8,
158 outlier_threshold: self.outlier_threshold.unwrap_or(3.0) as f32,
159 }),
160 "binary" => Ok(RustCompressionLevel::Binary {
161 threshold: self.threshold.unwrap_or(0.0) as f32,
162 }),
163 _ => Err(Error::new(
164 Status::InvalidArg,
165 format!("Invalid compression level: {}", self.level_type),
166 )),
167 }
168 }
169}
170
171/// Tensor compressor with adaptive level selection
172#[napi]
173pub struct TensorCompress {
174 inner: RustTensorCompress,
175}
176
177#[napi]
178impl TensorCompress {
179 /// Create a new tensor compressor
180 ///
181 /// # Example
182 /// ```javascript
183 /// const compressor = new TensorCompress();
184 /// ```
185 #[napi(constructor)]
186 pub fn new() -> Self {
187 Self {
188 inner: RustTensorCompress::new(),
189 }
190 }
191
192 /// Compress an embedding based on access frequency
193 ///
194 /// # Arguments
195 /// * `embedding` - The input embedding vector (Float32Array)
196 /// * `access_freq` - Access frequency in range [0.0, 1.0]
197 ///
198 /// # Returns
199 /// Compressed tensor as JSON string
200 ///
201 /// # Example
202 /// ```javascript
203 /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
204 /// const compressed = compressor.compress(embedding, 0.5);
205 /// ```
206 #[napi]
207 pub fn compress(&self, embedding: Float32Array, access_freq: f64) -> Result<String> {
208 let embedding_slice = embedding.as_ref();
209
210 let compressed = self
211 .inner
212 .compress(embedding_slice, access_freq as f32)
213 .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
214
215 serde_json::to_string(&compressed).map_err(|e| {
216 Error::new(
217 Status::GenericFailure,
218 format!("Serialization error: {}", e),
219 )
220 })
221 }
222
223 /// Compress with explicit compression level
224 ///
225 /// # Arguments
226 /// * `embedding` - The input embedding vector (Float32Array)
227 /// * `level` - Compression level configuration
228 ///
229 /// # Returns
230 /// Compressed tensor as JSON string
231 ///
232 /// # Example
233 /// ```javascript
234 /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
235 /// const level = { level_type: "half", scale: 1.0 };
236 /// const compressed = compressor.compressWithLevel(embedding, level);
237 /// ```
238 #[napi]
239 pub fn compress_with_level(
240 &self,
241 embedding: Float32Array,
242 level: CompressionLevelConfig,
243 ) -> Result<String> {
244 let embedding_slice = embedding.as_ref();
245 let rust_level = level.to_rust()?;
246
247 let compressed = self
248 .inner
249 .compress_with_level(embedding_slice, &rust_level)
250 .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
251
252 serde_json::to_string(&compressed).map_err(|e| {
253 Error::new(
254 Status::GenericFailure,
255 format!("Serialization error: {}", e),
256 )
257 })
258 }
259
260 /// Decompress a compressed tensor
261 ///
262 /// # Arguments
263 /// * `compressed_json` - Compressed tensor as JSON string
264 ///
265 /// # Returns
266 /// Decompressed embedding vector as Float32Array
267 ///
268 /// # Example
269 /// ```javascript
270 /// const decompressed = compressor.decompress(compressed);
271 /// ```
272 #[napi]
273 pub fn decompress(&self, compressed_json: String) -> Result<Float32Array> {
274 let compressed: RustCompressedTensor =
275 serde_json::from_str(&compressed_json).map_err(|e| {
276 Error::new(
277 Status::GenericFailure,
278 format!("Deserialization error: {}", e),
279 )
280 })?;
281
282 let result = self.inner.decompress(&compressed).map_err(|e| {
283 Error::new(
284 Status::GenericFailure,
285 format!("Decompression error: {}", e),
286 )
287 })?;
288
289 Ok(Float32Array::new(result))
290 }
291}
292
293// ==================== Search Functions ====================
294
295/// Result from differentiable search
296#[napi(object)]
297pub struct SearchResult {
298 /// Indices of top-k candidates
299 pub indices: Vec<u32>,
300 /// Soft weights for top-k candidates
301 pub weights: Vec<f64>,
302}
303
304/// Differentiable search using soft attention mechanism
305///
306/// # Arguments
307/// * `query` - The query vector (Float32Array)
308/// * `candidate_embeddings` - List of candidate embedding vectors (Array of Float32Array)
309/// * `k` - Number of top results to return
310/// * `temperature` - Temperature for softmax (lower = sharper, higher = smoother)
311///
312/// # Returns
313/// Search result with indices and soft weights
314///
315/// # Example
316/// ```javascript
317/// const query = new Float32Array([1.0, 0.0, 0.0]);
318/// const candidates = [new Float32Array([1.0, 0.0, 0.0]), new Float32Array([0.9, 0.1, 0.0]), new Float32Array([0.0, 1.0, 0.0])];
319/// const result = differentiableSearch(query, candidates, 2, 1.0);
320/// console.log(result.indices); // [0, 1]
321/// console.log(result.weights); // [0.x, 0.y]
322/// ```
323#[napi]
324pub fn differentiable_search(
325 query: Float32Array,
326 candidate_embeddings: Vec<Float32Array>,
327 k: u32,
328 temperature: f64,
329) -> Result<SearchResult> {
330 let query_slice = query.as_ref();
331 let candidates_vec: Vec<Vec<f32>> = candidate_embeddings
332 .into_iter()
333 .map(|arr| arr.to_vec())
334 .collect();
335
336 let (indices, weights) =
337 rust_differentiable_search(query_slice, &candidates_vec, k as usize, temperature as f32);
338
339 Ok(SearchResult {
340 indices: indices.iter().map(|&i| i as u32).collect(),
341 weights: weights.iter().map(|&w| w as f64).collect(),
342 })
343}
344
345/// Hierarchical forward pass through GNN layers
346///
347/// # Arguments
348/// * `query` - The query vector (Float32Array)
349/// * `layer_embeddings` - Embeddings organized by layer (Array of Array of Float32Array)
350/// * `gnn_layers_json` - JSON array of serialized GNN layers
351///
352/// # Returns
353/// Final embedding after hierarchical processing as Float32Array
354///
355/// # Example
356/// ```javascript
357/// const query = new Float32Array([1.0, 0.0]);
358/// const layerEmbeddings = [[new Float32Array([1.0, 0.0]), new Float32Array([0.0, 1.0])]];
359/// const layer1 = new RuvectorLayer(2, 2, 1, 0.0);
360/// const layers = [layer1.toJson()];
361/// const result = hierarchicalForward(query, layerEmbeddings, layers);
362/// ```
363#[napi]
364pub fn hierarchical_forward(
365 query: Float32Array,
366 layer_embeddings: Vec<Vec<Float32Array>>,
367 gnn_layers_json: Vec<String>,
368) -> Result<Float32Array> {
369 let query_slice = query.as_ref();
370
371 let embeddings_f32: Vec<Vec<Vec<f32>>> = layer_embeddings
372 .into_iter()
373 .map(|layer| layer.into_iter().map(|arr| arr.to_vec()).collect())
374 .collect();
375
376 let gnn_layers: Vec<RustRuvectorLayer> = gnn_layers_json
377 .iter()
378 .map(|json| {
379 serde_json::from_str(json).map_err(|e| {
380 Error::new(
381 Status::GenericFailure,
382 format!("Layer deserialization error: {}", e),
383 )
384 })
385 })
386 .collect::<Result<Vec<_>>>()?;
387
388 let result = rust_hierarchical_forward(query_slice, &embeddings_f32, &gnn_layers);
389
390 Ok(Float32Array::new(result))
391}
392
393// ==================== Helper Functions ====================
394
395/// Get the compression level that would be selected for a given access frequency
396///
397/// # Arguments
398/// * `access_freq` - Access frequency in range [0.0, 1.0]
399///
400/// # Returns
401/// String describing the compression level: "none", "half", "pq8", "pq4", or "binary"
402///
403/// # Example
404/// ```javascript
405/// const level = getCompressionLevel(0.9); // "none" (hot data)
406/// const level2 = getCompressionLevel(0.5); // "half" (warm data)
407/// ```
408#[napi]
409pub fn get_compression_level(access_freq: f64) -> String {
410 if access_freq > 0.8 {
411 "none".to_string()
412 } else if access_freq > 0.4 {
413 "half".to_string()
414 } else if access_freq > 0.1 {
415 "pq8".to_string()
416 } else if access_freq > 0.01 {
417 "pq4".to_string()
418 } else {
419 "binary".to_string()
420 }
421}
422
423/// Module initialization
424#[napi]
425pub fn init() -> String {
426 "Ruvector GNN Node.js bindings initialized".to_string()
427}