Skip to main content

ruvector_attention/
lib.rs

1//! # ruvector-attention
2//!
3//! Attention mechanisms for ruvector, including geometric, graph, and sparse attention.
4//!
5//! This crate provides efficient implementations of various attention mechanisms:
6//! - Scaled dot-product attention
7//! - Multi-head attention with parallel processing
8//! - Graph attention for GNN applications
9//! - Geometric attention in hyperbolic spaces
10//! - Sparse attention patterns
11//! - FlashAttention-3 IO-aware tiled attention
12//! - Multi-Head Latent Attention (MLA) with KV-cache compression
13//! - Selective State Space Models (Mamba)
14//! - Speculative decoding with draft/verify
15//!
16//! ## Features
17//!
18//! - **SIMD Acceleration**: Optional SIMD optimizations for performance
19//! - **Parallel Processing**: Rayon-based parallel head computation
20//! - **WASM Support**: WebAssembly compilation support
21//! - **NAPI Bindings**: Node.js bindings for JavaScript integration
22//!
23//! ## Example
24//!
25//! ```rust
26//! use ruvector_attention::{
27//!     attention::ScaledDotProductAttention,
28//!     traits::Attention,
29//! };
30//!
31//! // Create scaled dot-product attention
32//! let attention = ScaledDotProductAttention::new(512);
33//!
34//! // Prepare inputs
35//! let query = vec![1.0; 512];
36//! let keys = vec![vec![0.5; 512], vec![0.3; 512]];
37//! let values = vec![vec![1.0; 512], vec![2.0; 512]];
38//!
39//! let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
40//! let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
41//!
42//! // Compute attention
43//! let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
44//! assert_eq!(output.len(), 512);
45//! ```
46
47pub mod attention;
48pub mod config;
49pub mod error;
50pub mod graph;
51pub mod hyperbolic;
52pub mod moe;
53pub mod sdk;
54pub mod sparse;
55pub mod training;
56pub mod traits;
57pub mod utils;
58
59// Advanced attention mechanisms
60pub mod curvature;
61pub mod topology;
62pub mod transport;
63
64// Mathematical foundations
65pub mod info_bottleneck;
66pub mod info_geometry;
67pub mod pde_attention;
68pub mod unified_report;
69
70// Sheaf attention (Coherence-Gated Transformer per ADR-015)
71#[cfg(feature = "sheaf")]
72pub mod sheaf;
73
74// Re-export main types
75pub use attention::{MLACache, MLAConfig, MLALayer, MemoryComparison};
76pub use attention::{MultiHeadAttention, ScaledDotProductAttention};
77pub use config::{AttentionConfig, GraphAttentionConfig, SparseAttentionConfig};
78pub use error::{AttentionError, AttentionResult};
79pub use hyperbolic::{
80    exp_map, log_map, mobius_add, poincare_distance, project_to_ball, HyperbolicAttention,
81    HyperbolicAttentionConfig, MixedCurvatureAttention, MixedCurvatureConfig,
82};
83pub use traits::{
84    Attention, EdgeInfo, GeometricAttention, Gradients, GraphAttention, SparseAttention,
85    SparseMask, TrainableAttention,
86};
87
88// Sparse attention exports
89pub use sparse::{
90    AttentionMask, FlashAttention, LinearAttention, LocalGlobalAttention, SparseMaskBuilder,
91};
92
93// MoE exports
94pub use moe::{
95    Expert, ExpertType, HyperbolicExpert, LearnedRouter, LinearExpert, MoEAttention, MoEConfig,
96    Router, StandardExpert, TopKRouting,
97};
98
99// Graph attention exports
100pub use graph::{
101    DualSpaceAttention, DualSpaceConfig, EdgeFeaturedAttention, EdgeFeaturedConfig, GraphRoPE,
102    RoPEConfig,
103};
104
105// Training exports
106pub use training::{
107    Adam, AdamW, CurriculumScheduler, CurriculumStage, DecayType, HardNegativeMiner, InfoNCELoss,
108    LocalContrastiveLoss, Loss, MiningStrategy, NegativeMiner, Optimizer, Reduction,
109    SpectralRegularization, TemperatureAnnealing, SGD,
110};
111
112// SDK exports
113pub use sdk::{presets, AttentionBuilder, AttentionPipeline};
114
115// Transport (OT-based attention) exports
116pub use transport::{
117    CentroidCache, CentroidOTAttention, CentroidOTConfig, ProjectionCache,
118    SlicedWassersteinAttention, SlicedWassersteinConfig, WindowCache,
119};
120
121// Curvature (Mixed curvature attention) exports
122pub use curvature::{
123    ComponentQuantizer, FusedCurvatureConfig, MixedCurvatureCache, MixedCurvatureFusedAttention,
124    QuantizationConfig, QuantizedVector, TangentSpaceConfig, TangentSpaceMapper,
125};
126
127// Topology (Gated attention) exports
128pub use topology::{
129    AttentionMode, AttentionPolicy, CoherenceMetric, PolicyConfig, TopologyGatedAttention,
130    TopologyGatedConfig, WindowCoherence,
131};
132
133// Information Geometry exports
134pub use info_geometry::{FisherConfig, FisherMetric, NaturalGradient, NaturalGradientConfig};
135
136// Information Bottleneck exports
137pub use info_bottleneck::{DiagonalGaussian, IBConfig, InformationBottleneck, KLDivergence};
138
139// PDE Attention exports
140pub use pde_attention::{DiffusionAttention, DiffusionConfig, GraphLaplacian, LaplacianType};
141
142// Sheaf Attention exports (Coherence-Gated Transformer per ADR-015)
143#[cfg(feature = "sheaf")]
144pub use sheaf::{
145    process_with_early_exit, ComputeLane, EarlyExit, EarlyExitConfig, EarlyExitResult,
146    EarlyExitStatistics, ExitReason, LaneStatistics, ResidualSparseMask, RestrictionMap,
147    RestrictionMapConfig, RoutingDecision, SheafAttention, SheafAttentionConfig,
148    SparseResidualAttention, SparseResidualConfig, SparsityStatistics, TokenRouter,
149    TokenRouterConfig,
150};
151
152// Unified Report exports
153pub use unified_report::{
154    AttentionRecommendation, GeometryReport, MetricType, MetricValue, ReportBuilder, ReportConfig,
155};
156
157/// Library version
158pub const VERSION: &str = env!("CARGO_PKG_VERSION");
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_version() {
166        assert!(!VERSION.is_empty());
167    }
168
169    #[test]
170    fn test_basic_attention_workflow() {
171        let config = AttentionConfig::builder()
172            .dim(64)
173            .num_heads(4)
174            .build()
175            .unwrap();
176
177        assert_eq!(config.dim, 64);
178        assert_eq!(config.num_heads, 4);
179        assert_eq!(config.head_dim(), 16);
180    }
181}