1use thiserror::Error;
25
26#[derive(Debug, Error)]
28pub enum TrainingError {
29 #[error("invalid compression ratio {0}: must be in (0, 1]")]
31 InvalidRatio(f32),
32
33 #[error("dimension mismatch: expected {expected}, got {actual}")]
35 DimensionMismatch { expected: usize, actual: usize },
36
37 #[error("seed mismatch: compression used {compress}, decompression used {decompress}")]
39 SeedMismatch { compress: u64, decompress: u64 },
40}
41
42#[derive(Debug, Clone)]
44pub struct TrainingConfig {
45 pub default_compression_ratio: f32,
47 pub seed: u64,
49 pub gradient_clipping: Option<f32>,
51}
52
53impl Default for TrainingConfig {
54 fn default() -> Self {
55 Self {
56 default_compression_ratio: 0.1,
57 seed: 42,
58 gradient_clipping: None,
59 }
60 }
61}
62
63impl TrainingConfig {
64 pub fn with_compression_ratio(mut self, ratio: f32) -> Self {
66 self.default_compression_ratio = ratio;
67 self
68 }
69
70 pub fn with_seed(mut self, seed: u64) -> Self {
72 self.seed = seed;
73 self
74 }
75
76 pub fn with_gradient_clipping(mut self, max_norm: f32) -> Self {
78 self.gradient_clipping = Some(max_norm);
79 self
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct CompressedGradient {
86 pub data: Vec<f32>,
88 pub original_dim: usize,
90 pub seed: u64,
92 pub ratio: f32,
94}
95
96#[derive(Debug, Clone)]
101pub struct GradientCompressor {
102 config: TrainingConfig,
103}
104
105impl GradientCompressor {
106 pub fn new(config: TrainingConfig) -> Self {
108 Self { config }
109 }
110
111 #[allow(clippy::cast_precision_loss)]
122 pub fn compress(
123 &self,
124 gradients: &[f32],
125 ratio: Option<f32>,
126 ) -> Result<CompressedGradient, TrainingError> {
127 let ratio = ratio.unwrap_or(self.config.default_compression_ratio);
128
129 if ratio <= 0.0 || ratio > 1.0 {
130 return Err(TrainingError::InvalidRatio(ratio));
131 }
132
133 let original_dim = gradients.len();
134 let compressed_dim = ((original_dim as f32 * ratio).ceil() as usize).max(64);
135
136 let gradients = if let Some(max_norm) = self.config.gradient_clipping {
138 clip_gradients(gradients, max_norm)
139 } else {
140 gradients.to_vec()
141 };
142
143 let compressed = sparse_random_projection(&gradients, compressed_dim, self.config.seed);
145
146 Ok(CompressedGradient {
147 data: compressed,
148 original_dim,
149 seed: self.config.seed,
150 ratio,
151 })
152 }
153
154 pub fn decompress(&self, compressed: &CompressedGradient) -> Result<Vec<f32>, TrainingError> {
164 if compressed.seed != self.config.seed {
165 return Err(TrainingError::SeedMismatch {
166 compress: compressed.seed,
167 decompress: self.config.seed,
168 });
169 }
170
171 let recovered =
172 sparse_random_projection_transpose(&compressed.data, compressed.original_dim, compressed.seed);
173
174 Ok(recovered)
175 }
176
177 #[allow(clippy::cast_precision_loss)]
181 pub fn compress_ternary(
182 &self,
183 gradients: &[f32],
184 ratio: Option<f32>,
185 ) -> Result<TernaryCompressedGradient, TrainingError> {
186 let compressed = self.compress(gradients, ratio)?;
187
188 let (ternary, scale) = quantize_to_ternary(&compressed.data);
190
191 Ok(TernaryCompressedGradient {
192 data: ternary,
193 scale,
194 original_dim: compressed.original_dim,
195 compressed_dim: compressed.data.len(),
196 seed: compressed.seed,
197 })
198 }
199
200 pub fn decompress_ternary(
202 &self,
203 compressed: &TernaryCompressedGradient,
204 ) -> Result<Vec<f32>, TrainingError> {
205 if compressed.seed != self.config.seed {
206 return Err(TrainingError::SeedMismatch {
207 compress: compressed.seed,
208 decompress: self.config.seed,
209 });
210 }
211
212 let dequantized: Vec<f32> = compressed
214 .data
215 .iter()
216 .map(|&t| f32::from(t) * compressed.scale)
217 .collect();
218
219 let recovered = sparse_random_projection_transpose(&dequantized, compressed.original_dim, compressed.seed);
221
222 Ok(recovered)
223 }
224}
225
226#[derive(Debug, Clone)]
228pub struct TernaryCompressedGradient {
229 pub data: Vec<i8>,
231 pub scale: f32,
233 pub original_dim: usize,
235 pub compressed_dim: usize,
237 pub seed: u64,
239}
240
241impl TernaryCompressedGradient {
242 #[allow(clippy::cast_precision_loss)]
244 pub fn compression_ratio(&self) -> f32 {
245 let original_bits = self.original_dim * 32;
248 let compressed_bits = self.data.len() * 2 + 32;
249 original_bits as f32 / compressed_bits as f32
250 }
251}
252
253fn clip_gradients(gradients: &[f32], max_norm: f32) -> Vec<f32> {
256 let norm: f32 = gradients.iter().map(|x| x * x).sum::<f32>().sqrt();
257
258 if norm > max_norm {
259 let scale = max_norm / norm;
260 gradients.iter().map(|x| x * scale).collect()
261 } else {
262 gradients.to_vec()
263 }
264}
265
266#[allow(clippy::cast_precision_loss)]
267fn sparse_random_projection(input: &[f32], output_dim: usize, seed: u64) -> Vec<f32> {
268 use rand::{Rng, SeedableRng};
269 use rand_chacha::ChaCha8Rng;
270
271 let mut rng = ChaCha8Rng::seed_from_u64(seed);
272 let mut output = vec![0.0f32; output_dim];
273
274 let scale = 1.0 / (input.len() as f32).sqrt();
276
277 for &g in input {
279 for o in output.iter_mut() {
280 let r: f32 = rng.gen();
281 if r < 0.16 {
282 *o += g * scale;
283 } else if r < 0.32 {
284 *o -= g * scale;
285 }
286 }
287 }
288
289 output
290}
291
292#[allow(clippy::cast_precision_loss)]
293fn sparse_random_projection_transpose(input: &[f32], output_dim: usize, seed: u64) -> Vec<f32> {
294 use rand::{Rng, SeedableRng};
295 use rand_chacha::ChaCha8Rng;
296
297 let mut rng = ChaCha8Rng::seed_from_u64(seed);
298 let mut output = vec![0.0f32; output_dim];
299
300 let scale = 1.0 / (output_dim as f32).sqrt();
301
302 for o in output.iter_mut() {
304 for &c in input {
305 let r: f32 = rng.gen();
306 if r < 0.16 {
307 *o += c * scale;
308 } else if r < 0.32 {
309 *o -= c * scale;
310 }
311 }
312 }
313
314 output
315}
316
317fn quantize_to_ternary(values: &[f32]) -> (Vec<i8>, f32) {
318 let abs_mean: f32 = values.iter().map(|x| x.abs()).sum::<f32>() / values.len() as f32;
320 let scale = if abs_mean > 1e-10 { abs_mean } else { 1.0 };
321
322 let ternary: Vec<i8> = values
323 .iter()
324 .map(|&v| {
325 let normalized = v / scale;
326 if normalized > 0.5 {
327 1i8
328 } else if normalized < -0.5 {
329 -1i8
330 } else {
331 0i8
332 }
333 })
334 .collect();
335
336 (ternary, scale)
337}
338
339#[derive(Debug)]
343pub struct GradientAccumulator {
344 accumulated: Vec<f32>,
346 count: usize,
348}
349
350impl GradientAccumulator {
351 pub fn new(size: usize) -> Self {
353 Self {
354 accumulated: vec![0.0; size],
355 count: 0,
356 }
357 }
358
359 pub fn accumulate(&mut self, gradients: &[f32]) -> Result<(), TrainingError> {
361 if gradients.len() != self.accumulated.len() {
362 return Err(TrainingError::DimensionMismatch {
363 expected: self.accumulated.len(),
364 actual: gradients.len(),
365 });
366 }
367
368 for (acc, &g) in self.accumulated.iter_mut().zip(gradients.iter()) {
369 *acc += g;
370 }
371 self.count += 1;
372
373 Ok(())
374 }
375
376 #[allow(clippy::cast_precision_loss)]
378 pub fn get_and_reset(&mut self) -> Vec<f32> {
379 if self.count == 0 {
380 return self.accumulated.clone();
381 }
382
383 let scale = 1.0 / self.count as f32;
384 let result: Vec<f32> = self.accumulated.iter().map(|&x| x * scale).collect();
385
386 self.accumulated.fill(0.0);
388 self.count = 0;
389
390 result
391 }
392
393 pub fn count(&self) -> usize {
395 self.count
396 }
397}
398
399pub mod mixed_precision {
401 use super::TrainingError;
402
403 pub fn f32_to_bf16(value: f32) -> u16 {
407 let bits = value.to_bits();
408 (bits >> 16) as u16
409 }
410
411 pub fn bf16_to_f32(value: u16) -> f32 {
413 let bits = (value as u32) << 16;
414 f32::from_bits(bits)
415 }
416
417 pub fn convert_to_bf16(values: &[f32]) -> Vec<u16> {
419 values.iter().map(|&v| f32_to_bf16(v)).collect()
420 }
421
422 pub fn convert_from_bf16(values: &[u16]) -> Vec<f32> {
424 values.iter().map(|&v| bf16_to_f32(v)).collect()
425 }
426
427 #[derive(Debug, Clone)]
429 pub struct LossScaler {
430 scale: f32,
431 growth_factor: f32,
432 backoff_factor: f32,
433 growth_interval: usize,
434 steps_since_growth: usize,
435 }
436
437 impl Default for LossScaler {
438 fn default() -> Self {
439 Self {
440 scale: 65536.0, growth_factor: 2.0,
442 backoff_factor: 0.5,
443 growth_interval: 2000,
444 steps_since_growth: 0,
445 }
446 }
447 }
448
449 impl LossScaler {
450 pub fn with_initial_scale(scale: f32) -> Self {
452 Self {
453 scale,
454 ..Default::default()
455 }
456 }
457
458 pub fn scale(&self) -> f32 {
460 self.scale
461 }
462
463 pub fn scale_loss(&self, loss: f32) -> f32 {
465 loss * self.scale
466 }
467
468 pub fn unscale_gradients(&self, gradients: &mut [f32]) {
470 let inv_scale = 1.0 / self.scale;
471 for g in gradients.iter_mut() {
472 *g *= inv_scale;
473 }
474 }
475
476 pub fn update(&mut self, overflow: bool) {
478 if overflow {
479 self.scale *= self.backoff_factor;
480 self.steps_since_growth = 0;
481 } else {
482 self.steps_since_growth += 1;
483 if self.steps_since_growth >= self.growth_interval {
484 self.scale *= self.growth_factor;
485 self.steps_since_growth = 0;
486 }
487 }
488 }
489
490 pub fn check_overflow(gradients: &[f32]) -> bool {
492 gradients.iter().any(|&g| g.is_nan() || g.is_infinite())
493 }
494 }
495
496 pub fn has_nan_or_inf(values: &[f32]) -> bool {
498 values.iter().any(|&v| v.is_nan() || v.is_infinite())
499 }
500
501 pub fn safe_clip(values: &mut [f32], min: f32, max: f32) -> Result<(), TrainingError> {
503 for v in values.iter_mut() {
504 if v.is_nan() {
505 *v = 0.0;
506 } else {
507 *v = v.clamp(min, max);
508 }
509 }
510 Ok(())
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn test_gradient_compression_roundtrip() {
520 let config = TrainingConfig::default();
521 let compressor = GradientCompressor::new(config);
522
523 let gradients: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 500.0).collect();
524
525 let compressed = compressor.compress(&gradients, Some(0.1)).unwrap();
526 let recovered = compressor.decompress(&compressed).unwrap();
527
528 assert_eq!(recovered.len(), gradients.len());
530
531 let mse: f32 = gradients
533 .iter()
534 .zip(recovered.iter())
535 .map(|(a, b)| (a - b).powi(2))
536 .sum::<f32>()
537 / gradients.len() as f32;
538
539 assert!(mse < 1.0);
541 }
542
543 #[test]
544 fn test_ternary_compression() {
545 let config = TrainingConfig::default();
546 let compressor = GradientCompressor::new(config);
547
548 let gradients: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 500.0).collect();
549
550 let compressed = compressor.compress_ternary(&gradients, Some(0.1)).unwrap();
551
552 assert!(compressed.compression_ratio() > 10.0);
554
555 for &t in &compressed.data {
557 assert!([-1, 0, 1].contains(&t));
558 }
559 }
560
561 #[test]
562 fn test_gradient_accumulator() {
563 let mut acc = GradientAccumulator::new(4);
564
565 acc.accumulate(&[1.0, 2.0, 3.0, 4.0]).unwrap();
566 acc.accumulate(&[2.0, 4.0, 6.0, 8.0]).unwrap();
567
568 let result = acc.get_and_reset();
569
570 assert!((result[0] - 1.5).abs() < 1e-6);
572 assert!((result[1] - 3.0).abs() < 1e-6);
573 assert!((result[2] - 4.5).abs() < 1e-6);
574 assert!((result[3] - 6.0).abs() < 1e-6);
575 }
576
577 #[test]
578 fn test_mixed_precision_bf16() {
579 use mixed_precision::{bf16_to_f32, f32_to_bf16};
580
581 let original = 3.14159f32;
582 let bf16 = f32_to_bf16(original);
583 let recovered = bf16_to_f32(bf16);
584
585 assert!((original - recovered).abs() < 0.01);
587 }
588
589 #[test]
590 fn test_loss_scaler() {
591 use mixed_precision::LossScaler;
592
593 let mut scaler = LossScaler::default();
594 let initial_scale = scaler.scale();
595
596 scaler.update(true);
598 assert!(scaler.scale() < initial_scale);
599
600 for _ in 0..2000 {
602 scaler.update(false);
603 }
604 assert!(scaler.scale() > initial_scale * 0.5);
606 }
607}