Skip to main content

tensorlogic_infer/
strategy.rs

1//! Execution strategy configuration and policies.
2
3/// Execution mode for graph evaluation
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum ExecutionMode {
6    /// Execute graph immediately on each operation
7    Eager,
8    /// Build computation graph and execute lazily
9    Lazy,
10    /// Hybrid: lazy within subgraphs, eager between stages
11    Hybrid,
12}
13
14/// Gradient computation strategy
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum GradientStrategy {
17    /// No gradient computation
18    None,
19    /// Full gradient computation (standard backprop)
20    Full,
21    /// Checkpointing (recompute forward pass to save memory)
22    Checkpointed {
23        /// Checkpoint every N nodes
24        checkpoint_interval: usize,
25    },
26    /// Gradient accumulation across multiple steps
27    Accumulated {
28        /// Number of steps to accumulate before updating
29        accumulation_steps: usize,
30    },
31}
32
33/// Precision mode for computation
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum PrecisionMode {
36    /// Full precision (F64)
37    Full,
38    /// Single precision (F32)
39    Single,
40    /// Mixed precision (automatic F16/F32 selection)
41    Mixed,
42    /// Half precision (F16)
43    Half,
44}
45
46/// Memory management strategy
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum MemoryStrategy {
49    /// Standard allocation/deallocation
50    Standard,
51    /// Reuse tensors aggressively
52    Pooled,
53    /// Cache intermediate results
54    Cached,
55    /// Minimize peak memory usage
56    MinimalPeak,
57}
58
59/// Parallelism strategy
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum ParallelismStrategy {
62    /// No parallelism (sequential execution)
63    None,
64    /// Data parallelism across batches
65    DataParallel { num_workers: usize },
66    /// Model parallelism across devices
67    ModelParallel { num_devices: usize },
68    /// Pipeline parallelism
69    PipelineParallel { num_stages: usize },
70    /// Automatic selection based on graph structure
71    Automatic,
72}
73
74/// Complete execution strategy configuration
75#[derive(Debug, Clone)]
76pub struct ExecutionStrategy {
77    pub mode: ExecutionMode,
78    pub gradient: GradientStrategy,
79    pub precision: PrecisionMode,
80    pub memory: MemoryStrategy,
81    pub parallelism: ParallelismStrategy,
82    pub enable_fusion: bool,
83    pub enable_profiling: bool,
84}
85
86impl ExecutionStrategy {
87    /// Create a new execution strategy with defaults
88    pub fn new() -> Self {
89        ExecutionStrategy {
90            mode: ExecutionMode::Eager,
91            gradient: GradientStrategy::None,
92            precision: PrecisionMode::Full,
93            memory: MemoryStrategy::Standard,
94            parallelism: ParallelismStrategy::None,
95            enable_fusion: false,
96            enable_profiling: false,
97        }
98    }
99
100    /// Training strategy with full gradients and profiling
101    pub fn training() -> Self {
102        ExecutionStrategy {
103            mode: ExecutionMode::Lazy,
104            gradient: GradientStrategy::Full,
105            precision: PrecisionMode::Single,
106            memory: MemoryStrategy::Pooled,
107            parallelism: ParallelismStrategy::Automatic,
108            enable_fusion: true,
109            enable_profiling: true,
110        }
111    }
112
113    /// Inference strategy optimized for speed
114    pub fn inference() -> Self {
115        ExecutionStrategy {
116            mode: ExecutionMode::Eager,
117            gradient: GradientStrategy::None,
118            precision: PrecisionMode::Single,
119            memory: MemoryStrategy::Cached,
120            parallelism: ParallelismStrategy::Automatic,
121            enable_fusion: true,
122            enable_profiling: false,
123        }
124    }
125
126    /// Memory-efficient strategy for large models
127    pub fn memory_efficient() -> Self {
128        ExecutionStrategy {
129            mode: ExecutionMode::Hybrid,
130            gradient: GradientStrategy::Checkpointed {
131                checkpoint_interval: 10,
132            },
133            precision: PrecisionMode::Mixed,
134            memory: MemoryStrategy::MinimalPeak,
135            parallelism: ParallelismStrategy::None,
136            enable_fusion: false,
137            enable_profiling: false,
138        }
139    }
140
141    /// High-throughput strategy for batch processing
142    pub fn high_throughput() -> Self {
143        ExecutionStrategy {
144            mode: ExecutionMode::Lazy,
145            gradient: GradientStrategy::None,
146            precision: PrecisionMode::Single,
147            memory: MemoryStrategy::Pooled,
148            parallelism: ParallelismStrategy::DataParallel { num_workers: 4 },
149            enable_fusion: true,
150            enable_profiling: false,
151        }
152    }
153
154    /// Development/debugging strategy with profiling enabled
155    pub fn debug() -> Self {
156        ExecutionStrategy {
157            mode: ExecutionMode::Eager,
158            gradient: GradientStrategy::Full,
159            precision: PrecisionMode::Full,
160            memory: MemoryStrategy::Standard,
161            parallelism: ParallelismStrategy::None,
162            enable_fusion: false,
163            enable_profiling: true,
164        }
165    }
166
167    // Builder methods
168    pub fn with_mode(mut self, mode: ExecutionMode) -> Self {
169        self.mode = mode;
170        self
171    }
172
173    pub fn with_gradient(mut self, gradient: GradientStrategy) -> Self {
174        self.gradient = gradient;
175        self
176    }
177
178    pub fn with_precision(mut self, precision: PrecisionMode) -> Self {
179        self.precision = precision;
180        self
181    }
182
183    pub fn with_memory(mut self, memory: MemoryStrategy) -> Self {
184        self.memory = memory;
185        self
186    }
187
188    pub fn with_parallelism(mut self, parallelism: ParallelismStrategy) -> Self {
189        self.parallelism = parallelism;
190        self
191    }
192
193    pub fn enable_fusion(mut self) -> Self {
194        self.enable_fusion = true;
195        self
196    }
197
198    pub fn enable_profiling(mut self) -> Self {
199        self.enable_profiling = true;
200        self
201    }
202
203    /// Check if gradient computation is enabled
204    pub fn computes_gradients(&self) -> bool {
205        !matches!(self.gradient, GradientStrategy::None)
206    }
207
208    /// Check if strategy uses checkpointing
209    pub fn uses_checkpointing(&self) -> bool {
210        matches!(self.gradient, GradientStrategy::Checkpointed { .. })
211    }
212
213    /// Check if strategy is optimized for inference
214    pub fn is_inference_mode(&self) -> bool {
215        matches!(self.gradient, GradientStrategy::None)
216    }
217
218    /// Get checkpoint interval if using checkpointing
219    pub fn checkpoint_interval(&self) -> Option<usize> {
220        match self.gradient {
221            GradientStrategy::Checkpointed {
222                checkpoint_interval,
223            } => Some(checkpoint_interval),
224            _ => None,
225        }
226    }
227
228    /// Get gradient accumulation steps if using accumulation
229    pub fn accumulation_steps(&self) -> Option<usize> {
230        match self.gradient {
231            GradientStrategy::Accumulated { accumulation_steps } => Some(accumulation_steps),
232            _ => None,
233        }
234    }
235
236    /// Get number of parallel workers
237    pub fn num_workers(&self) -> usize {
238        match self.parallelism {
239            ParallelismStrategy::None => 1,
240            ParallelismStrategy::DataParallel { num_workers } => num_workers,
241            ParallelismStrategy::ModelParallel { num_devices } => num_devices,
242            ParallelismStrategy::PipelineParallel { num_stages } => num_stages,
243            ParallelismStrategy::Automatic => num_cpus::get().min(8),
244        }
245    }
246
247    /// Summary description of the strategy
248    pub fn summary(&self) -> String {
249        format!(
250            "Execution Strategy:\n\
251             - Mode: {:?}\n\
252             - Gradient: {:?}\n\
253             - Precision: {:?}\n\
254             - Memory: {:?}\n\
255             - Parallelism: {:?}\n\
256             - Fusion: {}\n\
257             - Profiling: {}",
258            self.mode,
259            self.gradient,
260            self.precision,
261            self.memory,
262            self.parallelism,
263            self.enable_fusion,
264            self.enable_profiling
265        )
266    }
267}
268
269impl Default for ExecutionStrategy {
270    fn default() -> Self {
271        Self::new()
272    }
273}
274
275/// Strategy optimizer for automatic strategy selection
276pub struct StrategyOptimizer;
277
278impl StrategyOptimizer {
279    /// Recommend strategy based on workload characteristics
280    pub fn recommend(
281        batch_size: usize,
282        model_size_mb: usize,
283        available_memory_mb: usize,
284        is_training: bool,
285    ) -> ExecutionStrategy {
286        let memory_pressure = (model_size_mb * batch_size) as f64 / available_memory_mb as f64;
287
288        if is_training {
289            if memory_pressure > 0.8 {
290                // High memory pressure: use checkpointing
291                ExecutionStrategy::training().with_gradient(GradientStrategy::Checkpointed {
292                    checkpoint_interval: 5,
293                })
294            } else if batch_size >= 64 {
295                // Large batch: use accumulation
296                ExecutionStrategy::training().with_gradient(GradientStrategy::Accumulated {
297                    accumulation_steps: 4,
298                })
299            } else {
300                ExecutionStrategy::training()
301            }
302        } else {
303            // Inference
304            if batch_size >= 32 {
305                ExecutionStrategy::high_throughput()
306            } else {
307                ExecutionStrategy::inference()
308            }
309        }
310    }
311
312    /// Estimate memory overhead for a strategy
313    pub fn estimate_memory_overhead(strategy: &ExecutionStrategy) -> f64 {
314        let mut overhead = 1.0;
315
316        // Execution mode overhead
317        overhead *= match strategy.mode {
318            ExecutionMode::Eager => 1.0,
319            ExecutionMode::Lazy => 1.2, // Graph storage
320            ExecutionMode::Hybrid => 1.1,
321        };
322
323        // Gradient overhead
324        overhead *= match strategy.gradient {
325            GradientStrategy::None => 1.0,
326            GradientStrategy::Full => 3.0, // Forward + backward + gradients
327            GradientStrategy::Checkpointed { .. } => 2.0, // Reduced memory
328            GradientStrategy::Accumulated { .. } => 3.5, // Extra gradient buffers
329        };
330
331        // Memory strategy adjustment
332        overhead *= match strategy.memory {
333            MemoryStrategy::Standard => 1.0,
334            MemoryStrategy::Pooled => 1.1,      // Pool overhead
335            MemoryStrategy::Cached => 1.3,      // Cache overhead
336            MemoryStrategy::MinimalPeak => 0.8, // Reduced peak
337        };
338
339        overhead
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_execution_strategy_presets() {
349        let training = ExecutionStrategy::training();
350        assert!(training.computes_gradients());
351        assert!(training.enable_fusion);
352
353        let inference = ExecutionStrategy::inference();
354        assert!(!inference.computes_gradients());
355        assert!(inference.is_inference_mode());
356
357        let memory_eff = ExecutionStrategy::memory_efficient();
358        assert!(memory_eff.uses_checkpointing());
359
360        let throughput = ExecutionStrategy::high_throughput();
361        assert!(throughput.num_workers() > 1);
362
363        let debug = ExecutionStrategy::debug();
364        assert!(debug.enable_profiling);
365    }
366
367    #[test]
368    fn test_execution_strategy_builder() {
369        let strategy = ExecutionStrategy::new()
370            .with_mode(ExecutionMode::Lazy)
371            .with_precision(PrecisionMode::Single)
372            .enable_fusion()
373            .enable_profiling();
374
375        assert_eq!(strategy.mode, ExecutionMode::Lazy);
376        assert_eq!(strategy.precision, PrecisionMode::Single);
377        assert!(strategy.enable_fusion);
378        assert!(strategy.enable_profiling);
379    }
380
381    #[test]
382    fn test_gradient_strategies() {
383        let no_grad = ExecutionStrategy::new().with_gradient(GradientStrategy::None);
384        assert!(!no_grad.computes_gradients());
385
386        let full_grad = ExecutionStrategy::new().with_gradient(GradientStrategy::Full);
387        assert!(full_grad.computes_gradients());
388
389        let checkpointed = ExecutionStrategy::new().with_gradient(GradientStrategy::Checkpointed {
390            checkpoint_interval: 10,
391        });
392        assert!(checkpointed.uses_checkpointing());
393        assert_eq!(checkpointed.checkpoint_interval(), Some(10));
394
395        let accumulated = ExecutionStrategy::new().with_gradient(GradientStrategy::Accumulated {
396            accumulation_steps: 4,
397        });
398        assert_eq!(accumulated.accumulation_steps(), Some(4));
399    }
400
401    #[test]
402    fn test_parallelism_strategies() {
403        let sequential = ExecutionStrategy::new().with_parallelism(ParallelismStrategy::None);
404        assert_eq!(sequential.num_workers(), 1);
405
406        let data_parallel = ExecutionStrategy::new()
407            .with_parallelism(ParallelismStrategy::DataParallel { num_workers: 4 });
408        assert_eq!(data_parallel.num_workers(), 4);
409
410        let automatic = ExecutionStrategy::new().with_parallelism(ParallelismStrategy::Automatic);
411        assert!(automatic.num_workers() >= 1);
412    }
413
414    #[test]
415    fn test_strategy_optimizer_recommendations() {
416        // Low memory, training
417        let strategy1 = StrategyOptimizer::recommend(32, 1000, 2000, true);
418        assert!(strategy1.computes_gradients());
419
420        // High memory pressure, training
421        let strategy2 = StrategyOptimizer::recommend(64, 2000, 2000, true);
422        assert!(strategy2.uses_checkpointing() || strategy2.accumulation_steps().is_some());
423
424        // Inference, large batch
425        let strategy3 = StrategyOptimizer::recommend(64, 500, 4000, false);
426        assert!(!strategy3.computes_gradients());
427
428        // Inference, small batch
429        let strategy4 = StrategyOptimizer::recommend(8, 500, 4000, false);
430        assert!(!strategy4.computes_gradients());
431    }
432
433    #[test]
434    fn test_memory_overhead_estimation() {
435        let eager_no_grad = ExecutionStrategy::new();
436        let overhead1 = StrategyOptimizer::estimate_memory_overhead(&eager_no_grad);
437        assert_eq!(overhead1, 1.0); // Baseline
438
439        let training = ExecutionStrategy::training();
440        let overhead2 = StrategyOptimizer::estimate_memory_overhead(&training);
441        assert!(overhead2 > 2.0); // Should have significant overhead
442
443        let memory_eff = ExecutionStrategy::memory_efficient();
444        let overhead3 = StrategyOptimizer::estimate_memory_overhead(&memory_eff);
445        assert!(overhead3 < overhead2); // Should be more efficient than full training
446    }
447
448    #[test]
449    fn test_execution_modes() {
450        assert_eq!(ExecutionMode::Eager, ExecutionMode::Eager);
451        assert_ne!(ExecutionMode::Eager, ExecutionMode::Lazy);
452    }
453
454    #[test]
455    fn test_precision_modes() {
456        let modes = vec![
457            PrecisionMode::Full,
458            PrecisionMode::Single,
459            PrecisionMode::Mixed,
460            PrecisionMode::Half,
461        ];
462
463        for mode in modes {
464            let strategy = ExecutionStrategy::new().with_precision(mode);
465            assert_eq!(strategy.precision, mode);
466        }
467    }
468
469    #[test]
470    fn test_strategy_summary() {
471        let strategy = ExecutionStrategy::training();
472        let summary = strategy.summary();
473
474        assert!(summary.contains("Execution Strategy"));
475        assert!(summary.contains("Mode"));
476        assert!(summary.contains("Gradient"));
477        assert!(summary.contains("Precision"));
478    }
479}