1use crate::automatic_differentiation::tape::{
8 BinaryOpType, ComputationTape, TapeNode, UnaryOpType, Variable,
9};
10use crate::error::OptimizeError;
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
12
13#[derive(Debug, Clone)]
15pub struct ReverseADOptions {
16 pub compute_gradient: bool,
18 pub compute_hessian: bool,
20 pub max_tape_size: usize,
22 pub optimize_tape: bool,
24}
25
26impl Default for ReverseADOptions {
27 fn default() -> Self {
28 Self {
29 compute_gradient: true,
30 compute_hessian: false,
31 max_tape_size: 1_000_000,
32 optimize_tape: true,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct ReverseVariable {
40 pub index: usize,
42 pub value: f64,
44 pub grad: f64,
46}
47
48impl ReverseVariable {
49 pub fn new(index: usize, value: f64) -> Self {
51 Self {
52 index,
53 value,
54 grad: 0.0,
55 }
56 }
57
58 pub fn constant(value: f64) -> Self {
60 Self {
61 index: usize::MAX, value,
63 grad: 0.0,
64 }
65 }
66
67 pub fn is_constant(&self) -> bool {
69 self.index == usize::MAX
70 }
71
72 pub fn value(&self) -> f64 {
74 self.value
75 }
76
77 pub fn grad(&self) -> f64 {
79 self.grad
80 }
81
82 pub fn set_grad(&mut self, grad: f64) {
84 self.grad = grad;
85 }
86
87 pub fn add_grad(&mut self, grad: f64) {
89 self.grad += grad;
90 }
91
92 pub fn zero_grad(&mut self) {
94 self.grad = 0.0;
95 }
96
97 pub fn from_scalar(value: f64) -> Self {
99 Self::constant(value)
100 }
101
102 pub fn powi(&self, n: i32) -> Self {
104 if self.is_constant() {
105 ReverseVariable::constant(self.value.powi(n))
106 } else {
107 ReverseVariable {
108 index: self.index,
109 value: self.value.powi(n),
110 grad: 0.0,
111 }
112 }
113 }
114
115 pub fn exp(&self) -> Self {
117 if self.is_constant() {
118 ReverseVariable::constant(self.value.exp())
119 } else {
120 ReverseVariable {
121 index: self.index,
122 value: self.value.exp(),
123 grad: 0.0,
124 }
125 }
126 }
127
128 pub fn ln(&self) -> Self {
130 if self.is_constant() {
131 ReverseVariable::constant(self.value.ln())
132 } else {
133 ReverseVariable {
134 index: self.index,
135 value: self.value.ln(),
136 grad: 0.0,
137 }
138 }
139 }
140
141 pub fn sin(&self) -> Self {
143 if self.is_constant() {
144 ReverseVariable::constant(self.value.sin())
145 } else {
146 ReverseVariable {
147 index: self.index,
148 value: self.value.sin(),
149 grad: 0.0,
150 }
151 }
152 }
153
154 pub fn cos(&self) -> Self {
156 if self.is_constant() {
157 ReverseVariable::constant(self.value.cos())
158 } else {
159 ReverseVariable {
160 index: self.index,
161 value: self.value.cos(),
162 grad: 0.0,
163 }
164 }
165 }
166
167 pub fn tan(&self) -> Self {
169 if self.is_constant() {
170 ReverseVariable::constant(self.value.tan())
171 } else {
172 ReverseVariable {
173 index: self.index,
174 value: self.value.tan(),
175 grad: 0.0,
176 }
177 }
178 }
179
180 pub fn sqrt(&self) -> Self {
182 if self.is_constant() {
183 ReverseVariable::constant(self.value.sqrt())
184 } else {
185 ReverseVariable {
186 index: self.index,
187 value: self.value.sqrt(),
188 grad: 0.0,
189 }
190 }
191 }
192
193 pub fn abs(&self) -> Self {
195 if self.is_constant() {
196 ReverseVariable::constant(self.value.abs())
197 } else {
198 ReverseVariable {
199 index: self.index,
200 value: self.value.abs(),
201 grad: 0.0,
202 }
203 }
204 }
205}
206
207pub struct ComputationGraph {
209 tape: ComputationTape,
211 var_counter: usize,
213 values: Vec<f64>,
215 gradients: Vec<f64>,
217}
218
219impl Default for ComputationGraph {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225impl ComputationGraph {
226 pub fn new() -> Self {
228 Self {
229 tape: ComputationTape::new(),
230 var_counter: 0,
231 values: Vec::new(),
232 gradients: Vec::new(),
233 }
234 }
235
236 pub fn variable(&mut self, value: f64) -> ReverseVariable {
238 let index = self.var_counter;
239 self.var_counter += 1;
240
241 self.values.push(value);
242 self.gradients.push(0.0);
243
244 self.tape.add_input(Variable::new(index, value));
245
246 ReverseVariable::new(index, value)
247 }
248
249 fn add_binary_op(
251 &mut self,
252 op_type: BinaryOpType,
253 left: &ReverseVariable,
254 right: &ReverseVariable,
255 result_value: f64,
256 left_grad: f64,
257 right_grad: f64,
258 ) -> ReverseVariable {
259 let result_index = self.var_counter;
260 self.var_counter += 1;
261
262 self.values.push(result_value);
263 self.gradients.push(0.0);
264
265 let node = TapeNode::BinaryOp {
267 op_type,
268 left: left.index,
269 right: right.index,
270 result: result_index,
271 left_partial: left_grad,
272 right_partial: right_grad,
273 };
274
275 self.tape.add_node(node);
276
277 ReverseVariable::new(result_index, result_value)
278 }
279
280 fn add_unary_op(
282 &mut self,
283 op_type: UnaryOpType,
284 input: &ReverseVariable,
285 result_value: f64,
286 input_grad: f64,
287 ) -> ReverseVariable {
288 let result_index = self.var_counter;
289 self.var_counter += 1;
290
291 self.values.push(result_value);
292 self.gradients.push(0.0);
293
294 let node = TapeNode::UnaryOp {
296 op_type,
297 input: input.index,
298 result: result_index,
299 partial: input_grad,
300 };
301
302 self.tape.add_node(node);
303
304 ReverseVariable::new(result_index, result_value)
305 }
306
307 pub fn backward(&mut self, output_var: &ReverseVariable) -> Result<(), OptimizeError> {
309 if !output_var.is_constant() {
311 self.gradients[output_var.index] = 1.0;
312 }
313
314 let _ = self.tape.backward(&mut self.gradients);
316
317 Ok(())
318 }
319
320 pub fn get_gradient(&self, var: &ReverseVariable) -> f64 {
322 if var.is_constant() {
323 0.0
324 } else {
325 self.gradients[var.index]
326 }
327 }
328
329 pub fn zero_gradients(&mut self) {
331 for grad in &mut self.gradients {
332 *grad = 0.0;
333 }
334 }
335}
336
337impl std::ops::Add for ReverseVariable {
341 type Output = Self;
342
343 fn add(self, other: Self) -> Self {
344 if self.is_constant() && other.is_constant() {
345 ReverseVariable::constant(self.value + other.value)
346 } else {
347 let result_value = self.value + other.value;
350 let max_index = self.index.max(other.index);
351 ReverseVariable {
352 index: if max_index == usize::MAX {
353 usize::MAX
354 } else {
355 max_index + 1
356 },
357 value: result_value,
358 grad: 0.0,
359 }
360 }
361 }
362}
363
364impl std::ops::Sub for ReverseVariable {
365 type Output = Self;
366
367 fn sub(self, other: Self) -> Self {
368 if self.is_constant() && other.is_constant() {
369 ReverseVariable::constant(self.value - other.value)
370 } else {
371 let result_value = self.value - other.value;
372 let max_index = self.index.max(other.index);
373 ReverseVariable {
374 index: if max_index == usize::MAX {
375 usize::MAX
376 } else {
377 max_index + 1
378 },
379 value: result_value,
380 grad: 0.0,
381 }
382 }
383 }
384}
385
386impl std::ops::Mul for ReverseVariable {
387 type Output = Self;
388
389 fn mul(self, other: Self) -> Self {
390 if self.is_constant() && other.is_constant() {
391 ReverseVariable::constant(self.value * other.value)
392 } else {
393 let result_value = self.value * other.value;
394 let max_index = self.index.max(other.index);
395 ReverseVariable {
396 index: if max_index == usize::MAX {
397 usize::MAX
398 } else {
399 max_index + 1
400 },
401 value: result_value,
402 grad: 0.0,
403 }
404 }
405 }
406}
407
408impl std::ops::Div for ReverseVariable {
409 type Output = Self;
410
411 fn div(self, other: Self) -> Self {
412 if self.is_constant() && other.is_constant() {
413 ReverseVariable::constant(self.value / other.value)
414 } else {
415 let result_value = self.value / other.value;
416 let max_index = self.index.max(other.index);
417 ReverseVariable {
418 index: if max_index == usize::MAX {
419 usize::MAX
420 } else {
421 max_index + 1
422 },
423 value: result_value,
424 grad: 0.0,
425 }
426 }
427 }
428}
429
430impl std::ops::Neg for ReverseVariable {
431 type Output = Self;
432
433 fn neg(self) -> Self {
434 if self.is_constant() {
435 ReverseVariable::constant(-self.value)
436 } else {
437 ReverseVariable {
438 index: self.index,
439 value: -self.value,
440 grad: 0.0,
441 }
442 }
443 }
444}
445
446impl std::ops::Add<f64> for ReverseVariable {
448 type Output = Self;
449
450 fn add(self, scalar: f64) -> Self {
451 ReverseVariable {
452 index: self.index,
453 value: self.value + scalar,
454 grad: self.grad,
455 }
456 }
457}
458
459impl std::ops::Sub<f64> for ReverseVariable {
460 type Output = Self;
461
462 fn sub(self, scalar: f64) -> Self {
463 ReverseVariable {
464 index: self.index,
465 value: self.value - scalar,
466 grad: self.grad,
467 }
468 }
469}
470
471impl std::ops::Mul<f64> for ReverseVariable {
472 type Output = Self;
473
474 fn mul(self, scalar: f64) -> Self {
475 ReverseVariable {
476 index: self.index,
477 value: self.value * scalar,
478 grad: self.grad,
479 }
480 }
481}
482
483impl std::ops::Div<f64> for ReverseVariable {
484 type Output = Self;
485
486 fn div(self, scalar: f64) -> Self {
487 ReverseVariable {
488 index: self.index,
489 value: self.value / scalar,
490 grad: self.grad,
491 }
492 }
493}
494
495impl std::ops::Add<ReverseVariable> for f64 {
497 type Output = ReverseVariable;
498
499 fn add(self, var: ReverseVariable) -> ReverseVariable {
500 var + self
501 }
502}
503
504impl std::ops::Sub<ReverseVariable> for f64 {
505 type Output = ReverseVariable;
506
507 fn sub(self, var: ReverseVariable) -> ReverseVariable {
508 ReverseVariable {
509 index: var.index,
510 value: self - var.value,
511 grad: var.grad,
512 }
513 }
514}
515
516impl std::ops::Mul<ReverseVariable> for f64 {
517 type Output = ReverseVariable;
518
519 fn mul(self, var: ReverseVariable) -> ReverseVariable {
520 var * self
521 }
522}
523
524impl std::ops::Div<ReverseVariable> for f64 {
525 type Output = ReverseVariable;
526
527 fn div(self, var: ReverseVariable) -> ReverseVariable {
528 ReverseVariable {
529 index: var.index,
530 value: self / var.value,
531 grad: var.grad,
532 }
533 }
534}
535
536#[allow(dead_code)]
538pub fn add(
539 graph: &mut ComputationGraph,
540 left: &ReverseVariable,
541 right: &ReverseVariable,
542) -> ReverseVariable {
543 if left.is_constant() && right.is_constant() {
544 return ReverseVariable::constant(left.value + right.value);
545 }
546
547 let result_value = left.value + right.value;
548 graph.add_binary_op(BinaryOpType::Add, left, right, result_value, 1.0, 1.0)
549}
550
551#[allow(dead_code)]
553pub fn mul(
554 graph: &mut ComputationGraph,
555 left: &ReverseVariable,
556 right: &ReverseVariable,
557) -> ReverseVariable {
558 if left.is_constant() && right.is_constant() {
559 return ReverseVariable::constant(left.value * right.value);
560 }
561
562 let result_value = left.value * right.value;
563 graph.add_binary_op(
564 BinaryOpType::Mul,
565 left,
566 right,
567 result_value,
568 right.value,
569 left.value,
570 )
571}
572
573#[allow(dead_code)]
575pub fn sub(
576 graph: &mut ComputationGraph,
577 left: &ReverseVariable,
578 right: &ReverseVariable,
579) -> ReverseVariable {
580 if left.is_constant() && right.is_constant() {
581 return ReverseVariable::constant(left.value - right.value);
582 }
583
584 let result_value = left.value - right.value;
585 graph.add_binary_op(BinaryOpType::Sub, left, right, result_value, 1.0, -1.0)
586}
587
588#[allow(dead_code)]
590pub fn div(
591 graph: &mut ComputationGraph,
592 left: &ReverseVariable,
593 right: &ReverseVariable,
594) -> ReverseVariable {
595 if left.is_constant() && right.is_constant() {
596 return ReverseVariable::constant(left.value / right.value);
597 }
598
599 let result_value = left.value / right.value;
600 let left_grad = 1.0 / right.value;
601 let right_grad = -left.value / (right.value * right.value);
602
603 graph.add_binary_op(
604 BinaryOpType::Div,
605 left,
606 right,
607 result_value,
608 left_grad,
609 right_grad,
610 )
611}
612
613#[allow(dead_code)]
615pub fn powi(graph: &mut ComputationGraph, input: &ReverseVariable, n: i32) -> ReverseVariable {
616 if input.is_constant() {
617 return ReverseVariable::constant(input.value.powi(n));
618 }
619
620 let result_value = input.value.powi(n);
621 let input_grad = (n as f64) * input.value.powi(n - 1);
622
623 graph.add_unary_op(UnaryOpType::Square, input, result_value, input_grad)
624}
625
626#[allow(dead_code)]
628pub fn exp(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
629 if input.is_constant() {
630 return ReverseVariable::constant(input.value.exp());
631 }
632
633 let result_value = input.value.exp();
634 let input_grad = result_value; graph.add_unary_op(UnaryOpType::Exp, input, result_value, input_grad)
637}
638
639#[allow(dead_code)]
641pub fn ln(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
642 if input.is_constant() {
643 return ReverseVariable::constant(input.value.ln());
644 }
645
646 let result_value = input.value.ln();
647 let input_grad = 1.0 / input.value;
648
649 graph.add_unary_op(UnaryOpType::Ln, input, result_value, input_grad)
650}
651
652#[allow(dead_code)]
654pub fn sin(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
655 if input.is_constant() {
656 return ReverseVariable::constant(input.value.sin());
657 }
658
659 let result_value = input.value.sin();
660 let input_grad = input.value.cos();
661
662 graph.add_unary_op(UnaryOpType::Sin, input, result_value, input_grad)
663}
664
665#[allow(dead_code)]
667pub fn cos(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
668 if input.is_constant() {
669 return ReverseVariable::constant(input.value.cos());
670 }
671
672 let result_value = input.value.cos();
673 let input_grad = -input.value.sin();
674
675 graph.add_unary_op(UnaryOpType::Cos, input, result_value, input_grad)
676}
677
678#[allow(dead_code)]
680pub fn tan(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
681 if input.is_constant() {
682 return ReverseVariable::constant(input.value.tan());
683 }
684
685 let result_value = input.value.tan();
686 let cos_val = input.value.cos();
687 let input_grad = 1.0 / (cos_val * cos_val); graph.add_unary_op(UnaryOpType::Tan, input, result_value, input_grad)
690}
691
692#[allow(dead_code)]
694pub fn sqrt(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
695 if input.is_constant() {
696 return ReverseVariable::constant(input.value.sqrt());
697 }
698
699 let result_value = input.value.sqrt();
700 let input_grad = 0.5 / result_value; graph.add_unary_op(UnaryOpType::Sqrt, input, result_value, input_grad)
703}
704
705#[allow(dead_code)]
707pub fn abs(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
708 if input.is_constant() {
709 return ReverseVariable::constant(input.value.abs());
710 }
711
712 let result_value = input.value.abs();
713 let input_grad = if input.value >= 0.0 { 1.0 } else { -1.0 };
714
715 graph.add_unary_op(UnaryOpType::Sqrt, input, result_value, input_grad)
716}
717
718#[allow(dead_code)]
720pub fn sigmoid(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
721 if input.is_constant() {
722 let exp_val = (-input.value).exp();
723 return ReverseVariable::constant(1.0 / (1.0 + exp_val));
724 }
725
726 let exp_neg_x = (-input.value).exp();
727 let result_value = 1.0 / (1.0 + exp_neg_x);
728 let input_grad = result_value * (1.0 - result_value); graph.add_unary_op(UnaryOpType::Exp, input, result_value, input_grad)
731}
732
733#[allow(dead_code)]
735pub fn tanh(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
736 if input.is_constant() {
737 return ReverseVariable::constant(input.value.tanh());
738 }
739
740 let result_value = input.value.tanh();
741 let input_grad = 1.0 - result_value * result_value; graph.add_unary_op(UnaryOpType::Tan, input, result_value, input_grad)
744}
745
746#[allow(dead_code)]
748pub fn relu(graph: &mut ComputationGraph, input: &ReverseVariable) -> ReverseVariable {
749 if input.is_constant() {
750 return ReverseVariable::constant(input.value.max(0.0));
751 }
752
753 let result_value = input.value.max(0.0);
754 let input_grad = if input.value > 0.0 { 1.0 } else { 0.0 };
755
756 graph.add_unary_op(UnaryOpType::Sqrt, input, result_value, input_grad)
757}
758
759#[allow(dead_code)]
761pub fn leaky_relu(
762 graph: &mut ComputationGraph,
763 input: &ReverseVariable,
764 alpha: f64,
765) -> ReverseVariable {
766 if input.is_constant() {
767 let result = if input.value > 0.0 {
768 input.value
769 } else {
770 alpha * input.value
771 };
772 return ReverseVariable::constant(result);
773 }
774
775 let result_value = if input.value > 0.0 {
776 input.value
777 } else {
778 alpha * input.value
779 };
780 let input_grad = if input.value > 0.0 { 1.0 } else { alpha };
781
782 graph.add_unary_op(UnaryOpType::Sqrt, input, result_value, input_grad)
783}
784
785#[allow(dead_code)]
789pub fn reverse_gradient<F>(func: F, x: &ArrayView1<f64>) -> Result<Array1<f64>, OptimizeError>
790where
791 F: Fn(&ArrayView1<f64>) -> f64,
792{
793 let n = x.len();
796 let mut gradient = Array1::zeros(n);
797 let h = 1e-8;
798
799 for i in 0..n {
800 let mut x_plus = x.to_owned();
801 x_plus[i] += h;
802 let f_plus = func(&x_plus.view());
803
804 let mut x_minus = x.to_owned();
805 x_minus[i] -= h;
806 let f_minus = func(&x_minus.view());
807
808 gradient[i] = (f_plus - f_minus) / (2.0 * h);
809 }
810
811 Ok(gradient)
812}
813
814#[allow(dead_code)]
816pub fn reverse_gradient_ad<F>(func: F, x: &ArrayView1<f64>) -> Result<Array1<f64>, OptimizeError>
817where
818 F: Fn(&mut ComputationGraph, &[ReverseVariable]) -> ReverseVariable,
819{
820 let mut graph = ComputationGraph::new();
821
822 let input_vars: Vec<ReverseVariable> = x.iter().map(|&xi| graph.variable(xi)).collect();
824
825 let output = func(&mut graph, &input_vars);
827
828 graph.backward(&output)?;
830
831 let mut gradient = Array1::zeros(x.len());
833 for (i, var) in input_vars.iter().enumerate() {
834 gradient[i] = graph.get_gradient(var);
835 }
836
837 Ok(gradient)
838}
839
840#[allow(dead_code)]
842pub fn reverse_hessian<F>(func: F, x: &ArrayView1<f64>) -> Result<Array2<f64>, OptimizeError>
843where
844 F: Fn(&ArrayView1<f64>) -> f64,
845{
846 let n = x.len();
847 let mut hessian = Array2::zeros((n, n));
848 let h = 1e-5;
849
850 for i in 0..n {
853 for j in 0..n {
854 if i == j {
855 let mut x_plus = x.to_owned();
857 x_plus[i] += h;
858 let f_plus = func(&x_plus.view());
859
860 let f_center = func(x);
861
862 let mut x_minus = x.to_owned();
863 x_minus[i] -= h;
864 let f_minus = func(&x_minus.view());
865
866 hessian[[i, j]] = (f_plus - 2.0 * f_center + f_minus) / (h * h);
867 } else {
868 {
871 #[allow(clippy::similar_names)]
872 let mut x_pp = x.to_owned();
873 x_pp[i] += h;
874 x_pp[j] += h;
875 #[allow(clippy::similar_names)]
876 let f_pp = func(&x_pp.view());
877
878 #[allow(clippy::similar_names)]
879 let mut x_pm = x.to_owned();
880 x_pm[i] += h;
881 x_pm[j] -= h;
882 #[allow(clippy::similar_names)]
883 let f_pm = func(&x_pm.view());
884
885 #[allow(clippy::similar_names)]
886 let mut x_mp = x.to_owned();
887 x_mp[i] -= h;
888 x_mp[j] += h;
889 #[allow(clippy::similar_names)]
890 let f_mp = func(&x_mp.view());
891
892 let mut x_mm = x.to_owned();
893 x_mm[i] -= h;
894 x_mm[j] -= h;
895 let f_mm = func(&x_mm.view());
896
897 hessian[[i, j]] = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h * h);
898 }
899 }
900 }
901 }
902
903 Ok(hessian)
904}
905
906#[allow(dead_code)]
908pub fn reverse_hessian_ad<F>(func: F, x: &ArrayView1<f64>) -> Result<Array2<f64>, OptimizeError>
909where
910 F: Fn(&mut ComputationGraph, &[ReverseVariable]) -> ReverseVariable,
911{
912 let n = x.len();
913 let mut hessian = Array2::zeros((n, n));
914
915 for i in 0..n {
918 let gradient_i_func = |x_val: &ArrayView1<f64>| -> f64 {
920 let grad = reverse_gradient_ad(&func, x_val).unwrap();
921 grad[i]
922 };
923
924 let hessian_row = reverse_gradient(gradient_i_func, x)?;
926 for j in 0..n {
927 hessian[[i, j]] = hessian_row[j];
928 }
929 }
930
931 Ok(hessian)
932}
933
934#[allow(dead_code)]
936pub fn reverse_gradient_with_tape<F>(
937 func: F,
938 x: &ArrayView1<f64>,
939 _options: &ReverseADOptions,
940) -> Result<Array1<f64>, OptimizeError>
941where
942 F: Fn(&mut ComputationGraph, &[ReverseVariable]) -> ReverseVariable,
943{
944 let mut graph = ComputationGraph::new();
945
946 let input_vars: Vec<ReverseVariable> = x.iter().map(|&xi| graph.variable(xi)).collect();
948
949 let output = func(&mut graph, &input_vars);
951
952 graph.backward(&output)?;
954
955 let mut gradient = Array1::zeros(x.len());
957 for (i, var) in input_vars.iter().enumerate() {
958 gradient[i] = graph.get_gradient(var);
959 }
960
961 Ok(gradient)
962}
963
964#[allow(dead_code)]
966pub fn is_reverse_mode_efficient(_input_dim: usize, output_dim: usize) -> bool {
967 output_dim <= 10 || (output_dim <= _input_dim && output_dim <= 20)
970}
971
972#[allow(clippy::many_single_char_names)]
974#[allow(dead_code)]
975pub fn reverse_vjp<F>(
976 func: F,
977 x: &ArrayView1<f64>,
978 v: &ArrayView1<f64>,
979) -> Result<Array1<f64>, OptimizeError>
980where
981 F: Fn(&ArrayView1<f64>) -> Array1<f64>,
982{
983 let n = x.len();
986 let m = v.len();
987
988 let mut result = Array1::zeros(n);
990
991 for i in 0..m {
993 if v[i] != 0.0 {
994 let component_func = |x_val: &ArrayView1<f64>| -> f64 {
996 let f_val = func(x_val);
997 f_val[i]
998 };
999
1000 let grad_i = reverse_gradient(component_func, x)?;
1002
1003 for j in 0..n {
1005 result[j] += v[i] * grad_i[j];
1006 }
1007 }
1008 }
1009
1010 Ok(result)
1011}
1012
1013#[allow(clippy::many_single_char_names)]
1015#[allow(dead_code)]
1016pub fn reverse_vjp_ad<F>(
1017 func: F,
1018 x: &ArrayView1<f64>,
1019 v: &ArrayView1<f64>,
1020) -> Result<Array1<f64>, OptimizeError>
1021where
1022 F: Fn(&mut ComputationGraph, &[ReverseVariable]) -> Vec<ReverseVariable>,
1023{
1024 let n = x.len();
1025 let m = v.len();
1026 let mut result = Array1::zeros(n);
1027
1028 for i in 0..m {
1030 if v[i] != 0.0 {
1031 let mut graph = ComputationGraph::new();
1032
1033 let input_vars: Vec<ReverseVariable> = x.iter().map(|&xi| graph.variable(xi)).collect();
1035
1036 let outputs = func(&mut graph, &input_vars);
1038
1039 if i < outputs.len() {
1041 graph.backward(&outputs[i])?;
1042
1043 for (j, var) in input_vars.iter().enumerate() {
1045 result[j] += v[i] * graph.get_gradient(var);
1046 }
1047 }
1048 }
1049 }
1050
1051 Ok(result)
1052}
1053
1054#[allow(dead_code)]
1056pub fn reverse_gauss_newton_hessian<F>(
1057 func: F,
1058 x: &ArrayView1<f64>,
1059) -> Result<Array2<f64>, OptimizeError>
1060where
1061 F: Fn(&ArrayView1<f64>) -> Array1<f64>,
1062{
1063 let n = x.len();
1065 let f_val = func(x);
1066 let m = f_val.len();
1067
1068 let mut hessian = Array2::zeros((n, n));
1070
1071 for i in 0..m {
1073 let residual_i = |x_val: &ArrayView1<f64>| -> f64 {
1075 let f_val = func(x_val);
1076 f_val[i]
1077 };
1078
1079 let grad_i = reverse_gradient(residual_i, x)?;
1081
1082 for j in 0..n {
1084 for k in 0..n {
1085 hessian[[j, k]] += grad_i[j] * grad_i[k];
1086 }
1087 }
1088 }
1089
1090 Ok(hessian)
1091}
1092
1093#[allow(dead_code)]
1095pub fn reverse_gauss_newton_hessian_ad<F>(
1096 func: F,
1097 x: &ArrayView1<f64>,
1098) -> Result<Array2<f64>, OptimizeError>
1099where
1100 F: Fn(&mut ComputationGraph, &[ReverseVariable]) -> Vec<ReverseVariable>,
1101{
1102 let n = x.len();
1103 let mut hessian = Array2::zeros((n, n));
1104
1105 let mut graph_temp = ComputationGraph::new();
1107 let input_vars_temp: Vec<ReverseVariable> =
1108 x.iter().map(|&xi| graph_temp.variable(xi)).collect();
1109 let outputs_temp = func(&mut graph_temp, &input_vars_temp);
1110 let m = outputs_temp.len();
1111
1112 for i in 0..m {
1114 let mut graph = ComputationGraph::new();
1115
1116 let input_vars: Vec<ReverseVariable> = x.iter().map(|&xi| graph.variable(xi)).collect();
1118
1119 let outputs = func(&mut graph, &input_vars);
1121
1122 if i < outputs.len() {
1124 graph.backward(&outputs[i])?;
1125
1126 let mut grad_i = Array1::zeros(n);
1128 for (j, var) in input_vars.iter().enumerate() {
1129 grad_i[j] = graph.get_gradient(var);
1130 }
1131
1132 for j in 0..n {
1134 for k in 0..n {
1135 hessian[[j, k]] += grad_i[j] * grad_i[k];
1136 }
1137 }
1138 }
1139 }
1140
1141 Ok(hessian)
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146 use super::*;
1147 use approx::assert_abs_diff_eq;
1148
1149 #[test]
1150 fn test_computation_graph() {
1151 let mut graph = ComputationGraph::new();
1152
1153 let x = graph.variable(2.0);
1155 let y = graph.variable(3.0);
1156
1157 let xy = mul(&mut graph, &x, &y);
1159 let z = add(&mut graph, &xy, &x);
1160
1161 assert_abs_diff_eq!(z.value, 8.0, epsilon = 1e-10); graph.backward(&z).unwrap();
1165
1166 assert_abs_diff_eq!(graph.get_gradient(&x), 4.0, epsilon = 1e-10);
1168 assert_abs_diff_eq!(graph.get_gradient(&y), 2.0, epsilon = 1e-10);
1169 }
1170
1171 #[test]
1172 fn test_unary_operations() {
1173 let mut graph = ComputationGraph::new();
1174
1175 let x = graph.variable(1.0);
1176 let exp_x = exp(&mut graph, &x);
1177
1178 assert_abs_diff_eq!(exp_x.value, std::f64::consts::E, epsilon = 1e-10);
1179
1180 graph.backward(&exp_x).unwrap();
1181
1182 assert_abs_diff_eq!(graph.get_gradient(&x), std::f64::consts::E, epsilon = 1e-10);
1184 }
1185
1186 #[test]
1187 fn test_reverse_gradient() {
1188 let func = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[0] * x[1] + 2.0 * x[1] * x[1] };
1190
1191 let x = Array1::from_vec(vec![1.0, 2.0]);
1192 let grad = reverse_gradient(func, &x.view()).unwrap();
1193
1194 assert_abs_diff_eq!(grad[0], 4.0, epsilon = 1e-6);
1197 assert_abs_diff_eq!(grad[1], 9.0, epsilon = 1e-6);
1198 }
1199
1200 #[test]
1201 fn test_is_reverse_mode_efficient() {
1202 assert!(is_reverse_mode_efficient(100, 1));
1204 assert!(is_reverse_mode_efficient(50, 5));
1205
1206 assert!(!is_reverse_mode_efficient(10, 100));
1208 }
1209
1210 #[test]
1211 fn test_reverse_hessian() {
1212 let func = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[0] * x[1] + 2.0 * x[1] * x[1] };
1214
1215 let x = Array1::from_vec(vec![1.0, 2.0]);
1216 let hess = reverse_hessian(func, &x.view()).unwrap();
1217
1218 assert_abs_diff_eq!(hess[[0, 0]], 2.0, epsilon = 1e-4);
1222 assert_abs_diff_eq!(hess[[0, 1]], 1.0, epsilon = 1e-4);
1223 assert_abs_diff_eq!(hess[[1, 0]], 1.0, epsilon = 1e-4);
1224 assert_abs_diff_eq!(hess[[1, 1]], 4.0, epsilon = 1e-4);
1225 }
1226
1227 #[test]
1228 fn test_reverse_gradient_ad() {
1229 let func = |graph: &mut ComputationGraph, vars: &[ReverseVariable]| {
1231 let x = &vars[0];
1232 let y = &vars[1];
1233
1234 let x_squared = mul(graph, x, x);
1235 let xy = mul(graph, x, y);
1236 let y_squared = mul(graph, y, y);
1237 let two_y_squared = mul(graph, &ReverseVariable::constant(2.0), &y_squared);
1238
1239 let temp = add(graph, &x_squared, &xy);
1240 add(graph, &temp, &two_y_squared)
1241 };
1242
1243 let x = Array1::from_vec(vec![1.0, 2.0]);
1244 let grad = reverse_gradient_ad(func, &x.view()).unwrap();
1245
1246 assert_abs_diff_eq!(grad[0], 4.0, epsilon = 1e-10);
1249 assert_abs_diff_eq!(grad[1], 9.0, epsilon = 1e-10);
1250 }
1251
1252 #[test]
1253 fn test_reverse_vjp() {
1254 let func = |x: &ArrayView1<f64>| -> Array1<f64> {
1256 Array1::from_vec(vec![x[0] * x[0], x[0] * x[1], x[1] * x[1]])
1257 };
1258
1259 let x = Array1::from_vec(vec![2.0, 3.0]);
1260 let v = Array1::from_vec(vec![1.0, 1.0, 1.0]);
1261 let vjp = reverse_vjp(func, &x.view(), &v.view()).unwrap();
1262
1263 assert_abs_diff_eq!(vjp[0], 7.0, epsilon = 1e-6);
1270 assert_abs_diff_eq!(vjp[1], 8.0, epsilon = 1e-6);
1271 }
1272
1273 #[test]
1274 fn test_reverse_gauss_newton_hessian() {
1275 let residual_func =
1277 |x: &ArrayView1<f64>| -> Array1<f64> { Array1::from_vec(vec![x[0] - 1.0, x[1] - 2.0]) };
1278
1279 let x = Array1::from_vec(vec![0.0, 0.0]);
1280 let gn_hess = reverse_gauss_newton_hessian(residual_func, &x.view()).unwrap();
1281
1282 assert_abs_diff_eq!(gn_hess[[0, 0]], 1.0, epsilon = 1e-6);
1284 assert_abs_diff_eq!(gn_hess[[0, 1]], 0.0, epsilon = 1e-6);
1285 assert_abs_diff_eq!(gn_hess[[1, 0]], 0.0, epsilon = 1e-6);
1286 assert_abs_diff_eq!(gn_hess[[1, 1]], 1.0, epsilon = 1e-6);
1287 }
1288
1289 #[test]
1290 fn test_power_operation() {
1291 let mut graph = ComputationGraph::new();
1292
1293 let x = graph.variable(2.0);
1294 let x_cubed = powi(&mut graph, &x, 3);
1295
1296 assert_abs_diff_eq!(x_cubed.value, 8.0, epsilon = 1e-10); graph.backward(&x_cubed).unwrap();
1299
1300 assert_abs_diff_eq!(graph.get_gradient(&x), 12.0, epsilon = 1e-10);
1302 }
1303
1304 #[test]
1305 fn test_trigonometric_operations() {
1306 let mut graph = ComputationGraph::new();
1307
1308 let x = graph.variable(0.0);
1309 let sin_x = sin(&mut graph, &x);
1310 let cos_x = cos(&mut graph, &x);
1311
1312 assert_abs_diff_eq!(sin_x.value, 0.0, epsilon = 1e-10); assert_abs_diff_eq!(cos_x.value, 1.0, epsilon = 1e-10); graph.backward(&sin_x).unwrap();
1316 assert_abs_diff_eq!(graph.get_gradient(&x), 1.0, epsilon = 1e-10); graph.zero_gradients();
1319 graph.backward(&cos_x).unwrap();
1320 assert_abs_diff_eq!(graph.get_gradient(&x), 0.0, epsilon = 1e-10); }
1322
1323 #[test]
1324 fn test_arithmetic_operations_without_graph() {
1325 let a = ReverseVariable::constant(3.0);
1327 let b = ReverseVariable::constant(2.0);
1328
1329 let sum = a.clone() + b.clone();
1331 assert_abs_diff_eq!(sum.value, 5.0, epsilon = 1e-10);
1332 assert!(sum.is_constant());
1333
1334 let diff = a.clone() - b.clone();
1336 assert_abs_diff_eq!(diff.value, 1.0, epsilon = 1e-10);
1337
1338 let product = a.clone() * b.clone();
1340 assert_abs_diff_eq!(product.value, 6.0, epsilon = 1e-10);
1341
1342 let quotient = a.clone() / b.clone();
1344 assert_abs_diff_eq!(quotient.value, 1.5, epsilon = 1e-10);
1345
1346 let neg_a = -a.clone();
1348 assert_abs_diff_eq!(neg_a.value, -3.0, epsilon = 1e-10);
1349 }
1350
1351 #[test]
1352 fn test_scalar_operations() {
1353 let var = ReverseVariable::constant(4.0);
1354
1355 let result = var.clone() + 2.0;
1357 assert_abs_diff_eq!(result.value, 6.0, epsilon = 1e-10);
1358
1359 let result = 2.0 + var.clone();
1361 assert_abs_diff_eq!(result.value, 6.0, epsilon = 1e-10);
1362
1363 let result = var.clone() * 3.0;
1365 assert_abs_diff_eq!(result.value, 12.0, epsilon = 1e-10);
1366
1367 let result = var.clone() / 2.0;
1369 assert_abs_diff_eq!(result.value, 2.0, epsilon = 1e-10);
1370
1371 let result = 8.0 / var.clone();
1373 assert_abs_diff_eq!(result.value, 2.0, epsilon = 1e-10);
1374 }
1375
1376 #[test]
1377 fn test_mathematical_functions_without_graph() {
1378 let var = ReverseVariable::constant(4.0);
1379
1380 let result = var.powi(2);
1382 assert_abs_diff_eq!(result.value, 16.0, epsilon = 1e-10);
1383
1384 let result = var.sqrt();
1386 assert_abs_diff_eq!(result.value, 2.0, epsilon = 1e-10);
1387
1388 let var_zero = ReverseVariable::constant(0.0);
1390 let result = var_zero.exp();
1391 assert_abs_diff_eq!(result.value, 1.0, epsilon = 1e-10);
1392
1393 let var_e = ReverseVariable::constant(std::f64::consts::E);
1395 let result = var_e.ln();
1396 assert_abs_diff_eq!(result.value, 1.0, epsilon = 1e-10);
1397
1398 let var_zero = ReverseVariable::constant(0.0);
1400 assert_abs_diff_eq!(var_zero.sin().value, 0.0, epsilon = 1e-10);
1401 assert_abs_diff_eq!(var_zero.cos().value, 1.0, epsilon = 1e-10);
1402 assert_abs_diff_eq!(var_zero.tan().value, 0.0, epsilon = 1e-10);
1403 }
1404
1405 #[test]
1406 fn test_advanced_operations_with_graph() {
1407 let mut graph = ComputationGraph::new();
1408
1409 let x = graph.variable(0.0);
1411 let sig = sigmoid(&mut graph, &x);
1412 assert_abs_diff_eq!(sig.value, 0.5, epsilon = 1e-10); graph.backward(&sig).unwrap();
1415 assert_abs_diff_eq!(graph.get_gradient(&x), 0.25, epsilon = 1e-10); graph.zero_gradients();
1419 let x_pos = graph.variable(2.0);
1420 let relu_pos = relu(&mut graph, &x_pos);
1421 assert_abs_diff_eq!(relu_pos.value, 2.0, epsilon = 1e-10);
1422
1423 graph.backward(&relu_pos).unwrap();
1424 assert_abs_diff_eq!(graph.get_gradient(&x_pos), 1.0, epsilon = 1e-10); let mut graph2 = ComputationGraph::new();
1428 let x_neg = graph2.variable(-1.0);
1429 let relu_neg = relu(&mut graph2, &x_neg);
1430 assert_abs_diff_eq!(relu_neg.value, 0.0, epsilon = 1e-10);
1431
1432 graph2.backward(&relu_neg).unwrap();
1433 assert_abs_diff_eq!(graph2.get_gradient(&x_neg), 0.0, epsilon = 1e-10); }
1435
1436 #[test]
1437 fn test_leaky_relu() {
1438 let mut graph = ComputationGraph::new();
1439
1440 let x_pos = graph.variable(2.0);
1442 let leaky_pos = leaky_relu(&mut graph, &x_pos, 0.01);
1443 assert_abs_diff_eq!(leaky_pos.value, 2.0, epsilon = 1e-10);
1444
1445 graph.backward(&leaky_pos).unwrap();
1446 assert_abs_diff_eq!(graph.get_gradient(&x_pos), 1.0, epsilon = 1e-10);
1447
1448 let mut graph2 = ComputationGraph::new();
1450 let x_neg = graph2.variable(-2.0);
1451 let leaky_neg = leaky_relu(&mut graph2, &x_neg, 0.01);
1452 assert_abs_diff_eq!(leaky_neg.value, -0.02, epsilon = 1e-10);
1453
1454 graph2.backward(&leaky_neg).unwrap();
1455 assert_abs_diff_eq!(graph2.get_gradient(&x_neg), 0.01, epsilon = 1e-10);
1456 }
1457
1458 #[test]
1459 fn test_complex_expression() {
1460 let mut graph = ComputationGraph::new();
1461
1462 let x = graph.variable(1.0);
1464 let y = graph.variable(0.5);
1465
1466 let x_squared = mul(&mut graph, &x, &x);
1467 let x_sq_plus_y = add(&mut graph, &x_squared, &y);
1468 let sig_term = sigmoid(&mut graph, &x_sq_plus_y);
1469
1470 let x_minus_y = sub(&mut graph, &x, &y);
1471 let tanh_term = tanh(&mut graph, &x_minus_y);
1472
1473 let result = mul(&mut graph, &sig_term, &tanh_term);
1474
1475 assert!(result.value.is_finite());
1477 assert!(result.value > 0.0); graph.backward(&result).unwrap();
1481
1482 let grad_x = graph.get_gradient(&x);
1484 let grad_y = graph.get_gradient(&y);
1485
1486 assert!(grad_x.is_finite());
1487 assert!(grad_y.is_finite());
1488 assert!(grad_x != 0.0);
1489 assert!(grad_y != 0.0);
1490 }
1491}