1use super::*;
2use crate::internal::*;
3use crate::ops::Op;
4use crate::plan::PlanOptions;
5use crate::prelude::*;
6
7use std::fmt;
8use tract_data::internal::*;
9use tract_itertools::Itertools;
10
11pub trait SpecialOps<F, O> {
12 fn create_dummy(&self) -> O;
13 fn create_source(&self, fact: F) -> O;
14 fn is_source(op: &O) -> bool;
15 fn wire_node(
16 &mut self,
17 name: impl Into<String>,
18 op: impl Into<O>,
19 inputs: &[OutletId],
20 ) -> TractResult<TVec<OutletId>>;
21 fn add_const(
22 &mut self,
23 name: impl Into<String>,
24 v: impl IntoArcTensor,
25 ) -> TractResult<OutletId>;
26}
27
28#[derive(Clone, Debug)]
32pub struct Graph<F, O>
33where
34 F: Fact + Clone + 'static,
35 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
36{
37 pub nodes: Vec<Node<F, O>>,
39 pub inputs: Vec<OutletId>,
41 pub outputs: Vec<OutletId>,
43 pub outlet_labels: HashMap<OutletId, String>,
45 pub properties: HashMap<String, Arc<Tensor>>,
47 pub symbols: SymbolScope,
49}
50
51impl<F, O> Default for Graph<F, O>
52where
53 F: Fact + Clone + 'static,
54 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
55{
56 fn default() -> Graph<F, O> {
57 Graph {
58 nodes: vec![],
59 inputs: vec![],
60 outputs: vec![],
61 outlet_labels: HashMap::new(),
62 properties: HashMap::new(),
63 symbols: Default::default(),
64 }
65 }
66}
67
68impl<F, O> Graph<F, O>
69where
70 F: Fact + Clone + 'static,
71 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
72 Graph<F, O>: SpecialOps<F, O>,
73{
74 pub fn add_source(&mut self, name: impl Into<String>, fact: F) -> TractResult<OutletId> {
75 let source = self.create_source(fact.clone());
76 let id = self.add_node(name, source, tvec!(fact))?;
77 let id = OutletId::new(id, 0);
78 self.inputs.push(id);
79 Ok(id)
80 }
81}
82
83impl<F, O> Graph<F, O>
84where
85 F: Fact + Clone + 'static,
86 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
87{
88 pub fn add_node(
89 &mut self,
90 name: impl Into<String>,
91 op: impl Into<O>,
92 output_facts: TVec<F>,
93 ) -> TractResult<usize> {
94 let op = op.into();
95 let name = name.into();
96 let id = self.nodes.len();
97 let outputs =
98 output_facts.into_iter().map(|fact| Outlet { fact, successors: tvec!() }).collect();
99 let node = Node { id, name, op, inputs: vec![], outputs };
100 self.nodes.push(node);
101 Ok(id)
102 }
103
104 pub fn add_edge(&mut self, outlet: OutletId, inlet: InletId) -> TractResult<()> {
106 if let Some(previous) = self.nodes[inlet.node].inputs.get(inlet.slot).cloned() {
107 self.nodes[previous.node].outputs[previous.slot]
108 .successors
109 .retain(|&mut succ| succ != inlet);
110 }
111 {
112 let prec = &mut self.nodes[outlet.node];
113 prec.outputs[outlet.slot].successors.push(inlet);
114 }
115 let succ = &mut self.nodes[inlet.node];
116 #[allow(clippy::comparison_chain)]
117 if inlet.slot == succ.inputs.len() {
118 succ.inputs.push(outlet);
119 } else if inlet.slot < succ.inputs.len() {
120 succ.inputs[inlet.slot] = outlet;
121 } else {
122 bail!("Edges must be added in order and consecutive. Trying to connect input {:?} of node {:?} ", inlet.slot, succ)
123 }
124 Ok(())
125 }
126
127 pub fn input_outlets(&self) -> TractResult<&[OutletId]> {
131 Ok(&self.inputs)
132 }
133
134 pub fn set_input_outlets(&mut self, inputs: &[OutletId]) -> TractResult<()> {
136 self.inputs = inputs.to_vec();
137 Ok(())
138 }
139
140 pub fn with_input_outlets(mut self, inputs: &[OutletId]) -> TractResult<Self> {
142 self.set_input_outlets(inputs)?;
143 Ok(self)
144 }
145
146 pub fn set_input_names(
148 &mut self,
149 inputs: impl IntoIterator<Item = impl AsRef<str>>,
150 ) -> TractResult<()> {
151 let mut ids = vec![];
152 for i in inputs.into_iter() {
153 let node = self.node_by_name(&i)?;
154 for o in 0..node.outputs.len() {
155 ids.push(OutletId::new(node.id, o))
156 }
157 }
158 self.inputs = ids;
159 Ok(())
160 }
161
162 pub fn with_input_names(
164 mut self,
165 inputs: impl IntoIterator<Item = impl AsRef<str>>,
166 ) -> TractResult<Self> {
167 self.set_input_names(inputs)?;
168 Ok(self)
169 }
170
171 pub fn input_fact(&self, ix: usize) -> TractResult<&F> {
173 let input = self.input_outlets()?[ix];
174 self.outlet_fact(input)
175 }
176
177 pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
179 let input = self.input_outlets()?[ix];
180 self.outlet_fact_mut(input)
181 }
182
183 pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
185 let outlet = self.inputs[input];
186 self.set_outlet_fact(outlet, fact)
187 }
188
189 pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
191 self.set_input_fact(input, fact)?;
192 Ok(self)
193 }
194
195 pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
198 Ok(&self.outputs)
199 }
200
201 pub fn auto_outputs(&mut self) -> TractResult<()> {
203 let outputs = self
204 .nodes
205 .iter()
206 .flat_map(|n| {
207 let id = n.id;
208 n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
209 (OutletId::new(id, ix), output_fact.successors.len())
210 })
211 })
212 .filter(|(_f, succs)| *succs == 0)
213 .map(|(f, _)| f)
214 .collect();
215 self.outputs = outputs;
216 Ok(())
217 }
218
219 pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
221 self.outputs = outputs.to_vec();
222 Ok(())
223 }
224
225 pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
227 self.set_output_outlets(outputs)?;
228 Ok(self)
229 }
230
231 pub fn set_output_names(
233 &mut self,
234 outputs: impl IntoIterator<Item = impl AsRef<str>>,
235 ) -> TractResult<()> {
236 let mut labels: HashMap<Cow<str>, OutletId> =
237 self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
238 for n in self.nodes() {
239 for ix in 0..n.outputs.len() {
240 labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
241 }
242 }
243 let ids: Vec<OutletId> = outputs
244 .into_iter()
245 .map(|s| {
246 let s = s.as_ref();
247 labels
248 .get(s)
249 .cloned()
250 .or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
251 .ok_or_else(|| format_err!("Node {} not found", s))
252 })
253 .collect::<TractResult<_>>()?;
254 self.outputs = ids;
255 Ok(())
256 }
257
258 pub fn with_output_names(
260 mut self,
261 outputs: impl IntoIterator<Item = impl AsRef<str>>,
262 ) -> TractResult<Self> {
263 self.set_output_names(outputs)?;
264 Ok(self)
265 }
266
267 pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
269 let output = self.output_outlets()?[ix];
270 self.outlet_fact(output)
271 }
272
273 pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
275 let output = self.output_outlets()?[ix];
276 self.outlet_fact_mut(output)
277 }
278
279 pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
281 let outlet = self.outputs[output];
282 self.set_outlet_fact(outlet, fact)
283 }
284
285 pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
287 self.set_output_fact(output, fact)?;
288 Ok(self)
289 }
290
291 pub fn node_names(&self) -> impl Iterator<Item = &str> {
295 self.nodes.iter().map(|s| &*s.name)
296 }
297
298 pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
299 self.nodes
300 .iter()
301 .find(|n| n.name == name)
302 .map(|n| n.id)
303 .with_context(|| format!("No node found for name: \"{name}\""))
304 }
305
306 pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
308 let id: usize = self.node_id_by_name(name.as_ref())?;
309 Ok(&self.nodes[id])
310 }
311
312 pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
314 let id: usize = self.node_id_by_name(name.as_ref())?;
315 Ok(&mut self.nodes[id])
316 }
317
318 pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
319 self.node_mut(id).name = name.to_string();
320 Ok(())
321 }
322
323 pub fn node(&self, id: usize) -> &Node<F, O> {
325 &self.nodes[id]
326 }
327
328 pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
330 &mut self.nodes[id]
331 }
332
333 pub fn nodes(&self) -> &[Node<F, O>] {
335 &self.nodes
336 }
337
338 pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
340 &mut self.nodes
341 }
342
343 pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
345 Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
346 }
347
348 pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
350 self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
351 }
352
353 pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
355 Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
356 }
357
358 pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
362 ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
363 let outlets = &self.nodes[outlet.node].outputs;
364 outlets
365 .get(outlet.slot)
366 .map(|o| &o.fact)
367 .with_context(|| format!("Invalid outlet reference: {outlet:?}"))
368 }
369
370 pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
372 let outlets = &mut self.nodes[outlet.node].outputs;
373 outlets
374 .get_mut(outlet.slot)
375 .map(|o| &mut o.fact)
376 .with_context(|| format!("Invalid outlet reference: {outlet:?}"))
377 }
378
379 pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
381 assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
382 unsafe {
383 outlets
384 .iter()
385 .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
386 .collect()
387 }
388 }
389
390 pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
392 let outlets = &mut self.nodes[outlet.node].outputs;
393 if outlets.len() <= outlet.slot {
394 bail!("Invalid outlet refererence: {:?}", outlet)
395 }
396 outlets[outlet.slot].fact = fact;
397 Ok(())
398 }
399
400 pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
402 self.set_outlet_fact(outlet, fact)?;
403 Ok(self)
404 }
405
406 pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
410 self.outlet_labels.get(&outlet).map(|s| &**s)
411 }
412
413 pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
415 self.outlet_labels.insert(outlet, label);
416 Ok(())
417 }
418
419 pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
421 self.set_outlet_label(outlet, label)?;
422 Ok(self)
423 }
424
425 pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
427 self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
428 }
429
430 pub fn eval_order(&self) -> TractResult<Vec<usize>> {
434 super::order::eval_order(self)
435 }
436
437 pub fn eval_order_opt_ram(&self) -> TractResult<Vec<usize>> {
440 super::order::eval_order_opt_ram(self)
441 }
442
443 #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
444 #[inline]
445 pub fn check_edges(&self) -> TractResult<()> {
446 Ok(())
447 }
448
449 #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
451 pub fn check_edges(&self) -> TractResult<()> {
452 for node_id in self.eval_order()? {
453 let node = &self.nodes[node_id];
454 for (ix, input) in node.inputs.iter().enumerate() {
455 let prec = &self.nodes[input.node];
456 if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
457 bail!(
458 "Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
459 node.id,
460 ix,
461 prec
462 )
463 }
464 }
465 for (ix, output) in node.outputs.iter().enumerate() {
466 for succ in &output.successors {
467 if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
468 bail!(
469 "Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
470 node.id,
471 ix,
472 succ
473 )
474 }
475 }
476 }
477 }
478 Ok(())
479 }
480
481 pub fn eval_tmp_memory_usage<Flushable>(
483 &self,
484 order: &[usize],
485 flushable: Flushable,
486 ) -> TractResult<TVec<(usize, TDim)>>
487 where
488 Flushable: Fn(&Node<F, O>) -> bool,
489 {
490 super::memory::eval_tmp_memory_usage(self, order, flushable)
491 }
492
493 #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
494 #[inline]
495 pub fn check_names(&self) -> TractResult<()> {
496 Ok(())
497 }
498
499 #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
501 pub fn check_names(&self) -> TractResult<()> {
502 let dups =
503 self.eval_order()?.iter().map(|n| &self.nodes[*n].name).duplicates().collect_vec();
504 ensure!(dups.len() == 0, "Duplicate node name(s) : {:?}\n{}", dups, &self);
505 Ok(())
506 }
507
508 pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
510 crate::plan::SimplePlan::new_with_options(self, &PlanOptions::default())
511 }
512
513 pub fn into_runnable_with_options(
516 self,
517 options: &PlanOptions,
518 ) -> TractResult<RunnableModel<F, O, Self>> {
519 crate::plan::SimplePlan::new_with_options(self, options)
520 }
521
522 pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
523 let node = &self.nodes()[id];
524 if node.inputs.len() != 1 {
525 return Ok(None);
526 }
527 let prec = &self.nodes()[node.inputs[0].node];
528 if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
529 return Ok(None);
530 }
531 Ok(Some(prec))
532 }
533
534 pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
535 let mut node = self.node(id);
536 for _ in 0..count {
537 if let Some(next) = self.single_prec(node.id)? {
538 node = next
539 } else {
540 return Ok(None);
541 }
542 }
543 Ok(Some(node))
544 }
545
546 pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
547 let mut node = self.node(id);
548 for _ in 0..count {
549 if let Some(next) = self.single_succ(node.id)? {
550 node = next
551 } else {
552 return Ok(None);
553 }
554 }
555 Ok(Some(node))
556 }
557
558 pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
561 let node = &self.nodes()[id];
562
563 if node.outputs.len() != 1 || node.outputs[0].successors.len() != 1 {
564 return Ok(None);
565 }
566 let succ = node.outputs[0].successors[0];
567 let succ = &self.nodes()[succ.node];
568 if succ.inputs.len() != 1 {
569 return Ok(None);
570 }
571 Ok(Some(succ))
572 }
573
574 pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
575 &self.nodes[outlet.node].outputs[outlet.slot].successors
576 }
577
578 pub fn sym(&self, s: &str) -> Symbol {
580 self.symbols.sym(s)
581 }
582
583 pub fn new_sym_with_prefix(&self, prefix: &str) -> Symbol {
585 self.symbols.new_with_prefix(prefix)
586 }
587
588 pub fn unique_name<'n>(&self, prefix: impl Into<Cow<'n, str>>) -> Cow<'n, str> {
591 let prefix = prefix.into();
592 if self.nodes.iter().all(|n| n.name != *prefix) {
593 return prefix;
594 }
595 for i in 1.. {
596 let s = format!("{prefix}.{i}");
597 if self.nodes.iter().all(|n| n.name != s) {
598 return Cow::Owned(s);
599 }
600 }
601 unreachable!();
602 }
603}
604
605impl<F, O> fmt::Display for Graph<F, O>
606where
607 F: Fact + Clone + 'static,
608 O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
609{
610 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
611 for i in 0..self.nodes.len() {
612 let input_1 =
613 self.nodes[i].inputs.first().map(|o| format!("{o:?}")).unwrap_or_default();
614 let input_2 = self.nodes[i].inputs.get(1).map(|o| format!("{o:?}")).unwrap_or_default();
615 let successors = self.nodes[i]
616 .outputs
617 .first()
618 .iter()
619 .flat_map(|o| o.successors.iter())
620 .collect_vec();
621 let output_1 = successors.first().map(|o| format!("{o:?}")).unwrap_or_default();
622 let output_2 = successors.get(1).map(|o| format!("{o:?}")).unwrap_or_default();
623 writeln!(
624 fmt,
625 "{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {} => {}",
626 i,
627 input_1,
628 input_2,
629 output_1,
630 output_2,
631 self.nodes[i].op().name(),
632 self.nodes[i].name,
633 self.node_input_facts(i).unwrap().iter().map(|f| format!("{f:?}")).join(" ; "),
634 self.node_output_facts(i).unwrap().iter().map(|f| format!("{f:?}")).join(" ; "),
635 )?;
636 if self.nodes[i].inputs.len() > 2 {
637 writeln!(
638 fmt,
639 " | * inputs: {}",
640 self.nodes[i].inputs.iter().map(|s| format!("{s:?}")).join(", ")
641 )?;
642 }
643 if self.nodes[i].outputs.len() > 1
644 || successors.len() > 2
645 || (self.outlet_label(i.into()).is_some()
646 && self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
647 {
648 for o in 0..self.nodes[i].outputs.len() {
649 if self.outlet_successors((i, o).into()).len() > 0 {
650 writeln!(
651 fmt,
652 " | * output #{}: {} {}",
653 o,
654 self.outlet_label((i, o).into()).unwrap_or(""),
655 self.outlet_successors((i, o).into())
656 .iter()
657 .map(|s| format!("{s:?}"))
658 .join(", "),
659 )?;
660 }
661 }
662 }
663 }
664 writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{o:?}")).join(", "))?;
665 Ok(())
666 }
667}
668
669impl<F, O> Graph<F, O>
670where
671 F: Fact + Clone + 'static + for<'a> std::convert::From<&'a F>,
672 O: std::fmt::Display
673 + std::fmt::Debug
674 + Clone
675 + AsRef<dyn Op>
676 + AsMut<dyn Op>
677 + Clone
678 + 'static
679 + for<'a> std::convert::From<&'a O>,
680 Graph<F, O>: SpecialOps<F, O>,
681{
682 #[cfg(debug_assertions)]
683 pub fn check_compact(&self) -> TractResult<()> {
684 let order = self.eval_order()?;
685 let useless_sources = self
686 .input_outlets()?
687 .iter()
688 .filter(|io| {
689 self.outlet_successors(**io).len() == 0
690 && !self.output_outlets().unwrap().contains(io)
691 })
692 .count();
693 if order.len() + useless_sources != self.nodes.len() {
694 bail!(
695 "Eval order is {} long, nodes are {}, including {} unused sources",
696 order.len(),
697 self.nodes.len(),
698 useless_sources
699 );
700 }
701 if (0..order.len()).any(|ix| order[ix] != ix) {
702 bail!("eval order is not trivial");
703 }
704 let mut seen = std::collections::HashSet::new();
705 for (ix, n) in self.nodes.iter().enumerate() {
706 if ix != n.id {
707 bail!("Invalid node id: position is {}, node is {}", ix, n);
708 }
709 if seen.contains(&n.name) {
710 bail!("duplicate name {}", n.name);
711 }
712 seen.insert(&n.name);
713 }
714 Ok(())
715 }
716
717 pub fn compact(&mut self) -> TractResult<()> {
718 let mut order = self.eval_order()?;
719 if order.len() == self.nodes.len() && order.iter().enumerate().all(|(a, b)| a == *b) {
720 return Ok(());
721 }
722 for i in &self.inputs {
723 if !order.contains(&i.node) {
724 order.push(i.node);
725 }
726 }
727 let mut old_to_new = vec![usize::MAX; self.nodes.len()];
728 let mut new_nodes = vec![
729 Node {
730 id: self.nodes.len(),
731 name: "".to_string(),
732 inputs: vec![],
733 op: self.create_dummy(),
734 outputs: tvec!(),
735 };
736 order.len()
737 ];
738 for (ix, id) in order.iter().enumerate() {
739 old_to_new[*id] = ix;
740 std::mem::swap(&mut new_nodes[ix], &mut self.nodes[*id]);
741 }
742 for node in &mut new_nodes {
743 if self.inputs.iter().any(|n| n.node == node.id) && !Self::is_source(&node.op) {
744 node.inputs.clear();
745 node.op = self.create_source(node.outputs[0].fact.clone());
746 }
747 node.id = old_to_new[node.id];
748 for input in &mut node.inputs {
749 assert!(old_to_new[input.node] < order.len());
750 input.node = old_to_new[input.node];
751 }
752 for output in &mut node.outputs {
753 for succ in &mut output.successors {
754 succ.node = old_to_new[succ.node];
755 }
756 output.successors.retain(|s| s.node < order.len());
757 output.successors.sort();
758 }
759 }
760 self.nodes = new_nodes;
761 for input in &mut self.inputs {
762 assert!(old_to_new[input.node] < order.len());
763 input.node = old_to_new[input.node];
764 }
765 for output in &mut self.outputs {
766 assert!(old_to_new[output.node] < order.len());
767 output.node = old_to_new[output.node];
768 }
769 self.outlet_labels = std::mem::take(&mut self.outlet_labels)
770 .into_iter()
771 .map(|(k, v)| (OutletId::new(old_to_new[k.node], k.slot), v))
772 .filter(|(k, _)| k.node < order.len())
773 .collect();
774 ensure!(self.nodes.iter().enumerate().all(|(ix, n)| n.id == ix));
775 #[cfg(debug_assertions)]
776 {
777 self.check_compact().context("after graph compaction")?;
778 }
779 Ok(())
780 }
781
782 pub fn into_compact(mut self) -> TractResult<Self> {
783 self.compact()?;
784 Ok(self)
785 }
786}