Skip to main content

vsa_optim_rs/
config.rs

1//! Configuration types for VSA training optimization.
2//!
3//! This module provides configuration structs for all optimization components:
4//! - [`VSAConfig`]: VSA gradient compression settings
5//! - [`TernaryConfig`]: Ternary gradient accumulation settings
6//! - [`PredictionConfig`]: Gradient prediction settings
7//! - [`PhaseConfig`]: Phase-based training orchestration settings
8
9use serde::{Deserialize, Serialize};
10
11/// Configuration for VSA gradient compression.
12///
13/// # Example
14///
15/// ```
16/// use vsa_optim_rs::VSAConfig;
17///
18/// let config = VSAConfig::default()
19///     .with_compression_ratio(0.1)
20///     .with_ternary(true);
21/// ```
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct VSAConfig {
24    /// Hypervector dimension for compression.
25    pub dimension: usize,
26
27    /// Target compression ratio (0.0 to 1.0).
28    /// A ratio of 0.1 means 90% compression.
29    pub compression_ratio: f32,
30
31    /// Whether to use ternary quantization on compressed gradients.
32    pub use_ternary: bool,
33
34    /// Random seed for reproducible projections.
35    pub seed: u64,
36}
37
38impl Default for VSAConfig {
39    fn default() -> Self {
40        Self {
41            dimension: 8192,
42            compression_ratio: 0.1,
43            use_ternary: true,
44            seed: 42,
45        }
46    }
47}
48
49impl VSAConfig {
50    /// Set the compression ratio.
51    #[must_use]
52    pub const fn with_compression_ratio(mut self, ratio: f32) -> Self {
53        self.compression_ratio = ratio;
54        self
55    }
56
57    /// Set whether to use ternary quantization.
58    #[must_use]
59    pub const fn with_ternary(mut self, use_ternary: bool) -> Self {
60        self.use_ternary = use_ternary;
61        self
62    }
63
64    /// Set the random seed.
65    #[must_use]
66    pub const fn with_seed(mut self, seed: u64) -> Self {
67        self.seed = seed;
68        self
69    }
70
71    /// Set the hypervector dimension.
72    #[must_use]
73    pub const fn with_dimension(mut self, dimension: usize) -> Self {
74        self.dimension = dimension;
75        self
76    }
77}
78
79/// Configuration for ternary gradient accumulation.
80///
81/// # Example
82///
83/// ```
84/// use vsa_optim_rs::TernaryConfig;
85///
86/// let config = TernaryConfig::default()
87///     .with_accumulation_steps(8)
88///     .with_stochastic_rounding(true);
89/// ```
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct TernaryConfig {
92    /// Number of gradient accumulation steps before optimizer update.
93    pub accumulation_steps: usize,
94
95    /// Threshold for ternary quantization (relative to mean abs).
96    pub ternary_threshold: f32,
97
98    /// Learning rate for scale parameters.
99    pub scale_learning_rate: f32,
100
101    /// Whether to use stochastic rounding (unbiased) or deterministic.
102    pub use_stochastic_rounding: bool,
103}
104
105impl Default for TernaryConfig {
106    fn default() -> Self {
107        Self {
108            accumulation_steps: 8,
109            ternary_threshold: 0.5,
110            scale_learning_rate: 0.01,
111            use_stochastic_rounding: true,
112        }
113    }
114}
115
116impl TernaryConfig {
117    /// Set the number of accumulation steps.
118    #[must_use]
119    pub const fn with_accumulation_steps(mut self, steps: usize) -> Self {
120        self.accumulation_steps = steps;
121        self
122    }
123
124    /// Set whether to use stochastic rounding.
125    #[must_use]
126    pub const fn with_stochastic_rounding(mut self, stochastic: bool) -> Self {
127        self.use_stochastic_rounding = stochastic;
128        self
129    }
130
131    /// Set the ternary threshold.
132    #[must_use]
133    pub const fn with_threshold(mut self, threshold: f32) -> Self {
134        self.ternary_threshold = threshold;
135        self
136    }
137}
138
139/// Configuration for gradient prediction.
140///
141/// # Example
142///
143/// ```
144/// use vsa_optim_rs::PredictionConfig;
145///
146/// let config = PredictionConfig::default()
147///     .with_history_size(5)
148///     .with_prediction_steps(4);
149/// ```
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct PredictionConfig {
152    /// Number of past gradients to keep in history.
153    pub history_size: usize,
154
155    /// Number of steps to predict before computing full gradients.
156    pub prediction_steps: usize,
157
158    /// Momentum factor for gradient extrapolation.
159    pub momentum: f32,
160
161    /// Weight applied to correction terms.
162    pub correction_weight: f32,
163
164    /// Minimum correlation threshold for using prediction.
165    pub min_correlation: f32,
166}
167
168impl Default for PredictionConfig {
169    fn default() -> Self {
170        Self {
171            history_size: 5,
172            prediction_steps: 4,
173            momentum: 0.9,
174            correction_weight: 0.5,
175            min_correlation: 0.8,
176        }
177    }
178}
179
180impl PredictionConfig {
181    /// Set the history size.
182    #[must_use]
183    pub const fn with_history_size(mut self, size: usize) -> Self {
184        self.history_size = size;
185        self
186    }
187
188    /// Set the number of prediction steps.
189    #[must_use]
190    pub const fn with_prediction_steps(mut self, steps: usize) -> Self {
191        self.prediction_steps = steps;
192        self
193    }
194
195    /// Set the momentum factor.
196    #[must_use]
197    pub const fn with_momentum(mut self, momentum: f32) -> Self {
198        self.momentum = momentum;
199        self
200    }
201
202    /// Set the correction weight.
203    #[must_use]
204    pub const fn with_correction_weight(mut self, weight: f32) -> Self {
205        self.correction_weight = weight;
206        self
207    }
208}
209
210/// Configuration for phase-based training.
211///
212/// The training cycle is: FULL → PREDICT → CORRECT → repeat
213///
214/// # Example
215///
216/// ```
217/// use vsa_optim_rs::PhaseConfig;
218///
219/// let config = PhaseConfig::default()
220///     .with_full_steps(10)
221///     .with_predict_steps(40);
222/// ```
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct PhaseConfig {
225    /// Number of full gradient computation steps per cycle.
226    pub full_steps: usize,
227
228    /// Number of predicted gradient steps per cycle.
229    pub predict_steps: usize,
230
231    /// Frequency of correction steps during predict phase.
232    pub correct_every: usize,
233
234    /// Sub-configuration for gradient prediction.
235    pub prediction_config: PredictionConfig,
236
237    /// Sub-configuration for ternary optimization.
238    pub ternary_config: TernaryConfig,
239
240    /// Sub-configuration for VSA compression.
241    pub vsa_config: VSAConfig,
242
243    /// Gradient accumulation steps.
244    pub gradient_accumulation: usize,
245
246    /// Maximum gradient norm for clipping.
247    pub max_grad_norm: f32,
248
249    /// Whether to adaptively adjust phase lengths based on loss.
250    pub adaptive_phases: bool,
251
252    /// Loss increase threshold for triggering more full steps.
253    pub loss_threshold: f32,
254}
255
256impl Default for PhaseConfig {
257    fn default() -> Self {
258        Self {
259            full_steps: 10,
260            predict_steps: 40,
261            correct_every: 10,
262            prediction_config: PredictionConfig::default(),
263            ternary_config: TernaryConfig::default(),
264            vsa_config: VSAConfig::default(),
265            gradient_accumulation: 1,
266            max_grad_norm: 1.0,
267            adaptive_phases: true,
268            loss_threshold: 0.1,
269        }
270    }
271}
272
273impl PhaseConfig {
274    /// Set the number of full training steps.
275    #[must_use]
276    pub const fn with_full_steps(mut self, steps: usize) -> Self {
277        self.full_steps = steps;
278        self
279    }
280
281    /// Set the number of prediction steps.
282    #[must_use]
283    pub const fn with_predict_steps(mut self, steps: usize) -> Self {
284        self.predict_steps = steps;
285        self
286    }
287
288    /// Set the correction frequency.
289    #[must_use]
290    pub const fn with_correct_every(mut self, every: usize) -> Self {
291        self.correct_every = every;
292        self
293    }
294
295    /// Set the maximum gradient norm for clipping.
296    #[must_use]
297    pub const fn with_max_grad_norm(mut self, norm: f32) -> Self {
298        self.max_grad_norm = norm;
299        self
300    }
301
302    /// Set whether to use adaptive phase scheduling.
303    #[must_use]
304    pub const fn with_adaptive_phases(mut self, adaptive: bool) -> Self {
305        self.adaptive_phases = adaptive;
306        self
307    }
308
309    /// Set the prediction sub-configuration.
310    #[must_use]
311    pub fn with_prediction_config(mut self, config: PredictionConfig) -> Self {
312        self.prediction_config = config;
313        self
314    }
315
316    /// Set the ternary sub-configuration.
317    #[must_use]
318    pub fn with_ternary_config(mut self, config: TernaryConfig) -> Self {
319        self.ternary_config = config;
320        self
321    }
322
323    /// Set the VSA sub-configuration.
324    #[must_use]
325    pub fn with_vsa_config(mut self, config: VSAConfig) -> Self {
326        self.vsa_config = config;
327        self
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn test_vsa_config_defaults() {
337        let config = VSAConfig::default();
338        assert_eq!(config.dimension, 8192);
339        assert!((config.compression_ratio - 0.1).abs() < 0.001);
340        assert!(config.use_ternary);
341        assert_eq!(config.seed, 42);
342    }
343
344    #[test]
345    fn test_vsa_config_builder() {
346        let config = VSAConfig::default()
347            .with_compression_ratio(0.2)
348            .with_ternary(false)
349            .with_seed(123);
350
351        assert!((config.compression_ratio - 0.2).abs() < 0.001);
352        assert!(!config.use_ternary);
353        assert_eq!(config.seed, 123);
354    }
355
356    #[test]
357    fn test_ternary_config_defaults() {
358        let config = TernaryConfig::default();
359        assert_eq!(config.accumulation_steps, 8);
360        assert!(config.use_stochastic_rounding);
361    }
362
363    #[test]
364    fn test_prediction_config_defaults() {
365        let config = PredictionConfig::default();
366        assert_eq!(config.history_size, 5);
367        assert_eq!(config.prediction_steps, 4);
368        assert!((config.momentum - 0.9).abs() < 0.001);
369    }
370
371    #[test]
372    fn test_phase_config_defaults() {
373        let config = PhaseConfig::default();
374        assert_eq!(config.full_steps, 10);
375        assert_eq!(config.predict_steps, 40);
376        assert_eq!(config.correct_every, 10);
377        assert!(config.adaptive_phases);
378    }
379
380    #[test]
381    fn test_phase_config_builder() {
382        let config = PhaseConfig::default()
383            .with_full_steps(5)
384            .with_predict_steps(20)
385            .with_correct_every(5)
386            .with_adaptive_phases(false);
387
388        assert_eq!(config.full_steps, 5);
389        assert_eq!(config.predict_steps, 20);
390        assert_eq!(config.correct_every, 5);
391        assert!(!config.adaptive_phases);
392    }
393}