Skip to main content

ruvector_attention/
lib.rs

1//! # ruvector-attention
2//!
3//! Attention mechanisms for ruvector, including geometric, graph, and sparse attention.
4//!
5//! This crate provides efficient implementations of various attention mechanisms:
6//! - Scaled dot-product attention
7//! - Multi-head attention with parallel processing
8//! - Graph attention for GNN applications
9//! - Geometric attention in hyperbolic spaces
10//! - Sparse attention patterns
11//!
12//! ## Features
13//!
14//! - **SIMD Acceleration**: Optional SIMD optimizations for performance
15//! - **Parallel Processing**: Rayon-based parallel head computation
16//! - **WASM Support**: WebAssembly compilation support
17//! - **NAPI Bindings**: Node.js bindings for JavaScript integration
18//!
19//! ## Example
20//!
21//! ```rust
22//! use ruvector_attention::{
23//!     attention::ScaledDotProductAttention,
24//!     traits::Attention,
25//! };
26//!
27//! // Create scaled dot-product attention
28//! let attention = ScaledDotProductAttention::new(512);
29//!
30//! // Prepare inputs
31//! let query = vec![1.0; 512];
32//! let keys = vec![vec![0.5; 512], vec![0.3; 512]];
33//! let values = vec![vec![1.0; 512], vec![2.0; 512]];
34//!
35//! let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
36//! let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
37//!
38//! // Compute attention
39//! let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
40//! assert_eq!(output.len(), 512);
41//! ```
42
43pub mod attention;
44pub mod config;
45pub mod error;
46pub mod graph;
47pub mod hyperbolic;
48pub mod moe;
49pub mod sdk;
50pub mod sparse;
51pub mod training;
52pub mod traits;
53pub mod utils;
54
55// Advanced attention mechanisms
56pub mod curvature;
57pub mod topology;
58pub mod transport;
59
60// Mathematical foundations
61pub mod info_bottleneck;
62pub mod info_geometry;
63pub mod pde_attention;
64pub mod unified_report;
65
66// Sheaf attention (Coherence-Gated Transformer per ADR-015)
67#[cfg(feature = "sheaf")]
68pub mod sheaf;
69
70// Re-export main types
71pub use attention::{MLACache, MLAConfig, MLALayer, MemoryComparison};
72pub use attention::{MultiHeadAttention, ScaledDotProductAttention};
73pub use config::{AttentionConfig, GraphAttentionConfig, SparseAttentionConfig};
74pub use error::{AttentionError, AttentionResult};
75pub use hyperbolic::{
76    exp_map, log_map, mobius_add, poincare_distance, project_to_ball, HyperbolicAttention,
77    HyperbolicAttentionConfig, MixedCurvatureAttention, MixedCurvatureConfig,
78};
79pub use traits::{
80    Attention, EdgeInfo, GeometricAttention, Gradients, GraphAttention, SparseAttention,
81    SparseMask, TrainableAttention,
82};
83
84// Sparse attention exports
85pub use sparse::{
86    AttentionMask, FlashAttention, LinearAttention, LocalGlobalAttention, SparseMaskBuilder,
87};
88
89// MoE exports
90pub use moe::{
91    Expert, ExpertType, HyperbolicExpert, LearnedRouter, LinearExpert, MoEAttention, MoEConfig,
92    Router, StandardExpert, TopKRouting,
93};
94
95// Graph attention exports
96pub use graph::{
97    DualSpaceAttention, DualSpaceConfig, EdgeFeaturedAttention, EdgeFeaturedConfig, GraphRoPE,
98    RoPEConfig,
99};
100
101// Training exports
102pub use training::{
103    Adam, AdamW, CurriculumScheduler, CurriculumStage, DecayType, HardNegativeMiner, InfoNCELoss,
104    LocalContrastiveLoss, Loss, MiningStrategy, NegativeMiner, Optimizer, Reduction,
105    SpectralRegularization, TemperatureAnnealing, SGD,
106};
107
108// SDK exports
109pub use sdk::{presets, AttentionBuilder, AttentionPipeline};
110
111// Transport (OT-based attention) exports
112pub use transport::{
113    CentroidCache, CentroidOTAttention, CentroidOTConfig, ProjectionCache,
114    SlicedWassersteinAttention, SlicedWassersteinConfig, WindowCache,
115};
116
117// Curvature (Mixed curvature attention) exports
118pub use curvature::{
119    ComponentQuantizer, FusedCurvatureConfig, MixedCurvatureCache, MixedCurvatureFusedAttention,
120    QuantizationConfig, QuantizedVector, TangentSpaceConfig, TangentSpaceMapper,
121};
122
123// Topology (Gated attention) exports
124pub use topology::{
125    AttentionMode, AttentionPolicy, CoherenceMetric, PolicyConfig, TopologyGatedAttention,
126    TopologyGatedConfig, WindowCoherence,
127};
128
129// Information Geometry exports
130pub use info_geometry::{FisherConfig, FisherMetric, NaturalGradient, NaturalGradientConfig};
131
132// Information Bottleneck exports
133pub use info_bottleneck::{DiagonalGaussian, IBConfig, InformationBottleneck, KLDivergence};
134
135// PDE Attention exports
136pub use pde_attention::{DiffusionAttention, DiffusionConfig, GraphLaplacian, LaplacianType};
137
138// Sheaf Attention exports (Coherence-Gated Transformer per ADR-015)
139#[cfg(feature = "sheaf")]
140pub use sheaf::{
141    process_with_early_exit, ComputeLane, EarlyExit, EarlyExitConfig, EarlyExitResult,
142    EarlyExitStatistics, ExitReason, LaneStatistics, ResidualSparseMask, RestrictionMap,
143    RestrictionMapConfig, RoutingDecision, SheafAttention, SheafAttentionConfig,
144    SparseResidualAttention, SparseResidualConfig, SparsityStatistics, TokenRouter,
145    TokenRouterConfig,
146};
147
148// Unified Report exports
149pub use unified_report::{
150    AttentionRecommendation, GeometryReport, MetricType, MetricValue, ReportBuilder, ReportConfig,
151};
152
153/// Library version
154pub const VERSION: &str = env!("CARGO_PKG_VERSION");
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn test_version() {
162        assert!(!VERSION.is_empty());
163    }
164
165    #[test]
166    fn test_basic_attention_workflow() {
167        let config = AttentionConfig::builder()
168            .dim(64)
169            .num_heads(4)
170            .build()
171            .unwrap();
172
173        assert_eq!(config.dim, 64);
174        assert_eq!(config.num_heads, 4);
175        assert_eq!(config.head_dim(), 16);
176    }
177}