1use crate::internal::*;
3use bit_set::BitSet;
4use std::collections::VecDeque;
5use std::fmt::{Debug, Display};
6use tract_itertools::Itertools;
7
8pub fn eval_order<F, O>(model: &super::Graph<F, O>) -> TractResult<Vec<usize>>
11where
12 F: Fact + Clone + 'static,
13 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
14{
15 let inputs = model.input_outlets()?.iter().map(|n| n.node).collect_vec();
16 let targets = model.output_outlets()?.iter().map(|n| n.node).collect_vec();
17 eval_order_for_nodes(model.nodes(), &inputs, &targets, &[])
18}
19
20pub fn eval_order_for_nodes<F, O>(
23 nodes: &[Node<F, O>],
24 model_inputs: &[usize],
25 model_outputs: &[usize],
26 more_dependencies: &[(usize, usize)],
27) -> TractResult<Vec<usize>>
28where
29 F: Fact + Clone + 'static,
30 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
31{
32 let mut done = BitSet::with_capacity(nodes.len());
33 let mut order: Vec<usize> = vec![];
34 for &model_target in model_outputs {
35 if done.contains(model_target) {
36 continue;
37 }
38 let mut current_stack: Vec<(usize, usize)> = vec![(model_target, 0)];
39 let mut pending = BitSet::with_capacity(nodes.len());
40 while let Some((current_node, current_input)) = current_stack.pop() {
41 let deps_from_inputs = nodes[current_node].inputs.len();
42 let all_deps_count =
43 deps_from_inputs + more_dependencies.iter().filter(|a| a.0 == current_node).count();
44 if model_inputs.contains(¤t_node) || current_input == all_deps_count {
45 order.push(current_node);
46 done.insert(current_node);
47 pending.remove(current_node);
48 } else {
49 let precursor: usize = nodes[current_node]
50 .inputs
51 .iter()
52 .filter(|n| nodes[n.node].inputs.len() > 0)
53 .map(|n| n.node)
54 .chain(more_dependencies.iter().filter(|a| a.0 == current_node).map(|n| n.1))
55 .chain(
56 nodes[current_node]
57 .inputs
58 .iter()
59 .filter(|n| nodes[n.node].inputs.len() == 0)
60 .map(|n| n.node),
61 )
62 .nth(current_input)
63 .unwrap();
64 if done.contains(precursor) {
65 current_stack.push((current_node, current_input + 1));
66 } else if pending.contains(precursor) {
67 if log_enabled!(log::Level::Debug) {
68 debug!("Loop detected:");
69 current_stack
70 .iter()
71 .skip_while(|s| s.0 != precursor)
72 .for_each(|n| debug!(" {}", nodes[n.0]));
73 }
74 bail!("Loop detected")
75 } else {
76 pending.insert(precursor);
77 current_stack.push((current_node, current_input));
78 current_stack.push((precursor, 0));
79 }
80 }
81 }
82 }
83 Ok(order)
84}
85
86pub fn build_flush_list<F, O, Flushable>(
87 model: &Graph<F, O>,
88 order: &[usize],
89 outputs: &[OutletId],
90 flushable: Flushable,
91) -> Vec<TVec<usize>>
92where
93 F: Fact + Clone + 'static,
94 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
95 Flushable: Fn(&Node<F, O>) -> bool,
96{
97 let mut values_needed_until_step = vec![0; model.nodes().len()];
98 for (step, node) in order.iter().enumerate() {
99 for i in &model.node(*node).inputs {
100 values_needed_until_step[i.node] = step;
101 }
102 }
103 for o in outputs.iter() {
104 values_needed_until_step[o.node] = order.len();
105 }
106 let mut flush_lists: Vec<TVec<usize>> = vec![tvec!(); order.len() + 1];
107
108 for (node, &flush_at) in values_needed_until_step.iter().enumerate() {
109 if flush_at != 0 && (flushable)(model.node(node)) {
110 flush_lists[flush_at].push(node)
111 }
112 }
113 flush_lists
114}
115
116pub fn eval_order_opt_ram<F, O>(model: &super::Graph<F, O>) -> TractResult<Vec<usize>>
118where
119 F: Fact + Clone + 'static,
120 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
121{
122 let inputs = model.input_outlets()?.iter().map(|n| n.node).collect_vec();
123 let targets = model.output_outlets()?.iter().map(|n| n.node).collect_vec();
124 eval_order_opt_ram_for_nodes(model.nodes(), &inputs, &targets, &[])
125}
126
127pub fn eval_order_opt_ram_for_nodes<F, O>(
129 nodes: &[Node<F, O>],
130 model_inputs: &[usize],
131 model_outputs: &[usize],
132 more_dependencies: &[(usize, usize)],
133) -> TractResult<Vec<usize>>
134where
135 F: Fact + Clone + 'static,
136 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
137{
138 let tocompute: BitSet =
139 eval_order_for_nodes(nodes, model_inputs, model_outputs, more_dependencies)?
140 .into_iter()
141 .collect();
142
143 let mut ups = vec![tvec!(); nodes.len()];
144 let mut downs = vec![tvec!(); nodes.len()];
145 for ix in tocompute.iter() {
146 for input in &nodes[ix].inputs {
147 if !ups[ix].contains(&input.node) {
148 ups[ix].push(input.node);
149 downs[input.node].push(ix);
150 }
151 }
152 }
153 for (down, up) in more_dependencies {
154 if !ups[*down].contains(up) {
155 ups[*down].push(*up);
156 downs[*up].push(*down);
157 }
158 }
159
160 #[derive(Debug)]
161 struct Dfs {
162 ups: Vec<TVec<usize>>,
163 downs: Vec<TVec<usize>>,
164 }
165
166 let dfs = Dfs { ups, downs };
167
168 #[derive(Debug, Clone, PartialEq, Eq)]
169 struct Path {
170 order: Vec<usize>,
171 done: BitSet,
172 alive: BitSet,
173 candidates: BitSet,
174 cache_upstream: Vec<Option<(usize, BitSet)>>,
175 }
176
177 impl Path {
178 fn with_size(nodes: usize) -> Path {
179 Path {
180 order: Vec::with_capacity(nodes),
181 done: BitSet::with_capacity(nodes),
182 alive: BitSet::with_capacity(nodes),
183 candidates: BitSet::with_capacity(nodes),
184 cache_upstream: vec![None; nodes],
185 }
186 }
187
188 fn follow_one(&mut self, dfs: &Dfs, next: usize) {
189 assert!(!self.done.contains(next));
190 self.order.push(next);
191 self.done.insert(next);
192 self.alive.insert(next);
193 self.candidates.remove(next);
194 for &succ in &dfs.downs[next] {
195 self.candidates.insert(succ);
196 }
197 for &maybe_dead in &dfs.ups[next] {
198 if dfs.downs[maybe_dead].iter().all(|n| self.done.contains(*n)) {
199 self.alive.remove(maybe_dead);
200 }
201 }
202 self.cache_upstream[next] = None;
203 for c in &self.candidates {
204 if let Some(upstream) = self.cache_upstream[c].as_mut() {
205 upstream.0 -= upstream.1.remove(next) as usize;
206 }
207 }
208 }
209
210 fn best_upstream_starter(&mut self, dfs: &Dfs) -> Option<usize> {
211 for from in self.candidates.iter() {
212 if self.cache_upstream[from].is_none() {
213 let mut found = BitSet::with_capacity(self.done.len());
214 let mut visited = self.done.clone();
215 let mut todo = VecDeque::<usize>::new();
216 todo.push_back(from);
217 visited.insert(from);
218 while let Some(next) = todo.pop_front() {
219 if dfs.ups[next].len() == 0 {
220 found.insert(next);
221 }
222 for up in &dfs.ups[next] {
223 if visited.insert(*up) {
224 todo.push_back(*up);
225 }
226 }
227 }
228 debug_assert!(found.len() > 0);
229 self.cache_upstream[from] = Some((found.len(), found));
230 }
231 }
232 self.candidates
233 .iter()
234 .map(|n| self.cache_upstream[n].as_ref().unwrap())
235 .min_by_key(|s| s.0)
236 .map(|s| s.1.iter().next().unwrap())
237 }
238 }
239
240 let mut done: Path = Path::with_size(nodes.len());
241 for i in model_inputs {
242 if tocompute.contains(*i) {
243 done.follow_one(&dfs, *i);
244 }
245 }
246
247 while !model_outputs.iter().all(|o| done.done.contains(*o)) {
248 let next = if let Some(next) =
249 done.candidates.iter().find(|n| dfs.ups[*n].iter().all(|n| done.done.contains(*n)))
250 {
251 next
252 } else if let Some(next) = done.best_upstream_starter(&dfs) {
253 next
254 } else {
255 tocompute
256 .difference(&done.done)
257 .find(|n| dfs.ups[*n].iter().all(|n| done.done.contains(*n)))
258 .unwrap()
259 };
260 done.follow_one(&dfs, next);
261 }
262
263 Ok(done.order.clone())
264}
265
266#[cfg(test)]
267mod tests {
268 use crate::internal::*;
269 use crate::ops::array::Gather;
270 use crate::ops::math;
271
272 #[test]
273 fn simple() {
274 let mut model = TypedModel::default();
275 let a = model.add_source("a", f32::fact([1])).unwrap();
276 let b = model.add_const("b", tensor1(&[12.0f32])).unwrap();
277 let add = model.wire_node("add", math::add(), &[a, b]).unwrap()[0];
278 model.auto_outputs().unwrap();
279 assert_eq!(model.eval_order().unwrap(), vec!(a.node, b.node, add.node));
280 assert_eq!(model.eval_order_opt_ram().unwrap(), vec!(a.node, b.node, add.node));
281 }
282
283 #[test]
284 fn diamond() {
285 let mut model = TypedModel::default();
286 let a = model.add_source("a", f32::fact([1])).unwrap();
287 let add = model.wire_node("add", math::add(), &[a, a]).unwrap()[0];
288 model.auto_outputs().unwrap();
289 assert_eq!(model.eval_order().unwrap(), vec!(a.node, add.node));
290 assert_eq!(model.eval_order_opt_ram().unwrap(), vec!(a.node, add.node));
291 }
292
293 #[cfg(not(target_family = "wasm"))]
295 #[test]
296 fn dodge_loop() {
297 let mut model = TypedModel::default();
298 let a = model.add_source("a", f32::fact([1])).unwrap();
299 let add = model.wire_node("add", math::add(), &[a, a]).unwrap()[0];
300 let neg = model.wire_node("neg", math::add(), &[add, a]).unwrap()[0];
301 model.add_edge(neg, InletId::new(add.node, 1)).unwrap();
302 model.set_output_outlets(&[neg]).unwrap();
303 let cloned = model.clone();
304 let (rx, tx) = std::sync::mpsc::channel();
305 std::thread::spawn(move || {
306 rx.send(cloned.eval_order()).unwrap();
307 });
308 assert!(tx.recv_timeout(std::time::Duration::from_secs(1)).unwrap().is_err());
309 let (rx, tx) = std::sync::mpsc::channel();
310 std::thread::spawn(move || {
311 rx.send(model.eval_order_opt_ram()).unwrap();
312 });
313 assert!(tx.recv_timeout(std::time::Duration::from_secs(1)).unwrap().is_err());
314 }
315
316 #[test]
317 fn opt_ram() -> TractResult<()> {
318 let mut model = TypedModel::default();
319 let b = model.add_const("b", tensor1(&[0i64; 1000]))?;
320 let d = model.add_const("d", tensor1(&[0i64; 100]))?;
321 let a = model.add_source("a", i32::fact([10]))?;
322 let c = model.wire_node("c", Gather::new(0), &[a, b])?[0];
323 let e = model.wire_node("e", Gather::new(0), &[c, d])?[0];
324 model.set_output_outlets(&[e]).unwrap();
325 eprintln!("{model}");
326 assert!(model.eval_order_opt_ram()?[2..] == [c.node, d.node, e.node]);
327 Ok(())
328 }
329}