scirs2_integrate/autodiff/
reverse.rs

1//! Reverse mode automatic differentiation (backpropagation)
2//!
3//! Reverse mode AD is efficient for computing gradients when the number of
4//! outputs is small compared to the number of inputs.
5
6use crate::common::IntegrateFloat;
7use crate::error::{IntegrateError, IntegrateResult};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
9use std::cell::RefCell;
10use std::collections::HashMap;
11use std::rc::Rc;
12
13/// Operations that can be recorded on the tape
14#[derive(Debug, Clone)]
15pub enum Operation<F: IntegrateFloat> {
16    /// Variable input
17    Variable(usize),
18    /// Constant value
19    Constant(F),
20    /// Addition
21    Add(usize, usize),
22    /// Subtraction
23    Sub(usize, usize),
24    /// Multiplication
25    Mul(usize, usize),
26    /// Division
27    Div(usize, usize),
28    /// Negation
29    Neg(usize),
30    /// Power
31    Pow(usize, F),
32    /// General power (base and exponent are both tape values)
33    PowGeneral(usize, usize),
34    /// Sin
35    Sin(usize),
36    /// Cos
37    Cos(usize),
38    /// Tan
39    Tan(usize),
40    /// Exp
41    Exp(usize),
42    /// Ln
43    Ln(usize),
44    /// Sqrt
45    Sqrt(usize),
46    /// Tanh
47    Tanh(usize),
48    /// Sinh
49    Sinh(usize),
50    /// Cosh
51    Cosh(usize),
52    /// Atan2
53    Atan2(usize, usize),
54    /// Abs
55    Abs(usize),
56    /// Max
57    Max(usize, usize),
58    /// Min
59    Min(usize, usize),
60}
61
62/// Node in the computation graph
63pub struct TapeNode<F: IntegrateFloat> {
64    /// The value at this node
65    pub value: F,
66    /// The operation that produced this value
67    pub operation: Operation<F>,
68    /// The gradient accumulated at this node
69    pub gradient: RefCell<F>,
70}
71
72impl<F: IntegrateFloat> TapeNode<F> {
73    /// Create a new tape node
74    pub fn new(value: F, operation: Operation<F>) -> Self {
75        TapeNode {
76            value,
77            operation,
78            gradient: RefCell::new(F::zero()),
79        }
80    }
81}
82
83/// Reverse mode AD tape for recording operations
84pub struct Tape<F: IntegrateFloat> {
85    /// Nodes in the computation graph
86    nodes: Vec<Rc<TapeNode<F>>>,
87    /// Mapping from variable indices to node indices
88    var_map: HashMap<usize, usize>,
89}
90
91impl<F: IntegrateFloat> Tape<F> {
92    /// Create a new tape
93    pub fn new() -> Self {
94        Tape {
95            nodes: Vec::new(),
96            var_map: HashMap::new(),
97        }
98    }
99
100    /// Add a variable to the tape
101    pub fn variable(&mut self, idx: usize, value: F) -> usize {
102        let nodeidx = self.nodes.len();
103        self.nodes
104            .push(Rc::new(TapeNode::new(value, Operation::Variable(idx))));
105        self.var_map.insert(idx, nodeidx);
106        nodeidx
107    }
108
109    /// Add a constant to the tape
110    pub fn constant(&mut self, value: F) -> usize {
111        let nodeidx = self.nodes.len();
112        self.nodes
113            .push(Rc::new(TapeNode::new(value, Operation::Constant(value))));
114        nodeidx
115    }
116
117    /// Record addition
118    pub fn add(&mut self, a: usize, b: usize) -> usize {
119        let value = self.nodes[a].value + self.nodes[b].value;
120        let nodeidx = self.nodes.len();
121        self.nodes
122            .push(Rc::new(TapeNode::new(value, Operation::Add(a, b))));
123        nodeidx
124    }
125
126    /// Record subtraction
127    pub fn sub(&mut self, a: usize, b: usize) -> usize {
128        let value = self.nodes[a].value - self.nodes[b].value;
129        let nodeidx = self.nodes.len();
130        self.nodes
131            .push(Rc::new(TapeNode::new(value, Operation::Sub(a, b))));
132        nodeidx
133    }
134
135    /// Record multiplication
136    pub fn mul(&mut self, a: usize, b: usize) -> usize {
137        let value = self.nodes[a].value * self.nodes[b].value;
138        let nodeidx = self.nodes.len();
139        self.nodes
140            .push(Rc::new(TapeNode::new(value, Operation::Mul(a, b))));
141        nodeidx
142    }
143
144    /// Record division
145    pub fn div(&mut self, a: usize, b: usize) -> usize {
146        let value = self.nodes[a].value / self.nodes[b].value;
147        let nodeidx = self.nodes.len();
148        self.nodes
149            .push(Rc::new(TapeNode::new(value, Operation::Div(a, b))));
150        nodeidx
151    }
152
153    /// Record negation
154    pub fn neg(&mut self, a: usize) -> usize {
155        let value = -self.nodes[a].value;
156        let nodeidx = self.nodes.len();
157        self.nodes
158            .push(Rc::new(TapeNode::new(value, Operation::Neg(a))));
159        nodeidx
160    }
161
162    /// Record power
163    pub fn pow(&mut self, a: usize, n: F) -> usize {
164        let value = self.nodes[a].value.powf(n);
165        let nodeidx = self.nodes.len();
166        self.nodes
167            .push(Rc::new(TapeNode::new(value, Operation::Pow(a, n))));
168        nodeidx
169    }
170
171    /// Record sin
172    pub fn sin(&mut self, a: usize) -> usize {
173        let value = self.nodes[a].value.sin();
174        let nodeidx = self.nodes.len();
175        self.nodes
176            .push(Rc::new(TapeNode::new(value, Operation::Sin(a))));
177        nodeidx
178    }
179
180    /// Record cos
181    pub fn cos(&mut self, a: usize) -> usize {
182        let value = self.nodes[a].value.cos();
183        let nodeidx = self.nodes.len();
184        self.nodes
185            .push(Rc::new(TapeNode::new(value, Operation::Cos(a))));
186        nodeidx
187    }
188
189    /// Record exp
190    pub fn exp(&mut self, a: usize) -> usize {
191        let value = self.nodes[a].value.exp();
192        let nodeidx = self.nodes.len();
193        self.nodes
194            .push(Rc::new(TapeNode::new(value, Operation::Exp(a))));
195        nodeidx
196    }
197
198    /// Record ln
199    pub fn ln(&mut self, a: usize) -> usize {
200        let value = self.nodes[a].value.ln();
201        let nodeidx = self.nodes.len();
202        self.nodes
203            .push(Rc::new(TapeNode::new(value, Operation::Ln(a))));
204        nodeidx
205    }
206
207    /// Record sqrt
208    pub fn sqrt(&mut self, a: usize) -> usize {
209        let value = self.nodes[a].value.sqrt();
210        let nodeidx = self.nodes.len();
211        self.nodes
212            .push(Rc::new(TapeNode::new(value, Operation::Sqrt(a))));
213        nodeidx
214    }
215
216    /// Record general power where both base and exponent are variables
217    pub fn pow_general(&mut self, a: usize, b: usize) -> usize {
218        let value = self.nodes[a].value.powf(self.nodes[b].value);
219        let nodeidx = self.nodes.len();
220        self.nodes
221            .push(Rc::new(TapeNode::new(value, Operation::PowGeneral(a, b))));
222        nodeidx
223    }
224
225    /// Record tan
226    pub fn tan(&mut self, a: usize) -> usize {
227        let value = self.nodes[a].value.tan();
228        let nodeidx = self.nodes.len();
229        self.nodes
230            .push(Rc::new(TapeNode::new(value, Operation::Tan(a))));
231        nodeidx
232    }
233
234    /// Record tanh
235    pub fn tanh(&mut self, a: usize) -> usize {
236        let value = self.nodes[a].value.tanh();
237        let nodeidx = self.nodes.len();
238        self.nodes
239            .push(Rc::new(TapeNode::new(value, Operation::Tanh(a))));
240        nodeidx
241    }
242
243    /// Record sinh
244    pub fn sinh(&mut self, a: usize) -> usize {
245        let value = self.nodes[a].value.sinh();
246        let nodeidx = self.nodes.len();
247        self.nodes
248            .push(Rc::new(TapeNode::new(value, Operation::Sinh(a))));
249        nodeidx
250    }
251
252    /// Record cosh
253    pub fn cosh(&mut self, a: usize) -> usize {
254        let value = self.nodes[a].value.cosh();
255        let nodeidx = self.nodes.len();
256        self.nodes
257            .push(Rc::new(TapeNode::new(value, Operation::Cosh(a))));
258        nodeidx
259    }
260
261    /// Record atan2
262    pub fn atan2(&mut self, y: usize, x: usize) -> usize {
263        let value = self.nodes[y].value.atan2(self.nodes[x].value);
264        let nodeidx = self.nodes.len();
265        self.nodes
266            .push(Rc::new(TapeNode::new(value, Operation::Atan2(y, x))));
267        nodeidx
268    }
269
270    /// Record abs
271    pub fn abs(&mut self, a: usize) -> usize {
272        let value = self.nodes[a].value.abs();
273        let nodeidx = self.nodes.len();
274        self.nodes
275            .push(Rc::new(TapeNode::new(value, Operation::Abs(a))));
276        nodeidx
277    }
278
279    /// Record max
280    pub fn max(&mut self, a: usize, b: usize) -> usize {
281        let value = self.nodes[a].value.max(self.nodes[b].value);
282        let nodeidx = self.nodes.len();
283        self.nodes
284            .push(Rc::new(TapeNode::new(value, Operation::Max(a, b))));
285        nodeidx
286    }
287
288    /// Record min
289    pub fn min(&mut self, a: usize, b: usize) -> usize {
290        let value = self.nodes[a].value.min(self.nodes[b].value);
291        let nodeidx = self.nodes.len();
292        self.nodes
293            .push(Rc::new(TapeNode::new(value, Operation::Min(a, b))));
294        nodeidx
295    }
296
297    /// Get the value at a node
298    pub fn value(&self, idx: usize) -> F {
299        self.nodes[idx].value
300    }
301
302    /// Backward pass to compute gradients
303    pub fn backward(&mut self, outputidx: usize, nvars: usize) -> Array1<F> {
304        // Initialize gradients to zero
305        for node in &self.nodes {
306            *node.gradient.borrow_mut() = F::zero();
307        }
308
309        // Set gradient of output to 1
310        *self.nodes[outputidx].gradient.borrow_mut() = F::one();
311
312        // Backward pass
313        for i in (0..=outputidx).rev() {
314            let node = &self.nodes[i];
315            let grad = *node.gradient.borrow();
316
317            if grad.abs() < F::epsilon() {
318                continue;
319            }
320
321            match &node.operation {
322                Operation::Variable(_) | Operation::Constant(_) => {}
323                Operation::Add(a, b) => {
324                    *self.nodes[*a].gradient.borrow_mut() += grad;
325                    *self.nodes[*b].gradient.borrow_mut() += grad;
326                }
327                Operation::Sub(a, b) => {
328                    *self.nodes[*a].gradient.borrow_mut() += grad;
329                    *self.nodes[*b].gradient.borrow_mut() -= grad;
330                }
331                Operation::Mul(a, b) => {
332                    *self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*b].value;
333                    *self.nodes[*b].gradient.borrow_mut() += grad * self.nodes[*a].value;
334                }
335                Operation::Div(a, b) => {
336                    let b_val = self.nodes[*b].value;
337                    *self.nodes[*a].gradient.borrow_mut() += grad / b_val;
338                    *self.nodes[*b].gradient.borrow_mut() -=
339                        grad * self.nodes[*a].value / (b_val * b_val);
340                }
341                Operation::Neg(a) => {
342                    *self.nodes[*a].gradient.borrow_mut() -= grad;
343                }
344                Operation::Pow(a, n) => {
345                    *self.nodes[*a].gradient.borrow_mut() +=
346                        grad * *n * self.nodes[*a].value.powf(*n - F::one());
347                }
348                Operation::Sin(a) => {
349                    *self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*a].value.cos();
350                }
351                Operation::Cos(a) => {
352                    *self.nodes[*a].gradient.borrow_mut() -= grad * self.nodes[*a].value.sin();
353                }
354                Operation::Exp(a) => {
355                    *self.nodes[*a].gradient.borrow_mut() += grad * node.value;
356                }
357                Operation::Ln(a) => {
358                    *self.nodes[*a].gradient.borrow_mut() += grad / self.nodes[*a].value;
359                }
360                Operation::Sqrt(a) => {
361                    *self.nodes[*a].gradient.borrow_mut() +=
362                        grad / (F::from(2.0).unwrap() * node.value);
363                }
364                Operation::PowGeneral(a, b) => {
365                    // d/da(a^b) = b * a^(b-1)
366                    // d/db(a^b) = a^b * ln(a)
367                    let a_val = self.nodes[*a].value;
368                    let b_val = self.nodes[*b].value;
369                    *self.nodes[*a].gradient.borrow_mut() +=
370                        grad * b_val * a_val.powf(b_val - F::one());
371                    *self.nodes[*b].gradient.borrow_mut() += grad * node.value * a_val.ln();
372                }
373                Operation::Tan(a) => {
374                    // d/dx(tan(x)) = sec²(x) = 1/cos²(x)
375                    let cos_val = self.nodes[*a].value.cos();
376                    *self.nodes[*a].gradient.borrow_mut() += grad / (cos_val * cos_val);
377                }
378                Operation::Tanh(a) => {
379                    // d/dx(tanh(x)) = 1 - tanh²(x)
380                    let tanh_val = node.value;
381                    *self.nodes[*a].gradient.borrow_mut() +=
382                        grad * (F::one() - tanh_val * tanh_val);
383                }
384                Operation::Sinh(a) => {
385                    // d/dx(sinh(x)) = cosh(x)
386                    *self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*a].value.cosh();
387                }
388                Operation::Cosh(a) => {
389                    // d/dx(cosh(x)) = sinh(x)
390                    *self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*a].value.sinh();
391                }
392                Operation::Atan2(y, x) => {
393                    // d/dy(atan2(y,x)) = x/(x² + y²)
394                    // d/dx(atan2(y,x)) = -y/(x² + y²)
395                    let x_val = self.nodes[*x].value;
396                    let y_val = self.nodes[*y].value;
397                    let denom = x_val * x_val + y_val * y_val;
398                    *self.nodes[*y].gradient.borrow_mut() += grad * x_val / denom;
399                    *self.nodes[*x].gradient.borrow_mut() -= grad * y_val / denom;
400                }
401                Operation::Abs(a) => {
402                    // d/dx(|x|) = sign(x)
403                    let sign = if self.nodes[*a].value >= F::zero() {
404                        F::one()
405                    } else {
406                        -F::one()
407                    };
408                    *self.nodes[*a].gradient.borrow_mut() += grad * sign;
409                }
410                Operation::Max(a, b) => {
411                    // Gradient flows to the larger input
412                    if self.nodes[*a].value >= self.nodes[*b].value {
413                        *self.nodes[*a].gradient.borrow_mut() += grad;
414                    } else {
415                        *self.nodes[*b].gradient.borrow_mut() += grad;
416                    }
417                }
418                Operation::Min(a, b) => {
419                    // Gradient flows to the smaller input
420                    if self.nodes[*a].value <= self.nodes[*b].value {
421                        *self.nodes[*a].gradient.borrow_mut() += grad;
422                    } else {
423                        *self.nodes[*b].gradient.borrow_mut() += grad;
424                    }
425                }
426            }
427        }
428
429        // Collect gradients for variables
430        let mut gradients = Array1::zeros(nvars);
431        for (varidx, &nodeidx) in &self.var_map {
432            if *varidx < nvars {
433                gradients[*varidx] = *self.nodes[nodeidx].gradient.borrow();
434            }
435        }
436
437        gradients
438    }
439}
440
441impl<F: IntegrateFloat> Default for Tape<F> {
442    fn default() -> Self {
443        Self::new()
444    }
445}
446
447/// Checkpointing strategy for memory-efficient gradient computation
448#[derive(Debug, Clone, Copy)]
449pub enum CheckpointStrategy {
450    /// No checkpointing (store everything)
451    None,
452    /// Fixed interval checkpointing
453    FixedInterval(usize),
454    /// Logarithmic checkpointing
455    Logarithmic,
456    /// Memory-based checkpointing
457    MemoryBased { max_nodes: usize },
458}
459
460/// Reverse mode automatic differentiation engine
461pub struct ReverseAD<F: IntegrateFloat> {
462    /// Number of independent variables
463    nvars: usize,
464    /// Checkpointing strategy
465    checkpoint_strategy: CheckpointStrategy,
466    _phantom: std::marker::PhantomData<F>,
467}
468
469impl<F: IntegrateFloat> ReverseAD<F> {
470    /// Create a new reverse AD engine
471    pub fn new(nvars: usize) -> Self {
472        ReverseAD {
473            nvars,
474            checkpoint_strategy: CheckpointStrategy::None,
475            _phantom: std::marker::PhantomData,
476        }
477    }
478
479    /// Set checkpointing strategy
480    pub fn with_checkpoint_strategy(mut self, strategy: CheckpointStrategy) -> Self {
481        self.checkpoint_strategy = strategy;
482        self
483    }
484
485    /// Compute gradient using reverse mode AD
486    pub fn gradient<Func>(&mut self, f: Func, x: ArrayView1<F>) -> IntegrateResult<Array1<F>>
487    where
488        Func: Fn(&mut Tape<F>, &[usize]) -> usize,
489    {
490        if x.len() != self.nvars {
491            return Err(IntegrateError::DimensionMismatch(format!(
492                "Expected {} variables, got {}",
493                self.nvars,
494                x.len()
495            )));
496        }
497
498        let mut tape = Tape::new();
499        let mut var_indices = Vec::new();
500
501        // Add variables to tape
502        for (i, &val) in x.iter().enumerate() {
503            let idx = tape.variable(i, val);
504            var_indices.push(idx);
505        }
506
507        // Compute function
508        let outputidx = f(&mut tape, &var_indices);
509
510        // Backward pass
511        Ok(tape.backward(outputidx, self.nvars))
512    }
513
514    /// Compute Jacobian using reverse mode AD
515    pub fn jacobian<Func>(&mut self, f: Func, x: ArrayView1<F>) -> IntegrateResult<Array2<F>>
516    where
517        Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
518    {
519        if x.len() != self.nvars {
520            return Err(IntegrateError::DimensionMismatch(format!(
521                "Expected {} variables, got {}",
522                self.nvars,
523                x.len()
524            )));
525        }
526
527        let mut tape = Tape::new();
528        let mut var_indices = Vec::new();
529
530        // Add variables to tape
531        for (i, &val) in x.iter().enumerate() {
532            let idx = tape.variable(i, val);
533            var_indices.push(idx);
534        }
535
536        // Compute function
537        let output_indices = f(&mut tape, &var_indices);
538        let m = output_indices.len();
539
540        let mut jacobian = Array2::zeros((m, self.nvars));
541
542        // Compute gradients for each output
543        for (i, &outputidx) in output_indices.iter().enumerate() {
544            let grad = tape.backward(outputidx, self.nvars);
545            jacobian.row_mut(i).assign(&grad);
546        }
547
548        Ok(jacobian)
549    }
550
551    /// Compute Hessian (second derivatives) using reverse-over-forward AD
552    pub fn hessian<Func>(&mut self, f: Func, x: ArrayView1<F>) -> IntegrateResult<Array2<F>>
553    where
554        Func: Fn(&mut Tape<F>, &[usize]) -> usize + Clone,
555    {
556        if x.len() != self.nvars {
557            return Err(IntegrateError::DimensionMismatch(format!(
558                "Expected {} variables, got {}",
559                self.nvars,
560                x.len()
561            )));
562        }
563
564        let mut hessian = Array2::zeros((self.nvars, self.nvars));
565        let eps = F::from(1e-8).unwrap();
566
567        // Compute Hessian using finite differences of gradients
568        for j in 0..self.nvars {
569            // Perturb x[j]
570            let mut x_plus = x.to_owned();
571            x_plus[j] += eps;
572
573            let grad_plus = self.gradient(f.clone(), x_plus.view())?;
574            let grad_base = self.gradient(f.clone(), x)?;
575
576            // Hessian column j = (grad(x + eps*e_j) - grad(x)) / eps
577            for i in 0..self.nvars {
578                hessian[[i, j]] = (grad_plus[i] - grad_base[i]) / eps;
579            }
580        }
581
582        // Make Hessian symmetric (average upper and lower triangular parts)
583        for i in 0..self.nvars {
584            for j in (i + 1)..self.nvars {
585                let avg = (hessian[[i, j]] + hessian[[j, i]]) / F::from(2.0).unwrap();
586                hessian[[i, j]] = avg;
587                hessian[[j, i]] = avg;
588            }
589        }
590
591        Ok(hessian)
592    }
593
594    /// Compute gradients for multiple inputs in batch
595    pub fn batch_gradient<Func>(
596        &mut self,
597        f: Func,
598        x_batch: &[Array1<F>],
599    ) -> IntegrateResult<Vec<Array1<F>>>
600    where
601        Func: Fn(&mut Tape<F>, &[usize]) -> usize + Clone,
602    {
603        let mut gradients = Vec::with_capacity(x_batch.len());
604
605        for x in x_batch {
606            gradients.push(self.gradient(f.clone(), x.view())?);
607        }
608
609        Ok(gradients)
610    }
611
612    /// Compute Jacobian-vector product efficiently without forming full Jacobian
613    pub fn jvp<Func>(
614        &mut self,
615        f: Func,
616        x: ArrayView1<F>,
617        v: ArrayView1<F>,
618    ) -> IntegrateResult<Array1<F>>
619    where
620        Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
621    {
622        if x.len() != self.nvars || v.len() != self.nvars {
623            return Err(IntegrateError::DimensionMismatch(format!(
624                "Expected {} variables for both x and v",
625                self.nvars
626            )));
627        }
628
629        // Use forward mode for efficient JVP computation
630        let eps = F::from(1e-8).unwrap();
631        let x_perturbed = &x + &(v.to_owned() * eps);
632
633        let mut tape = Tape::new();
634        let mut var_indices = Vec::new();
635        let mut var_indices_perturbed = Vec::new();
636
637        // Evaluate at x and x + eps*v
638        for (i, &val) in x.iter().enumerate() {
639            let idx = tape.variable(i, val);
640            var_indices.push(idx);
641        }
642
643        let output_base = f(&mut tape, &var_indices);
644
645        tape = Tape::new();
646        for (i, &val) in x_perturbed.iter().enumerate() {
647            let idx = tape.variable(i, val);
648            var_indices_perturbed.push(idx);
649        }
650
651        let output_perturbed = f(&mut tape, &var_indices_perturbed);
652
653        // Compute JVP as (f(x + eps*v) - f(x)) / eps
654        let mut jvp = Array1::zeros(output_base.len());
655        for (i, (&idx_base, &idx_pert)) in
656            output_base.iter().zip(output_perturbed.iter()).enumerate()
657        {
658            jvp[i] = (tape.value(idx_pert) - tape.value(idx_base)) / eps;
659        }
660
661        Ok(jvp)
662    }
663
664    /// Compute vector-Jacobian product (useful for backpropagation)
665    pub fn vjp<Func>(
666        &mut self,
667        f: Func,
668        x: ArrayView1<F>,
669        v: ArrayView1<F>,
670    ) -> IntegrateResult<Array1<F>>
671    where
672        Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
673    {
674        if x.len() != self.nvars {
675            return Err(IntegrateError::DimensionMismatch(format!(
676                "Expected {} variables",
677                self.nvars
678            )));
679        }
680
681        let mut tape = Tape::new();
682        let mut var_indices = Vec::new();
683
684        // Add variables to tape
685        for (i, &val) in x.iter().enumerate() {
686            let idx = tape.variable(i, val);
687            var_indices.push(idx);
688        }
689
690        // Compute function
691        let output_indices = f(&mut tape, &var_indices);
692
693        if v.len() != output_indices.len() {
694            return Err(IntegrateError::DimensionMismatch(format!(
695                "Vector v length {} doesn't match output dimension {}",
696                v.len(),
697                output_indices.len()
698            )));
699        }
700
701        // Compute weighted sum of outputs
702        let mut weighted_sum = tape.constant(F::zero());
703        for (i, &outputidx) in output_indices.iter().enumerate() {
704            let v_i = tape.constant(v[i]);
705            let term = tape.mul(v_i, outputidx);
706            weighted_sum = tape.add(weighted_sum, term);
707        }
708
709        // Compute gradient of weighted sum
710        Ok(tape.backward(weighted_sum, self.nvars))
711    }
712}
713
714/// Compute gradient using reverse mode AD (convenience function)
715#[allow(dead_code)]
716pub fn reverse_gradient<F, Func>(f: Func, x: ArrayView1<F>) -> IntegrateResult<Array1<F>>
717where
718    F: IntegrateFloat,
719    Func: Fn(&mut Tape<F>, &[usize]) -> usize,
720{
721    let mut ad = ReverseAD::new(x.len());
722    ad.gradient(f, x)
723}
724
725/// Compute Jacobian using reverse mode AD (convenience function)
726#[allow(dead_code)]
727pub fn reverse_jacobian<F, Func>(f: Func, x: ArrayView1<F>) -> IntegrateResult<Array2<F>>
728where
729    F: IntegrateFloat,
730    Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
731{
732    let mut ad = ReverseAD::new(x.len());
733    ad.jacobian(f, x)
734}
735
736#[cfg(test)]
737mod tests {
738    use super::*;
739
740    #[test]
741    fn test_reverse_gradient() {
742        // Test gradient of f(x,y) = x^2 + y^2
743        let f = |tape: &mut Tape<f64>, vars: &[usize]| {
744            let x_sq = tape.mul(vars[0], vars[0]);
745            let y_sq = tape.mul(vars[1], vars[1]);
746            tape.add(x_sq, y_sq)
747        };
748
749        let x = Array1::from_vec(vec![3.0, 4.0]);
750        let grad = reverse_gradient(f, x.view()).unwrap();
751
752        // Gradient should be [2x, 2y] = [6, 8]
753        assert!((grad[0] - 6.0).abs() < 1e-10);
754        assert!((grad[1] - 8.0).abs() < 1e-10);
755    }
756
757    #[test]
758    fn test_reverse_jacobian() {
759        // Test Jacobian of f(x,y) = [x^2, x*y, y^2]
760        let f = |tape: &mut Tape<f64>, vars: &[usize]| {
761            let x_sq = tape.mul(vars[0], vars[0]);
762            let xy = tape.mul(vars[0], vars[1]);
763            let y_sq = tape.mul(vars[1], vars[1]);
764            vec![x_sq, xy, y_sq]
765        };
766
767        let x = Array1::from_vec(vec![2.0, 3.0]);
768        let jac = reverse_jacobian(f, x.view()).unwrap();
769
770        // Jacobian should be:
771        // [[2x, 0 ],
772        //  [y,  x ],
773        //  [0,  2y]]
774        assert!((jac[[0, 0]] - 4.0).abs() < 1e-10); // 2*2
775        assert!((jac[[0, 1]] - 0.0).abs() < 1e-10);
776        assert!((jac[[1, 0]] - 3.0).abs() < 1e-10); // y
777        assert!((jac[[1, 1]] - 2.0).abs() < 1e-10); // x
778        assert!((jac[[2, 0]] - 0.0).abs() < 1e-10);
779        assert!((jac[[2, 1]] - 6.0).abs() < 1e-10); // 2*3
780    }
781}