trustformers_core/autodiff/
engine.rs1#![allow(unused_variables)] use super::graph::ComputationGraph;
9use super::tape::GradientTape;
10use super::variable::{GraphRef, Variable};
11use crate::errors::{tensor_op_error, Result};
12use crate::tensor::Tensor;
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex, OnceLock};
15
16#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum GradientMode {
19 Forward,
21 Reverse,
23 Mixed,
25}
26
27#[derive(Debug, Clone)]
29pub struct AutodiffConfig {
30 pub mode: GradientMode,
32 pub enabled: bool,
34 pub detect_anomalies: bool,
36 pub retain_graph: bool,
38 pub max_cache_size: usize,
40 pub optimize_graph: bool,
42 pub gradient_checkpointing: bool,
44}
45
46impl Default for AutodiffConfig {
47 fn default() -> Self {
48 Self {
49 mode: GradientMode::Reverse,
50 enabled: true,
51 detect_anomalies: false,
52 retain_graph: false,
53 max_cache_size: 10000,
54 optimize_graph: true,
55 gradient_checkpointing: false,
56 }
57 }
58}
59
60#[derive(Debug)]
62pub struct AutodiffEngine {
63 config: AutodiffConfig,
65 graph: GraphRef,
67 tape: Arc<Mutex<GradientTape>>,
69 #[allow(dead_code)]
71 operation_cache: Arc<Mutex<HashMap<String, CompiledOperation>>>,
72 stats: Arc<Mutex<AutodiffStats>>,
74}
75
76#[derive(Debug, Clone)]
78pub struct CompiledOperation {
79 pub id: String,
81 pub forward_fn: fn(&[&Tensor]) -> Result<Tensor>,
83 pub backward_fn: fn(&Tensor, &[&Tensor]) -> Result<Vec<Tensor>>,
85 pub metadata: OperationMetadata,
87}
88
89#[derive(Debug, Clone)]
91pub struct OperationMetadata {
92 pub op_type: String,
94 pub input_shapes: Vec<Vec<usize>>,
96 pub output_shape: Vec<usize>,
98 pub num_parameters: usize,
100 pub estimated_flops: usize,
102}
103
104#[derive(Debug, Default, Clone)]
106pub struct AutodiffStats {
107 pub forward_passes: u64,
109 pub backward_passes: u64,
111 pub total_operations: u64,
113 pub cache_hits: u64,
115 pub cache_misses: u64,
117 pub forward_time_us: u64,
119 pub backward_time_us: u64,
121 pub peak_memory_usage: usize,
123 pub current_memory_usage: usize,
125}
126
127impl Default for AutodiffEngine {
128 fn default() -> Self {
129 Self::new(AutodiffConfig::default())
130 }
131}
132
133impl AutodiffEngine {
134 pub fn new(config: AutodiffConfig) -> Self {
136 let graph = Arc::new(Mutex::new(ComputationGraph::new()));
137 let tape = Arc::new(Mutex::new(GradientTape::new()));
138 let operation_cache = Arc::new(Mutex::new(HashMap::new()));
139 let stats = Arc::new(Mutex::new(AutodiffStats::default()));
140
141 Self {
142 config,
143 graph,
144 tape,
145 operation_cache,
146 stats,
147 }
148 }
149
150 pub fn enable_grad(&mut self) {
152 self.config.enabled = true;
153 }
154
155 pub fn disable_grad(&mut self) {
157 self.config.enabled = false;
158 }
159
160 pub fn is_grad_enabled(&self) -> bool {
162 self.config.enabled
163 }
164
165 pub fn set_mode(&mut self, mode: GradientMode) {
167 self.config.mode = mode;
168 }
169
170 pub fn mode(&self) -> GradientMode {
172 self.config.mode
173 }
174
175 pub fn enable_anomaly_detection(&mut self) {
177 self.config.detect_anomalies = true;
178 }
179
180 pub fn disable_anomaly_detection(&mut self) {
182 self.config.detect_anomalies = false;
183 }
184
185 pub fn variable(&self, tensor: Tensor, requires_grad: bool) -> Variable {
187 Variable::from_graph(
188 self.graph.clone(),
189 {
190 let mut graph = self.graph.lock().expect("lock should not be poisoned");
191 graph.add_node(tensor, requires_grad, None)
192 },
193 requires_grad,
194 )
195 }
196
197 pub fn variable_with_name(
199 &self,
200 tensor: Tensor,
201 requires_grad: bool,
202 name: String,
203 ) -> Variable {
204 Variable::from_graph(
205 self.graph.clone(),
206 {
207 let mut graph = self.graph.lock().expect("lock should not be poisoned");
208 graph.add_node(tensor, requires_grad, Some(name))
209 },
210 requires_grad,
211 )
212 }
213
214 pub fn backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
216 let start_time = std::time::Instant::now();
217
218 match self.config.mode {
219 GradientMode::Forward => self.forward_mode_backward(output, grad_output),
220 GradientMode::Reverse => self.reverse_mode_backward(output, grad_output),
221 GradientMode::Mixed => self.mixed_mode_backward(output, grad_output),
222 }?;
223
224 let mut stats = self.stats.lock().expect("lock should not be poisoned");
226 stats.backward_passes += 1;
227 stats.backward_time_us += start_time.elapsed().as_micros() as u64;
228
229 Ok(())
230 }
231
232 fn forward_mode_backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
234 let mut graph = self.graph.lock().expect("lock should not be poisoned");
237 graph.backward(output.node_id(), grad_output)
238 }
239
240 fn reverse_mode_backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
242 let mut graph = self.graph.lock().expect("lock should not be poisoned");
243 graph.backward(output.node_id(), grad_output)
244 }
245
246 fn mixed_mode_backward(&self, output: &Variable, grad_output: Option<Tensor>) -> Result<()> {
248 let graph = self.graph.lock().expect("lock should not be poisoned");
250 let num_nodes = graph.num_nodes();
251
252 if num_nodes < 100 {
254 drop(graph);
255 self.forward_mode_backward(output, grad_output)
256 } else {
257 drop(graph);
258 self.reverse_mode_backward(output, grad_output)
259 }
260 }
261
262 pub fn zero_grad(&self) {
264 let mut graph = self.graph.lock().expect("lock should not be poisoned");
265 graph.zero_grad();
266 }
267
268 pub fn get_grad(&self, variable: &Variable) -> Result<Option<Tensor>> {
270 let graph = self.graph.lock().expect("lock should not be poisoned");
271 Ok(graph.get_gradient(variable.node_id()).cloned())
272 }
273
274 pub fn clear_graph(&self) {
276 let mut graph = self.graph.lock().expect("lock should not be poisoned");
277 *graph = ComputationGraph::new();
278
279 let mut tape = self.tape.lock().expect("lock should not be poisoned");
280 tape.clear();
281 }
282
283 pub fn stats(&self) -> AutodiffStats {
285 let stats = self.stats.lock().expect("lock should not be poisoned");
286 stats.clone()
287 }
288
289 pub fn reset_stats(&self) {
291 let mut stats = self.stats.lock().expect("lock should not be poisoned");
292 *stats = AutodiffStats::default();
293 }
294
295 pub fn graph(&self) -> GraphRef {
297 self.graph.clone()
298 }
299
300 pub fn optimize_graph(&self) -> Result<()> {
302 if !self.config.optimize_graph {
303 return Ok(());
304 }
305
306 let mut graph = self.graph.lock().expect("lock should not be poisoned");
307
308 self.eliminate_dead_nodes(&mut graph)?;
310 self.fuse_operations(&mut graph)?;
311 self.optimize_memory_layout(&mut graph)?;
312
313 Ok(())
314 }
315
316 fn eliminate_dead_nodes(&self, graph: &mut ComputationGraph) -> Result<()> {
318 Ok(())
321 }
322
323 fn fuse_operations(&self, graph: &mut ComputationGraph) -> Result<()> {
325 Ok(())
328 }
329
330 fn optimize_memory_layout(&self, graph: &mut ComputationGraph) -> Result<()> {
332 Ok(())
335 }
336
337 pub fn no_grad<F, R>(&mut self, f: F) -> R
339 where
340 F: FnOnce() -> R,
341 {
342 let was_enabled = self.config.enabled;
343 self.config.enabled = false;
344
345 let result = f();
346
347 self.config.enabled = was_enabled;
349
350 result
351 }
352
353 pub fn with_grad<F, R>(&mut self, f: F) -> R
355 where
356 F: FnOnce() -> R,
357 {
358 let was_enabled = self.config.enabled;
359 self.config.enabled = true;
360
361 let result = f();
362
363 self.config.enabled = was_enabled;
365
366 result
367 }
368
369 pub fn check_anomalies(&self, variable: &Variable) -> Result<()> {
371 if !self.config.detect_anomalies {
372 return Ok(());
373 }
374
375 if let Some(grad) = self.get_grad(variable)? {
376 let grad_values = grad.to_vec_f32()?;
377
378 for &value in &grad_values {
379 if value.is_nan() {
380 return Err(tensor_op_error(
381 "AutodiffEngine::check_anomalies",
382 "NaN detected in gradient",
383 ));
384 }
385 if value.is_infinite() {
386 return Err(tensor_op_error(
387 "AutodiffEngine::check_anomalies",
388 "Infinite value detected in gradient",
389 ));
390 }
391 }
392 }
393
394 Ok(())
395 }
396
397 pub fn enable_checkpointing(&mut self) {
399 self.config.gradient_checkpointing = true;
400 }
401
402 pub fn disable_checkpointing(&mut self) {
404 self.config.gradient_checkpointing = false;
405 }
406
407 pub fn is_checkpointing_enabled(&self) -> bool {
409 self.config.gradient_checkpointing
410 }
411
412 pub fn export_graph(&self) -> Result<String> {
414 let graph = self.graph.lock().expect("lock should not be poisoned");
415 let graph_export = graph.export_graph();
416
417 let mut dot = String::from("digraph G {\n");
419 dot.push_str(" rankdir=TB;\n");
420
421 for node in &graph_export.nodes {
422 let node_label = if let Some(ref name) = node.name {
423 name.clone()
424 } else {
425 format!("node_{}", node.id)
426 };
427
428 let op_label = if let Some(ref op) = node.operation {
429 format!("{:?}", op)
430 } else {
431 "Variable".to_string()
432 };
433
434 dot.push_str(&format!(
435 " {} [label=\"{}\\n{}\\n{:?}\"];\n",
436 node.id, node_label, op_label, node.shape
437 ));
438
439 for parent_id in &node.parents {
440 dot.push_str(&format!(" {} -> {};\n", parent_id, node.id));
441 }
442 }
443
444 dot.push_str("}\n");
445 Ok(dot)
446 }
447
448 pub fn memory_info(&self) -> Result<MemoryInfo> {
450 let graph = self.graph.lock().expect("lock should not be poisoned");
451 let mut total_memory = 0;
452 let mut num_tensors = 0;
453
454 for node in graph.export_graph().nodes {
455 total_memory += node.value.memory_usage();
456 num_tensors += 1;
457
458 if let Some(ref grad) = node.gradient {
459 total_memory += grad.memory_usage();
460 num_tensors += 1;
461 }
462 }
463
464 Ok(MemoryInfo {
465 total_memory_bytes: total_memory,
466 num_tensors,
467 num_nodes: graph.num_nodes(),
468 })
469 }
470}
471
472#[derive(Debug, Clone)]
474pub struct MemoryInfo {
475 pub total_memory_bytes: usize,
477 pub num_tensors: usize,
479 pub num_nodes: usize,
481}
482
483static GLOBAL_ENGINE: OnceLock<Arc<Mutex<AutodiffEngine>>> = OnceLock::new();
485
486pub fn init_engine(config: AutodiffConfig) {
488 let _ = GLOBAL_ENGINE.set(Arc::new(Mutex::new(AutodiffEngine::new(config))));
489}
490
491pub fn get_engine() -> Arc<Mutex<AutodiffEngine>> {
493 GLOBAL_ENGINE
494 .get_or_init(|| Arc::new(Mutex::new(AutodiffEngine::new(AutodiffConfig::default()))))
495 .clone()
496}
497
498pub struct GradContext {
500 previous_state: bool,
501}
502
503impl GradContext {
504 pub fn enable() -> Self {
506 let engine = get_engine();
507 let previous_state = engine.lock().expect("Lock poisoned").is_grad_enabled();
508 engine.lock().expect("Lock poisoned").enable_grad();
509
510 Self { previous_state }
511 }
512
513 pub fn disable() -> Self {
515 let engine = get_engine();
516 let previous_state = engine.lock().expect("Lock poisoned").is_grad_enabled();
517 engine.lock().expect("Lock poisoned").disable_grad();
518
519 Self { previous_state }
520 }
521}
522
523impl Drop for GradContext {
524 fn drop(&mut self) {
525 let engine = get_engine();
526 if self.previous_state {
527 engine.lock().expect("Lock poisoned").enable_grad();
528 } else {
529 engine.lock().expect("Lock poisoned").disable_grad();
530 }
531 }
532}
533
534#[macro_export]
536macro_rules! no_grad {
537 ($($stmt:stmt)*) => {
538 {
539 let _ctx = $crate::autodiff::engine::GradContext::disable();
540 $($stmt)*
541 }
542 };
543}
544
545#[macro_export]
546macro_rules! with_grad {
547 ($($stmt:stmt)*) => {
548 {
549 let _ctx = $crate::autodiff::engine::GradContext::enable();
550 $($stmt)*
551 }
552 };
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558 use crate::tensor::Tensor;
559
560 #[test]
561 fn test_engine_creation() {
562 let config = AutodiffConfig::default();
563 let engine = AutodiffEngine::new(config);
564
565 assert!(engine.is_grad_enabled());
566 assert_eq!(engine.mode(), GradientMode::Reverse);
567 }
568
569 #[test]
570 fn test_variable_creation() {
571 let engine = AutodiffEngine::default();
572 let tensor = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
573 let var = engine.variable(tensor, true);
574
575 assert!(var.requires_grad());
576 assert_eq!(var.shape().expect("operation failed in test"), vec![2, 3]);
577 }
578
579 #[test]
580 fn test_gradient_computation() {
581 let engine = AutodiffEngine::default();
582
583 let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
584 let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
585 let c = a.mul(&b).expect("Multiplication failed");
586
587 engine.backward(&c, None).expect("operation failed in test");
588
589 let grad_a = engine
590 .get_grad(&a)
591 .expect("operation failed in test")
592 .expect("operation failed in test");
593 let grad_b = engine
594 .get_grad(&b)
595 .expect("operation failed in test")
596 .expect("operation failed in test");
597
598 assert_eq!(grad_a.to_scalar().expect("operation failed in test"), 3.0);
599 assert_eq!(grad_b.to_scalar().expect("operation failed in test"), 2.0);
600 }
601
602 #[test]
603 fn test_grad_context() {
604 let engine = AutodiffEngine::default();
605 assert!(engine.is_grad_enabled());
606
607 {
608 let _ctx = GradContext::disable();
609 assert!(!get_engine().lock().expect("Lock poisoned").is_grad_enabled());
610 }
611
612 assert!(get_engine().lock().expect("Lock poisoned").is_grad_enabled());
614 }
615
616 #[test]
617 fn test_engine_stats() {
618 let engine = AutodiffEngine::default();
619 let stats = engine.stats();
620
621 assert_eq!(stats.forward_passes, 0);
622 assert_eq!(stats.backward_passes, 0);
623 }
624
625 #[test]
626 fn test_memory_info() {
627 let engine = AutodiffEngine::default();
628 let tensor = Tensor::ones(&[100, 100]).expect("Failed to create ones tensor");
629 let _var = engine.variable(tensor, true);
630
631 let memory_info = engine.memory_info().expect("operation failed in test");
632 assert!(memory_info.total_memory_bytes > 0);
633 assert!(memory_info.num_tensors > 0);
634 assert!(memory_info.num_nodes > 0);
635 }
636
637 #[test]
638 fn test_anomaly_detection() {
639 let config = AutodiffConfig {
640 detect_anomalies: true,
641 ..Default::default()
642 };
643 let engine = AutodiffEngine::new(config);
644
645 let var = engine.variable(Tensor::scalar(1.0).expect("tensor operation failed"), true);
646 let result = engine.check_anomalies(&var);
647
648 assert!(result.is_ok());
649 }
650
651 #[test]
652 fn test_graph_export() {
653 let engine = AutodiffEngine::default();
654 let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
655 let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
656 let _c = a.mul(&b).expect("Multiplication failed");
657
658 let dot_graph = engine.export_graph().expect("operation failed in test");
659 assert!(dot_graph.contains("digraph G"));
660 assert!(dot_graph.contains("->"));
661 }
662}