1use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct VSAConfig {
24 pub dimension: usize,
26
27 pub compression_ratio: f32,
30
31 pub use_ternary: bool,
33
34 pub seed: u64,
36}
37
38impl Default for VSAConfig {
39 fn default() -> Self {
40 Self {
41 dimension: 8192,
42 compression_ratio: 0.1,
43 use_ternary: true,
44 seed: 42,
45 }
46 }
47}
48
49impl VSAConfig {
50 #[must_use]
52 pub const fn with_compression_ratio(mut self, ratio: f32) -> Self {
53 self.compression_ratio = ratio;
54 self
55 }
56
57 #[must_use]
59 pub const fn with_ternary(mut self, use_ternary: bool) -> Self {
60 self.use_ternary = use_ternary;
61 self
62 }
63
64 #[must_use]
66 pub const fn with_seed(mut self, seed: u64) -> Self {
67 self.seed = seed;
68 self
69 }
70
71 #[must_use]
73 pub const fn with_dimension(mut self, dimension: usize) -> Self {
74 self.dimension = dimension;
75 self
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct TernaryConfig {
92 pub accumulation_steps: usize,
94
95 pub ternary_threshold: f32,
97
98 pub scale_learning_rate: f32,
100
101 pub use_stochastic_rounding: bool,
103}
104
105impl Default for TernaryConfig {
106 fn default() -> Self {
107 Self {
108 accumulation_steps: 8,
109 ternary_threshold: 0.5,
110 scale_learning_rate: 0.01,
111 use_stochastic_rounding: true,
112 }
113 }
114}
115
116impl TernaryConfig {
117 #[must_use]
119 pub const fn with_accumulation_steps(mut self, steps: usize) -> Self {
120 self.accumulation_steps = steps;
121 self
122 }
123
124 #[must_use]
126 pub const fn with_stochastic_rounding(mut self, stochastic: bool) -> Self {
127 self.use_stochastic_rounding = stochastic;
128 self
129 }
130
131 #[must_use]
133 pub const fn with_threshold(mut self, threshold: f32) -> Self {
134 self.ternary_threshold = threshold;
135 self
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct PredictionConfig {
152 pub history_size: usize,
154
155 pub prediction_steps: usize,
157
158 pub momentum: f32,
160
161 pub correction_weight: f32,
163
164 pub min_correlation: f32,
166}
167
168impl Default for PredictionConfig {
169 fn default() -> Self {
170 Self {
171 history_size: 5,
172 prediction_steps: 4,
173 momentum: 0.9,
174 correction_weight: 0.5,
175 min_correlation: 0.8,
176 }
177 }
178}
179
180impl PredictionConfig {
181 #[must_use]
183 pub const fn with_history_size(mut self, size: usize) -> Self {
184 self.history_size = size;
185 self
186 }
187
188 #[must_use]
190 pub const fn with_prediction_steps(mut self, steps: usize) -> Self {
191 self.prediction_steps = steps;
192 self
193 }
194
195 #[must_use]
197 pub const fn with_momentum(mut self, momentum: f32) -> Self {
198 self.momentum = momentum;
199 self
200 }
201
202 #[must_use]
204 pub const fn with_correction_weight(mut self, weight: f32) -> Self {
205 self.correction_weight = weight;
206 self
207 }
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct PhaseConfig {
225 pub full_steps: usize,
227
228 pub predict_steps: usize,
230
231 pub correct_every: usize,
233
234 pub prediction_config: PredictionConfig,
236
237 pub ternary_config: TernaryConfig,
239
240 pub vsa_config: VSAConfig,
242
243 pub gradient_accumulation: usize,
245
246 pub max_grad_norm: f32,
248
249 pub adaptive_phases: bool,
251
252 pub loss_threshold: f32,
254}
255
256impl Default for PhaseConfig {
257 fn default() -> Self {
258 Self {
259 full_steps: 10,
260 predict_steps: 40,
261 correct_every: 10,
262 prediction_config: PredictionConfig::default(),
263 ternary_config: TernaryConfig::default(),
264 vsa_config: VSAConfig::default(),
265 gradient_accumulation: 1,
266 max_grad_norm: 1.0,
267 adaptive_phases: true,
268 loss_threshold: 0.1,
269 }
270 }
271}
272
273impl PhaseConfig {
274 #[must_use]
276 pub const fn with_full_steps(mut self, steps: usize) -> Self {
277 self.full_steps = steps;
278 self
279 }
280
281 #[must_use]
283 pub const fn with_predict_steps(mut self, steps: usize) -> Self {
284 self.predict_steps = steps;
285 self
286 }
287
288 #[must_use]
290 pub const fn with_correct_every(mut self, every: usize) -> Self {
291 self.correct_every = every;
292 self
293 }
294
295 #[must_use]
297 pub const fn with_max_grad_norm(mut self, norm: f32) -> Self {
298 self.max_grad_norm = norm;
299 self
300 }
301
302 #[must_use]
304 pub const fn with_adaptive_phases(mut self, adaptive: bool) -> Self {
305 self.adaptive_phases = adaptive;
306 self
307 }
308
309 #[must_use]
311 pub fn with_prediction_config(mut self, config: PredictionConfig) -> Self {
312 self.prediction_config = config;
313 self
314 }
315
316 #[must_use]
318 pub fn with_ternary_config(mut self, config: TernaryConfig) -> Self {
319 self.ternary_config = config;
320 self
321 }
322
323 #[must_use]
325 pub fn with_vsa_config(mut self, config: VSAConfig) -> Self {
326 self.vsa_config = config;
327 self
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn test_vsa_config_defaults() {
337 let config = VSAConfig::default();
338 assert_eq!(config.dimension, 8192);
339 assert!((config.compression_ratio - 0.1).abs() < 0.001);
340 assert!(config.use_ternary);
341 assert_eq!(config.seed, 42);
342 }
343
344 #[test]
345 fn test_vsa_config_builder() {
346 let config = VSAConfig::default()
347 .with_compression_ratio(0.2)
348 .with_ternary(false)
349 .with_seed(123);
350
351 assert!((config.compression_ratio - 0.2).abs() < 0.001);
352 assert!(!config.use_ternary);
353 assert_eq!(config.seed, 123);
354 }
355
356 #[test]
357 fn test_ternary_config_defaults() {
358 let config = TernaryConfig::default();
359 assert_eq!(config.accumulation_steps, 8);
360 assert!(config.use_stochastic_rounding);
361 }
362
363 #[test]
364 fn test_prediction_config_defaults() {
365 let config = PredictionConfig::default();
366 assert_eq!(config.history_size, 5);
367 assert_eq!(config.prediction_steps, 4);
368 assert!((config.momentum - 0.9).abs() < 0.001);
369 }
370
371 #[test]
372 fn test_phase_config_defaults() {
373 let config = PhaseConfig::default();
374 assert_eq!(config.full_steps, 10);
375 assert_eq!(config.predict_steps, 40);
376 assert_eq!(config.correct_every, 10);
377 assert!(config.adaptive_phases);
378 }
379
380 #[test]
381 fn test_phase_config_builder() {
382 let config = PhaseConfig::default()
383 .with_full_steps(5)
384 .with_predict_steps(20)
385 .with_correct_every(5)
386 .with_adaptive_phases(false);
387
388 assert_eq!(config.full_steps, 5);
389 assert_eq!(config.predict_steps, 20);
390 assert_eq!(config.correct_every, 5);
391 assert!(!config.adaptive_phases);
392 }
393}