ruvector_gnn_node/lib.rs
1//! Node.js bindings for Ruvector GNN via NAPI-RS
2//!
3//! This module provides JavaScript bindings for the Ruvector GNN library,
4//! enabling graph neural network operations, tensor compression, and
5//! differentiable search in Node.js applications.
6
7#![deny(clippy::all)]
8
9use napi::bindgen_prelude::*;
10use napi_derive::napi;
11use ruvector_gnn::{
12 compress::{
13 CompressedTensor as RustCompressedTensor, CompressionLevel as RustCompressionLevel,
14 TensorCompress as RustTensorCompress,
15 },
16 layer::RuvectorLayer as RustRuvectorLayer,
17 search::{
18 differentiable_search as rust_differentiable_search,
19 hierarchical_forward as rust_hierarchical_forward,
20 },
21};
22
23// ==================== RuvectorLayer Bindings ====================
24
25/// Graph Neural Network layer for HNSW topology
26#[napi]
27pub struct RuvectorLayer {
28 inner: RustRuvectorLayer,
29}
30
31#[napi]
32impl RuvectorLayer {
33 /// Create a new Ruvector GNN layer
34 ///
35 /// # Arguments
36 /// * `input_dim` - Dimension of input node embeddings
37 /// * `hidden_dim` - Dimension of hidden representations
38 /// * `heads` - Number of attention heads
39 /// * `dropout` - Dropout rate (0.0 to 1.0)
40 ///
41 /// # Example
42 /// ```javascript
43 /// const layer = new RuvectorLayer(128, 256, 4, 0.1);
44 /// ```
45 #[napi(constructor)]
46 pub fn new(input_dim: u32, hidden_dim: u32, heads: u32, dropout: f64) -> Result<Self> {
47 if dropout < 0.0 || dropout > 1.0 {
48 return Err(Error::new(
49 Status::InvalidArg,
50 "Dropout must be between 0.0 and 1.0".to_string(),
51 ));
52 }
53
54 Ok(Self {
55 inner: RustRuvectorLayer::new(
56 input_dim as usize,
57 hidden_dim as usize,
58 heads as usize,
59 dropout as f32,
60 ),
61 })
62 }
63
64 /// Forward pass through the GNN layer
65 ///
66 /// # Arguments
67 /// * `node_embedding` - Current node's embedding
68 /// * `neighbor_embeddings` - Embeddings of neighbor nodes
69 /// * `edge_weights` - Weights of edges to neighbors
70 ///
71 /// # Returns
72 /// Updated node embedding
73 ///
74 /// # Example
75 /// ```javascript
76 /// const node = [1.0, 2.0, 3.0, 4.0];
77 /// const neighbors = [[0.5, 1.0, 1.5, 2.0], [2.0, 3.0, 4.0, 5.0]];
78 /// const weights = [0.3, 0.7];
79 /// const output = layer.forward(node, neighbors, weights);
80 /// ```
81 #[napi]
82 pub fn forward(
83 &self,
84 node_embedding: Vec<f64>,
85 neighbor_embeddings: Vec<Vec<f64>>,
86 edge_weights: Vec<f64>,
87 ) -> Result<Vec<f64>> {
88 // Convert f64 to f32
89 let node_f32: Vec<f32> = node_embedding.iter().map(|&x| x as f32).collect();
90 let neighbors_f32: Vec<Vec<f32>> = neighbor_embeddings
91 .iter()
92 .map(|v| v.iter().map(|&x| x as f32).collect())
93 .collect();
94 let weights_f32: Vec<f32> = edge_weights.iter().map(|&x| x as f32).collect();
95
96 let result = self.inner.forward(&node_f32, &neighbors_f32, &weights_f32);
97
98 // Convert back to f64
99 Ok(result.iter().map(|&x| x as f64).collect())
100 }
101
102 /// Serialize the layer to JSON
103 #[napi]
104 pub fn to_json(&self) -> Result<String> {
105 serde_json::to_string(&self.inner).map_err(|e| {
106 Error::new(
107 Status::GenericFailure,
108 format!("Serialization error: {}", e),
109 )
110 })
111 }
112
113 /// Deserialize the layer from JSON
114 #[napi(factory)]
115 pub fn from_json(json: String) -> Result<Self> {
116 let inner: RustRuvectorLayer = serde_json::from_str(&json).map_err(|e| {
117 Error::new(
118 Status::GenericFailure,
119 format!("Deserialization error: {}", e),
120 )
121 })?;
122 Ok(Self { inner })
123 }
124}
125
126// ==================== TensorCompress Bindings ====================
127
128/// Compression level for tensor compression
129#[napi(object)]
130pub struct CompressionLevelConfig {
131 /// Type of compression: "none", "half", "pq8", "pq4", "binary"
132 pub level_type: String,
133 /// Scale factor (for "half" compression)
134 pub scale: Option<f64>,
135 /// Number of subvectors (for PQ compression)
136 pub subvectors: Option<u32>,
137 /// Number of centroids (for PQ8)
138 pub centroids: Option<u32>,
139 /// Outlier threshold (for PQ4)
140 pub outlier_threshold: Option<f64>,
141 /// Binary threshold (for binary compression)
142 pub threshold: Option<f64>,
143}
144
145impl CompressionLevelConfig {
146 fn to_rust(&self) -> Result<RustCompressionLevel> {
147 match self.level_type.as_str() {
148 "none" => Ok(RustCompressionLevel::None),
149 "half" => Ok(RustCompressionLevel::Half {
150 scale: self.scale.unwrap_or(1.0) as f32,
151 }),
152 "pq8" => Ok(RustCompressionLevel::PQ8 {
153 subvectors: self.subvectors.unwrap_or(8) as u8,
154 centroids: self.centroids.unwrap_or(16) as u8,
155 }),
156 "pq4" => Ok(RustCompressionLevel::PQ4 {
157 subvectors: self.subvectors.unwrap_or(8) as u8,
158 outlier_threshold: self.outlier_threshold.unwrap_or(3.0) as f32,
159 }),
160 "binary" => Ok(RustCompressionLevel::Binary {
161 threshold: self.threshold.unwrap_or(0.0) as f32,
162 }),
163 _ => Err(Error::new(
164 Status::InvalidArg,
165 format!("Invalid compression level: {}", self.level_type),
166 )),
167 }
168 }
169}
170
171/// Tensor compressor with adaptive level selection
172#[napi]
173pub struct TensorCompress {
174 inner: RustTensorCompress,
175}
176
177#[napi]
178impl TensorCompress {
179 /// Create a new tensor compressor
180 ///
181 /// # Example
182 /// ```javascript
183 /// const compressor = new TensorCompress();
184 /// ```
185 #[napi(constructor)]
186 pub fn new() -> Self {
187 Self {
188 inner: RustTensorCompress::new(),
189 }
190 }
191
192 /// Compress an embedding based on access frequency
193 ///
194 /// # Arguments
195 /// * `embedding` - The input embedding vector
196 /// * `access_freq` - Access frequency in range [0.0, 1.0]
197 ///
198 /// # Returns
199 /// Compressed tensor as JSON string
200 ///
201 /// # Example
202 /// ```javascript
203 /// const embedding = [1.0, 2.0, 3.0, 4.0];
204 /// const compressed = compressor.compress(embedding, 0.5);
205 /// ```
206 #[napi]
207 pub fn compress(&self, embedding: Vec<f64>, access_freq: f64) -> Result<String> {
208 let embedding_f32: Vec<f32> = embedding.iter().map(|&x| x as f32).collect();
209
210 let compressed = self
211 .inner
212 .compress(&embedding_f32, access_freq as f32)
213 .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
214
215 serde_json::to_string(&compressed).map_err(|e| {
216 Error::new(
217 Status::GenericFailure,
218 format!("Serialization error: {}", e),
219 )
220 })
221 }
222
223 /// Compress with explicit compression level
224 ///
225 /// # Arguments
226 /// * `embedding` - The input embedding vector
227 /// * `level` - Compression level configuration
228 ///
229 /// # Returns
230 /// Compressed tensor as JSON string
231 ///
232 /// # Example
233 /// ```javascript
234 /// const embedding = [1.0, 2.0, 3.0, 4.0];
235 /// const level = { level_type: "half", scale: 1.0 };
236 /// const compressed = compressor.compressWithLevel(embedding, level);
237 /// ```
238 #[napi]
239 pub fn compress_with_level(
240 &self,
241 embedding: Vec<f64>,
242 level: CompressionLevelConfig,
243 ) -> Result<String> {
244 let embedding_f32: Vec<f32> = embedding.iter().map(|&x| x as f32).collect();
245 let rust_level = level.to_rust()?;
246
247 let compressed = self
248 .inner
249 .compress_with_level(&embedding_f32, &rust_level)
250 .map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
251
252 serde_json::to_string(&compressed).map_err(|e| {
253 Error::new(
254 Status::GenericFailure,
255 format!("Serialization error: {}", e),
256 )
257 })
258 }
259
260 /// Decompress a compressed tensor
261 ///
262 /// # Arguments
263 /// * `compressed_json` - Compressed tensor as JSON string
264 ///
265 /// # Returns
266 /// Decompressed embedding vector
267 ///
268 /// # Example
269 /// ```javascript
270 /// const decompressed = compressor.decompress(compressed);
271 /// ```
272 #[napi]
273 pub fn decompress(&self, compressed_json: String) -> Result<Vec<f64>> {
274 let compressed: RustCompressedTensor =
275 serde_json::from_str(&compressed_json).map_err(|e| {
276 Error::new(
277 Status::GenericFailure,
278 format!("Deserialization error: {}", e),
279 )
280 })?;
281
282 let result = self.inner.decompress(&compressed).map_err(|e| {
283 Error::new(
284 Status::GenericFailure,
285 format!("Decompression error: {}", e),
286 )
287 })?;
288
289 Ok(result.iter().map(|&x| x as f64).collect())
290 }
291}
292
293// ==================== Search Functions ====================
294
295/// Result from differentiable search
296#[napi(object)]
297pub struct SearchResult {
298 /// Indices of top-k candidates
299 pub indices: Vec<u32>,
300 /// Soft weights for top-k candidates
301 pub weights: Vec<f64>,
302}
303
304/// Differentiable search using soft attention mechanism
305///
306/// # Arguments
307/// * `query` - The query vector
308/// * `candidate_embeddings` - List of candidate embedding vectors
309/// * `k` - Number of top results to return
310/// * `temperature` - Temperature for softmax (lower = sharper, higher = smoother)
311///
312/// # Returns
313/// Search result with indices and soft weights
314///
315/// # Example
316/// ```javascript
317/// const query = [1.0, 0.0, 0.0];
318/// const candidates = [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.0, 1.0, 0.0]];
319/// const result = differentiableSearch(query, candidates, 2, 1.0);
320/// console.log(result.indices); // [0, 1]
321/// console.log(result.weights); // [0.x, 0.y]
322/// ```
323#[napi]
324pub fn differentiable_search(
325 query: Vec<f64>,
326 candidate_embeddings: Vec<Vec<f64>>,
327 k: u32,
328 temperature: f64,
329) -> Result<SearchResult> {
330 let query_f32: Vec<f32> = query.iter().map(|&x| x as f32).collect();
331 let candidates_f32: Vec<Vec<f32>> = candidate_embeddings
332 .iter()
333 .map(|v| v.iter().map(|&x| x as f32).collect())
334 .collect();
335
336 let (indices, weights) =
337 rust_differentiable_search(&query_f32, &candidates_f32, k as usize, temperature as f32);
338
339 Ok(SearchResult {
340 indices: indices.iter().map(|&i| i as u32).collect(),
341 weights: weights.iter().map(|&w| w as f64).collect(),
342 })
343}
344
345/// Hierarchical forward pass through GNN layers
346///
347/// # Arguments
348/// * `query` - The query vector
349/// * `layer_embeddings` - Embeddings organized by layer
350/// * `gnn_layers_json` - JSON array of serialized GNN layers
351///
352/// # Returns
353/// Final embedding after hierarchical processing
354///
355/// # Example
356/// ```javascript
357/// const query = [1.0, 0.0];
358/// const layerEmbeddings = [[[1.0, 0.0], [0.0, 1.0]]];
359/// const layer1 = new RuvectorLayer(2, 2, 1, 0.0);
360/// const layers = [layer1.toJson()];
361/// const result = hierarchicalForward(query, layerEmbeddings, layers);
362/// ```
363#[napi]
364pub fn hierarchical_forward(
365 query: Vec<f64>,
366 layer_embeddings: Vec<Vec<Vec<f64>>>,
367 gnn_layers_json: Vec<String>,
368) -> Result<Vec<f64>> {
369 let query_f32: Vec<f32> = query.iter().map(|&x| x as f32).collect();
370
371 let embeddings_f32: Vec<Vec<Vec<f32>>> = layer_embeddings
372 .iter()
373 .map(|layer| {
374 layer
375 .iter()
376 .map(|v| v.iter().map(|&x| x as f32).collect())
377 .collect()
378 })
379 .collect();
380
381 let gnn_layers: Vec<RustRuvectorLayer> = gnn_layers_json
382 .iter()
383 .map(|json| {
384 serde_json::from_str(json).map_err(|e| {
385 Error::new(
386 Status::GenericFailure,
387 format!("Layer deserialization error: {}", e),
388 )
389 })
390 })
391 .collect::<Result<Vec<_>>>()?;
392
393 let result = rust_hierarchical_forward(&query_f32, &embeddings_f32, &gnn_layers);
394
395 Ok(result.iter().map(|&x| x as f64).collect())
396}
397
398// ==================== Helper Functions ====================
399
400/// Get the compression level that would be selected for a given access frequency
401///
402/// # Arguments
403/// * `access_freq` - Access frequency in range [0.0, 1.0]
404///
405/// # Returns
406/// String describing the compression level: "none", "half", "pq8", "pq4", or "binary"
407///
408/// # Example
409/// ```javascript
410/// const level = getCompressionLevel(0.9); // "none" (hot data)
411/// const level2 = getCompressionLevel(0.5); // "half" (warm data)
412/// ```
413#[napi]
414pub fn get_compression_level(access_freq: f64) -> String {
415 if access_freq > 0.8 {
416 "none".to_string()
417 } else if access_freq > 0.4 {
418 "half".to_string()
419 } else if access_freq > 0.1 {
420 "pq8".to_string()
421 } else if access_freq > 0.01 {
422 "pq4".to_string()
423 } else {
424 "binary".to_string()
425 }
426}
427
428/// Module initialization
429#[napi]
430pub fn init() -> String {
431 "Ruvector GNN Node.js bindings initialized".to_string()
432}