tract_core/model/
graph.rs

1use super::*;
2use crate::internal::*;
3use crate::ops::Op;
4use crate::plan::PlanOptions;
5use crate::prelude::*;
6
7use std::fmt;
8use tract_data::internal::*;
9use tract_itertools::Itertools;
10
11pub trait SpecialOps<F, O> {
12    fn create_dummy(&self) -> O;
13    fn create_source(&self, fact: F) -> O;
14    fn is_source(op: &O) -> bool;
15    fn wire_node(
16        &mut self,
17        name: impl Into<String>,
18        op: impl Into<O>,
19        inputs: &[OutletId],
20    ) -> TractResult<TVec<OutletId>>;
21    fn add_const(
22        &mut self,
23        name: impl Into<String>,
24        v: impl IntoArcTensor,
25    ) -> TractResult<OutletId>;
26}
27
28/// Main model class
29///
30/// Parameterized by a Fact class.
31#[derive(Clone, Debug)]
32pub struct Graph<F, O>
33where
34    F: Fact + Clone + 'static,
35    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
36{
37    /// all nodes in the model
38    pub nodes: Vec<Node<F, O>>,
39    /// model inputs
40    pub inputs: Vec<OutletId>,
41    /// model outputs
42    pub outputs: Vec<OutletId>,
43    /// outlet labels
44    pub outlet_labels: HashMap<OutletId, String>,
45    /// model properties
46    pub properties: HashMap<String, Arc<Tensor>>,
47    /// symbol scope, including table
48    pub symbols: SymbolScope,
49}
50
51impl<F, O> Default for Graph<F, O>
52where
53    F: Fact + Clone + 'static,
54    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
55{
56    fn default() -> Graph<F, O> {
57        Graph {
58            nodes: vec![],
59            inputs: vec![],
60            outputs: vec![],
61            outlet_labels: HashMap::new(),
62            properties: HashMap::new(),
63            symbols: Default::default(),
64        }
65    }
66}
67
68impl<F, O> Graph<F, O>
69where
70    F: Fact + Clone + 'static,
71    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
72    Graph<F, O>: SpecialOps<F, O>,
73{
74    pub fn add_source(&mut self, name: impl Into<String>, fact: F) -> TractResult<OutletId> {
75        let source = self.create_source(fact.clone());
76        let id = self.add_node(name, source, tvec!(fact))?;
77        let id = OutletId::new(id, 0);
78        self.inputs.push(id);
79        Ok(id)
80    }
81}
82
83impl<F, O> Graph<F, O>
84where
85    F: Fact + Clone + 'static,
86    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
87{
88    pub fn add_node(
89        &mut self,
90        name: impl Into<String>,
91        op: impl Into<O>,
92        output_facts: TVec<F>,
93    ) -> TractResult<usize> {
94        let op = op.into();
95        let name = name.into();
96        let id = self.nodes.len();
97        let outputs =
98            output_facts.into_iter().map(|fact| Outlet { fact, successors: tvec!() }).collect();
99        let node = Node { id, name, op, inputs: vec![], outputs };
100        self.nodes.push(node);
101        Ok(id)
102    }
103
104    /// Connect a node outlet to a node inlet.
105    pub fn add_edge(&mut self, outlet: OutletId, inlet: InletId) -> TractResult<()> {
106        if let Some(previous) = self.nodes[inlet.node].inputs.get(inlet.slot).cloned() {
107            self.nodes[previous.node].outputs[previous.slot]
108                .successors
109                .retain(|&mut succ| succ != inlet);
110        }
111        {
112            let prec = &mut self.nodes[outlet.node];
113            prec.outputs[outlet.slot].successors.push(inlet);
114        }
115        let succ = &mut self.nodes[inlet.node];
116        #[allow(clippy::comparison_chain)]
117        if inlet.slot == succ.inputs.len() {
118            succ.inputs.push(outlet);
119        } else if inlet.slot < succ.inputs.len() {
120            succ.inputs[inlet.slot] = outlet;
121        } else {
122            bail!("Edges must be added in order and consecutive. Trying to connect input {:?} of node {:?} ", inlet.slot, succ)
123        }
124        Ok(())
125    }
126
127    // Inputs
128
129    /// Get model inputs.
130    pub fn input_outlets(&self) -> TractResult<&[OutletId]> {
131        Ok(&self.inputs)
132    }
133
134    /// Change model inputs.
135    pub fn set_input_outlets(&mut self, inputs: &[OutletId]) -> TractResult<()> {
136        self.inputs = inputs.to_vec();
137        Ok(())
138    }
139
140    /// Change model inputs and return `self`.
141    pub fn with_input_outlets(mut self, inputs: &[OutletId]) -> TractResult<Self> {
142        self.set_input_outlets(inputs)?;
143        Ok(self)
144    }
145
146    /// Set model inputs by the node name.
147    pub fn set_input_names(
148        &mut self,
149        inputs: impl IntoIterator<Item = impl AsRef<str>>,
150    ) -> TractResult<()> {
151        let mut ids = vec![];
152        for i in inputs.into_iter() {
153            let node = self.node_by_name(&i)?;
154            for o in 0..node.outputs.len() {
155                ids.push(OutletId::new(node.id, o))
156            }
157        }
158        self.inputs = ids;
159        Ok(())
160    }
161
162    /// Set model inputs by the node name and return `self`.
163    pub fn with_input_names(
164        mut self,
165        inputs: impl IntoIterator<Item = impl AsRef<str>>,
166    ) -> TractResult<Self> {
167        self.set_input_names(inputs)?;
168        Ok(self)
169    }
170
171    /// Get the `ix`-th input tensor type information.
172    pub fn input_fact(&self, ix: usize) -> TractResult<&F> {
173        let input = self.input_outlets()?[ix];
174        self.outlet_fact(input)
175    }
176
177    /// Get the `ix`-th input tensor type information, mutably.
178    pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
179        let input = self.input_outlets()?[ix];
180        self.outlet_fact_mut(input)
181    }
182
183    /// Set the `ix`-th input tensor type information.
184    pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
185        let outlet = self.inputs[input];
186        self.set_outlet_fact(outlet, fact)
187    }
188
189    /// Set the `ix`-th input tensor type information and return `self`.
190    pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
191        self.set_input_fact(input, fact)?;
192        Ok(self)
193    }
194
195    // Outputs
196    /// Get model outputs.
197    pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
198        Ok(&self.outputs)
199    }
200
201    /// Guess outputs from the topology: node or nodes with no successors.
202    pub fn auto_outputs(&mut self) -> TractResult<()> {
203        let outputs = self
204            .nodes
205            .iter()
206            .flat_map(|n| {
207                let id = n.id;
208                n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
209                    (OutletId::new(id, ix), output_fact.successors.len())
210                })
211            })
212            .filter(|(_f, succs)| *succs == 0)
213            .map(|(f, _)| f)
214            .collect();
215        self.outputs = outputs;
216        Ok(())
217    }
218
219    /// Change model outputs.
220    pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
221        self.outputs = outputs.to_vec();
222        Ok(())
223    }
224
225    /// Change model outputs and return `self`.
226    pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
227        self.set_output_outlets(outputs)?;
228        Ok(self)
229    }
230
231    /// Set model outputs by node names.
232    pub fn set_output_names(
233        &mut self,
234        outputs: impl IntoIterator<Item = impl AsRef<str>>,
235    ) -> TractResult<()> {
236        let mut labels: HashMap<Cow<str>, OutletId> =
237            self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
238        for n in self.nodes() {
239            for ix in 0..n.outputs.len() {
240                labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
241            }
242        }
243        let ids: Vec<OutletId> = outputs
244            .into_iter()
245            .map(|s| {
246                let s = s.as_ref();
247                labels
248                    .get(s)
249                    .cloned()
250                    .or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
251                    .ok_or_else(|| format_err!("Node {} not found", s))
252            })
253            .collect::<TractResult<_>>()?;
254        self.outputs = ids;
255        Ok(())
256    }
257
258    /// Set model outputs by node names and return `self`.
259    pub fn with_output_names(
260        mut self,
261        outputs: impl IntoIterator<Item = impl AsRef<str>>,
262    ) -> TractResult<Self> {
263        self.set_output_names(outputs)?;
264        Ok(self)
265    }
266
267    /// Get the `ix`-th input tensor type information.
268    pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
269        let output = self.output_outlets()?[ix];
270        self.outlet_fact(output)
271    }
272
273    /// Get the `ix`-th input tensor type information, mutably.
274    pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
275        let output = self.output_outlets()?[ix];
276        self.outlet_fact_mut(output)
277    }
278
279    /// Set the `ix`-th output tensor type information.
280    pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
281        let outlet = self.outputs[output];
282        self.set_outlet_fact(outlet, fact)
283    }
284
285    /// Set the `ix`-th output tensor type information and return `self`.
286    pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
287        self.set_output_fact(output, fact)?;
288        Ok(self)
289    }
290
291    // nodes and their facts
292
293    /// Iterate over all node names.
294    pub fn node_names(&self) -> impl Iterator<Item = &str> {
295        self.nodes.iter().map(|s| &*s.name)
296    }
297
298    pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
299        self.nodes
300            .iter()
301            .find(|n| n.name == name)
302            .map(|n| n.id)
303            .with_context(|| format!("No node found for name: \"{name}\""))
304    }
305
306    /// Find a node by its name.
307    pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
308        let id: usize = self.node_id_by_name(name.as_ref())?;
309        Ok(&self.nodes[id])
310    }
311
312    /// Borrow mutably a node by its name.
313    pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
314        let id: usize = self.node_id_by_name(name.as_ref())?;
315        Ok(&mut self.nodes[id])
316    }
317
318    pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
319        self.node_mut(id).name = name.to_string();
320        Ok(())
321    }
322
323    /// Find a node by its id.
324    pub fn node(&self, id: usize) -> &Node<F, O> {
325        &self.nodes[id]
326    }
327
328    /// Find a node by its id.
329    pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
330        &mut self.nodes[id]
331    }
332
333    /// Access the nodes table.
334    pub fn nodes(&self) -> &[Node<F, O>] {
335        &self.nodes
336    }
337
338    /// Access the nodes table.
339    pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
340        &mut self.nodes
341    }
342
343    /// Get input and output tensor information for a node.
344    pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
345        Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
346    }
347
348    /// Get input tensor information for a node.
349    pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
350        self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
351    }
352
353    /// Get output tensor information for a node.
354    pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
355        Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
356    }
357
358    // outlets
359
360    /// Get tensor information for a single outlet.
361    pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
362        ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
363        let outlets = &self.nodes[outlet.node].outputs;
364        outlets
365            .get(outlet.slot)
366            .map(|o| &o.fact)
367            .with_context(|| format!("Invalid outlet reference: {outlet:?}"))
368    }
369
370    /// Get tensor information for a single outlet.
371    pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
372        let outlets = &mut self.nodes[outlet.node].outputs;
373        outlets
374            .get_mut(outlet.slot)
375            .map(|o| &mut o.fact)
376            .with_context(|| format!("Invalid outlet reference: {outlet:?}"))
377    }
378
379    /// Get multiple mutable tensor information for outlets.
380    pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
381        assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
382        unsafe {
383            outlets
384                .iter()
385                .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
386                .collect()
387        }
388    }
389
390    /// Set tensor information for a single outlet.
391    pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
392        let outlets = &mut self.nodes[outlet.node].outputs;
393        if outlets.len() <= outlet.slot {
394            bail!("Invalid outlet refererence: {:?}", outlet)
395        }
396        outlets[outlet.slot].fact = fact;
397        Ok(())
398    }
399
400    /// Set tensor information for a single outlet and return `self`.
401    pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
402        self.set_outlet_fact(outlet, fact)?;
403        Ok(self)
404    }
405
406    // outlet labels
407
408    /// Get label for an outlet.
409    pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
410        self.outlet_labels.get(&outlet).map(|s| &**s)
411    }
412
413    /// Set label for an outlet.
414    pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
415        self.outlet_labels.insert(outlet, label);
416        Ok(())
417    }
418
419    /// Set label for an outlet and return `self`.
420    pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
421        self.set_outlet_label(outlet, label)?;
422        Ok(self)
423    }
424
425    /// Find outlet by label.
426    pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
427        self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
428    }
429
430    // misc
431
432    /// Computes an evalutation order for the graph inputs and outputs
433    pub fn eval_order(&self) -> TractResult<Vec<usize>> {
434        super::order::eval_order(self)
435    }
436
437    /// Computes an evalutation order for the graph inputs and outputs. This order will minimize
438    /// temporary buffers.
439    pub fn eval_order_opt_ram(&self) -> TractResult<Vec<usize>> {
440        super::order::eval_order_opt_ram(self)
441    }
442
443    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
444    #[inline]
445    pub fn check_edges(&self) -> TractResult<()> {
446        Ok(())
447    }
448
449    /// Performs a sanity check on network connections.
450    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
451    pub fn check_edges(&self) -> TractResult<()> {
452        for node_id in self.eval_order()? {
453            let node = &self.nodes[node_id];
454            for (ix, input) in node.inputs.iter().enumerate() {
455                let prec = &self.nodes[input.node];
456                if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
457                    bail!(
458                        "Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
459                        node.id,
460                        ix,
461                        prec
462                    )
463                }
464            }
465            for (ix, output) in node.outputs.iter().enumerate() {
466                for succ in &output.successors {
467                    if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
468                        bail!(
469                            "Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
470                            node.id,
471                            ix,
472                            succ
473                        )
474                    }
475                }
476            }
477        }
478        Ok(())
479    }
480
481    /// Evaluate temporary memory usage with its related node at each step of the given order.
482    pub fn eval_tmp_memory_usage<Flushable>(
483        &self,
484        order: &[usize],
485        flushable: Flushable,
486    ) -> TractResult<TVec<(usize, TDim)>>
487    where
488        Flushable: Fn(&Node<F, O>) -> bool,
489    {
490        super::memory::eval_tmp_memory_usage(self, order, flushable)
491    }
492
493    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
494    #[inline]
495    pub fn check_names(&self) -> TractResult<()> {
496        Ok(())
497    }
498
499    /// Performs a sanity check on network connections.
500    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
501    pub fn check_names(&self) -> TractResult<()> {
502        let dups =
503            self.eval_order()?.iter().map(|n| &self.nodes[*n].name).duplicates().collect_vec();
504        ensure!(dups.len() == 0, "Duplicate node name(s) : {:?}\n{}", dups, &self);
505        Ok(())
506    }
507
508    /// Converts the model into a `RunnableModel` to actually process user data.
509    pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
510        crate::plan::SimplePlan::new_with_options(self, &PlanOptions::default())
511    }
512
513    /// Converts the model into a `RunnableModel` to actually process user data. This variant
514    /// accepts options.
515    pub fn into_runnable_with_options(
516        self,
517        options: &PlanOptions,
518    ) -> TractResult<RunnableModel<F, O, Self>> {
519        crate::plan::SimplePlan::new_with_options(self, options)
520    }
521
522    pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
523        let node = &self.nodes()[id];
524        if node.inputs.len() != 1 {
525            return Ok(None);
526        }
527        let prec = &self.nodes()[node.inputs[0].node];
528        if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
529            return Ok(None);
530        }
531        Ok(Some(prec))
532    }
533
534    pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
535        let mut node = self.node(id);
536        for _ in 0..count {
537            if let Some(next) = self.single_prec(node.id)? {
538                node = next
539            } else {
540                return Ok(None);
541            }
542        }
543        Ok(Some(node))
544    }
545
546    pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
547        let mut node = self.node(id);
548        for _ in 0..count {
549            if let Some(next) = self.single_succ(node.id)? {
550                node = next
551            } else {
552                return Ok(None);
553            }
554        }
555        Ok(Some(node))
556    }
557
558    /// single_succ is only intended for optimisation of simple operators
559    /// with 1 output, and only 1 output successors (successor with only 1 input)
560    pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
561        let node = &self.nodes()[id];
562
563        if node.outputs.len() != 1 || node.outputs[0].successors.len() != 1 {
564            return Ok(None);
565        }
566        let succ = node.outputs[0].successors[0];
567        let succ = &self.nodes()[succ.node];
568        if succ.inputs.len() != 1 {
569            return Ok(None);
570        }
571        Ok(Some(succ))
572    }
573
574    pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
575        &self.nodes[outlet.node].outputs[outlet.slot].successors
576    }
577
578    /// retrieve of create a symbol
579    pub fn sym(&self, s: &str) -> Symbol {
580        self.symbols.sym(s)
581    }
582
583    /// create a new symbol with the prefix
584    pub fn new_sym_with_prefix(&self, prefix: &str) -> Symbol {
585        self.symbols.new_with_prefix(prefix)
586    }
587
588    /// generates a name for a new node in the model that will not conflict (by suffixing with a
589    /// dot and number)
590    pub fn unique_name<'n>(&self, prefix: impl Into<Cow<'n, str>>) -> Cow<'n, str> {
591        let prefix = prefix.into();
592        if self.nodes.iter().all(|n| n.name != *prefix) {
593            return prefix;
594        }
595        for i in 1.. {
596            let s = format!("{prefix}.{i}");
597            if self.nodes.iter().all(|n| n.name != s) {
598                return Cow::Owned(s);
599            }
600        }
601        unreachable!();
602    }
603}
604
605impl<F, O> fmt::Display for Graph<F, O>
606where
607    F: Fact + Clone + 'static,
608    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
609{
610    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
611        for i in 0..self.nodes.len() {
612            let input_1 =
613                self.nodes[i].inputs.first().map(|o| format!("{o:?}")).unwrap_or_default();
614            let input_2 = self.nodes[i].inputs.get(1).map(|o| format!("{o:?}")).unwrap_or_default();
615            let successors = self.nodes[i]
616                .outputs
617                .first()
618                .iter()
619                .flat_map(|o| o.successors.iter())
620                .collect_vec();
621            let output_1 = successors.first().map(|o| format!("{o:?}")).unwrap_or_default();
622            let output_2 = successors.get(1).map(|o| format!("{o:?}")).unwrap_or_default();
623            writeln!(
624                fmt,
625                "{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {} => {}",
626                i,
627                input_1,
628                input_2,
629                output_1,
630                output_2,
631                self.nodes[i].op().name(),
632                self.nodes[i].name,
633                self.node_input_facts(i).unwrap().iter().map(|f| format!("{f:?}")).join(" ; "),
634                self.node_output_facts(i).unwrap().iter().map(|f| format!("{f:?}")).join(" ; "),
635            )?;
636            if self.nodes[i].inputs.len() > 2 {
637                writeln!(
638                    fmt,
639                    "                                               |   * inputs: {}",
640                    self.nodes[i].inputs.iter().map(|s| format!("{s:?}")).join(", ")
641                )?;
642            }
643            if self.nodes[i].outputs.len() > 1
644                || successors.len() > 2
645                || (self.outlet_label(i.into()).is_some()
646                    && self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
647            {
648                for o in 0..self.nodes[i].outputs.len() {
649                    if self.outlet_successors((i, o).into()).len() > 0 {
650                        writeln!(
651                                    fmt,
652                                    "                                               |   * output #{}: {} {}",
653                                    o,
654                                    self.outlet_label((i, o).into()).unwrap_or(""),
655                                    self.outlet_successors((i, o).into())
656                                    .iter()
657                                    .map(|s| format!("{s:?}"))
658                                    .join(", "),
659                                    )?;
660                    }
661                }
662            }
663        }
664        writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{o:?}")).join(", "))?;
665        Ok(())
666    }
667}
668
669impl<F, O> Graph<F, O>
670where
671    F: Fact + Clone + 'static + for<'a> std::convert::From<&'a F>,
672    O: std::fmt::Display
673        + std::fmt::Debug
674        + Clone
675        + AsRef<dyn Op>
676        + AsMut<dyn Op>
677        + Clone
678        + 'static
679        + for<'a> std::convert::From<&'a O>,
680    Graph<F, O>: SpecialOps<F, O>,
681{
682    #[cfg(debug_assertions)]
683    pub fn check_compact(&self) -> TractResult<()> {
684        let order = self.eval_order()?;
685        let useless_sources = self
686            .input_outlets()?
687            .iter()
688            .filter(|io| {
689                self.outlet_successors(**io).len() == 0
690                    && !self.output_outlets().unwrap().contains(io)
691            })
692            .count();
693        if order.len() + useless_sources != self.nodes.len() {
694            bail!(
695                "Eval order is {} long, nodes are {}, including {} unused sources",
696                order.len(),
697                self.nodes.len(),
698                useless_sources
699            );
700        }
701        if (0..order.len()).any(|ix| order[ix] != ix) {
702            bail!("eval order is not trivial");
703        }
704        let mut seen = std::collections::HashSet::new();
705        for (ix, n) in self.nodes.iter().enumerate() {
706            if ix != n.id {
707                bail!("Invalid node id: position is {}, node is {}", ix, n);
708            }
709            if seen.contains(&n.name) {
710                bail!("duplicate name {}", n.name);
711            }
712            seen.insert(&n.name);
713        }
714        Ok(())
715    }
716
717    pub fn compact(&mut self) -> TractResult<()> {
718        let mut order = self.eval_order()?;
719        if order.len() == self.nodes.len() && order.iter().enumerate().all(|(a, b)| a == *b) {
720            return Ok(());
721        }
722        for i in &self.inputs {
723            if !order.contains(&i.node) {
724                order.push(i.node);
725            }
726        }
727        let mut old_to_new = vec![usize::MAX; self.nodes.len()];
728        let mut new_nodes = vec![
729            Node {
730                id: self.nodes.len(),
731                name: "".to_string(),
732                inputs: vec![],
733                op: self.create_dummy(),
734                outputs: tvec!(),
735            };
736            order.len()
737        ];
738        for (ix, id) in order.iter().enumerate() {
739            old_to_new[*id] = ix;
740            std::mem::swap(&mut new_nodes[ix], &mut self.nodes[*id]);
741        }
742        for node in &mut new_nodes {
743            if self.inputs.iter().any(|n| n.node == node.id) && !Self::is_source(&node.op) {
744                node.inputs.clear();
745                node.op = self.create_source(node.outputs[0].fact.clone());
746            }
747            node.id = old_to_new[node.id];
748            for input in &mut node.inputs {
749                assert!(old_to_new[input.node] < order.len());
750                input.node = old_to_new[input.node];
751            }
752            for output in &mut node.outputs {
753                for succ in &mut output.successors {
754                    succ.node = old_to_new[succ.node];
755                }
756                output.successors.retain(|s| s.node < order.len());
757                output.successors.sort();
758            }
759        }
760        self.nodes = new_nodes;
761        for input in &mut self.inputs {
762            assert!(old_to_new[input.node] < order.len());
763            input.node = old_to_new[input.node];
764        }
765        for output in &mut self.outputs {
766            assert!(old_to_new[output.node] < order.len());
767            output.node = old_to_new[output.node];
768        }
769        self.outlet_labels = std::mem::take(&mut self.outlet_labels)
770            .into_iter()
771            .map(|(k, v)| (OutletId::new(old_to_new[k.node], k.slot), v))
772            .filter(|(k, _)| k.node < order.len())
773            .collect();
774        ensure!(self.nodes.iter().enumerate().all(|(ix, n)| n.id == ix));
775        #[cfg(debug_assertions)]
776        {
777            self.check_compact().context("after graph compaction")?;
778        }
779        Ok(())
780    }
781
782    pub fn into_compact(mut self) -> TractResult<Self> {
783        self.compact()?;
784        Ok(self)
785    }
786}