1use crate::ndarray::compat::ArrayStatCompat;
14use ::ndarray::{Array, ArrayD, Dimension, IxDyn};
15
16use std::cell::RefCell;
17use std::collections::{HashMap, HashSet};
18use std::rc::Rc;
19
20use crate::array_protocol::operations::matmul;
21use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
22use crate::error::{CoreError, CoreResult, ErrorContext};
23
24#[derive(Clone)]
26pub struct GradientDict {
27 gradients: HashMap<String, Box<dyn ArrayProtocol>>,
28}
29
30impl std::fmt::Debug for GradientDict {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("GradientDict")
33 .field(
34 "gradients",
35 &format!("{{keys: {:?}}}", self.gradients.keys().collect::<Vec<_>>()),
36 )
37 .finish()
38 }
39}
40
41impl GradientDict {
42 pub fn new() -> Self {
44 Self {
45 gradients: HashMap::new(),
46 }
47 }
48
49 pub fn insert(&mut self, name: String, gradient: Box<dyn ArrayProtocol>) {
51 self.gradients.insert(name, gradient);
52 }
53
54 pub fn get(&self, name: &str) -> Option<&dyn ArrayProtocol> {
56 self.gradients.get(name).map(|b| b.as_ref())
57 }
58
59 pub fn get_mut(&mut self, name: &str) -> Option<&mut Box<dyn ArrayProtocol>> {
61 self.gradients.get_mut(name)
62 }
63
64 pub fn iter(&self) -> impl Iterator<Item = (&String, &Box<dyn ArrayProtocol>)> {
66 self.gradients.iter()
67 }
68
69 pub fn merge(&mut self, other: GradientDict) {
71 for (name, gradient) in other.gradients {
72 self.gradients.insert(name, gradient);
73 }
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.gradients.is_empty()
79 }
80
81 pub fn len(&self) -> usize {
83 self.gradients.len()
84 }
85
86 pub fn clear(&mut self) {
88 self.gradients.clear();
89 }
90
91 pub fn keys(&self) -> impl Iterator<Item = &String> {
93 self.gradients.keys()
94 }
95
96 pub fn values(&self) -> impl Iterator<Item = &Box<dyn ArrayProtocol>> {
98 self.gradients.values()
99 }
100}
101
102impl Default for GradientDict {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108#[allow(dead_code)]
110fn boxed_to_rc(boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
111 let array_ref = boxed.as_ref();
116
117 if let Some(ndarray_wrapper) = array_ref
120 .as_any()
121 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
122 {
123 let array_clone = ndarray_wrapper.as_array().clone();
124 return Rc::new(NdarrayWrapper::new(array_clone));
125 }
126
127 let fallback_array = ArrayD::<f64>::zeros(IxDyn(&[1, 1]));
130 Rc::new(NdarrayWrapper::new(fallback_array))
131}
132
133#[allow(dead_code)]
135fn box_to_rc_array_protocol(boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
136 boxed_to_rc(boxed)
137}
138
139#[allow(dead_code)]
141fn add(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
142 crate::array_protocol::operations::add(a, b).map_err(|e| e.into())
143}
144
145#[allow(dead_code)]
146fn multiply(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
147 crate::array_protocol::operations::multiply(a, b).map_err(|e| e.into())
148}
149
150#[allow(dead_code)]
151fn subtract(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
152 crate::array_protocol::operations::subtract(a, b).map_err(|e| e.into())
153}
154
155#[allow(dead_code)]
156fn ones_like(a: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
157 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
160 let shape = a_array.as_array().shape();
161 let ones = ArrayD::<f64>::ones(IxDyn(shape));
162 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
163 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
164 let shape = a_array.as_array().shape();
165 let ones = ArrayD::<f32>::ones(IxDyn(shape));
166 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
167 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
168 let shape = a_array.as_array().shape();
169 let ones = ArrayD::<i32>::ones(IxDyn(shape));
170 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
171 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
172 let shape = a_array.as_array().shape();
173 let ones = ArrayD::<i64>::ones(IxDyn(shape));
174 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
175 } else {
176 let shape = a.shape().to_vec();
178 let ones = ArrayD::<f64>::ones(IxDyn(&shape));
179 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
180 }
181}
182
183#[allow(dead_code)]
184fn broadcast_to(a: &dyn ArrayProtocol, shape: &[usize]) -> CoreResult<Box<dyn ArrayProtocol>> {
185 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
187 let array = a_array.as_array();
188 if array.len() == 1 {
190 let value = array.iter().next().cloned().unwrap_or(0.0);
191 let broadcasted = ArrayD::<f64>::from_elem(IxDyn(shape), value);
192 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
193 } else if array.shape() == shape {
194 Ok(Box::new(NdarrayWrapper::new(array.clone())) as Box<dyn ArrayProtocol>)
196 } else {
197 let inputshape = array.shape();
199 let _ndim_diff = shape.len().saturating_sub(inputshape.len());
200
201 let mut can_broadcast = true;
203 for i in 0..inputshape.len() {
204 let input_dim = inputshape[inputshape.len() - 1 - i];
205 let target_dim = shape[shape.len() - 1 - i];
206 if input_dim != 1 && input_dim != target_dim {
207 can_broadcast = false;
208 break;
209 }
210 }
211
212 if can_broadcast {
213 if let Some(broadcasted_view) = array.broadcast(IxDyn(shape)) {
215 let broadcasted = broadcasted_view.to_owned();
216 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
217 } else {
218 Err(CoreError::NotImplementedError(ErrorContext::new(
219 "Broadcasting failed for these shapes".to_string(),
220 )))
221 }
222 } else {
223 Err(CoreError::NotImplementedError(ErrorContext::new(
224 "Incompatible shapes for broadcasting".to_string(),
225 )))
226 }
227 }
228 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
229 let array = a_array.as_array();
230 if array.len() == 1 {
231 let value = array.iter().next().cloned().unwrap_or(0.0);
232 let broadcasted = ArrayD::<f32>::from_elem(IxDyn(shape), value);
233 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
234 } else if array.shape() == shape {
235 Ok(Box::new(NdarrayWrapper::new(array.clone())) as Box<dyn ArrayProtocol>)
236 } else if let Some(broadcasted_view) = array.broadcast(IxDyn(shape)) {
237 let broadcasted = broadcasted_view.to_owned();
238 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
239 } else {
240 Err(CoreError::NotImplementedError(ErrorContext::new(
241 "Broadcasting failed for these shapes".to_string(),
242 )))
243 }
244 } else {
245 let ones = ArrayD::<f64>::ones(IxDyn(shape));
247 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
248 }
249}
250
251#[derive(Clone)]
253struct Node {
254 value: Rc<dyn ArrayProtocol>,
256
257 grad: Option<Rc<dyn ArrayProtocol>>,
259
260 op: Option<String>,
262
263 inputs: Vec<GradientTensor>,
265
266 requiresgrad: bool,
268
269 is_leaf: bool,
271}
272
273impl Node {
274 fn leaf(requiresgrad: bool) -> Self {
276 Self {
277 value: Rc::new(NdarrayWrapper::new(
278 crate::ndarray::Array0::<f64>::zeros(()),
279 )) as Rc<dyn ArrayProtocol>,
280 grad: None,
281 op: None,
282 inputs: Vec::new(),
283 requiresgrad,
284 is_leaf: true,
285 }
286 }
287
288 fn new_op(value: Rc<dyn ArrayProtocol>, op: String, inputs: Vec<GradientTensor>) -> Self {
290 let requiresgrad = inputs.iter().any(|x| x.requiresgrad());
291
292 Self {
293 value,
294 grad: None,
295 op: Some(op),
296 inputs,
297 requiresgrad,
298 is_leaf: false,
299 }
300 }
301}
302
303#[derive(Clone)]
305pub struct GradientTensor {
306 node: Rc<RefCell<Node>>,
308}
309
310impl GradientTensor {
311 pub fn new(value: Rc<dyn ArrayProtocol>, requiresgrad: bool) -> Self {
313 let mut node_inner = Node::leaf(requiresgrad);
314 node_inner.value = value;
315 node_inner.grad = None;
316 let node = Rc::new(RefCell::new(node_inner));
317 Self { node }
318 }
319
320 pub fn from_array<T, D>(array: Array<T, D>, requiresgrad: bool) -> Self
322 where
323 T: Clone + Send + Sync + 'static,
324 D: Dimension + Send + Sync + 'static,
325 {
326 let value = Rc::new(NdarrayWrapper::new(array)) as Rc<dyn ArrayProtocol>;
327 Self::new(value, requiresgrad)
328 }
329
330 pub fn value(&self) -> Rc<dyn ArrayProtocol> {
332 self.node.borrow().value.clone()
333 }
334
335 pub fn grad_2(&self) -> Option<Rc<dyn ArrayProtocol>> {
337 self.node.borrow().grad.clone()
338 }
339
340 pub fn requiresgrad(&self) -> bool {
342 self.node.borrow().requiresgrad
343 }
344
345 pub fn set_requiresgrad(&mut self, requiresgrad: bool) {
347 self.node.borrow_mut().requiresgrad = requiresgrad;
348 }
349
350 pub fn is_leaf(&self) -> bool {
352 self.node.borrow().is_leaf
353 }
354
355 fn from_op(value: Rc<dyn ArrayProtocol>, op: String, inputs: Vec<GradientTensor>) -> Self {
357 let node = Rc::new(RefCell::new(Node::new_op(value, op, inputs)));
358 Self { node }
359 }
360
361 pub fn set_value(&mut self, newvalue: Rc<dyn ArrayProtocol>) {
363 self.node.borrow_mut().grad = None; self.node.borrow_mut().value = newvalue;
365 }
366
367 pub fn backward(&self) -> CoreResult<()> {
369 let gradshape = if let Some(array) = self
371 .value()
372 .as_any()
373 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
374 {
375 array.as_array().raw_dim()
376 } else {
377 crate::ndarray::IxDyn(&[1])
379 };
380
381 let grad_array = Array::<f64, IxDyn>::ones(gradshape);
382 let grad = Rc::new(NdarrayWrapper::new(grad_array)) as Rc<dyn ArrayProtocol>;
383
384 self.backward_with_grad(grad)
386 }
387
388 fn backward_with_grad(&self, grad: Rc<dyn ArrayProtocol>) -> CoreResult<()> {
390 self.node.borrow_mut().grad = Some(grad.clone());
392
393 let mut visited = HashSet::new();
395 let mut topo = Vec::new();
396
397 fn build_topo(
399 tensor: &GradientTensor,
400 visited: &mut HashSet<*const RefCell<Node>>,
401 topo: &mut Vec<GradientTensor>,
402 ) {
403 let node_ptr = Rc::as_ptr(&tensor.node);
404 if !visited.contains(&node_ptr) {
405 visited.insert(node_ptr);
406
407 for input in &tensor.node.borrow().inputs {
409 build_topo(input, visited, topo);
410 }
411
412 topo.push(tensor.clone());
414 }
415 }
416
417 build_topo(self, &mut visited, &mut topo);
419
420 for node in topo.iter().rev() {
422 if !node.requiresgrad() {
424 continue;
425 }
426
427 let node_grad = match node.grad_2() {
429 Some(g) => g,
430 None => continue, };
432
433 if node.is_leaf() {
435 continue;
436 }
437
438 let op = match &node.node.borrow().op {
440 Some(op) => op.clone(),
441 None => continue, };
443
444 let inputs = node.node.borrow().inputs.clone();
445
446 match op.as_str() {
448 "add" => {
449 for input in &inputs {
451 if input.requiresgrad() {
452 let mut input_node = input.node.borrow_mut();
453 if let Some(input_grad) = &input_node.grad {
454 if let Ok(sum) = add(input_grad.as_ref(), node_grad.as_ref()) {
456 input_node.grad = Some(sum.into());
457 }
458 } else {
459 input_node.grad = Some(node_grad.clone());
460 }
461 }
462 }
463 }
464 "multiply"
465 if inputs.len() == 2 => {
467 let (a, b) = (&inputs[0], &inputs[1]);
468
469 if a.requiresgrad() {
471 let b_value = b.value();
472 if let Ok(grad_a) = multiply(node_grad.as_ref(), b_value.as_ref()) {
473 let mut a_node = a.node.borrow_mut();
474 if let Some(a_grad) = &a_node.grad {
475 if let Ok(sum) = add(a_grad.as_ref(), grad_a.as_ref()) {
477 a_node.grad = Some(box_to_rc_array_protocol(sum));
478 }
479 } else {
480 a_node.grad = Some(box_to_rc_array_protocol(grad_a));
481 }
482 }
483 }
484
485 if b.requiresgrad() {
487 let a_value = a.value();
488 if let Ok(grad_b) = multiply(node_grad.as_ref(), a_value.as_ref()) {
489 let mut b_node = b.node.borrow_mut();
490 if let Some(b_grad) = &b_node.grad {
491 if let Ok(sum) = add(b_grad.as_ref(), grad_b.as_ref()) {
493 b_node.grad = Some(box_to_rc_array_protocol(sum));
494 }
495 } else {
496 b_node.grad = Some(box_to_rc_array_protocol(grad_b));
497 }
498 }
499 }
500 }
501 "matmul"
502 if inputs.len() == 2 => {
506 let (a, b) = (&inputs[0], &inputs[1]);
507
508 if a.requiresgrad() {
510 if let (Some(b_array), Some(grad_out_array)) = (
511 b.value()
512 .as_any()
513 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
514 node_grad
515 .as_any()
516 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
517 ) {
518 let b_array_val = b_array.as_array();
519 let grad_out_array_val = grad_out_array.as_array();
520
521 let b_t = b_array_val.t();
523
524 let grad_outshape = grad_out_array_val.shape();
527 let grad_out_rows = grad_outshape[0];
528 let grad_out_cols = if grad_outshape.len() > 1 {
529 grad_outshape.iter().skip(1).product()
530 } else {
531 1
532 };
533 let grad_out_2d = grad_out_array_val
534 .clone()
535 .into_shape_with_order((grad_out_rows, grad_out_cols))
536 .expect("Operation failed");
537
538 let b_tshape = b_t.shape();
539 let b_t_rows = b_tshape[0];
540 let b_t_cols = if b_tshape.len() > 1 {
541 b_tshape.iter().skip(1).product()
542 } else {
543 1
544 };
545 let b_t_2d = b_t
546 .clone()
547 .into_shape_with_order((b_t_rows, b_t_cols))
548 .expect("Operation failed");
549
550 let grad_a_val = grad_out_2d.dot(&b_t_2d);
552
553 let grad_a_dyn = grad_a_val.into_dyn();
555 let grad_a = NdarrayWrapper::new(grad_a_dyn);
556
557 let mut a_node = a.node.borrow_mut();
559 if let Some(a_grad) = &a_node.grad {
560 if let (Some(a_grad_array), Some(grad_a_array)) = (
561 a_grad
562 .as_any()
563 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
564 grad_a
565 .as_any()
566 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
567 ) {
568 let sum = a_grad_array.as_array() + grad_a_array.as_array();
570 a_node.grad = Some(Rc::new(NdarrayWrapper::new(sum)));
571 }
572 } else {
573 a_node.grad = Some(Rc::new(grad_a));
575 }
576 }
577 }
578
579 if b.requiresgrad() {
581 if let (Some(a_array), Some(grad_out_array)) = (
582 a.value()
583 .as_any()
584 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
585 node_grad
586 .as_any()
587 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
588 ) {
589 let a_array_val = a_array.as_array();
590 let grad_out_array_val = grad_out_array.as_array();
591
592 let a_t = a_array_val.t();
594
595 let grad_outshape = grad_out_array_val.shape();
598 let grad_out_rows = grad_outshape[0];
599 let grad_out_cols = if grad_outshape.len() > 1 {
600 grad_outshape.iter().skip(1).product()
601 } else {
602 1
603 };
604 let grad_out_2d = grad_out_array_val
605 .clone()
606 .into_shape_with_order((grad_out_rows, grad_out_cols))
607 .expect("Operation failed");
608
609 let a_tshape = a_t.shape();
610 let a_t_rows = a_tshape[0];
611 let a_t_cols = if a_tshape.len() > 1 {
612 a_tshape.iter().skip(1).product()
613 } else {
614 1
615 };
616 let a_t_2d = a_t
617 .clone()
618 .into_shape_with_order((a_t_rows, a_t_cols))
619 .expect("Operation failed");
620
621 let grad_b_val = a_t_2d.dot(&grad_out_2d);
623
624 let grad_b_dyn = grad_b_val.into_dyn();
626 let grad_b = NdarrayWrapper::new(grad_b_dyn);
627
628 let mut b_node = b.node.borrow_mut();
630 if let Some(b_grad) = &b_node.grad {
631 if let (Some(b_grad_array), Some(grad_b_array)) = (
632 b_grad
633 .as_any()
634 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
635 grad_b
636 .as_any()
637 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
638 ) {
639 let sum = b_grad_array.as_array() + grad_b_array.as_array();
641 b_node.grad = Some(Rc::new(NdarrayWrapper::new(sum)));
642 }
643 } else {
644 b_node.grad = Some(Rc::new(grad_b));
646 }
647 }
648 }
649 }
650 "subtract"
651 if inputs.len() == 2 => {
653 let (a, b) = (&inputs[0], &inputs[1]);
654
655 if a.requiresgrad() {
657 let mut a_node = a.node.borrow_mut();
658 if let Some(a_grad) = &a_node.grad {
659 if let Ok(sum) = add(a_grad.as_ref(), node_grad.as_ref()) {
661 a_node.grad = Some(box_to_rc_array_protocol(sum));
662 }
663 } else {
664 a_node.grad = Some(node_grad.clone());
665 }
666 }
667
668 if b.requiresgrad() {
670 if let Ok(neg_grad) = multiply_by_scalar(node_grad.as_ref(), -1.0) {
671 let mut b_node = b.node.borrow_mut();
672 if let Some(b_grad) = &b_node.grad {
673 if let Ok(sum) = add(b_grad.as_ref(), neg_grad.as_ref()) {
675 b_node.grad = Some(box_to_rc_array_protocol(sum));
676 }
677 } else {
678 b_node.grad = Some(box_to_rc_array_protocol(neg_grad));
679 }
680 }
681 }
682 }
683 "divide"
684 if inputs.len() == 2 => {
686 let (a, b) = (&inputs[0], &inputs[1]);
687
688 if a.requiresgrad() {
690 let b_value = b.value();
691 if let Ok(grad_a) = divide(node_grad.as_ref(), b_value.as_ref()) {
692 let mut a_node = a.node.borrow_mut();
693 if let Some(a_grad) = &a_node.grad {
694 if let Ok(sum) = add(a_grad.as_ref(), grad_a.as_ref()) {
696 a_node.grad = Some(box_to_rc_array_protocol(sum));
697 }
698 } else {
699 a_node.grad = Some(box_to_rc_array_protocol(grad_a));
700 }
701 }
702 }
703
704 if b.requiresgrad() {
706 let a_value = a.value();
707 let b_value = b.value();
708
709 if let Ok(b_squared) = multiply(b_value.as_ref(), b_value.as_ref()) {
711 if let Ok(grad_times_a) =
713 multiply(node_grad.as_ref(), a_value.as_ref())
714 {
715 if let Ok(div_result) =
717 divide(grad_times_a.as_ref(), b_squared.as_ref())
718 {
719 if let Ok(grad_b) =
721 multiply_by_scalar(div_result.as_ref(), -1.0)
722 {
723 let mut b_node = b.node.borrow_mut();
724 if let Some(b_grad) = &b_node.grad {
725 if let Ok(sum) =
727 add(b_grad.as_ref(), grad_b.as_ref())
728 {
729 b_node.grad =
730 Some(box_to_rc_array_protocol(sum));
731 }
732 } else {
733 b_node.grad =
734 Some(box_to_rc_array_protocol(grad_b));
735 }
736 }
737 }
738 }
739 }
740 }
741 }
742 "sigmoid"
743 if inputs.len() == 1 => {
745 let input = &inputs[0];
746
747 if input.requiresgrad() {
748 let sigmoid_value = node.value();
750
751 if let Ok(ones) = ones_like(sigmoid_value.as_ref()) {
753 if let Ok(one_minus_sigmoid) =
754 subtract(ones.as_ref(), sigmoid_value.as_ref())
755 {
756 if let Ok(sigmoid_deriv) =
758 multiply(sigmoid_value.as_ref(), one_minus_sigmoid.as_ref())
759 {
760 if let Ok(grad_input) =
762 multiply(node_grad.as_ref(), sigmoid_deriv.as_ref())
763 {
764 let mut input_node = input.node.borrow_mut();
765 if let Some(input_grad) = &input_node.grad {
766 if let Ok(sum) =
768 add(input_grad.as_ref(), grad_input.as_ref())
769 {
770 input_node.grad =
771 Some(box_to_rc_array_protocol(sum));
772 }
773 } else {
774 input_node.grad =
775 Some(box_to_rc_array_protocol(grad_input));
776 }
777 }
778 }
779 }
780 }
781 }
782 }
783 "mean"
784 if inputs.len() == 1 => {
786 let input = &inputs[0];
787
788 if input.requiresgrad() {
789 let input_value = input.value();
791 if let Some(inputarray) = input_value
792 .as_any()
793 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
794 {
795 let n_elements = inputarray.as_array().len() as f64;
796
797 if let Ok(grad_input) =
799 multiply_by_scalar(node_grad.as_ref(), 1.0 / n_elements)
800 {
801 if let Ok(broadcasted_grad) = broadcast_to(
803 grad_input.as_ref(),
804 inputarray.as_array().shape(),
805 ) {
806 let mut input_node = input.node.borrow_mut();
807 if let Some(input_grad) = &input_node.grad {
808 if let Ok(sum) =
810 add(input_grad.as_ref(), broadcasted_grad.as_ref())
811 {
812 input_node.grad =
813 Some(box_to_rc_array_protocol(sum));
814 }
815 } else {
816 input_node.grad =
817 Some(box_to_rc_array_protocol(broadcasted_grad));
818 }
819 }
820 }
821 }
822 }
823 }
824 _ => {
825 }
827 }
828 }
829
830 Ok(())
831 }
832
833 pub fn detach(&self) -> Self {
835 GradientTensor::new(self.value(), false)
836 }
837}
838
839#[allow(dead_code)]
842pub fn grad_add(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
843 let a_value = a.value();
844 let b_value = b.value();
845
846 let result = add(a_value.as_ref(), b_value.as_ref())?;
848
849 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
851 Ok(GradientTensor::from_op(
852 result_rc,
853 "add".to_string(),
854 vec![a.clone(), b.clone()],
855 ))
856}
857
858#[allow(dead_code)]
860pub fn grad_multiply(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
861 let a_value = a.value();
862 let b_value = b.value();
863
864 let result = multiply(a_value.as_ref(), b_value.as_ref())?;
866
867 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
869 Ok(GradientTensor::from_op(
870 result_rc,
871 "multiply".to_string(),
872 vec![a.clone(), b.clone()],
873 ))
874}
875
876#[allow(dead_code)]
878pub fn grad_matmul(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
879 let a_value = a.value();
880 let b_value = b.value();
881
882 let result = matmul(a_value.as_ref(), b_value.as_ref())?;
884
885 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
887 Ok(GradientTensor::from_op(
888 result_rc,
889 "matmul".to_string(),
890 vec![a.clone(), b.clone()],
891 ))
892}
893
894#[allow(dead_code)]
896pub fn grad_subtract(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
897 let a_value = a.value();
898 let b_value = b.value();
899
900 let result = subtract(a_value.as_ref(), b_value.as_ref())?;
902
903 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
905 Ok(GradientTensor::from_op(
906 result_rc,
907 "subtract".to_string(),
908 vec![a.clone(), b.clone()],
909 ))
910}
911
912#[allow(dead_code)]
914pub fn grad_divide(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
915 let a_value = a.value();
916 let b_value = b.value();
917
918 let result = divide(a_value.as_ref(), b_value.as_ref())?;
920
921 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
923 Ok(GradientTensor::from_op(
924 result_rc,
925 "divide".to_string(),
926 vec![a.clone(), b.clone()],
927 ))
928}
929
930#[allow(dead_code)]
932pub fn grad_sigmoid(a: &GradientTensor) -> CoreResult<GradientTensor> {
933 let a_value = a.value();
934
935 if let Some(a_array) = a_value
937 .as_any()
938 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
939 {
940 let array = a_array.as_array();
941 let result = array.mapv(|x| 1.0 / (1.0 + (-x).exp()));
942 let result_wrapped = NdarrayWrapper::new(result);
943 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
944 Ok(GradientTensor::from_op(
945 result_rc,
946 "sigmoid".to_string(),
947 vec![a.clone()],
948 ))
949 } else if let Some(a_array) = a_value
950 .as_any()
951 .downcast_ref::<NdarrayWrapper<f32, IxDyn>>()
952 {
953 let array = a_array.as_array();
954 let result = array.mapv(|x| 1.0f32 / (1.0f32 + (-x).exp()));
955 let result_wrapped = NdarrayWrapper::new(result);
956 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
957 Ok(GradientTensor::from_op(
958 result_rc,
959 "sigmoid".to_string(),
960 vec![a.clone()],
961 ))
962 } else {
963 Err(CoreError::NotImplementedError(ErrorContext::new(
964 "sigmoid not implemented for this array type".to_string(),
965 )))
966 }
967}
968
969#[allow(dead_code)]
971pub fn grad_mean(a: &GradientTensor) -> CoreResult<GradientTensor> {
972 let a_value = a.value();
973
974 if let Some(a_array) = a_value
976 .as_any()
977 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
978 {
979 let array = a_array.as_array();
980 let mean_value = array.mean_or(0.0);
981 let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
982 let result_wrapped = NdarrayWrapper::new(result);
983 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
984 Ok(GradientTensor::from_op(
985 result_rc,
986 "mean".to_string(),
987 vec![a.clone()],
988 ))
989 } else if let Some(a_array) = a_value
990 .as_any()
991 .downcast_ref::<NdarrayWrapper<f32, IxDyn>>()
992 {
993 let array = a_array.as_array();
994 let mean_value = array.mean_or(0.0f32);
995 let result = ArrayD::<f32>::from_elem(IxDyn(&[1]), mean_value);
996 let result_wrapped = NdarrayWrapper::new(result);
997 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
998 Ok(GradientTensor::from_op(
999 result_rc,
1000 "mean".to_string(),
1001 vec![a.clone()],
1002 ))
1003 } else {
1004 Err(CoreError::NotImplementedError(ErrorContext::new(
1005 "mean not implemented for this array type".to_string(),
1006 )))
1007 }
1008}
1009
1010pub struct Variable {
1012 tensor: GradientTensor,
1014
1015 name: String,
1017}
1018
1019impl Variable {
1020 pub fn new<T, D>(name: &str, array: Array<T, D>) -> Self
1022 where
1023 T: Clone + Send + Sync + 'static,
1024 D: Dimension + Send + Sync + 'static,
1025 {
1026 let tensor = GradientTensor::from_array(array, true);
1027 Self {
1028 tensor,
1029 name: name.to_string(),
1030 }
1031 }
1032
1033 pub const fn tensor(&self) -> &GradientTensor {
1035 &self.tensor
1036 }
1037
1038 pub fn value(&self) -> Rc<dyn ArrayProtocol> {
1040 self.tensor.value()
1041 }
1042
1043 pub fn grad_2(&self) -> Option<Rc<dyn ArrayProtocol>> {
1045 self.tensor.grad_2()
1046 }
1047
1048 pub fn name(&self) -> &str {
1050 &self.name
1051 }
1052
1053 pub fn set_gradient(&mut self, gradient: Box<dyn ArrayProtocol>) -> CoreResult<()> {
1055 let gradient_rc = self.box_to_rc(gradient);
1057
1058 self.tensor.node.borrow_mut().grad = Some(gradient_rc);
1060 Ok(())
1061 }
1062
1063 pub fn set_value(&mut self, newvalue: Box<dyn ArrayProtocol>) {
1065 let newvalue_rc = self.box_to_rc(newvalue);
1066 self.tensor.set_value(newvalue_rc);
1067 }
1068
1069 fn box_to_rc(&self, boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
1071 if let Some(ndarray_wrapper) = boxed
1073 .as_ref()
1074 .as_any()
1075 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
1076 {
1077 let array_clone = ndarray_wrapper.as_array().clone();
1078 Rc::new(NdarrayWrapper::new(array_clone))
1079 } else {
1080 let fallback_array = ArrayD::<f64>::zeros(IxDyn(&[1, 1]));
1082 Rc::new(NdarrayWrapper::new(fallback_array))
1083 }
1084 }
1085}
1086
1087pub trait Optimizer {
1089 fn step(&mut self) -> CoreResult<()>;
1091
1092 fn zero_grad(&mut self);
1094
1095 fn add_variable(&mut self, var: Variable);
1097
1098 fn variables(&self) -> &[Variable];
1100
1101 fn accumulate_gradients(&mut self, gradients: &GradientDict) -> CoreResult<()> {
1103 for (param_name, gradient) in gradients.iter() {
1105 for var in self.variables_mut() {
1107 if var.name() == param_name {
1108 var.set_gradient(gradient.clone())?;
1109 break;
1110 }
1111 }
1112 }
1113 Ok(())
1114 }
1115
1116 fn variables_mut(&mut self) -> &mut [Variable] {
1118 &mut []
1121 }
1122}
1123
1124pub struct SGD {
1126 variables: Vec<Variable>,
1128
1129 learningrate: f64,
1131
1132 momentum: f64,
1134
1135 velocity: Vec<Option<Box<dyn ArrayProtocol>>>,
1137}
1138
1139impl SGD {
1140 pub fn new(learningrate: f64, momentum: Option<f64>) -> Self {
1142 Self {
1143 variables: Vec::new(),
1144 learningrate,
1145 momentum: momentum.unwrap_or(0.0),
1146 velocity: Vec::new(),
1147 }
1148 }
1149
1150 pub fn set_learningrate(&mut self, learningrate: f64) {
1152 self.learningrate = learningrate;
1153 }
1154}
1155
1156impl Optimizer for SGD {
1157 fn step(&mut self) -> CoreResult<()> {
1158 for (i, var) in self.variables.iter_mut().enumerate() {
1159 if let Some(grad) = var.grad_2() {
1160 let var_value = var.value();
1161
1162 let update = if self.momentum > 0.0 {
1164 if i >= self.velocity.len() {
1165 self.velocity.resize_with(i + 1, || None);
1166 }
1167
1168 if let Some(vel) = &self.velocity[i] {
1169 let scaled_grad = multiply_by_scalar(grad.as_ref(), self.learningrate)?;
1171 let scaled_vel = multiply_by_scalar(vel.as_ref(), self.momentum)?;
1172 let update = add(scaled_vel.as_ref(), scaled_grad.as_ref())?;
1173 self.velocity[i] = Some(update.clone());
1174 update
1175 } else {
1176 let update = multiply_by_scalar(grad.as_ref(), self.learningrate)?;
1178 self.velocity[i] = Some(update.clone());
1179 update
1180 }
1181 } else {
1182 multiply_by_scalar(grad.as_ref(), self.learningrate)?
1184 };
1185
1186 let updated_value = subtract_arrays(var_value.as_ref(), update.as_ref())?;
1188 var.set_value(updated_value);
1189 }
1190 }
1191
1192 Ok(())
1193 }
1194
1195 fn zero_grad(&mut self) {
1196 for var in &self.variables {
1197 var.tensor.node.borrow_mut().grad = None;
1198 }
1199 }
1200
1201 fn add_variable(&mut self, var: Variable) {
1202 self.variables.push(var);
1203 self.velocity.push(None);
1204 }
1205
1206 fn variables(&self) -> &[Variable] {
1207 &self.variables
1208 }
1209
1210 fn variables_mut(&mut self) -> &mut [Variable] {
1211 &mut self.variables
1212 }
1213}
1214
1215pub struct Adam {
1217 variables: Vec<Variable>,
1219
1220 learningrate: f64,
1222
1223 beta1: f64,
1225
1226 beta2: f64,
1228
1229 epsilon: f64,
1231
1232 m: Vec<Option<Box<dyn ArrayProtocol>>>,
1234
1235 v: Vec<Option<Box<dyn ArrayProtocol>>>,
1237
1238 t: usize,
1240}
1241
1242impl Adam {
1243 pub fn new(
1245 learningrate: f64,
1246 beta1: Option<f64>,
1247 beta2: Option<f64>,
1248 epsilon: Option<f64>,
1249 ) -> Self {
1250 Self {
1251 variables: Vec::new(),
1252 learningrate,
1253 beta1: beta1.unwrap_or(0.9),
1254 beta2: beta2.unwrap_or(0.999),
1255 epsilon: epsilon.unwrap_or(1e-8),
1256 m: Vec::new(),
1257 v: Vec::new(),
1258 t: 0,
1259 }
1260 }
1261}
1262
1263impl Optimizer for Adam {
1264 fn step(&mut self) -> CoreResult<()> {
1265 self.t += 1;
1266
1267 for (i, var) in self.variables.iter_mut().enumerate() {
1268 if let Some(grad) = var.grad_2() {
1269 let var_value = var.value();
1270
1271 if i >= self.m.len() {
1273 self.m.resize_with(i + 1, || None);
1274 self.v.resize_with(i + 1, || None);
1275 }
1276
1277 let m = if let Some(m_prev) = &self.m[i] {
1279 let scaled_m = multiply_by_scalar(m_prev.as_ref(), self.beta1)?;
1281 let scaled_grad = multiply_by_scalar(grad.as_ref(), 1.0 - self.beta1)?;
1282 add(scaled_m.as_ref(), scaled_grad.as_ref())?
1283 } else {
1284 multiply_by_scalar(grad.as_ref(), 1.0 - self.beta1)?
1286 };
1287
1288 let v = if let Some(v_prev) = &self.v[i] {
1290 let scaled_v = multiply_by_scalar(v_prev.as_ref(), self.beta2)?;
1292 let grad_squared = multiply(grad.as_ref(), grad.as_ref())?;
1293 let scaled_grad_sq =
1294 multiply_by_scalar(grad_squared.as_ref(), 1.0 - self.beta2)?;
1295 add(scaled_v.as_ref(), scaled_grad_sq.as_ref())?
1296 } else {
1297 let grad_squared = multiply(grad.as_ref(), grad.as_ref())?;
1299 multiply_by_scalar(grad_squared.as_ref(), 1.0 - self.beta2)?
1300 };
1301
1302 self.m[i] = Some(m.clone());
1304 self.v[i] = Some(v.clone());
1305
1306 let m_hat =
1308 multiply_by_scalar(m.as_ref(), 1.0 / (1.0 - self.beta1.powi(self.t as i32)))?;
1309 let v_hat =
1310 multiply_by_scalar(v.as_ref(), 1.0 / (1.0 - self.beta2.powi(self.t as i32)))?;
1311
1312 let v_hat_sqrt = sqrt(v_hat.as_ref())?;
1314 let v_hat_sqrt_eps = add_scalar(v_hat_sqrt.as_ref(), self.epsilon)?;
1315 let update_dir = divide(m_hat.as_ref(), v_hat_sqrt_eps.as_ref())?;
1316 let update = multiply_by_scalar(update_dir.as_ref(), self.learningrate)?;
1317
1318 let updated_value = subtract_arrays(var_value.as_ref(), update.as_ref())?;
1320 var.set_value(updated_value);
1321 }
1322 }
1323
1324 Ok(())
1325 }
1326
1327 fn zero_grad(&mut self) {
1328 for var in &self.variables {
1329 var.tensor.node.borrow_mut().grad = None;
1330 }
1331 }
1332
1333 fn add_variable(&mut self, var: Variable) {
1334 self.variables.push(var);
1335 self.m.push(None);
1336 self.v.push(None);
1337 }
1338
1339 fn variables(&self) -> &[Variable] {
1340 &self.variables
1341 }
1342
1343 fn variables_mut(&mut self) -> &mut [Variable] {
1344 &mut self.variables
1345 }
1346}
1347
1348#[allow(dead_code)]
1352fn multiply_by_scalar(a: &dyn ArrayProtocol, scalar: f64) -> CoreResult<Box<dyn ArrayProtocol>> {
1353 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1354 let inputarray = a_array.as_array();
1355 let result = inputarray.mapv(|x| x * scalar);
1356 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1357 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1358 let inputarray = a_array.as_array();
1359 let result = inputarray.mapv(|x| x * scalar as f32);
1360 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1361 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
1362 let inputarray = a_array.as_array();
1363 let result = inputarray.mapv(|x| (x as f64 * scalar) as i32);
1364 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1365 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
1366 let inputarray = a_array.as_array();
1367 let result = inputarray.mapv(|x| (x as f64 * scalar) as i64);
1368 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1369 } else {
1370 Err(CoreError::NotImplementedError(ErrorContext::new(
1371 "multiply_by_scalar not implemented for this array type".to_string(),
1372 )))
1373 }
1374}
1375
1376#[allow(dead_code)]
1378fn subtract_arrays(
1379 a: &dyn ArrayProtocol,
1380 b: &dyn ArrayProtocol,
1381) -> CoreResult<Box<dyn ArrayProtocol>> {
1382 if let (Some(a_wrapper), Some(b_array)) = (
1384 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1385 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1386 ) {
1387 let a_arr = a_wrapper.as_array();
1388 let b_arr = b_array.as_array();
1389 let result = a_arr - b_arr;
1390 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1391 } else if let (Some(a_wrapper), Some(b_array)) = (
1392 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1393 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1394 ) {
1395 let a_arr = a_wrapper.as_array();
1396 let b_arr = b_array.as_array();
1397 let result = a_arr - b_arr;
1398 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1399 } else if let (Some(a_wrapper), Some(b_array)) = (
1400 a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
1401 b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
1402 ) {
1403 let a_arr = a_wrapper.as_array();
1404 let b_arr = b_array.as_array();
1405 let result = a_arr - b_arr;
1406 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1407 } else if let (Some(a_wrapper), Some(b_array)) = (
1408 a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
1409 b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
1410 ) {
1411 let a_arr = a_wrapper.as_array();
1412 let b_arr = b_array.as_array();
1413 let result = a_arr - b_arr;
1414 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1415 } else {
1416 Err(CoreError::NotImplementedError(ErrorContext::new(
1417 "subtract_arrays not implemented for these array types".to_string(),
1418 )))
1419 }
1420}
1421
1422#[allow(dead_code)]
1424fn sqrt(a: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
1425 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1426 let result = a_array.as_array().mapv(|x| x.sqrt());
1427 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1428 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1429 let result = a_array.as_array().mapv(|x| x.sqrt());
1430 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1431 } else {
1432 Err(CoreError::NotImplementedError(ErrorContext::new(
1433 "sqrt not implemented for this array type".to_string(),
1434 )))
1435 }
1436}
1437
1438#[allow(dead_code)]
1440fn add_scalar(a: &dyn ArrayProtocol, scalar: f64) -> CoreResult<Box<dyn ArrayProtocol>> {
1441 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1442 let result = a_array.as_array().mapv(|x| x + scalar);
1443 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1444 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1445 let result = a_array.as_array().mapv(|x| x + scalar as f32);
1446 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1447 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
1448 let result = a_array.as_array().mapv(|x| x + scalar as i32);
1449 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1450 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
1451 let result = a_array.as_array().mapv(|x| x + scalar as i64);
1452 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1453 } else {
1454 Err(CoreError::NotImplementedError(ErrorContext::new(
1455 "add_scalar not implemented for this array type".to_string(),
1456 )))
1457 }
1458}
1459
1460#[allow(dead_code)]
1462fn divide(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
1463 if let (Some(a_array), Some(b_array)) = (
1464 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1465 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1466 ) {
1467 let result = a_array.as_array() / b_array.as_array();
1468 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1469 } else if let (Some(a_array), Some(b_array)) = (
1470 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1471 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1472 ) {
1473 let result = a_array.as_array() / b_array.as_array();
1474 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1475 } else {
1476 Err(CoreError::NotImplementedError(ErrorContext::new(
1477 "divide not implemented for these array types".to_string(),
1478 )))
1479 }
1480}
1481
1482#[cfg(test)]
1483mod tests {
1484 use super::*;
1485 use ::ndarray::{array, Array2, Ix2};
1486
1487 #[test]
1488 fn test_gradient_tensor_creation() {
1489 let array = Array2::<f64>::ones((2, 2));
1491 let tensor = GradientTensor::from_array(array, true);
1492
1493 assert!(tensor.requiresgrad());
1495 assert!(tensor.is_leaf());
1496 assert!(tensor.grad_2().is_none());
1497 }
1498
1499 #[test]
1500 fn test_gradient_computation_add() {
1501 #[allow(unused_imports)]
1503 use ::ndarray::array;
1504
1505 let a_array = Array2::<f64>::ones((2, 2));
1507 let b_array = Array2::<f64>::ones((2, 2)) * 2.0;
1508
1509 let a = GradientTensor::from_array(a_array, true);
1510 let b = GradientTensor::from_array(b_array, true);
1511
1512 let c = match grad_add(&a, &b) {
1514 Ok(c) => c,
1515 Err(e) => {
1516 println!("Skipping test_gradient_computationadd: {e}");
1517 return;
1518 }
1519 };
1520
1521 let c_value = c.value();
1523 let c_array = match c_value.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1524 Some(array) => array,
1525 None => {
1526 println!("Skipping test_gradient_computationadd: result is not the expected type");
1527 return;
1528 }
1529 };
1530 assert_eq!(c_array.as_array(), &array![[3.0, 3.0], [3.0, 3.0]]);
1531
1532 if let Err(e) = c.backward() {
1534 println!("Skipping test_gradient_computationadd: {e}");
1535 return;
1536 }
1537
1538 let a_grad = match a.grad_2() {
1540 Some(grad) => grad,
1541 None => {
1542 println!("Skipping test_gradient_computationadd: no gradient for a");
1543 return;
1544 }
1545 };
1546
1547 let a_grad_array = match a_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1548 Some(array) => array,
1549 None => {
1550 println!("Skipping test_gradient_computationadd: a_grad is not the expected type");
1551 return;
1552 }
1553 };
1554 assert_eq!(a_grad_array.as_array(), &array![[1.0, 1.0], [1.0, 1.0]]);
1555
1556 let b_grad = match b.grad_2() {
1557 Some(grad) => grad,
1558 None => {
1559 println!("Skipping test_gradient_computationadd: no gradient for b");
1560 return;
1561 }
1562 };
1563
1564 let b_grad_array = match b_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1565 Some(array) => array,
1566 None => {
1567 println!("Skipping test_gradient_computationadd: b_grad is not the expected type");
1568 return;
1569 }
1570 };
1571 assert_eq!(b_grad_array.as_array(), &array![[1.0, 1.0], [1.0, 1.0]]);
1572 }
1573
1574 #[test]
1575 fn test_gradient_computation_multiply() {
1576 #[allow(unused_imports)]
1578 use ::ndarray::array;
1579
1580 let a_array = Array2::<f64>::ones((2, 2)) * 2.0;
1582 let b_array = Array2::<f64>::ones((2, 2)) * 3.0;
1583
1584 let a = GradientTensor::from_array(a_array, true);
1585 let b = GradientTensor::from_array(b_array, true);
1586
1587 let c = match grad_multiply(&a, &b) {
1589 Ok(c) => c,
1590 Err(e) => {
1591 println!("Skipping test_gradient_computationmultiply: {e}");
1592 return;
1593 }
1594 };
1595
1596 let c_value = c.value();
1598 let c_array = match c_value.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1599 Some(array) => array,
1600 None => {
1601 println!(
1602 "Skipping test_gradient_computation_multiply: result is not the expected type"
1603 );
1604 return;
1605 }
1606 };
1607 assert_eq!(c_array.as_array(), &array![[6.0, 6.0], [6.0, 6.0]]);
1608
1609 if let Err(e) = c.backward() {
1611 println!("Skipping test_gradient_computationmultiply: {e}");
1612 return;
1613 }
1614
1615 let a_grad = match a.grad_2() {
1617 Some(grad) => grad,
1618 None => {
1619 println!("Skipping test_gradient_computationmultiply: no gradient for a");
1620 return;
1621 }
1622 };
1623
1624 let a_grad_array = match a_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1625 Some(array) => array,
1626 None => {
1627 println!(
1628 "Skipping test_gradient_computation_multiply: a_grad is not the expected type"
1629 );
1630 return;
1631 }
1632 };
1633 assert_eq!(a_grad_array.as_array(), &array![[3.0, 3.0], [3.0, 3.0]]);
1634
1635 let b_grad = match b.grad_2() {
1636 Some(grad) => grad,
1637 None => {
1638 println!("Skipping test_gradient_computationmultiply: no gradient for b");
1639 return;
1640 }
1641 };
1642
1643 let b_grad_array = match b_grad.as_any().downcast_ref::<NdarrayWrapper<f64, Ix2>>() {
1644 Some(array) => array,
1645 None => {
1646 println!(
1647 "Skipping test_gradient_computation_multiply: b_grad is not the expected type"
1648 );
1649 return;
1650 }
1651 };
1652 assert_eq!(b_grad_array.as_array(), &array![[2.0, 2.0], [2.0, 2.0]]);
1653 }
1654
1655 #[test]
1656 fn test_sgd_optimizer() {
1657 #[allow(unused_imports)]
1659 use ::ndarray::array;
1660
1661 let weight_array = Array2::<f64>::ones((2, 2));
1663 let weight = Variable::new("weight", weight_array);
1664
1665 let bias_array = Array2::<f64>::zeros((2, 2));
1666 let bias = Variable::new("bias", bias_array);
1667
1668 let mut optimizer = SGD::new(0.1, Some(0.9));
1670 optimizer.add_variable(weight);
1671 optimizer.add_variable(bias);
1672
1673 let weight_grad_array = Array2::<f64>::ones((2, 2));
1675 let weight_grad = NdarrayWrapper::new(weight_grad_array);
1676 optimizer.variables()[0].tensor.node.borrow_mut().grad = Some(Rc::new(weight_grad));
1677
1678 let bias_grad_array = Array2::<f64>::ones((2, 2)) * 2.0;
1679 let bias_grad = NdarrayWrapper::new(bias_grad_array);
1680 optimizer.variables()[1].tensor.node.borrow_mut().grad = Some(Rc::new(bias_grad));
1681
1682 match optimizer.step() {
1684 Ok(_) => {
1685 optimizer.zero_grad();
1687
1688 assert!(optimizer.variables()[0].grad_2().is_none());
1690 assert!(optimizer.variables()[1].grad_2().is_none());
1691 }
1692 Err(e) => {
1693 println!("Skipping test_sgd_optimizer - step failed: {e}");
1694 }
1695 }
1696 }
1697}