1use napi::bindgen_prelude::*;
9use napi_derive::napi;
10use ruvector_attention::graph::{
11 EdgeFeaturedAttention as RustEdgeFeatured,
12 EdgeFeaturedConfig as RustEdgeConfig,
13 GraphRoPE as RustGraphRoPE,
14 RoPEConfig as RustRoPEConfig,
15 DualSpaceAttention as RustDualSpace,
16 DualSpaceConfig as RustDualConfig,
17};
18use ruvector_attention::traits::Attention;
19
20#[napi(object)]
26pub struct EdgeFeaturedConfig {
27 pub node_dim: u32,
28 pub edge_dim: u32,
29 pub num_heads: u32,
30 pub concat_heads: Option<bool>,
31 pub add_self_loops: Option<bool>,
32 pub negative_slope: Option<f64>,
33}
34
35#[napi]
37pub struct EdgeFeaturedAttention {
38 inner: RustEdgeFeatured,
39 config: EdgeFeaturedConfig,
40}
41
42#[napi]
43impl EdgeFeaturedAttention {
44 #[napi(constructor)]
49 pub fn new(config: EdgeFeaturedConfig) -> Self {
50 let rust_config = RustEdgeConfig {
51 node_dim: config.node_dim as usize,
52 edge_dim: config.edge_dim as usize,
53 num_heads: config.num_heads as usize,
54 concat_heads: config.concat_heads.unwrap_or(true),
55 add_self_loops: config.add_self_loops.unwrap_or(true),
56 negative_slope: config.negative_slope.unwrap_or(0.2) as f32,
57 dropout: 0.0,
58 };
59 Self {
60 inner: RustEdgeFeatured::new(rust_config),
61 config,
62 }
63 }
64
65 #[napi(factory)]
67 pub fn simple(node_dim: u32, edge_dim: u32, num_heads: u32) -> Self {
68 Self::new(EdgeFeaturedConfig {
69 node_dim,
70 edge_dim,
71 num_heads,
72 concat_heads: Some(true),
73 add_self_loops: Some(true),
74 negative_slope: Some(0.2),
75 })
76 }
77
78 #[napi]
80 pub fn compute(
81 &self,
82 query: Float32Array,
83 keys: Vec<Float32Array>,
84 values: Vec<Float32Array>,
85 ) -> Result<Float32Array> {
86 let query_slice = query.as_ref();
87 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
88 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
89 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
90 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
91
92 let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
93 .map_err(|e| Error::from_reason(e.to_string()))?;
94
95 Ok(Float32Array::new(result))
96 }
97
98 #[napi]
106 pub fn compute_with_edges(
107 &self,
108 query: Float32Array,
109 keys: Vec<Float32Array>,
110 values: Vec<Float32Array>,
111 edge_features: Vec<Float32Array>,
112 ) -> Result<Float32Array> {
113 let query_slice = query.as_ref();
114 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
115 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
116 let edge_features_vec: Vec<Vec<f32>> = edge_features.into_iter().map(|e| e.to_vec()).collect();
117
118 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
119 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
120 let edges_refs: Vec<&[f32]> = edge_features_vec.iter().map(|e| e.as_slice()).collect();
121
122 let result = self.inner.compute_with_edges(query_slice, &keys_refs, &values_refs, &edges_refs)
123 .map_err(|e| Error::from_reason(e.to_string()))?;
124
125 Ok(Float32Array::new(result))
126 }
127
128 #[napi(getter)]
130 pub fn node_dim(&self) -> u32 {
131 self.config.node_dim
132 }
133
134 #[napi(getter)]
136 pub fn edge_dim(&self) -> u32 {
137 self.config.edge_dim
138 }
139
140 #[napi(getter)]
142 pub fn num_heads(&self) -> u32 {
143 self.config.num_heads
144 }
145}
146
147#[napi(object)]
153pub struct RoPEConfig {
154 pub dim: u32,
155 pub max_position: u32,
156 pub base: Option<f64>,
157 pub scaling_factor: Option<f64>,
158}
159
160#[napi]
162pub struct GraphRoPEAttention {
163 inner: RustGraphRoPE,
164 config: RoPEConfig,
165}
166
167#[napi]
168impl GraphRoPEAttention {
169 #[napi(constructor)]
174 pub fn new(config: RoPEConfig) -> Self {
175 let rust_config = RustRoPEConfig {
176 dim: config.dim as usize,
177 max_position: config.max_position as usize,
178 base: config.base.unwrap_or(10000.0) as f32,
179 scaling_factor: config.scaling_factor.unwrap_or(1.0) as f32,
180 };
181 Self {
182 inner: RustGraphRoPE::new(rust_config),
183 config,
184 }
185 }
186
187 #[napi(factory)]
189 pub fn simple(dim: u32, max_position: u32) -> Self {
190 Self::new(RoPEConfig {
191 dim,
192 max_position,
193 base: Some(10000.0),
194 scaling_factor: Some(1.0),
195 })
196 }
197
198 #[napi]
200 pub fn compute(
201 &self,
202 query: Float32Array,
203 keys: Vec<Float32Array>,
204 values: Vec<Float32Array>,
205 ) -> Result<Float32Array> {
206 let query_slice = query.as_ref();
207 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
208 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
209 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
210 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
211
212 let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
213 .map_err(|e| Error::from_reason(e.to_string()))?;
214
215 Ok(Float32Array::new(result))
216 }
217
218 #[napi]
227 pub fn compute_with_positions(
228 &self,
229 query: Float32Array,
230 keys: Vec<Float32Array>,
231 values: Vec<Float32Array>,
232 query_position: u32,
233 key_positions: Vec<u32>,
234 ) -> Result<Float32Array> {
235 let query_slice = query.as_ref();
236 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
237 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
238 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
239 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
240 let positions_usize: Vec<usize> = key_positions.into_iter().map(|p| p as usize).collect();
241
242 let result = self.inner.compute_with_positions(
243 query_slice,
244 &keys_refs,
245 &values_refs,
246 query_position as usize,
247 &positions_usize
248 ).map_err(|e| Error::from_reason(e.to_string()))?;
249
250 Ok(Float32Array::new(result))
251 }
252
253 #[napi]
255 pub fn apply_rotary(&self, vector: Float32Array, position: u32) -> Float32Array {
256 let v = vector.as_ref();
257 let result = self.inner.apply_rotary(v, position as usize);
258 Float32Array::new(result)
259 }
260
261 #[napi]
263 pub fn distance_to_position(distance: u32, max_distance: u32) -> u32 {
264 RustGraphRoPE::distance_to_position(distance as usize, max_distance as usize) as u32
265 }
266
267 #[napi(getter)]
269 pub fn dim(&self) -> u32 {
270 self.config.dim
271 }
272
273 #[napi(getter)]
275 pub fn max_position(&self) -> u32 {
276 self.config.max_position
277 }
278}
279
280#[napi(object)]
286pub struct DualSpaceConfig {
287 pub dim: u32,
288 pub curvature: f64,
289 pub euclidean_weight: f64,
290 pub hyperbolic_weight: f64,
291 pub temperature: Option<f64>,
292}
293
294#[napi]
296pub struct DualSpaceAttention {
297 inner: RustDualSpace,
298 config: DualSpaceConfig,
299}
300
301#[napi]
302impl DualSpaceAttention {
303 #[napi(constructor)]
308 pub fn new(config: DualSpaceConfig) -> Self {
309 let rust_config = RustDualConfig {
310 dim: config.dim as usize,
311 curvature: config.curvature as f32,
312 euclidean_weight: config.euclidean_weight as f32,
313 hyperbolic_weight: config.hyperbolic_weight as f32,
314 learn_weights: false,
315 temperature: config.temperature.unwrap_or(1.0) as f32,
316 };
317 Self {
318 inner: RustDualSpace::new(rust_config),
319 config,
320 }
321 }
322
323 #[napi(factory)]
325 pub fn simple(dim: u32, curvature: f64) -> Self {
326 Self::new(DualSpaceConfig {
327 dim,
328 curvature,
329 euclidean_weight: 0.5,
330 hyperbolic_weight: 0.5,
331 temperature: Some(1.0),
332 })
333 }
334
335 #[napi(factory)]
337 pub fn with_weights(dim: u32, curvature: f64, euclidean_weight: f64, hyperbolic_weight: f64) -> Self {
338 Self::new(DualSpaceConfig {
339 dim,
340 curvature,
341 euclidean_weight,
342 hyperbolic_weight,
343 temperature: Some(1.0),
344 })
345 }
346
347 #[napi]
349 pub fn compute(
350 &self,
351 query: Float32Array,
352 keys: Vec<Float32Array>,
353 values: Vec<Float32Array>,
354 ) -> Result<Float32Array> {
355 let query_slice = query.as_ref();
356 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
357 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
358 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
359 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
360
361 let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
362 .map_err(|e| Error::from_reason(e.to_string()))?;
363
364 Ok(Float32Array::new(result))
365 }
366
367 #[napi]
369 pub fn get_space_contributions(
370 &self,
371 query: Float32Array,
372 keys: Vec<Float32Array>,
373 ) -> SpaceContributions {
374 let query_slice = query.as_ref();
375 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
376 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
377
378 let (euc_scores, hyp_scores) = self.inner.get_space_contributions(query_slice, &keys_refs);
379
380 SpaceContributions {
381 euclidean_scores: Float32Array::new(euc_scores),
382 hyperbolic_scores: Float32Array::new(hyp_scores),
383 }
384 }
385
386 #[napi(getter)]
388 pub fn dim(&self) -> u32 {
389 self.config.dim
390 }
391
392 #[napi(getter)]
394 pub fn curvature(&self) -> f64 {
395 self.config.curvature
396 }
397
398 #[napi(getter)]
400 pub fn euclidean_weight(&self) -> f64 {
401 self.config.euclidean_weight
402 }
403
404 #[napi(getter)]
406 pub fn hyperbolic_weight(&self) -> f64 {
407 self.config.hyperbolic_weight
408 }
409}
410
411#[napi(object)]
413pub struct SpaceContributions {
414 pub euclidean_scores: Float32Array,
415 pub hyperbolic_scores: Float32Array,
416}