ruvector_attention/
lib.rs

1//! # ruvector-attention
2//!
3//! Attention mechanisms for ruvector, including geometric, graph, and sparse attention.
4//!
5//! This crate provides efficient implementations of various attention mechanisms:
6//! - Scaled dot-product attention
7//! - Multi-head attention with parallel processing
8//! - Graph attention for GNN applications
9//! - Geometric attention in hyperbolic spaces
10//! - Sparse attention patterns
11//!
12//! ## Features
13//!
14//! - **SIMD Acceleration**: Optional SIMD optimizations for performance
15//! - **Parallel Processing**: Rayon-based parallel head computation
16//! - **WASM Support**: WebAssembly compilation support
17//! - **NAPI Bindings**: Node.js bindings for JavaScript integration
18//!
19//! ## Example
20//!
21//! ```rust
22//! use ruvector_attention::{
23//!     attention::ScaledDotProductAttention,
24//!     traits::Attention,
25//! };
26//!
27//! // Create scaled dot-product attention
28//! let attention = ScaledDotProductAttention::new(512);
29//!
30//! // Prepare inputs
31//! let query = vec![1.0; 512];
32//! let keys = vec![vec![0.5; 512], vec![0.3; 512]];
33//! let values = vec![vec![1.0; 512], vec![2.0; 512]];
34//!
35//! let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
36//! let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
37//!
38//! // Compute attention
39//! let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
40//! assert_eq!(output.len(), 512);
41//! ```
42
43pub mod attention;
44pub mod config;
45pub mod error;
46pub mod graph;
47pub mod hyperbolic;
48pub mod moe;
49pub mod sdk;
50pub mod sparse;
51pub mod training;
52pub mod traits;
53pub mod utils;
54
55// Advanced attention mechanisms
56pub mod curvature;
57pub mod topology;
58pub mod transport;
59
60// Mathematical foundations
61pub mod info_bottleneck;
62pub mod info_geometry;
63pub mod pde_attention;
64pub mod unified_report;
65
66// Re-export main types
67pub use attention::{MultiHeadAttention, ScaledDotProductAttention};
68pub use config::{AttentionConfig, GraphAttentionConfig, SparseAttentionConfig};
69pub use error::{AttentionError, AttentionResult};
70pub use hyperbolic::{
71    exp_map, log_map, mobius_add, poincare_distance, project_to_ball, HyperbolicAttention,
72    HyperbolicAttentionConfig, MixedCurvatureAttention, MixedCurvatureConfig,
73};
74pub use traits::{
75    Attention, EdgeInfo, GeometricAttention, Gradients, GraphAttention, SparseAttention,
76    SparseMask, TrainableAttention,
77};
78
79// Sparse attention exports
80pub use sparse::{
81    AttentionMask, FlashAttention, LinearAttention, LocalGlobalAttention, SparseMaskBuilder,
82};
83
84// MoE exports
85pub use moe::{
86    Expert, ExpertType, HyperbolicExpert, LearnedRouter, LinearExpert, MoEAttention, MoEConfig,
87    Router, StandardExpert, TopKRouting,
88};
89
90// Graph attention exports
91pub use graph::{
92    DualSpaceAttention, DualSpaceConfig, EdgeFeaturedAttention, EdgeFeaturedConfig, GraphRoPE,
93    RoPEConfig,
94};
95
96// Training exports
97pub use training::{
98    Adam, AdamW, CurriculumScheduler, CurriculumStage, DecayType, HardNegativeMiner, InfoNCELoss,
99    LocalContrastiveLoss, Loss, MiningStrategy, NegativeMiner, Optimizer, Reduction,
100    SpectralRegularization, TemperatureAnnealing, SGD,
101};
102
103// SDK exports
104pub use sdk::{presets, AttentionBuilder, AttentionPipeline};
105
106// Transport (OT-based attention) exports
107pub use transport::{
108    CentroidCache, CentroidOTAttention, CentroidOTConfig, ProjectionCache,
109    SlicedWassersteinAttention, SlicedWassersteinConfig, WindowCache,
110};
111
112// Curvature (Mixed curvature attention) exports
113pub use curvature::{
114    ComponentQuantizer, FusedCurvatureConfig, MixedCurvatureCache,
115    MixedCurvatureFusedAttention, QuantizationConfig, QuantizedVector, TangentSpaceMapper,
116    TangentSpaceConfig,
117};
118
119// Topology (Gated attention) exports
120pub use topology::{
121    AttentionMode, AttentionPolicy, CoherenceMetric, PolicyConfig, TopologyGatedAttention,
122    TopologyGatedConfig, WindowCoherence,
123};
124
125// Information Geometry exports
126pub use info_geometry::{FisherConfig, FisherMetric, NaturalGradient, NaturalGradientConfig};
127
128// Information Bottleneck exports
129pub use info_bottleneck::{
130    DiagonalGaussian, IBConfig, InformationBottleneck, KLDivergence,
131};
132
133// PDE Attention exports
134pub use pde_attention::{DiffusionAttention, DiffusionConfig, GraphLaplacian, LaplacianType};
135
136// Unified Report exports
137pub use unified_report::{
138    AttentionRecommendation, GeometryReport, MetricType, MetricValue, ReportBuilder, ReportConfig,
139};
140
141/// Library version
142pub const VERSION: &str = env!("CARGO_PKG_VERSION");
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_version() {
150        assert!(!VERSION.is_empty());
151    }
152
153    #[test]
154    fn test_basic_attention_workflow() {
155        let config = AttentionConfig::builder()
156            .dim(64)
157            .num_heads(4)
158            .build()
159            .unwrap();
160
161        assert_eq!(config.dim, 64);
162        assert_eq!(config.num_heads, 4);
163        assert_eq!(config.head_dim(), 16);
164    }
165}