tract_core/model/
order.rs

1//! Evaluation order for nodes.
2use crate::internal::*;
3use bit_set::BitSet;
4use std::collections::VecDeque;
5use std::fmt::{Debug, Display};
6use tract_itertools::Itertools;
7
8/// Find an evaluation order for a model, using its default inputs and outputs
9/// as boundaries.
10pub fn eval_order<F, O>(model: &super::Graph<F, O>) -> TractResult<Vec<usize>>
11where
12    F: Fact + Clone + 'static,
13    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
14{
15    let inputs = model.input_outlets()?.iter().map(|n| n.node).collect_vec();
16    let targets = model.output_outlets()?.iter().map(|n| n.node).collect_vec();
17    eval_order_for_nodes(model.nodes(), &inputs, &targets, &[])
18}
19
20/// Find a working evaluation order for a list of nodes.
21/// This algorithm starts from the outputs, so it will only compute what is necessary.
22pub fn eval_order_for_nodes<F, O>(
23    nodes: &[Node<F, O>],
24    model_inputs: &[usize],
25    model_outputs: &[usize],
26    more_dependencies: &[(usize, usize)],
27) -> TractResult<Vec<usize>>
28where
29    F: Fact + Clone + 'static,
30    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
31{
32    let mut done = BitSet::with_capacity(nodes.len());
33    let mut order: Vec<usize> = vec![];
34    for &model_target in model_outputs {
35        if done.contains(model_target) {
36            continue;
37        }
38        let mut current_stack: Vec<(usize, usize)> = vec![(model_target, 0)];
39        let mut pending = BitSet::with_capacity(nodes.len());
40        while let Some((current_node, current_input)) = current_stack.pop() {
41            let deps_from_inputs = nodes[current_node].inputs.len();
42            let all_deps_count =
43                deps_from_inputs + more_dependencies.iter().filter(|a| a.0 == current_node).count();
44            if model_inputs.contains(&current_node) || current_input == all_deps_count {
45                order.push(current_node);
46                done.insert(current_node);
47                pending.remove(current_node);
48            } else {
49                let precursor: usize = nodes[current_node]
50                    .inputs
51                    .iter()
52                    .filter(|n| nodes[n.node].inputs.len() > 0)
53                    .map(|n| n.node)
54                    .chain(more_dependencies.iter().filter(|a| a.0 == current_node).map(|n| n.1))
55                    .chain(
56                        nodes[current_node]
57                            .inputs
58                            .iter()
59                            .filter(|n| nodes[n.node].inputs.len() == 0)
60                            .map(|n| n.node),
61                    )
62                    .nth(current_input)
63                    .unwrap();
64                if done.contains(precursor) {
65                    current_stack.push((current_node, current_input + 1));
66                } else if pending.contains(precursor) {
67                    if log_enabled!(log::Level::Debug) {
68                        debug!("Loop detected:");
69                        current_stack
70                            .iter()
71                            .skip_while(|s| s.0 != precursor)
72                            .for_each(|n| debug!("  {}", nodes[n.0]));
73                    }
74                    bail!("Loop detected")
75                } else {
76                    pending.insert(precursor);
77                    current_stack.push((current_node, current_input));
78                    current_stack.push((precursor, 0));
79                }
80            }
81        }
82    }
83    Ok(order)
84}
85
86pub fn build_flush_list<F, O, Flushable>(
87    model: &Graph<F, O>,
88    order: &[usize],
89    outputs: &[OutletId],
90    flushable: Flushable,
91) -> Vec<TVec<usize>>
92where
93    F: Fact + Clone + 'static,
94    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
95    Flushable: Fn(&Node<F, O>) -> bool,
96{
97    let mut values_needed_until_step = vec![0; model.nodes().len()];
98    for (step, node) in order.iter().enumerate() {
99        for i in &model.node(*node).inputs {
100            values_needed_until_step[i.node] = step;
101        }
102    }
103    for o in outputs.iter() {
104        values_needed_until_step[o.node] = order.len();
105    }
106    let mut flush_lists: Vec<TVec<usize>> = vec![tvec!(); order.len() + 1];
107
108    for (node, &flush_at) in values_needed_until_step.iter().enumerate() {
109        if flush_at != 0 && (flushable)(model.node(node)) {
110            flush_lists[flush_at].push(node)
111        }
112    }
113    flush_lists
114}
115
116/// Find an evaluation order for a list of model trying to minimize memory occupation.
117pub fn eval_order_opt_ram<F, O>(model: &super::Graph<F, O>) -> TractResult<Vec<usize>>
118where
119    F: Fact + Clone + 'static,
120    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
121{
122    let inputs = model.input_outlets()?.iter().map(|n| n.node).collect_vec();
123    let targets = model.output_outlets()?.iter().map(|n| n.node).collect_vec();
124    eval_order_opt_ram_for_nodes(model.nodes(), &inputs, &targets, &[])
125}
126
127/// Find an evaluation order for a list of nodes trying to minimize memory occupation.
128pub fn eval_order_opt_ram_for_nodes<F, O>(
129    nodes: &[Node<F, O>],
130    model_inputs: &[usize],
131    model_outputs: &[usize],
132    more_dependencies: &[(usize, usize)],
133) -> TractResult<Vec<usize>>
134where
135    F: Fact + Clone + 'static,
136    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
137{
138    let tocompute: BitSet =
139        eval_order_for_nodes(nodes, model_inputs, model_outputs, more_dependencies)?
140            .into_iter()
141            .collect();
142
143    let mut ups = vec![tvec!(); nodes.len()];
144    let mut downs = vec![tvec!(); nodes.len()];
145    for ix in tocompute.iter() {
146        for input in &nodes[ix].inputs {
147            if !ups[ix].contains(&input.node) {
148                ups[ix].push(input.node);
149                downs[input.node].push(ix);
150            }
151        }
152    }
153    for (down, up) in more_dependencies {
154        if !ups[*down].contains(up) {
155            ups[*down].push(*up);
156            downs[*up].push(*down);
157        }
158    }
159
160    #[derive(Debug)]
161    struct Dfs {
162        ups: Vec<TVec<usize>>,
163        downs: Vec<TVec<usize>>,
164    }
165
166    let dfs = Dfs { ups, downs };
167
168    #[derive(Debug, Clone, PartialEq, Eq)]
169    struct Path {
170        order: Vec<usize>,
171        done: BitSet,
172        alive: BitSet,
173        candidates: BitSet,
174        cache_upstream: Vec<Option<(usize, BitSet)>>,
175    }
176
177    impl Path {
178        fn with_size(nodes: usize) -> Path {
179            Path {
180                order: Vec::with_capacity(nodes),
181                done: BitSet::with_capacity(nodes),
182                alive: BitSet::with_capacity(nodes),
183                candidates: BitSet::with_capacity(nodes),
184                cache_upstream: vec![None; nodes],
185            }
186        }
187
188        fn follow_one(&mut self, dfs: &Dfs, next: usize) {
189            assert!(!self.done.contains(next));
190            self.order.push(next);
191            self.done.insert(next);
192            self.alive.insert(next);
193            self.candidates.remove(next);
194            for &succ in &dfs.downs[next] {
195                self.candidates.insert(succ);
196            }
197            for &maybe_dead in &dfs.ups[next] {
198                if dfs.downs[maybe_dead].iter().all(|n| self.done.contains(*n)) {
199                    self.alive.remove(maybe_dead);
200                }
201            }
202            self.cache_upstream[next] = None;
203            for c in &self.candidates {
204                if let Some(upstream) = self.cache_upstream[c].as_mut() {
205                    upstream.0 -= upstream.1.remove(next) as usize;
206                }
207            }
208        }
209
210        fn best_upstream_starter(&mut self, dfs: &Dfs) -> Option<usize> {
211            for from in self.candidates.iter() {
212                if self.cache_upstream[from].is_none() {
213                    let mut found = BitSet::with_capacity(self.done.len());
214                    let mut visited = self.done.clone();
215                    let mut todo = VecDeque::<usize>::new();
216                    todo.push_back(from);
217                    visited.insert(from);
218                    while let Some(next) = todo.pop_front() {
219                        if dfs.ups[next].len() == 0 {
220                            found.insert(next);
221                        }
222                        for up in &dfs.ups[next] {
223                            if visited.insert(*up) {
224                                todo.push_back(*up);
225                            }
226                        }
227                    }
228                    debug_assert!(found.len() > 0);
229                    self.cache_upstream[from] = Some((found.len(), found));
230                }
231            }
232            self.candidates
233                .iter()
234                .map(|n| self.cache_upstream[n].as_ref().unwrap())
235                .min_by_key(|s| s.0)
236                .map(|s| s.1.iter().next().unwrap())
237        }
238    }
239
240    let mut done: Path = Path::with_size(nodes.len());
241    for i in model_inputs {
242        if tocompute.contains(*i) {
243            done.follow_one(&dfs, *i);
244        }
245    }
246
247    while !model_outputs.iter().all(|o| done.done.contains(*o)) {
248        let next = if let Some(next) =
249            done.candidates.iter().find(|n| dfs.ups[*n].iter().all(|n| done.done.contains(*n)))
250        {
251            next
252        } else if let Some(next) = done.best_upstream_starter(&dfs) {
253            next
254        } else {
255            tocompute
256                .difference(&done.done)
257                .find(|n| dfs.ups[*n].iter().all(|n| done.done.contains(*n)))
258                .unwrap()
259        };
260        done.follow_one(&dfs, next);
261    }
262
263    Ok(done.order.clone())
264}
265
266#[cfg(test)]
267mod tests {
268    use crate::internal::*;
269    use crate::ops::array::Gather;
270    use crate::ops::math;
271
272    #[test]
273    fn simple() {
274        let mut model = TypedModel::default();
275        let a = model.add_source("a", f32::fact([1])).unwrap();
276        let b = model.add_const("b", tensor1(&[12.0f32])).unwrap();
277        let add = model.wire_node("add", math::add(), &[a, b]).unwrap()[0];
278        model.auto_outputs().unwrap();
279        assert_eq!(model.eval_order().unwrap(), vec!(a.node, b.node, add.node));
280        assert_eq!(model.eval_order_opt_ram().unwrap(), vec!(a.node, b.node, add.node));
281    }
282
283    #[test]
284    fn diamond() {
285        let mut model = TypedModel::default();
286        let a = model.add_source("a", f32::fact([1])).unwrap();
287        let add = model.wire_node("add", math::add(), &[a, a]).unwrap()[0];
288        model.auto_outputs().unwrap();
289        assert_eq!(model.eval_order().unwrap(), vec!(a.node, add.node));
290        assert_eq!(model.eval_order_opt_ram().unwrap(), vec!(a.node, add.node));
291    }
292
293    // The test is disabled on Wasm because it uses threads.
294    #[cfg(not(target_family = "wasm"))]
295    #[test]
296    fn dodge_loop() {
297        let mut model = TypedModel::default();
298        let a = model.add_source("a", f32::fact([1])).unwrap();
299        let add = model.wire_node("add", math::add(), &[a, a]).unwrap()[0];
300        let neg = model.wire_node("neg", math::add(), &[add, a]).unwrap()[0];
301        model.add_edge(neg, InletId::new(add.node, 1)).unwrap();
302        model.set_output_outlets(&[neg]).unwrap();
303        let cloned = model.clone();
304        let (rx, tx) = std::sync::mpsc::channel();
305        std::thread::spawn(move || {
306            rx.send(cloned.eval_order()).unwrap();
307        });
308        assert!(tx.recv_timeout(std::time::Duration::from_secs(1)).unwrap().is_err());
309        let (rx, tx) = std::sync::mpsc::channel();
310        std::thread::spawn(move || {
311            rx.send(model.eval_order_opt_ram()).unwrap();
312        });
313        assert!(tx.recv_timeout(std::time::Duration::from_secs(1)).unwrap().is_err());
314    }
315
316    #[test]
317    fn opt_ram() -> TractResult<()> {
318        let mut model = TypedModel::default();
319        let b = model.add_const("b", tensor1(&[0i64; 1000]))?;
320        let d = model.add_const("d", tensor1(&[0i64; 100]))?;
321        let a = model.add_source("a", i32::fact([10]))?;
322        let c = model.wire_node("c", Gather::new(0), &[a, b])?[0];
323        let e = model.wire_node("e", Gather::new(0), &[c, d])?[0];
324        model.set_output_outlets(&[e]).unwrap();
325        eprintln!("{model}");
326        assert!(model.eval_order_opt_ram()?[2..] == [c.node, d.node, e.node]);
327        Ok(())
328    }
329}