zyx_core/
runtime.rs

1use crate::error::ZyxError;
2use crate::utils::{get_dtype, get_shape};
3use crate::{
4    dtype::DType,
5    node::Node,
6    scalar::Scalar,
7    shape::Shape,
8    tensor::{self, Id},
9};
10use alloc::collections::btree_map::Entry;
11use alloc::{
12    collections::{BTreeMap, BTreeSet},
13    vec::Vec,
14};
15use core::ops::Range;
16use rand::distributions::Uniform;
17
18/// RuntimeBackend is a good plug in point for backend developers.
19/// Use Runtime::new(YourOwnStructThatImplementsRuntimeBackend::new()) to write your
20/// own backend which needs to implement only evaluation of graph.
21/// Used by torch and native backends.
22pub trait RuntimeBackend {
23    /// Is tensor x evaluated?
24    fn is_evaluated(&self, x: Id) -> bool;
25    /// Check if there are no more buffers on id x
26    fn is_free_id(&self, x: Id) -> bool;
27    /// Delete all memory used by tensor x.
28    fn remove(&mut self, x: Id) -> Result<(), ZyxError>;
29    /// Store iterator into runtime backend
30    fn store<T: Scalar, IT>(&mut self, x: Id, iter: IT) -> Result<(), ZyxError>
31    where
32        IT: IntoIterator<Item = T>,
33        IT::IntoIter: ExactSizeIterator;
34    /// Load evaluated tensor x.
35    fn load<T: Scalar>(&mut self, x: Id, numel: usize) -> Result<Vec<T>, ZyxError>;
36    /// Evaluate tensors to_eval with given graph of nodes and recommended
37    /// order of evaluation.
38    fn evaluate(
39        &mut self,
40        rcs: BTreeMap<Id, u32>,
41        order: &[Id],
42        nodes: &[Node],
43    ) -> Result<(), ZyxError>;
44}
45
46/// Runtime with autograd engine.
47/// This runtime uses [Node] enum as representation of tensors.
48pub struct Runtime<R: RuntimeBackend> {
49    rng: rand::rngs::SmallRng,
50    rcs: Vec<u32>,
51    nodes: Vec<Node>,
52    unrealized_nodes_count: usize,
53    runtime_backend: R,
54}
55
56impl<R: RuntimeBackend> Runtime<R> {
57    /// Initialize new runtime.
58    #[must_use]
59    pub fn new(runtime_backend: R) -> Self {
60        use rand::SeedableRng;
61        Self {
62            rng: rand::rngs::SmallRng::seed_from_u64(420_694_206_942_069),
63            rcs: Vec::new(),
64            nodes: Vec::new(),
65            unrealized_nodes_count: 0,
66            runtime_backend,
67        }
68    }
69
70    /// Create tensor initialized from normal distribution.
71    pub fn randn(&mut self, shape: Shape, dtype: DType) -> Result<Id, ZyxError> {
72        use rand::Rng;
73        let n = shape.numel();
74        let mut rng = self.rng.clone();
75        use rand::distributions::Standard;
76        let data1 = match dtype {
77            DType::F32 => self.store::<f32, _>((0..n).map(move |_| rng.sample(Standard))),
78            DType::F64 => self.store::<f64, _>((0..n).map(move |_| rng.sample(Standard))),
79            DType::I32 => self.store::<i32, _>((0..n).map(move |_| rng.sample(Standard))),
80        }?;
81        let data = self.push(Node::Reshape(data1, shape))?;
82        self.release(data1)?;
83        // change the state of the random seed in rng
84        for _ in 0..n {
85            self.rng.sample::<f32, _>(Standard);
86        }
87        Ok(data)
88    }
89
90    /// Create uniform tensor from range low..high
91    pub fn uniform<T: Scalar>(&mut self, shape: Shape, range: Range<T>) -> Result<Id, ZyxError> {
92        // TODO for f32 in range 0.0..1.0 switch to Node::UniformF32 for better performance
93        use rand::Rng;
94        let n = shape.numel();
95        let mut rng = self.rng.clone();
96        use rand::distributions::Standard;
97        let data1 = match T::dtype() {
98            DType::F32 => self.store((0..n).map(move |_| {
99                rng.sample(Uniform::new(
100                    range.start.clone().into_f32(),
101                    range.end.clone().into_f32(),
102                ))
103            })),
104            DType::F64 => self.store((0..n).map(move |_| {
105                rng.sample(Uniform::new(
106                    range.start.clone().into_f64(),
107                    range.end.clone().into_f64(),
108                ))
109            })),
110            DType::I32 => self.store((0..n).map(move |_| {
111                rng.sample(Uniform::new(
112                    range.start.clone().into_i32(),
113                    range.end.clone().into_i32(),
114                ))
115            })),
116        }?;
117        let data = self.push(Node::Reshape(data1, shape))?;
118        self.release(data1)?;
119        // change the state of the random seed in rng
120        for _ in 0..n {
121            self.rng.sample::<f32, _>(Standard);
122        }
123        Ok(data)
124    }
125
126    /// Get shape of tensor x
127    #[must_use]
128    pub fn shape(&self, x: Id) -> &Shape {
129        get_shape(self.nodes.as_slice(), x)
130    }
131
132    /// Get dtype of tensor x
133    #[must_use]
134    pub fn dtype(&self, x: Id) -> DType {
135        get_dtype(self.nodes.as_slice(), x)
136    }
137
138    /// Load tensor x
139    pub fn load<T: Scalar>(&mut self, x: Id) -> Result<Vec<T>, ZyxError> {
140        if !self.runtime_backend.is_evaluated(x) {
141            self.evaluate(BTreeSet::from([x]))?;
142        }
143        let numel = get_shape(self.nodes.as_slice(), x).numel();
144        //std::println!("Reading buffer with {numel} elements.");
145        self.runtime_backend.load(x, numel)
146    }
147
148    /// Store iterator into runtime as tensor
149    pub fn store<T: Scalar, IT>(&mut self, iter: IT) -> Result<Id, ZyxError>
150    where
151        IT: IntoIterator<Item = T>,
152        IT::IntoIter: ExactSizeIterator,
153    {
154        // TODO optimizations for scalars and very small tensors, by using Node::Scalar(...) or Node::SmallTensor(..)
155        // With those optimizations, these can be compiled into kernels for better performance.
156        let iter = iter.into_iter();
157        let len = iter.len();
158        let node = Node::Leaf(len.into(), T::dtype());
159        let id = if let Some(i) = self
160            .rcs
161            .iter()
162            .enumerate()
163            .position(|(i, rc)| *rc == 0 && self.runtime_backend.is_free_id(tensor::id(i)))
164        {
165            let id = tensor::id(i);
166            self.rcs[i] = 1;
167            self.nodes[i] = node;
168            id
169        } else {
170            let id = tensor::id(self.rcs.len());
171            self.rcs.push(1);
172            self.nodes.push(node);
173            id
174        };
175        //if id.i() == 1 { panic!("break") }
176        self.runtime_backend.store(id, iter)?;
177        //std::println!("Storing {id}, {:?}", self.rcs);
178        Ok(id)
179    }
180
181    /// Push new Node into the graph creating new tensor.
182    /// This function does ZERO verification that the node is correct, but it optimizes
183    /// out useless operations (like reshaping to the same shape)
184    pub fn push(&mut self, node: Node) -> Result<Id, ZyxError> {
185        //std::println!("Pushing {node:?}, len: {}, rcs: {:?}", self.nodes.len(), self.rcs);
186        // get rid of noops :)
187        match node {
188            Node::Reshape(x, ref shape) | Node::Expand(x, ref shape) => {
189                if shape == self.shape(x) {
190                    self.retain(x);
191                    return Ok(x);
192                }
193            }
194            Node::Sum(x, ref axes, ..) | Node::Max(x, ref axes, ..) => {
195                if axes.len() == 0 {
196                    self.retain(x);
197                    return Ok(x);
198                }
199            }
200            _ => {}
201        }
202        for nid in node.parameters() {
203            self.retain(nid);
204        }
205        let id = if let Some(i) = self
206            .rcs
207            .iter()
208            .enumerate()
209            .position(|(i, rc)| *rc == 0 && self.runtime_backend.is_free_id(tensor::id(i)))
210        {
211            let id = tensor::id(i);
212            self.rcs[i] = 1;
213            self.nodes[i] = node;
214            id
215        } else {
216            let id = tensor::id(self.rcs.len());
217            self.rcs.push(1);
218            if self.rcs.len() > 4000000000 {
219                panic!("Maximum number of tensors has been reached. Zyx supports up to 4 billion tensors. \
220                Please check your code for memory leaks. If you really need to use more tensors, please raise an issue: https://github.com/zk4x/zyx");
221            }
222            self.nodes.push(node);
223            id
224        };
225        //std::println!("Assigned id: {id}, rcs {:?}", self.rcs);
226        self.unrealized_nodes_count += 1;
227        // This regulates caching, 256 tensors per batch seems like a good default
228        if self.unrealized_nodes_count > 10000 {
229            self.evaluate([id].into_iter().collect::<BTreeSet<Id>>())?;
230            //std::println!("Num tensors: {}", self.nodes.len());
231        }
232        Ok(id)
233    }
234
235    /// Decrease reference count of x. If x's reference count reaches zero, this function will delete
236    /// x and release all of it's predecessors in the graph.
237    pub fn release(&mut self, x: Id) -> Result<(), ZyxError> {
238        //std::println!("Releasing {x}");
239        let mut params = Vec::with_capacity(10);
240        params.push(x);
241        while let Some(x) = params.pop() {
242            self.rcs[x.i()] -= 1;
243            //std::println!("Releasing {x} {:?}", self.rcs);
244            if self.rcs[x.i()] == 0 {
245                params.extend(self.nodes[x.i()].parameters());
246                self.runtime_backend.remove(x)?;
247                // We count only non leaf nodes
248                if !matches!(self.nodes[x.i()], Node::Leaf(..) | Node::Uniform(..)) {
249                    self.unrealized_nodes_count -= 1;
250                }
251            }
252        }
253        //std::println!("After released {x} rcs {:?}", self.rcs);
254        Ok(())
255    }
256
257    /// Increase reference count of tensor x.
258    pub fn retain(&mut self, x: Id) {
259        //std::println!("Retaining {x}, rcs: {:?}", self.rcs);
260        //panic!();
261        debug_assert!(
262            self.rcs[x.i()] < u32::MAX,
263            "Reference count of tensor {x} has been exceeded,\
264        This is zyx bug. please report it at: https://github.com/zk4x/zyx"
265        );
266        self.rcs[x.i()] += 1;
267    }
268
269    /// Debug print all nodes
270    pub fn debug_graph(&self) {
271        for (id, node) in self.nodes.iter().enumerate() {
272            std::println!("{id:>5} x{:>3} -> {node:?}", self.rcs[id]);
273        }
274    }
275
276    /// Evaluate specified nodes.
277    pub fn evaluate(&mut self, nodes: BTreeSet<Id>) -> Result<(), ZyxError> {
278        // This whole function is needed so that we can batch ops together.
279        // This aleviates the cost of keeping intermediate buffers for backpropagation,
280        // as this function runs independently from backpropagation and if some tensors
281        // are dropped after backpropagation, this function optimizes those away.
282        // Basically the difference between immediate and lazy execution with caching.
283        // We simply wait to get more information about the graph structure before
284        // we push it to the device.
285
286        //std::println!("Evaluating {nodes:?}, rcs: {:?}", self.rcs);
287
288        // TODO in order to be more efficient, we can optimize the graph
289        // by reordering nodes and removing unnecessary nodes
290        // TODO should we decrease refcount of some other nodes to drop them from memory?
291        // This memory <=> performance tradeoff should be decided by the user, with some setting.
292        // TODO simplify this function if possible
293
294        // Creation of graph (DFS) runs in linear time, max once per node in self.nodes.
295        // Make a list of visited nodes and their reference counts.
296        let mut temp_rcs: BTreeMap<Id, u32> = BTreeMap::new();
297        let mut params: Vec<Id> = nodes.iter().copied().collect();
298        params.reserve(100);
299        while let Some(nid) = params.pop() {
300            //std::println!("{nid} is evaluated: {}", self.runtime_backend.is_evaluated(nid));
301            temp_rcs
302                .entry(nid)
303                .and_modify(|rc| *rc += 1)
304                .or_insert_with(|| {
305                    if !self.runtime_backend.is_evaluated(nid) {
306                        params.extend(self.nodes[nid.i()].parameters());
307                    }
308                    1
309                });
310        }
311        // Order them using rcs reference counts.
312        let mut order = Vec::new();
313        let mut rcs: BTreeMap<Id, u32> = BTreeMap::new();
314        let mut params: Vec<Id> = nodes.iter().copied().collect();
315        params.reserve(100);
316        while let Some(nid) = params.pop() {
317            if let Some(temp_rc) = temp_rcs.get(&nid) {
318                let rc = rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert(1);
319                if *temp_rc == *rc {
320                    order.push(nid);
321                    params.extend(self.nodes[nid.i()].parameters());
322                }
323            }
324        }
325        order.reverse();
326
327        // Just create another DFS that goes all the way to Node::Leaf and adds branches to drop_nodes
328        // if needed.
329        let mut drop_nodes = BTreeSet::new();
330        let mut temp_rcs: BTreeMap<Id, u32> = BTreeMap::new();
331        let mut params: Vec<Id> = nodes.iter().copied().collect();
332        params.reserve(100);
333        while let Some(nid) = params.pop() {
334            if !matches!(self.nodes[nid.i()], Node::Leaf(..)) {
335                temp_rcs
336                    .entry(nid)
337                    .and_modify(|rc| *rc += 1)
338                    .or_insert_with(|| {
339                        params.extend(self.nodes[nid.i()].parameters());
340                        1
341                    });
342            } else {
343                let rc = temp_rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert(1);
344                // This does not account for possible existance of user or global graph references
345                // that are not directly on leafs.
346                if *rc == self.rcs[nid.i()] && !nodes.contains(&nid) {
347                    drop_nodes.insert(nid);
348                }
349            }
350        }
351        let mut new_order = Vec::new();
352        let mut new_rcs: BTreeMap<Id, u32> = BTreeMap::new();
353        let mut params: Vec<Id> = nodes.iter().copied().collect();
354        params.reserve(100);
355        while let Some(nid) = params.pop() {
356            if let Some(temp_rc) = temp_rcs.get(&nid) {
357                let rc = new_rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert(1);
358                if *temp_rc == *rc {
359                    new_order.push(nid);
360                    params.extend(self.nodes[nid.i()].parameters());
361                }
362            }
363        }
364        new_order.reverse();
365
366        // This must go over the graph from the previous loop!
367        let mut new_leafs = BTreeSet::new();
368        for nid in &new_order {
369            if new_rcs[nid] == self.rcs[nid.i()] && !nodes.contains(nid) {
370                for p in self.nodes[nid.i()].parameters() {
371                    if drop_nodes.contains(&p) {
372                        drop_nodes.insert(*nid);
373                    }
374                }
375            } else {
376                if self.nodes[nid.i()]
377                    .parameters()
378                    .any(|p| drop_nodes.contains(&p))
379                {
380                    new_leafs.insert(*nid);
381                }
382            }
383        }
384
385        //std::println!("RCS: {:?}", self.rcs);
386        // Dealing with Detach nodes
387        // TODO this is a waste, optimize this.
388        let mut user_rc = self.rcs.clone();
389        for (i, node) in self.nodes.iter().enumerate() {
390            if self.rcs[i] != 0 {
391                //std::println!("{i}: {node:?}");
392                for p in node.parameters() {
393                    user_rc[p.i()] -= 1;
394                }
395            }
396        }
397        for nid in &new_order {
398            // TODO also add Cmplt, as it is not differentiable
399            match &self.nodes[nid.i()] {
400                Node::Cmplt(x, y) => {
401                    let mut detach_rc = BTreeMap::new();
402                    new_leafs.insert(*x);
403                    new_leafs.insert(*y);
404                    let mut params = Vec::with_capacity(10);
405                    params.push(*x);
406                    params.push(*y);
407                    while let Some(x) = params.pop() {
408                        let rc = detach_rc.entry(nid).and_modify(|rc| *rc += 1).or_insert(1);
409                        if *rc == self.rcs[x.i()] {
410                            drop_nodes.insert(x);
411                        }
412                    }
413                }
414                Node::Detach(x) => {
415                    let mut detach_rc = BTreeMap::new();
416                    new_leafs.insert(*x);
417                    let mut params = Vec::with_capacity(10);
418                    params.push(*x);
419                    while let Some(x) = params.pop() {
420                        let rc = detach_rc.entry(nid).and_modify(|rc| *rc += 1).or_insert(1);
421                        if *rc == self.rcs[x.i()] {
422                            drop_nodes.insert(x);
423                        }
424                    }
425                }
426                _ => {}
427            }
428        }
429
430        // Increase rcs for nodes that we want to keep evaluated.
431        // First it MUST be all new_leafs.
432        for nid in &new_leafs {
433            // Some leafs are already evaluated, so they are not in rcs
434            if let Some(rc) = rcs.get_mut(nid) {
435                *rc += 1;
436            }
437        }
438
439        /*std::println!("Non-evaluated nodes count before: {}", self.unrealized_nodes_count);
440        //self.debug_graph();
441        for i in &order {
442            std::println!("{i} x {} -> {:?}", rcs[i], self.nodes[i.i()]);
443        }
444        std::println!("Drop nodes: {drop_nodes:?}");
445        std::println!("New leafs {new_leafs:?}");
446        std::println!("Order: {order:?}");
447        std::println!();*/
448
449        self.runtime_backend.evaluate(rcs, &order, &self.nodes)?;
450
451        for nid in new_leafs {
452            self.unrealized_nodes_count -= 1;
453            self.nodes[nid.i()] = Node::Leaf(
454                get_shape(&self.nodes, nid).clone(),
455                get_dtype(&self.nodes, nid),
456            );
457        }
458        for nid in drop_nodes {
459            self.rcs[nid.i()] = 0;
460            self.runtime_backend.remove(nid)?;
461            if !matches!(self.nodes[nid.i()], Node::Leaf(..) | Node::Uniform(..)) {
462                self.unrealized_nodes_count -= 1;
463            }
464        }
465        std::println!("Non-evaluated nodes count after: {}", self.unrealized_nodes_count);
466
467        // TODO fix this
468        /*if self.backprop_nodes_count > 2000000000 {
469            panic!("Maximum number of tensors in gradient tape has been reached. Zyx supports up to 2 billion tensors on the tape.\
470            This error can be raised for example in RNNs. Please detach gradient tape (Tensor::detach) from some tensors.\
471            If you really need to use more tensors, please raise an issue: https://github.com/zk4x/zyx");
472        }*/
473        Ok(())
474    }
475
476    /// Plot dot graph in dot format between given nodes
477    #[must_use]
478    pub fn plot_graph_dot(&self, ids: &[Id]) -> alloc::string::String {
479        // Make a list of visited nodes and their reference counts.
480        let mut params: Vec<Id> = ids.into();
481        let mut rcs: BTreeMap<Id, u8> = BTreeMap::new();
482        while let Some(nid) = params.pop() {
483            rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert_with(|| {
484                params.extend(self.nodes[nid.i()].parameters());
485                1
486            });
487        }
488        // Order them using rcs reference counts
489        let mut order = Vec::new();
490        let mut internal_rcs: BTreeMap<Id, u8> = BTreeMap::new();
491        let mut params: Vec<Id> = ids.into();
492        while let Some(nid) = params.pop() {
493            if rcs[&nid]
494                == *internal_rcs
495                    .entry(nid)
496                    .and_modify(|rc| *rc += 1)
497                    .or_insert(1)
498            {
499                order.push(nid);
500                if rcs.contains_key(&nid) {
501                    params.extend(self.nodes[nid.i()].parameters());
502                }
503            }
504        }
505        // Build topo, this way it ensures that grad is not used in backprop
506        // before it was insert_or_add by all parents.
507        let mut topo: BTreeSet<Id> = ids.iter().copied().collect();
508        for nid in order.into_iter().rev() {
509            for p in self.nodes[nid.i()].parameters() {
510                if topo.contains(&p) {
511                    topo.insert(nid);
512                }
513            }
514        }
515
516        crate::utils::plot_graph_dot(&topo, &self.nodes, &self.rcs)
517    }
518
519    /// Common autograd engine, currently used by all backends.
520    pub fn backward(
521        &mut self,
522        x: Id,
523        sources: &BTreeSet<Id>,
524    ) -> Result<BTreeMap<Id, Id>, ZyxError> {
525        fn build_topo(x: Id, sources: &BTreeSet<Id>, nodes: &[Node]) -> Vec<Id> {
526            // Make a list of visited nodes and their reference counts.
527            let mut params: Vec<Id> = alloc::vec![x];
528            let mut rcs: BTreeMap<Id, u8> = BTreeMap::new();
529            while let Some(nid) = params.pop() {
530                rcs.entry(nid).and_modify(|rc| *rc += 1).or_insert_with(|| {
531                    if !sources.contains(&nid) && !matches!(nodes[nid.i()], Node::Detach(..) | Node::Cmplt(..)) {
532                        params.extend(nodes[nid.i()].parameters());
533                    }
534                    1
535                });
536            }
537            // Order them using rcs reference counts
538            let mut order = Vec::new();
539            let mut internal_rcs: BTreeMap<Id, u8> = BTreeMap::new();
540            let mut params: Vec<Id> = alloc::vec![x];
541            while let Some(nid) = params.pop() {
542                if let Some(rc) = rcs.get(&nid) {
543                    if *rc
544                        == *internal_rcs
545                            .entry(nid)
546                            .and_modify(|rc| *rc += 1)
547                            .or_insert(1)
548                    {
549                        order.push(nid);
550                        params.extend(nodes[nid.i()].parameters());
551                    }
552                }
553            }
554            // Build topo, this way it ensures that grad is not used in backprop
555            // before it was insert_or_add by all parents.
556            let mut topo = Vec::new();
557            let mut req_grad = sources.clone();
558            let mut visited = BTreeSet::new();
559            for nid in order.into_iter().rev() {
560                for p in nodes[nid.i()].parameters() {
561                    if req_grad.contains(&p) && visited.insert(nid) {
562                        req_grad.insert(nid);
563                        topo.push(nid);
564                    }
565                }
566            }
567            topo.reverse();
568            topo
569        }
570
571        let topo = build_topo(x, sources, &self.nodes);
572        //std::println!("Topo: {topo:?}");
573
574        let req_grad: BTreeSet<Id> = topo
575            .iter()
576            .copied()
577            .chain(sources.iter().copied())
578            .collect();
579        // Node -> Grad
580        let mut grads: BTreeMap<Id, Id> = BTreeMap::new();
581        // Initial gradient of ones
582        let grad1 = match get_dtype(&self.nodes, x) {
583            DType::F32 => self.store([1f32]),
584            DType::F64 => self.store([1f64]),
585            DType::I32 => self.store([1i32]),
586        }?;
587        let sh = get_shape(&self.nodes, x).clone();
588        grads.insert(x, self.push(Node::Expand(grad1, sh))?);
589        self.release(grad1)?;
590        //std::println!("{:?}", self.nodes.last().unwrap());
591
592        fn insert_or_add_grad<B: RuntimeBackend>(
593            r: &mut Runtime<B>,
594            grads: &mut BTreeMap<Id, Id>,
595            x: Id,
596            grad: Id,
597        ) -> Result<(), ZyxError> {
598            match grads.entry(x) {
599                Entry::Vacant(e) => {
600                    e.insert(grad);
601                }
602                Entry::Occupied(e) => {
603                    let (k, prev_grad) = e.remove_entry();
604                    grads.insert(k, r.push(Node::Add(prev_grad, grad))?);
605                    r.release(prev_grad)?;
606                    r.release(grad)?;
607                }
608            }
609            Ok(())
610        }
611
612        // backpropagate
613        // TODO this is not very clean code. Can we make it cleaner?
614        for nid in topo {
615            let grad = grads[&nid];
616            match self.nodes[nid.i()] {
617                Node::Detach(..) | Node::Leaf(..) | Node::Uniform(..) => {}
618                Node::Add(x, y) => {
619                    if req_grad.contains(&x) {
620                        self.retain(grad);
621                        insert_or_add_grad(self, &mut grads, x, grad)?;
622                    }
623                    if req_grad.contains(&y) {
624                        self.retain(grad);
625                        insert_or_add_grad(self, &mut grads, y, grad)?;
626                    }
627                }
628                Node::Sub(x, y) => {
629                    if req_grad.contains(&x) {
630                        self.retain(grad);
631                        insert_or_add_grad(self, &mut grads, x, grad)?;
632                    }
633                    if req_grad.contains(&y) {
634                        let grad = self.push(Node::Neg(grad))?;
635                        insert_or_add_grad(self, &mut grads, y, grad)?;
636                    }
637                }
638                Node::Mul(x, y) => {
639                    if req_grad.contains(&x) {
640                        let grad = self.push(Node::Mul(y, grad))?;
641                        insert_or_add_grad(self, &mut grads, x, grad)?;
642                    }
643                    if req_grad.contains(&y) {
644                        let grad = self.push(Node::Mul(x, grad))?;
645                        insert_or_add_grad(self, &mut grads, y, grad)?;
646                    }
647                }
648                Node::Div(x, y) => {
649                    if req_grad.contains(&x) {
650                        grads.insert(x, self.push(Node::Div(grad, y))?);
651                        insert_or_add_grad(self, &mut grads, x, grad)?;
652                    }
653                    if req_grad.contains(&y) {
654                        // -grad*x/(y^2)
655                        let two = match get_dtype(&self.nodes, y) {
656                            DType::F32 => self.store([2f32]),
657                            DType::F64 => self.store([2f64]),
658                            DType::I32 => self.store([2i32]),
659                        }?;
660                        let two_e =
661                            self.push(Node::Expand(two, get_shape(&self.nodes, y).clone()))?;
662                        self.release(two)?;
663                        let two_2 = self.push(Node::Pow(y, two_e))?;
664                        self.release(two_e)?;
665                        let temp = self.push(Node::Mul(x, grad))?;
666                        let temp_neg = self.push(Node::Neg(temp))?;
667                        self.release(temp)?;
668                        let y_grad = self.push(Node::Div(temp_neg, two_2))?;
669                        self.release(temp_neg)?;
670                        self.release(two_2)?;
671                        grads.insert(y, y_grad);
672                        insert_or_add_grad(self, &mut grads, y, grad)?;
673                    }
674                }
675                Node::Pow(x, y) => {
676                    if req_grad.contains(&x) {
677                        // grad * y * x.pow(y-1)
678                        let one = match get_dtype(&self.nodes, y) {
679                            DType::F32 => self.store([1f32]),
680                            DType::F64 => self.store([1f64]),
681                            DType::I32 => self.store([1i32]),
682                        }?;
683                        let one1 =
684                            self.push(Node::Expand(one, get_shape(&self.nodes, y).clone()))?;
685                        self.release(one)?;
686                        let y_1 = self.push(Node::Sub(y, one1))?;
687                        self.release(one1)?;
688                        let pow_y_1 = self.push(Node::Pow(x, y_1))?;
689                        self.release(y_1)?;
690                        let y_mul = self.push(Node::Mul(y, pow_y_1))?;
691                        self.release(pow_y_1)?;
692                        let x_grad = self.push(Node::Mul(grad, y_mul))?;
693                        self.release(y_mul)?;
694                        insert_or_add_grad(self, &mut grads, x, x_grad)?;
695                    }
696                    if req_grad.contains(&y) {
697                        // grad * x.pow(y) * ln(x)
698                        let temp1 = self.push(Node::Ln(x))?;
699                        let temp2 = self.push(Node::Mul(nid, temp1))?;
700                        self.release(temp1)?;
701                        let y_grad = self.push(Node::Mul(grad, temp2))?;
702                        self.release(temp2)?;
703                        insert_or_add_grad(self, &mut grads, y, y_grad)?;
704                    }
705                }
706                Node::Cmplt(..) => {
707                    panic!(
708                        "Compare less than (cmplt, operator <) is not a differentiable operation."
709                    );
710                }
711                Node::Where(x, y, z) => {
712                    //return None, \
713                    //self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
714                    //self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
715                    if req_grad.contains(&y) {
716                        let zero = match get_dtype(&self.nodes, x) {
717                            DType::F32 => self.store([0f32]),
718                            DType::F64 => self.store([0f64]),
719                            DType::I32 => self.store([0i32]),
720                        }?;
721                        let zeros =
722                            self.push(Node::Expand(zero, get_shape(&self.nodes, x).clone()))?;
723                        self.release(zero)?;
724                        let y_grad = self.push(Node::Where(x, grad, zeros))?;
725                        self.release(zeros)?;
726                        insert_or_add_grad(self, &mut grads, y, y_grad)?;
727                    }
728                    if req_grad.contains(&z) {
729                        let zero = match get_dtype(&self.nodes, x) {
730                            DType::F32 => self.store([0f32]),
731                            DType::F64 => self.store([0f64]),
732                            DType::I32 => self.store([0i32]),
733                        }?;
734                        let zeros =
735                            self.push(Node::Expand(zero, get_shape(&self.nodes, x).clone()))?;
736                        self.release(zero)?;
737                        let z_grad = self.push(Node::Where(x, zeros, grad))?;
738                        self.release(zeros)?;
739                        insert_or_add_grad(self, &mut grads, z, z_grad)?;
740                    }
741                }
742                Node::ReLU(x) => {
743                    let zero = match get_dtype(&self.nodes, x) {
744                        DType::F32 => self.store([0f32]),
745                        DType::F64 => self.store([0f64]),
746                        DType::I32 => self.store([0i32]),
747                    }?;
748                    let zeros = self.push(Node::Expand(zero, get_shape(&self.nodes, x).clone()))?;
749                    self.release(zero)?;
750                    let zl = self.push(Node::Cmplt(zeros, x))?;
751                    self.release(zeros)?;
752                    let x_grad = self.push(Node::Mul(zl, grad))?;
753                    self.release(zl)?;
754                    insert_or_add_grad(self, &mut grads, x, x_grad)?;
755                }
756                Node::Exp(x) => {
757                    let grad = self.push(Node::Mul(nid, grad))?;
758                    insert_or_add_grad(self, &mut grads, x, grad)?;
759                }
760                Node::Ln(x) => {
761                    let grad = self.push(Node::Div(grad, x))?;
762                    insert_or_add_grad(self, &mut grads, x, grad)?;
763                }
764                Node::Sin(x) => {
765                    let x_temp = self.push(Node::Cos(x))?;
766                    let grad = self.push(Node::Mul(x_temp, grad))?;
767                    self.release(x_temp)?;
768                    insert_or_add_grad(self, &mut grads, x, grad)?;
769                }
770                Node::Cos(x) => {
771                    let x_temp1 = self.push(Node::Sin(x))?;
772                    let x_temp = self.push(Node::Neg(x_temp1))?;
773                    self.release(x_temp1)?;
774                    let grad = self.push(Node::Mul(x_temp, grad))?;
775                    self.release(x_temp)?;
776                    insert_or_add_grad(self, &mut grads, x, grad)?;
777                }
778                Node::Sqrt(x) => {
779                    // x_grad = grad/(2*sqrt(x))
780                    let x_shape = get_shape(&self.nodes, x).clone();
781                    let two1 = match get_dtype(&self.nodes, x) {
782                        DType::F32 => self.store([2f32]),
783                        DType::F64 => self.store([2f64]),
784                        DType::I32 => self.store([2i32]),
785                    }?;
786                    let two2 = self.push(Node::Expand(two1, x_shape))?;
787                    self.release(two1)?;
788                    let x_temp = self.push(Node::Mul(two2, nid))?;
789                    self.release(two2)?;
790                    let grad = self.push(Node::Div(grad, x_temp))?;
791                    self.release(x_temp)?;
792                    insert_or_add_grad(self, &mut grads, x, grad)?;
793                }
794                Node::Cast(x, _) => {
795                    let grad = self.push(Node::Cast(grad, get_dtype(&self.nodes, x)))?;
796                    insert_or_add_grad(self, &mut grads, x, grad)?;
797                }
798                Node::Neg(x) => {
799                    let grad = self.push(Node::Neg(grad))?;
800                    insert_or_add_grad(self, &mut grads, x, grad)?;
801                }
802                Node::Tanh(x) => {
803                    // 1 - tanh^2(x)
804                    let shape = get_shape(&self.nodes, x).clone();
805                    let (two1, one1) = match get_dtype(&self.nodes, x) {
806                        DType::F32 => (self.store([2f32])?, self.store([1f32])?),
807                        DType::F64 => (self.store([2f64])?, self.store([1f64])?),
808                        DType::I32 => (self.store([2i32])?, self.store([1i32])?),
809                    };
810                    let two2 = self.push(Node::Expand(two1, shape.clone()))?;
811                    self.release(two1)?;
812                    let two = self.push(Node::Pow(nid, two2))?;
813                    self.release(two2)?;
814                    let one2 = self.push(Node::Expand(one1, shape))?;
815                    self.release(one1)?;
816                    let one = self.push(Node::Sub(one2, two))?;
817                    self.release(one2)?;
818                    self.release(two)?;
819                    let grad = self.push(Node::Mul(one, grad))?;
820                    self.release(one)?;
821                    insert_or_add_grad(self, &mut grads, x, grad)?;
822                }
823                Node::Reshape(x, ..) => {
824                    let grad = self.push(Node::Reshape(grad, get_shape(&self.nodes, x).clone()))?;
825                    insert_or_add_grad(self, &mut grads, x, grad)?;
826                }
827                Node::Expand(x, ref sh) => {
828                    let org_shape = get_shape(&self.nodes, x).clone();
829                    let axes = org_shape.expand_axes(sh);
830                    let temp = self.push(Node::Sum(grad, axes, org_shape.clone()))?;
831                    let grad = self.push(Node::Reshape(temp, org_shape))?;
832                    self.release(temp)?;
833                    insert_or_add_grad(self, &mut grads, x, grad)?;
834                }
835                Node::Permute(x, ref axes, _) => {
836                    let shape = get_shape(&self.nodes, x);
837                    let grad = self.push(Node::Permute(grad, axes.argsort(), shape.clone()))?;
838                    insert_or_add_grad(self, &mut grads, x, grad)?;
839                }
840                Node::Pad(x, ref padding, _) => {
841                    let sh = get_shape(&self.nodes, x).clone();
842                    let inv_padding = padding.iter().map(|(lp, rp)| (-lp, -rp)).collect();
843                    let grad = self.push(Node::Pad(grad, inv_padding, sh))?;
844                    insert_or_add_grad(self, &mut grads, x, grad)?;
845                }
846                Node::Sum(x, ..) => {
847                    let grad = self.push(Node::Expand(grad, get_shape(&self.nodes, x).clone()))?;
848                    insert_or_add_grad(self, &mut grads, x, grad)?;
849                }
850                Node::Max(x, ..) => {
851                    // x_grad = (1 - (x < z.expand(x.shape()))) * grad
852                    let x_shape = get_shape(&self.nodes, x).clone();
853                    let z_temp = self.push(Node::Expand(nid, x_shape.clone()))?;
854                    let cmp_t = self.push(Node::Cmplt(x, z_temp))?;
855                    self.release(z_temp)?;
856                    let one1 = match get_dtype(&self.nodes, x) {
857                        DType::F32 => self.store([1f32]),
858                        DType::F64 => self.store([1f64]),
859                        DType::I32 => self.store([1i32]),
860                    }?;
861                    let one2 = self.push(Node::Expand(one1, x_shape))?;
862                    self.release(one1)?;
863                    let max_1s = self.push(Node::Sub(one2, cmp_t))?;
864                    self.release(one2)?;
865                    self.release(cmp_t)?;
866                    let grad = self.push(Node::Mul(max_1s, grad))?;
867                    self.release(max_1s)?;
868                    insert_or_add_grad(self, &mut grads, x, grad)?;
869                }
870            }
871        }
872        let mut res = BTreeMap::new();
873        for (k, v) in grads.into_iter() {
874            if sources.contains(&k) {
875                res.insert(k, v);
876            } else {
877                self.release(v)?;
878            }
879        }
880        Ok(res)
881    }
882}