1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum ExecutionMode {
6 Eager,
8 Lazy,
10 Hybrid,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum GradientStrategy {
17 None,
19 Full,
21 Checkpointed {
23 checkpoint_interval: usize,
25 },
26 Accumulated {
28 accumulation_steps: usize,
30 },
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum PrecisionMode {
36 Full,
38 Single,
40 Mixed,
42 Half,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum MemoryStrategy {
49 Standard,
51 Pooled,
53 Cached,
55 MinimalPeak,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum ParallelismStrategy {
62 None,
64 DataParallel { num_workers: usize },
66 ModelParallel { num_devices: usize },
68 PipelineParallel { num_stages: usize },
70 Automatic,
72}
73
74#[derive(Debug, Clone)]
76pub struct ExecutionStrategy {
77 pub mode: ExecutionMode,
78 pub gradient: GradientStrategy,
79 pub precision: PrecisionMode,
80 pub memory: MemoryStrategy,
81 pub parallelism: ParallelismStrategy,
82 pub enable_fusion: bool,
83 pub enable_profiling: bool,
84}
85
86impl ExecutionStrategy {
87 pub fn new() -> Self {
89 ExecutionStrategy {
90 mode: ExecutionMode::Eager,
91 gradient: GradientStrategy::None,
92 precision: PrecisionMode::Full,
93 memory: MemoryStrategy::Standard,
94 parallelism: ParallelismStrategy::None,
95 enable_fusion: false,
96 enable_profiling: false,
97 }
98 }
99
100 pub fn training() -> Self {
102 ExecutionStrategy {
103 mode: ExecutionMode::Lazy,
104 gradient: GradientStrategy::Full,
105 precision: PrecisionMode::Single,
106 memory: MemoryStrategy::Pooled,
107 parallelism: ParallelismStrategy::Automatic,
108 enable_fusion: true,
109 enable_profiling: true,
110 }
111 }
112
113 pub fn inference() -> Self {
115 ExecutionStrategy {
116 mode: ExecutionMode::Eager,
117 gradient: GradientStrategy::None,
118 precision: PrecisionMode::Single,
119 memory: MemoryStrategy::Cached,
120 parallelism: ParallelismStrategy::Automatic,
121 enable_fusion: true,
122 enable_profiling: false,
123 }
124 }
125
126 pub fn memory_efficient() -> Self {
128 ExecutionStrategy {
129 mode: ExecutionMode::Hybrid,
130 gradient: GradientStrategy::Checkpointed {
131 checkpoint_interval: 10,
132 },
133 precision: PrecisionMode::Mixed,
134 memory: MemoryStrategy::MinimalPeak,
135 parallelism: ParallelismStrategy::None,
136 enable_fusion: false,
137 enable_profiling: false,
138 }
139 }
140
141 pub fn high_throughput() -> Self {
143 ExecutionStrategy {
144 mode: ExecutionMode::Lazy,
145 gradient: GradientStrategy::None,
146 precision: PrecisionMode::Single,
147 memory: MemoryStrategy::Pooled,
148 parallelism: ParallelismStrategy::DataParallel { num_workers: 4 },
149 enable_fusion: true,
150 enable_profiling: false,
151 }
152 }
153
154 pub fn debug() -> Self {
156 ExecutionStrategy {
157 mode: ExecutionMode::Eager,
158 gradient: GradientStrategy::Full,
159 precision: PrecisionMode::Full,
160 memory: MemoryStrategy::Standard,
161 parallelism: ParallelismStrategy::None,
162 enable_fusion: false,
163 enable_profiling: true,
164 }
165 }
166
167 pub fn with_mode(mut self, mode: ExecutionMode) -> Self {
169 self.mode = mode;
170 self
171 }
172
173 pub fn with_gradient(mut self, gradient: GradientStrategy) -> Self {
174 self.gradient = gradient;
175 self
176 }
177
178 pub fn with_precision(mut self, precision: PrecisionMode) -> Self {
179 self.precision = precision;
180 self
181 }
182
183 pub fn with_memory(mut self, memory: MemoryStrategy) -> Self {
184 self.memory = memory;
185 self
186 }
187
188 pub fn with_parallelism(mut self, parallelism: ParallelismStrategy) -> Self {
189 self.parallelism = parallelism;
190 self
191 }
192
193 pub fn enable_fusion(mut self) -> Self {
194 self.enable_fusion = true;
195 self
196 }
197
198 pub fn enable_profiling(mut self) -> Self {
199 self.enable_profiling = true;
200 self
201 }
202
203 pub fn computes_gradients(&self) -> bool {
205 !matches!(self.gradient, GradientStrategy::None)
206 }
207
208 pub fn uses_checkpointing(&self) -> bool {
210 matches!(self.gradient, GradientStrategy::Checkpointed { .. })
211 }
212
213 pub fn is_inference_mode(&self) -> bool {
215 matches!(self.gradient, GradientStrategy::None)
216 }
217
218 pub fn checkpoint_interval(&self) -> Option<usize> {
220 match self.gradient {
221 GradientStrategy::Checkpointed {
222 checkpoint_interval,
223 } => Some(checkpoint_interval),
224 _ => None,
225 }
226 }
227
228 pub fn accumulation_steps(&self) -> Option<usize> {
230 match self.gradient {
231 GradientStrategy::Accumulated { accumulation_steps } => Some(accumulation_steps),
232 _ => None,
233 }
234 }
235
236 pub fn num_workers(&self) -> usize {
238 match self.parallelism {
239 ParallelismStrategy::None => 1,
240 ParallelismStrategy::DataParallel { num_workers } => num_workers,
241 ParallelismStrategy::ModelParallel { num_devices } => num_devices,
242 ParallelismStrategy::PipelineParallel { num_stages } => num_stages,
243 ParallelismStrategy::Automatic => num_cpus::get().min(8),
244 }
245 }
246
247 pub fn summary(&self) -> String {
249 format!(
250 "Execution Strategy:\n\
251 - Mode: {:?}\n\
252 - Gradient: {:?}\n\
253 - Precision: {:?}\n\
254 - Memory: {:?}\n\
255 - Parallelism: {:?}\n\
256 - Fusion: {}\n\
257 - Profiling: {}",
258 self.mode,
259 self.gradient,
260 self.precision,
261 self.memory,
262 self.parallelism,
263 self.enable_fusion,
264 self.enable_profiling
265 )
266 }
267}
268
269impl Default for ExecutionStrategy {
270 fn default() -> Self {
271 Self::new()
272 }
273}
274
275pub struct StrategyOptimizer;
277
278impl StrategyOptimizer {
279 pub fn recommend(
281 batch_size: usize,
282 model_size_mb: usize,
283 available_memory_mb: usize,
284 is_training: bool,
285 ) -> ExecutionStrategy {
286 let memory_pressure = (model_size_mb * batch_size) as f64 / available_memory_mb as f64;
287
288 if is_training {
289 if memory_pressure > 0.8 {
290 ExecutionStrategy::training().with_gradient(GradientStrategy::Checkpointed {
292 checkpoint_interval: 5,
293 })
294 } else if batch_size >= 64 {
295 ExecutionStrategy::training().with_gradient(GradientStrategy::Accumulated {
297 accumulation_steps: 4,
298 })
299 } else {
300 ExecutionStrategy::training()
301 }
302 } else {
303 if batch_size >= 32 {
305 ExecutionStrategy::high_throughput()
306 } else {
307 ExecutionStrategy::inference()
308 }
309 }
310 }
311
312 pub fn estimate_memory_overhead(strategy: &ExecutionStrategy) -> f64 {
314 let mut overhead = 1.0;
315
316 overhead *= match strategy.mode {
318 ExecutionMode::Eager => 1.0,
319 ExecutionMode::Lazy => 1.2, ExecutionMode::Hybrid => 1.1,
321 };
322
323 overhead *= match strategy.gradient {
325 GradientStrategy::None => 1.0,
326 GradientStrategy::Full => 3.0, GradientStrategy::Checkpointed { .. } => 2.0, GradientStrategy::Accumulated { .. } => 3.5, };
330
331 overhead *= match strategy.memory {
333 MemoryStrategy::Standard => 1.0,
334 MemoryStrategy::Pooled => 1.1, MemoryStrategy::Cached => 1.3, MemoryStrategy::MinimalPeak => 0.8, };
338
339 overhead
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_execution_strategy_presets() {
349 let training = ExecutionStrategy::training();
350 assert!(training.computes_gradients());
351 assert!(training.enable_fusion);
352
353 let inference = ExecutionStrategy::inference();
354 assert!(!inference.computes_gradients());
355 assert!(inference.is_inference_mode());
356
357 let memory_eff = ExecutionStrategy::memory_efficient();
358 assert!(memory_eff.uses_checkpointing());
359
360 let throughput = ExecutionStrategy::high_throughput();
361 assert!(throughput.num_workers() > 1);
362
363 let debug = ExecutionStrategy::debug();
364 assert!(debug.enable_profiling);
365 }
366
367 #[test]
368 fn test_execution_strategy_builder() {
369 let strategy = ExecutionStrategy::new()
370 .with_mode(ExecutionMode::Lazy)
371 .with_precision(PrecisionMode::Single)
372 .enable_fusion()
373 .enable_profiling();
374
375 assert_eq!(strategy.mode, ExecutionMode::Lazy);
376 assert_eq!(strategy.precision, PrecisionMode::Single);
377 assert!(strategy.enable_fusion);
378 assert!(strategy.enable_profiling);
379 }
380
381 #[test]
382 fn test_gradient_strategies() {
383 let no_grad = ExecutionStrategy::new().with_gradient(GradientStrategy::None);
384 assert!(!no_grad.computes_gradients());
385
386 let full_grad = ExecutionStrategy::new().with_gradient(GradientStrategy::Full);
387 assert!(full_grad.computes_gradients());
388
389 let checkpointed = ExecutionStrategy::new().with_gradient(GradientStrategy::Checkpointed {
390 checkpoint_interval: 10,
391 });
392 assert!(checkpointed.uses_checkpointing());
393 assert_eq!(checkpointed.checkpoint_interval(), Some(10));
394
395 let accumulated = ExecutionStrategy::new().with_gradient(GradientStrategy::Accumulated {
396 accumulation_steps: 4,
397 });
398 assert_eq!(accumulated.accumulation_steps(), Some(4));
399 }
400
401 #[test]
402 fn test_parallelism_strategies() {
403 let sequential = ExecutionStrategy::new().with_parallelism(ParallelismStrategy::None);
404 assert_eq!(sequential.num_workers(), 1);
405
406 let data_parallel = ExecutionStrategy::new()
407 .with_parallelism(ParallelismStrategy::DataParallel { num_workers: 4 });
408 assert_eq!(data_parallel.num_workers(), 4);
409
410 let automatic = ExecutionStrategy::new().with_parallelism(ParallelismStrategy::Automatic);
411 assert!(automatic.num_workers() >= 1);
412 }
413
414 #[test]
415 fn test_strategy_optimizer_recommendations() {
416 let strategy1 = StrategyOptimizer::recommend(32, 1000, 2000, true);
418 assert!(strategy1.computes_gradients());
419
420 let strategy2 = StrategyOptimizer::recommend(64, 2000, 2000, true);
422 assert!(strategy2.uses_checkpointing() || strategy2.accumulation_steps().is_some());
423
424 let strategy3 = StrategyOptimizer::recommend(64, 500, 4000, false);
426 assert!(!strategy3.computes_gradients());
427
428 let strategy4 = StrategyOptimizer::recommend(8, 500, 4000, false);
430 assert!(!strategy4.computes_gradients());
431 }
432
433 #[test]
434 fn test_memory_overhead_estimation() {
435 let eager_no_grad = ExecutionStrategy::new();
436 let overhead1 = StrategyOptimizer::estimate_memory_overhead(&eager_no_grad);
437 assert_eq!(overhead1, 1.0); let training = ExecutionStrategy::training();
440 let overhead2 = StrategyOptimizer::estimate_memory_overhead(&training);
441 assert!(overhead2 > 2.0); let memory_eff = ExecutionStrategy::memory_efficient();
444 let overhead3 = StrategyOptimizer::estimate_memory_overhead(&memory_eff);
445 assert!(overhead3 < overhead2); }
447
448 #[test]
449 fn test_execution_modes() {
450 assert_eq!(ExecutionMode::Eager, ExecutionMode::Eager);
451 assert_ne!(ExecutionMode::Eager, ExecutionMode::Lazy);
452 }
453
454 #[test]
455 fn test_precision_modes() {
456 let modes = vec![
457 PrecisionMode::Full,
458 PrecisionMode::Single,
459 PrecisionMode::Mixed,
460 PrecisionMode::Half,
461 ];
462
463 for mode in modes {
464 let strategy = ExecutionStrategy::new().with_precision(mode);
465 assert_eq!(strategy.precision, mode);
466 }
467 }
468
469 #[test]
470 fn test_strategy_summary() {
471 let strategy = ExecutionStrategy::training();
472 let summary = strategy.summary();
473
474 assert!(summary.contains("Execution Strategy"));
475 assert!(summary.contains("Mode"));
476 assert!(summary.contains("Gradient"));
477 assert!(summary.contains("Precision"));
478 }
479}