ruvector_gnn_node/lib.rs
1//! Node.js bindings for Ruvector GNN via NAPI-RS
2//!
3//! This module provides JavaScript bindings for the Ruvector GNN library,
4//! enabling graph neural network operations, tensor compression, and
5//! differentiable search in Node.js applications.
6
7#![deny(clippy::all)]
8
9use napi::bindgen_prelude::*;
10use napi_derive::napi;
11use ruvector_gnn::{
12 compress::{
13 CompressedTensor as RustCompressedTensor, CompressionLevel as RustCompressionLevel,
14 TensorCompress as RustTensorCompress,
15 },
16 layer::RuvectorLayer as RustRuvectorLayer,
17 search::{
18 differentiable_search as rust_differentiable_search,
19 hierarchical_forward as rust_hierarchical_forward,
20 },
21};
22
23// ==================== RuvectorLayer Bindings ====================
24
25/// Graph Neural Network layer for HNSW topology
26#[napi]
27pub struct RuvectorLayer {
28 inner: RustRuvectorLayer,
29}
30
31#[napi]
32impl RuvectorLayer {
33 /// Create a new Ruvector GNN layer
34 ///
35 /// # Arguments
36 /// * `input_dim` - Dimension of input node embeddings
37 /// * `hidden_dim` - Dimension of hidden representations
38 /// * `heads` - Number of attention heads
39 /// * `dropout` - Dropout rate (0.0 to 1.0)
40 ///
41 /// # Example
42 /// ```javascript
43 /// const layer = new RuvectorLayer(128, 256, 4, 0.1);
44 /// ```
45 #[napi(constructor)]
46 pub fn new(input_dim: u32, hidden_dim: u32, heads: u32, dropout: f64) -> Result<Self> {
47 let inner = RustRuvectorLayer::new(
48 input_dim as usize,
49 hidden_dim as usize,
50 heads as usize,
51 dropout as f32,
52 )
53 .map_err(|e| Error::new(Status::InvalidArg, e.to_string()))?;
54
55 Ok(Self { inner })
56 }
57
58 /// Forward pass through the GNN layer
59 ///
60 /// # Arguments
61 /// * `node_embedding` - Current node's embedding (Float32Array)
62 /// * `neighbor_embeddings` - Embeddings of neighbor nodes (Array of Float32Array)
63 /// * `edge_weights` - Weights of edges to neighbors (Float32Array)
64 ///
65 /// # Returns
66 /// Updated node embedding as Float32Array
67 ///
68 /// # Example
69 /// ```javascript
70 /// const node = new Float32Array([1.0, 2.0, 3.0, 4.0]);
71 /// const neighbors = [new Float32Array([0.5, 1.0, 1.5, 2.0]), new Float32Array([2.0, 3.0, 4.0, 5.0])];
72 /// const weights = new Float32Array([0.3, 0.7]);
73 /// const output = layer.forward(node, neighbors, weights);
74 /// ```
75 #[napi]
76 pub fn forward(
77 &self,
78 node_embedding: Float32Array,
79 neighbor_embeddings: Vec<Float32Array>,
80 edge_weights: Float32Array,
81 ) -> Result<Float32Array> {
82 let node_slice = node_embedding.as_ref();
83 let neighbors_vec: Vec<Vec<f32>> = neighbor_embeddings
84 .into_iter()
85 .map(|arr| arr.to_vec())
86 .collect();
87 let weights_slice = edge_weights.as_ref();
88
89 let result = self
90 .inner
91 .forward(node_slice, &neighbors_vec, weights_slice);
92
93 Ok(Float32Array::new(result))
94 }
95
96 /// Serialize the layer to JSON
97 #[napi]
98 pub fn to_json(&self) -> Result<String> {
99 serde_json::to_string(&self.inner).map_err(|e| {
100 Error::new(
101 Status::GenericFailure,
102 format!("Serialization error: {}", e),
103 )
104 })
105 }
106
107 /// Deserialize the layer from JSON
108 #[napi(factory)]
109 pub fn from_json(json: String) -> Result<Self> {
110 let inner: RustRuvectorLayer = serde_json::from_str(&json).map_err(|e| {
111 Error::new(
112 Status::GenericFailure,
113 format!("Deserialization error: {}", e),
114 )
115 })?;
116 Ok(Self { inner })
117 }
118}
119
120// ==================== TensorCompress Bindings ====================
121
122/// Compression level for tensor compression
123#[napi(object)]
124pub struct CompressionLevelConfig {
125 /// Type of compression: "none", "half", "pq8", "pq4", "binary"
126 pub level_type: String,
127 /// Scale factor (for "half" compression)
128 pub scale: Option<f64>,
129 /// Number of subvectors (for PQ compression)
130 pub subvectors: Option<u32>,
131 /// Number of centroids (for PQ8)
132 pub centroids: Option<u32>,
133 /// Outlier threshold (for PQ4)
134 pub outlier_threshold: Option<f64>,
135 /// Binary threshold (for binary compression)
136 pub threshold: Option<f64>,
137}
138
139impl CompressionLevelConfig {
140 fn to_rust(&self) -> Result<RustCompressionLevel> {
141 match self.level_type.as_str() {
142 "none" => Ok(RustCompressionLevel::None),
143 "half" => Ok(RustCompressionLevel::Half {
144 scale: self.scale.unwrap_or(1.0) as f32,
145 }),
146 "pq8" => Ok(RustCompressionLevel::PQ8 {
147 subvectors: self.subvectors.unwrap_or(8) as u8,
148 centroids: self.centroids.unwrap_or(16) as u8,
149 }),
150 "pq4" => Ok(RustCompressionLevel::PQ4 {
151 subvectors: self.subvectors.unwrap_or(8) as u8,
152 outlier_threshold: self.outlier_threshold.unwrap_or(3.0) as f32,
153 }),
154 "binary" => Ok(RustCompressionLevel::Binary {
155 threshold: self.threshold.unwrap_or(0.0) as f32,
156 }),
157 _ => Err(Error::new(
158 Status::InvalidArg,
159 format!("Invalid compression level: {}", self.level_type),
160 )),
161 }
162 }
163}
164
165/// Tensor compressor with adaptive level selection
166#[napi]
167pub struct TensorCompress {
168 inner: RustTensorCompress,
169}
170
171#[napi]
172impl TensorCompress {
173 /// Create a new tensor compressor
174 ///
175 /// # Example
176 /// ```javascript
177 /// const compressor = new TensorCompress();
178 /// ```
179 #[napi(constructor)]
180 pub fn new() -> Self {
181 Self {
182 inner: RustTensorCompress::new(),
183 }
184 }
185
186 /// Compress an embedding based on access frequency
187 ///
188 /// # Arguments
189 /// * `embedding` - The input embedding vector (Float32Array)
190 /// * `access_freq` - Access frequency in range [0.0, 1.0]
191 ///
192 /// # Returns
193 /// Compressed tensor as JSON string
194 ///
195 /// # Example
196 /// ```javascript
197 /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
198 /// const compressed = compressor.compress(embedding, 0.5);
199 /// ```
200 #[napi]
201 pub fn compress(&self, embedding: Float32Array, access_freq: f64) -> Result<String> {
202 let embedding_slice = embedding.as_ref();
203
204 let compressed = self
205 .inner
206 .compress(embedding_slice, access_freq as f32)
207 .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
208
209 serde_json::to_string(&compressed).map_err(|e| {
210 Error::new(
211 Status::GenericFailure,
212 format!("Serialization error: {}", e),
213 )
214 })
215 }
216
217 /// Compress with explicit compression level
218 ///
219 /// # Arguments
220 /// * `embedding` - The input embedding vector (Float32Array)
221 /// * `level` - Compression level configuration
222 ///
223 /// # Returns
224 /// Compressed tensor as JSON string
225 ///
226 /// # Example
227 /// ```javascript
228 /// const embedding = new Float32Array([1.0, 2.0, 3.0, 4.0]);
229 /// const level = { level_type: "half", scale: 1.0 };
230 /// const compressed = compressor.compressWithLevel(embedding, level);
231 /// ```
232 #[napi]
233 pub fn compress_with_level(
234 &self,
235 embedding: Float32Array,
236 level: CompressionLevelConfig,
237 ) -> Result<String> {
238 let embedding_slice = embedding.as_ref();
239 let rust_level = level.to_rust()?;
240
241 let compressed = self
242 .inner
243 .compress_with_level(embedding_slice, &rust_level)
244 .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
245
246 serde_json::to_string(&compressed).map_err(|e| {
247 Error::new(
248 Status::GenericFailure,
249 format!("Serialization error: {}", e),
250 )
251 })
252 }
253
254 /// Decompress a compressed tensor
255 ///
256 /// # Arguments
257 /// * `compressed_json` - Compressed tensor as JSON string
258 ///
259 /// # Returns
260 /// Decompressed embedding vector as Float32Array
261 ///
262 /// # Example
263 /// ```javascript
264 /// const decompressed = compressor.decompress(compressed);
265 /// ```
266 #[napi]
267 pub fn decompress(&self, compressed_json: String) -> Result<Float32Array> {
268 let compressed: RustCompressedTensor =
269 serde_json::from_str(&compressed_json).map_err(|e| {
270 Error::new(
271 Status::GenericFailure,
272 format!("Deserialization error: {}", e),
273 )
274 })?;
275
276 let result = self.inner.decompress(&compressed).map_err(|e| {
277 Error::new(
278 Status::GenericFailure,
279 format!("Decompression error: {}", e),
280 )
281 })?;
282
283 Ok(Float32Array::new(result))
284 }
285}
286
287// ==================== Search Functions ====================
288
289/// Result from differentiable search
290#[napi(object)]
291pub struct SearchResult {
292 /// Indices of top-k candidates
293 pub indices: Vec<u32>,
294 /// Soft weights for top-k candidates
295 pub weights: Vec<f64>,
296}
297
298/// Differentiable search using soft attention mechanism
299///
300/// # Arguments
301/// * `query` - The query vector (Float32Array)
302/// * `candidate_embeddings` - List of candidate embedding vectors (Array of Float32Array)
303/// * `k` - Number of top results to return
304/// * `temperature` - Temperature for softmax (lower = sharper, higher = smoother)
305///
306/// # Returns
307/// Search result with indices and soft weights
308///
309/// # Example
310/// ```javascript
311/// const query = new Float32Array([1.0, 0.0, 0.0]);
312/// const candidates = [new Float32Array([1.0, 0.0, 0.0]), new Float32Array([0.9, 0.1, 0.0]), new Float32Array([0.0, 1.0, 0.0])];
313/// const result = differentiableSearch(query, candidates, 2, 1.0);
314/// console.log(result.indices); // [0, 1]
315/// console.log(result.weights); // [0.x, 0.y]
316/// ```
317#[napi]
318pub fn differentiable_search(
319 query: Float32Array,
320 candidate_embeddings: Vec<Float32Array>,
321 k: u32,
322 temperature: f64,
323) -> Result<SearchResult> {
324 let query_slice = query.as_ref();
325 let candidates_vec: Vec<Vec<f32>> = candidate_embeddings
326 .into_iter()
327 .map(|arr| arr.to_vec())
328 .collect();
329
330 let (indices, weights) =
331 rust_differentiable_search(query_slice, &candidates_vec, k as usize, temperature as f32);
332
333 Ok(SearchResult {
334 indices: indices.iter().map(|&i| i as u32).collect(),
335 weights: weights.iter().map(|&w| w as f64).collect(),
336 })
337}
338
339/// Hierarchical forward pass through GNN layers
340///
341/// # Arguments
342/// * `query` - The query vector (Float32Array)
343/// * `layer_embeddings` - Embeddings organized by layer (Array of Array of Float32Array)
344/// * `gnn_layers_json` - JSON array of serialized GNN layers
345///
346/// # Returns
347/// Final embedding after hierarchical processing as Float32Array
348///
349/// # Example
350/// ```javascript
351/// const query = new Float32Array([1.0, 0.0]);
352/// const layerEmbeddings = [[new Float32Array([1.0, 0.0]), new Float32Array([0.0, 1.0])]];
353/// const layer1 = new RuvectorLayer(2, 2, 1, 0.0);
354/// const layers = [layer1.toJson()];
355/// const result = hierarchicalForward(query, layerEmbeddings, layers);
356/// ```
357#[napi]
358pub fn hierarchical_forward(
359 query: Float32Array,
360 layer_embeddings: Vec<Vec<Float32Array>>,
361 gnn_layers_json: Vec<String>,
362) -> Result<Float32Array> {
363 let query_slice = query.as_ref();
364
365 let embeddings_f32: Vec<Vec<Vec<f32>>> = layer_embeddings
366 .into_iter()
367 .map(|layer| layer.into_iter().map(|arr| arr.to_vec()).collect())
368 .collect();
369
370 let gnn_layers: Vec<RustRuvectorLayer> = gnn_layers_json
371 .iter()
372 .map(|json| {
373 serde_json::from_str(json).map_err(|e| {
374 Error::new(
375 Status::GenericFailure,
376 format!("Layer deserialization error: {}", e),
377 )
378 })
379 })
380 .collect::<Result<Vec<_>>>()?;
381
382 let result = rust_hierarchical_forward(query_slice, &embeddings_f32, &gnn_layers);
383
384 Ok(Float32Array::new(result))
385}
386
387// ==================== Helper Functions ====================
388
389/// Get the compression level that would be selected for a given access frequency
390///
391/// # Arguments
392/// * `access_freq` - Access frequency in range [0.0, 1.0]
393///
394/// # Returns
395/// String describing the compression level: "none", "half", "pq8", "pq4", or "binary"
396///
397/// # Example
398/// ```javascript
399/// const level = getCompressionLevel(0.9); // "none" (hot data)
400/// const level2 = getCompressionLevel(0.5); // "half" (warm data)
401/// ```
402#[napi]
403pub fn get_compression_level(access_freq: f64) -> String {
404 if access_freq > 0.8 {
405 "none".to_string()
406 } else if access_freq > 0.4 {
407 "half".to_string()
408 } else if access_freq > 0.1 {
409 "pq8".to_string()
410 } else if access_freq > 0.01 {
411 "pq4".to_string()
412 } else {
413 "binary".to_string()
414 }
415}
416
417/// Module initialization
418#[napi]
419pub fn init() -> String {
420 "Ruvector GNN Node.js bindings initialized".to_string()
421}