1use napi::bindgen_prelude::*;
9use napi_derive::napi;
10use ruvector_attention::graph::{
11 DualSpaceAttention as RustDualSpace, DualSpaceConfig as RustDualConfig,
12 EdgeFeaturedAttention as RustEdgeFeatured, EdgeFeaturedConfig as RustEdgeConfig,
13 GraphRoPE as RustGraphRoPE, RoPEConfig as RustRoPEConfig,
14};
15use ruvector_attention::traits::Attention;
16
17#[napi(object)]
23pub struct EdgeFeaturedConfig {
24 pub node_dim: u32,
25 pub edge_dim: u32,
26 pub num_heads: u32,
27 pub concat_heads: Option<bool>,
28 pub add_self_loops: Option<bool>,
29 pub negative_slope: Option<f64>,
30}
31
32#[napi]
34pub struct EdgeFeaturedAttention {
35 inner: RustEdgeFeatured,
36 config: EdgeFeaturedConfig,
37}
38
39#[napi]
40impl EdgeFeaturedAttention {
41 #[napi(constructor)]
46 pub fn new(config: EdgeFeaturedConfig) -> Self {
47 let rust_config = RustEdgeConfig {
48 node_dim: config.node_dim as usize,
49 edge_dim: config.edge_dim as usize,
50 num_heads: config.num_heads as usize,
51 concat_heads: config.concat_heads.unwrap_or(true),
52 add_self_loops: config.add_self_loops.unwrap_or(true),
53 negative_slope: config.negative_slope.unwrap_or(0.2) as f32,
54 dropout: 0.0,
55 };
56 Self {
57 inner: RustEdgeFeatured::new(rust_config),
58 config,
59 }
60 }
61
62 #[napi(factory)]
64 pub fn simple(node_dim: u32, edge_dim: u32, num_heads: u32) -> Self {
65 Self::new(EdgeFeaturedConfig {
66 node_dim,
67 edge_dim,
68 num_heads,
69 concat_heads: Some(true),
70 add_self_loops: Some(true),
71 negative_slope: Some(0.2),
72 })
73 }
74
75 #[napi]
77 pub fn compute(
78 &self,
79 query: Float32Array,
80 keys: Vec<Float32Array>,
81 values: Vec<Float32Array>,
82 ) -> Result<Float32Array> {
83 let query_slice = query.as_ref();
84 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
85 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
86 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
87 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
88
89 let result = self
90 .inner
91 .compute(query_slice, &keys_refs, &values_refs)
92 .map_err(|e| Error::from_reason(e.to_string()))?;
93
94 Ok(Float32Array::new(result))
95 }
96
97 #[napi]
105 pub fn compute_with_edges(
106 &self,
107 query: Float32Array,
108 keys: Vec<Float32Array>,
109 values: Vec<Float32Array>,
110 edge_features: Vec<Float32Array>,
111 ) -> Result<Float32Array> {
112 let query_slice = query.as_ref();
113 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
114 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
115 let edge_features_vec: Vec<Vec<f32>> =
116 edge_features.into_iter().map(|e| e.to_vec()).collect();
117
118 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
119 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
120 let edges_refs: Vec<&[f32]> = edge_features_vec.iter().map(|e| e.as_slice()).collect();
121
122 let result = self
123 .inner
124 .compute_with_edges(query_slice, &keys_refs, &values_refs, &edges_refs)
125 .map_err(|e| Error::from_reason(e.to_string()))?;
126
127 Ok(Float32Array::new(result))
128 }
129
130 #[napi(getter)]
132 pub fn node_dim(&self) -> u32 {
133 self.config.node_dim
134 }
135
136 #[napi(getter)]
138 pub fn edge_dim(&self) -> u32 {
139 self.config.edge_dim
140 }
141
142 #[napi(getter)]
144 pub fn num_heads(&self) -> u32 {
145 self.config.num_heads
146 }
147}
148
149#[napi(object)]
155pub struct RoPEConfig {
156 pub dim: u32,
157 pub max_position: u32,
158 pub base: Option<f64>,
159 pub scaling_factor: Option<f64>,
160}
161
162#[napi]
164pub struct GraphRoPEAttention {
165 inner: RustGraphRoPE,
166 config: RoPEConfig,
167}
168
169#[napi]
170impl GraphRoPEAttention {
171 #[napi(constructor)]
176 pub fn new(config: RoPEConfig) -> Self {
177 let rust_config = RustRoPEConfig {
178 dim: config.dim as usize,
179 max_position: config.max_position as usize,
180 base: config.base.unwrap_or(10000.0) as f32,
181 scaling_factor: config.scaling_factor.unwrap_or(1.0) as f32,
182 };
183 Self {
184 inner: RustGraphRoPE::new(rust_config),
185 config,
186 }
187 }
188
189 #[napi(factory)]
191 pub fn simple(dim: u32, max_position: u32) -> Self {
192 Self::new(RoPEConfig {
193 dim,
194 max_position,
195 base: Some(10000.0),
196 scaling_factor: Some(1.0),
197 })
198 }
199
200 #[napi]
202 pub fn compute(
203 &self,
204 query: Float32Array,
205 keys: Vec<Float32Array>,
206 values: Vec<Float32Array>,
207 ) -> Result<Float32Array> {
208 let query_slice = query.as_ref();
209 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
210 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
211 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
212 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
213
214 let result = self
215 .inner
216 .compute(query_slice, &keys_refs, &values_refs)
217 .map_err(|e| Error::from_reason(e.to_string()))?;
218
219 Ok(Float32Array::new(result))
220 }
221
222 #[napi]
231 pub fn compute_with_positions(
232 &self,
233 query: Float32Array,
234 keys: Vec<Float32Array>,
235 values: Vec<Float32Array>,
236 query_position: u32,
237 key_positions: Vec<u32>,
238 ) -> Result<Float32Array> {
239 let query_slice = query.as_ref();
240 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
241 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
242 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
243 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
244 let positions_usize: Vec<usize> = key_positions.into_iter().map(|p| p as usize).collect();
245
246 let result = self
247 .inner
248 .compute_with_positions(
249 query_slice,
250 &keys_refs,
251 &values_refs,
252 query_position as usize,
253 &positions_usize,
254 )
255 .map_err(|e| Error::from_reason(e.to_string()))?;
256
257 Ok(Float32Array::new(result))
258 }
259
260 #[napi]
262 pub fn apply_rotary(&self, vector: Float32Array, position: u32) -> Float32Array {
263 let v = vector.as_ref();
264 let result = self.inner.apply_rotary(v, position as usize);
265 Float32Array::new(result)
266 }
267
268 #[napi]
270 pub fn distance_to_position(distance: u32, max_distance: u32) -> u32 {
271 RustGraphRoPE::distance_to_position(distance as usize, max_distance as usize) as u32
272 }
273
274 #[napi(getter)]
276 pub fn dim(&self) -> u32 {
277 self.config.dim
278 }
279
280 #[napi(getter)]
282 pub fn max_position(&self) -> u32 {
283 self.config.max_position
284 }
285}
286
287#[napi(object)]
293pub struct DualSpaceConfig {
294 pub dim: u32,
295 pub curvature: f64,
296 pub euclidean_weight: f64,
297 pub hyperbolic_weight: f64,
298 pub temperature: Option<f64>,
299}
300
301#[napi]
303pub struct DualSpaceAttention {
304 inner: RustDualSpace,
305 config: DualSpaceConfig,
306}
307
308#[napi]
309impl DualSpaceAttention {
310 #[napi(constructor)]
315 pub fn new(config: DualSpaceConfig) -> Self {
316 let rust_config = RustDualConfig {
317 dim: config.dim as usize,
318 curvature: config.curvature as f32,
319 euclidean_weight: config.euclidean_weight as f32,
320 hyperbolic_weight: config.hyperbolic_weight as f32,
321 learn_weights: false,
322 temperature: config.temperature.unwrap_or(1.0) as f32,
323 };
324 Self {
325 inner: RustDualSpace::new(rust_config),
326 config,
327 }
328 }
329
330 #[napi(factory)]
332 pub fn simple(dim: u32, curvature: f64) -> Self {
333 Self::new(DualSpaceConfig {
334 dim,
335 curvature,
336 euclidean_weight: 0.5,
337 hyperbolic_weight: 0.5,
338 temperature: Some(1.0),
339 })
340 }
341
342 #[napi(factory)]
344 pub fn with_weights(
345 dim: u32,
346 curvature: f64,
347 euclidean_weight: f64,
348 hyperbolic_weight: f64,
349 ) -> Self {
350 Self::new(DualSpaceConfig {
351 dim,
352 curvature,
353 euclidean_weight,
354 hyperbolic_weight,
355 temperature: Some(1.0),
356 })
357 }
358
359 #[napi]
361 pub fn compute(
362 &self,
363 query: Float32Array,
364 keys: Vec<Float32Array>,
365 values: Vec<Float32Array>,
366 ) -> Result<Float32Array> {
367 let query_slice = query.as_ref();
368 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
369 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
370 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
371 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
372
373 let result = self
374 .inner
375 .compute(query_slice, &keys_refs, &values_refs)
376 .map_err(|e| Error::from_reason(e.to_string()))?;
377
378 Ok(Float32Array::new(result))
379 }
380
381 #[napi]
383 pub fn get_space_contributions(
384 &self,
385 query: Float32Array,
386 keys: Vec<Float32Array>,
387 ) -> SpaceContributions {
388 let query_slice = query.as_ref();
389 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
390 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
391
392 let (euc_scores, hyp_scores) = self.inner.get_space_contributions(query_slice, &keys_refs);
393
394 SpaceContributions {
395 euclidean_scores: Float32Array::new(euc_scores),
396 hyperbolic_scores: Float32Array::new(hyp_scores),
397 }
398 }
399
400 #[napi(getter)]
402 pub fn dim(&self) -> u32 {
403 self.config.dim
404 }
405
406 #[napi(getter)]
408 pub fn curvature(&self) -> f64 {
409 self.config.curvature
410 }
411
412 #[napi(getter)]
414 pub fn euclidean_weight(&self) -> f64 {
415 self.config.euclidean_weight
416 }
417
418 #[napi(getter)]
420 pub fn hyperbolic_weight(&self) -> f64 {
421 self.config.hyperbolic_weight
422 }
423}
424
425#[napi(object)]
427pub struct SpaceContributions {
428 pub euclidean_scores: Float32Array,
429 pub hyperbolic_scores: Float32Array,
430}