vsa_optim_rs/lib.rs
1//! # vsa-optim-rs
2//!
3//! Deterministic training optimization using Vector Symbolic Architecture (VSA),
4//! ternary quantization, and closed-form gradient prediction.
5//!
6//! This crate enables efficient large model fine-tuning on consumer hardware through
7//! mathematically principled gradient compression and prediction with guaranteed
8//! reproducibility.
9//!
10//! ## Key Properties
11//!
12//! - **Deterministic**: Identical inputs produce identical outputs
13//! - **Closed-form**: Weighted least squares with Cramer's rule—no iterative optimization
14//! - **Memory-efficient**: ~90% gradient storage reduction via VSA compression
15//! - **Compute-efficient**: ~80% backward pass reduction via gradient prediction
16//!
17//! ## Quick Start
18//!
19//! The recommended entry point is [`DeterministicPhaseTrainer`], which orchestrates
20//! training through four phases: WARMUP → FULL → PREDICT → CORRECT.
21//!
22//! ```ignore
23//! use vsa_optim_rs::{DeterministicPhaseTrainer, DeterministicPhaseConfig, DeterministicPhase};
24//! use candle_core::Device;
25//!
26//! let shapes = vec![
27//! ("layer1.weight".into(), vec![768, 768]),
28//! ("layer2.weight".into(), vec![768, 3072]),
29//! ];
30//!
31//! let config = DeterministicPhaseConfig::default();
32//! let mut trainer = DeterministicPhaseTrainer::new(&shapes, config, &Device::Cpu)?;
33//!
34//! for step in 0..100 {
35//! let info = trainer.begin_step()?;
36//!
37//! if trainer.should_compute_full() {
38//! // Compute gradients via backpropagation
39//! trainer.record_full_gradients(&gradients)?;
40//! } else {
41//! // Use deterministically predicted gradients
42//! let predicted = trainer.get_predicted_gradients()?;
43//! }
44//!
45//! trainer.end_step(loss)?;
46//! }
47//! # Ok::<(), vsa_optim_rs::error::OptimError>(())
48//! ```
49//!
50//! ## Modules
51//!
52//! - [`config`]: Configuration types for all components
53//! - [`error`]: Error types and result aliases
54//! - [`phase`]: Phase-based training orchestration (deterministic and legacy)
55//! - [`prediction`]: Gradient prediction (deterministic least squares and momentum)
56//! - [`ternary`]: Ternary `{-1, 0, +1}` gradient accumulation
57//! - [`vsa`]: VSA gradient compression with bind/bundle/unbind operations
58//!
59//! ## Deterministic Gradient Prediction
60//!
61//! The core algorithm fits a linear gradient model using weighted least squares:
62//!
63//! ```text
64//! g(t) = baseline + velocity × t + residual
65//! ```
66//!
67//! - **baseline**: Weighted mean of historical gradients
68//! - **velocity**: Gradient change rate (fitted via Cramer's rule)
69//! - **residual**: Exponentially-averaged prediction error for drift correction
70//!
71//! ## References
72//!
73//! - Kanerva, P. (2009). Hyperdimensional Computing
74//! - Johnson, W. & Lindenstrauss, J. (1984). Extensions of Lipschitz mappings
75//! - Ma, S. et al. (2024). The Era of 1-bit LLMs
76
77#![deny(unsafe_code)]
78#![warn(missing_docs)]
79#![warn(clippy::pedantic)]
80#![allow(clippy::module_name_repetitions)]
81
82pub mod config;
83pub mod error;
84pub mod phase;
85pub mod prediction;
86pub mod ternary;
87pub mod vsa;
88
89// Re-export main types at crate root for convenience
90pub use config::{PhaseConfig, PredictionConfig, TernaryConfig, VSAConfig};
91pub use error::{OptimError, Result};
92pub use phase::{PhaseTrainer, TrainingPhase};
93pub use prediction::GradientPredictor;
94pub use ternary::{TernaryGradientAccumulator, TernaryOptimizerWrapper};
95pub use vsa::VSAGradientCompressor;
96
97// Re-export deterministic training types (recommended for production)
98pub use phase::{
99 DeterministicPhase, DeterministicPhaseConfig, DeterministicPhaseTrainer,
100 DeterministicStepInfo, DeterministicTrainerStats,
101};
102pub use prediction::{DeterministicPredictionConfig, DeterministicPredictor, PredictorStatistics};
103
104#[cfg(feature = "python")]
105mod python;
106
107#[cfg(feature = "python")]
108pub use python::*;