1use crate::error::ZyxError;
2use crate::utils::{get_dtype, get_shape};
3use crate::{
4 dtype::DType,
5 node::Node,
6 scalar::Scalar,
7 shape::Shape,
8 tensor::{self, Id},
9};
10use alloc::collections::btree_map::Entry;
11use alloc::{
12 collections::{BTreeMap, BTreeSet},
13 vec::Vec,
14};
15use core::ops::Range;
16use rand::distributions::Uniform;
17
18pub trait RuntimeBackend {
23 fn is_evaluated(&self, x: Id) -> bool;
25 fn is_free_id(&self, x: Id) -> bool;
27 fn remove(&mut self, x: Id) -> Result<(), ZyxError>;
29 fn store<T: Scalar, IT>(&mut self, x: Id, iter: IT) -> Result<(), ZyxError>
31 where
32 IT: IntoIterator<Item = T>,
33 IT::IntoIter: ExactSizeIterator;
34 fn load<T: Scalar>(&mut self, x: Id, numel: usize) -> Result<Vec<T>, ZyxError>;
36 fn evaluate(
39 &mut self,
40 rcs: BTreeMap<Id, u32>,
41 order: &[Id],
42 nodes: &[Node],
43 ) -> Result<(), ZyxError>;
44}
45
46pub struct Runtime<R: RuntimeBackend> {
49 rng: rand::rngs::SmallRng,
50 rcs: Vec<u32>,
51 nodes: Vec<Node>,
52 unrealized_nodes_count: usize,
53 runtime_backend: R,
54}
55
56impl<R: RuntimeBackend> Runtime<R> {
57 #[must_use]
59 pub fn new(runtime_backend: R) -> Self {
60 use rand::SeedableRng;
61 Self {
62 rng: rand::rngs::SmallRng::seed_from_u64(420_694_206_942_069),
63 rcs: Vec::new(),
64 nodes: Vec::new(),
65 unrealized_nodes_count: 0,
66 runtime_backend,
67 }
68 }
69
70 pub fn randn(&mut self, shape: Shape, dtype: DType) -> Result<Id, ZyxError> {
72 use rand::Rng;
73 let n = shape.numel();
74 let mut rng = self.rng.clone();
75 use rand::distributions::Standard;
76 let data1 = match dtype {
77 DType::F32 => self.store::<f32, _>((0..n).map(move |_| rng.sample(Standard))),
78 DType::F64 => self.store::<f64, _>((0..n).map(move |_| rng.sample(Standard))),
79 DType::I32 => self.store::<i32, _>((0..n).map(move |_| rng.sample(Standard))),
80 }?;
81 let data = self.push(Node::Reshape(data1, shape))?;
82 self.release(data1)?;
83 for _ in 0..n {
85 self.rng.sample::<f32, _>(Standard);
86 }
87 Ok(data)
88 }
89
90 pub fn uniform<T: Scalar>(&mut self, shape: Shape, range: Range<T>) -> Result<Id, ZyxError> {
92 use rand::Rng;
94 let n = shape.numel();
95 let mut rng = self.rng.clone();
96 use rand::distributions::Standard;
97 let data1 = match T::dtype() {
98 DType::F32 => self.store((0..n).map(move |_| {
99 rng.sample(Uniform::new(
100 range.start.clone().into_f32(),
101 range.end.clone().into_f32(),
102 ))
103 })),
104 DType::F64 => self.store((0..n).map(move |_| {
105 rng.sample(Uniform::new(
106 range.start.clone().into_f64(),
107 range.end.clone().into_f64(),
108 ))
109 })),
110 DType::I32 => self.store((0..n).map(move |_| {
111 rng.sample(Uniform::new(
112 range.start.clone().into_i32(),
113 range.end.clone().into_i32(),
114 ))
115 })),
116 }?;
117 let data = self.push(Node::Reshape(data1, shape))?;
118 self.release(data1)?;
119 for _ in 0..n {
121 self.rng.sample::<f32, _>(Standard);
122 }
123 Ok(data)
124 }
125
126 #[must_use]
128 pub fn shape(&self, x: Id) -> &Shape {
129 get_shape(self.nodes.as_slice(), x)
130 }
131
132 #[must_use]
134 pub fn dtype(&self, x: Id) -> DType {
135 get_dtype(self.nodes.as_slice(), x)
136 }
137
138 pub fn load<T: Scalar>(&mut self, x: Id) -> Result<Vec<T>, ZyxError> {
140 if !self.runtime_backend.is_evaluated(x) {
141 self.evaluate(BTreeSet::from([x]))?;
142 }
143 let numel = get_shape(self.nodes.as_slice(), x).numel();
144 self.runtime_backend.load(x, numel)
146 }
147
148 pub fn store<T: Scalar, IT>(&mut self, iter: IT) -> Result<Id, ZyxError>
150 where
151 IT: IntoIterator<Item = T>,
152 IT::IntoIter: ExactSizeIterator,
153 {
154 let iter = iter.into_iter();
157 let len = iter.len();
158 let node = Node::Leaf(len.into(), T::dtype());
159 let id = if let Some(i) = self
160 .rcs
161 .iter()
162 .enumerate()
163 .position(|(i, rc)| *rc == 0 && self.runtime_backend.is_free_id(tensor::id(i)))
164 {
165 let id = tensor::id(i);
166 self.rcs[i] = 1;
167 self.nodes[i] = node;
168 id
169 } else {
170 let id = tensor::id(self.rcs.len());
171 self.rcs.push(1);
172 self.nodes.push(node);
173 id
174 };
175 self.runtime_backend.store(id, iter)?;
177 Ok(id)
179 }
180
181 pub fn push(&mut self, node: Node) -> Result<Id, ZyxError> {
185 match node {
188 Node::Reshape(x, ref shape) | Node::Expand(x, ref shape) => {
189 if shape == self.shape(x) {
190 self.retain(x);
191 return Ok(x);
192 }
193 }
194 Node::Sum(x, ref axes, ..) | Node::Max(x, ref axes, ..) => {
195 if axes.len() == 0 {
196 self.retain(x);
197 return Ok(x);
198 }
199 }
200 _ => {}
201 }
202 for nid in node.parameters() {
203 self.retain(nid);
204 }
205 let id = if let Some(i) = self
206 .rcs
207 .iter()
208 .enumerate()
209 .position(|(i, rc)| *rc == 0 && self.runtime_backend.is_free_id(tensor::id(i)))
210 {
211 let id = tensor::id(i);
212 self.rcs[i] = 1;
213 self.nodes[i] = node;
214 id
215 } else {
216 let id = tensor::id(self.rcs.len());
217 self.rcs.push(1);
218 if self.rcs.len() > 4000000000 {
219 panic!("Maximum number of tensors has been reached. Zyx supports up to 4 billion tensors. \
220 Please check your code for memory leaks. If you really need to use more tensors, please raise an issue: https://github.com/zk4x/zyx");
221 }
222 self.nodes.push(node);
223 id
224 };
225 self.unrealized_nodes_count += 1;
227 if self.unrealized_nodes_count > 10000 {
229 self.evaluate([id].into_iter().collect::<BTreeSet<Id>>())?;
230 }
232 Ok(id)
233 }
234
235 pub fn release(&mut self, x: Id) -> Result<(), ZyxError> {
238 let mut params = Vec::with_capacity(10);
240 params.push(x);
241 while let Some(x) = params.pop() {
242 self.rcs[x.i()] -= 1;
243 if self.rcs[x.i()] == 0 {
245 params.extend(self.nodes[x.i()].parameters());
246 self.runtime_backend.remove(x)?;
247 if !matches!(self.nodes[x.i()], Node::Leaf(..) | Node::Uniform(..)) {
249 self.unrealized_nodes_count -= 1;
250 }
251 }
252 }
253 Ok(())
255 }
256
257 pub fn retain(&mut self, x: Id) {
259 debug_assert!(
262 self.rcs[x.i()] < u32::MAX,
263 "Reference count of tensor {x} has been exceeded,\
264 This is zyx bug. please report it at: https://github.com/zk4x/zyx"
265 );
266 self.rcs[x.i()] += 1;
267 }
268
269 pub fn debug_graph(&self) {
271 for (id, node) in self.nodes.iter().enumerate() {
272 std::println!("{id:>5} x{:>3} -> {node:?}", self.rcs[id]);
273 }
274 }
275
276 pub fn evaluate(&mut self, nodes: BTreeSet<Id>) -> Result<(), ZyxError> {
278 let mut temp_rcs: BTreeMap<Id, u32> = BTreeMap::new();
297 let mut params: Vec<Id> = nodes.iter().copied().collect();
298 params.reserve(100);
299 while let Some(nid) = params.pop() {
300 temp_rcs
302 .entry(nid)
303 .and_modify(|rc| *rc += 1)
304 .or_insert_with(|| {
305 if !self.runtime_backend.is_evaluated(nid) {
306 params.extend(self.nodes[nid.i()].parameters());
307 }
308 1
309 });
310 }
311 let mut order = Vec::new();
313 let mut rcs: BTreeMap<Id, u32> = BTreeMap::new();
314 let mut params: Vec<Id> = nodes.iter().copied().collect();
315 params.reserve(100);
316 while let Some(nid) = params.pop() {
317 if let Some(temp_rc) = temp_rcs.get(&nid) {
318 let rc = rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert(1);
319 if *temp_rc == *rc {
320 order.push(nid);
321 params.extend(self.nodes[nid.i()].parameters());
322 }
323 }
324 }
325 order.reverse();
326
327 let mut drop_nodes = BTreeSet::new();
330 let mut temp_rcs: BTreeMap<Id, u32> = BTreeMap::new();
331 let mut params: Vec<Id> = nodes.iter().copied().collect();
332 params.reserve(100);
333 while let Some(nid) = params.pop() {
334 if !matches!(self.nodes[nid.i()], Node::Leaf(..)) {
335 temp_rcs
336 .entry(nid)
337 .and_modify(|rc| *rc += 1)
338 .or_insert_with(|| {
339 params.extend(self.nodes[nid.i()].parameters());
340 1
341 });
342 } else {
343 let rc = temp_rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert(1);
344 if *rc == self.rcs[nid.i()] && !nodes.contains(&nid) {
347 drop_nodes.insert(nid);
348 }
349 }
350 }
351 let mut new_order = Vec::new();
352 let mut new_rcs: BTreeMap<Id, u32> = BTreeMap::new();
353 let mut params: Vec<Id> = nodes.iter().copied().collect();
354 params.reserve(100);
355 while let Some(nid) = params.pop() {
356 if let Some(temp_rc) = temp_rcs.get(&nid) {
357 let rc = new_rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert(1);
358 if *temp_rc == *rc {
359 new_order.push(nid);
360 params.extend(self.nodes[nid.i()].parameters());
361 }
362 }
363 }
364 new_order.reverse();
365
366 let mut new_leafs = BTreeSet::new();
368 for nid in &new_order {
369 if new_rcs[nid] == self.rcs[nid.i()] && !nodes.contains(nid) {
370 for p in self.nodes[nid.i()].parameters() {
371 if drop_nodes.contains(&p) {
372 drop_nodes.insert(*nid);
373 }
374 }
375 } else {
376 if self.nodes[nid.i()]
377 .parameters()
378 .any(|p| drop_nodes.contains(&p))
379 {
380 new_leafs.insert(*nid);
381 }
382 }
383 }
384
385 let mut user_rc = self.rcs.clone();
389 for (i, node) in self.nodes.iter().enumerate() {
390 if self.rcs[i] != 0 {
391 for p in node.parameters() {
393 user_rc[p.i()] -= 1;
394 }
395 }
396 }
397 for nid in &new_order {
398 match &self.nodes[nid.i()] {
400 Node::Cmplt(x, y) => {
401 let mut detach_rc = BTreeMap::new();
402 new_leafs.insert(*x);
403 new_leafs.insert(*y);
404 let mut params = Vec::with_capacity(10);
405 params.push(*x);
406 params.push(*y);
407 while let Some(x) = params.pop() {
408 let rc = detach_rc.entry(nid).and_modify(|rc| *rc += 1).or_insert(1);
409 if *rc == self.rcs[x.i()] {
410 drop_nodes.insert(x);
411 }
412 }
413 }
414 Node::Detach(x) => {
415 let mut detach_rc = BTreeMap::new();
416 new_leafs.insert(*x);
417 let mut params = Vec::with_capacity(10);
418 params.push(*x);
419 while let Some(x) = params.pop() {
420 let rc = detach_rc.entry(nid).and_modify(|rc| *rc += 1).or_insert(1);
421 if *rc == self.rcs[x.i()] {
422 drop_nodes.insert(x);
423 }
424 }
425 }
426 _ => {}
427 }
428 }
429
430 for nid in &new_leafs {
433 if let Some(rc) = rcs.get_mut(nid) {
435 *rc += 1;
436 }
437 }
438
439 self.runtime_backend.evaluate(rcs, &order, &self.nodes)?;
450
451 for nid in new_leafs {
452 self.unrealized_nodes_count -= 1;
453 self.nodes[nid.i()] = Node::Leaf(
454 get_shape(&self.nodes, nid).clone(),
455 get_dtype(&self.nodes, nid),
456 );
457 }
458 for nid in drop_nodes {
459 self.rcs[nid.i()] = 0;
460 self.runtime_backend.remove(nid)?;
461 if !matches!(self.nodes[nid.i()], Node::Leaf(..) | Node::Uniform(..)) {
462 self.unrealized_nodes_count -= 1;
463 }
464 }
465 std::println!("Non-evaluated nodes count after: {}", self.unrealized_nodes_count);
466
467 Ok(())
474 }
475
476 #[must_use]
478 pub fn plot_graph_dot(&self, ids: &[Id]) -> alloc::string::String {
479 let mut params: Vec<Id> = ids.into();
481 let mut rcs: BTreeMap<Id, u8> = BTreeMap::new();
482 while let Some(nid) = params.pop() {
483 rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert_with(|| {
484 params.extend(self.nodes[nid.i()].parameters());
485 1
486 });
487 }
488 let mut order = Vec::new();
490 let mut internal_rcs: BTreeMap<Id, u8> = BTreeMap::new();
491 let mut params: Vec<Id> = ids.into();
492 while let Some(nid) = params.pop() {
493 if rcs[&nid]
494 == *internal_rcs
495 .entry(nid)
496 .and_modify(|rc| *rc += 1)
497 .or_insert(1)
498 {
499 order.push(nid);
500 if rcs.contains_key(&nid) {
501 params.extend(self.nodes[nid.i()].parameters());
502 }
503 }
504 }
505 let mut topo: BTreeSet<Id> = ids.iter().copied().collect();
508 for nid in order.into_iter().rev() {
509 for p in self.nodes[nid.i()].parameters() {
510 if topo.contains(&p) {
511 topo.insert(nid);
512 }
513 }
514 }
515
516 crate::utils::plot_graph_dot(&topo, &self.nodes, &self.rcs)
517 }
518
519 pub fn backward(
521 &mut self,
522 x: Id,
523 sources: &BTreeSet<Id>,
524 ) -> Result<BTreeMap<Id, Id>, ZyxError> {
525 fn build_topo(x: Id, sources: &BTreeSet<Id>, nodes: &[Node]) -> Vec<Id> {
526 let mut params: Vec<Id> = alloc::vec![x];
528 let mut rcs: BTreeMap<Id, u8> = BTreeMap::new();
529 while let Some(nid) = params.pop() {
530 rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert_with(|| {
531 if !sources.contains(&nid) && !matches!(nodes[nid.i()], Node::Detach(..) | Node::Cmplt(..)) {
532 params.extend(nodes[nid.i()].parameters());
533 }
534 1
535 });
536 }
537 let mut order = Vec::new();
539 let mut internal_rcs: BTreeMap<Id, u8> = BTreeMap::new();
540 let mut params: Vec<Id> = alloc::vec![x];
541 while let Some(nid) = params.pop() {
542 if let Some(rc) = rcs.get(&nid) {
543 if *rc
544 == *internal_rcs
545 .entry(nid)
546 .and_modify(|rc| *rc += 1)
547 .or_insert(1)
548 {
549 order.push(nid);
550 params.extend(nodes[nid.i()].parameters());
551 }
552 }
553 }
554 let mut topo = Vec::new();
557 let mut req_grad = sources.clone();
558 let mut visited = BTreeSet::new();
559 for nid in order.into_iter().rev() {
560 for p in nodes[nid.i()].parameters() {
561 if req_grad.contains(&p) && visited.insert(nid) {
562 req_grad.insert(nid);
563 topo.push(nid);
564 }
565 }
566 }
567 topo.reverse();
568 topo
569 }
570
571 let topo = build_topo(x, sources, &self.nodes);
572 let req_grad: BTreeSet<Id> = topo
575 .iter()
576 .copied()
577 .chain(sources.iter().copied())
578 .collect();
579 let mut grads: BTreeMap<Id, Id> = BTreeMap::new();
581 let grad1 = match get_dtype(&self.nodes, x) {
583 DType::F32 => self.store([1f32]),
584 DType::F64 => self.store([1f64]),
585 DType::I32 => self.store([1i32]),
586 }?;
587 let sh = get_shape(&self.nodes, x).clone();
588 grads.insert(x, self.push(Node::Expand(grad1, sh))?);
589 self.release(grad1)?;
590 fn insert_or_add_grad<B: RuntimeBackend>(
593 r: &mut Runtime<B>,
594 grads: &mut BTreeMap<Id, Id>,
595 x: Id,
596 grad: Id,
597 ) -> Result<(), ZyxError> {
598 match grads.entry(x) {
599 Entry::Vacant(e) => {
600 e.insert(grad);
601 }
602 Entry::Occupied(e) => {
603 let (k, prev_grad) = e.remove_entry();
604 grads.insert(k, r.push(Node::Add(prev_grad, grad))?);
605 r.release(prev_grad)?;
606 r.release(grad)?;
607 }
608 }
609 Ok(())
610 }
611
612 for nid in topo {
615 let grad = grads[&nid];
616 match self.nodes[nid.i()] {
617 Node::Detach(..) | Node::Leaf(..) | Node::Uniform(..) => {}
618 Node::Add(x, y) => {
619 if req_grad.contains(&x) {
620 self.retain(grad);
621 insert_or_add_grad(self, &mut grads, x, grad)?;
622 }
623 if req_grad.contains(&y) {
624 self.retain(grad);
625 insert_or_add_grad(self, &mut grads, y, grad)?;
626 }
627 }
628 Node::Sub(x, y) => {
629 if req_grad.contains(&x) {
630 self.retain(grad);
631 insert_or_add_grad(self, &mut grads, x, grad)?;
632 }
633 if req_grad.contains(&y) {
634 let grad = self.push(Node::Neg(grad))?;
635 insert_or_add_grad(self, &mut grads, y, grad)?;
636 }
637 }
638 Node::Mul(x, y) => {
639 if req_grad.contains(&x) {
640 let grad = self.push(Node::Mul(y, grad))?;
641 insert_or_add_grad(self, &mut grads, x, grad)?;
642 }
643 if req_grad.contains(&y) {
644 let grad = self.push(Node::Mul(x, grad))?;
645 insert_or_add_grad(self, &mut grads, y, grad)?;
646 }
647 }
648 Node::Div(x, y) => {
649 if req_grad.contains(&x) {
650 grads.insert(x, self.push(Node::Div(grad, y))?);
651 insert_or_add_grad(self, &mut grads, x, grad)?;
652 }
653 if req_grad.contains(&y) {
654 let two = match get_dtype(&self.nodes, y) {
656 DType::F32 => self.store([2f32]),
657 DType::F64 => self.store([2f64]),
658 DType::I32 => self.store([2i32]),
659 }?;
660 let two_e =
661 self.push(Node::Expand(two, get_shape(&self.nodes, y).clone()))?;
662 self.release(two)?;
663 let two_2 = self.push(Node::Pow(y, two_e))?;
664 self.release(two_e)?;
665 let temp = self.push(Node::Mul(x, grad))?;
666 let temp_neg = self.push(Node::Neg(temp))?;
667 self.release(temp)?;
668 let y_grad = self.push(Node::Div(temp_neg, two_2))?;
669 self.release(temp_neg)?;
670 self.release(two_2)?;
671 grads.insert(y, y_grad);
672 insert_or_add_grad(self, &mut grads, y, grad)?;
673 }
674 }
675 Node::Pow(x, y) => {
676 if req_grad.contains(&x) {
677 let one = match get_dtype(&self.nodes, y) {
679 DType::F32 => self.store([1f32]),
680 DType::F64 => self.store([1f64]),
681 DType::I32 => self.store([1i32]),
682 }?;
683 let one1 =
684 self.push(Node::Expand(one, get_shape(&self.nodes, y).clone()))?;
685 self.release(one)?;
686 let y_1 = self.push(Node::Sub(y, one1))?;
687 self.release(one1)?;
688 let pow_y_1 = self.push(Node::Pow(x, y_1))?;
689 self.release(y_1)?;
690 let y_mul = self.push(Node::Mul(y, pow_y_1))?;
691 self.release(pow_y_1)?;
692 let x_grad = self.push(Node::Mul(grad, y_mul))?;
693 self.release(y_mul)?;
694 insert_or_add_grad(self, &mut grads, x, x_grad)?;
695 }
696 if req_grad.contains(&y) {
697 let temp1 = self.push(Node::Ln(x))?;
699 let temp2 = self.push(Node::Mul(nid, temp1))?;
700 self.release(temp1)?;
701 let y_grad = self.push(Node::Mul(grad, temp2))?;
702 self.release(temp2)?;
703 insert_or_add_grad(self, &mut grads, y, y_grad)?;
704 }
705 }
706 Node::Cmplt(..) => {
707 panic!(
708 "Compare less than (cmplt, operator <) is not a differentiable operation."
709 );
710 }
711 Node::Where(x, y, z) => {
712 if req_grad.contains(&y) {
716 let zero = match get_dtype(&self.nodes, x) {
717 DType::F32 => self.store([0f32]),
718 DType::F64 => self.store([0f64]),
719 DType::I32 => self.store([0i32]),
720 }?;
721 let zeros =
722 self.push(Node::Expand(zero, get_shape(&self.nodes, x).clone()))?;
723 self.release(zero)?;
724 let y_grad = self.push(Node::Where(x, grad, zeros))?;
725 self.release(zeros)?;
726 insert_or_add_grad(self, &mut grads, y, y_grad)?;
727 }
728 if req_grad.contains(&z) {
729 let zero = match get_dtype(&self.nodes, x) {
730 DType::F32 => self.store([0f32]),
731 DType::F64 => self.store([0f64]),
732 DType::I32 => self.store([0i32]),
733 }?;
734 let zeros =
735 self.push(Node::Expand(zero, get_shape(&self.nodes, x).clone()))?;
736 self.release(zero)?;
737 let z_grad = self.push(Node::Where(x, zeros, grad))?;
738 self.release(zeros)?;
739 insert_or_add_grad(self, &mut grads, z, z_grad)?;
740 }
741 }
742 Node::ReLU(x) => {
743 let zero = match get_dtype(&self.nodes, x) {
744 DType::F32 => self.store([0f32]),
745 DType::F64 => self.store([0f64]),
746 DType::I32 => self.store([0i32]),
747 }?;
748 let zeros = self.push(Node::Expand(zero, get_shape(&self.nodes, x).clone()))?;
749 self.release(zero)?;
750 let zl = self.push(Node::Cmplt(zeros, x))?;
751 self.release(zeros)?;
752 let x_grad = self.push(Node::Mul(zl, grad))?;
753 self.release(zl)?;
754 insert_or_add_grad(self, &mut grads, x, x_grad)?;
755 }
756 Node::Exp(x) => {
757 let grad = self.push(Node::Mul(nid, grad))?;
758 insert_or_add_grad(self, &mut grads, x, grad)?;
759 }
760 Node::Ln(x) => {
761 let grad = self.push(Node::Div(grad, x))?;
762 insert_or_add_grad(self, &mut grads, x, grad)?;
763 }
764 Node::Sin(x) => {
765 let x_temp = self.push(Node::Cos(x))?;
766 let grad = self.push(Node::Mul(x_temp, grad))?;
767 self.release(x_temp)?;
768 insert_or_add_grad(self, &mut grads, x, grad)?;
769 }
770 Node::Cos(x) => {
771 let x_temp1 = self.push(Node::Sin(x))?;
772 let x_temp = self.push(Node::Neg(x_temp1))?;
773 self.release(x_temp1)?;
774 let grad = self.push(Node::Mul(x_temp, grad))?;
775 self.release(x_temp)?;
776 insert_or_add_grad(self, &mut grads, x, grad)?;
777 }
778 Node::Sqrt(x) => {
779 let x_shape = get_shape(&self.nodes, x).clone();
781 let two1 = match get_dtype(&self.nodes, x) {
782 DType::F32 => self.store([2f32]),
783 DType::F64 => self.store([2f64]),
784 DType::I32 => self.store([2i32]),
785 }?;
786 let two2 = self.push(Node::Expand(two1, x_shape))?;
787 self.release(two1)?;
788 let x_temp = self.push(Node::Mul(two2, nid))?;
789 self.release(two2)?;
790 let grad = self.push(Node::Div(grad, x_temp))?;
791 self.release(x_temp)?;
792 insert_or_add_grad(self, &mut grads, x, grad)?;
793 }
794 Node::Cast(x, _) => {
795 let grad = self.push(Node::Cast(grad, get_dtype(&self.nodes, x)))?;
796 insert_or_add_grad(self, &mut grads, x, grad)?;
797 }
798 Node::Neg(x) => {
799 let grad = self.push(Node::Neg(grad))?;
800 insert_or_add_grad(self, &mut grads, x, grad)?;
801 }
802 Node::Tanh(x) => {
803 let shape = get_shape(&self.nodes, x).clone();
805 let (two1, one1) = match get_dtype(&self.nodes, x) {
806 DType::F32 => (self.store([2f32])?, self.store([1f32])?),
807 DType::F64 => (self.store([2f64])?, self.store([1f64])?),
808 DType::I32 => (self.store([2i32])?, self.store([1i32])?),
809 };
810 let two2 = self.push(Node::Expand(two1, shape.clone()))?;
811 self.release(two1)?;
812 let two = self.push(Node::Pow(nid, two2))?;
813 self.release(two2)?;
814 let one2 = self.push(Node::Expand(one1, shape))?;
815 self.release(one1)?;
816 let one = self.push(Node::Sub(one2, two))?;
817 self.release(one2)?;
818 self.release(two)?;
819 let grad = self.push(Node::Mul(one, grad))?;
820 self.release(one)?;
821 insert_or_add_grad(self, &mut grads, x, grad)?;
822 }
823 Node::Reshape(x, ..) => {
824 let grad = self.push(Node::Reshape(grad, get_shape(&self.nodes, x).clone()))?;
825 insert_or_add_grad(self, &mut grads, x, grad)?;
826 }
827 Node::Expand(x, ref sh) => {
828 let org_shape = get_shape(&self.nodes, x).clone();
829 let axes = org_shape.expand_axes(sh);
830 let temp = self.push(Node::Sum(grad, axes, org_shape.clone()))?;
831 let grad = self.push(Node::Reshape(temp, org_shape))?;
832 self.release(temp)?;
833 insert_or_add_grad(self, &mut grads, x, grad)?;
834 }
835 Node::Permute(x, ref axes, _) => {
836 let shape = get_shape(&self.nodes, x);
837 let grad = self.push(Node::Permute(grad, axes.argsort(), shape.clone()))?;
838 insert_or_add_grad(self, &mut grads, x, grad)?;
839 }
840 Node::Pad(x, ref padding, _) => {
841 let sh = get_shape(&self.nodes, x).clone();
842 let inv_padding = padding.iter().map(|(lp, rp)| (-lp, -rp)).collect();
843 let grad = self.push(Node::Pad(grad, inv_padding, sh))?;
844 insert_or_add_grad(self, &mut grads, x, grad)?;
845 }
846 Node::Sum(x, ..) => {
847 let grad = self.push(Node::Expand(grad, get_shape(&self.nodes, x).clone()))?;
848 insert_or_add_grad(self, &mut grads, x, grad)?;
849 }
850 Node::Max(x, ..) => {
851 let x_shape = get_shape(&self.nodes, x).clone();
853 let z_temp = self.push(Node::Expand(nid, x_shape.clone()))?;
854 let cmp_t = self.push(Node::Cmplt(x, z_temp))?;
855 self.release(z_temp)?;
856 let one1 = match get_dtype(&self.nodes, x) {
857 DType::F32 => self.store([1f32]),
858 DType::F64 => self.store([1f64]),
859 DType::I32 => self.store([1i32]),
860 }?;
861 let one2 = self.push(Node::Expand(one1, x_shape))?;
862 self.release(one1)?;
863 let max_1s = self.push(Node::Sub(one2, cmp_t))?;
864 self.release(one2)?;
865 self.release(cmp_t)?;
866 let grad = self.push(Node::Mul(max_1s, grad))?;
867 self.release(max_1s)?;
868 insert_or_add_grad(self, &mut grads, x, grad)?;
869 }
870 }
871 }
872 let mut res = BTreeMap::new();
873 for (k, v) in grads.into_iter() {
874 if sources.contains(&k) {
875 res.insert(k, v);
876 } else {
877 self.release(v)?;
878 }
879 }
880 Ok(res)
881 }
882}