1use crate::ndarray::compat::ArrayStatCompat;
14use ::ndarray::{Array, ArrayD, Dimension, IxDyn};
15
16use std::cell::RefCell;
17use std::collections::{HashMap, HashSet};
18use std::rc::Rc;
19
20use crate::array_protocol::operations::matmul;
21use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
22use crate::error::{CoreError, CoreResult, ErrorContext};
23
24#[derive(Clone)]
26pub struct GradientDict {
27 gradients: HashMap<String, Box<dyn ArrayProtocol>>,
28}
29
30impl std::fmt::Debug for GradientDict {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("GradientDict")
33 .field(
34 "gradients",
35 &format!("{{keys: {:?}}}", self.gradients.keys().collect::<Vec<_>>()),
36 )
37 .finish()
38 }
39}
40
41impl GradientDict {
42 pub fn new() -> Self {
44 Self {
45 gradients: HashMap::new(),
46 }
47 }
48
49 pub fn insert(&mut self, name: String, gradient: Box<dyn ArrayProtocol>) {
51 self.gradients.insert(name, gradient);
52 }
53
54 pub fn get(&self, name: &str) -> Option<&dyn ArrayProtocol> {
56 self.gradients.get(name).map(|b| b.as_ref())
57 }
58
59 pub fn get_mut(&mut self, name: &str) -> Option<&mut Box<dyn ArrayProtocol>> {
61 self.gradients.get_mut(name)
62 }
63
64 pub fn iter(&self) -> impl Iterator<Item = (&String, &Box<dyn ArrayProtocol>)> {
66 self.gradients.iter()
67 }
68
69 pub fn merge(&mut self, other: GradientDict) {
71 for (name, gradient) in other.gradients {
72 self.gradients.insert(name, gradient);
73 }
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.gradients.is_empty()
79 }
80
81 pub fn len(&self) -> usize {
83 self.gradients.len()
84 }
85
86 pub fn clear(&mut self) {
88 self.gradients.clear();
89 }
90
91 pub fn keys(&self) -> impl Iterator<Item = &String> {
93 self.gradients.keys()
94 }
95
96 pub fn values(&self) -> impl Iterator<Item = &Box<dyn ArrayProtocol>> {
98 self.gradients.values()
99 }
100}
101
102impl Default for GradientDict {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108#[allow(dead_code)]
110fn boxed_to_rc(boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
111 let array_ref = boxed.as_ref();
116
117 if let Some(ndarray_wrapper) = array_ref
120 .as_any()
121 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
122 {
123 let array_clone = ndarray_wrapper.as_array().clone();
124 return Rc::new(NdarrayWrapper::new(array_clone));
125 }
126
127 let fallback_array = ArrayD::<f64>::zeros(IxDyn(&[1, 1]));
130 Rc::new(NdarrayWrapper::new(fallback_array))
131}
132
133#[allow(dead_code)]
135fn box_to_rc_array_protocol(boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
136 boxed_to_rc(boxed)
137}
138
139#[allow(dead_code)]
141fn add(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
142 crate::array_protocol::operations::add(a, b).map_err(|e| e.into())
143}
144
145#[allow(dead_code)]
146fn multiply(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
147 crate::array_protocol::operations::multiply(a, b).map_err(|e| e.into())
148}
149
150#[allow(dead_code)]
151fn subtract(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
152 crate::array_protocol::operations::subtract(a, b).map_err(|e| e.into())
153}
154
155#[allow(dead_code)]
156fn ones_like(a: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
157 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
160 let shape = a_array.as_array().shape();
161 let ones = ArrayD::<f64>::ones(IxDyn(shape));
162 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
163 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
164 let shape = a_array.as_array().shape();
165 let ones = ArrayD::<f32>::ones(IxDyn(shape));
166 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
167 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
168 let shape = a_array.as_array().shape();
169 let ones = ArrayD::<i32>::ones(IxDyn(shape));
170 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
171 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
172 let shape = a_array.as_array().shape();
173 let ones = ArrayD::<i64>::ones(IxDyn(shape));
174 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
175 } else {
176 let shape = a.shape().to_vec();
178 let ones = ArrayD::<f64>::ones(IxDyn(&shape));
179 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
180 }
181}
182
183#[allow(dead_code)]
184fn broadcast_to(a: &dyn ArrayProtocol, shape: &[usize]) -> CoreResult<Box<dyn ArrayProtocol>> {
185 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
187 let array = a_array.as_array();
188 if array.len() == 1 {
190 let value = array.iter().next().cloned().unwrap_or(0.0);
191 let broadcasted = ArrayD::<f64>::from_elem(IxDyn(shape), value);
192 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
193 } else if array.shape() == shape {
194 Ok(Box::new(NdarrayWrapper::new(array.clone())) as Box<dyn ArrayProtocol>)
196 } else {
197 let inputshape = array.shape();
199 let _ndim_diff = shape.len().saturating_sub(inputshape.len());
200
201 let mut can_broadcast = true;
203 for i in 0..inputshape.len() {
204 let input_dim = inputshape[inputshape.len() - 1 - i];
205 let target_dim = shape[shape.len() - 1 - i];
206 if input_dim != 1 && input_dim != target_dim {
207 can_broadcast = false;
208 break;
209 }
210 }
211
212 if can_broadcast {
213 if let Some(broadcasted_view) = array.broadcast(IxDyn(shape)) {
215 let broadcasted = broadcasted_view.to_owned();
216 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
217 } else {
218 Err(CoreError::NotImplementedError(ErrorContext::new(
219 "Broadcasting failed for these shapes".to_string(),
220 )))
221 }
222 } else {
223 Err(CoreError::NotImplementedError(ErrorContext::new(
224 "Incompatible shapes for broadcasting".to_string(),
225 )))
226 }
227 }
228 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
229 let array = a_array.as_array();
230 if array.len() == 1 {
231 let value = array.iter().next().cloned().unwrap_or(0.0);
232 let broadcasted = ArrayD::<f32>::from_elem(IxDyn(shape), value);
233 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
234 } else if array.shape() == shape {
235 Ok(Box::new(NdarrayWrapper::new(array.clone())) as Box<dyn ArrayProtocol>)
236 } else if let Some(broadcasted_view) = array.broadcast(IxDyn(shape)) {
237 let broadcasted = broadcasted_view.to_owned();
238 Ok(Box::new(NdarrayWrapper::new(broadcasted)) as Box<dyn ArrayProtocol>)
239 } else {
240 Err(CoreError::NotImplementedError(ErrorContext::new(
241 "Broadcasting failed for these shapes".to_string(),
242 )))
243 }
244 } else {
245 let ones = ArrayD::<f64>::ones(IxDyn(shape));
247 Ok(Box::new(NdarrayWrapper::new(ones)) as Box<dyn ArrayProtocol>)
248 }
249}
250
251#[derive(Clone)]
253struct Node {
254 value: Rc<dyn ArrayProtocol>,
256
257 grad: Option<Rc<dyn ArrayProtocol>>,
259
260 op: Option<String>,
262
263 inputs: Vec<GradientTensor>,
265
266 requiresgrad: bool,
268
269 is_leaf: bool,
271}
272
273impl Node {
274 fn leaf(requiresgrad: bool) -> Self {
276 Self {
277 value: Rc::new(NdarrayWrapper::new(
278 crate::ndarray::Array0::<f64>::zeros(()),
279 )) as Rc<dyn ArrayProtocol>,
280 grad: None,
281 op: None,
282 inputs: Vec::new(),
283 requiresgrad,
284 is_leaf: true,
285 }
286 }
287
288 fn new_op(value: Rc<dyn ArrayProtocol>, op: String, inputs: Vec<GradientTensor>) -> Self {
290 let requiresgrad = inputs.iter().any(|x| x.requiresgrad());
291
292 Self {
293 value,
294 grad: None,
295 op: Some(op),
296 inputs,
297 requiresgrad,
298 is_leaf: false,
299 }
300 }
301}
302
303#[derive(Clone)]
305pub struct GradientTensor {
306 node: Rc<RefCell<Node>>,
308}
309
310impl GradientTensor {
311 pub fn new(value: Rc<dyn ArrayProtocol>, requiresgrad: bool) -> Self {
313 let mut node_inner = Node::leaf(requiresgrad);
314 node_inner.value = value;
315 node_inner.grad = None;
316 let node = Rc::new(RefCell::new(node_inner));
317 Self { node }
318 }
319
320 pub fn from_array<T, D>(array: Array<T, D>, requiresgrad: bool) -> Self
322 where
323 T: Clone + Send + Sync + 'static,
324 D: Dimension + Send + Sync + 'static,
325 {
326 let value = Rc::new(NdarrayWrapper::new(array)) as Rc<dyn ArrayProtocol>;
327 Self::new(value, requiresgrad)
328 }
329
330 pub fn value(&self) -> Rc<dyn ArrayProtocol> {
332 self.node.borrow().value.clone()
333 }
334
335 pub fn grad_2(&self) -> Option<Rc<dyn ArrayProtocol>> {
337 self.node.borrow().grad.clone()
338 }
339
340 pub fn requiresgrad(&self) -> bool {
342 self.node.borrow().requiresgrad
343 }
344
345 pub fn set_requiresgrad(&mut self, requiresgrad: bool) {
347 self.node.borrow_mut().requiresgrad = requiresgrad;
348 }
349
350 pub fn is_leaf(&self) -> bool {
352 self.node.borrow().is_leaf
353 }
354
355 fn from_op(value: Rc<dyn ArrayProtocol>, op: String, inputs: Vec<GradientTensor>) -> Self {
357 let node = Rc::new(RefCell::new(Node::new_op(value, op, inputs)));
358 Self { node }
359 }
360
361 pub fn set_value(&mut self, newvalue: Rc<dyn ArrayProtocol>) {
363 self.node.borrow_mut().grad = None; self.node.borrow_mut().value = newvalue;
365 }
366
367 pub fn backward(&self) -> CoreResult<()> {
369 let gradshape = if let Some(array) = self
371 .value()
372 .as_any()
373 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
374 {
375 array.as_array().raw_dim()
376 } else {
377 crate::ndarray::IxDyn(&[1])
379 };
380
381 let grad_array = Array::<f64, IxDyn>::ones(gradshape);
382 let grad = Rc::new(NdarrayWrapper::new(grad_array)) as Rc<dyn ArrayProtocol>;
383
384 self.backward_with_grad(grad)
386 }
387
388 fn backward_with_grad(&self, grad: Rc<dyn ArrayProtocol>) -> CoreResult<()> {
390 self.node.borrow_mut().grad = Some(grad.clone());
392
393 let mut visited = HashSet::new();
395 let mut topo = Vec::new();
396
397 fn build_topo(
399 tensor: &GradientTensor,
400 visited: &mut HashSet<*const RefCell<Node>>,
401 topo: &mut Vec<GradientTensor>,
402 ) {
403 let node_ptr = Rc::as_ptr(&tensor.node);
404 if !visited.contains(&node_ptr) {
405 visited.insert(node_ptr);
406
407 for input in &tensor.node.borrow().inputs {
409 build_topo(input, visited, topo);
410 }
411
412 topo.push(tensor.clone());
414 }
415 }
416
417 build_topo(self, &mut visited, &mut topo);
419
420 for node in topo.iter().rev() {
422 if !node.requiresgrad() {
424 continue;
425 }
426
427 let node_grad = match node.grad_2() {
429 Some(g) => g,
430 None => continue, };
432
433 if node.is_leaf() {
435 continue;
436 }
437
438 let op = match &node.node.borrow().op {
440 Some(op) => op.clone(),
441 None => continue, };
443
444 let inputs = node.node.borrow().inputs.clone();
445
446 match op.as_str() {
448 "add" => {
449 for input in &inputs {
451 if input.requiresgrad() {
452 let mut input_node = input.node.borrow_mut();
453 if let Some(input_grad) = &input_node.grad {
454 if let Ok(sum) = add(input_grad.as_ref(), node_grad.as_ref()) {
456 input_node.grad = Some(sum.into());
457 }
458 } else {
459 input_node.grad = Some(node_grad.clone());
460 }
461 }
462 }
463 }
464 "multiply"
465 if inputs.len() == 2 => {
467 let (a, b) = (&inputs[0], &inputs[1]);
468
469 if a.requiresgrad() {
471 let b_value = b.value();
472 if let Ok(grad_a) = multiply(node_grad.as_ref(), b_value.as_ref()) {
473 let mut a_node = a.node.borrow_mut();
474 if let Some(a_grad) = &a_node.grad {
475 if let Ok(sum) = add(a_grad.as_ref(), grad_a.as_ref()) {
477 a_node.grad = Some(box_to_rc_array_protocol(sum));
478 }
479 } else {
480 a_node.grad = Some(box_to_rc_array_protocol(grad_a));
481 }
482 }
483 }
484
485 if b.requiresgrad() {
487 let a_value = a.value();
488 if let Ok(grad_b) = multiply(node_grad.as_ref(), a_value.as_ref()) {
489 let mut b_node = b.node.borrow_mut();
490 if let Some(b_grad) = &b_node.grad {
491 if let Ok(sum) = add(b_grad.as_ref(), grad_b.as_ref()) {
493 b_node.grad = Some(box_to_rc_array_protocol(sum));
494 }
495 } else {
496 b_node.grad = Some(box_to_rc_array_protocol(grad_b));
497 }
498 }
499 }
500 }
501 "matmul"
502 if inputs.len() == 2 => {
506 let (a, b) = (&inputs[0], &inputs[1]);
507
508 if a.requiresgrad() {
510 if let (Some(b_array), Some(grad_out_array)) = (
511 b.value()
512 .as_any()
513 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
514 node_grad
515 .as_any()
516 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
517 ) {
518 let b_array_val = b_array.as_array();
519 let grad_out_array_val = grad_out_array.as_array();
520
521 let b_t = b_array_val.t();
523
524 let grad_outshape = grad_out_array_val.shape();
527 let grad_out_rows = grad_outshape[0];
528 let grad_out_cols = if grad_outshape.len() > 1 {
529 grad_outshape.iter().skip(1).product()
530 } else {
531 1
532 };
533 let grad_out_2d = grad_out_array_val
534 .clone()
535 .into_shape_with_order((grad_out_rows, grad_out_cols))
536 .expect("Operation failed");
537
538 let b_tshape = b_t.shape();
539 let b_t_rows = b_tshape[0];
540 let b_t_cols = if b_tshape.len() > 1 {
541 b_tshape.iter().skip(1).product()
542 } else {
543 1
544 };
545 let b_t_2d = b_t
546 .clone()
547 .into_shape_with_order((b_t_rows, b_t_cols))
548 .expect("Operation failed");
549
550 let grad_a_val = grad_out_2d.dot(&b_t_2d);
552
553 let grad_a_dyn = grad_a_val.into_dyn();
555 let grad_a = NdarrayWrapper::new(grad_a_dyn);
556
557 let mut a_node = a.node.borrow_mut();
559 if let Some(a_grad) = &a_node.grad {
560 if let (Some(a_grad_array), Some(grad_a_array)) = (
561 a_grad
562 .as_any()
563 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
564 grad_a
565 .as_any()
566 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
567 ) {
568 let sum = a_grad_array.as_array() + grad_a_array.as_array();
570 a_node.grad = Some(Rc::new(NdarrayWrapper::new(sum)));
571 }
572 } else {
573 a_node.grad = Some(Rc::new(grad_a));
575 }
576 }
577 }
578
579 if b.requiresgrad() {
581 if let (Some(a_array), Some(grad_out_array)) = (
582 a.value()
583 .as_any()
584 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
585 node_grad
586 .as_any()
587 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
588 ) {
589 let a_array_val = a_array.as_array();
590 let grad_out_array_val = grad_out_array.as_array();
591
592 let a_t = a_array_val.t();
594
595 let grad_outshape = grad_out_array_val.shape();
598 let grad_out_rows = grad_outshape[0];
599 let grad_out_cols = if grad_outshape.len() > 1 {
600 grad_outshape.iter().skip(1).product()
601 } else {
602 1
603 };
604 let grad_out_2d = grad_out_array_val
605 .clone()
606 .into_shape_with_order((grad_out_rows, grad_out_cols))
607 .expect("Operation failed");
608
609 let a_tshape = a_t.shape();
610 let a_t_rows = a_tshape[0];
611 let a_t_cols = if a_tshape.len() > 1 {
612 a_tshape.iter().skip(1).product()
613 } else {
614 1
615 };
616 let a_t_2d = a_t
617 .clone()
618 .into_shape_with_order((a_t_rows, a_t_cols))
619 .expect("Operation failed");
620
621 let grad_b_val = a_t_2d.dot(&grad_out_2d);
623
624 let grad_b_dyn = grad_b_val.into_dyn();
626 let grad_b = NdarrayWrapper::new(grad_b_dyn);
627
628 let mut b_node = b.node.borrow_mut();
630 if let Some(b_grad) = &b_node.grad {
631 if let (Some(b_grad_array), Some(grad_b_array)) = (
632 b_grad
633 .as_any()
634 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
635 grad_b
636 .as_any()
637 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
638 ) {
639 let sum = b_grad_array.as_array() + grad_b_array.as_array();
641 b_node.grad = Some(Rc::new(NdarrayWrapper::new(sum)));
642 }
643 } else {
644 b_node.grad = Some(Rc::new(grad_b));
646 }
647 }
648 }
649 }
650 "subtract"
651 if inputs.len() == 2 => {
653 let (a, b) = (&inputs[0], &inputs[1]);
654
655 if a.requiresgrad() {
657 let mut a_node = a.node.borrow_mut();
658 if let Some(a_grad) = &a_node.grad {
659 if let Ok(sum) = add(a_grad.as_ref(), node_grad.as_ref()) {
661 a_node.grad = Some(box_to_rc_array_protocol(sum));
662 }
663 } else {
664 a_node.grad = Some(node_grad.clone());
665 }
666 }
667
668 if b.requiresgrad() {
670 if let Ok(neg_grad) = multiply_by_scalar(node_grad.as_ref(), -1.0) {
671 let mut b_node = b.node.borrow_mut();
672 if let Some(b_grad) = &b_node.grad {
673 if let Ok(sum) = add(b_grad.as_ref(), neg_grad.as_ref()) {
675 b_node.grad = Some(box_to_rc_array_protocol(sum));
676 }
677 } else {
678 b_node.grad = Some(box_to_rc_array_protocol(neg_grad));
679 }
680 }
681 }
682 }
683 "divide"
684 if inputs.len() == 2 => {
686 let (a, b) = (&inputs[0], &inputs[1]);
687
688 if a.requiresgrad() {
690 let b_value = b.value();
691 if let Ok(grad_a) = divide(node_grad.as_ref(), b_value.as_ref()) {
692 let mut a_node = a.node.borrow_mut();
693 if let Some(a_grad) = &a_node.grad {
694 if let Ok(sum) = add(a_grad.as_ref(), grad_a.as_ref()) {
696 a_node.grad = Some(box_to_rc_array_protocol(sum));
697 }
698 } else {
699 a_node.grad = Some(box_to_rc_array_protocol(grad_a));
700 }
701 }
702 }
703
704 if b.requiresgrad() {
706 let a_value = a.value();
707 let b_value = b.value();
708
709 if let Ok(b_squared) = multiply(b_value.as_ref(), b_value.as_ref()) {
711 if let Ok(grad_times_a) =
713 multiply(node_grad.as_ref(), a_value.as_ref())
714 {
715 if let Ok(div_result) =
717 divide(grad_times_a.as_ref(), b_squared.as_ref())
718 {
719 if let Ok(grad_b) =
721 multiply_by_scalar(div_result.as_ref(), -1.0)
722 {
723 let mut b_node = b.node.borrow_mut();
724 if let Some(b_grad) = &b_node.grad {
725 if let Ok(sum) =
727 add(b_grad.as_ref(), grad_b.as_ref())
728 {
729 b_node.grad =
730 Some(box_to_rc_array_protocol(sum));
731 }
732 } else {
733 b_node.grad =
734 Some(box_to_rc_array_protocol(grad_b));
735 }
736 }
737 }
738 }
739 }
740 }
741 }
742 "sigmoid"
743 if inputs.len() == 1 => {
745 let input = &inputs[0];
746
747 if input.requiresgrad() {
748 let sigmoid_value = node.value();
750
751 if let Ok(ones) = ones_like(sigmoid_value.as_ref()) {
753 if let Ok(one_minus_sigmoid) =
754 subtract(ones.as_ref(), sigmoid_value.as_ref())
755 {
756 if let Ok(sigmoid_deriv) =
758 multiply(sigmoid_value.as_ref(), one_minus_sigmoid.as_ref())
759 {
760 if let Ok(grad_input) =
762 multiply(node_grad.as_ref(), sigmoid_deriv.as_ref())
763 {
764 let mut input_node = input.node.borrow_mut();
765 if let Some(input_grad) = &input_node.grad {
766 if let Ok(sum) =
768 add(input_grad.as_ref(), grad_input.as_ref())
769 {
770 input_node.grad =
771 Some(box_to_rc_array_protocol(sum));
772 }
773 } else {
774 input_node.grad =
775 Some(box_to_rc_array_protocol(grad_input));
776 }
777 }
778 }
779 }
780 }
781 }
782 }
783 "mean"
784 if inputs.len() == 1 => {
786 let input = &inputs[0];
787
788 if input.requiresgrad() {
789 let input_value = input.value();
791 if let Some(inputarray) = input_value
792 .as_any()
793 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
794 {
795 let n_elements = inputarray.as_array().len() as f64;
796
797 if let Ok(grad_input) =
799 multiply_by_scalar(node_grad.as_ref(), 1.0 / n_elements)
800 {
801 if let Ok(broadcasted_grad) = broadcast_to(
803 grad_input.as_ref(),
804 inputarray.as_array().shape(),
805 ) {
806 let mut input_node = input.node.borrow_mut();
807 if let Some(input_grad) = &input_node.grad {
808 if let Ok(sum) =
810 add(input_grad.as_ref(), broadcasted_grad.as_ref())
811 {
812 input_node.grad =
813 Some(box_to_rc_array_protocol(sum));
814 }
815 } else {
816 input_node.grad =
817 Some(box_to_rc_array_protocol(broadcasted_grad));
818 }
819 }
820 }
821 }
822 }
823 }
824 _ => {
825 }
827 }
828 }
829
830 Ok(())
831 }
832
833 pub fn detach(&self) -> Self {
835 GradientTensor::new(self.value(), false)
836 }
837}
838
839#[allow(dead_code)]
842pub fn grad_add(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
843 let a_value = a.value();
844 let b_value = b.value();
845
846 let result = add(a_value.as_ref(), b_value.as_ref())?;
848
849 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
851 Ok(GradientTensor::from_op(
852 result_rc,
853 "add".to_string(),
854 vec![a.clone(), b.clone()],
855 ))
856}
857
858#[allow(dead_code)]
860pub fn grad_multiply(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
861 let a_value = a.value();
862 let b_value = b.value();
863
864 let result = multiply(a_value.as_ref(), b_value.as_ref())?;
866
867 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
869 Ok(GradientTensor::from_op(
870 result_rc,
871 "multiply".to_string(),
872 vec![a.clone(), b.clone()],
873 ))
874}
875
876#[allow(dead_code)]
878pub fn grad_matmul(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
879 let a_value = a.value();
880 let b_value = b.value();
881
882 let result = matmul(a_value.as_ref(), b_value.as_ref())?;
884
885 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
887 Ok(GradientTensor::from_op(
888 result_rc,
889 "matmul".to_string(),
890 vec![a.clone(), b.clone()],
891 ))
892}
893
894#[allow(dead_code)]
896pub fn grad_subtract(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
897 let a_value = a.value();
898 let b_value = b.value();
899
900 let result = subtract(a_value.as_ref(), b_value.as_ref())?;
902
903 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
905 Ok(GradientTensor::from_op(
906 result_rc,
907 "subtract".to_string(),
908 vec![a.clone(), b.clone()],
909 ))
910}
911
912#[allow(dead_code)]
914pub fn grad_divide(a: &GradientTensor, b: &GradientTensor) -> CoreResult<GradientTensor> {
915 let a_value = a.value();
916 let b_value = b.value();
917
918 let result = divide(a_value.as_ref(), b_value.as_ref())?;
920
921 let result_rc: Rc<dyn ArrayProtocol> = box_to_rc_array_protocol(result);
923 Ok(GradientTensor::from_op(
924 result_rc,
925 "divide".to_string(),
926 vec![a.clone(), b.clone()],
927 ))
928}
929
930#[allow(dead_code)]
932pub fn grad_sigmoid(a: &GradientTensor) -> CoreResult<GradientTensor> {
933 let a_value = a.value();
934
935 if let Some(a_array) = a_value
937 .as_any()
938 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
939 {
940 let array = a_array.as_array();
941 let result = array.mapv(|x| 1.0 / (1.0 + (-x).exp()));
942 let result_wrapped = NdarrayWrapper::new(result);
943 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
944 Ok(GradientTensor::from_op(
945 result_rc,
946 "sigmoid".to_string(),
947 vec![a.clone()],
948 ))
949 } else if let Some(a_array) = a_value
950 .as_any()
951 .downcast_ref::<NdarrayWrapper<f32, IxDyn>>()
952 {
953 let array = a_array.as_array();
954 let result = array.mapv(|x| 1.0f32 / (1.0f32 + (-x).exp()));
955 let result_wrapped = NdarrayWrapper::new(result);
956 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
957 Ok(GradientTensor::from_op(
958 result_rc,
959 "sigmoid".to_string(),
960 vec![a.clone()],
961 ))
962 } else {
963 Err(CoreError::NotImplementedError(ErrorContext::new(
964 "sigmoid not implemented for this array type".to_string(),
965 )))
966 }
967}
968
969#[allow(dead_code)]
971pub fn grad_mean(a: &GradientTensor) -> CoreResult<GradientTensor> {
972 let a_value = a.value();
973
974 if let Some(a_array) = a_value
976 .as_any()
977 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
978 {
979 let array = a_array.as_array();
980 let mean_value = array.mean_or(0.0);
981 let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
982 let result_wrapped = NdarrayWrapper::new(result);
983 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
984 Ok(GradientTensor::from_op(
985 result_rc,
986 "mean".to_string(),
987 vec![a.clone()],
988 ))
989 } else if let Some(a_array) = a_value
990 .as_any()
991 .downcast_ref::<NdarrayWrapper<f32, IxDyn>>()
992 {
993 let array = a_array.as_array();
994 let mean_value = array.mean_or(0.0f32);
995 let result = ArrayD::<f32>::from_elem(IxDyn(&[1]), mean_value);
996 let result_wrapped = NdarrayWrapper::new(result);
997 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
998 Ok(GradientTensor::from_op(
999 result_rc,
1000 "mean".to_string(),
1001 vec![a.clone()],
1002 ))
1003 } else if let Some(a_array) = a_value
1004 .as_any()
1005 .downcast_ref::<NdarrayWrapper<i32, IxDyn>>()
1006 {
1007 let array = a_array.as_array();
1008 let mean_value = if array.is_empty() {
1009 0.0f64
1010 } else {
1011 array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
1012 };
1013 let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
1014 let result_wrapped = NdarrayWrapper::new(result);
1015 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
1016 Ok(GradientTensor::from_op(
1017 result_rc,
1018 "mean".to_string(),
1019 vec![a.clone()],
1020 ))
1021 } else if let Some(a_array) = a_value
1022 .as_any()
1023 .downcast_ref::<NdarrayWrapper<i64, IxDyn>>()
1024 {
1025 let array = a_array.as_array();
1026 let mean_value = if array.is_empty() {
1027 0.0f64
1028 } else {
1029 array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
1030 };
1031 let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
1032 let result_wrapped = NdarrayWrapper::new(result);
1033 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
1034 Ok(GradientTensor::from_op(
1035 result_rc,
1036 "mean".to_string(),
1037 vec![a.clone()],
1038 ))
1039 } else if let Some(a_array) = a_value.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>() {
1040 let array = a_array.as_array();
1041 let mean_value = if array.is_empty() {
1042 0.0f64
1043 } else {
1044 array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
1045 };
1046 let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
1047 let result_wrapped = NdarrayWrapper::new(result);
1048 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
1049 Ok(GradientTensor::from_op(
1050 result_rc,
1051 "mean".to_string(),
1052 vec![a.clone()],
1053 ))
1054 } else if let Some(a_array) = a_value
1055 .as_any()
1056 .downcast_ref::<NdarrayWrapper<u16, IxDyn>>()
1057 {
1058 let array = a_array.as_array();
1059 let mean_value = if array.is_empty() {
1060 0.0f64
1061 } else {
1062 array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
1063 };
1064 let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
1065 let result_wrapped = NdarrayWrapper::new(result);
1066 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
1067 Ok(GradientTensor::from_op(
1068 result_rc,
1069 "mean".to_string(),
1070 vec![a.clone()],
1071 ))
1072 } else if let Some(a_array) = a_value
1073 .as_any()
1074 .downcast_ref::<NdarrayWrapper<u32, IxDyn>>()
1075 {
1076 let array = a_array.as_array();
1077 let mean_value = if array.is_empty() {
1078 0.0f64
1079 } else {
1080 array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
1081 };
1082 let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
1083 let result_wrapped = NdarrayWrapper::new(result);
1084 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
1085 Ok(GradientTensor::from_op(
1086 result_rc,
1087 "mean".to_string(),
1088 vec![a.clone()],
1089 ))
1090 } else if let Some(a_array) = a_value
1091 .as_any()
1092 .downcast_ref::<NdarrayWrapper<u64, IxDyn>>()
1093 {
1094 let array = a_array.as_array();
1095 let mean_value = if array.is_empty() {
1096 0.0f64
1097 } else {
1098 array.iter().map(|&x| x as f64).sum::<f64>() / array.len() as f64
1099 };
1100 let result = ArrayD::<f64>::from_elem(IxDyn(&[1]), mean_value);
1101 let result_wrapped = NdarrayWrapper::new(result);
1102 let result_rc: Rc<dyn ArrayProtocol> = Rc::new(result_wrapped);
1103 Ok(GradientTensor::from_op(
1104 result_rc,
1105 "mean".to_string(),
1106 vec![a.clone()],
1107 ))
1108 } else {
1109 Err(CoreError::NotImplementedError(ErrorContext::new(
1110 "mean not implemented for this array type".to_string(),
1111 )))
1112 }
1113}
1114
1115pub struct Variable {
1117 tensor: GradientTensor,
1119
1120 name: String,
1122}
1123
1124impl Variable {
1125 pub fn new<T, D>(name: &str, array: Array<T, D>) -> Self
1127 where
1128 T: Clone + Send + Sync + 'static,
1129 D: Dimension + Send + Sync + 'static,
1130 {
1131 let tensor = GradientTensor::from_array(array, true);
1132 Self {
1133 tensor,
1134 name: name.to_string(),
1135 }
1136 }
1137
1138 pub const fn tensor(&self) -> &GradientTensor {
1140 &self.tensor
1141 }
1142
1143 pub fn value(&self) -> Rc<dyn ArrayProtocol> {
1145 self.tensor.value()
1146 }
1147
1148 pub fn grad_2(&self) -> Option<Rc<dyn ArrayProtocol>> {
1150 self.tensor.grad_2()
1151 }
1152
1153 pub fn name(&self) -> &str {
1155 &self.name
1156 }
1157
1158 pub fn set_gradient(&mut self, gradient: Box<dyn ArrayProtocol>) -> CoreResult<()> {
1160 let gradient_rc = self.box_to_rc(gradient);
1162
1163 self.tensor.node.borrow_mut().grad = Some(gradient_rc);
1165 Ok(())
1166 }
1167
1168 pub fn set_value(&mut self, newvalue: Box<dyn ArrayProtocol>) {
1170 let newvalue_rc = self.box_to_rc(newvalue);
1171 self.tensor.set_value(newvalue_rc);
1172 }
1173
1174 fn box_to_rc(&self, boxed: Box<dyn ArrayProtocol>) -> Rc<dyn ArrayProtocol> {
1176 if let Some(ndarray_wrapper) = boxed
1178 .as_ref()
1179 .as_any()
1180 .downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
1181 {
1182 let array_clone = ndarray_wrapper.as_array().clone();
1183 Rc::new(NdarrayWrapper::new(array_clone))
1184 } else {
1185 let fallback_array = ArrayD::<f64>::zeros(IxDyn(&[1, 1]));
1187 Rc::new(NdarrayWrapper::new(fallback_array))
1188 }
1189 }
1190}
1191
1192pub trait Optimizer {
1194 fn step(&mut self) -> CoreResult<()>;
1196
1197 fn zero_grad(&mut self);
1199
1200 fn add_variable(&mut self, var: Variable);
1202
1203 fn variables(&self) -> &[Variable];
1205
1206 fn accumulate_gradients(&mut self, gradients: &GradientDict) -> CoreResult<()> {
1208 for (param_name, gradient) in gradients.iter() {
1210 for var in self.variables_mut() {
1212 if var.name() == param_name {
1213 var.set_gradient(gradient.clone())?;
1214 break;
1215 }
1216 }
1217 }
1218 Ok(())
1219 }
1220
1221 fn variables_mut(&mut self) -> &mut [Variable] {
1223 &mut []
1226 }
1227}
1228
1229pub struct SGD {
1231 variables: Vec<Variable>,
1233
1234 learningrate: f64,
1236
1237 momentum: f64,
1239
1240 velocity: Vec<Option<Box<dyn ArrayProtocol>>>,
1242}
1243
1244impl SGD {
1245 pub fn new(learningrate: f64, momentum: Option<f64>) -> Self {
1247 Self {
1248 variables: Vec::new(),
1249 learningrate,
1250 momentum: momentum.unwrap_or(0.0),
1251 velocity: Vec::new(),
1252 }
1253 }
1254
1255 pub fn set_learningrate(&mut self, learningrate: f64) {
1257 self.learningrate = learningrate;
1258 }
1259}
1260
1261impl Optimizer for SGD {
1262 fn step(&mut self) -> CoreResult<()> {
1263 for (i, var) in self.variables.iter_mut().enumerate() {
1264 if let Some(grad) = var.grad_2() {
1265 let var_value = var.value();
1266
1267 let update = if self.momentum > 0.0 {
1269 if i >= self.velocity.len() {
1270 self.velocity.resize_with(i + 1, || None);
1271 }
1272
1273 if let Some(vel) = &self.velocity[i] {
1274 let scaled_grad = multiply_by_scalar(grad.as_ref(), self.learningrate)?;
1276 let scaled_vel = multiply_by_scalar(vel.as_ref(), self.momentum)?;
1277 let update = add(scaled_vel.as_ref(), scaled_grad.as_ref())?;
1278 self.velocity[i] = Some(update.clone());
1279 update
1280 } else {
1281 let update = multiply_by_scalar(grad.as_ref(), self.learningrate)?;
1283 self.velocity[i] = Some(update.clone());
1284 update
1285 }
1286 } else {
1287 multiply_by_scalar(grad.as_ref(), self.learningrate)?
1289 };
1290
1291 let updated_value = subtract_arrays(var_value.as_ref(), update.as_ref())?;
1293 var.set_value(updated_value);
1294 }
1295 }
1296
1297 Ok(())
1298 }
1299
1300 fn zero_grad(&mut self) {
1301 for var in &self.variables {
1302 var.tensor.node.borrow_mut().grad = None;
1303 }
1304 }
1305
1306 fn add_variable(&mut self, var: Variable) {
1307 self.variables.push(var);
1308 self.velocity.push(None);
1309 }
1310
1311 fn variables(&self) -> &[Variable] {
1312 &self.variables
1313 }
1314
1315 fn variables_mut(&mut self) -> &mut [Variable] {
1316 &mut self.variables
1317 }
1318}
1319
1320pub struct Adam {
1322 variables: Vec<Variable>,
1324
1325 learningrate: f64,
1327
1328 beta1: f64,
1330
1331 beta2: f64,
1333
1334 epsilon: f64,
1336
1337 m: Vec<Option<Box<dyn ArrayProtocol>>>,
1339
1340 v: Vec<Option<Box<dyn ArrayProtocol>>>,
1342
1343 t: usize,
1345}
1346
1347impl Adam {
1348 pub fn new(
1350 learningrate: f64,
1351 beta1: Option<f64>,
1352 beta2: Option<f64>,
1353 epsilon: Option<f64>,
1354 ) -> Self {
1355 Self {
1356 variables: Vec::new(),
1357 learningrate,
1358 beta1: beta1.unwrap_or(0.9),
1359 beta2: beta2.unwrap_or(0.999),
1360 epsilon: epsilon.unwrap_or(1e-8),
1361 m: Vec::new(),
1362 v: Vec::new(),
1363 t: 0,
1364 }
1365 }
1366}
1367
1368impl Optimizer for Adam {
1369 fn step(&mut self) -> CoreResult<()> {
1370 self.t += 1;
1371
1372 for (i, var) in self.variables.iter_mut().enumerate() {
1373 if let Some(grad) = var.grad_2() {
1374 let var_value = var.value();
1375
1376 if i >= self.m.len() {
1378 self.m.resize_with(i + 1, || None);
1379 self.v.resize_with(i + 1, || None);
1380 }
1381
1382 let m = if let Some(m_prev) = &self.m[i] {
1384 let scaled_m = multiply_by_scalar(m_prev.as_ref(), self.beta1)?;
1386 let scaled_grad = multiply_by_scalar(grad.as_ref(), 1.0 - self.beta1)?;
1387 add(scaled_m.as_ref(), scaled_grad.as_ref())?
1388 } else {
1389 multiply_by_scalar(grad.as_ref(), 1.0 - self.beta1)?
1391 };
1392
1393 let v = if let Some(v_prev) = &self.v[i] {
1395 let scaled_v = multiply_by_scalar(v_prev.as_ref(), self.beta2)?;
1397 let grad_squared = multiply(grad.as_ref(), grad.as_ref())?;
1398 let scaled_grad_sq =
1399 multiply_by_scalar(grad_squared.as_ref(), 1.0 - self.beta2)?;
1400 add(scaled_v.as_ref(), scaled_grad_sq.as_ref())?
1401 } else {
1402 let grad_squared = multiply(grad.as_ref(), grad.as_ref())?;
1404 multiply_by_scalar(grad_squared.as_ref(), 1.0 - self.beta2)?
1405 };
1406
1407 self.m[i] = Some(m.clone());
1409 self.v[i] = Some(v.clone());
1410
1411 let m_hat =
1413 multiply_by_scalar(m.as_ref(), 1.0 / (1.0 - self.beta1.powi(self.t as i32)))?;
1414 let v_hat =
1415 multiply_by_scalar(v.as_ref(), 1.0 / (1.0 - self.beta2.powi(self.t as i32)))?;
1416
1417 let v_hat_sqrt = sqrt(v_hat.as_ref())?;
1419 let v_hat_sqrt_eps = add_scalar(v_hat_sqrt.as_ref(), self.epsilon)?;
1420 let update_dir = divide(m_hat.as_ref(), v_hat_sqrt_eps.as_ref())?;
1421 let update = multiply_by_scalar(update_dir.as_ref(), self.learningrate)?;
1422
1423 let updated_value = subtract_arrays(var_value.as_ref(), update.as_ref())?;
1425 var.set_value(updated_value);
1426 }
1427 }
1428
1429 Ok(())
1430 }
1431
1432 fn zero_grad(&mut self) {
1433 for var in &self.variables {
1434 var.tensor.node.borrow_mut().grad = None;
1435 }
1436 }
1437
1438 fn add_variable(&mut self, var: Variable) {
1439 self.variables.push(var);
1440 self.m.push(None);
1441 self.v.push(None);
1442 }
1443
1444 fn variables(&self) -> &[Variable] {
1445 &self.variables
1446 }
1447
1448 fn variables_mut(&mut self) -> &mut [Variable] {
1449 &mut self.variables
1450 }
1451}
1452
1453fn multiply_by_scalar(a: &dyn ArrayProtocol, scalar: f64) -> CoreResult<Box<dyn ArrayProtocol>> {
1457 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1458 let inputarray = a_array.as_array();
1459 let result = inputarray.mapv(|x| x * scalar);
1460 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1461 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1462 let inputarray = a_array.as_array();
1463 let result = inputarray.mapv(|x| x * scalar as f32);
1464 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1465 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
1466 let inputarray = a_array.as_array();
1467 let result = inputarray.mapv(|x| (x as f64 * scalar) as i32);
1468 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1469 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
1470 let inputarray = a_array.as_array();
1471 let result = inputarray.mapv(|x| (x as f64 * scalar) as i64);
1472 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1473 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>() {
1474 let inputarray = a_array.as_array();
1475 let result = inputarray.mapv(|x| (x as f64 * scalar) as u8);
1476 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1477 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>() {
1478 let inputarray = a_array.as_array();
1479 let result = inputarray.mapv(|x| (x as f64 * scalar) as u16);
1480 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1481 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>() {
1482 let inputarray = a_array.as_array();
1483 let result = inputarray.mapv(|x| (x as f64 * scalar) as u32);
1484 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1485 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>() {
1486 let inputarray = a_array.as_array();
1487 let result = inputarray.mapv(|x| (x as f64 * scalar) as u64);
1488 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1489 } else {
1490 Err(CoreError::NotImplementedError(ErrorContext::new(
1491 "multiply_by_scalar not implemented for this array type".to_string(),
1492 )))
1493 }
1494}
1495
1496fn subtract_arrays(
1498 a: &dyn ArrayProtocol,
1499 b: &dyn ArrayProtocol,
1500) -> CoreResult<Box<dyn ArrayProtocol>> {
1501 if let (Some(a_wrapper), Some(b_array)) = (
1503 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1504 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1505 ) {
1506 let a_arr = a_wrapper.as_array();
1507 let b_arr = b_array.as_array();
1508 let result = a_arr - b_arr;
1509 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1510 } else if let (Some(a_wrapper), Some(b_array)) = (
1511 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1512 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1513 ) {
1514 let a_arr = a_wrapper.as_array();
1515 let b_arr = b_array.as_array();
1516 let result = a_arr - b_arr;
1517 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1518 } else if let (Some(a_wrapper), Some(b_array)) = (
1519 a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
1520 b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
1521 ) {
1522 let a_arr = a_wrapper.as_array();
1523 let b_arr = b_array.as_array();
1524 let result = a_arr - b_arr;
1525 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1526 } else if let (Some(a_wrapper), Some(b_array)) = (
1527 a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
1528 b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
1529 ) {
1530 let a_arr = a_wrapper.as_array();
1531 let b_arr = b_array.as_array();
1532 let result = a_arr - b_arr;
1533 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1534 } else if let (Some(a_wrapper), Some(b_array)) = (
1535 a.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>(),
1536 b.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>(),
1537 ) {
1538 let a_arr = a_wrapper.as_array();
1539 let b_arr = b_array.as_array();
1540 let result = a_arr - b_arr;
1541 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1542 } else if let (Some(a_wrapper), Some(b_array)) = (
1543 a.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>(),
1544 b.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>(),
1545 ) {
1546 let a_arr = a_wrapper.as_array();
1547 let b_arr = b_array.as_array();
1548 let result = a_arr - b_arr;
1549 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1550 } else if let (Some(a_wrapper), Some(b_array)) = (
1551 a.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>(),
1552 b.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>(),
1553 ) {
1554 let a_arr = a_wrapper.as_array();
1555 let b_arr = b_array.as_array();
1556 let result = a_arr - b_arr;
1557 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1558 } else if let (Some(a_wrapper), Some(b_array)) = (
1559 a.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>(),
1560 b.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>(),
1561 ) {
1562 let a_arr = a_wrapper.as_array();
1563 let b_arr = b_array.as_array();
1564 let result = a_arr - b_arr;
1565 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1566 } else {
1567 Err(CoreError::NotImplementedError(ErrorContext::new(
1568 "subtract_arrays not implemented for these array types".to_string(),
1569 )))
1570 }
1571}
1572
1573fn sqrt(a: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
1575 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1576 let result = a_array.as_array().mapv(|x| x.sqrt());
1577 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1578 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1579 let result = a_array.as_array().mapv(|x| x.sqrt());
1580 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1581 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
1582 let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
1583 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1584 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
1585 let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
1586 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1587 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>() {
1588 let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
1589 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1590 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>() {
1591 let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
1592 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1593 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>() {
1594 let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
1595 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1596 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>() {
1597 let result = a_array.as_array().mapv(|x| (x as f64).sqrt());
1598 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1599 } else {
1600 Err(CoreError::NotImplementedError(ErrorContext::new(
1601 "sqrt not implemented for this array type".to_string(),
1602 )))
1603 }
1604}
1605
1606fn add_scalar(a: &dyn ArrayProtocol, scalar: f64) -> CoreResult<Box<dyn ArrayProtocol>> {
1608 if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
1609 let result = a_array.as_array().mapv(|x| x + scalar);
1610 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1611 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>() {
1612 let result = a_array.as_array().mapv(|x| x + scalar as f32);
1613 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1614 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>() {
1615 let result = a_array.as_array().mapv(|x| x + scalar as i32);
1616 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1617 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>() {
1618 let result = a_array.as_array().mapv(|x| x + scalar as i64);
1619 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1620 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>() {
1621 let result = a_array.as_array().mapv(|x| x + scalar as u8);
1622 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1623 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>() {
1624 let result = a_array.as_array().mapv(|x| x + scalar as u16);
1625 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1626 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>() {
1627 let result = a_array.as_array().mapv(|x| x + scalar as u32);
1628 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1629 } else if let Some(a_array) = a.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>() {
1630 let result = a_array.as_array().mapv(|x| x + scalar as u64);
1631 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1632 } else {
1633 Err(CoreError::NotImplementedError(ErrorContext::new(
1634 "add_scalar not implemented for this array type".to_string(),
1635 )))
1636 }
1637}
1638
1639fn divide(a: &dyn ArrayProtocol, b: &dyn ArrayProtocol) -> CoreResult<Box<dyn ArrayProtocol>> {
1641 if let (Some(a_array), Some(b_array)) = (
1642 a.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1643 b.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>(),
1644 ) {
1645 let result = a_array.as_array() / b_array.as_array();
1646 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1647 } else if let (Some(a_array), Some(b_array)) = (
1648 a.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1649 b.as_any().downcast_ref::<NdarrayWrapper<f32, IxDyn>>(),
1650 ) {
1651 let result = a_array.as_array() / b_array.as_array();
1652 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1653 } else if let (Some(a_array), Some(b_array)) = (
1654 a.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
1655 b.as_any().downcast_ref::<NdarrayWrapper<i32, IxDyn>>(),
1656 ) {
1657 let result = ::ndarray::Zip::from(a_array.as_array())
1658 .and(b_array.as_array())
1659 .map_collect(|&av, &bv| av as f64 / bv as f64);
1660 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1661 } else if let (Some(a_array), Some(b_array)) = (
1662 a.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
1663 b.as_any().downcast_ref::<NdarrayWrapper<i64, IxDyn>>(),
1664 ) {
1665 let result = ::ndarray::Zip::from(a_array.as_array())
1666 .and(b_array.as_array())
1667 .map_collect(|&av, &bv| av as f64 / bv as f64);
1668 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1669 } else if let (Some(a_array), Some(b_array)) = (
1670 a.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>(),
1671 b.as_any().downcast_ref::<NdarrayWrapper<u8, IxDyn>>(),
1672 ) {
1673 let result = ::ndarray::Zip::from(a_array.as_array())
1674 .and(b_array.as_array())
1675 .map_collect(|&av, &bv| av as f64 / bv as f64);
1676 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1677 } else if let (Some(a_array), Some(b_array)) = (
1678 a.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>(),
1679 b.as_any().downcast_ref::<NdarrayWrapper<u16, IxDyn>>(),
1680 ) {
1681 let result = ::ndarray::Zip::from(a_array.as_array())
1682 .and(b_array.as_array())
1683 .map_collect(|&av, &bv| av as f64 / bv as f64);
1684 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1685 } else if let (Some(a_array), Some(b_array)) = (
1686 a.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>(),
1687 b.as_any().downcast_ref::<NdarrayWrapper<u32, IxDyn>>(),
1688 ) {
1689 let result = ::ndarray::Zip::from(a_array.as_array())
1690 .and(b_array.as_array())
1691 .map_collect(|&av, &bv| av as f64 / bv as f64);
1692 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1693 } else if let (Some(a_array), Some(b_array)) = (
1694 a.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>(),
1695 b.as_any().downcast_ref::<NdarrayWrapper<u64, IxDyn>>(),
1696 ) {
1697 let result = ::ndarray::Zip::from(a_array.as_array())
1698 .and(b_array.as_array())
1699 .map_collect(|&av, &bv| av as f64 / bv as f64);
1700 Ok(Box::new(NdarrayWrapper::new(result)) as Box<dyn ArrayProtocol>)
1701 } else {
1702 Err(CoreError::NotImplementedError(ErrorContext::new(
1703 "divide not implemented for these array types".to_string(),
1704 )))
1705 }
1706}
1707
1708#[cfg(test)]
1709#[path = "grad_tests.rs"]
1710mod tests;