Skip to main content

ruvector_attention_node/
graph.rs

1//! NAPI-RS bindings for graph attention mechanisms
2//!
3//! Provides Node.js bindings for:
4//! - Edge-featured attention (GATv2-style)
5//! - Graph RoPE (Rotary Position Embeddings for graphs)
6//! - Dual-space attention (Euclidean + Hyperbolic)
7
8use napi::bindgen_prelude::*;
9use napi_derive::napi;
10use ruvector_attention::graph::{
11    DualSpaceAttention as RustDualSpace, DualSpaceConfig as RustDualConfig,
12    EdgeFeaturedAttention as RustEdgeFeatured, EdgeFeaturedConfig as RustEdgeConfig,
13    GraphRoPE as RustGraphRoPE, RoPEConfig as RustRoPEConfig,
14};
15use ruvector_attention::traits::Attention;
16
17// ============================================================================
18// Edge-Featured Attention
19// ============================================================================
20
21/// Configuration for edge-featured attention
22#[napi(object)]
23pub struct EdgeFeaturedConfig {
24    pub node_dim: u32,
25    pub edge_dim: u32,
26    pub num_heads: u32,
27    pub concat_heads: Option<bool>,
28    pub add_self_loops: Option<bool>,
29    pub negative_slope: Option<f64>,
30}
31
32/// Edge-featured attention (GATv2-style)
33#[napi]
34pub struct EdgeFeaturedAttention {
35    inner: RustEdgeFeatured,
36    config: EdgeFeaturedConfig,
37}
38
39#[napi]
40impl EdgeFeaturedAttention {
41    /// Create a new edge-featured attention instance
42    ///
43    /// # Arguments
44    /// * `config` - Edge-featured attention configuration
45    #[napi(constructor)]
46    pub fn new(config: EdgeFeaturedConfig) -> Self {
47        let rust_config = RustEdgeConfig {
48            node_dim: config.node_dim as usize,
49            edge_dim: config.edge_dim as usize,
50            num_heads: config.num_heads as usize,
51            concat_heads: config.concat_heads.unwrap_or(true),
52            add_self_loops: config.add_self_loops.unwrap_or(true),
53            negative_slope: config.negative_slope.unwrap_or(0.2) as f32,
54            dropout: 0.0,
55        };
56        Self {
57            inner: RustEdgeFeatured::new(rust_config),
58            config,
59        }
60    }
61
62    /// Create with simple parameters
63    #[napi(factory)]
64    pub fn simple(node_dim: u32, edge_dim: u32, num_heads: u32) -> Self {
65        Self::new(EdgeFeaturedConfig {
66            node_dim,
67            edge_dim,
68            num_heads,
69            concat_heads: Some(true),
70            add_self_loops: Some(true),
71            negative_slope: Some(0.2),
72        })
73    }
74
75    /// Compute attention without edge features (standard attention)
76    #[napi]
77    pub fn compute(
78        &self,
79        query: Float32Array,
80        keys: Vec<Float32Array>,
81        values: Vec<Float32Array>,
82    ) -> Result<Float32Array> {
83        let query_slice = query.as_ref();
84        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
85        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
86        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
87        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
88
89        let result = self
90            .inner
91            .compute(query_slice, &keys_refs, &values_refs)
92            .map_err(|e| Error::from_reason(e.to_string()))?;
93
94        Ok(Float32Array::new(result))
95    }
96
97    /// Compute attention with edge features
98    ///
99    /// # Arguments
100    /// * `query` - Query vector
101    /// * `keys` - Array of key vectors
102    /// * `values` - Array of value vectors
103    /// * `edge_features` - Array of edge feature vectors (same length as keys)
104    #[napi]
105    pub fn compute_with_edges(
106        &self,
107        query: Float32Array,
108        keys: Vec<Float32Array>,
109        values: Vec<Float32Array>,
110        edge_features: Vec<Float32Array>,
111    ) -> Result<Float32Array> {
112        let query_slice = query.as_ref();
113        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
114        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
115        let edge_features_vec: Vec<Vec<f32>> =
116            edge_features.into_iter().map(|e| e.to_vec()).collect();
117
118        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
119        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
120        let edges_refs: Vec<&[f32]> = edge_features_vec.iter().map(|e| e.as_slice()).collect();
121
122        let result = self
123            .inner
124            .compute_with_edges(query_slice, &keys_refs, &values_refs, &edges_refs)
125            .map_err(|e| Error::from_reason(e.to_string()))?;
126
127        Ok(Float32Array::new(result))
128    }
129
130    /// Get the node dimension
131    #[napi(getter)]
132    pub fn node_dim(&self) -> u32 {
133        self.config.node_dim
134    }
135
136    /// Get the edge dimension
137    #[napi(getter)]
138    pub fn edge_dim(&self) -> u32 {
139        self.config.edge_dim
140    }
141
142    /// Get the number of heads
143    #[napi(getter)]
144    pub fn num_heads(&self) -> u32 {
145        self.config.num_heads
146    }
147}
148
149// ============================================================================
150// Graph RoPE Attention
151// ============================================================================
152
153/// Configuration for Graph RoPE attention
154#[napi(object)]
155pub struct RoPEConfig {
156    pub dim: u32,
157    pub max_position: u32,
158    pub base: Option<f64>,
159    pub scaling_factor: Option<f64>,
160}
161
162/// Graph RoPE attention (Rotary Position Embeddings for graphs)
163#[napi]
164pub struct GraphRoPEAttention {
165    inner: RustGraphRoPE,
166    config: RoPEConfig,
167}
168
169#[napi]
170impl GraphRoPEAttention {
171    /// Create a new Graph RoPE attention instance
172    ///
173    /// # Arguments
174    /// * `config` - RoPE configuration
175    #[napi(constructor)]
176    pub fn new(config: RoPEConfig) -> Self {
177        let rust_config = RustRoPEConfig {
178            dim: config.dim as usize,
179            max_position: config.max_position as usize,
180            base: config.base.unwrap_or(10000.0) as f32,
181            scaling_factor: config.scaling_factor.unwrap_or(1.0) as f32,
182        };
183        Self {
184            inner: RustGraphRoPE::new(rust_config),
185            config,
186        }
187    }
188
189    /// Create with simple parameters
190    #[napi(factory)]
191    pub fn simple(dim: u32, max_position: u32) -> Self {
192        Self::new(RoPEConfig {
193            dim,
194            max_position,
195            base: Some(10000.0),
196            scaling_factor: Some(1.0),
197        })
198    }
199
200    /// Compute attention without positional encoding
201    #[napi]
202    pub fn compute(
203        &self,
204        query: Float32Array,
205        keys: Vec<Float32Array>,
206        values: Vec<Float32Array>,
207    ) -> Result<Float32Array> {
208        let query_slice = query.as_ref();
209        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
210        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
211        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
212        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
213
214        let result = self
215            .inner
216            .compute(query_slice, &keys_refs, &values_refs)
217            .map_err(|e| Error::from_reason(e.to_string()))?;
218
219        Ok(Float32Array::new(result))
220    }
221
222    /// Compute attention with graph positions
223    ///
224    /// # Arguments
225    /// * `query` - Query vector
226    /// * `keys` - Array of key vectors
227    /// * `values` - Array of value vectors
228    /// * `query_position` - Position of query node
229    /// * `key_positions` - Positions of key nodes (e.g., hop distances)
230    #[napi]
231    pub fn compute_with_positions(
232        &self,
233        query: Float32Array,
234        keys: Vec<Float32Array>,
235        values: Vec<Float32Array>,
236        query_position: u32,
237        key_positions: Vec<u32>,
238    ) -> Result<Float32Array> {
239        let query_slice = query.as_ref();
240        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
241        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
242        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
243        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
244        let positions_usize: Vec<usize> = key_positions.into_iter().map(|p| p as usize).collect();
245
246        let result = self
247            .inner
248            .compute_with_positions(
249                query_slice,
250                &keys_refs,
251                &values_refs,
252                query_position as usize,
253                &positions_usize,
254            )
255            .map_err(|e| Error::from_reason(e.to_string()))?;
256
257        Ok(Float32Array::new(result))
258    }
259
260    /// Apply rotary embedding to a vector
261    #[napi]
262    pub fn apply_rotary(&self, vector: Float32Array, position: u32) -> Float32Array {
263        let v = vector.as_ref();
264        let result = self.inner.apply_rotary(v, position as usize);
265        Float32Array::new(result)
266    }
267
268    /// Convert graph distance to position bucket
269    #[napi]
270    pub fn distance_to_position(distance: u32, max_distance: u32) -> u32 {
271        RustGraphRoPE::distance_to_position(distance as usize, max_distance as usize) as u32
272    }
273
274    /// Get the dimension
275    #[napi(getter)]
276    pub fn dim(&self) -> u32 {
277        self.config.dim
278    }
279
280    /// Get the max position
281    #[napi(getter)]
282    pub fn max_position(&self) -> u32 {
283        self.config.max_position
284    }
285}
286
287// ============================================================================
288// Dual-Space Attention
289// ============================================================================
290
291/// Configuration for dual-space attention
292#[napi(object)]
293pub struct DualSpaceConfig {
294    pub dim: u32,
295    pub curvature: f64,
296    pub euclidean_weight: f64,
297    pub hyperbolic_weight: f64,
298    pub temperature: Option<f64>,
299}
300
301/// Dual-space attention (Euclidean + Hyperbolic)
302#[napi]
303pub struct DualSpaceAttention {
304    inner: RustDualSpace,
305    config: DualSpaceConfig,
306}
307
308#[napi]
309impl DualSpaceAttention {
310    /// Create a new dual-space attention instance
311    ///
312    /// # Arguments
313    /// * `config` - Dual-space configuration
314    #[napi(constructor)]
315    pub fn new(config: DualSpaceConfig) -> Self {
316        let rust_config = RustDualConfig {
317            dim: config.dim as usize,
318            curvature: config.curvature as f32,
319            euclidean_weight: config.euclidean_weight as f32,
320            hyperbolic_weight: config.hyperbolic_weight as f32,
321            learn_weights: false,
322            temperature: config.temperature.unwrap_or(1.0) as f32,
323        };
324        Self {
325            inner: RustDualSpace::new(rust_config),
326            config,
327        }
328    }
329
330    /// Create with simple parameters (equal weights)
331    #[napi(factory)]
332    pub fn simple(dim: u32, curvature: f64) -> Self {
333        Self::new(DualSpaceConfig {
334            dim,
335            curvature,
336            euclidean_weight: 0.5,
337            hyperbolic_weight: 0.5,
338            temperature: Some(1.0),
339        })
340    }
341
342    /// Create with custom weights
343    #[napi(factory)]
344    pub fn with_weights(
345        dim: u32,
346        curvature: f64,
347        euclidean_weight: f64,
348        hyperbolic_weight: f64,
349    ) -> Self {
350        Self::new(DualSpaceConfig {
351            dim,
352            curvature,
353            euclidean_weight,
354            hyperbolic_weight,
355            temperature: Some(1.0),
356        })
357    }
358
359    /// Compute dual-space attention
360    #[napi]
361    pub fn compute(
362        &self,
363        query: Float32Array,
364        keys: Vec<Float32Array>,
365        values: Vec<Float32Array>,
366    ) -> Result<Float32Array> {
367        let query_slice = query.as_ref();
368        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
369        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
370        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
371        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
372
373        let result = self
374            .inner
375            .compute(query_slice, &keys_refs, &values_refs)
376            .map_err(|e| Error::from_reason(e.to_string()))?;
377
378        Ok(Float32Array::new(result))
379    }
380
381    /// Get space contributions (Euclidean and Hyperbolic scores separately)
382    #[napi]
383    pub fn get_space_contributions(
384        &self,
385        query: Float32Array,
386        keys: Vec<Float32Array>,
387    ) -> SpaceContributions {
388        let query_slice = query.as_ref();
389        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
390        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
391
392        let (euc_scores, hyp_scores) = self.inner.get_space_contributions(query_slice, &keys_refs);
393
394        SpaceContributions {
395            euclidean_scores: Float32Array::new(euc_scores),
396            hyperbolic_scores: Float32Array::new(hyp_scores),
397        }
398    }
399
400    /// Get the dimension
401    #[napi(getter)]
402    pub fn dim(&self) -> u32 {
403        self.config.dim
404    }
405
406    /// Get the curvature
407    #[napi(getter)]
408    pub fn curvature(&self) -> f64 {
409        self.config.curvature
410    }
411
412    /// Get the Euclidean weight
413    #[napi(getter)]
414    pub fn euclidean_weight(&self) -> f64 {
415        self.config.euclidean_weight
416    }
417
418    /// Get the Hyperbolic weight
419    #[napi(getter)]
420    pub fn hyperbolic_weight(&self) -> f64 {
421        self.config.hyperbolic_weight
422    }
423}
424
425/// Space contribution scores
426#[napi(object)]
427pub struct SpaceContributions {
428    pub euclidean_scores: Float32Array,
429    pub hyperbolic_scores: Float32Array,
430}