1use std::collections::HashMap;
9
10use tensorlogic_ir::EinsumGraph;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum GradientAccumulationStrategy {
15 Standard,
17 Average,
19 Checkpointing,
21 MixedPrecision,
23}
24
25#[derive(Debug, Clone)]
27pub struct AccumulationConfig {
28 pub strategy: GradientAccumulationStrategy,
29 pub accumulation_steps: usize,
30 pub clear_after_step: bool,
31}
32
33impl AccumulationConfig {
34 pub fn new(strategy: GradientAccumulationStrategy, steps: usize) -> Self {
35 AccumulationConfig {
36 strategy,
37 accumulation_steps: steps,
38 clear_after_step: true,
39 }
40 }
41
42 pub fn standard(steps: usize) -> Self {
43 Self::new(GradientAccumulationStrategy::Standard, steps)
44 }
45
46 pub fn average(steps: usize) -> Self {
47 Self::new(GradientAccumulationStrategy::Average, steps)
48 }
49
50 pub fn checkpointing(steps: usize) -> Self {
51 Self::new(GradientAccumulationStrategy::Checkpointing, steps)
52 }
53
54 pub fn mixed_precision(steps: usize) -> Self {
55 Self::new(GradientAccumulationStrategy::MixedPrecision, steps)
56 }
57}
58
59impl Default for AccumulationConfig {
60 fn default() -> Self {
61 Self::standard(1)
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq)]
67pub enum ClippingStrategy {
68 None,
70 ByValue { min: f64, max: f64 },
72 ByGlobalNorm { max_norm: f64 },
74 ByLayerNorm { max_norm: f64 },
76}
77
78#[derive(Debug, Clone, Copy, PartialEq)]
80pub struct GradientScaling {
81 pub enabled: bool,
82 pub initial_scale: f64,
83 pub growth_factor: f64,
84 pub backoff_factor: f64,
85 pub growth_interval: usize,
86}
87
88impl GradientScaling {
89 pub fn new(initial_scale: f64) -> Self {
90 GradientScaling {
91 enabled: true,
92 initial_scale,
93 growth_factor: 2.0,
94 backoff_factor: 0.5,
95 growth_interval: 2000,
96 }
97 }
98
99 pub fn disabled() -> Self {
100 GradientScaling {
101 enabled: false,
102 initial_scale: 1.0,
103 growth_factor: 1.0,
104 backoff_factor: 1.0,
105 growth_interval: 0,
106 }
107 }
108}
109
110impl Default for GradientScaling {
111 fn default() -> Self {
112 Self::disabled()
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct GradientConfig {
119 pub accumulation: AccumulationConfig,
120 pub clipping: ClippingStrategy,
121 pub scaling: GradientScaling,
122}
123
124impl GradientConfig {
125 pub fn new() -> Self {
126 GradientConfig {
127 accumulation: AccumulationConfig::default(),
128 clipping: ClippingStrategy::None,
129 scaling: GradientScaling::default(),
130 }
131 }
132
133 pub fn with_accumulation(mut self, config: AccumulationConfig) -> Self {
134 self.accumulation = config;
135 self
136 }
137
138 pub fn with_clipping(mut self, strategy: ClippingStrategy) -> Self {
139 self.clipping = strategy;
140 self
141 }
142
143 pub fn with_scaling(mut self, scaling: GradientScaling) -> Self {
144 self.scaling = scaling;
145 self
146 }
147}
148
149impl Default for GradientConfig {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155pub type BackwardFn<T, E> = Box<dyn Fn(&T, &[T]) -> Result<Vec<T>, E>>;
157
158pub struct CustomGradientRegistry<T, E> {
160 gradients: HashMap<String, BackwardFn<T, E>>,
161}
162
163impl<T, E> CustomGradientRegistry<T, E> {
164 pub fn new() -> Self {
165 CustomGradientRegistry {
166 gradients: HashMap::new(),
167 }
168 }
169
170 pub fn register<F>(&mut self, operation_name: String, backward_fn: F)
172 where
173 F: Fn(&T, &[T]) -> Result<Vec<T>, E> + 'static,
174 {
175 self.gradients.insert(operation_name, Box::new(backward_fn));
176 }
177
178 pub fn get(&self, operation_name: &str) -> Option<&BackwardFn<T, E>> {
180 self.gradients.get(operation_name)
181 }
182
183 pub fn has_custom_gradient(&self, operation_name: &str) -> bool {
185 self.gradients.contains_key(operation_name)
186 }
187
188 pub fn unregister(&mut self, operation_name: &str) -> bool {
190 self.gradients.remove(operation_name).is_some()
191 }
192
193 pub fn len(&self) -> usize {
195 self.gradients.len()
196 }
197
198 pub fn is_empty(&self) -> bool {
199 self.gradients.is_empty()
200 }
201}
202
203impl<T, E> Default for CustomGradientRegistry<T, E> {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209#[derive(Debug, Clone)]
211pub struct GradientStats {
212 pub global_norm: f64,
213 pub min_value: f64,
214 pub max_value: f64,
215 pub mean_value: f64,
216 pub num_parameters: usize,
217 pub num_finite: usize,
218 pub num_infinite: usize,
219 pub num_nan: usize,
220}
221
222impl GradientStats {
223 pub fn new() -> Self {
224 GradientStats {
225 global_norm: 0.0,
226 min_value: f64::INFINITY,
227 max_value: f64::NEG_INFINITY,
228 mean_value: 0.0,
229 num_parameters: 0,
230 num_finite: 0,
231 num_infinite: 0,
232 num_nan: 0,
233 }
234 }
235
236 pub fn has_nan(&self) -> bool {
237 self.num_nan > 0
238 }
239
240 pub fn has_inf(&self) -> bool {
241 self.num_infinite > 0
242 }
243
244 pub fn is_healthy(&self) -> bool {
245 !self.has_nan() && !self.has_inf()
246 }
247
248 pub fn finite_ratio(&self) -> f64 {
249 if self.num_parameters == 0 {
250 return 0.0;
251 }
252 (self.num_finite as f64) / (self.num_parameters as f64)
253 }
254}
255
256impl Default for GradientStats {
257 fn default() -> Self {
258 Self::new()
259 }
260}
261
262pub trait TlEnhancedAutodiff {
264 type Tensor;
265 type Tape;
266 type Error;
267
268 fn forward_with_accumulation(
270 &mut self,
271 graph: &EinsumGraph,
272 config: &AccumulationConfig,
273 ) -> Result<Self::Tensor, Self::Error>;
274
275 fn backward_with_clipping(
277 &mut self,
278 graph: &EinsumGraph,
279 loss: &Self::Tensor,
280 strategy: ClippingStrategy,
281 ) -> Result<Self::Tape, Self::Error>;
282
283 fn scale_gradients(
285 &mut self,
286 gradients: &mut Self::Tape,
287 scaling: &GradientScaling,
288 ) -> Result<(), Self::Error>;
289
290 fn gradient_stats(&self, gradients: &Self::Tape) -> Result<GradientStats, Self::Error>;
292
293 fn register_custom_gradient(
295 &mut self,
296 operation_name: String,
297 backward_fn: BackwardFn<Self::Tensor, Self::Error>,
298 );
299
300 fn has_custom_gradient(&self, operation_name: &str) -> bool;
302}
303
304pub struct GradientAccumulator<T> {
306 accumulated_gradients: Vec<T>,
307 accumulation_count: usize,
308 config: AccumulationConfig,
309}
310
311impl<T: Clone> GradientAccumulator<T> {
312 pub fn new(config: AccumulationConfig) -> Self {
313 GradientAccumulator {
314 accumulated_gradients: Vec::new(),
315 accumulation_count: 0,
316 config,
317 }
318 }
319
320 pub fn accumulate(&mut self, gradients: Vec<T>) {
322 if self.accumulated_gradients.is_empty() {
323 self.accumulated_gradients = gradients;
324 } else {
325 self.accumulated_gradients = gradients;
327 }
328 self.accumulation_count += 1;
329 }
330
331 pub fn is_ready(&self) -> bool {
333 self.accumulation_count >= self.config.accumulation_steps
334 }
335
336 pub fn step(&mut self) -> Vec<T> {
338 let gradients = self.accumulated_gradients.clone();
339
340 if self.config.clear_after_step {
341 self.clear();
342 }
343
344 gradients
345 }
346
347 pub fn clear(&mut self) {
349 self.accumulated_gradients.clear();
350 self.accumulation_count = 0;
351 }
352
353 pub fn count(&self) -> usize {
355 self.accumulation_count
356 }
357
358 pub fn config(&self) -> &AccumulationConfig {
359 &self.config
360 }
361}
362
363pub struct GradientClipper {
365 strategy: ClippingStrategy,
366 num_clips: usize,
367}
368
369impl GradientClipper {
370 pub fn new(strategy: ClippingStrategy) -> Self {
371 GradientClipper {
372 strategy,
373 num_clips: 0,
374 }
375 }
376
377 pub fn should_clip(&self, value: f64) -> bool {
379 match self.strategy {
380 ClippingStrategy::None => false,
381 ClippingStrategy::ByValue { min, max } => value < min || value > max,
382 ClippingStrategy::ByGlobalNorm { max_norm: _ } => {
383 false
385 }
386 ClippingStrategy::ByLayerNorm { max_norm: _ } => {
387 false
389 }
390 }
391 }
392
393 pub fn clip_value(&mut self, value: f64) -> f64 {
395 match self.strategy {
396 ClippingStrategy::None => value,
397 ClippingStrategy::ByValue { min, max } => {
398 if value < min || value > max {
399 self.num_clips += 1;
400 }
401 value.clamp(min, max)
402 }
403 ClippingStrategy::ByGlobalNorm { max_norm: _ } => value,
404 ClippingStrategy::ByLayerNorm { max_norm: _ } => value,
405 }
406 }
407
408 pub fn num_clips(&self) -> usize {
410 self.num_clips
411 }
412
413 pub fn reset(&mut self) {
415 self.num_clips = 0;
416 }
417
418 pub fn strategy(&self) -> ClippingStrategy {
419 self.strategy
420 }
421}
422
423pub struct GradientScaler {
425 config: GradientScaling,
426 current_scale: f64,
427 growth_tracker: usize,
428}
429
430impl GradientScaler {
431 pub fn new(config: GradientScaling) -> Self {
432 let current_scale = config.initial_scale;
433 GradientScaler {
434 config,
435 current_scale,
436 growth_tracker: 0,
437 }
438 }
439
440 pub fn scale(&self, value: f64) -> f64 {
442 if !self.config.enabled {
443 return value;
444 }
445 value * self.current_scale
446 }
447
448 pub fn unscale(&self, value: f64) -> f64 {
450 if !self.config.enabled {
451 return value;
452 }
453 value / self.current_scale
454 }
455
456 pub fn update(&mut self, gradients_healthy: bool) {
458 if !self.config.enabled {
459 return;
460 }
461
462 if gradients_healthy {
463 self.growth_tracker += 1;
464 if self.growth_tracker >= self.config.growth_interval {
465 self.current_scale *= self.config.growth_factor;
466 self.growth_tracker = 0;
467 }
468 } else {
469 self.current_scale *= self.config.backoff_factor;
471 self.growth_tracker = 0;
472 }
473 }
474
475 pub fn get_scale(&self) -> f64 {
477 self.current_scale
478 }
479
480 pub fn config(&self) -> &GradientScaling {
481 &self.config
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_accumulation_config() {
491 let config = AccumulationConfig::standard(4);
492 assert_eq!(config.strategy, GradientAccumulationStrategy::Standard);
493 assert_eq!(config.accumulation_steps, 4);
494 assert!(config.clear_after_step);
495 }
496
497 #[test]
498 fn test_clipping_strategy() {
499 let none = ClippingStrategy::None;
500 let by_value = ClippingStrategy::ByValue {
501 min: -1.0,
502 max: 1.0,
503 };
504 let by_norm = ClippingStrategy::ByGlobalNorm { max_norm: 1.0 };
505
506 assert_eq!(none, ClippingStrategy::None);
507 assert_ne!(by_value, none);
508 assert_ne!(by_norm, by_value);
509 }
510
511 #[test]
512 fn test_gradient_config() {
513 let config = GradientConfig::new()
514 .with_accumulation(AccumulationConfig::average(4))
515 .with_clipping(ClippingStrategy::ByValue {
516 min: -1.0,
517 max: 1.0,
518 });
519
520 assert_eq!(
521 config.accumulation.strategy,
522 GradientAccumulationStrategy::Average
523 );
524 assert_eq!(config.accumulation.accumulation_steps, 4);
525 }
526
527 #[test]
528 fn test_gradient_scaling() {
529 let scaling = GradientScaling::new(1024.0);
530 assert!(scaling.enabled);
531 assert_eq!(scaling.initial_scale, 1024.0);
532 assert_eq!(scaling.growth_factor, 2.0);
533
534 let disabled = GradientScaling::disabled();
535 assert!(!disabled.enabled);
536 }
537
538 #[test]
539 fn test_gradient_stats() {
540 let mut stats = GradientStats::new();
541 stats.num_parameters = 100;
542 stats.num_finite = 95;
543 stats.num_nan = 5;
544 stats.num_infinite = 0;
545
546 assert!(stats.has_nan());
547 assert!(!stats.has_inf());
548 assert!(!stats.is_healthy());
549 assert_eq!(stats.finite_ratio(), 0.95);
550 }
551
552 #[test]
553 fn test_custom_gradient_registry() {
554 let mut registry: CustomGradientRegistry<f64, String> = CustomGradientRegistry::new();
555
556 registry.register("custom_op".to_string(), |_output, _inputs| {
557 Ok(vec![1.0, 2.0, 3.0])
558 });
559
560 assert!(registry.has_custom_gradient("custom_op"));
561 assert!(!registry.has_custom_gradient("other_op"));
562 assert_eq!(registry.len(), 1);
563 assert!(!registry.is_empty());
564
565 let removed = registry.unregister("custom_op");
566 assert!(removed);
567 assert!(registry.is_empty());
568 }
569
570 #[test]
571 fn test_gradient_accumulator() {
572 let config = AccumulationConfig::standard(3);
573 let mut accumulator: GradientAccumulator<f64> = GradientAccumulator::new(config);
574
575 assert_eq!(accumulator.count(), 0);
576 assert!(!accumulator.is_ready());
577
578 accumulator.accumulate(vec![1.0, 2.0, 3.0]);
579 assert_eq!(accumulator.count(), 1);
580 assert!(!accumulator.is_ready());
581
582 accumulator.accumulate(vec![4.0, 5.0, 6.0]);
583 accumulator.accumulate(vec![7.0, 8.0, 9.0]);
584 assert!(accumulator.is_ready());
585
586 let _gradients = accumulator.step();
587 assert_eq!(accumulator.count(), 0);
588 }
589
590 #[test]
591 fn test_gradient_clipper() {
592 let mut clipper = GradientClipper::new(ClippingStrategy::ByValue {
593 min: -1.0,
594 max: 1.0,
595 });
596
597 assert!(!clipper.should_clip(0.5));
598 assert!(clipper.should_clip(2.0));
599 assert!(clipper.should_clip(-2.0));
600
601 let clipped = clipper.clip_value(2.0);
602 assert_eq!(clipped, 1.0);
603 assert_eq!(clipper.num_clips(), 1);
604
605 let clipped = clipper.clip_value(-2.0);
606 assert_eq!(clipped, -1.0);
607 assert_eq!(clipper.num_clips(), 2);
608
609 clipper.reset();
610 assert_eq!(clipper.num_clips(), 0);
611 }
612
613 #[test]
614 fn test_gradient_scaler() {
615 let config = GradientScaling::new(1024.0);
616 let mut scaler = GradientScaler::new(config);
617
618 assert_eq!(scaler.get_scale(), 1024.0);
619
620 let scaled = scaler.scale(2.0);
621 assert_eq!(scaled, 2048.0);
622
623 let unscaled = scaler.unscale(2048.0);
624 assert_eq!(unscaled, 2.0);
625
626 scaler.growth_tracker = config.growth_interval - 1;
628 scaler.update(true);
629 assert_eq!(scaler.get_scale(), 2048.0); scaler.update(false);
633 assert_eq!(scaler.get_scale(), 1024.0); }
635
636 #[test]
637 fn test_gradient_scaler_disabled() {
638 let config = GradientScaling::disabled();
639 let scaler = GradientScaler::new(config);
640
641 assert_eq!(scaler.scale(2.0), 2.0);
642 assert_eq!(scaler.unscale(2.0), 2.0);
643 }
644}