ruvector_gnn_node/lib.rs
1//! Node.js bindings for Ruvector GNN via NAPI-RS
2//!
3//! This module provides JavaScript bindings for the Ruvector GNN library,
4//! enabling graph neural network operations, tensor compression, and
5//! differentiable search in Node.js applications.
6
7#![deny(clippy::all)]
8
9use napi::bindgen_prelude::*;
10use napi_derive::napi;
11use ruvector_gnn::{
12 compress::{
13 CompressedTensor as RustCompressedTensor, CompressionLevel as RustCompressionLevel,
14 TensorCompress as RustTensorCompress,
15 },
16 layer::RuvectorLayer as RustRuvectorLayer,
17 search::{
18 differentiable_search as rust_differentiable_search,
19 hierarchical_forward as rust_hierarchical_forward,
20 },
21};
22
23// ==================== RuvectorLayer Bindings ====================
24
25/// Graph Neural Network layer for HNSW topology
26#[napi]
27pub struct RuvectorLayer {
28 inner: RustRuvectorLayer,
29}
30
31#[napi]
32impl RuvectorLayer {
33 /// Create a new Ruvector GNN layer
34 ///
35 /// # Arguments
36 /// * `input_dim` - Dimension of input node embeddings
37 /// * `hidden_dim` - Dimension of hidden representations
38 /// * `heads` - Number of attention heads
39 /// * `dropout` - Dropout rate (0.0 to 1.0)
40 ///
41 /// # Example
42 /// ```javascript
43 /// const layer = new RuvectorLayer(128, 256, 4, 0.1);
44 /// ```
45 #[napi(constructor)]
46 pub fn new(input_dim: u32, hidden_dim: u32, heads: u32, dropout: f64) -> Result<Self> {
47 if dropout < 0.0 || dropout > 1.0 {
48 return Err(Error::new(
49 Status::InvalidArg,
50 "Dropout must be between 0.0 and 1.0".to_string(),
51 ));
52 }
53
54 Ok(Self {
55 inner: RustRuvectorLayer::new(
56 input_dim as usize,
57 hidden_dim as usize,
58 heads as usize,
59 dropout as f32,
60 ),
61 })
62 }
63
64 /// Forward pass through the GNN layer
65 ///
66 /// # Arguments
67 /// * `node_embedding` - Current node's embedding (Float32Array)
68 /// * `neighbor_embeddings` - Embeddings of neighbor nodes (Array of Float32Array)
69 /// * `edge_weights` - Weights of edges to neighbors (Float32Array)
70 ///
71 /// # Returns
72 /// Updated node embedding as Float32Array
73 ///
74 /// # Example
75 /// ```javascript
76 /// const node = new Float32Array([1.0, 2.0, 3.0, 4.0]);
77 /// const neighbors = [new Float32Array([0.5, 1.0, 1.5, 2.0]), new Float32Array([2.0, 3.0, 4.0, 5.0])];
78 /// const weights = new Float32Array([0.3, 0.7]);
79 /// const output = layer.forward(node, neighbors, weights);
80 /// ```
81 #[napi]
82 pub fn forward(
83 &self,
84 node_embedding: Float32Array,
85 neighbor_embeddings: Vec<Float32Array>,
86 edge_weights: Float32Array,
87 ) -> Result<Float32Array> {
88 let node_slice = node_embedding.as_ref();
89 let neighbors_vec: Vec<Vec<f32>> = neighbor_embeddings
90 .into_iter()
91 .map(|arr| arr.to_vec())
92 .collect();
93 let weights_slice = edge_weights.as_ref();
94
95 let result = self.inner.forward(node_slice, &neighbors_vec, weights_slice);
96
97 Ok(Float32Array::new(result))
98 }
99
100 /// Serialize the layer to JSON
101 #[napi]
102 pub fn to_json(&self) -> Result<String> {
103 serde_json::to_string(&self.inner).map_err(|e| {
104 Error::new(
105 Status::GenericFailure,
106 format!("Serialization error: {}", e),
107 )
108 })
109 }
110
111 /// Deserialize the layer from JSON
112 #[napi(factory)]
113 pub fn from_json(json: String) -> Result<Self> {
114 let inner: RustRuvectorLayer = serde_json::from_str(&json).map_err(|e| {
115 Error::new(
116 Status::GenericFailure,
117 format!("Deserialization error: {}", e),
118 )
119 })?;
120 Ok(Self { inner })
121 }
122}
123
124// ==================== TensorCompress Bindings ====================
125
126/// Compression level for tensor compression
127#[napi(object)]
128pub struct CompressionLevelConfig {
129 /// Type of compression: "none", "half", "pq8", "pq4", "binary"
130 pub level_type: String,
131 /// Scale factor (for "half" compression)
132 pub scale: Option<f64>,
133 /// Number of subvectors (for PQ compression)
134 pub subvectors: Option<u32>,
135 /// Number of centroids (for PQ8)
136 pub centroids: Option<u32>,
137 /// Outlier threshold (for PQ4)
138 pub outlier_threshold: Option<f64>,
139 /// Binary threshold (for binary compression)
140 pub threshold: Option<f64>,
141}
142
143impl CompressionLevelConfig {
144 fn to_rust(&self) -> Result<RustCompressionLevel> {
145 match self.level_type.as_str() {
146 "none" => Ok(RustCompressionLevel::None),
147 "half" => Ok(RustCompressionLevel::Half {
148 scale: self.scale.unwrap_or(1.0) as f32,
149 }),
150 "pq8" => Ok(RustCompressionLevel::PQ8 {
151 subvectors: self.subvectors.unwrap_or(8) as u8,
152 centroids: self.centroids.unwrap_or(16) as u8,
153 }),
154 "pq4" => Ok(RustCompressionLevel::PQ4 {
155 subvectors: self.subvectors.unwrap_or(8) as u8,
156 outlier_threshold: self.outlier_threshold.unwrap_or(3.0) as f32,
157 }),
158 "binary" => Ok(RustCompressionLevel::Binary {
159 threshold: self.threshold.unwrap_or(0.0) as f32,
160 }),
161 _ => Err(Error::new(
162 Status::InvalidArg,
163 format!("Invalid compression level: {}", self.level_type),
164 )),
165 }
166 }
167}
168
169/// Tensor compressor with adaptive level selection
170#[napi]
171pub struct TensorCompress {
172 inner: RustTensorCompress,
173}
174
175#[napi]
176impl TensorCompress {
177 /// Create a new tensor compressor
178 ///
179 /// # Example
180 /// ```javascript
181 /// const compressor = new TensorCompress();
182 /// ```
183 #[napi(constructor)]
184 pub fn new() -> Self {
185 Self {
186 inner: RustTensorCompress::new(),
187 }
188 }
189
190 /// Compress an embedding based on access frequency
191 ///
192 /// # Arguments
193 /// * `embedding` - The input embedding vector (Float32Array)
194 /// * `access_freq` - Access frequency in range [0.0, 1.0]
195 ///
196 /// # Returns
197 /// Compressed tensor as JSON string
198 ///
199 /// # Example
200 /// ```javascript
201 /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
202 /// const compressed = compressor.compress(embedding, 0.5);
203 /// ```
204 #[napi]
205 pub fn compress(&self, embedding: Float32Array, access_freq: f64) -> Result<String> {
206 let embedding_slice = embedding.as_ref();
207
208 let compressed = self
209 .inner
210 .compress(embedding_slice, access_freq as f32)
211 .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
212
213 serde_json::to_string(&compressed).map_err(|e| {
214 Error::new(
215 Status::GenericFailure,
216 format!("Serialization error: {}", e),
217 )
218 })
219 }
220
221 /// Compress with explicit compression level
222 ///
223 /// # Arguments
224 /// * `embedding` - The input embedding vector (Float32Array)
225 /// * `level` - Compression level configuration
226 ///
227 /// # Returns
228 /// Compressed tensor as JSON string
229 ///
230 /// # Example
231 /// ```javascript
232 /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
233 /// const level = { level_type: "half", scale: 1.0 };
234 /// const compressed = compressor.compressWithLevel(embedding, level);
235 /// ```
236 #[napi]
237 pub fn compress_with_level(
238 &self,
239 embedding: Float32Array,
240 level: CompressionLevelConfig,
241 ) -> Result<String> {
242 let embedding_slice = embedding.as_ref();
243 let rust_level = level.to_rust()?;
244
245 let compressed = self
246 .inner
247 .compress_with_level(embedding_slice, &rust_level)
248 .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
249
250 serde_json::to_string(&compressed).map_err(|e| {
251 Error::new(
252 Status::GenericFailure,
253 format!("Serialization error: {}", e),
254 )
255 })
256 }
257
258 /// Decompress a compressed tensor
259 ///
260 /// # Arguments
261 /// * `compressed_json` - Compressed tensor as JSON string
262 ///
263 /// # Returns
264 /// Decompressed embedding vector as Float32Array
265 ///
266 /// # Example
267 /// ```javascript
268 /// const decompressed = compressor.decompress(compressed);
269 /// ```
270 #[napi]
271 pub fn decompress(&self, compressed_json: String) -> Result<Float32Array> {
272 let compressed: RustCompressedTensor =
273 serde_json::from_str(&compressed_json).map_err(|e| {
274 Error::new(
275 Status::GenericFailure,
276 format!("Deserialization error: {}", e),
277 )
278 })?;
279
280 let result = self.inner.decompress(&compressed).map_err(|e| {
281 Error::new(
282 Status::GenericFailure,
283 format!("Decompression error: {}", e),
284 )
285 })?;
286
287 Ok(Float32Array::new(result))
288 }
289}
290
291// ==================== Search Functions ====================
292
293/// Result from differentiable search
294#[napi(object)]
295pub struct SearchResult {
296 /// Indices of top-k candidates
297 pub indices: Vec<u32>,
298 /// Soft weights for top-k candidates
299 pub weights: Vec<f64>,
300}
301
302/// Differentiable search using soft attention mechanism
303///
304/// # Arguments
305/// * `query` - The query vector (Float32Array)
306/// * `candidate_embeddings` - List of candidate embedding vectors (Array of Float32Array)
307/// * `k` - Number of top results to return
308/// * `temperature` - Temperature for softmax (lower = sharper, higher = smoother)
309///
310/// # Returns
311/// Search result with indices and soft weights
312///
313/// # Example
314/// ```javascript
315/// const query = new Float32Array([1.0, 0.0, 0.0]);
316/// const candidates = [new Float32Array([1.0, 0.0, 0.0]), new Float32Array([0.9, 0.1, 0.0]), new Float32Array([0.0, 1.0, 0.0])];
317/// const result = differentiableSearch(query, candidates, 2, 1.0);
318/// console.log(result.indices); // [0, 1]
319/// console.log(result.weights); // [0.x, 0.y]
320/// ```
321#[napi]
322pub fn differentiable_search(
323 query: Float32Array,
324 candidate_embeddings: Vec<Float32Array>,
325 k: u32,
326 temperature: f64,
327) -> Result<SearchResult> {
328 let query_slice = query.as_ref();
329 let candidates_vec: Vec<Vec<f32>> = candidate_embeddings
330 .into_iter()
331 .map(|arr| arr.to_vec())
332 .collect();
333
334 let (indices, weights) =
335 rust_differentiable_search(query_slice, &candidates_vec, k as usize, temperature as f32);
336
337 Ok(SearchResult {
338 indices: indices.iter().map(|&i| i as u32).collect(),
339 weights: weights.iter().map(|&w| w as f64).collect(),
340 })
341}
342
343/// Hierarchical forward pass through GNN layers
344///
345/// # Arguments
346/// * `query` - The query vector (Float32Array)
347/// * `layer_embeddings` - Embeddings organized by layer (Array of Array of Float32Array)
348/// * `gnn_layers_json` - JSON array of serialized GNN layers
349///
350/// # Returns
351/// Final embedding after hierarchical processing as Float32Array
352///
353/// # Example
354/// ```javascript
355/// const query = new Float32Array([1.0, 0.0]);
356/// const layerEmbeddings = [[new Float32Array([1.0, 0.0]), new Float32Array([0.0, 1.0])]];
357/// const layer1 = new RuvectorLayer(2, 2, 1, 0.0);
358/// const layers = [layer1.toJson()];
359/// const result = hierarchicalForward(query, layerEmbeddings, layers);
360/// ```
361#[napi]
362pub fn hierarchical_forward(
363 query: Float32Array,
364 layer_embeddings: Vec<Vec<Float32Array>>,
365 gnn_layers_json: Vec<String>,
366) -> Result<Float32Array> {
367 let query_slice = query.as_ref();
368
369 let embeddings_f32: Vec<Vec<Vec<f32>>> = layer_embeddings
370 .into_iter()
371 .map(|layer| {
372 layer
373 .into_iter()
374 .map(|arr| arr.to_vec())
375 .collect()
376 })
377 .collect();
378
379 let gnn_layers: Vec<RustRuvectorLayer> = gnn_layers_json
380 .iter()
381 .map(|json| {
382 serde_json::from_str(json).map_err(|e| {
383 Error::new(
384 Status::GenericFailure,
385 format!("Layer deserialization error: {}", e),
386 )
387 })
388 })
389 .collect::<Result<Vec<_>>>()?;
390
391 let result = rust_hierarchical_forward(query_slice, &embeddings_f32, &gnn_layers);
392
393 Ok(Float32Array::new(result))
394}
395
396// ==================== Helper Functions ====================
397
398/// Get the compression level that would be selected for a given access frequency
399///
400/// # Arguments
401/// * `access_freq` - Access frequency in range [0.0, 1.0]
402///
403/// # Returns
404/// String describing the compression level: "none", "half", "pq8", "pq4", or "binary"
405///
406/// # Example
407/// ```javascript
408/// const level = getCompressionLevel(0.9); // "none" (hot data)
409/// const level2 = getCompressionLevel(0.5); // "half" (warm data)
410/// ```
411#[napi]
412pub fn get_compression_level(access_freq: f64) -> String {
413 if access_freq > 0.8 {
414 "none".to_string()
415 } else if access_freq > 0.4 {
416 "half".to_string()
417 } else if access_freq > 0.1 {
418 "pq8".to_string()
419 } else if access_freq > 0.01 {
420 "pq4".to_string()
421 } else {
422 "binary".to_string()
423 }
424}
425
426/// Module initialization
427#[napi]
428pub fn init() -> String {
429 "Ruvector GNN Node.js bindings initialized".to_string()
430}