Skip to main content

trustformers_optim/zero/
mod.rs

1//! ZeRO (Zero Redundancy Optimizer) Implementation for TrustformeRS
2//!
3//! ZeRO is a memory-efficient training technique that partitions optimizer states,
4//! gradients, and parameters across devices to reduce memory usage while maintaining
5//! training efficiency.
6//!
7//! Implements three stages:
8//! - Stage 1: Partition optimizer states
9//! - Stage 2: Partition optimizer states + gradients
10//! - Stage 3: Partition optimizer states + gradients + parameters
11
12pub mod zero_optimizer;
13pub mod zero_stage1;
14pub mod zero_stage2;
15pub mod zero_stage3;
16pub mod zero_utils;
17
18pub use zero_optimizer::{ZeROConfig, ZeROOptimizer, ZeROStage};
19pub use zero_stage1::ZeROStage1;
20pub use zero_stage2::ZeROStage2;
21pub use zero_stage3::ZeROStage3;
22pub use zero_utils::{
23    all_gather_gradients, gather_parameters, partition_gradients, partition_parameters,
24    reduce_scatter_gradients, GradientBuffer, ParameterGroup, ParameterPartition, ZeROState,
25};
26
27/// ZeRO optimization stages
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum ZeROImplementationStage {
30    /// Stage 1: Partition optimizer states only
31    Stage1,
32    /// Stage 2: Partition optimizer states + gradients
33    Stage2,
34    /// Stage 3: Partition optimizer states + gradients + parameters
35    Stage3,
36}
37
38/// Memory statistics for ZeRO optimization
39#[derive(Debug, Clone)]
40pub struct ZeROMemoryStats {
41    /// Memory saved by partitioning optimizer states
42    pub optimizer_memory_saved: usize,
43    /// Memory saved by partitioning gradients
44    pub gradient_memory_saved: usize,
45    /// Memory saved by partitioning parameters
46    pub parameter_memory_saved: usize,
47    /// Total memory saved
48    pub total_memory_saved: usize,
49    /// Memory overhead from communication buffers
50    pub communication_overhead: usize,
51}
52
53impl Default for ZeROMemoryStats {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl ZeROMemoryStats {
60    pub fn new() -> Self {
61        Self {
62            optimizer_memory_saved: 0,
63            gradient_memory_saved: 0,
64            parameter_memory_saved: 0,
65            total_memory_saved: 0,
66            communication_overhead: 0,
67        }
68    }
69
70    pub fn update_totals(&mut self) {
71        self.total_memory_saved =
72            self.optimizer_memory_saved + self.gradient_memory_saved + self.parameter_memory_saved;
73    }
74}