Skip to main content

trustformers_core/autodiff/
engine.rs

1//! Automatic differentiation engine.
2//!
3//! This module provides the main engine for managing automatic differentiation,
4//! including gradient computation modes and optimization settings.
5
6#![allow(unused_variables)] // Autodiff engine with reserved parameters
7
8use super::graph::ComputationGraph;
9use super::tape::GradientTape;
10use super::variable::{GraphRef, Variable};
11use crate::errors::{tensor_op_error, Result};
12use crate::tensor::Tensor;
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex, OnceLock};
15
16/// Gradient computation modes
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum GradientMode {
19    /// Forward-mode automatic differentiation
20    Forward,
21    /// Reverse-mode automatic differentiation (backpropagation)
22    Reverse,
23    /// Mixed-mode (forward for small graphs, reverse for large graphs)
24    Mixed,
25}
26
27/// Configuration for the automatic differentiation engine
28#[derive(Debug, Clone)]
29pub struct AutodiffConfig {
30    /// Gradient computation mode
31    pub mode: GradientMode,
32    /// Whether to enable gradient computation
33    pub enabled: bool,
34    /// Whether to detect anomalies in gradient computation
35    pub detect_anomalies: bool,
36    /// Whether to retain the computational graph after backward pass
37    pub retain_graph: bool,
38    /// Maximum number of operations to cache
39    pub max_cache_size: usize,
40    /// Whether to use graph optimization
41    pub optimize_graph: bool,
42    /// Whether to enable gradient checkpointing
43    pub gradient_checkpointing: bool,
44}
45
46impl Default for AutodiffConfig {
47    fn default() -> Self {
48        Self {
49            mode: GradientMode::Reverse,
50            enabled: true,
51            detect_anomalies: false,
52            retain_graph: false,
53            max_cache_size: 10000,
54            optimize_graph: true,
55            gradient_checkpointing: false,
56        }
57    }
58}
59
60/// Main automatic differentiation engine
61#[derive(Debug)]
62pub struct AutodiffEngine {
63    /// Configuration
64    config: AutodiffConfig,
65    /// Current computation graph
66    graph: GraphRef,
67    /// Gradient tape for recording operations
68    tape: Arc<Mutex<GradientTape>>,
69    /// Cache for compiled operations
70    #[allow(dead_code)]
71    operation_cache: Arc<Mutex<HashMap<String, CompiledOperation>>>,
72    /// Statistics
73    stats: Arc<Mutex<AutodiffStats>>,
74}
75
76/// Compiled operation for performance optimization
77#[derive(Debug, Clone)]
78pub struct CompiledOperation {
79    /// Operation ID
80    pub id: String,
81    /// Compiled function
82    pub forward_fn: fn(&[&Tensor]) -> Result<Tensor>,
83    /// Compiled backward function
84    pub backward_fn: fn(&Tensor, &[&Tensor]) -> Result<Vec<Tensor>>,
85    /// Operation metadata
86    pub metadata: OperationMetadata,
87}
88
89/// Metadata for operations
90#[derive(Debug, Clone)]
91pub struct OperationMetadata {
92    /// Operation type
93    pub op_type: String,
94    /// Input shapes
95    pub input_shapes: Vec<Vec<usize>>,
96    /// Output shape
97    pub output_shape: Vec<usize>,
98    /// Number of parameters
99    pub num_parameters: usize,
100    /// Estimated FLOPS
101    pub estimated_flops: usize,
102}
103
104/// Statistics for the autodiff engine
105#[derive(Debug, Default, Clone)]
106pub struct AutodiffStats {
107    /// Number of forward passes
108    pub forward_passes: u64,
109    /// Number of backward passes
110    pub backward_passes: u64,
111    /// Total operations executed
112    pub total_operations: u64,
113    /// Cache hits
114    pub cache_hits: u64,
115    /// Cache misses
116    pub cache_misses: u64,
117    /// Total time spent in forward pass (microseconds)
118    pub forward_time_us: u64,
119    /// Total time spent in backward pass (microseconds)
120    pub backward_time_us: u64,
121    /// Peak memory usage (bytes)
122    pub peak_memory_usage: usize,
123    /// Current memory usage (bytes)
124    pub current_memory_usage: usize,
125}
126
127impl Default for AutodiffEngine {
128    fn default() -> Self {
129        Self::new(AutodiffConfig::default())
130    }
131}
132
133impl AutodiffEngine {
134    /// Create a new autodiff engine
135    pub fn new(config: AutodiffConfig) -> Self {
136        let graph = Arc::new(Mutex::new(ComputationGraph::new()));
137        let tape = Arc::new(Mutex::new(GradientTape::new()));
138        let operation_cache = Arc::new(Mutex::new(HashMap::new()));
139        let stats = Arc::new(Mutex::new(AutodiffStats::default()));
140
141        Self {
142            config,
143            graph,
144            tape,
145            operation_cache,
146            stats,
147        }
148    }
149
150    /// Enable gradient computation
151    pub fn enable_grad(&mut self) {
152        self.config.enabled = true;
153    }
154
155    /// Disable gradient computation
156    pub fn disable_grad(&mut self) {
157        self.config.enabled = false;
158    }
159
160    /// Check if gradient computation is enabled
161    pub fn is_grad_enabled(&self) -> bool {
162        self.config.enabled
163    }
164
165    /// Set gradient computation mode
166    pub fn set_mode(&mut self, mode: GradientMode) {
167        self.config.mode = mode;
168    }
169
170    /// Get current gradient computation mode
171    pub fn mode(&self) -> GradientMode {
172        self.config.mode
173    }
174
175    /// Enable anomaly detection
176    pub fn enable_anomaly_detection(&mut self) {
177        self.config.detect_anomalies = true;
178    }
179
180    /// Disable anomaly detection
181    pub fn disable_anomaly_detection(&mut self) {
182        self.config.detect_anomalies = false;
183    }
184
185    /// Create a new variable
186    pub fn variable(&self, tensor: Tensor, requires_grad: bool) -> Variable {
187        Variable::from_graph(
188            self.graph.clone(),
189            {
190                let mut graph = self.graph.lock().expect("lock should not be poisoned");
191                graph.add_node(tensor, requires_grad, None)
192            },
193            requires_grad,
194        )
195    }
196
197    /// Create a new variable with a name
198    pub fn variable_with_name(
199        &self,
200        tensor: Tensor,
201        requires_grad: bool,
202        name: String,
203    ) -> Variable {
204        Variable::from_graph(
205            self.graph.clone(),
206            {
207                let mut graph = self.graph.lock().expect("lock should not be poisoned");
208                graph.add_node(tensor, requires_grad, Some(name))
209            },
210            requires_grad,
211        )
212    }
213
214    /// Compute gradients using the current mode
215    pub fn backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
216        let start_time = std::time::Instant::now();
217
218        match self.config.mode {
219            GradientMode::Forward => self.forward_mode_backward(output, grad_output),
220            GradientMode::Reverse => self.reverse_mode_backward(output, grad_output),
221            GradientMode::Mixed => self.mixed_mode_backward(output, grad_output),
222        }?;
223
224        // Update statistics
225        let mut stats = self.stats.lock().expect("lock should not be poisoned");
226        stats.backward_passes += 1;
227        stats.backward_time_us += start_time.elapsed().as_micros() as u64;
228
229        Ok(())
230    }
231
232    /// Forward-mode automatic differentiation
233    fn forward_mode_backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
234        // Forward-mode AD is typically used for computing derivatives with respect to few inputs
235        // This is a simplified implementation
236        let mut graph = self.graph.lock().expect("lock should not be poisoned");
237        graph.backward(output.node_id(), grad_output)
238    }
239
240    /// Reverse-mode automatic differentiation (standard backpropagation)
241    fn reverse_mode_backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
242        let mut graph = self.graph.lock().expect("lock should not be poisoned");
243        graph.backward(output.node_id(), grad_output)
244    }
245
246    /// Mixed-mode automatic differentiation
247    fn mixed_mode_backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
248        // Decide between forward and reverse mode based on graph characteristics
249        let graph = self.graph.lock().expect("lock should not be poisoned");
250        let num_nodes = graph.num_nodes();
251
252        // Use forward mode for small graphs, reverse mode for large graphs
253        if num_nodes < 100 {
254            drop(graph);
255            self.forward_mode_backward(output, grad_output)
256        } else {
257            drop(graph);
258            self.reverse_mode_backward(output, grad_output)
259        }
260    }
261
262    /// Zero all gradients in the computation graph
263    pub fn zero_grad(&self) {
264        let mut graph = self.graph.lock().expect("lock should not be poisoned");
265        graph.zero_grad();
266    }
267
268    /// Get gradient for a variable
269    pub fn get_grad(&self, variable: &Variable) -> Result<Option<Tensor>> {
270        let graph = self.graph.lock().expect("lock should not be poisoned");
271        Ok(graph.get_gradient(variable.node_id()).cloned())
272    }
273
274    /// Clear the computation graph
275    pub fn clear_graph(&self) {
276        let mut graph = self.graph.lock().expect("lock should not be poisoned");
277        *graph = ComputationGraph::new();
278
279        let mut tape = self.tape.lock().expect("lock should not be poisoned");
280        tape.clear();
281    }
282
283    /// Get engine statistics
284    pub fn stats(&self) -> AutodiffStats {
285        let stats = self.stats.lock().expect("lock should not be poisoned");
286        stats.clone()
287    }
288
289    /// Reset statistics
290    pub fn reset_stats(&self) {
291        let mut stats = self.stats.lock().expect("lock should not be poisoned");
292        *stats = AutodiffStats::default();
293    }
294
295    /// Get the current graph
296    pub fn graph(&self) -> GraphRef {
297        self.graph.clone()
298    }
299
300    /// Optimize the computation graph
301    pub fn optimize_graph(&self) -> Result<()> {
302        if !self.config.optimize_graph {
303            return Ok(());
304        }
305
306        let mut graph = self.graph.lock().expect("lock should not be poisoned");
307
308        // Perform various graph optimizations
309        self.eliminate_dead_nodes(&mut graph)?;
310        self.fuse_operations(&mut graph)?;
311        self.optimize_memory_layout(&mut graph)?;
312
313        Ok(())
314    }
315
316    /// Eliminate dead nodes (nodes with no children)
317    fn eliminate_dead_nodes(&self, graph: &mut ComputationGraph) -> Result<()> {
318        // This is a simplified implementation
319        // In practice, you would identify and remove nodes that don't contribute to the output
320        Ok(())
321    }
322
323    /// Fuse operations where possible
324    fn fuse_operations(&self, graph: &mut ComputationGraph) -> Result<()> {
325        // This is a simplified implementation
326        // In practice, you would identify patterns like Add+Mul and fuse them into FusedAddMul
327        Ok(())
328    }
329
330    /// Optimize memory layout
331    fn optimize_memory_layout(&self, graph: &mut ComputationGraph) -> Result<()> {
332        // This is a simplified implementation
333        // In practice, you would reorder operations to minimize memory usage
334        Ok(())
335    }
336
337    /// Execute a function with gradient computation disabled
338    pub fn no_grad<F, R>(&mut self, f: F) -> R
339    where
340        F: FnOnce() -> R,
341    {
342        let was_enabled = self.config.enabled;
343        self.config.enabled = false;
344
345        let result = f();
346
347        // Restore original state
348        self.config.enabled = was_enabled;
349
350        result
351    }
352
353    /// Execute a function with gradients enabled
354    pub fn with_grad<F, R>(&mut self, f: F) -> R
355    where
356        F: FnOnce() -> R,
357    {
358        let was_enabled = self.config.enabled;
359        self.config.enabled = true;
360
361        let result = f();
362
363        // Restore original state
364        self.config.enabled = was_enabled;
365
366        result
367    }
368
369    /// Check for gradient anomalies
370    pub fn check_anomalies(&self, variable: &Variable) -> Result<()> {
371        if !self.config.detect_anomalies {
372            return Ok(());
373        }
374
375        if let Some(grad) = self.get_grad(variable)? {
376            let grad_values = grad.to_vec_f32()?;
377
378            for &value in &grad_values {
379                if value.is_nan() {
380                    return Err(tensor_op_error(
381                        "AutodiffEngine::check_anomalies",
382                        "NaN detected in gradient",
383                    ));
384                }
385                if value.is_infinite() {
386                    return Err(tensor_op_error(
387                        "AutodiffEngine::check_anomalies",
388                        "Infinite value detected in gradient",
389                    ));
390                }
391            }
392        }
393
394        Ok(())
395    }
396
397    /// Enable gradient checkpointing
398    pub fn enable_checkpointing(&mut self) {
399        self.config.gradient_checkpointing = true;
400    }
401
402    /// Disable gradient checkpointing
403    pub fn disable_checkpointing(&mut self) {
404        self.config.gradient_checkpointing = false;
405    }
406
407    /// Check if gradient checkpointing is enabled
408    pub fn is_checkpointing_enabled(&self) -> bool {
409        self.config.gradient_checkpointing
410    }
411
412    /// Export computation graph for visualization
413    pub fn export_graph(&self) -> Result<String> {
414        let graph = self.graph.lock().expect("lock should not be poisoned");
415        let graph_export = graph.export_graph();
416
417        // Convert to DOT format for visualization
418        let mut dot = String::from("digraph G {\n");
419        dot.push_str("  rankdir=TB;\n");
420
421        for node in &graph_export.nodes {
422            let node_label = if let Some(ref name) = node.name {
423                name.clone()
424            } else {
425                format!("node_{}", node.id)
426            };
427
428            let op_label = if let Some(ref op) = node.operation {
429                format!("{:?}", op)
430            } else {
431                "Variable".to_string()
432            };
433
434            dot.push_str(&format!(
435                "  {} [label=\"{}\\n{}\\n{:?}\"];\n",
436                node.id, node_label, op_label, node.shape
437            ));
438
439            for parent_id in &node.parents {
440                dot.push_str(&format!("  {} -> {};\n", parent_id, node.id));
441            }
442        }
443
444        dot.push_str("}\n");
445        Ok(dot)
446    }
447
448    /// Get memory usage information
449    pub fn memory_info(&self) -> Result<MemoryInfo> {
450        let graph = self.graph.lock().expect("lock should not be poisoned");
451        let mut total_memory = 0;
452        let mut num_tensors = 0;
453
454        for node in graph.export_graph().nodes {
455            total_memory += node.value.memory_usage();
456            num_tensors += 1;
457
458            if let Some(ref grad) = node.gradient {
459                total_memory += grad.memory_usage();
460                num_tensors += 1;
461            }
462        }
463
464        Ok(MemoryInfo {
465            total_memory_bytes: total_memory,
466            num_tensors,
467            num_nodes: graph.num_nodes(),
468        })
469    }
470}
471
472/// Memory usage information
473#[derive(Debug, Clone)]
474pub struct MemoryInfo {
475    /// Total memory usage in bytes
476    pub total_memory_bytes: usize,
477    /// Number of tensors
478    pub num_tensors: usize,
479    /// Number of graph nodes
480    pub num_nodes: usize,
481}
482
483/// Global autodiff engine instance
484static GLOBAL_ENGINE: OnceLock<Arc<Mutex<AutodiffEngine>>> = OnceLock::new();
485
486/// Initialize the global autodiff engine
487pub fn init_engine(config: AutodiffConfig) {
488    let _ = GLOBAL_ENGINE.set(Arc::new(Mutex::new(AutodiffEngine::new(config))));
489}
490
491/// Get the global autodiff engine
492pub fn get_engine() -> Arc<Mutex<AutodiffEngine>> {
493    GLOBAL_ENGINE
494        .get_or_init(|| Arc::new(Mutex::new(AutodiffEngine::new(AutodiffConfig::default()))))
495        .clone()
496}
497
498/// Context manager for gradient computation
499pub struct GradContext {
500    previous_state: bool,
501}
502
503impl GradContext {
504    /// Create a new context with gradients enabled
505    pub fn enable() -> Self {
506        let engine = get_engine();
507        let previous_state = engine.lock().expect("Lock poisoned").is_grad_enabled();
508        engine.lock().expect("Lock poisoned").enable_grad();
509
510        Self { previous_state }
511    }
512
513    /// Create a new context with gradients disabled
514    pub fn disable() -> Self {
515        let engine = get_engine();
516        let previous_state = engine.lock().expect("Lock poisoned").is_grad_enabled();
517        engine.lock().expect("Lock poisoned").disable_grad();
518
519        Self { previous_state }
520    }
521}
522
523impl Drop for GradContext {
524    fn drop(&mut self) {
525        let engine = get_engine();
526        if self.previous_state {
527            engine.lock().expect("Lock poisoned").enable_grad();
528        } else {
529            engine.lock().expect("Lock poisoned").disable_grad();
530        }
531    }
532}
533
534/// Convenience macros for gradient contexts
535#[macro_export]
536macro_rules! no_grad {
537    ($($stmt:stmt)*) => {
538        {
539            let _ctx = $crate::autodiff::engine::GradContext::disable();
540            $($stmt)*
541        }
542    };
543}
544
545#[macro_export]
546macro_rules! with_grad {
547    ($($stmt:stmt)*) => {
548        {
549            let _ctx = $crate::autodiff::engine::GradContext::enable();
550            $($stmt)*
551        }
552    };
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use crate::tensor::Tensor;
559
560    #[test]
561    fn test_engine_creation() {
562        let config = AutodiffConfig::default();
563        let engine = AutodiffEngine::new(config);
564
565        assert!(engine.is_grad_enabled());
566        assert_eq!(engine.mode(), GradientMode::Reverse);
567    }
568
569    #[test]
570    fn test_variable_creation() {
571        let engine = AutodiffEngine::default();
572        let tensor = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
573        let var = engine.variable(tensor, true);
574
575        assert!(var.requires_grad());
576        assert_eq!(var.shape().expect("operation failed in test"), vec![2, 3]);
577    }
578
579    #[test]
580    fn test_gradient_computation() {
581        let engine = AutodiffEngine::default();
582
583        let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
584        let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
585        let c = a.mul(&b).expect("Multiplication failed");
586
587        engine.backward(&c, None).expect("operation failed in test");
588
589        let grad_a = engine
590            .get_grad(&a)
591            .expect("operation failed in test")
592            .expect("operation failed in test");
593        let grad_b = engine
594            .get_grad(&b)
595            .expect("operation failed in test")
596            .expect("operation failed in test");
597
598        assert_eq!(grad_a.to_scalar().expect("operation failed in test"), 3.0);
599        assert_eq!(grad_b.to_scalar().expect("operation failed in test"), 2.0);
600    }
601
602    #[test]
603    fn test_grad_context() {
604        let engine = AutodiffEngine::default();
605        assert!(engine.is_grad_enabled());
606
607        {
608            let _ctx = GradContext::disable();
609            assert!(!get_engine().lock().expect("Lock poisoned").is_grad_enabled());
610        }
611
612        // Should be restored after context ends
613        assert!(get_engine().lock().expect("Lock poisoned").is_grad_enabled());
614    }
615
616    #[test]
617    fn test_engine_stats() {
618        let engine = AutodiffEngine::default();
619        let stats = engine.stats();
620
621        assert_eq!(stats.forward_passes, 0);
622        assert_eq!(stats.backward_passes, 0);
623    }
624
625    #[test]
626    fn test_memory_info() {
627        let engine = AutodiffEngine::default();
628        let tensor = Tensor::ones(&[100, 100]).expect("Failed to create ones tensor");
629        let _var = engine.variable(tensor, true);
630
631        let memory_info = engine.memory_info().expect("operation failed in test");
632        assert!(memory_info.total_memory_bytes > 0);
633        assert!(memory_info.num_tensors > 0);
634        assert!(memory_info.num_nodes > 0);
635    }
636
637    #[test]
638    fn test_anomaly_detection() {
639        let config = AutodiffConfig {
640            detect_anomalies: true,
641            ..Default::default()
642        };
643        let engine = AutodiffEngine::new(config);
644
645        let var = engine.variable(Tensor::scalar(1.0).expect("tensor operation failed"), true);
646        let result = engine.check_anomalies(&var);
647
648        assert!(result.is_ok());
649    }
650
651    #[test]
652    fn test_graph_export() {
653        let engine = AutodiffEngine::default();
654        let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
655        let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
656        let _c = a.mul(&b).expect("Multiplication failed");
657
658        let dot_graph = engine.export_graph().expect("operation failed in test");
659        assert!(dot_graph.contains("digraph G"));
660        assert!(dot_graph.contains("->"));
661    }
662}