ruvector_attention_node/
graph.rs

1//! NAPI-RS bindings for graph attention mechanisms
2//!
3//! Provides Node.js bindings for:
4//! - Edge-featured attention (GATv2-style)
5//! - Graph RoPE (Rotary Position Embeddings for graphs)
6//! - Dual-space attention (Euclidean + Hyperbolic)
7
8use napi::bindgen_prelude::*;
9use napi_derive::napi;
10use ruvector_attention::graph::{
11    EdgeFeaturedAttention as RustEdgeFeatured,
12    EdgeFeaturedConfig as RustEdgeConfig,
13    GraphRoPE as RustGraphRoPE,
14    RoPEConfig as RustRoPEConfig,
15    DualSpaceAttention as RustDualSpace,
16    DualSpaceConfig as RustDualConfig,
17};
18use ruvector_attention::traits::Attention;
19
20// ============================================================================
21// Edge-Featured Attention
22// ============================================================================
23
24/// Configuration for edge-featured attention
25#[napi(object)]
26pub struct EdgeFeaturedConfig {
27    pub node_dim: u32,
28    pub edge_dim: u32,
29    pub num_heads: u32,
30    pub concat_heads: Option<bool>,
31    pub add_self_loops: Option<bool>,
32    pub negative_slope: Option<f64>,
33}
34
35/// Edge-featured attention (GATv2-style)
36#[napi]
37pub struct EdgeFeaturedAttention {
38    inner: RustEdgeFeatured,
39    config: EdgeFeaturedConfig,
40}
41
42#[napi]
43impl EdgeFeaturedAttention {
44    /// Create a new edge-featured attention instance
45    ///
46    /// # Arguments
47    /// * `config` - Edge-featured attention configuration
48    #[napi(constructor)]
49    pub fn new(config: EdgeFeaturedConfig) -> Self {
50        let rust_config = RustEdgeConfig {
51            node_dim: config.node_dim as usize,
52            edge_dim: config.edge_dim as usize,
53            num_heads: config.num_heads as usize,
54            concat_heads: config.concat_heads.unwrap_or(true),
55            add_self_loops: config.add_self_loops.unwrap_or(true),
56            negative_slope: config.negative_slope.unwrap_or(0.2) as f32,
57            dropout: 0.0,
58        };
59        Self {
60            inner: RustEdgeFeatured::new(rust_config),
61            config,
62        }
63    }
64
65    /// Create with simple parameters
66    #[napi(factory)]
67    pub fn simple(node_dim: u32, edge_dim: u32, num_heads: u32) -> Self {
68        Self::new(EdgeFeaturedConfig {
69            node_dim,
70            edge_dim,
71            num_heads,
72            concat_heads: Some(true),
73            add_self_loops: Some(true),
74            negative_slope: Some(0.2),
75        })
76    }
77
78    /// Compute attention without edge features (standard attention)
79    #[napi]
80    pub fn compute(
81        &self,
82        query: Float32Array,
83        keys: Vec<Float32Array>,
84        values: Vec<Float32Array>,
85    ) -> Result<Float32Array> {
86        let query_slice = query.as_ref();
87        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
88        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
89        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
90        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
91
92        let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
93            .map_err(|e| Error::from_reason(e.to_string()))?;
94
95        Ok(Float32Array::new(result))
96    }
97
98    /// Compute attention with edge features
99    ///
100    /// # Arguments
101    /// * `query` - Query vector
102    /// * `keys` - Array of key vectors
103    /// * `values` - Array of value vectors
104    /// * `edge_features` - Array of edge feature vectors (same length as keys)
105    #[napi]
106    pub fn compute_with_edges(
107        &self,
108        query: Float32Array,
109        keys: Vec<Float32Array>,
110        values: Vec<Float32Array>,
111        edge_features: Vec<Float32Array>,
112    ) -> Result<Float32Array> {
113        let query_slice = query.as_ref();
114        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
115        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
116        let edge_features_vec: Vec<Vec<f32>> = edge_features.into_iter().map(|e| e.to_vec()).collect();
117
118        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
119        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
120        let edges_refs: Vec<&[f32]> = edge_features_vec.iter().map(|e| e.as_slice()).collect();
121
122        let result = self.inner.compute_with_edges(query_slice, &keys_refs, &values_refs, &edges_refs)
123            .map_err(|e| Error::from_reason(e.to_string()))?;
124
125        Ok(Float32Array::new(result))
126    }
127
128    /// Get the node dimension
129    #[napi(getter)]
130    pub fn node_dim(&self) -> u32 {
131        self.config.node_dim
132    }
133
134    /// Get the edge dimension
135    #[napi(getter)]
136    pub fn edge_dim(&self) -> u32 {
137        self.config.edge_dim
138    }
139
140    /// Get the number of heads
141    #[napi(getter)]
142    pub fn num_heads(&self) -> u32 {
143        self.config.num_heads
144    }
145}
146
147// ============================================================================
148// Graph RoPE Attention
149// ============================================================================
150
151/// Configuration for Graph RoPE attention
152#[napi(object)]
153pub struct RoPEConfig {
154    pub dim: u32,
155    pub max_position: u32,
156    pub base: Option<f64>,
157    pub scaling_factor: Option<f64>,
158}
159
160/// Graph RoPE attention (Rotary Position Embeddings for graphs)
161#[napi]
162pub struct GraphRoPEAttention {
163    inner: RustGraphRoPE,
164    config: RoPEConfig,
165}
166
167#[napi]
168impl GraphRoPEAttention {
169    /// Create a new Graph RoPE attention instance
170    ///
171    /// # Arguments
172    /// * `config` - RoPE configuration
173    #[napi(constructor)]
174    pub fn new(config: RoPEConfig) -> Self {
175        let rust_config = RustRoPEConfig {
176            dim: config.dim as usize,
177            max_position: config.max_position as usize,
178            base: config.base.unwrap_or(10000.0) as f32,
179            scaling_factor: config.scaling_factor.unwrap_or(1.0) as f32,
180        };
181        Self {
182            inner: RustGraphRoPE::new(rust_config),
183            config,
184        }
185    }
186
187    /// Create with simple parameters
188    #[napi(factory)]
189    pub fn simple(dim: u32, max_position: u32) -> Self {
190        Self::new(RoPEConfig {
191            dim,
192            max_position,
193            base: Some(10000.0),
194            scaling_factor: Some(1.0),
195        })
196    }
197
198    /// Compute attention without positional encoding
199    #[napi]
200    pub fn compute(
201        &self,
202        query: Float32Array,
203        keys: Vec<Float32Array>,
204        values: Vec<Float32Array>,
205    ) -> Result<Float32Array> {
206        let query_slice = query.as_ref();
207        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
208        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
209        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
210        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
211
212        let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
213            .map_err(|e| Error::from_reason(e.to_string()))?;
214
215        Ok(Float32Array::new(result))
216    }
217
218    /// Compute attention with graph positions
219    ///
220    /// # Arguments
221    /// * `query` - Query vector
222    /// * `keys` - Array of key vectors
223    /// * `values` - Array of value vectors
224    /// * `query_position` - Position of query node
225    /// * `key_positions` - Positions of key nodes (e.g., hop distances)
226    #[napi]
227    pub fn compute_with_positions(
228        &self,
229        query: Float32Array,
230        keys: Vec<Float32Array>,
231        values: Vec<Float32Array>,
232        query_position: u32,
233        key_positions: Vec<u32>,
234    ) -> Result<Float32Array> {
235        let query_slice = query.as_ref();
236        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
237        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
238        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
239        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
240        let positions_usize: Vec<usize> = key_positions.into_iter().map(|p| p as usize).collect();
241
242        let result = self.inner.compute_with_positions(
243            query_slice,
244            &keys_refs,
245            &values_refs,
246            query_position as usize,
247            &positions_usize
248        ).map_err(|e| Error::from_reason(e.to_string()))?;
249
250        Ok(Float32Array::new(result))
251    }
252
253    /// Apply rotary embedding to a vector
254    #[napi]
255    pub fn apply_rotary(&self, vector: Float32Array, position: u32) -> Float32Array {
256        let v = vector.as_ref();
257        let result = self.inner.apply_rotary(v, position as usize);
258        Float32Array::new(result)
259    }
260
261    /// Convert graph distance to position bucket
262    #[napi]
263    pub fn distance_to_position(distance: u32, max_distance: u32) -> u32 {
264        RustGraphRoPE::distance_to_position(distance as usize, max_distance as usize) as u32
265    }
266
267    /// Get the dimension
268    #[napi(getter)]
269    pub fn dim(&self) -> u32 {
270        self.config.dim
271    }
272
273    /// Get the max position
274    #[napi(getter)]
275    pub fn max_position(&self) -> u32 {
276        self.config.max_position
277    }
278}
279
280// ============================================================================
281// Dual-Space Attention
282// ============================================================================
283
284/// Configuration for dual-space attention
285#[napi(object)]
286pub struct DualSpaceConfig {
287    pub dim: u32,
288    pub curvature: f64,
289    pub euclidean_weight: f64,
290    pub hyperbolic_weight: f64,
291    pub temperature: Option<f64>,
292}
293
294/// Dual-space attention (Euclidean + Hyperbolic)
295#[napi]
296pub struct DualSpaceAttention {
297    inner: RustDualSpace,
298    config: DualSpaceConfig,
299}
300
301#[napi]
302impl DualSpaceAttention {
303    /// Create a new dual-space attention instance
304    ///
305    /// # Arguments
306    /// * `config` - Dual-space configuration
307    #[napi(constructor)]
308    pub fn new(config: DualSpaceConfig) -> Self {
309        let rust_config = RustDualConfig {
310            dim: config.dim as usize,
311            curvature: config.curvature as f32,
312            euclidean_weight: config.euclidean_weight as f32,
313            hyperbolic_weight: config.hyperbolic_weight as f32,
314            learn_weights: false,
315            temperature: config.temperature.unwrap_or(1.0) as f32,
316        };
317        Self {
318            inner: RustDualSpace::new(rust_config),
319            config,
320        }
321    }
322
323    /// Create with simple parameters (equal weights)
324    #[napi(factory)]
325    pub fn simple(dim: u32, curvature: f64) -> Self {
326        Self::new(DualSpaceConfig {
327            dim,
328            curvature,
329            euclidean_weight: 0.5,
330            hyperbolic_weight: 0.5,
331            temperature: Some(1.0),
332        })
333    }
334
335    /// Create with custom weights
336    #[napi(factory)]
337    pub fn with_weights(dim: u32, curvature: f64, euclidean_weight: f64, hyperbolic_weight: f64) -> Self {
338        Self::new(DualSpaceConfig {
339            dim,
340            curvature,
341            euclidean_weight,
342            hyperbolic_weight,
343            temperature: Some(1.0),
344        })
345    }
346
347    /// Compute dual-space attention
348    #[napi]
349    pub fn compute(
350        &self,
351        query: Float32Array,
352        keys: Vec<Float32Array>,
353        values: Vec<Float32Array>,
354    ) -> Result<Float32Array> {
355        let query_slice = query.as_ref();
356        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
357        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
358        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
359        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
360
361        let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
362            .map_err(|e| Error::from_reason(e.to_string()))?;
363
364        Ok(Float32Array::new(result))
365    }
366
367    /// Get space contributions (Euclidean and Hyperbolic scores separately)
368    #[napi]
369    pub fn get_space_contributions(
370        &self,
371        query: Float32Array,
372        keys: Vec<Float32Array>,
373    ) -> SpaceContributions {
374        let query_slice = query.as_ref();
375        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
376        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
377
378        let (euc_scores, hyp_scores) = self.inner.get_space_contributions(query_slice, &keys_refs);
379
380        SpaceContributions {
381            euclidean_scores: Float32Array::new(euc_scores),
382            hyperbolic_scores: Float32Array::new(hyp_scores),
383        }
384    }
385
386    /// Get the dimension
387    #[napi(getter)]
388    pub fn dim(&self) -> u32 {
389        self.config.dim
390    }
391
392    /// Get the curvature
393    #[napi(getter)]
394    pub fn curvature(&self) -> f64 {
395        self.config.curvature
396    }
397
398    /// Get the Euclidean weight
399    #[napi(getter)]
400    pub fn euclidean_weight(&self) -> f64 {
401        self.config.euclidean_weight
402    }
403
404    /// Get the Hyperbolic weight
405    #[napi(getter)]
406    pub fn hyperbolic_weight(&self) -> f64 {
407        self.config.hyperbolic_weight
408    }
409}
410
411/// Space contribution scores
412#[napi(object)]
413pub struct SpaceContributions {
414    pub euclidean_scores: Float32Array,
415    pub hyperbolic_scores: Float32Array,
416}