Skip to main content

tensorlogic_infer/
autodiff.rs

1//! Autodiff enhancements for training and optimization.
2//!
3//! This module extends the basic TlAutodiff trait with:
4//! - Gradient accumulation strategies
5//! - Custom gradient functions
6//! - Gradient clipping and scaling
7
8use std::collections::HashMap;
9
10use tensorlogic_ir::EinsumGraph;
11
12/// Strategy for accumulating gradients
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum GradientAccumulationStrategy {
15    /// Standard accumulation (sum gradients)
16    Standard,
17    /// Average gradients over accumulation steps
18    Average,
19    /// Gradient checkpointing to save memory
20    Checkpointing,
21    /// Mixed precision accumulation
22    MixedPrecision,
23}
24
25/// Configuration for gradient accumulation
26#[derive(Debug, Clone)]
27pub struct AccumulationConfig {
28    pub strategy: GradientAccumulationStrategy,
29    pub accumulation_steps: usize,
30    pub clear_after_step: bool,
31}
32
33impl AccumulationConfig {
34    pub fn new(strategy: GradientAccumulationStrategy, steps: usize) -> Self {
35        AccumulationConfig {
36            strategy,
37            accumulation_steps: steps,
38            clear_after_step: true,
39        }
40    }
41
42    pub fn standard(steps: usize) -> Self {
43        Self::new(GradientAccumulationStrategy::Standard, steps)
44    }
45
46    pub fn average(steps: usize) -> Self {
47        Self::new(GradientAccumulationStrategy::Average, steps)
48    }
49
50    pub fn checkpointing(steps: usize) -> Self {
51        Self::new(GradientAccumulationStrategy::Checkpointing, steps)
52    }
53
54    pub fn mixed_precision(steps: usize) -> Self {
55        Self::new(GradientAccumulationStrategy::MixedPrecision, steps)
56    }
57}
58
59impl Default for AccumulationConfig {
60    fn default() -> Self {
61        Self::standard(1)
62    }
63}
64
65/// Gradient clipping strategy
66#[derive(Debug, Clone, Copy, PartialEq)]
67pub enum ClippingStrategy {
68    /// No clipping
69    None,
70    /// Clip by value (element-wise)
71    ByValue { min: f64, max: f64 },
72    /// Clip by global norm
73    ByGlobalNorm { max_norm: f64 },
74    /// Clip by layer norm
75    ByLayerNorm { max_norm: f64 },
76}
77
78/// Gradient scaling configuration
79#[derive(Debug, Clone, Copy, PartialEq)]
80pub struct GradientScaling {
81    pub enabled: bool,
82    pub initial_scale: f64,
83    pub growth_factor: f64,
84    pub backoff_factor: f64,
85    pub growth_interval: usize,
86}
87
88impl GradientScaling {
89    pub fn new(initial_scale: f64) -> Self {
90        GradientScaling {
91            enabled: true,
92            initial_scale,
93            growth_factor: 2.0,
94            backoff_factor: 0.5,
95            growth_interval: 2000,
96        }
97    }
98
99    pub fn disabled() -> Self {
100        GradientScaling {
101            enabled: false,
102            initial_scale: 1.0,
103            growth_factor: 1.0,
104            backoff_factor: 1.0,
105            growth_interval: 0,
106        }
107    }
108}
109
110impl Default for GradientScaling {
111    fn default() -> Self {
112        Self::disabled()
113    }
114}
115
116/// Complete gradient configuration
117#[derive(Debug, Clone)]
118pub struct GradientConfig {
119    pub accumulation: AccumulationConfig,
120    pub clipping: ClippingStrategy,
121    pub scaling: GradientScaling,
122}
123
124impl GradientConfig {
125    pub fn new() -> Self {
126        GradientConfig {
127            accumulation: AccumulationConfig::default(),
128            clipping: ClippingStrategy::None,
129            scaling: GradientScaling::default(),
130        }
131    }
132
133    pub fn with_accumulation(mut self, config: AccumulationConfig) -> Self {
134        self.accumulation = config;
135        self
136    }
137
138    pub fn with_clipping(mut self, strategy: ClippingStrategy) -> Self {
139        self.clipping = strategy;
140        self
141    }
142
143    pub fn with_scaling(mut self, scaling: GradientScaling) -> Self {
144        self.scaling = scaling;
145        self
146    }
147}
148
149impl Default for GradientConfig {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155/// Custom backward function for a tensor operation
156pub type BackwardFn<T, E> = Box<dyn Fn(&T, &[T]) -> Result<Vec<T>, E>>;
157
158/// Registry for custom gradient functions
159pub struct CustomGradientRegistry<T, E> {
160    gradients: HashMap<String, BackwardFn<T, E>>,
161}
162
163impl<T, E> CustomGradientRegistry<T, E> {
164    pub fn new() -> Self {
165        CustomGradientRegistry {
166            gradients: HashMap::new(),
167        }
168    }
169
170    /// Register a custom backward function for an operation
171    pub fn register<F>(&mut self, operation_name: String, backward_fn: F)
172    where
173        F: Fn(&T, &[T]) -> Result<Vec<T>, E> + 'static,
174    {
175        self.gradients.insert(operation_name, Box::new(backward_fn));
176    }
177
178    /// Get custom gradient function for an operation
179    pub fn get(&self, operation_name: &str) -> Option<&BackwardFn<T, E>> {
180        self.gradients.get(operation_name)
181    }
182
183    /// Check if custom gradient exists
184    pub fn has_custom_gradient(&self, operation_name: &str) -> bool {
185        self.gradients.contains_key(operation_name)
186    }
187
188    /// Remove custom gradient
189    pub fn unregister(&mut self, operation_name: &str) -> bool {
190        self.gradients.remove(operation_name).is_some()
191    }
192
193    /// Get number of registered gradients
194    pub fn len(&self) -> usize {
195        self.gradients.len()
196    }
197
198    pub fn is_empty(&self) -> bool {
199        self.gradients.is_empty()
200    }
201}
202
203impl<T, E> Default for CustomGradientRegistry<T, E> {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209/// Gradient statistics for monitoring
210#[derive(Debug, Clone)]
211pub struct GradientStats {
212    pub global_norm: f64,
213    pub min_value: f64,
214    pub max_value: f64,
215    pub mean_value: f64,
216    pub num_parameters: usize,
217    pub num_finite: usize,
218    pub num_infinite: usize,
219    pub num_nan: usize,
220}
221
222impl GradientStats {
223    pub fn new() -> Self {
224        GradientStats {
225            global_norm: 0.0,
226            min_value: f64::INFINITY,
227            max_value: f64::NEG_INFINITY,
228            mean_value: 0.0,
229            num_parameters: 0,
230            num_finite: 0,
231            num_infinite: 0,
232            num_nan: 0,
233        }
234    }
235
236    pub fn has_nan(&self) -> bool {
237        self.num_nan > 0
238    }
239
240    pub fn has_inf(&self) -> bool {
241        self.num_infinite > 0
242    }
243
244    pub fn is_healthy(&self) -> bool {
245        !self.has_nan() && !self.has_inf()
246    }
247
248    pub fn finite_ratio(&self) -> f64 {
249        if self.num_parameters == 0 {
250            return 0.0;
251        }
252        (self.num_finite as f64) / (self.num_parameters as f64)
253    }
254}
255
256impl Default for GradientStats {
257    fn default() -> Self {
258        Self::new()
259    }
260}
261
262/// Trait for executors with enhanced autodiff capabilities
263pub trait TlEnhancedAutodiff {
264    type Tensor;
265    type Tape;
266    type Error;
267
268    /// Execute forward pass with gradient accumulation
269    fn forward_with_accumulation(
270        &mut self,
271        graph: &EinsumGraph,
272        config: &AccumulationConfig,
273    ) -> Result<Self::Tensor, Self::Error>;
274
275    /// Execute backward pass with gradient clipping
276    fn backward_with_clipping(
277        &mut self,
278        graph: &EinsumGraph,
279        loss: &Self::Tensor,
280        strategy: ClippingStrategy,
281    ) -> Result<Self::Tape, Self::Error>;
282
283    /// Apply gradient scaling
284    fn scale_gradients(
285        &mut self,
286        gradients: &mut Self::Tape,
287        scaling: &GradientScaling,
288    ) -> Result<(), Self::Error>;
289
290    /// Compute gradient statistics
291    fn gradient_stats(&self, gradients: &Self::Tape) -> Result<GradientStats, Self::Error>;
292
293    /// Register custom gradient function
294    fn register_custom_gradient(
295        &mut self,
296        operation_name: String,
297        backward_fn: BackwardFn<Self::Tensor, Self::Error>,
298    );
299
300    /// Check if custom gradient exists
301    fn has_custom_gradient(&self, operation_name: &str) -> bool;
302}
303
304/// Gradient accumulator for managing accumulated gradients
305pub struct GradientAccumulator<T> {
306    accumulated_gradients: Vec<T>,
307    accumulation_count: usize,
308    config: AccumulationConfig,
309}
310
311impl<T: Clone> GradientAccumulator<T> {
312    pub fn new(config: AccumulationConfig) -> Self {
313        GradientAccumulator {
314            accumulated_gradients: Vec::new(),
315            accumulation_count: 0,
316            config,
317        }
318    }
319
320    /// Add gradients to accumulator
321    pub fn accumulate(&mut self, gradients: Vec<T>) {
322        if self.accumulated_gradients.is_empty() {
323            self.accumulated_gradients = gradients;
324        } else {
325            // In real implementation, would add tensors element-wise
326            self.accumulated_gradients = gradients;
327        }
328        self.accumulation_count += 1;
329    }
330
331    /// Check if ready to step (accumulated enough)
332    pub fn is_ready(&self) -> bool {
333        self.accumulation_count >= self.config.accumulation_steps
334    }
335
336    /// Get accumulated gradients and optionally reset
337    pub fn step(&mut self) -> Vec<T> {
338        let gradients = self.accumulated_gradients.clone();
339
340        if self.config.clear_after_step {
341            self.clear();
342        }
343
344        gradients
345    }
346
347    /// Clear accumulated gradients
348    pub fn clear(&mut self) {
349        self.accumulated_gradients.clear();
350        self.accumulation_count = 0;
351    }
352
353    /// Get current accumulation count
354    pub fn count(&self) -> usize {
355        self.accumulation_count
356    }
357
358    pub fn config(&self) -> &AccumulationConfig {
359        &self.config
360    }
361}
362
363/// Gradient clipper for applying clipping strategies
364pub struct GradientClipper {
365    strategy: ClippingStrategy,
366    num_clips: usize,
367}
368
369impl GradientClipper {
370    pub fn new(strategy: ClippingStrategy) -> Self {
371        GradientClipper {
372            strategy,
373            num_clips: 0,
374        }
375    }
376
377    /// Check if gradient value should be clipped
378    pub fn should_clip(&self, value: f64) -> bool {
379        match self.strategy {
380            ClippingStrategy::None => false,
381            ClippingStrategy::ByValue { min, max } => value < min || value > max,
382            ClippingStrategy::ByGlobalNorm { max_norm: _ } => {
383                // Would need full gradient to compute global norm
384                false
385            }
386            ClippingStrategy::ByLayerNorm { max_norm: _ } => {
387                // Would need layer gradients
388                false
389            }
390        }
391    }
392
393    /// Clip a single gradient value
394    pub fn clip_value(&mut self, value: f64) -> f64 {
395        match self.strategy {
396            ClippingStrategy::None => value,
397            ClippingStrategy::ByValue { min, max } => {
398                if value < min || value > max {
399                    self.num_clips += 1;
400                }
401                value.clamp(min, max)
402            }
403            ClippingStrategy::ByGlobalNorm { max_norm: _ } => value,
404            ClippingStrategy::ByLayerNorm { max_norm: _ } => value,
405        }
406    }
407
408    /// Get number of clipped values
409    pub fn num_clips(&self) -> usize {
410        self.num_clips
411    }
412
413    /// Reset clip counter
414    pub fn reset(&mut self) {
415        self.num_clips = 0;
416    }
417
418    pub fn strategy(&self) -> ClippingStrategy {
419        self.strategy
420    }
421}
422
423/// Gradient scaler for mixed precision training
424pub struct GradientScaler {
425    config: GradientScaling,
426    current_scale: f64,
427    growth_tracker: usize,
428}
429
430impl GradientScaler {
431    pub fn new(config: GradientScaling) -> Self {
432        let current_scale = config.initial_scale;
433        GradientScaler {
434            config,
435            current_scale,
436            growth_tracker: 0,
437        }
438    }
439
440    /// Scale gradients up
441    pub fn scale(&self, value: f64) -> f64 {
442        if !self.config.enabled {
443            return value;
444        }
445        value * self.current_scale
446    }
447
448    /// Unscale gradients (for optimizer step)
449    pub fn unscale(&self, value: f64) -> f64 {
450        if !self.config.enabled {
451            return value;
452        }
453        value / self.current_scale
454    }
455
456    /// Update scale based on gradient health
457    pub fn update(&mut self, gradients_healthy: bool) {
458        if !self.config.enabled {
459            return;
460        }
461
462        if gradients_healthy {
463            self.growth_tracker += 1;
464            if self.growth_tracker >= self.config.growth_interval {
465                self.current_scale *= self.config.growth_factor;
466                self.growth_tracker = 0;
467            }
468        } else {
469            // Backoff on unhealthy gradients
470            self.current_scale *= self.config.backoff_factor;
471            self.growth_tracker = 0;
472        }
473    }
474
475    /// Get current scale factor
476    pub fn get_scale(&self) -> f64 {
477        self.current_scale
478    }
479
480    pub fn config(&self) -> &GradientScaling {
481        &self.config
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    #[test]
490    fn test_accumulation_config() {
491        let config = AccumulationConfig::standard(4);
492        assert_eq!(config.strategy, GradientAccumulationStrategy::Standard);
493        assert_eq!(config.accumulation_steps, 4);
494        assert!(config.clear_after_step);
495    }
496
497    #[test]
498    fn test_clipping_strategy() {
499        let none = ClippingStrategy::None;
500        let by_value = ClippingStrategy::ByValue {
501            min: -1.0,
502            max: 1.0,
503        };
504        let by_norm = ClippingStrategy::ByGlobalNorm { max_norm: 1.0 };
505
506        assert_eq!(none, ClippingStrategy::None);
507        assert_ne!(by_value, none);
508        assert_ne!(by_norm, by_value);
509    }
510
511    #[test]
512    fn test_gradient_config() {
513        let config = GradientConfig::new()
514            .with_accumulation(AccumulationConfig::average(4))
515            .with_clipping(ClippingStrategy::ByValue {
516                min: -1.0,
517                max: 1.0,
518            });
519
520        assert_eq!(
521            config.accumulation.strategy,
522            GradientAccumulationStrategy::Average
523        );
524        assert_eq!(config.accumulation.accumulation_steps, 4);
525    }
526
527    #[test]
528    fn test_gradient_scaling() {
529        let scaling = GradientScaling::new(1024.0);
530        assert!(scaling.enabled);
531        assert_eq!(scaling.initial_scale, 1024.0);
532        assert_eq!(scaling.growth_factor, 2.0);
533
534        let disabled = GradientScaling::disabled();
535        assert!(!disabled.enabled);
536    }
537
538    #[test]
539    fn test_gradient_stats() {
540        let mut stats = GradientStats::new();
541        stats.num_parameters = 100;
542        stats.num_finite = 95;
543        stats.num_nan = 5;
544        stats.num_infinite = 0;
545
546        assert!(stats.has_nan());
547        assert!(!stats.has_inf());
548        assert!(!stats.is_healthy());
549        assert_eq!(stats.finite_ratio(), 0.95);
550    }
551
552    #[test]
553    fn test_custom_gradient_registry() {
554        let mut registry: CustomGradientRegistry<f64, String> = CustomGradientRegistry::new();
555
556        registry.register("custom_op".to_string(), |_output, _inputs| {
557            Ok(vec![1.0, 2.0, 3.0])
558        });
559
560        assert!(registry.has_custom_gradient("custom_op"));
561        assert!(!registry.has_custom_gradient("other_op"));
562        assert_eq!(registry.len(), 1);
563        assert!(!registry.is_empty());
564
565        let removed = registry.unregister("custom_op");
566        assert!(removed);
567        assert!(registry.is_empty());
568    }
569
570    #[test]
571    fn test_gradient_accumulator() {
572        let config = AccumulationConfig::standard(3);
573        let mut accumulator: GradientAccumulator<f64> = GradientAccumulator::new(config);
574
575        assert_eq!(accumulator.count(), 0);
576        assert!(!accumulator.is_ready());
577
578        accumulator.accumulate(vec![1.0, 2.0, 3.0]);
579        assert_eq!(accumulator.count(), 1);
580        assert!(!accumulator.is_ready());
581
582        accumulator.accumulate(vec![4.0, 5.0, 6.0]);
583        accumulator.accumulate(vec![7.0, 8.0, 9.0]);
584        assert!(accumulator.is_ready());
585
586        let _gradients = accumulator.step();
587        assert_eq!(accumulator.count(), 0);
588    }
589
590    #[test]
591    fn test_gradient_clipper() {
592        let mut clipper = GradientClipper::new(ClippingStrategy::ByValue {
593            min: -1.0,
594            max: 1.0,
595        });
596
597        assert!(!clipper.should_clip(0.5));
598        assert!(clipper.should_clip(2.0));
599        assert!(clipper.should_clip(-2.0));
600
601        let clipped = clipper.clip_value(2.0);
602        assert_eq!(clipped, 1.0);
603        assert_eq!(clipper.num_clips(), 1);
604
605        let clipped = clipper.clip_value(-2.0);
606        assert_eq!(clipped, -1.0);
607        assert_eq!(clipper.num_clips(), 2);
608
609        clipper.reset();
610        assert_eq!(clipper.num_clips(), 0);
611    }
612
613    #[test]
614    fn test_gradient_scaler() {
615        let config = GradientScaling::new(1024.0);
616        let mut scaler = GradientScaler::new(config);
617
618        assert_eq!(scaler.get_scale(), 1024.0);
619
620        let scaled = scaler.scale(2.0);
621        assert_eq!(scaled, 2048.0);
622
623        let unscaled = scaler.unscale(2048.0);
624        assert_eq!(unscaled, 2.0);
625
626        // Test growth
627        scaler.growth_tracker = config.growth_interval - 1;
628        scaler.update(true);
629        assert_eq!(scaler.get_scale(), 2048.0); // Grew by factor of 2
630
631        // Test backoff
632        scaler.update(false);
633        assert_eq!(scaler.get_scale(), 1024.0); // Backed off by factor of 0.5
634    }
635
636    #[test]
637    fn test_gradient_scaler_disabled() {
638        let config = GradientScaling::disabled();
639        let scaler = GradientScaler::new(config);
640
641        assert_eq!(scaler.scale(2.0), 2.0);
642        assert_eq!(scaler.unscale(2.0), 2.0);
643    }
644}