ruvector_attention_node/
attention.rs

1//! NAPI-RS bindings for attention mechanisms
2//!
3//! Provides Node.js bindings for all attention variants:
4//! - Scaled dot-product attention
5//! - Multi-head attention
6//! - Hyperbolic attention
7//! - Flash attention
8//! - Linear attention
9//! - Local-global attention
10//! - Mixture of Experts attention
11
12use napi::bindgen_prelude::*;
13use napi_derive::napi;
14use ruvector_attention::{
15    attention::{ScaledDotProductAttention, MultiHeadAttention as RustMultiHead},
16    sparse::{FlashAttention as RustFlash, LinearAttention as RustLinear, LocalGlobalAttention as RustLocalGlobal},
17    hyperbolic::{HyperbolicAttention as RustHyperbolic, HyperbolicAttentionConfig},
18    moe::{MoEAttention as RustMoE, MoEConfig as RustMoEConfig},
19    traits::Attention,
20};
21
22/// Attention configuration object
23#[napi(object)]
24pub struct AttentionConfig {
25    pub dim: u32,
26    pub num_heads: Option<u32>,
27    pub dropout: Option<f64>,
28    pub scale: Option<f64>,
29    pub causal: Option<bool>,
30}
31
32/// Scaled dot-product attention
33#[napi]
34pub struct DotProductAttention {
35    inner: ScaledDotProductAttention,
36}
37
38#[napi]
39impl DotProductAttention {
40    /// Create a new scaled dot-product attention instance
41    ///
42    /// # Arguments
43    /// * `dim` - Embedding dimension
44    #[napi(constructor)]
45    pub fn new(dim: u32) -> Result<Self> {
46        Ok(Self {
47            inner: ScaledDotProductAttention::new(dim as usize),
48        })
49    }
50
51    /// Compute attention output
52    ///
53    /// # Arguments
54    /// * `query` - Query vector
55    /// * `keys` - Array of key vectors
56    /// * `values` - Array of value vectors
57    #[napi]
58    pub fn compute(
59        &self,
60        query: Float32Array,
61        keys: Vec<Float32Array>,
62        values: Vec<Float32Array>,
63    ) -> Result<Float32Array> {
64        let query_slice = query.as_ref();
65        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
66        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
67        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
68        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
69
70        let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
71            .map_err(|e| Error::from_reason(e.to_string()))?;
72
73        Ok(Float32Array::new(result))
74    }
75
76    /// Compute attention with mask
77    ///
78    /// # Arguments
79    /// * `query` - Query vector
80    /// * `keys` - Array of key vectors
81    /// * `values` - Array of value vectors
82    /// * `mask` - Boolean mask array (true = attend, false = mask)
83    #[napi]
84    pub fn compute_with_mask(
85        &self,
86        query: Float32Array,
87        keys: Vec<Float32Array>,
88        values: Vec<Float32Array>,
89        mask: Vec<bool>,
90    ) -> Result<Float32Array> {
91        let query_slice = query.as_ref();
92        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
93        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
94        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
95        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
96
97        let result = self.inner.compute_with_mask(query_slice, &keys_refs, &values_refs, Some(mask.as_slice()))
98            .map_err(|e| Error::from_reason(e.to_string()))?;
99
100        Ok(Float32Array::new(result))
101    }
102
103    /// Get the dimension
104    #[napi(getter)]
105    pub fn dim(&self) -> u32 {
106        self.inner.dim() as u32
107    }
108}
109
110/// Multi-head attention mechanism
111#[napi]
112pub struct MultiHeadAttention {
113    inner: RustMultiHead,
114    dim_value: usize,
115    num_heads_value: usize,
116}
117
118#[napi]
119impl MultiHeadAttention {
120    /// Create a new multi-head attention instance
121    ///
122    /// # Arguments
123    /// * `dim` - Embedding dimension (must be divisible by num_heads)
124    /// * `num_heads` - Number of attention heads
125    #[napi(constructor)]
126    pub fn new(dim: u32, num_heads: u32) -> Result<Self> {
127        let d = dim as usize;
128        let h = num_heads as usize;
129
130        if d % h != 0 {
131            return Err(Error::from_reason(format!(
132                "Dimension {} must be divisible by number of heads {}",
133                d, h
134            )));
135        }
136
137        Ok(Self {
138            inner: RustMultiHead::new(d, h),
139            dim_value: d,
140            num_heads_value: h,
141        })
142    }
143
144    /// Compute multi-head attention
145    #[napi]
146    pub fn compute(
147        &self,
148        query: Float32Array,
149        keys: Vec<Float32Array>,
150        values: Vec<Float32Array>,
151    ) -> Result<Float32Array> {
152        let query_slice = query.as_ref();
153        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
154        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
155        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
156        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
157
158        let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
159            .map_err(|e| Error::from_reason(e.to_string()))?;
160
161        Ok(Float32Array::new(result))
162    }
163
164    /// Get the number of heads
165    #[napi(getter)]
166    pub fn num_heads(&self) -> u32 {
167        self.num_heads_value as u32
168    }
169
170    /// Get the dimension
171    #[napi(getter)]
172    pub fn dim(&self) -> u32 {
173        self.dim_value as u32
174    }
175
176    /// Get the head dimension
177    #[napi(getter)]
178    pub fn head_dim(&self) -> u32 {
179        (self.dim_value / self.num_heads_value) as u32
180    }
181}
182
183/// Hyperbolic attention in Poincaré ball model
184#[napi]
185pub struct HyperbolicAttention {
186    inner: RustHyperbolic,
187    curvature_value: f32,
188    dim_value: usize,
189}
190
191#[napi]
192impl HyperbolicAttention {
193    /// Create a new hyperbolic attention instance
194    ///
195    /// # Arguments
196    /// * `dim` - Embedding dimension
197    /// * `curvature` - Hyperbolic curvature (typically 1.0)
198    #[napi(constructor)]
199    pub fn new(dim: u32, curvature: f64) -> Self {
200        let config = HyperbolicAttentionConfig {
201            dim: dim as usize,
202            curvature: curvature as f32,
203            ..Default::default()
204        };
205        Self {
206            inner: RustHyperbolic::new(config),
207            curvature_value: curvature as f32,
208            dim_value: dim as usize,
209        }
210    }
211
212    /// Create with full configuration
213    ///
214    /// # Arguments
215    /// * `dim` - Embedding dimension
216    /// * `curvature` - Hyperbolic curvature
217    /// * `adaptive_curvature` - Whether to use adaptive curvature
218    /// * `temperature` - Temperature for softmax
219    #[napi(factory)]
220    pub fn with_config(dim: u32, curvature: f64, adaptive_curvature: bool, temperature: f64) -> Self {
221        let config = HyperbolicAttentionConfig {
222            dim: dim as usize,
223            curvature: curvature as f32,
224            adaptive_curvature,
225            temperature: temperature as f32,
226            frechet_max_iter: 100,
227            frechet_tol: 1e-6,
228        };
229        Self {
230            inner: RustHyperbolic::new(config),
231            curvature_value: curvature as f32,
232            dim_value: dim as usize,
233        }
234    }
235
236    /// Compute hyperbolic attention
237    #[napi]
238    pub fn compute(
239        &self,
240        query: Float32Array,
241        keys: Vec<Float32Array>,
242        values: Vec<Float32Array>,
243    ) -> Result<Float32Array> {
244        let query_slice = query.as_ref();
245        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
246        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
247        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
248        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
249
250        let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
251            .map_err(|e| Error::from_reason(e.to_string()))?;
252
253        Ok(Float32Array::new(result))
254    }
255
256    /// Get the curvature
257    #[napi(getter)]
258    pub fn curvature(&self) -> f64 {
259        self.curvature_value as f64
260    }
261
262    /// Get the dimension
263    #[napi(getter)]
264    pub fn dim(&self) -> u32 {
265        self.dim_value as u32
266    }
267}
268
269/// Flash attention with tiled computation
270#[napi]
271pub struct FlashAttention {
272    inner: RustFlash,
273    dim_value: usize,
274    block_size_value: usize,
275}
276
277#[napi]
278impl FlashAttention {
279    /// Create a new flash attention instance
280    ///
281    /// # Arguments
282    /// * `dim` - Embedding dimension
283    /// * `block_size` - Block size for tiled computation
284    #[napi(constructor)]
285    pub fn new(dim: u32, block_size: u32) -> Self {
286        Self {
287            inner: RustFlash::new(dim as usize, block_size as usize),
288            dim_value: dim as usize,
289            block_size_value: block_size as usize,
290        }
291    }
292
293    /// Compute flash attention
294    #[napi]
295    pub fn compute(
296        &self,
297        query: Float32Array,
298        keys: Vec<Float32Array>,
299        values: Vec<Float32Array>,
300    ) -> Result<Float32Array> {
301        let query_slice = query.as_ref();
302        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
303        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
304        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
305        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
306
307        let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
308            .map_err(|e| Error::from_reason(e.to_string()))?;
309
310        Ok(Float32Array::new(result))
311    }
312
313    /// Get the dimension
314    #[napi(getter)]
315    pub fn dim(&self) -> u32 {
316        self.dim_value as u32
317    }
318
319    /// Get the block size
320    #[napi(getter)]
321    pub fn block_size(&self) -> u32 {
322        self.block_size_value as u32
323    }
324}
325
326/// Linear attention (Performer-style) with O(n) complexity
327#[napi]
328pub struct LinearAttention {
329    inner: RustLinear,
330    dim_value: usize,
331    num_features_value: usize,
332}
333
334#[napi]
335impl LinearAttention {
336    /// Create a new linear attention instance
337    ///
338    /// # Arguments
339    /// * `dim` - Embedding dimension
340    /// * `num_features` - Number of random features
341    #[napi(constructor)]
342    pub fn new(dim: u32, num_features: u32) -> Self {
343        Self {
344            inner: RustLinear::new(dim as usize, num_features as usize),
345            dim_value: dim as usize,
346            num_features_value: num_features as usize,
347        }
348    }
349
350    /// Compute linear attention
351    #[napi]
352    pub fn compute(
353        &self,
354        query: Float32Array,
355        keys: Vec<Float32Array>,
356        values: Vec<Float32Array>,
357    ) -> Result<Float32Array> {
358        let query_slice = query.as_ref();
359        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
360        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
361        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
362        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
363
364        let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
365            .map_err(|e| Error::from_reason(e.to_string()))?;
366
367        Ok(Float32Array::new(result))
368    }
369
370    /// Get the dimension
371    #[napi(getter)]
372    pub fn dim(&self) -> u32 {
373        self.dim_value as u32
374    }
375
376    /// Get the number of random features
377    #[napi(getter)]
378    pub fn num_features(&self) -> u32 {
379        self.num_features_value as u32
380    }
381}
382
383/// Local-global attention (Longformer-style)
384#[napi]
385pub struct LocalGlobalAttention {
386    inner: RustLocalGlobal,
387    dim_value: usize,
388    local_window_value: usize,
389    global_tokens_value: usize,
390}
391
392#[napi]
393impl LocalGlobalAttention {
394    /// Create a new local-global attention instance
395    ///
396    /// # Arguments
397    /// * `dim` - Embedding dimension
398    /// * `local_window` - Size of local attention window
399    /// * `global_tokens` - Number of global attention tokens
400    #[napi(constructor)]
401    pub fn new(dim: u32, local_window: u32, global_tokens: u32) -> Self {
402        Self {
403            inner: RustLocalGlobal::new(dim as usize, local_window as usize, global_tokens as usize),
404            dim_value: dim as usize,
405            local_window_value: local_window as usize,
406            global_tokens_value: global_tokens as usize,
407        }
408    }
409
410    /// Compute local-global attention
411    #[napi]
412    pub fn compute(
413        &self,
414        query: Float32Array,
415        keys: Vec<Float32Array>,
416        values: Vec<Float32Array>,
417    ) -> Result<Float32Array> {
418        let query_slice = query.as_ref();
419        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
420        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
421        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
422        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
423
424        let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
425            .map_err(|e| Error::from_reason(e.to_string()))?;
426
427        Ok(Float32Array::new(result))
428    }
429
430    /// Get the dimension
431    #[napi(getter)]
432    pub fn dim(&self) -> u32 {
433        self.dim_value as u32
434    }
435
436    /// Get the local window size
437    #[napi(getter)]
438    pub fn local_window(&self) -> u32 {
439        self.local_window_value as u32
440    }
441
442    /// Get the number of global tokens
443    #[napi(getter)]
444    pub fn global_tokens(&self) -> u32 {
445        self.global_tokens_value as u32
446    }
447}
448
449/// MoE attention configuration
450#[napi(object)]
451pub struct MoEConfig {
452    pub dim: u32,
453    pub num_experts: u32,
454    pub top_k: u32,
455    pub expert_capacity: Option<f64>,
456}
457
458/// Mixture of Experts attention
459#[napi]
460pub struct MoEAttention {
461    inner: RustMoE,
462    config: MoEConfig,
463}
464
465#[napi]
466impl MoEAttention {
467    /// Create a new MoE attention instance
468    ///
469    /// # Arguments
470    /// * `config` - MoE configuration object
471    #[napi(constructor)]
472    pub fn new(config: MoEConfig) -> Self {
473        let rust_config = RustMoEConfig::builder()
474            .dim(config.dim as usize)
475            .num_experts(config.num_experts as usize)
476            .top_k(config.top_k as usize)
477            .expert_capacity(config.expert_capacity.unwrap_or(1.25) as f32)
478            .build();
479
480        Self {
481            inner: RustMoE::new(rust_config),
482            config,
483        }
484    }
485
486    /// Create with simple parameters
487    ///
488    /// # Arguments
489    /// * `dim` - Embedding dimension
490    /// * `num_experts` - Number of expert networks
491    /// * `top_k` - Number of experts to route to
492    #[napi(factory)]
493    pub fn simple(dim: u32, num_experts: u32, top_k: u32) -> Self {
494        let config = MoEConfig {
495            dim,
496            num_experts,
497            top_k,
498            expert_capacity: Some(1.25),
499        };
500        Self::new(config)
501    }
502
503    /// Compute MoE attention
504    #[napi]
505    pub fn compute(
506        &self,
507        query: Float32Array,
508        keys: Vec<Float32Array>,
509        values: Vec<Float32Array>,
510    ) -> Result<Float32Array> {
511        let query_slice = query.as_ref();
512        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
513        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
514        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
515        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
516
517        let result = self.inner.compute(query_slice, &keys_refs, &values_refs)
518            .map_err(|e| Error::from_reason(e.to_string()))?;
519
520        Ok(Float32Array::new(result))
521    }
522
523    /// Get the dimension
524    #[napi(getter)]
525    pub fn dim(&self) -> u32 {
526        self.config.dim
527    }
528
529    /// Get the number of experts
530    #[napi(getter)]
531    pub fn num_experts(&self) -> u32 {
532        self.config.num_experts
533    }
534
535    /// Get the top-k value
536    #[napi(getter)]
537    pub fn top_k(&self) -> u32 {
538        self.config.top_k
539    }
540}
541
542// Utility functions
543
544/// Project a vector into the Poincaré ball
545#[napi]
546pub fn project_to_poincare_ball(vector: Float32Array, curvature: f64) -> Float32Array {
547    let v = vector.to_vec();
548    let projected = ruvector_attention::hyperbolic::project_to_ball(&v, curvature as f32, 1e-5);
549    Float32Array::new(projected)
550}
551
552/// Compute hyperbolic (Poincaré) distance between two points
553#[napi]
554pub fn poincare_distance(a: Float32Array, b: Float32Array, curvature: f64) -> f64 {
555    let a_slice = a.as_ref();
556    let b_slice = b.as_ref();
557    ruvector_attention::hyperbolic::poincare_distance(a_slice, b_slice, curvature as f32) as f64
558}
559
560/// Möbius addition in hyperbolic space
561#[napi]
562pub fn mobius_addition(a: Float32Array, b: Float32Array, curvature: f64) -> Float32Array {
563    let a_slice = a.as_ref();
564    let b_slice = b.as_ref();
565    let result = ruvector_attention::hyperbolic::mobius_add(a_slice, b_slice, curvature as f32);
566    Float32Array::new(result)
567}
568
569/// Exponential map from tangent space to hyperbolic space
570#[napi]
571pub fn exp_map(base: Float32Array, tangent: Float32Array, curvature: f64) -> Float32Array {
572    let base_slice = base.as_ref();
573    let tangent_slice = tangent.as_ref();
574    let result = ruvector_attention::hyperbolic::exp_map(base_slice, tangent_slice, curvature as f32);
575    Float32Array::new(result)
576}
577
578/// Logarithmic map from hyperbolic space to tangent space
579#[napi]
580pub fn log_map(base: Float32Array, point: Float32Array, curvature: f64) -> Float32Array {
581    let base_slice = base.as_ref();
582    let point_slice = point.as_ref();
583    let result = ruvector_attention::hyperbolic::log_map(base_slice, point_slice, curvature as f32);
584    Float32Array::new(result)
585}