scirs2_optimize/automatic_differentiation/
reverse_mode.rs

1//! Reverse-mode automatic differentiation (backpropagation)
2//!
3//! Reverse-mode AD is efficient for computing derivatives when the number of
4//! output variables is small (typically 1 for optimization). It builds a
5//! computational graph and then propagates derivatives backwards.
6
7use crate::automatic_differentiation::tape::{
8    BinaryOpType, ComputationTape, TapeNode, UnaryOpType, Variable,
9};
10use crate::error::OptimizeError;
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
12
13/// Options for reverse-mode automatic differentiation
14#[derive(Debug, Clone)]
15pub struct ReverseADOptions {
16    /// Whether to compute gradient
17    pub compute_gradient: bool,
18    /// Whether to compute Hessian
19    pub compute_hessian: bool,
20    /// Maximum tape size to prevent memory issues
21    pub max_tape_size: usize,
22    /// Enable tape optimization
23    pub optimize_tape: bool,
24}
25
26impl Default for ReverseADOptions {
27    fn default() -> Self {
28        Self {
29            compute_gradient: true,
30            compute_hessian: false,
31            max_tape_size: 1_000_000,
32            optimize_tape: true,
33        }
34    }
35}
36
37/// Variable in the computational graph for reverse-mode AD
38#[derive(Debug, Clone)]
39pub struct ReverseVariable {
40    /// Variable index in the tape
41    pub index: usize,
42    /// Current value
43    pub value: f64,
44    /// Accumulated gradient (adjoint)
45    pub grad: f64,
46}
47
48impl ReverseVariable {
49    /// Create a new variable
50    pub fn new(index: usize, value: f64) -> Self {
51        Self {
52            index,
53            value,
54            grad: 0.0,
55        }
56    }
57
58    /// Create a constant variable (not in tape)
59    pub fn constant(value: f64) -> Self {
60        Self {
61            index: usize::MAX, // Special index for constants
62            value,
63            grad: 0.0,
64        }
65    }
66
67    /// Check if this is a constant
68    pub fn is_constant(&self) -> bool {
69        self.index == usize::MAX
70    }
71
72    /// Get the value
73    pub fn value(&self) -> f64 {
74        self.value
75    }
76
77    /// Get the gradient
78    pub fn grad(&self) -> f64 {
79        self.grad
80    }
81
82    /// Set the gradient (used internally by backpropagation)
83    pub fn set_grad(&mut self, grad: f64) {
84        self.grad = grad;
85    }
86
87    /// Add to the gradient (used internally by backpropagation)
88    pub fn add_grad(&mut self, grad: f64) {
89        self.grad += grad;
90    }
91
92    /// Reset gradient to zero
93    pub fn zero_grad(&mut self) {
94        self.grad = 0.0;
95    }
96
97    /// Create a variable from a scalar (convenience method)
98    pub fn from_scalar(value: f64) -> Self {
99        Self::constant(value)
100    }
101
102    /// Power operation (simple version without graph context)
103    pub fn powi(&self, n: i32) -> Self {
104        if self.is_constant() {
105            ReverseVariable::constant(self.value.powi(n))
106        } else {
107            ReverseVariable {
108                index: self.index,
109                value: self.value.powi(n),
110                grad: 0.0,
111            }
112        }
113    }
114
115    /// Exponential operation (simple version without graph context)
116    pub fn exp(&self) -> Self {
117        if self.is_constant() {
118            ReverseVariable::constant(self.value.exp())
119        } else {
120            ReverseVariable {
121                index: self.index,
122                value: self.value.exp(),
123                grad: 0.0,
124            }
125        }
126    }
127
128    /// Natural logarithm operation (simple version without graph context)
129    pub fn ln(&self) -> Self {
130        if self.is_constant() {
131            ReverseVariable::constant(self.value.ln())
132        } else {
133            ReverseVariable {
134                index: self.index,
135                value: self.value.ln(),
136                grad: 0.0,
137            }
138        }
139    }
140
141    /// Sine operation (simple version without graph context)
142    pub fn sin(&self) -> Self {
143        if self.is_constant() {
144            ReverseVariable::constant(self.value.sin())
145        } else {
146            ReverseVariable {
147                index: self.index,
148                value: self.value.sin(),
149                grad: 0.0,
150            }
151        }
152    }
153
154    /// Cosine operation (simple version without graph context)
155    pub fn cos(&self) -> Self {
156        if self.is_constant() {
157            ReverseVariable::constant(self.value.cos())
158        } else {
159            ReverseVariable {
160                index: self.index,
161                value: self.value.cos(),
162                grad: 0.0,
163            }
164        }
165    }
166
167    /// Tangent operation (simple version without graph context)
168    pub fn tan(&self) -> Self {
169        if self.is_constant() {
170            ReverseVariable::constant(self.value.tan())
171        } else {
172            ReverseVariable {
173                index: self.index,
174                value: self.value.tan(),
175                grad: 0.0,
176            }
177        }
178    }
179
180    /// Square root operation (simple version without graph context)
181    pub fn sqrt(&self) -> Self {
182        if self.is_constant() {
183            ReverseVariable::constant(self.value.sqrt())
184        } else {
185            ReverseVariable {
186                index: self.index,
187                value: self.value.sqrt(),
188                grad: 0.0,
189            }
190        }
191    }
192
193    /// Absolute value operation (simple version without graph context)
194    pub fn abs(&self) -> Self {
195        if self.is_constant() {
196            ReverseVariable::constant(self.value.abs())
197        } else {
198            ReverseVariable {
199                index: self.index,
200                value: self.value.abs(),
201                grad: 0.0,
202            }
203        }
204    }
205}
206
207/// Computational graph for reverse-mode AD
208pub struct ComputationGraph {
209    /// Computation tape
210    tape: ComputationTape,
211    /// Current variable counter
212    var_counter: usize,
213    /// Variable values
214    values: Vec<f64>,
215    /// Variable gradients (adjoints)
216    gradients: Vec<f64>,
217}
218
219impl Default for ComputationGraph {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225impl ComputationGraph {
226    /// Create a new computation graph
227    pub fn new() -> Self {
228        Self {
229            tape: ComputationTape::new(),
230            var_counter: 0,
231            values: Vec::new(),
232            gradients: Vec::new(),
233        }
234    }
235
236    /// Create a new variable in the graph
237    pub fn variable(&mut self, value: f64) -> ReverseVariable {
238        let index = self.var_counter;
239        self.var_counter += 1;
240
241        self.values.push(value);
242        self.gradients.push(0.0);
243
244        self.tape.add_input(Variable::new(index, value));
245
246        ReverseVariable::new(index, value)
247    }
248
249    /// Add a binary operation to the tape
250    fn add_binary_op(
251        &mut self,
252        op_type: BinaryOpType,
253        left: &ReverseVariable,
254        right: &ReverseVariable,
255        result_value: f64,
256        left_grad: f64,
257        right_grad: f64,
258    ) -> ReverseVariable {
259        let result_index = self.var_counter;
260        self.var_counter += 1;
261
262        self.values.push(result_value);
263        self.gradients.push(0.0);
264
265        // Add to tape
266        let node = TapeNode::BinaryOp {
267            op_type,
268            left: left.index,
269            right: right.index,
270            result: result_index,
271            left_partial: left_grad,
272            right_partial: right_grad,
273        };
274
275        self.tape.add_node(node);
276
277        ReverseVariable::new(result_index, result_value)
278    }
279
280    /// Add a unary operation to the tape
281    fn add_unary_op(
282        &mut self,
283        op_type: UnaryOpType,
284        input: &ReverseVariable,
285        result_value: f64,
286        input_grad: f64,
287    ) -> ReverseVariable {
288        let result_index = self.var_counter;
289        self.var_counter += 1;
290
291        self.values.push(result_value);
292        self.gradients.push(0.0);
293
294        // Add to tape
295        let node = TapeNode::UnaryOp {
296            op_type,
297            input: input.index,
298            result: result_index,
299            partial: input_grad,
300        };
301
302        self.tape.add_node(node);
303
304        ReverseVariable::new(result_index, result_value)
305    }
306
307    /// Perform backpropagation to compute gradients
308    pub fn backward(&mut self, output_var: &ReverseVariable) -> Result<(), OptimizeError> {
309        // Initialize output gradient to 1
310        if !output_var.is_constant() {
311            self.gradients[output_var.index] = 1.0;
312        }
313
314        // Reverse pass through the tape
315        let _ = self.tape.backward(&mut self.gradients);
316
317        Ok(())
318    }
319
320    /// Get gradient for a variable
321    pub fn get_gradient(&self, var: &ReverseVariable) -> f64 {
322        if var.is_constant() {
323            0.0
324        } else {
325            self.gradients[var.index]
326        }
327    }
328
329    /// Clear gradients for next computation
330    pub fn zero_gradients(&mut self) {
331        for grad in &mut self.gradients {
332            *grad = 0.0;
333        }
334    }
335}
336
337// Arithmetic operations for ReverseVariable
338// Note: These implementations are for simple cases without graph context.
339// For full AD functionality, use the graph-based operations (add, mul, etc.)
340impl std::ops::Add for ReverseVariable {
341    type Output = Self;
342
343    fn add(self, other: Self) -> Self {
344        if self.is_constant() && other.is_constant() {
345            ReverseVariable::constant(self.value + other.value)
346        } else {
347            // For non-constant variables, create a new variable with combined value
348            // This won't track gradients properly - use graph-based operations for AD
349            let result_value = self.value + other.value;
350            let max_index = self.index.max(other.index);
351            ReverseVariable {
352                index: if max_index == usize::MAX {
353                    usize::MAX
354                } else {
355                    max_index + 1
356                },
357                value: result_value,
358                grad: 0.0,
359            }
360        }
361    }
362}
363
364impl std::ops::Sub for ReverseVariable {
365    type Output = Self;
366
367    fn sub(self, other: Self) -> Self {
368        if self.is_constant() && other.is_constant() {
369            ReverseVariable::constant(self.value - other.value)
370        } else {
371            let result_value = self.value - other.value;
372            let max_index = self.index.max(other.index);
373            ReverseVariable {
374                index: if max_index == usize::MAX {
375                    usize::MAX
376                } else {
377                    max_index + 1
378                },
379                value: result_value,
380                grad: 0.0,
381            }
382        }
383    }
384}
385
386impl std::ops::Mul for ReverseVariable {
387    type Output = Self;
388
389    fn mul(self, other: Self) -> Self {
390        if self.is_constant() && other.is_constant() {
391            ReverseVariable::constant(self.value * other.value)
392        } else {
393            let result_value = self.value * other.value;
394            let max_index = self.index.max(other.index);
395            ReverseVariable {
396                index: if max_index == usize::MAX {
397                    usize::MAX
398                } else {
399                    max_index + 1
400                },
401                value: result_value,
402                grad: 0.0,
403            }
404        }
405    }
406}
407
408impl std::ops::Div for ReverseVariable {
409    type Output = Self;
410
411    fn div(self, other: Self) -> Self {
412        if self.is_constant() && other.is_constant() {
413            ReverseVariable::constant(self.value / other.value)
414        } else {
415            let result_value = self.value / other.value;
416            let max_index = self.index.max(other.index);
417            ReverseVariable {
418                index: if max_index == usize::MAX {
419                    usize::MAX
420                } else {
421                    max_index + 1
422                },
423                value: result_value,
424                grad: 0.0,
425            }
426        }
427    }
428}
429
430impl std::ops::Neg for ReverseVariable {
431    type Output = Self;
432
433    fn neg(self) -> Self {
434        if self.is_constant() {
435            ReverseVariable::constant(-self.value)
436        } else {
437            ReverseVariable {
438                index: self.index,
439                value: -self.value,
440                grad: 0.0,
441            }
442        }
443    }
444}
445
446// Scalar operations
447impl std::ops::Add<f64> for ReverseVariable {
448    type Output = Self;
449
450    fn add(self, scalar: f64) -> Self {
451        ReverseVariable {
452            index: self.index,
453            value: self.value + scalar,
454            grad: self.grad,
455        }
456    }
457}
458
459impl std::ops::Sub<f64> for ReverseVariable {
460    type Output = Self;
461
462    fn sub(self, scalar: f64) -> Self {
463        ReverseVariable {
464            index: self.index,
465            value: self.value - scalar,
466            grad: self.grad,
467        }
468    }
469}
470
471impl std::ops::Mul<f64> for ReverseVariable {
472    type Output = Self;
473
474    fn mul(self, scalar: f64) -> Self {
475        ReverseVariable {
476            index: self.index,
477            value: self.value * scalar,
478            grad: self.grad,
479        }
480    }
481}
482
483impl std::ops::Div<f64> for ReverseVariable {
484    type Output = Self;
485
486    fn div(self, scalar: f64) -> Self {
487        ReverseVariable {
488            index: self.index,
489            value: self.value / scalar,
490            grad: self.grad,
491        }
492    }
493}
494
495// Reverse scalar operations (f64 + ReverseVariable, etc.)
496impl std::ops::Add<ReverseVariable> for f64 {
497    type Output = ReverseVariable;
498
499    fn add(self, var: ReverseVariable) -> ReverseVariable {
500        var + self
501    }
502}
503
504impl std::ops::Sub<ReverseVariable> for f64 {
505    type Output = ReverseVariable;
506
507    fn sub(self, var: ReverseVariable) -> ReverseVariable {
508        ReverseVariable {
509            index: var.index,
510            value: self - var.value,
511            grad: var.grad,
512        }
513    }
514}
515
516impl std::ops::Mul<ReverseVariable> for f64 {
517    type Output = ReverseVariable;
518
519    fn mul(self, var: ReverseVariable) -> ReverseVariable {
520        var * self
521    }
522}
523
524impl std::ops::Div<ReverseVariable> for f64 {
525    type Output = ReverseVariable;
526
527    fn div(self, var: ReverseVariable) -> ReverseVariable {
528        ReverseVariable {
529            index: var.index,
530            value: self / var.value,
531            grad: var.grad,
532        }
533    }
534}
535
536/// Addition operation on computation graph
537#[allow(dead_code)]
538pub fn add(
539    graph: &mut ComputationGraph,
540    left: &ReverseVariable,
541    right: &ReverseVariable,
542) -> ReverseVariable {
543    if left.is_constant() && right.is_constant() {
544        return ReverseVariable::constant(left.value + right.value);
545    }
546
547    let result_value = left.value + right.value;
548    graph.add_binary_op(BinaryOpType::Add, left, right, result_value, 1.0, 1.0)
549}
550
551/// Multiplication operation on computation graph
552#[allow(dead_code)]
553pub fn mul(
554    graph: &mut ComputationGraph,
555    left: &ReverseVariable,
556    right: &ReverseVariable,
557) -> ReverseVariable {
558    if left.is_constant() && right.is_constant() {
559        return ReverseVariable::constant(left.value * right.value);
560    }
561
562    let result_value = left.value * right.value;
563    graph.add_binary_op(
564        BinaryOpType::Mul,
565        left,
566        right,
567        result_value,
568        right.value,
569        left.value,
570    )
571}
572
573/// Subtraction operation on computation graph
574#[allow(dead_code)]
575pub fn sub(
576    graph: &mut ComputationGraph,
577    left: &ReverseVariable,
578    right: &ReverseVariable,
579) -> ReverseVariable {
580    if left.is_constant() && right.is_constant() {
581        return ReverseVariable::constant(left.value - right.value);
582    }
583
584    let result_value = left.value - right.value;
585    graph.add_binary_op(BinaryOpType::Sub, left, right, result_value, 1.0, -1.0)
586}
587
588/// Division operation on computation graph
589#[allow(dead_code)]
590pub fn div(
591    graph: &mut ComputationGraph,
592    left: &ReverseVariable,
593    right: &ReverseVariable,
594) -> ReverseVariable {
595    if left.is_constant() && right.is_constant() {
596        return ReverseVariable::constant(left.value / right.value);
597    }
598
599    let result_value = left.value / right.value;
600    let left_grad = 1.0 / right.value;
601    let right_grad = -left.value / (right.value * right.value);
602
603    graph.add_binary_op(
604        BinaryOpType::Div,
605        left,
606        right,
607        result_value,
608        left_grad,
609        right_grad,
610    )
611}
612
613/// Power operation (x^n) on computation graph
614#[allow(dead_code)]
615pub fn powi(graph: &mut ComputationGraph, input: &ReverseVariable, n: i32) -> ReverseVariable {
616    if input.is_constant() {
617        return ReverseVariable::constant(input.value.powi(n));
618    }
619
620    let result_value = input.value.powi(n);
621    let input_grad = (n as f64) * input.value.powi(n - 1);
622
623    graph.add_unary_op(UnaryOpType::Square, input, result_value, input_grad)
624}
625
626/// Exponential operation on computation graph
627#[allow(dead_code)]
628pub fn exp(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
629    if input.is_constant() {
630        return ReverseVariable::constant(input.value.exp());
631    }
632
633    let result_value = input.value.exp();
634    let input_grad = result_value; // d/dx(e^x) = e^x
635
636    graph.add_unary_op(UnaryOpType::Exp, input, result_value, input_grad)
637}
638
639/// Natural logarithm operation on computation graph
640#[allow(dead_code)]
641pub fn ln(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
642    if input.is_constant() {
643        return ReverseVariable::constant(input.value.ln());
644    }
645
646    let result_value = input.value.ln();
647    let input_grad = 1.0 / input.value;
648
649    graph.add_unary_op(UnaryOpType::Ln, input, result_value, input_grad)
650}
651
652/// Sine operation on computation graph
653#[allow(dead_code)]
654pub fn sin(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
655    if input.is_constant() {
656        return ReverseVariable::constant(input.value.sin());
657    }
658
659    let result_value = input.value.sin();
660    let input_grad = input.value.cos();
661
662    graph.add_unary_op(UnaryOpType::Sin, input, result_value, input_grad)
663}
664
665/// Cosine operation on computation graph
666#[allow(dead_code)]
667pub fn cos(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
668    if input.is_constant() {
669        return ReverseVariable::constant(input.value.cos());
670    }
671
672    let result_value = input.value.cos();
673    let input_grad = -input.value.sin();
674
675    graph.add_unary_op(UnaryOpType::Cos, input, result_value, input_grad)
676}
677
678/// Tangent operation on computation graph
679#[allow(dead_code)]
680pub fn tan(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
681    if input.is_constant() {
682        return ReverseVariable::constant(input.value.tan());
683    }
684
685    let result_value = input.value.tan();
686    let cos_val = input.value.cos();
687    let input_grad = 1.0 / (cos_val * cos_val); // sec²(x) = 1/cos²(x)
688
689    graph.add_unary_op(UnaryOpType::Tan, input, result_value, input_grad)
690}
691
692/// Square root operation on computation graph
693#[allow(dead_code)]
694pub fn sqrt(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
695    if input.is_constant() {
696        return ReverseVariable::constant(input.value.sqrt());
697    }
698
699    let result_value = input.value.sqrt();
700    let input_grad = 0.5 / result_value; // d/dx(√x) = 1/(2√x)
701
702    graph.add_unary_op(UnaryOpType::Sqrt, input, result_value, input_grad)
703}
704
705/// Absolute value operation on computation graph
706#[allow(dead_code)]
707pub fn abs(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
708    if input.is_constant() {
709        return ReverseVariable::constant(input.value.abs());
710    }
711
712    let result_value = input.value.abs();
713    let input_grad = if input.value >= 0.0 { 1.0 } else { -1.0 };
714
715    graph.add_unary_op(UnaryOpType::Sqrt, input, result_value, input_grad)
716}
717
718/// Sigmoid operation on computation graph
719#[allow(dead_code)]
720pub fn sigmoid(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
721    if input.is_constant() {
722        let exp_val = (-input.value).exp();
723        return ReverseVariable::constant(1.0 / (1.0 + exp_val));
724    }
725
726    let exp_neg_x = (-input.value).exp();
727    let result_value = 1.0 / (1.0 + exp_neg_x);
728    let input_grad = result_value * (1.0 - result_value); // σ'(x) = σ(x)(1-σ(x))
729
730    graph.add_unary_op(UnaryOpType::Exp, input, result_value, input_grad)
731}
732
733/// Hyperbolic tangent operation on computation graph
734#[allow(dead_code)]
735pub fn tanh(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
736    if input.is_constant() {
737        return ReverseVariable::constant(input.value.tanh());
738    }
739
740    let result_value = input.value.tanh();
741    let input_grad = 1.0 - result_value * result_value; // d/dx(tanh(x)) = 1 - tanh²(x)
742
743    graph.add_unary_op(UnaryOpType::Tan, input, result_value, input_grad)
744}
745
746/// ReLU (Rectified Linear Unit) operation on computation graph
747#[allow(dead_code)]
748pub fn relu(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
749    if input.is_constant() {
750        return ReverseVariable::constant(input.value.max(0.0));
751    }
752
753    let result_value = input.value.max(0.0);
754    let input_grad = if input.value > 0.0 { 1.0 } else { 0.0 };
755
756    graph.add_unary_op(UnaryOpType::Sqrt, input, result_value, input_grad)
757}
758
759/// Leaky ReLU operation on computation graph
760#[allow(dead_code)]
761pub fn leaky_relu(
762    graph: &mut ComputationGraph,
763    input: &ReverseVariable,
764    alpha: f64,
765) -> ReverseVariable {
766    if input.is_constant() {
767        let result = if input.value > 0.0 {
768            input.value
769        } else {
770            alpha * input.value
771        };
772        return ReverseVariable::constant(result);
773    }
774
775    let result_value = if input.value > 0.0 {
776        input.value
777    } else {
778        alpha * input.value
779    };
780    let input_grad = if input.value > 0.0 { 1.0 } else { alpha };
781
782    graph.add_unary_op(UnaryOpType::Sqrt, input, result_value, input_grad)
783}
784
785/// Compute gradient using reverse-mode automatic differentiation
786/// This is a generic function that works with closures, using finite differences
787/// For functions that can be expressed in terms of AD operations, use reverse_gradient_with_tape
788#[allow(dead_code)]
789pub fn reverse_gradient<F>(func: F, x: &ArrayView1<f64>) -> Result<Array1<f64>, OptimizeError>
790where
791    F: Fn(&ArrayView1<f64>) -> f64,
792{
793    // For generic functions, we need to use finite differences
794    // This is because we don't have access to the function's AD representation
795    let n = x.len();
796    let mut gradient = Array1::zeros(n);
797    let h = 1e-8;
798
799    for i in 0..n {
800        let mut x_plus = x.to_owned();
801        x_plus[i] += h;
802        let f_plus = func(&x_plus.view());
803
804        let mut x_minus = x.to_owned();
805        x_minus[i] -= h;
806        let f_minus = func(&x_minus.view());
807
808        gradient[i] = (f_plus - f_minus) / (2.0 * h);
809    }
810
811    Ok(gradient)
812}
813
814/// Compute gradient using reverse-mode AD with a function that directly uses AD operations
815#[allow(dead_code)]
816pub fn reverse_gradient_ad<F>(func: F, x: &ArrayView1<f64>) -> Result<Array1<f64>, OptimizeError>
817where
818    F: Fn(&mut ComputationGraph, &[ReverseVariable]) -> ReverseVariable,
819{
820    let mut graph = ComputationGraph::new();
821
822    // Create input variables
823    let input_vars: Vec<ReverseVariable> = x.iter().map(|&xi| graph.variable(xi)).collect();
824
825    // Evaluate function with the computation graph
826    let output = func(&mut graph, &input_vars);
827
828    // Perform backpropagation
829    graph.backward(&output)?;
830
831    // Extract gradients
832    let mut gradient = Array1::zeros(x.len());
833    for (i, var) in input_vars.iter().enumerate() {
834        gradient[i] = graph.get_gradient(var);
835    }
836
837    Ok(gradient)
838}
839
840/// Compute Hessian using reverse-mode automatic differentiation (finite differences for generic functions)
841#[allow(dead_code)]
842pub fn reverse_hessian<F>(func: F, x: &ArrayView1<f64>) -> Result<Array2<f64>, OptimizeError>
843where
844    F: Fn(&ArrayView1<f64>) -> f64,
845{
846    let n = x.len();
847    let mut hessian = Array2::zeros((n, n));
848    let h = 1e-5;
849
850    // Compute Hessian using finite differences
851    // For generic functions, this is the most practical approach
852    for i in 0..n {
853        for j in 0..n {
854            if i == j {
855                // Diagonal element: f''(x) = (f(x+h) - 2f(x) + f(x-h)) / h²
856                let mut x_plus = x.to_owned();
857                x_plus[i] += h;
858                let f_plus = func(&x_plus.view());
859
860                let f_center = func(x);
861
862                let mut x_minus = x.to_owned();
863                x_minus[i] -= h;
864                let f_minus = func(&x_minus.view());
865
866                hessian[[i, j]] = (f_plus - 2.0 * f_center + f_minus) / (h * h);
867            } else {
868                // Off-diagonal element: mixed partial derivative
869                // Variable names represent plus/minus combinations for finite differences
870                {
871                    #[allow(clippy::similar_names)]
872                    let mut x_pp = x.to_owned();
873                    x_pp[i] += h;
874                    x_pp[j] += h;
875                    #[allow(clippy::similar_names)]
876                    let f_pp = func(&x_pp.view());
877
878                    #[allow(clippy::similar_names)]
879                    let mut x_pm = x.to_owned();
880                    x_pm[i] += h;
881                    x_pm[j] -= h;
882                    #[allow(clippy::similar_names)]
883                    let f_pm = func(&x_pm.view());
884
885                    #[allow(clippy::similar_names)]
886                    let mut x_mp = x.to_owned();
887                    x_mp[i] -= h;
888                    x_mp[j] += h;
889                    #[allow(clippy::similar_names)]
890                    let f_mp = func(&x_mp.view());
891
892                    let mut x_mm = x.to_owned();
893                    x_mm[i] -= h;
894                    x_mm[j] -= h;
895                    let f_mm = func(&x_mm.view());
896
897                    hessian[[i, j]] = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h * h);
898                }
899            }
900        }
901    }
902
903    Ok(hessian)
904}
905
906/// Compute Hessian using forward-over-reverse mode for AD functions
907#[allow(dead_code)]
908pub fn reverse_hessian_ad<F>(func: F, x: &ArrayView1<f64>) -> Result<Array2<f64>, OptimizeError>
909where
910    F: Fn(&mut ComputationGraph, &[ReverseVariable]) -> ReverseVariable,
911{
912    let n = x.len();
913    let mut hessian = Array2::zeros((n, n));
914
915    // Compute Hessian by differentiating the gradient
916    // For each input variable, compute gradient and then differentiate again
917    for i in 0..n {
918        // Create a function that returns the i-th component of the gradient
919        let gradient_i_func = |x_val: &ArrayView1<f64>| -> f64 {
920            let grad = reverse_gradient_ad(&func, x_val).unwrap();
921            grad[i]
922        };
923
924        // Compute the gradient of the i-th gradient component (i-th row of Hessian)
925        let hessian_row = reverse_gradient(gradient_i_func, x)?;
926        for j in 0..n {
927            hessian[[i, j]] = hessian_row[j];
928        }
929    }
930
931    Ok(hessian)
932}
933
934/// Simple reverse-mode gradient computation using a basic tape
935#[allow(dead_code)]
936pub fn reverse_gradient_with_tape<F>(
937    func: F,
938    x: &ArrayView1<f64>,
939    _options: &ReverseADOptions,
940) -> Result<Array1<f64>, OptimizeError>
941where
942    F: Fn(&mut ComputationGraph, &[ReverseVariable]) -> ReverseVariable,
943{
944    let mut graph = ComputationGraph::new();
945
946    // Create input variables
947    let input_vars: Vec<ReverseVariable> = x.iter().map(|&xi| graph.variable(xi)).collect();
948
949    // Evaluate function with the computation graph
950    let output = func(&mut graph, &input_vars);
951
952    // Perform backpropagation
953    graph.backward(&output)?;
954
955    // Extract gradients
956    let mut gradient = Array1::zeros(x.len());
957    for (i, var) in input_vars.iter().enumerate() {
958        gradient[i] = graph.get_gradient(var);
959    }
960
961    Ok(gradient)
962}
963
964/// Check if reverse mode is preferred for the given problem dimensions
965#[allow(dead_code)]
966pub fn is_reverse_mode_efficient(_input_dim: usize, output_dim: usize) -> bool {
967    // Reverse mode is efficient when output dimension is small
968    // Cost is O(output_dim * cost_of_function)
969    output_dim <= 10 || (output_dim <= _input_dim && output_dim <= 20)
970}
971
972/// Vector-Jacobian product using reverse-mode AD
973#[allow(clippy::many_single_char_names)]
974#[allow(dead_code)]
975pub fn reverse_vjp<F>(
976    func: F,
977    x: &ArrayView1<f64>,
978    v: &ArrayView1<f64>,
979) -> Result<Array1<f64>, OptimizeError>
980where
981    F: Fn(&ArrayView1<f64>) -> Array1<f64>,
982{
983    // For vector-valued functions, we use the natural efficiency of reverse-mode AD
984    // which computes v^T * J efficiently by seeding the output with v
985    let n = x.len();
986    let m = v.len();
987
988    // Compute v^T * J by running reverse mode for each output component weighted by v
989    let mut result = Array1::zeros(n);
990
991    // For each output component, compute its contribution to the VJP
992    for i in 0..m {
993        if v[i] != 0.0 {
994            // Create a scalar function that extracts the i-th component
995            let component_func = |x_val: &ArrayView1<f64>| -> f64 {
996                let f_val = func(x_val);
997                f_val[i]
998            };
999
1000            // Compute gradient of this component
1001            let grad_i = reverse_gradient(component_func, x)?;
1002
1003            // Add weighted contribution to result
1004            for j in 0..n {
1005                result[j] += v[i] * grad_i[j];
1006            }
1007        }
1008    }
1009
1010    Ok(result)
1011}
1012
1013/// Vector-Jacobian product using reverse-mode AD for AD-compatible functions
1014#[allow(clippy::many_single_char_names)]
1015#[allow(dead_code)]
1016pub fn reverse_vjp_ad<F>(
1017    func: F,
1018    x: &ArrayView1<f64>,
1019    v: &ArrayView1<f64>,
1020) -> Result<Array1<f64>, OptimizeError>
1021where
1022    F: Fn(&mut ComputationGraph, &[ReverseVariable]) -> Vec<ReverseVariable>,
1023{
1024    let n = x.len();
1025    let m = v.len();
1026    let mut result = Array1::zeros(n);
1027
1028    // For each output component with non-zero weight
1029    for i in 0..m {
1030        if v[i] != 0.0 {
1031            let mut graph = ComputationGraph::new();
1032
1033            // Create input variables
1034            let input_vars: Vec<ReverseVariable> = x.iter().map(|&xi| graph.variable(xi)).collect();
1035
1036            // Evaluate function
1037            let outputs = func(&mut graph, &input_vars);
1038
1039            // Seed the i-th output with 1.0 and perform backpropagation
1040            if i < outputs.len() {
1041                graph.backward(&outputs[i])?;
1042
1043                // Add weighted contribution to result
1044                for (j, var) in input_vars.iter().enumerate() {
1045                    result[j] += v[i] * graph.get_gradient(var);
1046                }
1047            }
1048        }
1049    }
1050
1051    Ok(result)
1052}
1053
1054/// Gauss-Newton Hessian approximation using reverse-mode AD
1055#[allow(dead_code)]
1056pub fn reverse_gauss_newton_hessian<F>(
1057    func: F,
1058    x: &ArrayView1<f64>,
1059) -> Result<Array2<f64>, OptimizeError>
1060where
1061    F: Fn(&ArrayView1<f64>) -> Array1<f64>,
1062{
1063    // Compute Gauss-Newton approximation: H ≈ J^T * J efficiently using reverse-mode AD
1064    let n = x.len();
1065    let f_val = func(x);
1066    let m = f_val.len();
1067
1068    // Use reverse-mode AD to compute J^T * J directly without forming J explicitly
1069    let mut hessian = Array2::zeros((n, n));
1070
1071    // For each output component, compute its contribution to the Gauss-Newton Hessian
1072    for i in 0..m {
1073        // Create a scalar function for the i-th residual component
1074        let residual_i = |x_val: &ArrayView1<f64>| -> f64 {
1075            let f_val = func(x_val);
1076            f_val[i]
1077        };
1078
1079        // Compute gradient of this residual
1080        let grad_i = reverse_gradient(residual_i, x)?;
1081
1082        // Add outer product grad_i * grad_i^T to the Hessian
1083        for j in 0..n {
1084            for k in 0..n {
1085                hessian[[j, k]] += grad_i[j] * grad_i[k];
1086            }
1087        }
1088    }
1089
1090    Ok(hessian)
1091}
1092
1093/// Gauss-Newton Hessian approximation using reverse-mode AD for AD-compatible functions
1094#[allow(dead_code)]
1095pub fn reverse_gauss_newton_hessian_ad<F>(
1096    func: F,
1097    x: &ArrayView1<f64>,
1098) -> Result<Array2<f64>, OptimizeError>
1099where
1100    F: Fn(&mut ComputationGraph, &[ReverseVariable]) -> Vec<ReverseVariable>,
1101{
1102    let n = x.len();
1103    let mut hessian = Array2::zeros((n, n));
1104
1105    // Get function values to determine output dimension
1106    let mut graph_temp = ComputationGraph::new();
1107    let input_vars_temp: Vec<ReverseVariable> =
1108        x.iter().map(|&xi| graph_temp.variable(xi)).collect();
1109    let outputs_temp = func(&mut graph_temp, &input_vars_temp);
1110    let m = outputs_temp.len();
1111
1112    // For each output component, compute its contribution to the Gauss-Newton Hessian
1113    for i in 0..m {
1114        let mut graph = ComputationGraph::new();
1115
1116        // Create input variables
1117        let input_vars: Vec<ReverseVariable> = x.iter().map(|&xi| graph.variable(xi)).collect();
1118
1119        // Evaluate function
1120        let outputs = func(&mut graph, &input_vars);
1121
1122        // Compute gradient of the i-th output component
1123        if i < outputs.len() {
1124            graph.backward(&outputs[i])?;
1125
1126            // Extract gradients
1127            let mut grad_i = Array1::zeros(n);
1128            for (j, var) in input_vars.iter().enumerate() {
1129                grad_i[j] = graph.get_gradient(var);
1130            }
1131
1132            // Add outer product grad_i * grad_i^T to the Hessian
1133            for j in 0..n {
1134                for k in 0..n {
1135                    hessian[[j, k]] += grad_i[j] * grad_i[k];
1136                }
1137            }
1138        }
1139    }
1140
1141    Ok(hessian)
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146    use super::*;
1147    use approx::assert_abs_diff_eq;
1148
1149    #[test]
1150    fn test_computation_graph() {
1151        let mut graph = ComputationGraph::new();
1152
1153        // Create variables: x = 2, y = 3
1154        let x = graph.variable(2.0);
1155        let y = graph.variable(3.0);
1156
1157        // Compute z = x * y + x
1158        let xy = mul(&mut graph, &x, &y);
1159        let z = add(&mut graph, &xy, &x);
1160
1161        assert_abs_diff_eq!(z.value, 8.0, epsilon = 1e-10); // 2*3 + 2 = 8
1162
1163        // Perform backpropagation
1164        graph.backward(&z).unwrap();
1165
1166        // Check gradients: ∂z/∂x = y + 1 = 4, ∂z/∂y = x = 2
1167        assert_abs_diff_eq!(graph.get_gradient(&x), 4.0, epsilon = 1e-10);
1168        assert_abs_diff_eq!(graph.get_gradient(&y), 2.0, epsilon = 1e-10);
1169    }
1170
1171    #[test]
1172    fn test_unary_operations() {
1173        let mut graph = ComputationGraph::new();
1174
1175        let x = graph.variable(1.0);
1176        let exp_x = exp(&mut graph, &x);
1177
1178        assert_abs_diff_eq!(exp_x.value, std::f64::consts::E, epsilon = 1e-10);
1179
1180        graph.backward(&exp_x).unwrap();
1181
1182        // ∂(e^x)/∂x = e^x
1183        assert_abs_diff_eq!(graph.get_gradient(&x), std::f64::consts::E, epsilon = 1e-10);
1184    }
1185
1186    #[test]
1187    fn test_reverse_gradient() {
1188        // Test function: f(x, y) = x² + xy + 2y²
1189        let func = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[0] * x[1] + 2.0 * x[1] * x[1] };
1190
1191        let x = Array1::from_vec(vec![1.0, 2.0]);
1192        let grad = reverse_gradient(func, &x.view()).unwrap();
1193
1194        // ∂f/∂x = 2x + y = 2(1) + 2 = 4
1195        // ∂f/∂y = x + 4y = 1 + 4(2) = 9
1196        assert_abs_diff_eq!(grad[0], 4.0, epsilon = 1e-6);
1197        assert_abs_diff_eq!(grad[1], 9.0, epsilon = 1e-6);
1198    }
1199
1200    #[test]
1201    fn test_is_reverse_mode_efficient() {
1202        // Small output dimension should prefer reverse mode
1203        assert!(is_reverse_mode_efficient(100, 1));
1204        assert!(is_reverse_mode_efficient(50, 5));
1205
1206        // Large output dimension should not prefer reverse mode
1207        assert!(!is_reverse_mode_efficient(10, 100));
1208    }
1209
1210    #[test]
1211    fn test_reverse_hessian() {
1212        // Test function: f(x, y) = x² + xy + 2y²
1213        let func = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[0] * x[1] + 2.0 * x[1] * x[1] };
1214
1215        let x = Array1::from_vec(vec![1.0, 2.0]);
1216        let hess = reverse_hessian(func, &x.view()).unwrap();
1217
1218        // Expected Hessian:
1219        // ∂²f/∂x² = 2, ∂²f/∂x∂y = 1
1220        // ∂²f/∂y∂x = 1, ∂²f/∂y² = 4
1221        assert_abs_diff_eq!(hess[[0, 0]], 2.0, epsilon = 1e-4);
1222        assert_abs_diff_eq!(hess[[0, 1]], 1.0, epsilon = 1e-4);
1223        assert_abs_diff_eq!(hess[[1, 0]], 1.0, epsilon = 1e-4);
1224        assert_abs_diff_eq!(hess[[1, 1]], 4.0, epsilon = 1e-4);
1225    }
1226
1227    #[test]
1228    fn test_reverse_gradient_ad() {
1229        // Test function: f(x, y) = x² + xy + 2y²
1230        let func = |graph: &mut ComputationGraph, vars: &[ReverseVariable]| {
1231            let x = &vars[0];
1232            let y = &vars[1];
1233
1234            let x_squared = mul(graph, x, x);
1235            let xy = mul(graph, x, y);
1236            let y_squared = mul(graph, y, y);
1237            let two_y_squared = mul(graph, &ReverseVariable::constant(2.0), &y_squared);
1238
1239            let temp = add(graph, &x_squared, &xy);
1240            add(graph, &temp, &two_y_squared)
1241        };
1242
1243        let x = Array1::from_vec(vec![1.0, 2.0]);
1244        let grad = reverse_gradient_ad(func, &x.view()).unwrap();
1245
1246        // ∂f/∂x = 2x + y = 2(1) + 2 = 4
1247        // ∂f/∂y = x + 4y = 1 + 4(2) = 9
1248        assert_abs_diff_eq!(grad[0], 4.0, epsilon = 1e-10);
1249        assert_abs_diff_eq!(grad[1], 9.0, epsilon = 1e-10);
1250    }
1251
1252    #[test]
1253    fn test_reverse_vjp() {
1254        // Test vector function: f(x) = [x₀², x₀x₁, x₁²]
1255        let func = |x: &ArrayView1<f64>| -> Array1<f64> {
1256            Array1::from_vec(vec![x[0] * x[0], x[0] * x[1], x[1] * x[1]])
1257        };
1258
1259        let x = Array1::from_vec(vec![2.0, 3.0]);
1260        let v = Array1::from_vec(vec![1.0, 1.0, 1.0]);
1261        let vjp = reverse_vjp(func, &x.view(), &v.view()).unwrap();
1262
1263        // Jacobian at (2,3):
1264        // ∂f₀/∂x₀ = 2x₀ = 4, ∂f₀/∂x₁ = 0
1265        // ∂f₁/∂x₀ = x₁ = 3,  ∂f₁/∂x₁ = x₀ = 2
1266        // ∂f₂/∂x₀ = 0,      ∂f₂/∂x₁ = 2x₁ = 6
1267
1268        // v^T * J = [1,1,1] * [[4,0], [3,2], [0,6]] = [7, 8]
1269        assert_abs_diff_eq!(vjp[0], 7.0, epsilon = 1e-6);
1270        assert_abs_diff_eq!(vjp[1], 8.0, epsilon = 1e-6);
1271    }
1272
1273    #[test]
1274    fn test_reverse_gauss_newton_hessian() {
1275        // Test residual function: r(x) = [x₀ - 1, x₁ - 2]
1276        let residual_func =
1277            |x: &ArrayView1<f64>| -> Array1<f64> { Array1::from_vec(vec![x[0] - 1.0, x[1] - 2.0]) };
1278
1279        let x = Array1::from_vec(vec![0.0, 0.0]);
1280        let gn_hess = reverse_gauss_newton_hessian(residual_func, &x.view()).unwrap();
1281
1282        // Jacobian is identity matrix, so J^T * J should be identity
1283        assert_abs_diff_eq!(gn_hess[[0, 0]], 1.0, epsilon = 1e-6);
1284        assert_abs_diff_eq!(gn_hess[[0, 1]], 0.0, epsilon = 1e-6);
1285        assert_abs_diff_eq!(gn_hess[[1, 0]], 0.0, epsilon = 1e-6);
1286        assert_abs_diff_eq!(gn_hess[[1, 1]], 1.0, epsilon = 1e-6);
1287    }
1288
1289    #[test]
1290    fn test_power_operation() {
1291        let mut graph = ComputationGraph::new();
1292
1293        let x = graph.variable(2.0);
1294        let x_cubed = powi(&mut graph, &x, 3);
1295
1296        assert_abs_diff_eq!(x_cubed.value, 8.0, epsilon = 1e-10); // 2³ = 8
1297
1298        graph.backward(&x_cubed).unwrap();
1299
1300        // ∂(x³)/∂x = 3x² = 3(4) = 12 at x=2
1301        assert_abs_diff_eq!(graph.get_gradient(&x), 12.0, epsilon = 1e-10);
1302    }
1303
1304    #[test]
1305    fn test_trigonometric_operations() {
1306        let mut graph = ComputationGraph::new();
1307
1308        let x = graph.variable(0.0);
1309        let sin_x = sin(&mut graph, &x);
1310        let cos_x = cos(&mut graph, &x);
1311
1312        assert_abs_diff_eq!(sin_x.value, 0.0, epsilon = 1e-10); // sin(0) = 0
1313        assert_abs_diff_eq!(cos_x.value, 1.0, epsilon = 1e-10); // cos(0) = 1
1314
1315        graph.backward(&sin_x).unwrap();
1316        assert_abs_diff_eq!(graph.get_gradient(&x), 1.0, epsilon = 1e-10); // d/dx(sin(x)) = cos(x) = 1 at x=0
1317
1318        graph.zero_gradients();
1319        graph.backward(&cos_x).unwrap();
1320        assert_abs_diff_eq!(graph.get_gradient(&x), 0.0, epsilon = 1e-10); // d/dx(cos(x)) = -sin(x) = 0 at x=0
1321    }
1322
1323    #[test]
1324    fn test_arithmetic_operations_without_graph() {
1325        // Test arithmetic operations that work without explicit graph context
1326        let a = ReverseVariable::constant(3.0);
1327        let b = ReverseVariable::constant(2.0);
1328
1329        // Test addition
1330        let sum = a.clone() + b.clone();
1331        assert_abs_diff_eq!(sum.value, 5.0, epsilon = 1e-10);
1332        assert!(sum.is_constant());
1333
1334        // Test subtraction
1335        let diff = a.clone() - b.clone();
1336        assert_abs_diff_eq!(diff.value, 1.0, epsilon = 1e-10);
1337
1338        // Test multiplication
1339        let product = a.clone() * b.clone();
1340        assert_abs_diff_eq!(product.value, 6.0, epsilon = 1e-10);
1341
1342        // Test division
1343        let quotient = a.clone() / b.clone();
1344        assert_abs_diff_eq!(quotient.value, 1.5, epsilon = 1e-10);
1345
1346        // Test negation
1347        let neg_a = -a.clone();
1348        assert_abs_diff_eq!(neg_a.value, -3.0, epsilon = 1e-10);
1349    }
1350
1351    #[test]
1352    fn test_scalar_operations() {
1353        let var = ReverseVariable::constant(4.0);
1354
1355        // Test scalar addition
1356        let result = var.clone() + 2.0;
1357        assert_abs_diff_eq!(result.value, 6.0, epsilon = 1e-10);
1358
1359        // Test reverse scalar addition
1360        let result = 2.0 + var.clone();
1361        assert_abs_diff_eq!(result.value, 6.0, epsilon = 1e-10);
1362
1363        // Test scalar multiplication
1364        let result = var.clone() * 3.0;
1365        assert_abs_diff_eq!(result.value, 12.0, epsilon = 1e-10);
1366
1367        // Test scalar division
1368        let result = var.clone() / 2.0;
1369        assert_abs_diff_eq!(result.value, 2.0, epsilon = 1e-10);
1370
1371        // Test reverse scalar division
1372        let result = 8.0 / var.clone();
1373        assert_abs_diff_eq!(result.value, 2.0, epsilon = 1e-10);
1374    }
1375
1376    #[test]
1377    fn test_mathematical_functions_without_graph() {
1378        let var = ReverseVariable::constant(4.0);
1379
1380        // Test power
1381        let result = var.powi(2);
1382        assert_abs_diff_eq!(result.value, 16.0, epsilon = 1e-10);
1383
1384        // Test square root
1385        let result = var.sqrt();
1386        assert_abs_diff_eq!(result.value, 2.0, epsilon = 1e-10);
1387
1388        // Test exponential
1389        let var_zero = ReverseVariable::constant(0.0);
1390        let result = var_zero.exp();
1391        assert_abs_diff_eq!(result.value, 1.0, epsilon = 1e-10);
1392
1393        // Test natural logarithm
1394        let var_e = ReverseVariable::constant(std::f64::consts::E);
1395        let result = var_e.ln();
1396        assert_abs_diff_eq!(result.value, 1.0, epsilon = 1e-10);
1397
1398        // Test trigonometric functions
1399        let var_zero = ReverseVariable::constant(0.0);
1400        assert_abs_diff_eq!(var_zero.sin().value, 0.0, epsilon = 1e-10);
1401        assert_abs_diff_eq!(var_zero.cos().value, 1.0, epsilon = 1e-10);
1402        assert_abs_diff_eq!(var_zero.tan().value, 0.0, epsilon = 1e-10);
1403    }
1404
1405    #[test]
1406    fn test_advanced_operations_with_graph() {
1407        let mut graph = ComputationGraph::new();
1408
1409        // Test sigmoid function
1410        let x = graph.variable(0.0);
1411        let sig = sigmoid(&mut graph, &x);
1412        assert_abs_diff_eq!(sig.value, 0.5, epsilon = 1e-10); // sigmoid(0) = 0.5
1413
1414        graph.backward(&sig).unwrap();
1415        assert_abs_diff_eq!(graph.get_gradient(&x), 0.25, epsilon = 1e-10); // sigmoid'(0) = 0.25
1416
1417        // Test ReLU function
1418        graph.zero_gradients();
1419        let x_pos = graph.variable(2.0);
1420        let relu_pos = relu(&mut graph, &x_pos);
1421        assert_abs_diff_eq!(relu_pos.value, 2.0, epsilon = 1e-10);
1422
1423        graph.backward(&relu_pos).unwrap();
1424        assert_abs_diff_eq!(graph.get_gradient(&x_pos), 1.0, epsilon = 1e-10); // ReLU'(2) = 1
1425
1426        // Test ReLU for negative input
1427        let mut graph2 = ComputationGraph::new();
1428        let x_neg = graph2.variable(-1.0);
1429        let relu_neg = relu(&mut graph2, &x_neg);
1430        assert_abs_diff_eq!(relu_neg.value, 0.0, epsilon = 1e-10);
1431
1432        graph2.backward(&relu_neg).unwrap();
1433        assert_abs_diff_eq!(graph2.get_gradient(&x_neg), 0.0, epsilon = 1e-10); // ReLU'(-1) = 0
1434    }
1435
1436    #[test]
1437    fn test_leaky_relu() {
1438        let mut graph = ComputationGraph::new();
1439
1440        // Test Leaky ReLU with positive input
1441        let x_pos = graph.variable(2.0);
1442        let leaky_pos = leaky_relu(&mut graph, &x_pos, 0.01);
1443        assert_abs_diff_eq!(leaky_pos.value, 2.0, epsilon = 1e-10);
1444
1445        graph.backward(&leaky_pos).unwrap();
1446        assert_abs_diff_eq!(graph.get_gradient(&x_pos), 1.0, epsilon = 1e-10);
1447
1448        // Test Leaky ReLU with negative input
1449        let mut graph2 = ComputationGraph::new();
1450        let x_neg = graph2.variable(-2.0);
1451        let leaky_neg = leaky_relu(&mut graph2, &x_neg, 0.01);
1452        assert_abs_diff_eq!(leaky_neg.value, -0.02, epsilon = 1e-10);
1453
1454        graph2.backward(&leaky_neg).unwrap();
1455        assert_abs_diff_eq!(graph2.get_gradient(&x_neg), 0.01, epsilon = 1e-10);
1456    }
1457
1458    #[test]
1459    fn test_complex_expression() {
1460        let mut graph = ComputationGraph::new();
1461
1462        // Test complex expression: f(x, y) = sigmoid(x² + y) * tanh(x - y)
1463        let x = graph.variable(1.0);
1464        let y = graph.variable(0.5);
1465
1466        let x_squared = mul(&mut graph, &x, &x);
1467        let x_sq_plus_y = add(&mut graph, &x_squared, &y);
1468        let sig_term = sigmoid(&mut graph, &x_sq_plus_y);
1469
1470        let x_minus_y = sub(&mut graph, &x, &y);
1471        let tanh_term = tanh(&mut graph, &x_minus_y);
1472
1473        let result = mul(&mut graph, &sig_term, &tanh_term);
1474
1475        // Verify the computation produces a reasonable result
1476        assert!(result.value.is_finite());
1477        assert!(result.value > 0.0); // Both sigmoid and tanh(0.5) are positive
1478
1479        // Test backpropagation
1480        graph.backward(&result).unwrap();
1481
1482        // Gradients should be finite and non-zero
1483        let grad_x = graph.get_gradient(&x);
1484        let grad_y = graph.get_gradient(&y);
1485
1486        assert!(grad_x.is_finite());
1487        assert!(grad_y.is_finite());
1488        assert!(grad_x != 0.0);
1489        assert!(grad_y != 0.0);
1490    }
1491}