Skip to main content

ruvector_attention/
lib.rs

1//! # ruvector-attention
2//!
3//! Attention mechanisms for ruvector, including geometric, graph, and sparse attention.
4//!
5//! This crate provides efficient implementations of various attention mechanisms:
6//! - Scaled dot-product attention
7//! - Multi-head attention with parallel processing
8//! - Graph attention for GNN applications
9//! - Geometric attention in hyperbolic spaces
10//! - Sparse attention patterns
11//!
12//! ## Features
13//!
14//! - **SIMD Acceleration**: Optional SIMD optimizations for performance
15//! - **Parallel Processing**: Rayon-based parallel head computation
16//! - **WASM Support**: WebAssembly compilation support
17//! - **NAPI Bindings**: Node.js bindings for JavaScript integration
18//!
19//! ## Example
20//!
21//! ```rust
22//! use ruvector_attention::{
23//!     attention::ScaledDotProductAttention,
24//!     traits::Attention,
25//! };
26//!
27//! // Create scaled dot-product attention
28//! let attention = ScaledDotProductAttention::new(512);
29//!
30//! // Prepare inputs
31//! let query = vec![1.0; 512];
32//! let keys = vec![vec![0.5; 512], vec![0.3; 512]];
33//! let values = vec![vec![1.0; 512], vec![2.0; 512]];
34//!
35//! let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
36//! let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
37//!
38//! // Compute attention
39//! let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
40//! assert_eq!(output.len(), 512);
41//! ```
42
43pub mod attention;
44pub mod config;
45pub mod error;
46pub mod graph;
47pub mod hyperbolic;
48pub mod moe;
49pub mod sdk;
50pub mod sparse;
51pub mod training;
52pub mod traits;
53pub mod utils;
54
55// Advanced attention mechanisms
56pub mod curvature;
57pub mod topology;
58pub mod transport;
59
60// Mathematical foundations
61pub mod info_bottleneck;
62pub mod info_geometry;
63pub mod pde_attention;
64pub mod unified_report;
65
66// Sheaf attention (Coherence-Gated Transformer per ADR-015)
67#[cfg(feature = "sheaf")]
68pub mod sheaf;
69
70// Re-export main types
71pub use attention::{MultiHeadAttention, ScaledDotProductAttention};
72pub use config::{AttentionConfig, GraphAttentionConfig, SparseAttentionConfig};
73pub use error::{AttentionError, AttentionResult};
74pub use hyperbolic::{
75    exp_map, log_map, mobius_add, poincare_distance, project_to_ball, HyperbolicAttention,
76    HyperbolicAttentionConfig, MixedCurvatureAttention, MixedCurvatureConfig,
77};
78pub use traits::{
79    Attention, EdgeInfo, GeometricAttention, Gradients, GraphAttention, SparseAttention,
80    SparseMask, TrainableAttention,
81};
82
83// Sparse attention exports
84pub use sparse::{
85    AttentionMask, FlashAttention, LinearAttention, LocalGlobalAttention, SparseMaskBuilder,
86};
87
88// MoE exports
89pub use moe::{
90    Expert, ExpertType, HyperbolicExpert, LearnedRouter, LinearExpert, MoEAttention, MoEConfig,
91    Router, StandardExpert, TopKRouting,
92};
93
94// Graph attention exports
95pub use graph::{
96    DualSpaceAttention, DualSpaceConfig, EdgeFeaturedAttention, EdgeFeaturedConfig, GraphRoPE,
97    RoPEConfig,
98};
99
100// Training exports
101pub use training::{
102    Adam, AdamW, CurriculumScheduler, CurriculumStage, DecayType, HardNegativeMiner, InfoNCELoss,
103    LocalContrastiveLoss, Loss, MiningStrategy, NegativeMiner, Optimizer, Reduction,
104    SpectralRegularization, TemperatureAnnealing, SGD,
105};
106
107// SDK exports
108pub use sdk::{presets, AttentionBuilder, AttentionPipeline};
109
110// Transport (OT-based attention) exports
111pub use transport::{
112    CentroidCache, CentroidOTAttention, CentroidOTConfig, ProjectionCache,
113    SlicedWassersteinAttention, SlicedWassersteinConfig, WindowCache,
114};
115
116// Curvature (Mixed curvature attention) exports
117pub use curvature::{
118    ComponentQuantizer, FusedCurvatureConfig, MixedCurvatureCache, MixedCurvatureFusedAttention,
119    QuantizationConfig, QuantizedVector, TangentSpaceConfig, TangentSpaceMapper,
120};
121
122// Topology (Gated attention) exports
123pub use topology::{
124    AttentionMode, AttentionPolicy, CoherenceMetric, PolicyConfig, TopologyGatedAttention,
125    TopologyGatedConfig, WindowCoherence,
126};
127
128// Information Geometry exports
129pub use info_geometry::{FisherConfig, FisherMetric, NaturalGradient, NaturalGradientConfig};
130
131// Information Bottleneck exports
132pub use info_bottleneck::{DiagonalGaussian, IBConfig, InformationBottleneck, KLDivergence};
133
134// PDE Attention exports
135pub use pde_attention::{DiffusionAttention, DiffusionConfig, GraphLaplacian, LaplacianType};
136
137// Sheaf Attention exports (Coherence-Gated Transformer per ADR-015)
138#[cfg(feature = "sheaf")]
139pub use sheaf::{
140    process_with_early_exit, ComputeLane, EarlyExit, EarlyExitConfig, EarlyExitResult,
141    EarlyExitStatistics, ExitReason, LaneStatistics, ResidualSparseMask, RestrictionMap,
142    RestrictionMapConfig, RoutingDecision, SheafAttention, SheafAttentionConfig,
143    SparseResidualAttention, SparseResidualConfig, SparsityStatistics, TokenRouter,
144    TokenRouterConfig,
145};
146
147// Unified Report exports
148pub use unified_report::{
149    AttentionRecommendation, GeometryReport, MetricType, MetricValue, ReportBuilder, ReportConfig,
150};
151
152/// Library version
153pub const VERSION: &str = env!("CARGO_PKG_VERSION");
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_version() {
161        assert!(!VERSION.is_empty());
162    }
163
164    #[test]
165    fn test_basic_attention_workflow() {
166        let config = AttentionConfig::builder()
167            .dim(64)
168            .num_heads(4)
169            .build()
170            .unwrap();
171
172        assert_eq!(config.dim, 64);
173        assert_eq!(config.num_heads, 4);
174        assert_eq!(config.head_dim(), 16);
175    }
176}