1use napi::bindgen_prelude::*;
13use napi_derive::napi;
14use ruvector_attention::{
15 attention::{ScaledDotProductAttention, MultiHeadAttention as RustMultiHead},
16 sparse::{FlashAttention as RustFlash, LinearAttention as RustLinear, LocalGlobalAttention as RustLocalGlobal},
17 hyperbolic::{HyperbolicAttention as RustHyperbolic, HyperbolicAttentionConfig},
18 moe::{MoEAttention as RustMoE, MoEConfig as RustMoEConfig},
19 traits::Attention,
20};
21
22#[napi(object)]
24pub struct AttentionConfig {
25 pub dim: u32,
26 pub num_heads: Option<u32>,
27 pub dropout: Option<f64>,
28 pub scale: Option<f64>,
29 pub causal: Option<bool>,
30}
31
32#[napi]
34pub struct DotProductAttention {
35 inner: ScaledDotProductAttention,
36}
37
38#[napi]
39impl DotProductAttention {
40 #[napi(constructor)]
45 pub fn new(dim: u32) -> Result<Self> {
46 Ok(Self {
47 inner: ScaledDotProductAttention::new(dim as usize),
48 })
49 }
50
51 #[napi]
58 pub fn compute(
59 &self,
60 query: Float32Array,
61 keys: Vec<Float32Array>,
62 values: Vec<Float32Array>,
63 ) -> Result<Float32Array> {
64 let query_slice = query.as_ref();
65 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
66 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
67 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
68 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
69
70 let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
71 .map_err(|e| Error::from_reason(e.to_string()))?;
72
73 Ok(Float32Array::new(result))
74 }
75
76 #[napi]
84 pub fn compute_with_mask(
85 &self,
86 query: Float32Array,
87 keys: Vec<Float32Array>,
88 values: Vec<Float32Array>,
89 mask: Vec<bool>,
90 ) -> Result<Float32Array> {
91 let query_slice = query.as_ref();
92 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
93 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
94 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
95 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
96
97 let result = self.inner.compute_with_mask(query_slice, &keys_refs, &values_refs, Some(mask.as_slice()))
98 .map_err(|e| Error::from_reason(e.to_string()))?;
99
100 Ok(Float32Array::new(result))
101 }
102
103 #[napi(getter)]
105 pub fn dim(&self) -> u32 {
106 self.inner.dim() as u32
107 }
108}
109
110#[napi]
112pub struct MultiHeadAttention {
113 inner: RustMultiHead,
114 dim_value: usize,
115 num_heads_value: usize,
116}
117
118#[napi]
119impl MultiHeadAttention {
120 #[napi(constructor)]
126 pub fn new(dim: u32, num_heads: u32) -> Result<Self> {
127 let d = dim as usize;
128 let h = num_heads as usize;
129
130 if d % h != 0 {
131 return Err(Error::from_reason(format!(
132 "Dimension {} must be divisible by number of heads {}",
133 d, h
134 )));
135 }
136
137 Ok(Self {
138 inner: RustMultiHead::new(d, h),
139 dim_value: d,
140 num_heads_value: h,
141 })
142 }
143
144 #[napi]
146 pub fn compute(
147 &self,
148 query: Float32Array,
149 keys: Vec<Float32Array>,
150 values: Vec<Float32Array>,
151 ) -> Result<Float32Array> {
152 let query_slice = query.as_ref();
153 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
154 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
155 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
156 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
157
158 let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
159 .map_err(|e| Error::from_reason(e.to_string()))?;
160
161 Ok(Float32Array::new(result))
162 }
163
164 #[napi(getter)]
166 pub fn num_heads(&self) -> u32 {
167 self.num_heads_value as u32
168 }
169
170 #[napi(getter)]
172 pub fn dim(&self) -> u32 {
173 self.dim_value as u32
174 }
175
176 #[napi(getter)]
178 pub fn head_dim(&self) -> u32 {
179 (self.dim_value / self.num_heads_value) as u32
180 }
181}
182
183#[napi]
185pub struct HyperbolicAttention {
186 inner: RustHyperbolic,
187 curvature_value: f32,
188 dim_value: usize,
189}
190
191#[napi]
192impl HyperbolicAttention {
193 #[napi(constructor)]
199 pub fn new(dim: u32, curvature: f64) -> Self {
200 let config = HyperbolicAttentionConfig {
201 dim: dim as usize,
202 curvature: curvature as f32,
203 ..Default::default()
204 };
205 Self {
206 inner: RustHyperbolic::new(config),
207 curvature_value: curvature as f32,
208 dim_value: dim as usize,
209 }
210 }
211
212 #[napi(factory)]
220 pub fn with_config(dim: u32, curvature: f64, adaptive_curvature: bool, temperature: f64) -> Self {
221 let config = HyperbolicAttentionConfig {
222 dim: dim as usize,
223 curvature: curvature as f32,
224 adaptive_curvature,
225 temperature: temperature as f32,
226 frechet_max_iter: 100,
227 frechet_tol: 1e-6,
228 };
229 Self {
230 inner: RustHyperbolic::new(config),
231 curvature_value: curvature as f32,
232 dim_value: dim as usize,
233 }
234 }
235
236 #[napi]
238 pub fn compute(
239 &self,
240 query: Float32Array,
241 keys: Vec<Float32Array>,
242 values: Vec<Float32Array>,
243 ) -> Result<Float32Array> {
244 let query_slice = query.as_ref();
245 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
246 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
247 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
248 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
249
250 let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
251 .map_err(|e| Error::from_reason(e.to_string()))?;
252
253 Ok(Float32Array::new(result))
254 }
255
256 #[napi(getter)]
258 pub fn curvature(&self) -> f64 {
259 self.curvature_value as f64
260 }
261
262 #[napi(getter)]
264 pub fn dim(&self) -> u32 {
265 self.dim_value as u32
266 }
267}
268
269#[napi]
271pub struct FlashAttention {
272 inner: RustFlash,
273 dim_value: usize,
274 block_size_value: usize,
275}
276
277#[napi]
278impl FlashAttention {
279 #[napi(constructor)]
285 pub fn new(dim: u32, block_size: u32) -> Self {
286 Self {
287 inner: RustFlash::new(dim as usize, block_size as usize),
288 dim_value: dim as usize,
289 block_size_value: block_size as usize,
290 }
291 }
292
293 #[napi]
295 pub fn compute(
296 &self,
297 query: Float32Array,
298 keys: Vec<Float32Array>,
299 values: Vec<Float32Array>,
300 ) -> Result<Float32Array> {
301 let query_slice = query.as_ref();
302 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
303 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
304 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
305 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
306
307 let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
308 .map_err(|e| Error::from_reason(e.to_string()))?;
309
310 Ok(Float32Array::new(result))
311 }
312
313 #[napi(getter)]
315 pub fn dim(&self) -> u32 {
316 self.dim_value as u32
317 }
318
319 #[napi(getter)]
321 pub fn block_size(&self) -> u32 {
322 self.block_size_value as u32
323 }
324}
325
326#[napi]
328pub struct LinearAttention {
329 inner: RustLinear,
330 dim_value: usize,
331 num_features_value: usize,
332}
333
334#[napi]
335impl LinearAttention {
336 #[napi(constructor)]
342 pub fn new(dim: u32, num_features: u32) -> Self {
343 Self {
344 inner: RustLinear::new(dim as usize, num_features as usize),
345 dim_value: dim as usize,
346 num_features_value: num_features as usize,
347 }
348 }
349
350 #[napi]
352 pub fn compute(
353 &self,
354 query: Float32Array,
355 keys: Vec<Float32Array>,
356 values: Vec<Float32Array>,
357 ) -> Result<Float32Array> {
358 let query_slice = query.as_ref();
359 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
360 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
361 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
362 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
363
364 let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
365 .map_err(|e| Error::from_reason(e.to_string()))?;
366
367 Ok(Float32Array::new(result))
368 }
369
370 #[napi(getter)]
372 pub fn dim(&self) -> u32 {
373 self.dim_value as u32
374 }
375
376 #[napi(getter)]
378 pub fn num_features(&self) -> u32 {
379 self.num_features_value as u32
380 }
381}
382
383#[napi]
385pub struct LocalGlobalAttention {
386 inner: RustLocalGlobal,
387 dim_value: usize,
388 local_window_value: usize,
389 global_tokens_value: usize,
390}
391
392#[napi]
393impl LocalGlobalAttention {
394 #[napi(constructor)]
401 pub fn new(dim: u32, local_window: u32, global_tokens: u32) -> Self {
402 Self {
403 inner: RustLocalGlobal::new(dim as usize, local_window as usize, global_tokens as usize),
404 dim_value: dim as usize,
405 local_window_value: local_window as usize,
406 global_tokens_value: global_tokens as usize,
407 }
408 }
409
410 #[napi]
412 pub fn compute(
413 &self,
414 query: Float32Array,
415 keys: Vec<Float32Array>,
416 values: Vec<Float32Array>,
417 ) -> Result<Float32Array> {
418 let query_slice = query.as_ref();
419 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
420 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
421 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
422 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
423
424 let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
425 .map_err(|e| Error::from_reason(e.to_string()))?;
426
427 Ok(Float32Array::new(result))
428 }
429
430 #[napi(getter)]
432 pub fn dim(&self) -> u32 {
433 self.dim_value as u32
434 }
435
436 #[napi(getter)]
438 pub fn local_window(&self) -> u32 {
439 self.local_window_value as u32
440 }
441
442 #[napi(getter)]
444 pub fn global_tokens(&self) -> u32 {
445 self.global_tokens_value as u32
446 }
447}
448
449#[napi(object)]
451pub struct MoEConfig {
452 pub dim: u32,
453 pub num_experts: u32,
454 pub top_k: u32,
455 pub expert_capacity: Option<f64>,
456}
457
458#[napi]
460pub struct MoEAttention {
461 inner: RustMoE,
462 config: MoEConfig,
463}
464
465#[napi]
466impl MoEAttention {
467 #[napi(constructor)]
472 pub fn new(config: MoEConfig) -> Self {
473 let rust_config = RustMoEConfig::builder()
474 .dim(config.dim as usize)
475 .num_experts(config.num_experts as usize)
476 .top_k(config.top_k as usize)
477 .expert_capacity(config.expert_capacity.unwrap_or(1.25) as f32)
478 .build();
479
480 Self {
481 inner: RustMoE::new(rust_config),
482 config,
483 }
484 }
485
486 #[napi(factory)]
493 pub fn simple(dim: u32, num_experts: u32, top_k: u32) -> Self {
494 let config = MoEConfig {
495 dim,
496 num_experts,
497 top_k,
498 expert_capacity: Some(1.25),
499 };
500 Self::new(config)
501 }
502
503 #[napi]
505 pub fn compute(
506 &self,
507 query: Float32Array,
508 keys: Vec<Float32Array>,
509 values: Vec<Float32Array>,
510 ) -> Result<Float32Array> {
511 let query_slice = query.as_ref();
512 let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
513 let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
514 let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
515 let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
516
517 let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
518 .map_err(|e| Error::from_reason(e.to_string()))?;
519
520 Ok(Float32Array::new(result))
521 }
522
523 #[napi(getter)]
525 pub fn dim(&self) -> u32 {
526 self.config.dim
527 }
528
529 #[napi(getter)]
531 pub fn num_experts(&self) -> u32 {
532 self.config.num_experts
533 }
534
535 #[napi(getter)]
537 pub fn top_k(&self) -> u32 {
538 self.config.top_k
539 }
540}
541
542#[napi]
546pub fn project_to_poincare_ball(vector: Float32Array, curvature: f64) -> Float32Array {
547 let v = vector.to_vec();
548 let projected = ruvector_attention::hyperbolic::project_to_ball(&v, curvature as f32, 1e-5);
549 Float32Array::new(projected)
550}
551
552#[napi]
554pub fn poincare_distance(a: Float32Array, b: Float32Array, curvature: f64) -> f64 {
555 let a_slice = a.as_ref();
556 let b_slice = b.as_ref();
557 ruvector_attention::hyperbolic::poincare_distance(a_slice, b_slice, curvature as f32) as f64
558}
559
560#[napi]
562pub fn mobius_addition(a: Float32Array, b: Float32Array, curvature: f64) -> Float32Array {
563 let a_slice = a.as_ref();
564 let b_slice = b.as_ref();
565 let result = ruvector_attention::hyperbolic::mobius_add(a_slice, b_slice, curvature as f32);
566 Float32Array::new(result)
567}
568
569#[napi]
571pub fn exp_map(base: Float32Array, tangent: Float32Array, curvature: f64) -> Float32Array {
572 let base_slice = base.as_ref();
573 let tangent_slice = tangent.as_ref();
574 let result = ruvector_attention::hyperbolic::exp_map(base_slice, tangent_slice, curvature as f32);
575 Float32Array::new(result)
576}
577
578#[napi]
580pub fn log_map(base: Float32Array, point: Float32Array, curvature: f64) -> Float32Array {
581 let base_slice = base.as_ref();
582 let point_slice = point.as_ref();
583 let result = ruvector_attention::hyperbolic::log_map(base_slice, point_slice, curvature as f32);
584 Float32Array::new(result)
585}