1use std::cell::RefCell;
20use std::collections::{HashMap, HashSet};
21use std::rc::Rc;
22
23use ndarray::{Array, ArrayD, Dimension, IxDyn};
24
25use crate::array_protocol::operations::matmul;
26use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
27use crate::error::{CoreError, CoreResult, ErrorContext};
28
29#[derive(Clone)]
31pub struct GradientDict {
32 gradients: HashMap<String, Box<dyn ArrayProtocol>>,
33}
34
35impl std::fmt::Debug for GradientDict {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.debug_struct("GradientDict")
38 .field(
39 "gradients",
40 &format!("{{keys: {:?}}}", self.gradients.keys().collect::<Vec<_>>()),
41 )
42 .finish()
43 }
44}
45
46impl GradientDict {
47 pub fn new() -> Self {
49 Self {
50 gradients: HashMap::new(),
51 }
52 }
53
54 pub fn insert(&mut self, name: String, gradient: Box<dyn ArrayProtocol>) {
56 self.gradients.insert(name, gradient);
57 }
58
59 pub fn get(&self, name: &str) -> Option<&dyn ArrayProtocol> {
61 self.gradients.get(name).map(|b| b.as_ref())
62 }
63
64 pub fn get_mut(&mut self, name: &str) -> Option<&mut Box<dyn ArrayProtocol>> {
66 self.gradients.get_mut(name)
67 }
68
69 pub fn iter(&self) -> impl Iterator<Item = (&String, &Box<dyn ArrayProtocol>)> {
71 self.gradients.iter()
72 }
73
74 pub fn merge(&mut self, other: GradientDict) {
76 for (name, gradient) in other.gradients {
77 self.gradients.insert(name, gradient);
78 }
79 }
80
81 pub fn is_empty(&self) -> bool {
83 self.gradients.is_empty()
84 }
85
86 pub fn len(&self) -> usize {
88 self.gradients.len()
89 }
90
91 pub fn clear(&mut self) {
93 self.gradients.clear();
94 }
95
96 pub fn keys(&self) -> impl Iterator<Item = &String> {
98 self.gradients.keys()
99 }
100
101 pub fn values(&self) -> impl Iterator<Item = &Box<dyn ArrayProtocol>> {
103 self.gradients.values()
104 }
105}
106
107impl Default for GradientDict {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113#[allow(dead_code)]
115fn boxed_to_rc(boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
116 let array_ref = boxed.as_ref();
121
122 if let Some(ndarray_wrapper) = array_ref
125 .as_any()
126 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
127 {
128 let array_clone = ndarray_wrapper.as_array().clone();
129 return Rc::new(NdarrayWrapper::new(array_clone));
130 }
131
132 let fallback_array = ArrayD::<f64>::zeros(IxDyn(&[1, 1]));
135 Rc::new(NdarrayWrapper::new(fallback_array))
136}
137
138#[allow(dead_code)]
140fn box_to_rc_array_protocol(boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
141 boxed_to_rc(boxed)
142}
143
144#[allow(dead_code)]
146fn add(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
147 crate::array_protocol::operations::add(a, b).map_err(|e| e.into())
148}
149
150#[allow(dead_code)]
151fn multiply(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
152 crate::array_protocol::operations::multiply(a, b).map_err(|e| e.into())
153}
154
155#[allow(dead_code)]
156fn subtract(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
157 crate::array_protocol::operations::subtract(a, b).map_err(|e| e.into())
158}
159
160#[allow(dead_code)]
161fn ones_like(a: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
162 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
165 let shape = a_array.as_array().shape();
166 let ones = ArrayD::<f64>::ones(IxDyn(shape));
167 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
168 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
169 let shape = a_array.as_array().shape();
170 let ones = ArrayD::<f32>::ones(IxDyn(shape));
171 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
172 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
173 let shape = a_array.as_array().shape();
174 let ones = ArrayD::<i32>::ones(IxDyn(shape));
175 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
176 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
177 let shape = a_array.as_array().shape();
178 let ones = ArrayD::<i64>::ones(IxDyn(shape));
179 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
180 } else {
181 let shape = a.shape().to_vec();
183 let ones = ArrayD::<f64>::ones(IxDyn(&shape));
184 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
185 }
186}
187
188#[allow(dead_code)]
189fn broadcast_to(a: &dyn ArrayProtocol, shape: &[usize]) -> CoreResult<Box<dyn ArrayProtocol>> {
190 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
192 let array = a_array.as_array();
193 if array.len() == 1 {
195 let value = array.iter().next().cloned().unwrap_or(0.0);
196 let broadcasted = ArrayD::<f64>::from_elem(IxDyn(shape), value);
197 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
198 } else if array.shape() == shape {
199 Ok(Box::new(NdarrayWrapper::new(array.clone())) as Box<dyn ArrayProtocol>)
201 } else {
202 let inputshape = array.shape();
204 let _ndim_diff = shape.len().saturating_sub(inputshape.len());
205
206 let mut can_broadcast = true;
208 for i in 0..inputshape.len() {
209 let input_dim = inputshape[inputshape.len() - 1 - i];
210 let target_dim = shape[shape.len() - 1 - i];
211 if input_dim != 1 && input_dim != target_dim {
212 can_broadcast = false;
213 break;
214 }
215 }
216
217 if can_broadcast {
218 if let Some(broadcasted_view) = array.broadcast(IxDyn(shape)) {
220 let broadcasted = broadcasted_view.to_owned();
221 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
222 } else {
223 Err(CoreError::NotImplementedError(ErrorContext::new(
224 "Broadcasting failed for these shapes".to_string(),
225 )))
226 }
227 } else {
228 Err(CoreError::NotImplementedError(ErrorContext::new(
229 "Incompatible shapes for broadcasting".to_string(),
230 )))
231 }
232 }
233 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
234 let array = a_array.as_array();
235 if array.len() == 1 {
236 let value = array.iter().next().cloned().unwrap_or(0.0);
237 let broadcasted = ArrayD::<f32>::from_elem(IxDyn(shape), value);
238 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
239 } else if array.shape() == shape {
240 Ok(Box::new(NdarrayWrapper::new(array.clone())) as Box<dyn ArrayProtocol>)
241 } else if let Some(broadcasted_view) = array.broadcast(IxDyn(shape)) {
242 let broadcasted = broadcasted_view.to_owned();
243 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
244 } else {
245 Err(CoreError::NotImplementedError(ErrorContext::new(
246 "Broadcasting failed for these shapes".to_string(),
247 )))
248 }
249 } else {
250 let ones = ArrayD::<f64>::ones(IxDyn(shape));
252 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
253 }
254}
255
256#[derive(Clone)]
258struct Node {
259 value: Rc<dyn ArrayProtocol>,
261
262 grad: Option<Rc<dyn ArrayProtocol>>,
264
265 op: Option<String>,
267
268 inputs: Vec<GradientTensor>,
270
271 requiresgrad: bool,
273
274 is_leaf: bool,
276}
277
278impl Node {
279 fn leaf(requiresgrad: bool) -> Self {
281 Self {
282 value: Rc::new(NdarrayWrapper::new(ndarray::Array0::<f64>::zeros(())))
283 as Rc<dyn ArrayProtocol>,
284 grad: None,
285 op: None,
286 inputs: Vec::new(),
287 requiresgrad,
288 is_leaf: true,
289 }
290 }
291
292 fn new_op(value: Rc<dyn ArrayProtocol>, op: String, inputs: Vec<GradientTensor>) -> Self {
294 let requiresgrad = inputs.iter().any(|x| x.requiresgrad());
295
296 Self {
297 value,
298 grad: None,
299 op: Some(op),
300 inputs,
301 requiresgrad,
302 is_leaf: false,
303 }
304 }
305}
306
307#[derive(Clone)]
309pub struct GradientTensor {
310 node: Rc<RefCell<Node>>,
312}
313
314impl GradientTensor {
315 pub fn new(value: Rc<dyn ArrayProtocol>, requiresgrad: bool) -> Self {
317 let mut node_inner = Node::leaf(requiresgrad);
318 node_inner.value = value;
319 node_inner.grad = None;
320 let node = Rc::new(RefCell::new(node_inner));
321 Self { node }
322 }
323
324 pub fn from_array<T, D>(array: Array<T, D>, requiresgrad: bool) -> Self
326 where
327 T: Clone + Send + Sync + 'static,
328 D: Dimension + Send + Sync + 'static,
329 {
330 let value = Rc::new(NdarrayWrapper::new(array)) as Rc<dyn ArrayProtocol>;
331 Self::new(value, requiresgrad)
332 }
333
334 pub fn value(&self) -> Rc<dyn ArrayProtocol> {
336 self.node.borrow().value.clone()
337 }
338
339 pub fn grad_2(&self) -> Option<Rc<dyn ArrayProtocol>> {
341 self.node.borrow().grad.clone()
342 }
343
344 pub fn requiresgrad(&self) -> bool {
346 self.node.borrow().requiresgrad
347 }
348
349 pub fn set_requiresgrad(&mut self, requiresgrad: bool) {
351 self.node.borrow_mut().requiresgrad = requiresgrad;
352 }
353
354 pub fn is_leaf(&self) -> bool {
356 self.node.borrow().is_leaf
357 }
358
359 fn from_op(value: Rc<dyn ArrayProtocol>, op: String, inputs: Vec<GradientTensor>) -> Self {
361 let node = Rc::new(RefCell::new(Node::new_op(value, op, inputs)));
362 Self { node }
363 }
364
365 pub fn set_value(&mut self, newvalue: Rc<dyn ArrayProtocol>) {
367 self.node.borrow_mut().grad = None; self.node.borrow_mut().value = newvalue;
369 }
370
371 pub fn backward(&self) -> CoreResult<()> {
373 let gradshape = if let Some(array) = self
375 .value()
376 .as_any()
377 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
378 {
379 array.as_array().raw_dim()
380 } else {
381 ndarray::IxDyn(&[1])
383 };
384
385 let grad_array = Array::<f64, IxDyn>::ones(gradshape);
386 let grad = Rc::new(NdarrayWrapper::new(grad_array)) as Rc<dyn ArrayProtocol>;
387
388 self.backward_with_grad(grad)
390 }
391
392 fn backward_with_grad(&self, grad: Rc<dyn ArrayProtocol>) -> CoreResult<()> {
394 self.node.borrow_mut().grad = Some(grad.clone());
396
397 let mut visited = HashSet::new();
399 let mut topo = Vec::new();
400
401 fn build_topo(
403 tensor: &GradientTensor,
404 visited: &mut HashSet<*const RefCell<Node>>,
405 topo: &mut Vec<GradientTensor>,
406 ) {
407 let node_ptr = Rc::as_ptr(&tensor.node);
408 if !visited.contains(&node_ptr) {
409 visited.insert(node_ptr);
410
411 for input in &tensor.node.borrow().inputs {
413 build_topo(input, visited, topo);
414 }
415
416 topo.push(tensor.clone());
418 }
419 }
420
421 build_topo(self, &mut visited, &mut topo);
423
424 for node in topo.iter().rev() {
426 if !node.requiresgrad() {
428 continue;
429 }
430
431 let node_grad = match node.grad_2() {
433 Some(g) => g,
434 None => continue, };
436
437 if node.is_leaf() {
439 continue;
440 }
441
442 let op = match &node.node.borrow().op {
444 Some(op) => op.clone(),
445 None => continue, };
447
448 let inputs = node.node.borrow().inputs.clone();
449
450 match op.as_str() {
452 "add" => {
453 for input in &inputs {
455 if input.requiresgrad() {
456 let mut input_node = input.node.borrow_mut();
457 if let Some(input_grad) = &input_node.grad {
458 if let Ok(sum) = add(input_grad.as_ref(), node_grad.as_ref()) {
460 input_node.grad = Some(sum.into());
461 }
462 } else {
463 input_node.grad = Some(node_grad.clone());
464 }
465 }
466 }
467 }
468 "multiply" => {
469 if inputs.len() == 2 {
471 let (a, b) = (&inputs[0], &inputs[1]);
472
473 if a.requiresgrad() {
475 let b_value = b.value();
476 if let Ok(grad_a) = multiply(node_grad.as_ref(), b_value.as_ref()) {
477 let mut a_node = a.node.borrow_mut();
478 if let Some(a_grad) = &a_node.grad {
479 if let Ok(sum) = add(a_grad.as_ref(), grad_a.as_ref()) {
481 a_node.grad = Some(box_to_rc_array_protocol(sum));
482 }
483 } else {
484 a_node.grad = Some(box_to_rc_array_protocol(grad_a));
485 }
486 }
487 }
488
489 if b.requiresgrad() {
491 let a_value = a.value();
492 if let Ok(grad_b) = multiply(node_grad.as_ref(), a_value.as_ref()) {
493 let mut b_node = b.node.borrow_mut();
494 if let Some(b_grad) = &b_node.grad {
495 if let Ok(sum) = add(b_grad.as_ref(), grad_b.as_ref()) {
497 b_node.grad = Some(box_to_rc_array_protocol(sum));
498 }
499 } else {
500 b_node.grad = Some(box_to_rc_array_protocol(grad_b));
501 }
502 }
503 }
504 }
505 }
506 "matmul" => {
507 if inputs.len() == 2 {
511 let (a, b) = (&inputs[0], &inputs[1]);
512
513 if a.requiresgrad() {
515 if let (Some(b_array), Some(grad_out_array)) = (
516 b.value()
517 .as_any()
518 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
519 node_grad
520 .as_any()
521 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
522 ) {
523 let b_array_val = b_array.as_array();
524 let grad_out_array_val = grad_out_array.as_array();
525
526 let b_t = b_array_val.t();
528
529 let grad_outshape = grad_out_array_val.shape();
532 let grad_out_rows = grad_outshape[0];
533 let grad_out_cols = if grad_outshape.len() > 1 {
534 grad_outshape.iter().skip(1).product()
535 } else {
536 1
537 };
538 let grad_out_2d = grad_out_array_val
539 .clone()
540 .into_shape_with_order((grad_out_rows, grad_out_cols))
541 .unwrap();
542
543 let b_tshape = b_t.shape();
544 let b_t_rows = b_tshape[0];
545 let b_t_cols = if b_tshape.len() > 1 {
546 b_tshape.iter().skip(1).product()
547 } else {
548 1
549 };
550 let b_t_2d = b_t
551 .clone()
552 .into_shape_with_order((b_t_rows, b_t_cols))
553 .unwrap();
554
555 let grad_a_val = grad_out_2d.dot(&b_t_2d);
557
558 let grad_a_dyn = grad_a_val.into_dyn();
560 let grad_a = NdarrayWrapper::new(grad_a_dyn);
561
562 let mut a_node = a.node.borrow_mut();
564 if let Some(a_grad) = &a_node.grad {
565 if let (Some(a_grad_array), Some(grad_a_array)) = (
566 a_grad
567 .as_any()
568 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
569 grad_a
570 .as_any()
571 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
572 ) {
573 let sum = a_grad_array.as_array() + grad_a_array.as_array();
575 a_node.grad = Some(Rc::new(NdarrayWrapper::new(sum)));
576 }
577 } else {
578 a_node.grad = Some(Rc::new(grad_a));
580 }
581 }
582 }
583
584 if b.requiresgrad() {
586 if let (Some(a_array), Some(grad_out_array)) = (
587 a.value()
588 .as_any()
589 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
590 node_grad
591 .as_any()
592 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
593 ) {
594 let a_array_val = a_array.as_array();
595 let grad_out_array_val = grad_out_array.as_array();
596
597 let a_t = a_array_val.t();
599
600 let grad_outshape = grad_out_array_val.shape();
603 let grad_out_rows = grad_outshape[0];
604 let grad_out_cols = if grad_outshape.len() > 1 {
605 grad_outshape.iter().skip(1).product()
606 } else {
607 1
608 };
609 let grad_out_2d = grad_out_array_val
610 .clone()
611 .into_shape_with_order((grad_out_rows, grad_out_cols))
612 .unwrap();
613
614 let a_tshape = a_t.shape();
615 let a_t_rows = a_tshape[0];
616 let a_t_cols = if a_tshape.len() > 1 {
617 a_tshape.iter().skip(1).product()
618 } else {
619 1
620 };
621 let a_t_2d = a_t
622 .clone()
623 .into_shape_with_order((a_t_rows, a_t_cols))
624 .unwrap();
625
626 let grad_b_val = a_t_2d.dot(&grad_out_2d);
628
629 let grad_b_dyn = grad_b_val.into_dyn();
631 let grad_b = NdarrayWrapper::new(grad_b_dyn);
632
633 let mut b_node = b.node.borrow_mut();
635 if let Some(b_grad) = &b_node.grad {
636 if let (Some(b_grad_array), Some(grad_b_array)) = (
637 b_grad
638 .as_any()
639 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
640 grad_b
641 .as_any()
642 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
643 ) {
644 let sum = b_grad_array.as_array() + grad_b_array.as_array();
646 b_node.grad = Some(Rc::new(NdarrayWrapper::new(sum)));
647 }
648 } else {
649 b_node.grad = Some(Rc::new(grad_b));
651 }
652 }
653 }
654 }
655 }
656 "subtract" => {
657 if inputs.len() == 2 {
659 let (a, b) = (&inputs[0], &inputs[1]);
660
661 if a.requiresgrad() {
663 let mut a_node = a.node.borrow_mut();
664 if let Some(a_grad) = &a_node.grad {
665 if let Ok(sum) = add(a_grad.as_ref(), node_grad.as_ref()) {
667 a_node.grad = Some(box_to_rc_array_protocol(sum));
668 }
669 } else {
670 a_node.grad = Some(node_grad.clone());
671 }
672 }
673
674 if b.requiresgrad() {
676 if let Ok(neg_grad) = multiply_by_scalar(node_grad.as_ref(), -1.0) {
677 let mut b_node = b.node.borrow_mut();
678 if let Some(b_grad) = &b_node.grad {
679 if let Ok(sum) = add(b_grad.as_ref(), neg_grad.as_ref()) {
681 b_node.grad = Some(box_to_rc_array_protocol(sum));
682 }
683 } else {
684 b_node.grad = Some(box_to_rc_array_protocol(neg_grad));
685 }
686 }
687 }
688 }
689 }
690 "divide" => {
691 if inputs.len() == 2 {
693 let (a, b) = (&inputs[0], &inputs[1]);
694
695 if a.requiresgrad() {
697 let b_value = b.value();
698 if let Ok(grad_a) = divide(node_grad.as_ref(), b_value.as_ref()) {
699 let mut a_node = a.node.borrow_mut();
700 if let Some(a_grad) = &a_node.grad {
701 if let Ok(sum) = add(a_grad.as_ref(), grad_a.as_ref()) {
703 a_node.grad = Some(box_to_rc_array_protocol(sum));
704 }
705 } else {
706 a_node.grad = Some(box_to_rc_array_protocol(grad_a));
707 }
708 }
709 }
710
711 if b.requiresgrad() {
713 let a_value = a.value();
714 let b_value = b.value();
715
716 if let Ok(b_squared) = multiply(b_value.as_ref(), b_value.as_ref()) {
718 if let Ok(grad_times_a) =
720 multiply(node_grad.as_ref(), a_value.as_ref())
721 {
722 if let Ok(div_result) =
724 divide(grad_times_a.as_ref(), b_squared.as_ref())
725 {
726 if let Ok(grad_b) =
728 multiply_by_scalar(div_result.as_ref(), -1.0)
729 {
730 let mut b_node = b.node.borrow_mut();
731 if let Some(b_grad) = &b_node.grad {
732 if let Ok(sum) =
734 add(b_grad.as_ref(), grad_b.as_ref())
735 {
736 b_node.grad =
737 Some(box_to_rc_array_protocol(sum));
738 }
739 } else {
740 b_node.grad =
741 Some(box_to_rc_array_protocol(grad_b));
742 }
743 }
744 }
745 }
746 }
747 }
748 }
749 }
750 "sigmoid" => {
751 if inputs.len() == 1 {
753 let input = &inputs[0];
754
755 if input.requiresgrad() {
756 let sigmoid_value = node.value();
758
759 if let Ok(ones) = ones_like(sigmoid_value.as_ref()) {
761 if let Ok(one_minus_sigmoid) =
762 subtract(ones.as_ref(), sigmoid_value.as_ref())
763 {
764 if let Ok(sigmoid_deriv) =
766 multiply(sigmoid_value.as_ref(), one_minus_sigmoid.as_ref())
767 {
768 if let Ok(grad_input) =
770 multiply(node_grad.as_ref(), sigmoid_deriv.as_ref())
771 {
772 let mut input_node = input.node.borrow_mut();
773 if let Some(input_grad) = &input_node.grad {
774 if let Ok(sum) =
776 add(input_grad.as_ref(), grad_input.as_ref())
777 {
778 input_node.grad =
779 Some(box_to_rc_array_protocol(sum));
780 }
781 } else {
782 input_node.grad =
783 Some(box_to_rc_array_protocol(grad_input));
784 }
785 }
786 }
787 }
788 }
789 }
790 }
791 }
792 "mean" => {
793 if inputs.len() == 1 {
795 let input = &inputs[0];
796
797 if input.requiresgrad() {
798 let input_value = input.value();
800 if let Some(inputarray) = input_value
801 .as_any()
802 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
803 {
804 let n_elements = inputarray.as_array().len() as f64;
805
806 if let Ok(grad_input) =
808 multiply_by_scalar(node_grad.as_ref(), 1.0 / n_elements)
809 {
810 if let Ok(broadcasted_grad) = broadcast_to(
812 grad_input.as_ref(),
813 inputarray.as_array().shape(),
814 ) {
815 let mut input_node = input.node.borrow_mut();
816 if let Some(input_grad) = &input_node.grad {
817 if let Ok(sum) =
819 add(input_grad.as_ref(), broadcasted_grad.as_ref())
820 {
821 input_node.grad =
822 Some(box_to_rc_array_protocol(sum));
823 }
824 } else {
825 input_node.grad =
826 Some(box_to_rc_array_protocol(broadcasted_grad));
827 }
828 }
829 }
830 }
831 }
832 }
833 }
834 _ => {
835 }
837 }
838 }
839
840 Ok(())
841 }
842
843 pub fn detach(&self) -> Self {
845 GradientTensor::new(self.value(), false)
846 }
847}
848
849#[allow(dead_code)]
852pub fn grad_add(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
853 let a_value = a.value();
854 let b_value = b.value();
855
856 let result = add(a_value.as_ref(), b_value.as_ref())?;
858
859 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
861 Ok(GradientTensor::from_op(
862 result_rc,
863 "add".to_string(),
864 vec![a.clone(), b.clone()],
865 ))
866}
867
868#[allow(dead_code)]
870pub fn grad_multiply(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
871 let a_value = a.value();
872 let b_value = b.value();
873
874 let result = multiply(a_value.as_ref(), b_value.as_ref())?;
876
877 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
879 Ok(GradientTensor::from_op(
880 result_rc,
881 "multiply".to_string(),
882 vec![a.clone(), b.clone()],
883 ))
884}
885
886#[allow(dead_code)]
888pub fn grad_matmul(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
889 let a_value = a.value();
890 let b_value = b.value();
891
892 let result = matmul(a_value.as_ref(), b_value.as_ref())?;
894
895 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
897 Ok(GradientTensor::from_op(
898 result_rc,
899 "matmul".to_string(),
900 vec![a.clone(), b.clone()],
901 ))
902}
903
904#[allow(dead_code)]
906pub fn grad_subtract(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
907 let a_value = a.value();
908 let b_value = b.value();
909
910 let result = subtract(a_value.as_ref(), b_value.as_ref())?;
912
913 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
915 Ok(GradientTensor::from_op(
916 result_rc,
917 "subtract".to_string(),
918 vec![a.clone(), b.clone()],
919 ))
920}
921
922#[allow(dead_code)]
924pub fn grad_divide(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
925 let a_value = a.value();
926 let b_value = b.value();
927
928 let result = divide(a_value.as_ref(), b_value.as_ref())?;
930
931 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
933 Ok(GradientTensor::from_op(
934 result_rc,
935 "divide".to_string(),
936 vec![a.clone(), b.clone()],
937 ))
938}
939
940#[allow(dead_code)]
942pub fn grad_sigmoid(a: &GradientTensor) -> CoreResult<GradientTensor> {
943 let a_value = a.value();
944
945 if let Some(a_array) = a_value
947 .as_any()
948 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
949 {
950 let array = a_array.as_array();
951 let result = array.mapv(|x| 1.0 / (1.0 + (-x).exp()));
952 let result_wrapped = NdarrayWrapper::new(result);
953 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
954 Ok(GradientTensor::from_op(
955 result_rc,
956 "sigmoid".to_string(),
957 vec![a.clone()],
958 ))
959 } else if let Some(a_array) = a_value
960 .as_any()
961 .downcast_ref::<NdarrayWrapper<f32, IxDyn>>()
962 {
963 let array = a_array.as_array();
964 let result = array.mapv(|x| 1.0f32 / (1.0f32 + (-x).exp()));
965 let result_wrapped = NdarrayWrapper::new(result);
966 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
967 Ok(GradientTensor::from_op(
968 result_rc,
969 "sigmoid".to_string(),
970 vec![a.clone()],
971 ))
972 } else {
973 Err(CoreError::NotImplementedError(ErrorContext::new(
974 "sigmoid not implemented for this array type".to_string(),
975 )))
976 }
977}
978
979#[allow(dead_code)]
981pub fn grad_mean(a: &GradientTensor) -> CoreResult<GradientTensor> {
982 let a_value = a.value();
983
984 if let Some(a_array) = a_value
986 .as_any()
987 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
988 {
989 let array = a_array.as_array();
990 let mean_value = array.mean().unwrap_or(0.0);
991 let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
992 let result_wrapped = NdarrayWrapper::new(result);
993 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
994 Ok(GradientTensor::from_op(
995 result_rc,
996 "mean".to_string(),
997 vec![a.clone()],
998 ))
999 } else if let Some(a_array) = a_value
1000 .as_any()
1001 .downcast_ref::<NdarrayWrapper<f32, IxDyn>>()
1002 {
1003 let array = a_array.as_array();
1004 let mean_value = array.mean().unwrap_or(0.0f32);
1005 let result = ArrayD::<f32>::from_elem(IxDyn(&[1]), mean_value);
1006 let result_wrapped = NdarrayWrapper::new(result);
1007 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
1008 Ok(GradientTensor::from_op(
1009 result_rc,
1010 "mean".to_string(),
1011 vec![a.clone()],
1012 ))
1013 } else {
1014 Err(CoreError::NotImplementedError(ErrorContext::new(
1015 "mean not implemented for this array type".to_string(),
1016 )))
1017 }
1018}
1019
1020pub struct Variable {
1022 tensor: GradientTensor,
1024
1025 name: String,
1027}
1028
1029impl Variable {
1030 pub fn new<T, D>(name: &str, array: Array<T, D>) -> Self
1032 where
1033 T: Clone + Send + Sync + 'static,
1034 D: Dimension + Send + Sync + 'static,
1035 {
1036 let tensor = GradientTensor::from_array(array, true);
1037 Self {
1038 tensor,
1039 name: name.to_string(),
1040 }
1041 }
1042
1043 pub const fn tensor(&self) -> &GradientTensor {
1045 &self.tensor
1046 }
1047
1048 pub fn value(&self) -> Rc<dyn ArrayProtocol> {
1050 self.tensor.value()
1051 }
1052
1053 pub fn grad_2(&self) -> Option<Rc<dyn ArrayProtocol>> {
1055 self.tensor.grad_2()
1056 }
1057
1058 pub fn name(&self) -> &str {
1060 &self.name
1061 }
1062
1063 pub fn set_gradient(&mut self, gradient: Box<dyn ArrayProtocol>) -> CoreResult<()> {
1065 let gradient_rc = self.box_to_rc(gradient);
1067
1068 self.tensor.node.borrow_mut().grad = Some(gradient_rc);
1070 Ok(())
1071 }
1072
1073 pub fn set_value(&mut self, newvalue: Box<dyn ArrayProtocol>) {
1075 let newvalue_rc = self.box_to_rc(newvalue);
1076 self.tensor.set_value(newvalue_rc);
1077 }
1078
1079 fn box_to_rc(&self, boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
1081 if let Some(ndarray_wrapper) = boxed
1083 .as_ref()
1084 .as_any()
1085 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
1086 {
1087 let array_clone = ndarray_wrapper.as_array().clone();
1088 Rc::new(NdarrayWrapper::new(array_clone))
1089 } else {
1090 let fallback_array = ArrayD::<f64>::zeros(IxDyn(&[1, 1]));
1092 Rc::new(NdarrayWrapper::new(fallback_array))
1093 }
1094 }
1095}
1096
1097pub trait Optimizer {
1099 fn step(&mut self) -> CoreResult<()>;
1101
1102 fn zero_grad(&mut self);
1104
1105 fn add_variable(&mut self, var: Variable);
1107
1108 fn variables(&self) -> &[Variable];
1110
1111 fn accumulate_gradients(&mut self, gradients: &GradientDict) -> CoreResult<()> {
1113 for (param_name, gradient) in gradients.iter() {
1115 for var in self.variables_mut() {
1117 if var.name() == param_name {
1118 var.set_gradient(gradient.clone())?;
1119 break;
1120 }
1121 }
1122 }
1123 Ok(())
1124 }
1125
1126 fn variables_mut(&mut self) -> &mut [Variable] {
1128 &mut []
1131 }
1132}
1133
1134pub struct SGD {
1136 variables: Vec<Variable>,
1138
1139 learningrate: f64,
1141
1142 momentum: f64,
1144
1145 velocity: Vec<Option<Box<dyn ArrayProtocol>>>,
1147}
1148
1149impl SGD {
1150 pub fn new(learningrate: f64, momentum: Option<f64>) -> Self {
1152 Self {
1153 variables: Vec::new(),
1154 learningrate,
1155 momentum: momentum.unwrap_or(0.0),
1156 velocity: Vec::new(),
1157 }
1158 }
1159
1160 pub fn set_learningrate(&mut self, learningrate: f64) {
1162 self.learningrate = learningrate;
1163 }
1164}
1165
1166impl Optimizer for SGD {
1167 fn step(&mut self) -> CoreResult<()> {
1168 for (i, var) in self.variables.iter_mut().enumerate() {
1169 if let Some(grad) = var.grad_2() {
1170 let var_value = var.value();
1171
1172 let update = if self.momentum > 0.0 {
1174 if i >= self.velocity.len() {
1175 self.velocity.resize_with(i + 1, || None);
1176 }
1177
1178 if let Some(vel) = &self.velocity[i] {
1179 let scaled_grad = multiply_by_scalar(grad.as_ref(), self.learningrate)?;
1181 let scaled_vel = multiply_by_scalar(vel.as_ref(), self.momentum)?;
1182 let update = add(scaled_vel.as_ref(), scaled_grad.as_ref())?;
1183 self.velocity[i] = Some(update.clone());
1184 update
1185 } else {
1186 let update = multiply_by_scalar(grad.as_ref(), self.learningrate)?;
1188 self.velocity[i] = Some(update.clone());
1189 update
1190 }
1191 } else {
1192 multiply_by_scalar(grad.as_ref(), self.learningrate)?
1194 };
1195
1196 let updated_value = subtract_arrays(var_value.as_ref(), update.as_ref())?;
1198 var.set_value(updated_value);
1199 }
1200 }
1201
1202 Ok(())
1203 }
1204
1205 fn zero_grad(&mut self) {
1206 for var in &self.variables {
1207 var.tensor.node.borrow_mut().grad = None;
1208 }
1209 }
1210
1211 fn add_variable(&mut self, var: Variable) {
1212 self.variables.push(var);
1213 self.velocity.push(None);
1214 }
1215
1216 fn variables(&self) -> &[Variable] {
1217 &self.variables
1218 }
1219
1220 fn variables_mut(&mut self) -> &mut [Variable] {
1221 &mut self.variables
1222 }
1223}
1224
1225pub struct Adam {
1227 variables: Vec<Variable>,
1229
1230 learningrate: f64,
1232
1233 beta1: f64,
1235
1236 beta2: f64,
1238
1239 epsilon: f64,
1241
1242 m: Vec<Option<Box<dyn ArrayProtocol>>>,
1244
1245 v: Vec<Option<Box<dyn ArrayProtocol>>>,
1247
1248 t: usize,
1250}
1251
1252impl Adam {
1253 pub fn new(
1255 learningrate: f64,
1256 beta1: Option<f64>,
1257 beta2: Option<f64>,
1258 epsilon: Option<f64>,
1259 ) -> Self {
1260 Self {
1261 variables: Vec::new(),
1262 learningrate,
1263 beta1: beta1.unwrap_or(0.9),
1264 beta2: beta2.unwrap_or(0.999),
1265 epsilon: epsilon.unwrap_or(1e-8),
1266 m: Vec::new(),
1267 v: Vec::new(),
1268 t: 0,
1269 }
1270 }
1271}
1272
1273impl Optimizer for Adam {
1274 fn step(&mut self) -> CoreResult<()> {
1275 self.t += 1;
1276
1277 for (i, var) in self.variables.iter_mut().enumerate() {
1278 if let Some(grad) = var.grad_2() {
1279 let var_value = var.value();
1280
1281 if i >= self.m.len() {
1283 self.m.resize_with(i + 1, || None);
1284 self.v.resize_with(i + 1, || None);
1285 }
1286
1287 let m = if let Some(m_prev) = &self.m[i] {
1289 let scaled_m = multiply_by_scalar(m_prev.as_ref(), self.beta1)?;
1291 let scaled_grad = multiply_by_scalar(grad.as_ref(), 1.0 - self.beta1)?;
1292 add(scaled_m.as_ref(), scaled_grad.as_ref())?
1293 } else {
1294 multiply_by_scalar(grad.as_ref(), 1.0 - self.beta1)?
1296 };
1297
1298 let v = if let Some(v_prev) = &self.v[i] {
1300 let scaled_v = multiply_by_scalar(v_prev.as_ref(), self.beta2)?;
1302 let grad_squared = multiply(grad.as_ref(), grad.as_ref())?;
1303 let scaled_grad_sq =
1304 multiply_by_scalar(grad_squared.as_ref(), 1.0 - self.beta2)?;
1305 add(scaled_v.as_ref(), scaled_grad_sq.as_ref())?
1306 } else {
1307 let grad_squared = multiply(grad.as_ref(), grad.as_ref())?;
1309 multiply_by_scalar(grad_squared.as_ref(), 1.0 - self.beta2)?
1310 };
1311
1312 self.m[i] = Some(m.clone());
1314 self.v[i] = Some(v.clone());
1315
1316 let m_hat =
1318 multiply_by_scalar(m.as_ref(), 1.0 / (1.0 - self.beta1.powi(self.t as i32)))?;
1319 let v_hat =
1320 multiply_by_scalar(v.as_ref(), 1.0 / (1.0 - self.beta2.powi(self.t as i32)))?;
1321
1322 let v_hat_sqrt = sqrt(v_hat.as_ref())?;
1324 let v_hat_sqrt_eps = add_scalar(v_hat_sqrt.as_ref(), self.epsilon)?;
1325 let update_dir = divide(m_hat.as_ref(), v_hat_sqrt_eps.as_ref())?;
1326 let update = multiply_by_scalar(update_dir.as_ref(), self.learningrate)?;
1327
1328 let updated_value = subtract_arrays(var_value.as_ref(), update.as_ref())?;
1330 var.set_value(updated_value);
1331 }
1332 }
1333
1334 Ok(())
1335 }
1336
1337 fn zero_grad(&mut self) {
1338 for var in &self.variables {
1339 var.tensor.node.borrow_mut().grad = None;
1340 }
1341 }
1342
1343 fn add_variable(&mut self, var: Variable) {
1344 self.variables.push(var);
1345 self.m.push(None);
1346 self.v.push(None);
1347 }
1348
1349 fn variables(&self) -> &[Variable] {
1350 &self.variables
1351 }
1352
1353 fn variables_mut(&mut self) -> &mut [Variable] {
1354 &mut self.variables
1355 }
1356}
1357
1358#[allow(dead_code)]
1362fn multiply_by_scalar(a: &dyn ArrayProtocol, scalar: f64) -> CoreResult<Box<dyn ArrayProtocol>> {
1363 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1364 let inputarray = a_array.as_array();
1365 let result = inputarray.mapv(|x| x * scalar);
1366 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1367 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1368 let inputarray = a_array.as_array();
1369 let result = inputarray.mapv(|x| x * scalar as f32);
1370 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1371 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
1372 let inputarray = a_array.as_array();
1373 let result = inputarray.mapv(|x| (x as f64 * scalar) as i32);
1374 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1375 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
1376 let inputarray = a_array.as_array();
1377 let result = inputarray.mapv(|x| (x as f64 * scalar) as i64);
1378 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1379 } else {
1380 Err(CoreError::NotImplementedError(ErrorContext::new(
1381 "multiply_by_scalar not implemented for this array type".to_string(),
1382 )))
1383 }
1384}
1385
1386#[allow(dead_code)]
1388fn subtract_arrays(
1389 a: &dyn ArrayProtocol,
1390 b: &dyn ArrayProtocol,
1391) -> CoreResult<Box<dyn ArrayProtocol>> {
1392 if let (Some(a_wrapper), Some(b_array)) = (
1394 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1395 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1396 ) {
1397 let a_arr = a_wrapper.as_array();
1398 let b_arr = b_array.as_array();
1399 let result = a_arr - b_arr;
1400 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1401 } else if let (Some(a_wrapper), Some(b_array)) = (
1402 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1403 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1404 ) {
1405 let a_arr = a_wrapper.as_array();
1406 let b_arr = b_array.as_array();
1407 let result = a_arr - b_arr;
1408 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1409 } else if let (Some(a_wrapper), Some(b_array)) = (
1410 a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
1411 b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
1412 ) {
1413 let a_arr = a_wrapper.as_array();
1414 let b_arr = b_array.as_array();
1415 let result = a_arr - b_arr;
1416 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1417 } else if let (Some(a_wrapper), Some(b_array)) = (
1418 a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
1419 b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
1420 ) {
1421 let a_arr = a_wrapper.as_array();
1422 let b_arr = b_array.as_array();
1423 let result = a_arr - b_arr;
1424 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1425 } else {
1426 Err(CoreError::NotImplementedError(ErrorContext::new(
1427 "subtract_arrays not implemented for these array types".to_string(),
1428 )))
1429 }
1430}
1431
1432#[allow(dead_code)]
1434fn sqrt(a: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
1435 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1436 let result = a_array.as_array().mapv(|x| x.sqrt());
1437 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1438 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1439 let result = a_array.as_array().mapv(|x| x.sqrt());
1440 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1441 } else {
1442 Err(CoreError::NotImplementedError(ErrorContext::new(
1443 "sqrt not implemented for this array type".to_string(),
1444 )))
1445 }
1446}
1447
1448#[allow(dead_code)]
1450fn add_scalar(a: &dyn ArrayProtocol, scalar: f64) -> CoreResult<Box<dyn ArrayProtocol>> {
1451 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1452 let result = a_array.as_array().mapv(|x| x + scalar);
1453 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1454 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1455 let result = a_array.as_array().mapv(|x| x + scalar as f32);
1456 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1457 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
1458 let result = a_array.as_array().mapv(|x| x + scalar as i32);
1459 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1460 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
1461 let result = a_array.as_array().mapv(|x| x + scalar as i64);
1462 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1463 } else {
1464 Err(CoreError::NotImplementedError(ErrorContext::new(
1465 "add_scalar not implemented for this array type".to_string(),
1466 )))
1467 }
1468}
1469
1470#[allow(dead_code)]
1472fn divide(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
1473 if let (Some(a_array), Some(b_array)) = (
1474 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1475 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1476 ) {
1477 let result = a_array.as_array() / b_array.as_array();
1478 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1479 } else if let (Some(a_array), Some(b_array)) = (
1480 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1481 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1482 ) {
1483 let result = a_array.as_array() / b_array.as_array();
1484 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1485 } else {
1486 Err(CoreError::NotImplementedError(ErrorContext::new(
1487 "divide not implemented for these array types".to_string(),
1488 )))
1489 }
1490}
1491
1492#[cfg(test)]
1493mod tests {
1494 use super::*;
1495 use ndarray::{array, Array2, Ix2};
1496
1497 #[test]
1498 fn test_gradient_tensor_creation() {
1499 let array = Array2::<f64>::ones((2, 2));
1501 let tensor = GradientTensor::from_array(array, true);
1502
1503 assert!(tensor.requiresgrad());
1505 assert!(tensor.is_leaf());
1506 assert!(tensor.grad_2().is_none());
1507 }
1508
1509 #[test]
1510 fn test_gradient_computation_add() {
1511 #[allow(unused_imports)]
1513 use ndarray::array;
1514
1515 let a_array = Array2::<f64>::ones((2, 2));
1517 let b_array = Array2::<f64>::ones((2, 2)) * 2.0;
1518
1519 let a = GradientTensor::from_array(a_array, true);
1520 let b = GradientTensor::from_array(b_array, true);
1521
1522 let c = match grad_add(&a, &b) {
1524 Ok(c) => c,
1525 Err(e) => {
1526 println!("Skipping test_gradient_computationadd: {e}");
1527 return;
1528 }
1529 };
1530
1531 let c_value = c.value();
1533 let c_array = match c_value.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1534 Some(array) => array,
1535 None => {
1536 println!("Skipping test_gradient_computationadd: result is not the expected type");
1537 return;
1538 }
1539 };
1540 assert_eq!(c_array.as_array(), &array![[3.0, 3.0], [3.0, 3.0]]);
1541
1542 if let Err(e) = c.backward() {
1544 println!("Skipping test_gradient_computationadd: {e}");
1545 return;
1546 }
1547
1548 let a_grad = match a.grad_2() {
1550 Some(grad) => grad,
1551 None => {
1552 println!("Skipping test_gradient_computationadd: no gradient for a");
1553 return;
1554 }
1555 };
1556
1557 let a_grad_array = match a_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1558 Some(array) => array,
1559 None => {
1560 println!("Skipping test_gradient_computationadd: a_grad is not the expected type");
1561 return;
1562 }
1563 };
1564 assert_eq!(a_grad_array.as_array(), &array![[1.0, 1.0], [1.0, 1.0]]);
1565
1566 let b_grad = match b.grad_2() {
1567 Some(grad) => grad,
1568 None => {
1569 println!("Skipping test_gradient_computationadd: no gradient for b");
1570 return;
1571 }
1572 };
1573
1574 let b_grad_array = match b_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1575 Some(array) => array,
1576 None => {
1577 println!("Skipping test_gradient_computationadd: b_grad is not the expected type");
1578 return;
1579 }
1580 };
1581 assert_eq!(b_grad_array.as_array(), &array![[1.0, 1.0], [1.0, 1.0]]);
1582 }
1583
1584 #[test]
1585 fn test_gradient_computation_multiply() {
1586 #[allow(unused_imports)]
1588 use ndarray::array;
1589
1590 let a_array = Array2::<f64>::ones((2, 2)) * 2.0;
1592 let b_array = Array2::<f64>::ones((2, 2)) * 3.0;
1593
1594 let a = GradientTensor::from_array(a_array, true);
1595 let b = GradientTensor::from_array(b_array, true);
1596
1597 let c = match grad_multiply(&a, &b) {
1599 Ok(c) => c,
1600 Err(e) => {
1601 println!("Skipping test_gradient_computationmultiply: {e}");
1602 return;
1603 }
1604 };
1605
1606 let c_value = c.value();
1608 let c_array = match c_value.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1609 Some(array) => array,
1610 None => {
1611 println!(
1612 "Skipping test_gradient_computation_multiply: result is not the expected type"
1613 );
1614 return;
1615 }
1616 };
1617 assert_eq!(c_array.as_array(), &array![[6.0, 6.0], [6.0, 6.0]]);
1618
1619 if let Err(e) = c.backward() {
1621 println!("Skipping test_gradient_computationmultiply: {e}");
1622 return;
1623 }
1624
1625 let a_grad = match a.grad_2() {
1627 Some(grad) => grad,
1628 None => {
1629 println!("Skipping test_gradient_computationmultiply: no gradient for a");
1630 return;
1631 }
1632 };
1633
1634 let a_grad_array = match a_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1635 Some(array) => array,
1636 None => {
1637 println!(
1638 "Skipping test_gradient_computation_multiply: a_grad is not the expected type"
1639 );
1640 return;
1641 }
1642 };
1643 assert_eq!(a_grad_array.as_array(), &array![[3.0, 3.0], [3.0, 3.0]]);
1644
1645 let b_grad = match b.grad_2() {
1646 Some(grad) => grad,
1647 None => {
1648 println!("Skipping test_gradient_computationmultiply: no gradient for b");
1649 return;
1650 }
1651 };
1652
1653 let b_grad_array = match b_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1654 Some(array) => array,
1655 None => {
1656 println!(
1657 "Skipping test_gradient_computation_multiply: b_grad is not the expected type"
1658 );
1659 return;
1660 }
1661 };
1662 assert_eq!(b_grad_array.as_array(), &array![[2.0, 2.0], [2.0, 2.0]]);
1663 }
1664
1665 #[test]
1666 fn test_sgd_optimizer() {
1667 #[allow(unused_imports)]
1669 use ndarray::array;
1670
1671 let weight_array = Array2::<f64>::ones((2, 2));
1673 let weight = Variable::new("weight", weight_array);
1674
1675 let bias_array = Array2::<f64>::zeros((2, 2));
1676 let bias = Variable::new("bias", bias_array);
1677
1678 let mut optimizer = SGD::new(0.1, Some(0.9));
1680 optimizer.add_variable(weight);
1681 optimizer.add_variable(bias);
1682
1683 let weight_grad_array = Array2::<f64>::ones((2, 2));
1685 let weight_grad = NdarrayWrapper::new(weight_grad_array);
1686 optimizer.variables()[0].tensor.node.borrow_mut().grad = Some(Rc::new(weight_grad));
1687
1688 let bias_grad_array = Array2::<f64>::ones((2, 2)) * 2.0;
1689 let bias_grad = NdarrayWrapper::new(bias_grad_array);
1690 optimizer.variables()[1].tensor.node.borrow_mut().grad = Some(Rc::new(bias_grad));
1691
1692 match optimizer.step() {
1694 Ok(_) => {
1695 optimizer.zero_grad();
1697
1698 assert!(optimizer.variables()[0].grad_2().is_none());
1700 assert!(optimizer.variables()[1].grad_2().is_none());
1701 }
1702 Err(e) => {
1703 println!("Skipping test_sgd_optimizer - step failed: {e}");
1704 }
1705 }
1706 }
1707}