1use crate::ndarray::compat::ArrayStatCompat;
20use ::ndarray::{Array, ArrayD, Dimension, IxDyn};
21
22use std::cell::RefCell;
23use std::collections::{HashMap, HashSet};
24use std::rc::Rc;
25
26use crate::array_protocol::operations::matmul;
27use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
28use crate::error::{CoreError, CoreResult, ErrorContext};
29
30#[derive(Clone)]
32pub struct GradientDict {
33 gradients: HashMap<String, Box<dyn ArrayProtocol>>,
34}
35
36impl std::fmt::Debug for GradientDict {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("GradientDict")
39 .field(
40 "gradients",
41 &format!("{{keys: {:?}}}", self.gradients.keys().collect::<Vec<_>>()),
42 )
43 .finish()
44 }
45}
46
47impl GradientDict {
48 pub fn new() -> Self {
50 Self {
51 gradients: HashMap::new(),
52 }
53 }
54
55 pub fn insert(&mut self, name: String, gradient: Box<dyn ArrayProtocol>) {
57 self.gradients.insert(name, gradient);
58 }
59
60 pub fn get(&self, name: &str) -> Option<&dyn ArrayProtocol> {
62 self.gradients.get(name).map(|b| b.as_ref())
63 }
64
65 pub fn get_mut(&mut self, name: &str) -> Option<&mut Box<dyn ArrayProtocol>> {
67 self.gradients.get_mut(name)
68 }
69
70 pub fn iter(&self) -> impl Iterator<Item = (&String, &Box<dyn ArrayProtocol>)> {
72 self.gradients.iter()
73 }
74
75 pub fn merge(&mut self, other: GradientDict) {
77 for (name, gradient) in other.gradients {
78 self.gradients.insert(name, gradient);
79 }
80 }
81
82 pub fn is_empty(&self) -> bool {
84 self.gradients.is_empty()
85 }
86
87 pub fn len(&self) -> usize {
89 self.gradients.len()
90 }
91
92 pub fn clear(&mut self) {
94 self.gradients.clear();
95 }
96
97 pub fn keys(&self) -> impl Iterator<Item = &String> {
99 self.gradients.keys()
100 }
101
102 pub fn values(&self) -> impl Iterator<Item = &Box<dyn ArrayProtocol>> {
104 self.gradients.values()
105 }
106}
107
108impl Default for GradientDict {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114#[allow(dead_code)]
116fn boxed_to_rc(boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
117 let array_ref = boxed.as_ref();
122
123 if let Some(ndarray_wrapper) = array_ref
126 .as_any()
127 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
128 {
129 let array_clone = ndarray_wrapper.as_array().clone();
130 return Rc::new(NdarrayWrapper::new(array_clone));
131 }
132
133 let fallback_array = ArrayD::<f64>::zeros(IxDyn(&[1, 1]));
136 Rc::new(NdarrayWrapper::new(fallback_array))
137}
138
139#[allow(dead_code)]
141fn box_to_rc_array_protocol(boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
142 boxed_to_rc(boxed)
143}
144
145#[allow(dead_code)]
147fn add(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
148 crate::array_protocol::operations::add(a, b).map_err(|e| e.into())
149}
150
151#[allow(dead_code)]
152fn multiply(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
153 crate::array_protocol::operations::multiply(a, b).map_err(|e| e.into())
154}
155
156#[allow(dead_code)]
157fn subtract(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
158 crate::array_protocol::operations::subtract(a, b).map_err(|e| e.into())
159}
160
161#[allow(dead_code)]
162fn ones_like(a: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
163 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
166 let shape = a_array.as_array().shape();
167 let ones = ArrayD::<f64>::ones(IxDyn(shape));
168 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
169 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
170 let shape = a_array.as_array().shape();
171 let ones = ArrayD::<f32>::ones(IxDyn(shape));
172 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
173 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
174 let shape = a_array.as_array().shape();
175 let ones = ArrayD::<i32>::ones(IxDyn(shape));
176 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
177 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
178 let shape = a_array.as_array().shape();
179 let ones = ArrayD::<i64>::ones(IxDyn(shape));
180 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
181 } else {
182 let shape = a.shape().to_vec();
184 let ones = ArrayD::<f64>::ones(IxDyn(&shape));
185 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
186 }
187}
188
189#[allow(dead_code)]
190fn broadcast_to(a: &dyn ArrayProtocol, shape: &[usize]) -> CoreResult<Box<dyn ArrayProtocol>> {
191 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
193 let array = a_array.as_array();
194 if array.len() == 1 {
196 let value = array.iter().next().cloned().unwrap_or(0.0);
197 let broadcasted = ArrayD::<f64>::from_elem(IxDyn(shape), value);
198 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
199 } else if array.shape() == shape {
200 Ok(Box::new(NdarrayWrapper::new(array.clone())) as Box<dyn ArrayProtocol>)
202 } else {
203 let inputshape = array.shape();
205 let _ndim_diff = shape.len().saturating_sub(inputshape.len());
206
207 let mut can_broadcast = true;
209 for i in 0..inputshape.len() {
210 let input_dim = inputshape[inputshape.len() - 1 - i];
211 let target_dim = shape[shape.len() - 1 - i];
212 if input_dim != 1 && input_dim != target_dim {
213 can_broadcast = false;
214 break;
215 }
216 }
217
218 if can_broadcast {
219 if let Some(broadcasted_view) = array.broadcast(IxDyn(shape)) {
221 let broadcasted = broadcasted_view.to_owned();
222 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
223 } else {
224 Err(CoreError::NotImplementedError(ErrorContext::new(
225 "Broadcasting failed for these shapes".to_string(),
226 )))
227 }
228 } else {
229 Err(CoreError::NotImplementedError(ErrorContext::new(
230 "Incompatible shapes for broadcasting".to_string(),
231 )))
232 }
233 }
234 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
235 let array = a_array.as_array();
236 if array.len() == 1 {
237 let value = array.iter().next().cloned().unwrap_or(0.0);
238 let broadcasted = ArrayD::<f32>::from_elem(IxDyn(shape), value);
239 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
240 } else if array.shape() == shape {
241 Ok(Box::new(NdarrayWrapper::new(array.clone())) as Box<dyn ArrayProtocol>)
242 } else if let Some(broadcasted_view) = array.broadcast(IxDyn(shape)) {
243 let broadcasted = broadcasted_view.to_owned();
244 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
245 } else {
246 Err(CoreError::NotImplementedError(ErrorContext::new(
247 "Broadcasting failed for these shapes".to_string(),
248 )))
249 }
250 } else {
251 let ones = ArrayD::<f64>::ones(IxDyn(shape));
253 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
254 }
255}
256
257#[derive(Clone)]
259struct Node {
260 value: Rc<dyn ArrayProtocol>,
262
263 grad: Option<Rc<dyn ArrayProtocol>>,
265
266 op: Option<String>,
268
269 inputs: Vec<GradientTensor>,
271
272 requiresgrad: bool,
274
275 is_leaf: bool,
277}
278
279impl Node {
280 fn leaf(requiresgrad: bool) -> Self {
282 Self {
283 value: Rc::new(NdarrayWrapper::new(
284 crate::ndarray::Array0::<f64>::zeros(()),
285 )) as Rc<dyn ArrayProtocol>,
286 grad: None,
287 op: None,
288 inputs: Vec::new(),
289 requiresgrad,
290 is_leaf: true,
291 }
292 }
293
294 fn new_op(value: Rc<dyn ArrayProtocol>, op: String, inputs: Vec<GradientTensor>) -> Self {
296 let requiresgrad = inputs.iter().any(|x| x.requiresgrad());
297
298 Self {
299 value,
300 grad: None,
301 op: Some(op),
302 inputs,
303 requiresgrad,
304 is_leaf: false,
305 }
306 }
307}
308
309#[derive(Clone)]
311pub struct GradientTensor {
312 node: Rc<RefCell<Node>>,
314}
315
316impl GradientTensor {
317 pub fn new(value: Rc<dyn ArrayProtocol>, requiresgrad: bool) -> Self {
319 let mut node_inner = Node::leaf(requiresgrad);
320 node_inner.value = value;
321 node_inner.grad = None;
322 let node = Rc::new(RefCell::new(node_inner));
323 Self { node }
324 }
325
326 pub fn from_array<T, D>(array: Array<T, D>, requiresgrad: bool) -> Self
328 where
329 T: Clone + Send + Sync + 'static,
330 D: Dimension + Send + Sync + 'static,
331 {
332 let value = Rc::new(NdarrayWrapper::new(array)) as Rc<dyn ArrayProtocol>;
333 Self::new(value, requiresgrad)
334 }
335
336 pub fn value(&self) -> Rc<dyn ArrayProtocol> {
338 self.node.borrow().value.clone()
339 }
340
341 pub fn grad_2(&self) -> Option<Rc<dyn ArrayProtocol>> {
343 self.node.borrow().grad.clone()
344 }
345
346 pub fn requiresgrad(&self) -> bool {
348 self.node.borrow().requiresgrad
349 }
350
351 pub fn set_requiresgrad(&mut self, requiresgrad: bool) {
353 self.node.borrow_mut().requiresgrad = requiresgrad;
354 }
355
356 pub fn is_leaf(&self) -> bool {
358 self.node.borrow().is_leaf
359 }
360
361 fn from_op(value: Rc<dyn ArrayProtocol>, op: String, inputs: Vec<GradientTensor>) -> Self {
363 let node = Rc::new(RefCell::new(Node::new_op(value, op, inputs)));
364 Self { node }
365 }
366
367 pub fn set_value(&mut self, newvalue: Rc<dyn ArrayProtocol>) {
369 self.node.borrow_mut().grad = None; self.node.borrow_mut().value = newvalue;
371 }
372
373 pub fn backward(&self) -> CoreResult<()> {
375 let gradshape = if let Some(array) = self
377 .value()
378 .as_any()
379 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
380 {
381 array.as_array().raw_dim()
382 } else {
383 crate::ndarray::IxDyn(&[1])
385 };
386
387 let grad_array = Array::<f64, IxDyn>::ones(gradshape);
388 let grad = Rc::new(NdarrayWrapper::new(grad_array)) as Rc<dyn ArrayProtocol>;
389
390 self.backward_with_grad(grad)
392 }
393
394 fn backward_with_grad(&self, grad: Rc<dyn ArrayProtocol>) -> CoreResult<()> {
396 self.node.borrow_mut().grad = Some(grad.clone());
398
399 let mut visited = HashSet::new();
401 let mut topo = Vec::new();
402
403 fn build_topo(
405 tensor: &GradientTensor,
406 visited: &mut HashSet<*const RefCell<Node>>,
407 topo: &mut Vec<GradientTensor>,
408 ) {
409 let node_ptr = Rc::as_ptr(&tensor.node);
410 if !visited.contains(&node_ptr) {
411 visited.insert(node_ptr);
412
413 for input in &tensor.node.borrow().inputs {
415 build_topo(input, visited, topo);
416 }
417
418 topo.push(tensor.clone());
420 }
421 }
422
423 build_topo(self, &mut visited, &mut topo);
425
426 for node in topo.iter().rev() {
428 if !node.requiresgrad() {
430 continue;
431 }
432
433 let node_grad = match node.grad_2() {
435 Some(g) => g,
436 None => continue, };
438
439 if node.is_leaf() {
441 continue;
442 }
443
444 let op = match &node.node.borrow().op {
446 Some(op) => op.clone(),
447 None => continue, };
449
450 let inputs = node.node.borrow().inputs.clone();
451
452 match op.as_str() {
454 "add" => {
455 for input in &inputs {
457 if input.requiresgrad() {
458 let mut input_node = input.node.borrow_mut();
459 if let Some(input_grad) = &input_node.grad {
460 if let Ok(sum) = add(input_grad.as_ref(), node_grad.as_ref()) {
462 input_node.grad = Some(sum.into());
463 }
464 } else {
465 input_node.grad = Some(node_grad.clone());
466 }
467 }
468 }
469 }
470 "multiply" => {
471 if inputs.len() == 2 {
473 let (a, b) = (&inputs[0], &inputs[1]);
474
475 if a.requiresgrad() {
477 let b_value = b.value();
478 if let Ok(grad_a) = multiply(node_grad.as_ref(), b_value.as_ref()) {
479 let mut a_node = a.node.borrow_mut();
480 if let Some(a_grad) = &a_node.grad {
481 if let Ok(sum) = add(a_grad.as_ref(), grad_a.as_ref()) {
483 a_node.grad = Some(box_to_rc_array_protocol(sum));
484 }
485 } else {
486 a_node.grad = Some(box_to_rc_array_protocol(grad_a));
487 }
488 }
489 }
490
491 if b.requiresgrad() {
493 let a_value = a.value();
494 if let Ok(grad_b) = multiply(node_grad.as_ref(), a_value.as_ref()) {
495 let mut b_node = b.node.borrow_mut();
496 if let Some(b_grad) = &b_node.grad {
497 if let Ok(sum) = add(b_grad.as_ref(), grad_b.as_ref()) {
499 b_node.grad = Some(box_to_rc_array_protocol(sum));
500 }
501 } else {
502 b_node.grad = Some(box_to_rc_array_protocol(grad_b));
503 }
504 }
505 }
506 }
507 }
508 "matmul" => {
509 if inputs.len() == 2 {
513 let (a, b) = (&inputs[0], &inputs[1]);
514
515 if a.requiresgrad() {
517 if let (Some(b_array), Some(grad_out_array)) = (
518 b.value()
519 .as_any()
520 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
521 node_grad
522 .as_any()
523 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
524 ) {
525 let b_array_val = b_array.as_array();
526 let grad_out_array_val = grad_out_array.as_array();
527
528 let b_t = b_array_val.t();
530
531 let grad_outshape = grad_out_array_val.shape();
534 let grad_out_rows = grad_outshape[0];
535 let grad_out_cols = if grad_outshape.len() > 1 {
536 grad_outshape.iter().skip(1).product()
537 } else {
538 1
539 };
540 let grad_out_2d = grad_out_array_val
541 .clone()
542 .into_shape_with_order((grad_out_rows, grad_out_cols))
543 .unwrap();
544
545 let b_tshape = b_t.shape();
546 let b_t_rows = b_tshape[0];
547 let b_t_cols = if b_tshape.len() > 1 {
548 b_tshape.iter().skip(1).product()
549 } else {
550 1
551 };
552 let b_t_2d = b_t
553 .clone()
554 .into_shape_with_order((b_t_rows, b_t_cols))
555 .unwrap();
556
557 let grad_a_val = grad_out_2d.dot(&b_t_2d);
559
560 let grad_a_dyn = grad_a_val.into_dyn();
562 let grad_a = NdarrayWrapper::new(grad_a_dyn);
563
564 let mut a_node = a.node.borrow_mut();
566 if let Some(a_grad) = &a_node.grad {
567 if let (Some(a_grad_array), Some(grad_a_array)) = (
568 a_grad
569 .as_any()
570 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
571 grad_a
572 .as_any()
573 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
574 ) {
575 let sum = a_grad_array.as_array() + grad_a_array.as_array();
577 a_node.grad = Some(Rc::new(NdarrayWrapper::new(sum)));
578 }
579 } else {
580 a_node.grad = Some(Rc::new(grad_a));
582 }
583 }
584 }
585
586 if b.requiresgrad() {
588 if let (Some(a_array), Some(grad_out_array)) = (
589 a.value()
590 .as_any()
591 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
592 node_grad
593 .as_any()
594 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
595 ) {
596 let a_array_val = a_array.as_array();
597 let grad_out_array_val = grad_out_array.as_array();
598
599 let a_t = a_array_val.t();
601
602 let grad_outshape = grad_out_array_val.shape();
605 let grad_out_rows = grad_outshape[0];
606 let grad_out_cols = if grad_outshape.len() > 1 {
607 grad_outshape.iter().skip(1).product()
608 } else {
609 1
610 };
611 let grad_out_2d = grad_out_array_val
612 .clone()
613 .into_shape_with_order((grad_out_rows, grad_out_cols))
614 .unwrap();
615
616 let a_tshape = a_t.shape();
617 let a_t_rows = a_tshape[0];
618 let a_t_cols = if a_tshape.len() > 1 {
619 a_tshape.iter().skip(1).product()
620 } else {
621 1
622 };
623 let a_t_2d = a_t
624 .clone()
625 .into_shape_with_order((a_t_rows, a_t_cols))
626 .unwrap();
627
628 let grad_b_val = a_t_2d.dot(&grad_out_2d);
630
631 let grad_b_dyn = grad_b_val.into_dyn();
633 let grad_b = NdarrayWrapper::new(grad_b_dyn);
634
635 let mut b_node = b.node.borrow_mut();
637 if let Some(b_grad) = &b_node.grad {
638 if let (Some(b_grad_array), Some(grad_b_array)) = (
639 b_grad
640 .as_any()
641 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
642 grad_b
643 .as_any()
644 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
645 ) {
646 let sum = b_grad_array.as_array() + grad_b_array.as_array();
648 b_node.grad = Some(Rc::new(NdarrayWrapper::new(sum)));
649 }
650 } else {
651 b_node.grad = Some(Rc::new(grad_b));
653 }
654 }
655 }
656 }
657 }
658 "subtract" => {
659 if inputs.len() == 2 {
661 let (a, b) = (&inputs[0], &inputs[1]);
662
663 if a.requiresgrad() {
665 let mut a_node = a.node.borrow_mut();
666 if let Some(a_grad) = &a_node.grad {
667 if let Ok(sum) = add(a_grad.as_ref(), node_grad.as_ref()) {
669 a_node.grad = Some(box_to_rc_array_protocol(sum));
670 }
671 } else {
672 a_node.grad = Some(node_grad.clone());
673 }
674 }
675
676 if b.requiresgrad() {
678 if let Ok(neg_grad) = multiply_by_scalar(node_grad.as_ref(), -1.0) {
679 let mut b_node = b.node.borrow_mut();
680 if let Some(b_grad) = &b_node.grad {
681 if let Ok(sum) = add(b_grad.as_ref(), neg_grad.as_ref()) {
683 b_node.grad = Some(box_to_rc_array_protocol(sum));
684 }
685 } else {
686 b_node.grad = Some(box_to_rc_array_protocol(neg_grad));
687 }
688 }
689 }
690 }
691 }
692 "divide" => {
693 if inputs.len() == 2 {
695 let (a, b) = (&inputs[0], &inputs[1]);
696
697 if a.requiresgrad() {
699 let b_value = b.value();
700 if let Ok(grad_a) = divide(node_grad.as_ref(), b_value.as_ref()) {
701 let mut a_node = a.node.borrow_mut();
702 if let Some(a_grad) = &a_node.grad {
703 if let Ok(sum) = add(a_grad.as_ref(), grad_a.as_ref()) {
705 a_node.grad = Some(box_to_rc_array_protocol(sum));
706 }
707 } else {
708 a_node.grad = Some(box_to_rc_array_protocol(grad_a));
709 }
710 }
711 }
712
713 if b.requiresgrad() {
715 let a_value = a.value();
716 let b_value = b.value();
717
718 if let Ok(b_squared) = multiply(b_value.as_ref(), b_value.as_ref()) {
720 if let Ok(grad_times_a) =
722 multiply(node_grad.as_ref(), a_value.as_ref())
723 {
724 if let Ok(div_result) =
726 divide(grad_times_a.as_ref(), b_squared.as_ref())
727 {
728 if let Ok(grad_b) =
730 multiply_by_scalar(div_result.as_ref(), -1.0)
731 {
732 let mut b_node = b.node.borrow_mut();
733 if let Some(b_grad) = &b_node.grad {
734 if let Ok(sum) =
736 add(b_grad.as_ref(), grad_b.as_ref())
737 {
738 b_node.grad =
739 Some(box_to_rc_array_protocol(sum));
740 }
741 } else {
742 b_node.grad =
743 Some(box_to_rc_array_protocol(grad_b));
744 }
745 }
746 }
747 }
748 }
749 }
750 }
751 }
752 "sigmoid" => {
753 if inputs.len() == 1 {
755 let input = &inputs[0];
756
757 if input.requiresgrad() {
758 let sigmoid_value = node.value();
760
761 if let Ok(ones) = ones_like(sigmoid_value.as_ref()) {
763 if let Ok(one_minus_sigmoid) =
764 subtract(ones.as_ref(), sigmoid_value.as_ref())
765 {
766 if let Ok(sigmoid_deriv) =
768 multiply(sigmoid_value.as_ref(), one_minus_sigmoid.as_ref())
769 {
770 if let Ok(grad_input) =
772 multiply(node_grad.as_ref(), sigmoid_deriv.as_ref())
773 {
774 let mut input_node = input.node.borrow_mut();
775 if let Some(input_grad) = &input_node.grad {
776 if let Ok(sum) =
778 add(input_grad.as_ref(), grad_input.as_ref())
779 {
780 input_node.grad =
781 Some(box_to_rc_array_protocol(sum));
782 }
783 } else {
784 input_node.grad =
785 Some(box_to_rc_array_protocol(grad_input));
786 }
787 }
788 }
789 }
790 }
791 }
792 }
793 }
794 "mean" => {
795 if inputs.len() == 1 {
797 let input = &inputs[0];
798
799 if input.requiresgrad() {
800 let input_value = input.value();
802 if let Some(inputarray) = input_value
803 .as_any()
804 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
805 {
806 let n_elements = inputarray.as_array().len() as f64;
807
808 if let Ok(grad_input) =
810 multiply_by_scalar(node_grad.as_ref(), 1.0 / n_elements)
811 {
812 if let Ok(broadcasted_grad) = broadcast_to(
814 grad_input.as_ref(),
815 inputarray.as_array().shape(),
816 ) {
817 let mut input_node = input.node.borrow_mut();
818 if let Some(input_grad) = &input_node.grad {
819 if let Ok(sum) =
821 add(input_grad.as_ref(), broadcasted_grad.as_ref())
822 {
823 input_node.grad =
824 Some(box_to_rc_array_protocol(sum));
825 }
826 } else {
827 input_node.grad =
828 Some(box_to_rc_array_protocol(broadcasted_grad));
829 }
830 }
831 }
832 }
833 }
834 }
835 }
836 _ => {
837 }
839 }
840 }
841
842 Ok(())
843 }
844
845 pub fn detach(&self) -> Self {
847 GradientTensor::new(self.value(), false)
848 }
849}
850
851#[allow(dead_code)]
854pub fn grad_add(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
855 let a_value = a.value();
856 let b_value = b.value();
857
858 let result = add(a_value.as_ref(), b_value.as_ref())?;
860
861 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
863 Ok(GradientTensor::from_op(
864 result_rc,
865 "add".to_string(),
866 vec![a.clone(), b.clone()],
867 ))
868}
869
870#[allow(dead_code)]
872pub fn grad_multiply(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
873 let a_value = a.value();
874 let b_value = b.value();
875
876 let result = multiply(a_value.as_ref(), b_value.as_ref())?;
878
879 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
881 Ok(GradientTensor::from_op(
882 result_rc,
883 "multiply".to_string(),
884 vec![a.clone(), b.clone()],
885 ))
886}
887
888#[allow(dead_code)]
890pub fn grad_matmul(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
891 let a_value = a.value();
892 let b_value = b.value();
893
894 let result = matmul(a_value.as_ref(), b_value.as_ref())?;
896
897 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
899 Ok(GradientTensor::from_op(
900 result_rc,
901 "matmul".to_string(),
902 vec![a.clone(), b.clone()],
903 ))
904}
905
906#[allow(dead_code)]
908pub fn grad_subtract(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
909 let a_value = a.value();
910 let b_value = b.value();
911
912 let result = subtract(a_value.as_ref(), b_value.as_ref())?;
914
915 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
917 Ok(GradientTensor::from_op(
918 result_rc,
919 "subtract".to_string(),
920 vec![a.clone(), b.clone()],
921 ))
922}
923
924#[allow(dead_code)]
926pub fn grad_divide(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
927 let a_value = a.value();
928 let b_value = b.value();
929
930 let result = divide(a_value.as_ref(), b_value.as_ref())?;
932
933 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
935 Ok(GradientTensor::from_op(
936 result_rc,
937 "divide".to_string(),
938 vec![a.clone(), b.clone()],
939 ))
940}
941
942#[allow(dead_code)]
944pub fn grad_sigmoid(a: &GradientTensor) -> CoreResult<GradientTensor> {
945 let a_value = a.value();
946
947 if let Some(a_array) = a_value
949 .as_any()
950 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
951 {
952 let array = a_array.as_array();
953 let result = array.mapv(|x| 1.0 / (1.0 + (-x).exp()));
954 let result_wrapped = NdarrayWrapper::new(result);
955 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
956 Ok(GradientTensor::from_op(
957 result_rc,
958 "sigmoid".to_string(),
959 vec![a.clone()],
960 ))
961 } else if let Some(a_array) = a_value
962 .as_any()
963 .downcast_ref::<NdarrayWrapper<f32, IxDyn>>()
964 {
965 let array = a_array.as_array();
966 let result = array.mapv(|x| 1.0f32 / (1.0f32 + (-x).exp()));
967 let result_wrapped = NdarrayWrapper::new(result);
968 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
969 Ok(GradientTensor::from_op(
970 result_rc,
971 "sigmoid".to_string(),
972 vec![a.clone()],
973 ))
974 } else {
975 Err(CoreError::NotImplementedError(ErrorContext::new(
976 "sigmoid not implemented for this array type".to_string(),
977 )))
978 }
979}
980
981#[allow(dead_code)]
983pub fn grad_mean(a: &GradientTensor) -> CoreResult<GradientTensor> {
984 let a_value = a.value();
985
986 if let Some(a_array) = a_value
988 .as_any()
989 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
990 {
991 let array = a_array.as_array();
992 let mean_value = array.mean_or(0.0);
993 let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
994 let result_wrapped = NdarrayWrapper::new(result);
995 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
996 Ok(GradientTensor::from_op(
997 result_rc,
998 "mean".to_string(),
999 vec![a.clone()],
1000 ))
1001 } else if let Some(a_array) = a_value
1002 .as_any()
1003 .downcast_ref::<NdarrayWrapper<f32, IxDyn>>()
1004 {
1005 let array = a_array.as_array();
1006 let mean_value = array.mean_or(0.0f32);
1007 let result = ArrayD::<f32>::from_elem(IxDyn(&[1]), mean_value);
1008 let result_wrapped = NdarrayWrapper::new(result);
1009 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
1010 Ok(GradientTensor::from_op(
1011 result_rc,
1012 "mean".to_string(),
1013 vec![a.clone()],
1014 ))
1015 } else {
1016 Err(CoreError::NotImplementedError(ErrorContext::new(
1017 "mean not implemented for this array type".to_string(),
1018 )))
1019 }
1020}
1021
1022pub struct Variable {
1024 tensor: GradientTensor,
1026
1027 name: String,
1029}
1030
1031impl Variable {
1032 pub fn new<T, D>(name: &str, array: Array<T, D>) -> Self
1034 where
1035 T: Clone + Send + Sync + 'static,
1036 D: Dimension + Send + Sync + 'static,
1037 {
1038 let tensor = GradientTensor::from_array(array, true);
1039 Self {
1040 tensor,
1041 name: name.to_string(),
1042 }
1043 }
1044
1045 pub const fn tensor(&self) -> &GradientTensor {
1047 &self.tensor
1048 }
1049
1050 pub fn value(&self) -> Rc<dyn ArrayProtocol> {
1052 self.tensor.value()
1053 }
1054
1055 pub fn grad_2(&self) -> Option<Rc<dyn ArrayProtocol>> {
1057 self.tensor.grad_2()
1058 }
1059
1060 pub fn name(&self) -> &str {
1062 &self.name
1063 }
1064
1065 pub fn set_gradient(&mut self, gradient: Box<dyn ArrayProtocol>) -> CoreResult<()> {
1067 let gradient_rc = self.box_to_rc(gradient);
1069
1070 self.tensor.node.borrow_mut().grad = Some(gradient_rc);
1072 Ok(())
1073 }
1074
1075 pub fn set_value(&mut self, newvalue: Box<dyn ArrayProtocol>) {
1077 let newvalue_rc = self.box_to_rc(newvalue);
1078 self.tensor.set_value(newvalue_rc);
1079 }
1080
1081 fn box_to_rc(&self, boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
1083 if let Some(ndarray_wrapper) = boxed
1085 .as_ref()
1086 .as_any()
1087 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
1088 {
1089 let array_clone = ndarray_wrapper.as_array().clone();
1090 Rc::new(NdarrayWrapper::new(array_clone))
1091 } else {
1092 let fallback_array = ArrayD::<f64>::zeros(IxDyn(&[1, 1]));
1094 Rc::new(NdarrayWrapper::new(fallback_array))
1095 }
1096 }
1097}
1098
1099pub trait Optimizer {
1101 fn step(&mut self) -> CoreResult<()>;
1103
1104 fn zero_grad(&mut self);
1106
1107 fn add_variable(&mut self, var: Variable);
1109
1110 fn variables(&self) -> &[Variable];
1112
1113 fn accumulate_gradients(&mut self, gradients: &GradientDict) -> CoreResult<()> {
1115 for (param_name, gradient) in gradients.iter() {
1117 for var in self.variables_mut() {
1119 if var.name() == param_name {
1120 var.set_gradient(gradient.clone())?;
1121 break;
1122 }
1123 }
1124 }
1125 Ok(())
1126 }
1127
1128 fn variables_mut(&mut self) -> &mut [Variable] {
1130 &mut []
1133 }
1134}
1135
1136pub struct SGD {
1138 variables: Vec<Variable>,
1140
1141 learningrate: f64,
1143
1144 momentum: f64,
1146
1147 velocity: Vec<Option<Box<dyn ArrayProtocol>>>,
1149}
1150
1151impl SGD {
1152 pub fn new(learningrate: f64, momentum: Option<f64>) -> Self {
1154 Self {
1155 variables: Vec::new(),
1156 learningrate,
1157 momentum: momentum.unwrap_or(0.0),
1158 velocity: Vec::new(),
1159 }
1160 }
1161
1162 pub fn set_learningrate(&mut self, learningrate: f64) {
1164 self.learningrate = learningrate;
1165 }
1166}
1167
1168impl Optimizer for SGD {
1169 fn step(&mut self) -> CoreResult<()> {
1170 for (i, var) in self.variables.iter_mut().enumerate() {
1171 if let Some(grad) = var.grad_2() {
1172 let var_value = var.value();
1173
1174 let update = if self.momentum > 0.0 {
1176 if i >= self.velocity.len() {
1177 self.velocity.resize_with(i + 1, || None);
1178 }
1179
1180 if let Some(vel) = &self.velocity[i] {
1181 let scaled_grad = multiply_by_scalar(grad.as_ref(), self.learningrate)?;
1183 let scaled_vel = multiply_by_scalar(vel.as_ref(), self.momentum)?;
1184 let update = add(scaled_vel.as_ref(), scaled_grad.as_ref())?;
1185 self.velocity[i] = Some(update.clone());
1186 update
1187 } else {
1188 let update = multiply_by_scalar(grad.as_ref(), self.learningrate)?;
1190 self.velocity[i] = Some(update.clone());
1191 update
1192 }
1193 } else {
1194 multiply_by_scalar(grad.as_ref(), self.learningrate)?
1196 };
1197
1198 let updated_value = subtract_arrays(var_value.as_ref(), update.as_ref())?;
1200 var.set_value(updated_value);
1201 }
1202 }
1203
1204 Ok(())
1205 }
1206
1207 fn zero_grad(&mut self) {
1208 for var in &self.variables {
1209 var.tensor.node.borrow_mut().grad = None;
1210 }
1211 }
1212
1213 fn add_variable(&mut self, var: Variable) {
1214 self.variables.push(var);
1215 self.velocity.push(None);
1216 }
1217
1218 fn variables(&self) -> &[Variable] {
1219 &self.variables
1220 }
1221
1222 fn variables_mut(&mut self) -> &mut [Variable] {
1223 &mut self.variables
1224 }
1225}
1226
1227pub struct Adam {
1229 variables: Vec<Variable>,
1231
1232 learningrate: f64,
1234
1235 beta1: f64,
1237
1238 beta2: f64,
1240
1241 epsilon: f64,
1243
1244 m: Vec<Option<Box<dyn ArrayProtocol>>>,
1246
1247 v: Vec<Option<Box<dyn ArrayProtocol>>>,
1249
1250 t: usize,
1252}
1253
1254impl Adam {
1255 pub fn new(
1257 learningrate: f64,
1258 beta1: Option<f64>,
1259 beta2: Option<f64>,
1260 epsilon: Option<f64>,
1261 ) -> Self {
1262 Self {
1263 variables: Vec::new(),
1264 learningrate,
1265 beta1: beta1.unwrap_or(0.9),
1266 beta2: beta2.unwrap_or(0.999),
1267 epsilon: epsilon.unwrap_or(1e-8),
1268 m: Vec::new(),
1269 v: Vec::new(),
1270 t: 0,
1271 }
1272 }
1273}
1274
1275impl Optimizer for Adam {
1276 fn step(&mut self) -> CoreResult<()> {
1277 self.t += 1;
1278
1279 for (i, var) in self.variables.iter_mut().enumerate() {
1280 if let Some(grad) = var.grad_2() {
1281 let var_value = var.value();
1282
1283 if i >= self.m.len() {
1285 self.m.resize_with(i + 1, || None);
1286 self.v.resize_with(i + 1, || None);
1287 }
1288
1289 let m = if let Some(m_prev) = &self.m[i] {
1291 let scaled_m = multiply_by_scalar(m_prev.as_ref(), self.beta1)?;
1293 let scaled_grad = multiply_by_scalar(grad.as_ref(), 1.0 - self.beta1)?;
1294 add(scaled_m.as_ref(), scaled_grad.as_ref())?
1295 } else {
1296 multiply_by_scalar(grad.as_ref(), 1.0 - self.beta1)?
1298 };
1299
1300 let v = if let Some(v_prev) = &self.v[i] {
1302 let scaled_v = multiply_by_scalar(v_prev.as_ref(), self.beta2)?;
1304 let grad_squared = multiply(grad.as_ref(), grad.as_ref())?;
1305 let scaled_grad_sq =
1306 multiply_by_scalar(grad_squared.as_ref(), 1.0 - self.beta2)?;
1307 add(scaled_v.as_ref(), scaled_grad_sq.as_ref())?
1308 } else {
1309 let grad_squared = multiply(grad.as_ref(), grad.as_ref())?;
1311 multiply_by_scalar(grad_squared.as_ref(), 1.0 - self.beta2)?
1312 };
1313
1314 self.m[i] = Some(m.clone());
1316 self.v[i] = Some(v.clone());
1317
1318 let m_hat =
1320 multiply_by_scalar(m.as_ref(), 1.0 / (1.0 - self.beta1.powi(self.t as i32)))?;
1321 let v_hat =
1322 multiply_by_scalar(v.as_ref(), 1.0 / (1.0 - self.beta2.powi(self.t as i32)))?;
1323
1324 let v_hat_sqrt = sqrt(v_hat.as_ref())?;
1326 let v_hat_sqrt_eps = add_scalar(v_hat_sqrt.as_ref(), self.epsilon)?;
1327 let update_dir = divide(m_hat.as_ref(), v_hat_sqrt_eps.as_ref())?;
1328 let update = multiply_by_scalar(update_dir.as_ref(), self.learningrate)?;
1329
1330 let updated_value = subtract_arrays(var_value.as_ref(), update.as_ref())?;
1332 var.set_value(updated_value);
1333 }
1334 }
1335
1336 Ok(())
1337 }
1338
1339 fn zero_grad(&mut self) {
1340 for var in &self.variables {
1341 var.tensor.node.borrow_mut().grad = None;
1342 }
1343 }
1344
1345 fn add_variable(&mut self, var: Variable) {
1346 self.variables.push(var);
1347 self.m.push(None);
1348 self.v.push(None);
1349 }
1350
1351 fn variables(&self) -> &[Variable] {
1352 &self.variables
1353 }
1354
1355 fn variables_mut(&mut self) -> &mut [Variable] {
1356 &mut self.variables
1357 }
1358}
1359
1360#[allow(dead_code)]
1364fn multiply_by_scalar(a: &dyn ArrayProtocol, scalar: f64) -> CoreResult<Box<dyn ArrayProtocol>> {
1365 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1366 let inputarray = a_array.as_array();
1367 let result = inputarray.mapv(|x| x * scalar);
1368 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1369 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1370 let inputarray = a_array.as_array();
1371 let result = inputarray.mapv(|x| x * scalar as f32);
1372 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1373 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
1374 let inputarray = a_array.as_array();
1375 let result = inputarray.mapv(|x| (x as f64 * scalar) as i32);
1376 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1377 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
1378 let inputarray = a_array.as_array();
1379 let result = inputarray.mapv(|x| (x as f64 * scalar) as i64);
1380 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1381 } else {
1382 Err(CoreError::NotImplementedError(ErrorContext::new(
1383 "multiply_by_scalar not implemented for this array type".to_string(),
1384 )))
1385 }
1386}
1387
1388#[allow(dead_code)]
1390fn subtract_arrays(
1391 a: &dyn ArrayProtocol,
1392 b: &dyn ArrayProtocol,
1393) -> CoreResult<Box<dyn ArrayProtocol>> {
1394 if let (Some(a_wrapper), Some(b_array)) = (
1396 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1397 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1398 ) {
1399 let a_arr = a_wrapper.as_array();
1400 let b_arr = b_array.as_array();
1401 let result = a_arr - b_arr;
1402 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1403 } else if let (Some(a_wrapper), Some(b_array)) = (
1404 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1405 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1406 ) {
1407 let a_arr = a_wrapper.as_array();
1408 let b_arr = b_array.as_array();
1409 let result = a_arr - b_arr;
1410 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1411 } else if let (Some(a_wrapper), Some(b_array)) = (
1412 a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
1413 b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
1414 ) {
1415 let a_arr = a_wrapper.as_array();
1416 let b_arr = b_array.as_array();
1417 let result = a_arr - b_arr;
1418 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1419 } else if let (Some(a_wrapper), Some(b_array)) = (
1420 a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
1421 b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
1422 ) {
1423 let a_arr = a_wrapper.as_array();
1424 let b_arr = b_array.as_array();
1425 let result = a_arr - b_arr;
1426 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1427 } else {
1428 Err(CoreError::NotImplementedError(ErrorContext::new(
1429 "subtract_arrays not implemented for these array types".to_string(),
1430 )))
1431 }
1432}
1433
1434#[allow(dead_code)]
1436fn sqrt(a: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
1437 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1438 let result = a_array.as_array().mapv(|x| x.sqrt());
1439 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1440 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1441 let result = a_array.as_array().mapv(|x| x.sqrt());
1442 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1443 } else {
1444 Err(CoreError::NotImplementedError(ErrorContext::new(
1445 "sqrt not implemented for this array type".to_string(),
1446 )))
1447 }
1448}
1449
1450#[allow(dead_code)]
1452fn add_scalar(a: &dyn ArrayProtocol, scalar: f64) -> CoreResult<Box<dyn ArrayProtocol>> {
1453 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1454 let result = a_array.as_array().mapv(|x| x + scalar);
1455 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1456 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1457 let result = a_array.as_array().mapv(|x| x + scalar as f32);
1458 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1459 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
1460 let result = a_array.as_array().mapv(|x| x + scalar as i32);
1461 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1462 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
1463 let result = a_array.as_array().mapv(|x| x + scalar as i64);
1464 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1465 } else {
1466 Err(CoreError::NotImplementedError(ErrorContext::new(
1467 "add_scalar not implemented for this array type".to_string(),
1468 )))
1469 }
1470}
1471
1472#[allow(dead_code)]
1474fn divide(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
1475 if let (Some(a_array), Some(b_array)) = (
1476 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1477 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1478 ) {
1479 let result = a_array.as_array() / b_array.as_array();
1480 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1481 } else if let (Some(a_array), Some(b_array)) = (
1482 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1483 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1484 ) {
1485 let result = a_array.as_array() / b_array.as_array();
1486 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1487 } else {
1488 Err(CoreError::NotImplementedError(ErrorContext::new(
1489 "divide not implemented for these array types".to_string(),
1490 )))
1491 }
1492}
1493
1494#[cfg(test)]
1495mod tests {
1496 use super::*;
1497 use ::ndarray::{array, Array2, Ix2};
1498
1499 #[test]
1500 fn test_gradient_tensor_creation() {
1501 let array = Array2::<f64>::ones((2, 2));
1503 let tensor = GradientTensor::from_array(array, true);
1504
1505 assert!(tensor.requiresgrad());
1507 assert!(tensor.is_leaf());
1508 assert!(tensor.grad_2().is_none());
1509 }
1510
1511 #[test]
1512 fn test_gradient_computation_add() {
1513 #[allow(unused_imports)]
1515 use ::ndarray::array;
1516
1517 let a_array = Array2::<f64>::ones((2, 2));
1519 let b_array = Array2::<f64>::ones((2, 2)) * 2.0;
1520
1521 let a = GradientTensor::from_array(a_array, true);
1522 let b = GradientTensor::from_array(b_array, true);
1523
1524 let c = match grad_add(&a, &b) {
1526 Ok(c) => c,
1527 Err(e) => {
1528 println!("Skipping test_gradient_computationadd: {e}");
1529 return;
1530 }
1531 };
1532
1533 let c_value = c.value();
1535 let c_array = match c_value.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1536 Some(array) => array,
1537 None => {
1538 println!("Skipping test_gradient_computationadd: result is not the expected type");
1539 return;
1540 }
1541 };
1542 assert_eq!(c_array.as_array(), &array![[3.0, 3.0], [3.0, 3.0]]);
1543
1544 if let Err(e) = c.backward() {
1546 println!("Skipping test_gradient_computationadd: {e}");
1547 return;
1548 }
1549
1550 let a_grad = match a.grad_2() {
1552 Some(grad) => grad,
1553 None => {
1554 println!("Skipping test_gradient_computationadd: no gradient for a");
1555 return;
1556 }
1557 };
1558
1559 let a_grad_array = match a_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1560 Some(array) => array,
1561 None => {
1562 println!("Skipping test_gradient_computationadd: a_grad is not the expected type");
1563 return;
1564 }
1565 };
1566 assert_eq!(a_grad_array.as_array(), &array![[1.0, 1.0], [1.0, 1.0]]);
1567
1568 let b_grad = match b.grad_2() {
1569 Some(grad) => grad,
1570 None => {
1571 println!("Skipping test_gradient_computationadd: no gradient for b");
1572 return;
1573 }
1574 };
1575
1576 let b_grad_array = match b_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1577 Some(array) => array,
1578 None => {
1579 println!("Skipping test_gradient_computationadd: b_grad is not the expected type");
1580 return;
1581 }
1582 };
1583 assert_eq!(b_grad_array.as_array(), &array![[1.0, 1.0], [1.0, 1.0]]);
1584 }
1585
1586 #[test]
1587 fn test_gradient_computation_multiply() {
1588 #[allow(unused_imports)]
1590 use ::ndarray::array;
1591
1592 let a_array = Array2::<f64>::ones((2, 2)) * 2.0;
1594 let b_array = Array2::<f64>::ones((2, 2)) * 3.0;
1595
1596 let a = GradientTensor::from_array(a_array, true);
1597 let b = GradientTensor::from_array(b_array, true);
1598
1599 let c = match grad_multiply(&a, &b) {
1601 Ok(c) => c,
1602 Err(e) => {
1603 println!("Skipping test_gradient_computationmultiply: {e}");
1604 return;
1605 }
1606 };
1607
1608 let c_value = c.value();
1610 let c_array = match c_value.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1611 Some(array) => array,
1612 None => {
1613 println!(
1614 "Skipping test_gradient_computation_multiply: result is not the expected type"
1615 );
1616 return;
1617 }
1618 };
1619 assert_eq!(c_array.as_array(), &array![[6.0, 6.0], [6.0, 6.0]]);
1620
1621 if let Err(e) = c.backward() {
1623 println!("Skipping test_gradient_computationmultiply: {e}");
1624 return;
1625 }
1626
1627 let a_grad = match a.grad_2() {
1629 Some(grad) => grad,
1630 None => {
1631 println!("Skipping test_gradient_computationmultiply: no gradient for a");
1632 return;
1633 }
1634 };
1635
1636 let a_grad_array = match a_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1637 Some(array) => array,
1638 None => {
1639 println!(
1640 "Skipping test_gradient_computation_multiply: a_grad is not the expected type"
1641 );
1642 return;
1643 }
1644 };
1645 assert_eq!(a_grad_array.as_array(), &array![[3.0, 3.0], [3.0, 3.0]]);
1646
1647 let b_grad = match b.grad_2() {
1648 Some(grad) => grad,
1649 None => {
1650 println!("Skipping test_gradient_computationmultiply: no gradient for b");
1651 return;
1652 }
1653 };
1654
1655 let b_grad_array = match b_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1656 Some(array) => array,
1657 None => {
1658 println!(
1659 "Skipping test_gradient_computation_multiply: b_grad is not the expected type"
1660 );
1661 return;
1662 }
1663 };
1664 assert_eq!(b_grad_array.as_array(), &array![[2.0, 2.0], [2.0, 2.0]]);
1665 }
1666
1667 #[test]
1668 fn test_sgd_optimizer() {
1669 #[allow(unused_imports)]
1671 use ::ndarray::array;
1672
1673 let weight_array = Array2::<f64>::ones((2, 2));
1675 let weight = Variable::new("weight", weight_array);
1676
1677 let bias_array = Array2::<f64>::zeros((2, 2));
1678 let bias = Variable::new("bias", bias_array);
1679
1680 let mut optimizer = SGD::new(0.1, Some(0.9));
1682 optimizer.add_variable(weight);
1683 optimizer.add_variable(bias);
1684
1685 let weight_grad_array = Array2::<f64>::ones((2, 2));
1687 let weight_grad = NdarrayWrapper::new(weight_grad_array);
1688 optimizer.variables()[0].tensor.node.borrow_mut().grad = Some(Rc::new(weight_grad));
1689
1690 let bias_grad_array = Array2::<f64>::ones((2, 2)) * 2.0;
1691 let bias_grad = NdarrayWrapper::new(bias_grad_array);
1692 optimizer.variables()[1].tensor.node.borrow_mut().grad = Some(Rc::new(bias_grad));
1693
1694 match optimizer.step() {
1696 Ok(_) => {
1697 optimizer.zero_grad();
1699
1700 assert!(optimizer.variables()[0].grad_2().is_none());
1702 assert!(optimizer.variables()[1].grad_2().is_none());
1703 }
1704 Err(e) => {
1705 println!("Skipping test_sgd_optimizer - step failed: {e}");
1706 }
1707 }
1708 }
1709}