trustformers_optim/zero/
mod.rs1pub mod zero_optimizer;
13pub mod zero_stage1;
14pub mod zero_stage2;
15pub mod zero_stage3;
16pub mod zero_utils;
17
18pub use zero_optimizer::{ZeROConfig, ZeROOptimizer, ZeROStage};
19pub use zero_stage1::ZeROStage1;
20pub use zero_stage2::ZeROStage2;
21pub use zero_stage3::ZeROStage3;
22pub use zero_utils::{
23 all_gather_gradients, gather_parameters, partition_gradients, partition_parameters,
24 reduce_scatter_gradients, GradientBuffer, ParameterGroup, ParameterPartition, ZeROState,
25};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum ZeROImplementationStage {
30 Stage1,
32 Stage2,
34 Stage3,
36}
37
38#[derive(Debug, Clone)]
40pub struct ZeROMemoryStats {
41 pub optimizer_memory_saved: usize,
43 pub gradient_memory_saved: usize,
45 pub parameter_memory_saved: usize,
47 pub total_memory_saved: usize,
49 pub communication_overhead: usize,
51}
52
53impl Default for ZeROMemoryStats {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl ZeROMemoryStats {
60 pub fn new() -> Self {
61 Self {
62 optimizer_memory_saved: 0,
63 gradient_memory_saved: 0,
64 parameter_memory_saved: 0,
65 total_memory_saved: 0,
66 communication_overhead: 0,
67 }
68 }
69
70 pub fn update_totals(&mut self) {
71 self.total_memory_saved =
72 self.optimizer_memory_saved + self.gradient_memory_saved + self.parameter_memory_saved;
73 }
74}