1use napi::bindgen_prelude::*;
13use napi_derive::napi;
14use ruvector_attention::{
15 attention::{MultiHeadAttention as RustMultiHead, ScaledDotProductAttention},
16 hyperbolic::{HyperbolicAttention as RustHyperbolic, HyperbolicAttentionConfig},
17 moe::{MoEAttention as RustMoE, MoEConfig as RustMoEConfig},
18 sparse::{
19 FlashAttention as RustFlash, LinearAttention as RustLinear,
20 LocalGlobalAttention as RustLocalGlobal,
21 },
22 traits::Attention,
23};
24
25#[napi(object)]
27pub struct AttentionConfig {
28 pub dim: u32,
29 pub num_heads: Option<u32>,
30 pub dropout: Option<f64>,
31 pub scale: Option<f64>,
32 pub causal: Option<bool>,
33}
34
35#[napi]
37pub struct DotProductAttention {
38 inner: ScaledDotProductAttention,
39}
40
41#[napi]
42impl DotProductAttention {
43 #[napi(constructor)]
48 pub fn new(dim: u32) -> Result<Self> {
49 Ok(Self {
50 inner: ScaledDotProductAttention::new(dim as usize),
51 })
52 }
53
54 #[napi]
61 pub fn compute(
62 &self,
63 query: Float32Array,
64 keys: Vec<Float32Array>,
65 values: Vec<Float32Array>,
66 ) -> Result<Float32Array> {
67 let query_slice = query.as_ref();
68 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
69 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
70 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
71 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
72
73 let result = self
74 .inner
75 .compute(query_slice, &keys_refs, &values_refs)
76 .map_err(|e| Error::from_reason(e.to_string()))?;
77
78 Ok(Float32Array::new(result))
79 }
80
81 #[napi]
89 pub fn compute_with_mask(
90 &self,
91 query: Float32Array,
92 keys: Vec<Float32Array>,
93 values: Vec<Float32Array>,
94 mask: Vec<bool>,
95 ) -> Result<Float32Array> {
96 let query_slice = query.as_ref();
97 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
98 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
99 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
100 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
101
102 let result = self
103 .inner
104 .compute_with_mask(query_slice, &keys_refs, &values_refs, Some(mask.as_slice()))
105 .map_err(|e| Error::from_reason(e.to_string()))?;
106
107 Ok(Float32Array::new(result))
108 }
109
110 #[napi(getter)]
112 pub fn dim(&self) -> u32 {
113 self.inner.dim() as u32
114 }
115}
116
117#[napi]
119pub struct MultiHeadAttention {
120 inner: RustMultiHead,
121 dim_value: usize,
122 num_heads_value: usize,
123}
124
125#[napi]
126impl MultiHeadAttention {
127 #[napi(constructor)]
133 pub fn new(dim: u32, num_heads: u32) -> Result<Self> {
134 let d = dim as usize;
135 let h = num_heads as usize;
136
137 if d % h != 0 {
138 return Err(Error::from_reason(format!(
139 "Dimension {} must be divisible by number of heads {}",
140 d, h
141 )));
142 }
143
144 Ok(Self {
145 inner: RustMultiHead::new(d, h),
146 dim_value: d,
147 num_heads_value: h,
148 })
149 }
150
151 #[napi]
153 pub fn compute(
154 &self,
155 query: Float32Array,
156 keys: Vec<Float32Array>,
157 values: Vec<Float32Array>,
158 ) -> Result<Float32Array> {
159 let query_slice = query.as_ref();
160 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
161 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
162 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
163 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
164
165 let result = self
166 .inner
167 .compute(query_slice, &keys_refs, &values_refs)
168 .map_err(|e| Error::from_reason(e.to_string()))?;
169
170 Ok(Float32Array::new(result))
171 }
172
173 #[napi(getter)]
175 pub fn num_heads(&self) -> u32 {
176 self.num_heads_value as u32
177 }
178
179 #[napi(getter)]
181 pub fn dim(&self) -> u32 {
182 self.dim_value as u32
183 }
184
185 #[napi(getter)]
187 pub fn head_dim(&self) -> u32 {
188 (self.dim_value / self.num_heads_value) as u32
189 }
190}
191
192#[napi]
194pub struct HyperbolicAttention {
195 inner: RustHyperbolic,
196 curvature_value: f32,
197 dim_value: usize,
198}
199
200#[napi]
201impl HyperbolicAttention {
202 #[napi(constructor)]
208 pub fn new(dim: u32, curvature: f64) -> Self {
209 let config = HyperbolicAttentionConfig {
210 dim: dim as usize,
211 curvature: curvature as f32,
212 ..Default::default()
213 };
214 Self {
215 inner: RustHyperbolic::new(config),
216 curvature_value: curvature as f32,
217 dim_value: dim as usize,
218 }
219 }
220
221 #[napi(factory)]
229 pub fn with_config(
230 dim: u32,
231 curvature: f64,
232 adaptive_curvature: bool,
233 temperature: f64,
234 ) -> Self {
235 let config = HyperbolicAttentionConfig {
236 dim: dim as usize,
237 curvature: curvature as f32,
238 adaptive_curvature,
239 temperature: temperature as f32,
240 frechet_max_iter: 100,
241 frechet_tol: 1e-6,
242 };
243 Self {
244 inner: RustHyperbolic::new(config),
245 curvature_value: curvature as f32,
246 dim_value: dim as usize,
247 }
248 }
249
250 #[napi]
252 pub fn compute(
253 &self,
254 query: Float32Array,
255 keys: Vec<Float32Array>,
256 values: Vec<Float32Array>,
257 ) -> Result<Float32Array> {
258 let query_slice = query.as_ref();
259 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
260 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
261 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
262 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
263
264 let result = self
265 .inner
266 .compute(query_slice, &keys_refs, &values_refs)
267 .map_err(|e| Error::from_reason(e.to_string()))?;
268
269 Ok(Float32Array::new(result))
270 }
271
272 #[napi(getter)]
274 pub fn curvature(&self) -> f64 {
275 self.curvature_value as f64
276 }
277
278 #[napi(getter)]
280 pub fn dim(&self) -> u32 {
281 self.dim_value as u32
282 }
283}
284
285#[napi]
287pub struct FlashAttention {
288 inner: RustFlash,
289 dim_value: usize,
290 block_size_value: usize,
291}
292
293#[napi]
294impl FlashAttention {
295 #[napi(constructor)]
301 pub fn new(dim: u32, block_size: u32) -> Self {
302 Self {
303 inner: RustFlash::new(dim as usize, block_size as usize),
304 dim_value: dim as usize,
305 block_size_value: block_size as usize,
306 }
307 }
308
309 #[napi]
311 pub fn compute(
312 &self,
313 query: Float32Array,
314 keys: Vec<Float32Array>,
315 values: Vec<Float32Array>,
316 ) -> Result<Float32Array> {
317 let query_slice = query.as_ref();
318 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
319 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
320 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
321 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
322
323 let result = self
324 .inner
325 .compute(query_slice, &keys_refs, &values_refs)
326 .map_err(|e| Error::from_reason(e.to_string()))?;
327
328 Ok(Float32Array::new(result))
329 }
330
331 #[napi(getter)]
333 pub fn dim(&self) -> u32 {
334 self.dim_value as u32
335 }
336
337 #[napi(getter)]
339 pub fn block_size(&self) -> u32 {
340 self.block_size_value as u32
341 }
342}
343
344#[napi]
346pub struct LinearAttention {
347 inner: RustLinear,
348 dim_value: usize,
349 num_features_value: usize,
350}
351
352#[napi]
353impl LinearAttention {
354 #[napi(constructor)]
360 pub fn new(dim: u32, num_features: u32) -> Self {
361 Self {
362 inner: RustLinear::new(dim as usize, num_features as usize),
363 dim_value: dim as usize,
364 num_features_value: num_features as usize,
365 }
366 }
367
368 #[napi]
370 pub fn compute(
371 &self,
372 query: Float32Array,
373 keys: Vec<Float32Array>,
374 values: Vec<Float32Array>,
375 ) -> Result<Float32Array> {
376 let query_slice = query.as_ref();
377 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
378 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
379 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
380 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
381
382 let result = self
383 .inner
384 .compute(query_slice, &keys_refs, &values_refs)
385 .map_err(|e| Error::from_reason(e.to_string()))?;
386
387 Ok(Float32Array::new(result))
388 }
389
390 #[napi(getter)]
392 pub fn dim(&self) -> u32 {
393 self.dim_value as u32
394 }
395
396 #[napi(getter)]
398 pub fn num_features(&self) -> u32 {
399 self.num_features_value as u32
400 }
401}
402
403#[napi]
405pub struct LocalGlobalAttention {
406 inner: RustLocalGlobal,
407 dim_value: usize,
408 local_window_value: usize,
409 global_tokens_value: usize,
410}
411
412#[napi]
413impl LocalGlobalAttention {
414 #[napi(constructor)]
421 pub fn new(dim: u32, local_window: u32, global_tokens: u32) -> Self {
422 Self {
423 inner: RustLocalGlobal::new(
424 dim as usize,
425 local_window as usize,
426 global_tokens as usize,
427 ),
428 dim_value: dim as usize,
429 local_window_value: local_window as usize,
430 global_tokens_value: global_tokens as usize,
431 }
432 }
433
434 #[napi]
436 pub fn compute(
437 &self,
438 query: Float32Array,
439 keys: Vec<Float32Array>,
440 values: Vec<Float32Array>,
441 ) -> Result<Float32Array> {
442 let query_slice = query.as_ref();
443 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
444 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
445 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
446 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
447
448 let result = self
449 .inner
450 .compute(query_slice, &keys_refs, &values_refs)
451 .map_err(|e| Error::from_reason(e.to_string()))?;
452
453 Ok(Float32Array::new(result))
454 }
455
456 #[napi(getter)]
458 pub fn dim(&self) -> u32 {
459 self.dim_value as u32
460 }
461
462 #[napi(getter)]
464 pub fn local_window(&self) -> u32 {
465 self.local_window_value as u32
466 }
467
468 #[napi(getter)]
470 pub fn global_tokens(&self) -> u32 {
471 self.global_tokens_value as u32
472 }
473}
474
475#[napi(object)]
477pub struct MoEConfig {
478 pub dim: u32,
479 pub num_experts: u32,
480 pub top_k: u32,
481 pub expert_capacity: Option<f64>,
482}
483
484#[napi]
486pub struct MoEAttention {
487 inner: RustMoE,
488 config: MoEConfig,
489}
490
491#[napi]
492impl MoEAttention {
493 #[napi(constructor)]
498 pub fn new(config: MoEConfig) -> Self {
499 let rust_config = RustMoEConfig::builder()
500 .dim(config.dim as usize)
501 .num_experts(config.num_experts as usize)
502 .top_k(config.top_k as usize)
503 .expert_capacity(config.expert_capacity.unwrap_or(1.25) as f32)
504 .build();
505
506 Self {
507 inner: RustMoE::new(rust_config),
508 config,
509 }
510 }
511
512 #[napi(factory)]
519 pub fn simple(dim: u32, num_experts: u32, top_k: u32) -> Self {
520 let config = MoEConfig {
521 dim,
522 num_experts,
523 top_k,
524 expert_capacity: Some(1.25),
525 };
526 Self::new(config)
527 }
528
529 #[napi]
531 pub fn compute(
532 &self,
533 query: Float32Array,
534 keys: Vec<Float32Array>,
535 values: Vec<Float32Array>,
536 ) -> Result<Float32Array> {
537 let query_slice = query.as_ref();
538 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
539 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
540 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
541 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
542
543 let result = self
544 .inner
545 .compute(query_slice, &keys_refs, &values_refs)
546 .map_err(|e| Error::from_reason(e.to_string()))?;
547
548 Ok(Float32Array::new(result))
549 }
550
551 #[napi(getter)]
553 pub fn dim(&self) -> u32 {
554 self.config.dim
555 }
556
557 #[napi(getter)]
559 pub fn num_experts(&self) -> u32 {
560 self.config.num_experts
561 }
562
563 #[napi(getter)]
565 pub fn top_k(&self) -> u32 {
566 self.config.top_k
567 }
568}
569
570#[napi]
574pub fn project_to_poincare_ball(vector: Float32Array, curvature: f64) -> Float32Array {
575 let v = vector.to_vec();
576 let projected = ruvector_attention::hyperbolic::project_to_ball(&v, curvature as f32, 1e-5);
577 Float32Array::new(projected)
578}
579
580#[napi]
582pub fn poincare_distance(a: Float32Array, b: Float32Array, curvature: f64) -> f64 {
583 let a_slice = a.as_ref();
584 let b_slice = b.as_ref();
585 ruvector_attention::hyperbolic::poincare_distance(a_slice, b_slice, curvature as f32) as f64
586}
587
588#[napi]
590pub fn mobius_addition(a: Float32Array, b: Float32Array, curvature: f64) -> Float32Array {
591 let a_slice = a.as_ref();
592 let b_slice = b.as_ref();
593 let result = ruvector_attention::hyperbolic::mobius_add(a_slice, b_slice, curvature as f32);
594 Float32Array::new(result)
595}
596
597#[napi]
599pub fn exp_map(base: Float32Array, tangent: Float32Array, curvature: f64) -> Float32Array {
600 let base_slice = base.as_ref();
601 let tangent_slice = tangent.as_ref();
602 let result =
603 ruvector_attention::hyperbolic::exp_map(base_slice, tangent_slice, curvature as f32);
604 Float32Array::new(result)
605}
606
607#[napi]
609pub fn log_map(base: Float32Array, point: Float32Array, curvature: f64) -> Float32Array {
610 let base_slice = base.as_ref();
611 let point_slice = point.as_ref();
612 let result = ruvector_attention::hyperbolic::log_map(base_slice, point_slice, curvature as f32);
613 Float32Array::new(result)
614}