ruvector_attention/
lib.rs

1//! # ruvector-attention
2//!
3//! Attention mechanisms for ruvector, including geometric, graph, and sparse attention.
4//!
5//! This crate provides efficient implementations of various attention mechanisms:
6//! - Scaled dot-product attention
7//! - Multi-head attention with parallel processing
8//! - Graph attention for GNN applications
9//! - Geometric attention in hyperbolic spaces
10//! - Sparse attention patterns
11//!
12//! ## Features
13//!
14//! - **SIMD Acceleration**: Optional SIMD optimizations for performance
15//! - **Parallel Processing**: Rayon-based parallel head computation
16//! - **WASM Support**: WebAssembly compilation support
17//! - **NAPI Bindings**: Node.js bindings for JavaScript integration
18//!
19//! ## Example
20//!
21//! ```rust
22//! use ruvector_attention::{
23//!     attention::ScaledDotProductAttention,
24//!     traits::Attention,
25//! };
26//!
27//! // Create scaled dot-product attention
28//! let attention = ScaledDotProductAttention::new(512);
29//!
30//! // Prepare inputs
31//! let query = vec![1.0; 512];
32//! let keys = vec![vec![0.5; 512], vec![0.3; 512]];
33//! let values = vec![vec![1.0; 512], vec![2.0; 512]];
34//!
35//! let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
36//! let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
37//!
38//! // Compute attention
39//! let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
40//! assert_eq!(output.len(), 512);
41//! ```
42
43pub mod attention;
44pub mod config;
45pub mod error;
46pub mod traits;
47pub mod utils;
48pub mod hyperbolic;
49pub mod sparse;
50pub mod moe;
51pub mod graph;
52pub mod training;
53pub mod sdk;
54
55// Re-export main types
56pub use attention::{MultiHeadAttention, ScaledDotProductAttention};
57pub use config::{AttentionConfig, GraphAttentionConfig, SparseAttentionConfig};
58pub use error::{AttentionError, AttentionResult};
59pub use traits::{
60    Attention, EdgeInfo, GeometricAttention, Gradients, GraphAttention, SparseAttention,
61    SparseMask, TrainableAttention,
62};
63pub use hyperbolic::{
64    poincare_distance, mobius_add, exp_map, log_map, project_to_ball,
65    HyperbolicAttention, HyperbolicAttentionConfig,
66    MixedCurvatureAttention, MixedCurvatureConfig,
67};
68
69// Sparse attention exports
70pub use sparse::{
71    SparseMaskBuilder, AttentionMask,
72    LocalGlobalAttention, LinearAttention, FlashAttention,
73};
74
75// MoE exports
76pub use moe::{
77    MoEAttention, MoEConfig,
78    Expert, ExpertType, StandardExpert, HyperbolicExpert, LinearExpert,
79    Router, LearnedRouter, TopKRouting,
80};
81
82// Graph attention exports
83pub use graph::{
84    EdgeFeaturedAttention, EdgeFeaturedConfig,
85    GraphRoPE, RoPEConfig,
86    DualSpaceAttention, DualSpaceConfig,
87};
88
89// Training exports
90pub use training::{
91    Loss, InfoNCELoss, LocalContrastiveLoss, SpectralRegularization, Reduction,
92    Optimizer, SGD, Adam, AdamW,
93    CurriculumScheduler, CurriculumStage, TemperatureAnnealing, DecayType,
94    NegativeMiner, HardNegativeMiner, MiningStrategy,
95};
96
97// SDK exports
98pub use sdk::{AttentionBuilder, AttentionPipeline, presets};
99
100/// Library version
101pub const VERSION: &str = env!("CARGO_PKG_VERSION");
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_version() {
109        assert!(!VERSION.is_empty());
110    }
111
112    #[test]
113    fn test_basic_attention_workflow() {
114        let config = AttentionConfig::builder()
115            .dim(64)
116            .num_heads(4)
117            .build()
118            .unwrap();
119
120        assert_eq!(config.dim, 64);
121        assert_eq!(config.num_heads, 4);
122        assert_eq!(config.head_dim(), 16);
123    }
124}