Skip to main content

tritter_accel/core/
training.rs

1//! Training acceleration utilities.
2//!
3//! Provides tools for accelerating neural network training:
4//! - Gradient compression for distributed training
5//! - Mixed precision utilities
6//! - Memory-efficient operations
7//!
8//! # Example
9//!
10//! ```rust,ignore
11//! use tritter_accel::core::training::{GradientCompressor, TrainingConfig};
12//!
13//! let config = TrainingConfig::default();
14//! let compressor = GradientCompressor::new(config);
15//!
16//! // Compress gradients for communication
17//! let gradients = vec![0.1, -0.2, 0.3, -0.4];
18//! let compressed = compressor.compress(&gradients, 0.1)?;
19//!
20//! // Decompress on receiving end
21//! let recovered = compressor.decompress(&compressed, gradients.len())?;
22//! ```
23
24use thiserror::Error;
25
26/// Errors from training operations.
27#[derive(Debug, Error)]
28pub enum TrainingError {
29    /// Invalid compression ratio.
30    #[error("invalid compression ratio {0}: must be in (0, 1]")]
31    InvalidRatio(f32),
32
33    /// Dimension mismatch.
34    #[error("dimension mismatch: expected {expected}, got {actual}")]
35    DimensionMismatch { expected: usize, actual: usize },
36
37    /// Seed mismatch during decompression.
38    #[error("seed mismatch: compression used {compress}, decompression used {decompress}")]
39    SeedMismatch { compress: u64, decompress: u64 },
40}
41
42/// Configuration for training acceleration.
43#[derive(Debug, Clone)]
44pub struct TrainingConfig {
45    /// Default compression ratio for gradient compression.
46    pub default_compression_ratio: f32,
47    /// Random seed for reproducibility.
48    pub seed: u64,
49    /// Enable gradient clipping.
50    pub gradient_clipping: Option<f32>,
51}
52
53impl Default for TrainingConfig {
54    fn default() -> Self {
55        Self {
56            default_compression_ratio: 0.1,
57            seed: 42,
58            gradient_clipping: None,
59        }
60    }
61}
62
63impl TrainingConfig {
64    /// Set compression ratio.
65    pub fn with_compression_ratio(mut self, ratio: f32) -> Self {
66        self.default_compression_ratio = ratio;
67        self
68    }
69
70    /// Set random seed.
71    pub fn with_seed(mut self, seed: u64) -> Self {
72        self.seed = seed;
73        self
74    }
75
76    /// Enable gradient clipping.
77    pub fn with_gradient_clipping(mut self, max_norm: f32) -> Self {
78        self.gradient_clipping = Some(max_norm);
79        self
80    }
81}
82
83/// Compressed gradient representation.
84#[derive(Debug, Clone)]
85pub struct CompressedGradient {
86    /// Compressed data.
87    pub data: Vec<f32>,
88    /// Original dimension.
89    pub original_dim: usize,
90    /// Random seed used for projection.
91    pub seed: u64,
92    /// Compression ratio achieved.
93    pub ratio: f32,
94}
95
96/// Gradient compressor using random projection.
97///
98/// Uses Johnson-Lindenstrauss style random projection for
99/// communication-efficient distributed training.
100#[derive(Debug, Clone)]
101pub struct GradientCompressor {
102    config: TrainingConfig,
103}
104
105impl GradientCompressor {
106    /// Create a new gradient compressor.
107    pub fn new(config: TrainingConfig) -> Self {
108        Self { config }
109    }
110
111    /// Compress gradients using random projection.
112    ///
113    /// # Arguments
114    ///
115    /// * `gradients` - Original gradient vector
116    /// * `ratio` - Compression ratio (0 < ratio <= 1), None uses default
117    ///
118    /// # Returns
119    ///
120    /// Compressed gradient representation.
121    #[allow(clippy::cast_precision_loss)]
122    pub fn compress(
123        &self,
124        gradients: &[f32],
125        ratio: Option<f32>,
126    ) -> Result<CompressedGradient, TrainingError> {
127        let ratio = ratio.unwrap_or(self.config.default_compression_ratio);
128
129        if ratio <= 0.0 || ratio > 1.0 {
130            return Err(TrainingError::InvalidRatio(ratio));
131        }
132
133        let original_dim = gradients.len();
134        let compressed_dim = ((original_dim as f32 * ratio).ceil() as usize).max(64);
135
136        // Apply gradient clipping if configured
137        let gradients = if let Some(max_norm) = self.config.gradient_clipping {
138            clip_gradients(gradients, max_norm)
139        } else {
140            gradients.to_vec()
141        };
142
143        // Random projection (sparse for efficiency)
144        let compressed = sparse_random_projection(&gradients, compressed_dim, self.config.seed);
145
146        Ok(CompressedGradient {
147            data: compressed,
148            original_dim,
149            seed: self.config.seed,
150            ratio,
151        })
152    }
153
154    /// Decompress gradients.
155    ///
156    /// # Arguments
157    ///
158    /// * `compressed` - Compressed gradient
159    ///
160    /// # Returns
161    ///
162    /// Reconstructed gradient vector (approximate).
163    pub fn decompress(&self, compressed: &CompressedGradient) -> Result<Vec<f32>, TrainingError> {
164        if compressed.seed != self.config.seed {
165            return Err(TrainingError::SeedMismatch {
166                compress: compressed.seed,
167                decompress: self.config.seed,
168            });
169        }
170
171        let recovered =
172            sparse_random_projection_transpose(&compressed.data, compressed.original_dim, compressed.seed);
173
174        Ok(recovered)
175    }
176
177    /// Compress and immediately quantize to ternary for maximum compression.
178    ///
179    /// This achieves ~300x compression: 10x from projection + ~30x from ternary.
180    #[allow(clippy::cast_precision_loss)]
181    pub fn compress_ternary(
182        &self,
183        gradients: &[f32],
184        ratio: Option<f32>,
185    ) -> Result<TernaryCompressedGradient, TrainingError> {
186        let compressed = self.compress(gradients, ratio)?;
187
188        // Quantize compressed representation to ternary
189        let (ternary, scale) = quantize_to_ternary(&compressed.data);
190
191        Ok(TernaryCompressedGradient {
192            data: ternary,
193            scale,
194            original_dim: compressed.original_dim,
195            compressed_dim: compressed.data.len(),
196            seed: compressed.seed,
197        })
198    }
199
200    /// Decompress ternary compressed gradients.
201    pub fn decompress_ternary(
202        &self,
203        compressed: &TernaryCompressedGradient,
204    ) -> Result<Vec<f32>, TrainingError> {
205        if compressed.seed != self.config.seed {
206            return Err(TrainingError::SeedMismatch {
207                compress: compressed.seed,
208                decompress: self.config.seed,
209            });
210        }
211
212        // Dequantize from ternary
213        let dequantized: Vec<f32> = compressed
214            .data
215            .iter()
216            .map(|&t| f32::from(t) * compressed.scale)
217            .collect();
218
219        // Inverse projection
220        let recovered = sparse_random_projection_transpose(&dequantized, compressed.original_dim, compressed.seed);
221
222        Ok(recovered)
223    }
224}
225
226/// Ternary compressed gradient (maximum compression).
227#[derive(Debug, Clone)]
228pub struct TernaryCompressedGradient {
229    /// Ternary values (-1, 0, +1).
230    pub data: Vec<i8>,
231    /// Scale factor.
232    pub scale: f32,
233    /// Original dimension.
234    pub original_dim: usize,
235    /// Compressed dimension.
236    pub compressed_dim: usize,
237    /// Random seed.
238    pub seed: u64,
239}
240
241impl TernaryCompressedGradient {
242    /// Calculate compression ratio.
243    #[allow(clippy::cast_precision_loss)]
244    pub fn compression_ratio(&self) -> f32 {
245        // Original: f32 = 32 bits each
246        // Compressed: 2 bits each (ternary) + 32 bits for scale
247        let original_bits = self.original_dim * 32;
248        let compressed_bits = self.data.len() * 2 + 32;
249        original_bits as f32 / compressed_bits as f32
250    }
251}
252
253// Helper functions
254
255fn clip_gradients(gradients: &[f32], max_norm: f32) -> Vec<f32> {
256    let norm: f32 = gradients.iter().map(|x| x * x).sum::<f32>().sqrt();
257
258    if norm > max_norm {
259        let scale = max_norm / norm;
260        gradients.iter().map(|x| x * scale).collect()
261    } else {
262        gradients.to_vec()
263    }
264}
265
266#[allow(clippy::cast_precision_loss)]
267fn sparse_random_projection(input: &[f32], output_dim: usize, seed: u64) -> Vec<f32> {
268    use rand::{Rng, SeedableRng};
269    use rand_chacha::ChaCha8Rng;
270
271    let mut rng = ChaCha8Rng::seed_from_u64(seed);
272    let mut output = vec![0.0f32; output_dim];
273
274    // Scale factor for Johnson-Lindenstrauss
275    let scale = 1.0 / (input.len() as f32).sqrt();
276
277    // Sparse random projection: ~68% zeros, 16% +1, 16% -1
278    for &g in input {
279        for o in output.iter_mut() {
280            let r: f32 = rng.gen();
281            if r < 0.16 {
282                *o += g * scale;
283            } else if r < 0.32 {
284                *o -= g * scale;
285            }
286        }
287    }
288
289    output
290}
291
292#[allow(clippy::cast_precision_loss)]
293fn sparse_random_projection_transpose(input: &[f32], output_dim: usize, seed: u64) -> Vec<f32> {
294    use rand::{Rng, SeedableRng};
295    use rand_chacha::ChaCha8Rng;
296
297    let mut rng = ChaCha8Rng::seed_from_u64(seed);
298    let mut output = vec![0.0f32; output_dim];
299
300    let scale = 1.0 / (output_dim as f32).sqrt();
301
302    // Transpose of sparse random projection
303    for o in output.iter_mut() {
304        for &c in input {
305            let r: f32 = rng.gen();
306            if r < 0.16 {
307                *o += c * scale;
308            } else if r < 0.32 {
309                *o -= c * scale;
310            }
311        }
312    }
313
314    output
315}
316
317fn quantize_to_ternary(values: &[f32]) -> (Vec<i8>, f32) {
318    // Use AbsMean scaling
319    let abs_mean: f32 = values.iter().map(|x| x.abs()).sum::<f32>() / values.len() as f32;
320    let scale = if abs_mean > 1e-10 { abs_mean } else { 1.0 };
321
322    let ternary: Vec<i8> = values
323        .iter()
324        .map(|&v| {
325            let normalized = v / scale;
326            if normalized > 0.5 {
327                1i8
328            } else if normalized < -0.5 {
329                -1i8
330            } else {
331                0i8
332            }
333        })
334        .collect();
335
336    (ternary, scale)
337}
338
339/// Gradient accumulator for memory-efficient training.
340///
341/// Accumulates gradients in lower precision to reduce memory usage.
342#[derive(Debug)]
343pub struct GradientAccumulator {
344    /// Accumulated gradients.
345    accumulated: Vec<f32>,
346    /// Number of accumulated batches.
347    count: usize,
348}
349
350impl GradientAccumulator {
351    /// Create a new accumulator.
352    pub fn new(size: usize) -> Self {
353        Self {
354            accumulated: vec![0.0; size],
355            count: 0,
356        }
357    }
358
359    /// Add gradients to accumulator.
360    pub fn accumulate(&mut self, gradients: &[f32]) -> Result<(), TrainingError> {
361        if gradients.len() != self.accumulated.len() {
362            return Err(TrainingError::DimensionMismatch {
363                expected: self.accumulated.len(),
364                actual: gradients.len(),
365            });
366        }
367
368        for (acc, &g) in self.accumulated.iter_mut().zip(gradients.iter()) {
369            *acc += g;
370        }
371        self.count += 1;
372
373        Ok(())
374    }
375
376    /// Get averaged gradients and reset accumulator.
377    #[allow(clippy::cast_precision_loss)]
378    pub fn get_and_reset(&mut self) -> Vec<f32> {
379        if self.count == 0 {
380            return self.accumulated.clone();
381        }
382
383        let scale = 1.0 / self.count as f32;
384        let result: Vec<f32> = self.accumulated.iter().map(|&x| x * scale).collect();
385
386        // Reset
387        self.accumulated.fill(0.0);
388        self.count = 0;
389
390        result
391    }
392
393    /// Get current count.
394    pub fn count(&self) -> usize {
395        self.count
396    }
397}
398
399/// Mixed precision training utilities.
400pub mod mixed_precision {
401    use super::TrainingError;
402
403    /// Convert f32 to bf16 representation (as u16).
404    ///
405    /// BF16 truncates the lower 16 bits of f32, preserving range but reducing precision.
406    pub fn f32_to_bf16(value: f32) -> u16 {
407        let bits = value.to_bits();
408        (bits >> 16) as u16
409    }
410
411    /// Convert bf16 (as u16) back to f32.
412    pub fn bf16_to_f32(value: u16) -> f32 {
413        let bits = (value as u32) << 16;
414        f32::from_bits(bits)
415    }
416
417    /// Convert slice of f32 to bf16.
418    pub fn convert_to_bf16(values: &[f32]) -> Vec<u16> {
419        values.iter().map(|&v| f32_to_bf16(v)).collect()
420    }
421
422    /// Convert slice of bf16 back to f32.
423    pub fn convert_from_bf16(values: &[u16]) -> Vec<f32> {
424        values.iter().map(|&v| bf16_to_f32(v)).collect()
425    }
426
427    /// Loss scaling for mixed precision training.
428    #[derive(Debug, Clone)]
429    pub struct LossScaler {
430        scale: f32,
431        growth_factor: f32,
432        backoff_factor: f32,
433        growth_interval: usize,
434        steps_since_growth: usize,
435    }
436
437    impl Default for LossScaler {
438        fn default() -> Self {
439            Self {
440                scale: 65536.0, // 2^16
441                growth_factor: 2.0,
442                backoff_factor: 0.5,
443                growth_interval: 2000,
444                steps_since_growth: 0,
445            }
446        }
447    }
448
449    impl LossScaler {
450        /// Create with initial scale.
451        pub fn with_initial_scale(scale: f32) -> Self {
452            Self {
453                scale,
454                ..Default::default()
455            }
456        }
457
458        /// Get current scale factor.
459        pub fn scale(&self) -> f32 {
460            self.scale
461        }
462
463        /// Scale loss for backward pass.
464        pub fn scale_loss(&self, loss: f32) -> f32 {
465            loss * self.scale
466        }
467
468        /// Unscale gradients after backward pass.
469        pub fn unscale_gradients(&self, gradients: &mut [f32]) {
470            let inv_scale = 1.0 / self.scale;
471            for g in gradients.iter_mut() {
472                *g *= inv_scale;
473            }
474        }
475
476        /// Update scale based on whether overflow occurred.
477        pub fn update(&mut self, overflow: bool) {
478            if overflow {
479                self.scale *= self.backoff_factor;
480                self.steps_since_growth = 0;
481            } else {
482                self.steps_since_growth += 1;
483                if self.steps_since_growth >= self.growth_interval {
484                    self.scale *= self.growth_factor;
485                    self.steps_since_growth = 0;
486                }
487            }
488        }
489
490        /// Check if gradients contain inf/nan (overflow).
491        pub fn check_overflow(gradients: &[f32]) -> bool {
492            gradients.iter().any(|&g| g.is_nan() || g.is_infinite())
493        }
494    }
495
496    /// Check for NaN or Inf in tensor.
497    pub fn has_nan_or_inf(values: &[f32]) -> bool {
498        values.iter().any(|&v| v.is_nan() || v.is_infinite())
499    }
500
501    /// Clip values to prevent overflow.
502    pub fn safe_clip(values: &mut [f32], min: f32, max: f32) -> Result<(), TrainingError> {
503        for v in values.iter_mut() {
504            if v.is_nan() {
505                *v = 0.0;
506            } else {
507                *v = v.clamp(min, max);
508            }
509        }
510        Ok(())
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[test]
519    fn test_gradient_compression_roundtrip() {
520        let config = TrainingConfig::default();
521        let compressor = GradientCompressor::new(config);
522
523        let gradients: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 500.0).collect();
524
525        let compressed = compressor.compress(&gradients, Some(0.1)).unwrap();
526        let recovered = compressor.decompress(&compressed).unwrap();
527
528        // Check dimensions
529        assert_eq!(recovered.len(), gradients.len());
530
531        // Check approximate reconstruction (lossy compression)
532        let mse: f32 = gradients
533            .iter()
534            .zip(recovered.iter())
535            .map(|(a, b)| (a - b).powi(2))
536            .sum::<f32>()
537            / gradients.len() as f32;
538
539        // MSE should be reasonable (not zero due to lossy compression)
540        assert!(mse < 1.0);
541    }
542
543    #[test]
544    fn test_ternary_compression() {
545        let config = TrainingConfig::default();
546        let compressor = GradientCompressor::new(config);
547
548        let gradients: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 500.0).collect();
549
550        let compressed = compressor.compress_ternary(&gradients, Some(0.1)).unwrap();
551
552        // Check compression ratio is high
553        assert!(compressed.compression_ratio() > 10.0);
554
555        // Check ternary values
556        for &t in &compressed.data {
557            assert!([-1, 0, 1].contains(&t));
558        }
559    }
560
561    #[test]
562    fn test_gradient_accumulator() {
563        let mut acc = GradientAccumulator::new(4);
564
565        acc.accumulate(&[1.0, 2.0, 3.0, 4.0]).unwrap();
566        acc.accumulate(&[2.0, 4.0, 6.0, 8.0]).unwrap();
567
568        let result = acc.get_and_reset();
569
570        // Average: [1.5, 3.0, 4.5, 6.0]
571        assert!((result[0] - 1.5).abs() < 1e-6);
572        assert!((result[1] - 3.0).abs() < 1e-6);
573        assert!((result[2] - 4.5).abs() < 1e-6);
574        assert!((result[3] - 6.0).abs() < 1e-6);
575    }
576
577    #[test]
578    fn test_mixed_precision_bf16() {
579        use mixed_precision::{bf16_to_f32, f32_to_bf16};
580
581        let original = 3.14159f32;
582        let bf16 = f32_to_bf16(original);
583        let recovered = bf16_to_f32(bf16);
584
585        // BF16 has ~3 decimal digits of precision
586        assert!((original - recovered).abs() < 0.01);
587    }
588
589    #[test]
590    fn test_loss_scaler() {
591        use mixed_precision::LossScaler;
592
593        let mut scaler = LossScaler::default();
594        let initial_scale = scaler.scale();
595
596        // Simulate overflow
597        scaler.update(true);
598        assert!(scaler.scale() < initial_scale);
599
600        // Simulate many successful steps
601        for _ in 0..2000 {
602            scaler.update(false);
603        }
604        // Scale should have grown back
605        assert!(scaler.scale() > initial_scale * 0.5);
606    }
607}