Skip to main content

ruvector_attention_node/
attention.rs

1//! NAPI-RS bindings for attention mechanisms
2//!
3//! Provides Node.js bindings for all attention variants:
4//! - Scaled dot-product attention
5//! - Multi-head attention
6//! - Hyperbolic attention
7//! - Flash attention
8//! - Linear attention
9//! - Local-global attention
10//! - Mixture of Experts attention
11
12use napi::bindgen_prelude::*;
13use napi_derive::napi;
14use ruvector_attention::{
15    attention::{MultiHeadAttention as RustMultiHead, ScaledDotProductAttention},
16    hyperbolic::{HyperbolicAttention as RustHyperbolic, HyperbolicAttentionConfig},
17    moe::{MoEAttention as RustMoE, MoEConfig as RustMoEConfig},
18    sparse::{
19        FlashAttention as RustFlash, LinearAttention as RustLinear,
20        LocalGlobalAttention as RustLocalGlobal,
21    },
22    traits::Attention,
23};
24
25/// Attention configuration object
26#[napi(object)]
27pub struct AttentionConfig {
28    pub dim: u32,
29    pub num_heads: Option<u32>,
30    pub dropout: Option<f64>,
31    pub scale: Option<f64>,
32    pub causal: Option<bool>,
33}
34
35/// Scaled dot-product attention
36#[napi]
37pub struct DotProductAttention {
38    inner: ScaledDotProductAttention,
39}
40
41#[napi]
42impl DotProductAttention {
43    /// Create a new scaled dot-product attention instance
44    ///
45    /// # Arguments
46    /// * `dim` - Embedding dimension
47    #[napi(constructor)]
48    pub fn new(dim: u32) -> Result<Self> {
49        Ok(Self {
50            inner: ScaledDotProductAttention::new(dim as usize),
51        })
52    }
53
54    /// Compute attention output
55    ///
56    /// # Arguments
57    /// * `query` - Query vector
58    /// * `keys` - Array of key vectors
59    /// * `values` - Array of value vectors
60    #[napi]
61    pub fn compute(
62        &self,
63        query: Float32Array,
64        keys: Vec<Float32Array>,
65        values: Vec<Float32Array>,
66    ) -> Result<Float32Array> {
67        let query_slice = query.as_ref();
68        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
69        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
70        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
71        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
72
73        let result = self
74            .inner
75            .compute(query_slice, &keys_refs, &values_refs)
76            .map_err(|e| Error::from_reason(e.to_string()))?;
77
78        Ok(Float32Array::new(result))
79    }
80
81    /// Compute attention with mask
82    ///
83    /// # Arguments
84    /// * `query` - Query vector
85    /// * `keys` - Array of key vectors
86    /// * `values` - Array of value vectors
87    /// * `mask` - Boolean mask array (true = attend, false = mask)
88    #[napi]
89    pub fn compute_with_mask(
90        &self,
91        query: Float32Array,
92        keys: Vec<Float32Array>,
93        values: Vec<Float32Array>,
94        mask: Vec<bool>,
95    ) -> Result<Float32Array> {
96        let query_slice = query.as_ref();
97        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
98        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
99        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
100        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
101
102        let result = self
103            .inner
104            .compute_with_mask(query_slice, &keys_refs, &values_refs, Some(mask.as_slice()))
105            .map_err(|e| Error::from_reason(e.to_string()))?;
106
107        Ok(Float32Array::new(result))
108    }
109
110    /// Get the dimension
111    #[napi(getter)]
112    pub fn dim(&self) -> u32 {
113        self.inner.dim() as u32
114    }
115}
116
117/// Multi-head attention mechanism
118#[napi]
119pub struct MultiHeadAttention {
120    inner: RustMultiHead,
121    dim_value: usize,
122    num_heads_value: usize,
123}
124
125#[napi]
126impl MultiHeadAttention {
127    /// Create a new multi-head attention instance
128    ///
129    /// # Arguments
130    /// * `dim` - Embedding dimension (must be divisible by num_heads)
131    /// * `num_heads` - Number of attention heads
132    #[napi(constructor)]
133    pub fn new(dim: u32, num_heads: u32) -> Result<Self> {
134        let d = dim as usize;
135        let h = num_heads as usize;
136
137        if d % h != 0 {
138            return Err(Error::from_reason(format!(
139                "Dimension {} must be divisible by number of heads {}",
140                d, h
141            )));
142        }
143
144        Ok(Self {
145            inner: RustMultiHead::new(d, h),
146            dim_value: d,
147            num_heads_value: h,
148        })
149    }
150
151    /// Compute multi-head attention
152    #[napi]
153    pub fn compute(
154        &self,
155        query: Float32Array,
156        keys: Vec<Float32Array>,
157        values: Vec<Float32Array>,
158    ) -> Result<Float32Array> {
159        let query_slice = query.as_ref();
160        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
161        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
162        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
163        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
164
165        let result = self
166            .inner
167            .compute(query_slice, &keys_refs, &values_refs)
168            .map_err(|e| Error::from_reason(e.to_string()))?;
169
170        Ok(Float32Array::new(result))
171    }
172
173    /// Get the number of heads
174    #[napi(getter)]
175    pub fn num_heads(&self) -> u32 {
176        self.num_heads_value as u32
177    }
178
179    /// Get the dimension
180    #[napi(getter)]
181    pub fn dim(&self) -> u32 {
182        self.dim_value as u32
183    }
184
185    /// Get the head dimension
186    #[napi(getter)]
187    pub fn head_dim(&self) -> u32 {
188        (self.dim_value / self.num_heads_value) as u32
189    }
190}
191
192/// Hyperbolic attention in Poincaré ball model
193#[napi]
194pub struct HyperbolicAttention {
195    inner: RustHyperbolic,
196    curvature_value: f32,
197    dim_value: usize,
198}
199
200#[napi]
201impl HyperbolicAttention {
202    /// Create a new hyperbolic attention instance
203    ///
204    /// # Arguments
205    /// * `dim` - Embedding dimension
206    /// * `curvature` - Hyperbolic curvature (typically 1.0)
207    #[napi(constructor)]
208    pub fn new(dim: u32, curvature: f64) -> Self {
209        let config = HyperbolicAttentionConfig {
210            dim: dim as usize,
211            curvature: curvature as f32,
212            ..Default::default()
213        };
214        Self {
215            inner: RustHyperbolic::new(config),
216            curvature_value: curvature as f32,
217            dim_value: dim as usize,
218        }
219    }
220
221    /// Create with full configuration
222    ///
223    /// # Arguments
224    /// * `dim` - Embedding dimension
225    /// * `curvature` - Hyperbolic curvature
226    /// * `adaptive_curvature` - Whether to use adaptive curvature
227    /// * `temperature` - Temperature for softmax
228    #[napi(factory)]
229    pub fn with_config(
230        dim: u32,
231        curvature: f64,
232        adaptive_curvature: bool,
233        temperature: f64,
234    ) -> Self {
235        let config = HyperbolicAttentionConfig {
236            dim: dim as usize,
237            curvature: curvature as f32,
238            adaptive_curvature,
239            temperature: temperature as f32,
240            frechet_max_iter: 100,
241            frechet_tol: 1e-6,
242        };
243        Self {
244            inner: RustHyperbolic::new(config),
245            curvature_value: curvature as f32,
246            dim_value: dim as usize,
247        }
248    }
249
250    /// Compute hyperbolic attention
251    #[napi]
252    pub fn compute(
253        &self,
254        query: Float32Array,
255        keys: Vec<Float32Array>,
256        values: Vec<Float32Array>,
257    ) -> Result<Float32Array> {
258        let query_slice = query.as_ref();
259        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
260        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
261        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
262        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
263
264        let result = self
265            .inner
266            .compute(query_slice, &keys_refs, &values_refs)
267            .map_err(|e| Error::from_reason(e.to_string()))?;
268
269        Ok(Float32Array::new(result))
270    }
271
272    /// Get the curvature
273    #[napi(getter)]
274    pub fn curvature(&self) -> f64 {
275        self.curvature_value as f64
276    }
277
278    /// Get the dimension
279    #[napi(getter)]
280    pub fn dim(&self) -> u32 {
281        self.dim_value as u32
282    }
283}
284
285/// Flash attention with tiled computation
286#[napi]
287pub struct FlashAttention {
288    inner: RustFlash,
289    dim_value: usize,
290    block_size_value: usize,
291}
292
293#[napi]
294impl FlashAttention {
295    /// Create a new flash attention instance
296    ///
297    /// # Arguments
298    /// * `dim` - Embedding dimension
299    /// * `block_size` - Block size for tiled computation
300    #[napi(constructor)]
301    pub fn new(dim: u32, block_size: u32) -> Self {
302        Self {
303            inner: RustFlash::new(dim as usize, block_size as usize),
304            dim_value: dim as usize,
305            block_size_value: block_size as usize,
306        }
307    }
308
309    /// Compute flash attention
310    #[napi]
311    pub fn compute(
312        &self,
313        query: Float32Array,
314        keys: Vec<Float32Array>,
315        values: Vec<Float32Array>,
316    ) -> Result<Float32Array> {
317        let query_slice = query.as_ref();
318        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
319        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
320        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
321        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
322
323        let result = self
324            .inner
325            .compute(query_slice, &keys_refs, &values_refs)
326            .map_err(|e| Error::from_reason(e.to_string()))?;
327
328        Ok(Float32Array::new(result))
329    }
330
331    /// Get the dimension
332    #[napi(getter)]
333    pub fn dim(&self) -> u32 {
334        self.dim_value as u32
335    }
336
337    /// Get the block size
338    #[napi(getter)]
339    pub fn block_size(&self) -> u32 {
340        self.block_size_value as u32
341    }
342}
343
344/// Linear attention (Performer-style) with O(n) complexity
345#[napi]
346pub struct LinearAttention {
347    inner: RustLinear,
348    dim_value: usize,
349    num_features_value: usize,
350}
351
352#[napi]
353impl LinearAttention {
354    /// Create a new linear attention instance
355    ///
356    /// # Arguments
357    /// * `dim` - Embedding dimension
358    /// * `num_features` - Number of random features
359    #[napi(constructor)]
360    pub fn new(dim: u32, num_features: u32) -> Self {
361        Self {
362            inner: RustLinear::new(dim as usize, num_features as usize),
363            dim_value: dim as usize,
364            num_features_value: num_features as usize,
365        }
366    }
367
368    /// Compute linear attention
369    #[napi]
370    pub fn compute(
371        &self,
372        query: Float32Array,
373        keys: Vec<Float32Array>,
374        values: Vec<Float32Array>,
375    ) -> Result<Float32Array> {
376        let query_slice = query.as_ref();
377        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
378        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
379        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
380        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
381
382        let result = self
383            .inner
384            .compute(query_slice, &keys_refs, &values_refs)
385            .map_err(|e| Error::from_reason(e.to_string()))?;
386
387        Ok(Float32Array::new(result))
388    }
389
390    /// Get the dimension
391    #[napi(getter)]
392    pub fn dim(&self) -> u32 {
393        self.dim_value as u32
394    }
395
396    /// Get the number of random features
397    #[napi(getter)]
398    pub fn num_features(&self) -> u32 {
399        self.num_features_value as u32
400    }
401}
402
403/// Local-global attention (Longformer-style)
404#[napi]
405pub struct LocalGlobalAttention {
406    inner: RustLocalGlobal,
407    dim_value: usize,
408    local_window_value: usize,
409    global_tokens_value: usize,
410}
411
412#[napi]
413impl LocalGlobalAttention {
414    /// Create a new local-global attention instance
415    ///
416    /// # Arguments
417    /// * `dim` - Embedding dimension
418    /// * `local_window` - Size of local attention window
419    /// * `global_tokens` - Number of global attention tokens
420    #[napi(constructor)]
421    pub fn new(dim: u32, local_window: u32, global_tokens: u32) -> Self {
422        Self {
423            inner: RustLocalGlobal::new(
424                dim as usize,
425                local_window as usize,
426                global_tokens as usize,
427            ),
428            dim_value: dim as usize,
429            local_window_value: local_window as usize,
430            global_tokens_value: global_tokens as usize,
431        }
432    }
433
434    /// Compute local-global attention
435    #[napi]
436    pub fn compute(
437        &self,
438        query: Float32Array,
439        keys: Vec<Float32Array>,
440        values: Vec<Float32Array>,
441    ) -> Result<Float32Array> {
442        let query_slice = query.as_ref();
443        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
444        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
445        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
446        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
447
448        let result = self
449            .inner
450            .compute(query_slice, &keys_refs, &values_refs)
451            .map_err(|e| Error::from_reason(e.to_string()))?;
452
453        Ok(Float32Array::new(result))
454    }
455
456    /// Get the dimension
457    #[napi(getter)]
458    pub fn dim(&self) -> u32 {
459        self.dim_value as u32
460    }
461
462    /// Get the local window size
463    #[napi(getter)]
464    pub fn local_window(&self) -> u32 {
465        self.local_window_value as u32
466    }
467
468    /// Get the number of global tokens
469    #[napi(getter)]
470    pub fn global_tokens(&self) -> u32 {
471        self.global_tokens_value as u32
472    }
473}
474
475/// MoE attention configuration
476#[napi(object)]
477pub struct MoEConfig {
478    pub dim: u32,
479    pub num_experts: u32,
480    pub top_k: u32,
481    pub expert_capacity: Option<f64>,
482}
483
484/// Mixture of Experts attention
485#[napi]
486pub struct MoEAttention {
487    inner: RustMoE,
488    config: MoEConfig,
489}
490
491#[napi]
492impl MoEAttention {
493    /// Create a new MoE attention instance
494    ///
495    /// # Arguments
496    /// * `config` - MoE configuration object
497    #[napi(constructor)]
498    pub fn new(config: MoEConfig) -> Self {
499        let rust_config = RustMoEConfig::builder()
500            .dim(config.dim as usize)
501            .num_experts(config.num_experts as usize)
502            .top_k(config.top_k as usize)
503            .expert_capacity(config.expert_capacity.unwrap_or(1.25) as f32)
504            .build();
505
506        Self {
507            inner: RustMoE::new(rust_config),
508            config,
509        }
510    }
511
512    /// Create with simple parameters
513    ///
514    /// # Arguments
515    /// * `dim` - Embedding dimension
516    /// * `num_experts` - Number of expert networks
517    /// * `top_k` - Number of experts to route to
518    #[napi(factory)]
519    pub fn simple(dim: u32, num_experts: u32, top_k: u32) -> Self {
520        let config = MoEConfig {
521            dim,
522            num_experts,
523            top_k,
524            expert_capacity: Some(1.25),
525        };
526        Self::new(config)
527    }
528
529    /// Compute MoE attention
530    #[napi]
531    pub fn compute(
532        &self,
533        query: Float32Array,
534        keys: Vec<Float32Array>,
535        values: Vec<Float32Array>,
536    ) -> Result<Float32Array> {
537        let query_slice = query.as_ref();
538        let keys_vec: Vec<Vec<f32>> = keys.into_iter().map(|k| k.to_vec()).collect();
539        let values_vec: Vec<Vec<f32>> = values.into_iter().map(|v| v.to_vec()).collect();
540        let keys_refs: Vec<&[f32]> = keys_vec.iter().map(|k| k.as_slice()).collect();
541        let values_refs: Vec<&[f32]> = values_vec.iter().map(|v| v.as_slice()).collect();
542
543        let result = self
544            .inner
545            .compute(query_slice, &keys_refs, &values_refs)
546            .map_err(|e| Error::from_reason(e.to_string()))?;
547
548        Ok(Float32Array::new(result))
549    }
550
551    /// Get the dimension
552    #[napi(getter)]
553    pub fn dim(&self) -> u32 {
554        self.config.dim
555    }
556
557    /// Get the number of experts
558    #[napi(getter)]
559    pub fn num_experts(&self) -> u32 {
560        self.config.num_experts
561    }
562
563    /// Get the top-k value
564    #[napi(getter)]
565    pub fn top_k(&self) -> u32 {
566        self.config.top_k
567    }
568}
569
570// Utility functions
571
572/// Project a vector into the Poincaré ball
573#[napi]
574pub fn project_to_poincare_ball(vector: Float32Array, curvature: f64) -> Float32Array {
575    let v = vector.to_vec();
576    let projected = ruvector_attention::hyperbolic::project_to_ball(&v, curvature as f32, 1e-5);
577    Float32Array::new(projected)
578}
579
580/// Compute hyperbolic (Poincaré) distance between two points
581#[napi]
582pub fn poincare_distance(a: Float32Array, b: Float32Array, curvature: f64) -> f64 {
583    let a_slice = a.as_ref();
584    let b_slice = b.as_ref();
585    ruvector_attention::hyperbolic::poincare_distance(a_slice, b_slice, curvature as f32) as f64
586}
587
588/// Möbius addition in hyperbolic space
589#[napi]
590pub fn mobius_addition(a: Float32Array, b: Float32Array, curvature: f64) -> Float32Array {
591    let a_slice = a.as_ref();
592    let b_slice = b.as_ref();
593    let result = ruvector_attention::hyperbolic::mobius_add(a_slice, b_slice, curvature as f32);
594    Float32Array::new(result)
595}
596
597/// Exponential map from tangent space to hyperbolic space
598#[napi]
599pub fn exp_map(base: Float32Array, tangent: Float32Array, curvature: f64) -> Float32Array {
600    let base_slice = base.as_ref();
601    let tangent_slice = tangent.as_ref();
602    let result =
603        ruvector_attention::hyperbolic::exp_map(base_slice, tangent_slice, curvature as f32);
604    Float32Array::new(result)
605}
606
607/// Logarithmic map from hyperbolic space to tangent space
608#[napi]
609pub fn log_map(base: Float32Array, point: Float32Array, curvature: f64) -> Float32Array {
610    let base_slice = base.as_ref();
611    let point_slice = point.as_ref();
612    let result = ruvector_attention::hyperbolic::log_map(base_slice, point_slice, curvature as f32);
613    Float32Array::new(result)
614}