1use super::*;
2use crate::internal::*;
3use crate::ops::Op;
4use crate::prelude::*;
5use crate::runtime::RunOptions;
6
7use std::fmt;
8use tract_data::internal::*;
9use tract_itertools::Itertools;
10
11pub trait SpecialOps<F, O> {
12 fn create_dummy(&self) -> O;
13 fn create_source(&self, fact: F) -> O;
14 fn is_source(op: &O) -> bool;
15 fn wire_node(
16 &mut self,
17 name: impl Into<String>,
18 op: impl Into<O>,
19 inputs: &[OutletId],
20 ) -> TractResult<TVec<OutletId>>;
21 fn add_const(
22 &mut self,
23 name: impl Into<String>,
24 v: impl IntoArcTensor,
25 ) -> TractResult<OutletId>;
26}
27
28#[derive(Clone, Debug)]
32pub struct Graph<F, O>
33where
34 F: Fact + Clone + 'static,
35 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
36{
37 pub nodes: Vec<Node<F, O>>,
39 pub inputs: Vec<OutletId>,
41 pub outputs: Vec<OutletId>,
43 pub outlet_labels: HashMap<OutletId, String>,
45 pub properties: HashMap<String, Arc<Tensor>>,
47 pub symbols: SymbolScope,
49}
50
51impl<F, O> Default for Graph<F, O>
52where
53 F: Fact + Clone + 'static,
54 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
55{
56 fn default() -> Graph<F, O> {
57 Graph {
58 nodes: vec![],
59 inputs: vec![],
60 outputs: vec![],
61 outlet_labels: HashMap::new(),
62 properties: HashMap::new(),
63 symbols: Default::default(),
64 }
65 }
66}
67
68impl<F, O> Graph<F, O>
69where
70 F: Fact + Clone + 'static,
71 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
72 Graph<F, O>: SpecialOps<F, O>,
73{
74 pub fn add_source(&mut self, name: impl Into<String>, fact: F) -> TractResult<OutletId> {
75 let source = self.create_source(fact.clone());
76 let id = self.add_node(name, source, tvec!(fact))?;
77 let id = OutletId::new(id, 0);
78 self.inputs.push(id);
79 Ok(id)
80 }
81}
82
83impl<F, O> Graph<F, O>
84where
85 F: Fact + Clone + 'static,
86 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
87{
88 pub fn add_node(
89 &mut self,
90 name: impl Into<String>,
91 op: impl Into<O>,
92 output_facts: TVec<F>,
93 ) -> TractResult<usize> {
94 let op = op.into();
95 let name = name.into();
96 let id = self.nodes.len();
97 let outputs =
98 output_facts.into_iter().map(|fact| Outlet { fact, successors: tvec!() }).collect();
99 let node = Node { id, name, op, inputs: vec![], outputs };
100 self.nodes.push(node);
101 Ok(id)
102 }
103
104 pub fn add_edge(&mut self, outlet: OutletId, inlet: InletId) -> TractResult<()> {
106 if let Some(previous) = self.nodes[inlet.node].inputs.get(inlet.slot).cloned() {
107 self.nodes[previous.node].outputs[previous.slot]
108 .successors
109 .retain(|&mut succ| succ != inlet);
110 }
111 {
112 let prec = &mut self.nodes[outlet.node];
113 prec.outputs[outlet.slot].successors.push(inlet);
114 }
115 let succ = &mut self.nodes[inlet.node];
116 #[allow(clippy::comparison_chain)]
117 if inlet.slot == succ.inputs.len() {
118 succ.inputs.push(outlet);
119 } else if inlet.slot < succ.inputs.len() {
120 succ.inputs[inlet.slot] = outlet;
121 } else {
122 bail!(
123 "Edges must be added in order and consecutive. Trying to connect input {:?} of node {:?} ",
124 inlet.slot,
125 succ
126 )
127 }
128 Ok(())
129 }
130
131 pub fn input_outlets(&self) -> TractResult<&[OutletId]> {
135 Ok(&self.inputs)
136 }
137
138 pub fn set_input_outlets(&mut self, inputs: &[OutletId]) -> TractResult<()> {
140 self.inputs = inputs.to_vec();
141 Ok(())
142 }
143
144 pub fn with_input_outlets(mut self, inputs: &[OutletId]) -> TractResult<Self> {
146 self.set_input_outlets(inputs)?;
147 Ok(self)
148 }
149
150 pub fn set_input_names(
152 &mut self,
153 inputs: impl IntoIterator<Item = impl AsRef<str>>,
154 ) -> TractResult<()> {
155 let mut ids = vec![];
156 for i in inputs.into_iter() {
157 let node = self.node_by_name(&i)?;
158 for o in 0..node.outputs.len() {
159 ids.push(OutletId::new(node.id, o))
160 }
161 }
162 self.inputs = ids;
163 Ok(())
164 }
165
166 pub fn with_input_names(
168 mut self,
169 inputs: impl IntoIterator<Item = impl AsRef<str>>,
170 ) -> TractResult<Self> {
171 self.set_input_names(inputs)?;
172 Ok(self)
173 }
174
175 pub fn input_fact(&self, ix: usize) -> TractResult<&F> {
177 let input = self.input_outlets()?[ix];
178 self.outlet_fact(input)
179 }
180
181 pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
183 let input = self.input_outlets()?[ix];
184 self.outlet_fact_mut(input)
185 }
186
187 pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
189 let outlet = self.inputs[input];
190 self.set_outlet_fact(outlet, fact)
191 }
192
193 pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
195 self.set_input_fact(input, fact)?;
196 Ok(self)
197 }
198
199 pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
202 Ok(&self.outputs)
203 }
204
205 pub fn auto_outputs(&mut self) -> TractResult<()> {
207 let outputs = self
208 .nodes
209 .iter()
210 .flat_map(|n| {
211 let id = n.id;
212 n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
213 (OutletId::new(id, ix), output_fact.successors.len())
214 })
215 })
216 .filter(|(_f, succs)| *succs == 0)
217 .map(|(f, _)| f)
218 .collect();
219 self.outputs = outputs;
220 Ok(())
221 }
222
223 pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
225 self.outputs = outputs.to_vec();
226 Ok(())
227 }
228
229 pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
231 self.set_output_outlets(outputs)?;
232 Ok(self)
233 }
234
235 pub fn set_output_names(
237 &mut self,
238 outputs: impl IntoIterator<Item = impl AsRef<str>>,
239 ) -> TractResult<()> {
240 let mut labels: HashMap<StaticName, OutletId> =
241 self.outlet_labels.iter().map(|(o, s)| (Cow::Owned((*s).to_string()), *o)).collect();
242 for n in self.nodes() {
243 for ix in 0..n.outputs.len() {
244 labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
245 }
246 }
247 let ids: Vec<OutletId> = outputs
248 .into_iter()
249 .map(|s| {
250 let s = s.as_ref();
251 labels
252 .get(s)
253 .cloned()
254 .or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
255 .ok_or_else(|| format_err!("Node {} not found", s))
256 })
257 .collect::<TractResult<_>>()?;
258 self.outputs = ids;
259 Ok(())
260 }
261
262 pub fn with_output_names(
264 mut self,
265 outputs: impl IntoIterator<Item = impl AsRef<str>>,
266 ) -> TractResult<Self> {
267 self.set_output_names(outputs)?;
268 Ok(self)
269 }
270
271 pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
273 let output = self.output_outlets()?[ix];
274 self.outlet_fact(output)
275 }
276
277 pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
279 let output = self.output_outlets()?[ix];
280 self.outlet_fact_mut(output)
281 }
282
283 pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
285 let outlet = self.outputs[output];
286 self.set_outlet_fact(outlet, fact)
287 }
288
289 pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
291 self.set_output_fact(output, fact)?;
292 Ok(self)
293 }
294
295 pub fn node_names(&self) -> impl Iterator<Item = &str> {
299 self.nodes.iter().map(|s| &*s.name)
300 }
301
302 pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
303 self.nodes
304 .iter()
305 .find(|n| n.name == name)
306 .map(|n| n.id)
307 .with_context(|| format!("No node found for name: \"{name}\""))
308 }
309
310 pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
312 let id: usize = self.node_id_by_name(name.as_ref())?;
313 Ok(&self.nodes[id])
314 }
315
316 pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
318 let id: usize = self.node_id_by_name(name.as_ref())?;
319 Ok(&mut self.nodes[id])
320 }
321
322 pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
323 self.node_mut(id).name = name.to_string();
324 Ok(())
325 }
326
327 pub fn node(&self, id: usize) -> &Node<F, O> {
329 &self.nodes[id]
330 }
331
332 pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
334 &mut self.nodes[id]
335 }
336
337 pub fn nodes(&self) -> &[Node<F, O>] {
339 &self.nodes
340 }
341
342 pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
344 &mut self.nodes
345 }
346
347 pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
349 Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
350 }
351
352 pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
354 self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
355 }
356
357 pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
359 Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
360 }
361
362 pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
366 ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
367 let outlets = &self.nodes[outlet.node].outputs;
368 outlets
369 .get(outlet.slot)
370 .map(|o| &o.fact)
371 .with_context(|| format!("Invalid outlet reference: {outlet:?}"))
372 }
373
374 pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
376 let outlets = &mut self.nodes[outlet.node].outputs;
377 outlets
378 .get_mut(outlet.slot)
379 .map(|o| &mut o.fact)
380 .with_context(|| format!("Invalid outlet reference: {outlet:?}"))
381 }
382
383 pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
385 assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
386 unsafe {
387 outlets
388 .iter()
389 .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
390 .collect()
391 }
392 }
393
394 pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
396 let outlets = &mut self.nodes[outlet.node].outputs;
397 if outlets.len() <= outlet.slot {
398 bail!("Invalid outlet refererence: {:?}", outlet)
399 }
400 outlets[outlet.slot].fact = fact;
401 Ok(())
402 }
403
404 pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
406 self.set_outlet_fact(outlet, fact)?;
407 Ok(self)
408 }
409
410 pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
414 self.outlet_labels.get(&outlet).map(|s| &**s)
415 }
416
417 pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
419 self.outlet_labels.insert(outlet, label);
420 Ok(())
421 }
422
423 pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
425 self.set_outlet_label(outlet, label)?;
426 Ok(self)
427 }
428
429 pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
431 self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
432 }
433
434 pub fn eval_order(&self) -> TractResult<Vec<usize>> {
438 super::order::eval_order(self)
439 }
440
441 pub fn eval_order_opt_ram(&self) -> TractResult<Vec<usize>> {
444 super::order::eval_order_opt_ram(self)
445 }
446
447 #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
448 #[inline]
449 pub fn check_edges(&self) -> TractResult<()> {
450 Ok(())
451 }
452
453 #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
455 pub fn check_edges(&self) -> TractResult<()> {
456 for node_id in self.eval_order()? {
457 let node = &self.nodes[node_id];
458 for (ix, input) in node.inputs.iter().enumerate() {
459 let prec = &self.nodes[input.node];
460 if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
461 bail!(
462 "Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
463 node.id,
464 ix,
465 prec
466 )
467 }
468 }
469 for (ix, output) in node.outputs.iter().enumerate() {
470 for succ in &output.successors {
471 if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
472 bail!(
473 "Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
474 node.id,
475 ix,
476 succ
477 )
478 }
479 }
480 }
481 }
482 Ok(())
483 }
484
485 pub fn eval_tmp_memory_usage<Flushable>(
487 &self,
488 order: &[usize],
489 flushable: Flushable,
490 ) -> TractResult<TVec<(usize, TDim)>>
491 where
492 Flushable: Fn(&Node<F, O>) -> bool,
493 {
494 super::memory::eval_tmp_memory_usage(self, order, flushable)
495 }
496
497 #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
498 #[inline]
499 pub fn check_names(&self) -> TractResult<()> {
500 Ok(())
501 }
502
503 #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
505 pub fn check_names(&self) -> TractResult<()> {
506 let dups =
507 self.eval_order()?.iter().map(|n| &self.nodes[*n].name).duplicates().collect_vec();
508 ensure!(dups.len() == 0, "Duplicate node name(s) : {:?}\n{}", dups, &self);
509 Ok(())
510 }
511
512 pub fn into_runnable_with_options(
520 self,
521 options: &RunOptions,
522 ) -> TractResult<Arc<RunnableModel<F, O>>> {
523 crate::plan::SimplePlan::new_with_options(self, options)
524 }
525
526 pub fn linear_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
527 let node = &self.nodes()[id];
528 rule_if!(node.inputs.len() == 1);
529 let prec = &self.nodes()[node.inputs[0].node];
530 rule_if!(prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() == 1);
531 Ok(Some(prec))
532 }
533
534 pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
535 let node = &self.nodes()[id];
536 rule_if!(node.inputs.len() == 1);
537 let prec = &self.nodes()[node.inputs[0].node];
538 Ok(Some(prec))
539 }
540
541 pub fn all_prec(&self, id: usize) -> TractResult<Option<TVec<&Node<F, O>>>> {
542 let node = &self.nodes()[id];
543 rule_if!(node.inputs.len() > 0);
544 Ok(Some(node.inputs.iter().map(|n| &self.nodes()[n.node]).collect()))
545 }
546
547 pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
548 let mut node = self.node(id);
549 for _ in 0..count {
550 if let Some(next) = self.linear_prec(node.id)? {
551 node = next
552 } else {
553 return Ok(None);
554 }
555 }
556 Ok(Some(node))
557 }
558
559 pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
560 let mut node = self.node(id);
561 for _ in 0..count {
562 if let Some(next) = self.linear_succ(node.id)? {
563 node = next
564 } else {
565 return Ok(None);
566 }
567 }
568 Ok(Some(node))
569 }
570
571 pub fn linear_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
574 let node = &self.nodes()[id];
575
576 rule_if!(node.outputs.len() == 1);
577 rule_if!(node.outputs[0].successors.len() == 1);
578 let succ = node.outputs[0].successors[0];
579 let succ = &self.nodes()[succ.node];
580 rule_if!(succ.inputs.len() == 1);
581 Ok(Some(succ))
582 }
583
584 pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
585 let node = &self.nodes()[id];
586
587 rule_if!(node.outputs.len() == 1);
588 rule_if!(node.outputs[0].successors.len() == 1);
589 let succ = node.outputs[0].successors[0];
590 Ok(Some(&self.nodes()[succ.node]))
591 }
592
593 pub fn all_succ(&self, id: usize) -> TractResult<Option<TVec<&Node<F, O>>>> {
594 let node = &self.nodes()[id];
595 rule_if!(!node.outputs.is_empty());
596
597 Ok(Some(
598 node.outputs
599 .iter()
600 .flat_map(|o| {
601 o.successors.iter().map(|succ| &self.nodes()[succ.node]).collect::<Vec<_>>()
602 })
603 .collect(),
604 ))
605 }
606
607 pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
608 &self.nodes[outlet.node].outputs[outlet.slot].successors
609 }
610
611 pub fn sym(&self, s: &str) -> Symbol {
613 self.symbols.sym(s)
614 }
615
616 pub fn new_sym_with_prefix(&self, prefix: &str) -> Symbol {
618 self.symbols.new_with_prefix(prefix)
619 }
620
621 pub fn unique_name<'n>(&self, prefix: impl Into<Cow<'n, str>>) -> Cow<'n, str> {
624 let prefix = prefix.into();
625 if self.nodes.iter().all(|n| n.name != *prefix) {
626 return prefix;
627 }
628 for i in 1.. {
629 let s = format!("{prefix}.{i}");
630 if self.nodes.iter().all(|n| n.name != s) {
631 return Cow::Owned(s);
632 }
633 }
634 unreachable!();
635 }
636}
637
638impl<F, O> fmt::Display for Graph<F, O>
639where
640 F: Fact + Clone + 'static,
641 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
642{
643 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
644 for i in 0..self.nodes.len() {
645 let input_1 =
646 self.nodes[i].inputs.first().map(|o| format!("{o:?}")).unwrap_or_default();
647 let input_2 = self.nodes[i].inputs.get(1).map(|o| format!("{o:?}")).unwrap_or_default();
648 let successors = self.nodes[i]
649 .outputs
650 .first()
651 .iter()
652 .flat_map(|o| o.successors.iter())
653 .collect_vec();
654 let output_1 = successors.first().map(|o| format!("{o:?}")).unwrap_or_default();
655 let output_2 = successors.get(1).map(|o| format!("{o:?}")).unwrap_or_default();
656 writeln!(
657 fmt,
658 "{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {} => {}",
659 i,
660 input_1,
661 input_2,
662 output_1,
663 output_2,
664 self.nodes[i].op().name(),
665 self.nodes[i].name,
666 self.node_input_facts(i).unwrap().iter().map(|f| format!("{f:?}")).join(" ; "),
667 self.node_output_facts(i).unwrap().iter().map(|f| format!("{f:?}")).join(" ; "),
668 )?;
669 if self.nodes[i].inputs.len() > 2 {
670 writeln!(
671 fmt,
672 " | * inputs: {}",
673 self.nodes[i].inputs.iter().map(|s| format!("{s:?}")).join(", ")
674 )?;
675 }
676 if self.nodes[i].outputs.len() > 1
677 || successors.len() > 2
678 || (self.outlet_label(i.into()).is_some()
679 && self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
680 {
681 for o in 0..self.nodes[i].outputs.len() {
682 if self.outlet_successors((i, o).into()).len() > 0 {
683 writeln!(
684 fmt,
685 " | * output #{}: {} {}",
686 o,
687 self.outlet_label((i, o).into()).unwrap_or(""),
688 self.outlet_successors((i, o).into())
689 .iter()
690 .map(|s| format!("{s:?}"))
691 .join(", "),
692 )?;
693 }
694 }
695 }
696 }
697 writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{o:?}")).join(", "))?;
698 Ok(())
699 }
700}
701
702impl<F, O> Graph<F, O>
703where
704 F: Fact + Clone + 'static + for<'a> std::convert::From<&'a F>,
705 O: std::fmt::Display
706 + std::fmt::Debug
707 + Clone
708 + AsRef<dyn Op>
709 + AsMut<dyn Op>
710 + Clone
711 + 'static
712 + for<'a> std::convert::From<&'a O>,
713 Graph<F, O>: SpecialOps<F, O>,
714{
715 #[cfg(debug_assertions)]
716 pub fn check_compact(&self) -> TractResult<()> {
717 let order = self.eval_order()?;
718 let useless_sources = self
719 .input_outlets()?
720 .iter()
721 .filter(|io| {
722 self.outlet_successors(**io).len() == 0
723 && !self.output_outlets().unwrap().contains(io)
724 })
725 .count();
726 if order.len() + useless_sources != self.nodes.len() {
727 bail!(
728 "Eval order is {} long, nodes are {}, including {} unused sources",
729 order.len(),
730 self.nodes.len(),
731 useless_sources
732 );
733 }
734 if (0..order.len()).any(|ix| order[ix] != ix) {
735 bail!("eval order is not trivial");
736 }
737 let mut seen = std::collections::HashSet::new();
738 for (ix, n) in self.nodes.iter().enumerate() {
739 if ix != n.id {
740 bail!("Invalid node id: position is {}, node is {}", ix, n);
741 }
742 if seen.contains(&n.name) {
743 bail!("duplicate name for node {n}");
744 }
745 seen.insert(&n.name);
746 }
747 Ok(())
748 }
749
750 pub fn compact(&mut self) -> TractResult<()> {
751 let mut order = self.eval_order()?;
752 if order.len() == self.nodes.len() && order.iter().enumerate().all(|(a, b)| a == *b) {
753 return Ok(());
754 }
755 for i in &self.inputs {
756 if !order.contains(&i.node) {
757 order.push(i.node);
758 }
759 }
760 let mut old_to_new = vec![usize::MAX; self.nodes.len()];
761 let mut new_nodes = vec![
762 Node {
763 id: self.nodes.len(),
764 name: "".to_string(),
765 inputs: vec![],
766 op: self.create_dummy(),
767 outputs: tvec!(),
768 };
769 order.len()
770 ];
771 for (ix, id) in order.iter().enumerate() {
772 old_to_new[*id] = ix;
773 std::mem::swap(&mut new_nodes[ix], &mut self.nodes[*id]);
774 }
775 for node in &mut new_nodes {
776 if self.inputs.iter().any(|n| n.node == node.id) && !Self::is_source(&node.op) {
777 node.inputs.clear();
778 node.op = self.create_source(node.outputs[0].fact.clone());
779 }
780 node.id = old_to_new[node.id];
781 for input in &mut node.inputs {
782 assert!(old_to_new[input.node] < order.len());
783 input.node = old_to_new[input.node];
784 }
785 for output in &mut node.outputs {
786 for succ in &mut output.successors {
787 succ.node = old_to_new[succ.node];
788 }
789 output.successors.retain(|s| s.node < order.len());
790 output.successors.sort();
791 }
792 }
793 self.nodes = new_nodes;
794 for input in &mut self.inputs {
795 assert!(old_to_new[input.node] < order.len());
796 input.node = old_to_new[input.node];
797 }
798 for output in &mut self.outputs {
799 assert!(old_to_new[output.node] < order.len());
800 output.node = old_to_new[output.node];
801 }
802 self.outlet_labels = std::mem::take(&mut self.outlet_labels)
803 .into_iter()
804 .map(|(k, v)| (OutletId::new(old_to_new[k.node], k.slot), v))
805 .filter(|(k, _)| k.node < order.len())
806 .collect();
807 ensure!(self.nodes.iter().enumerate().all(|(ix, n)| n.id == ix));
808 #[cfg(debug_assertions)]
809 {
810 self.check_compact().context("after graph compaction")?;
811 }
812 Ok(())
813 }
814
815 pub fn into_compact(mut self) -> TractResult<Self> {
816 self.compact()?;
817 Ok(self)
818 }
819}
820
821pub trait IntoRunnable<F, O>
822where
823 F: Fact + Clone + 'static,
824 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
825{
826 fn into_runnable(self) -> TractResult<Arc<RunnableModel<F, O>>>;
827}
828
829impl<G, F, O> IntoRunnable<F, O> for G
830where
831 F: Fact + Clone + 'static,
832 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
833 G: Into<Arc<Graph<F, O>>>,
834{
835 fn into_runnable(self) -> TractResult<Arc<RunnableModel<F, O>>> {
836 SimplePlan::new(self)
837 }
838}