Skip to main content

tract_core/model/
graph.rs

1use super::*;
2use crate::internal::*;
3use crate::ops::Op;
4use crate::prelude::*;
5use crate::runtime::RunOptions;
6
7use std::fmt;
8use tract_data::internal::*;
9use tract_itertools::Itertools;
10
11pub trait SpecialOps<F, O> {
12    fn create_dummy(&self) -> O;
13    fn create_source(&self, fact: F) -> O;
14    fn is_source(op: &O) -> bool;
15    fn wire_node(
16        &mut self,
17        name: impl Into<String>,
18        op: impl Into<O>,
19        inputs: &[OutletId],
20    ) -> TractResult<TVec<OutletId>>;
21    fn add_const(
22        &mut self,
23        name: impl Into<String>,
24        v: impl IntoArcTensor,
25    ) -> TractResult<OutletId>;
26}
27
28/// Main model class
29///
30/// Parameterized by a Fact class.
31#[derive(Clone, Debug)]
32pub struct Graph<F, O>
33where
34    F: Fact + Clone + 'static,
35    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
36{
37    /// all nodes in the model
38    pub nodes: Vec<Node<F, O>>,
39    /// model inputs
40    pub inputs: Vec<OutletId>,
41    /// model outputs
42    pub outputs: Vec<OutletId>,
43    /// outlet labels
44    pub outlet_labels: HashMap<OutletId, String>,
45    /// model properties
46    pub properties: HashMap<String, Arc<Tensor>>,
47    /// symbol scope, including table
48    pub symbols: SymbolScope,
49}
50
51impl<F, O> Default for Graph<F, O>
52where
53    F: Fact + Clone + 'static,
54    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
55{
56    fn default() -> Graph<F, O> {
57        Graph {
58            nodes: vec![],
59            inputs: vec![],
60            outputs: vec![],
61            outlet_labels: HashMap::new(),
62            properties: HashMap::new(),
63            symbols: Default::default(),
64        }
65    }
66}
67
68impl<F, O> Graph<F, O>
69where
70    F: Fact + Clone + 'static,
71    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
72    Graph<F, O>: SpecialOps<F, O>,
73{
74    pub fn add_source(&mut self, name: impl Into<String>, fact: F) -> TractResult<OutletId> {
75        let source = self.create_source(fact.clone());
76        let id = self.add_node(name, source, tvec!(fact))?;
77        let id = OutletId::new(id, 0);
78        self.inputs.push(id);
79        Ok(id)
80    }
81}
82
83impl<F, O> Graph<F, O>
84where
85    F: Fact + Clone + 'static,
86    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
87{
88    pub fn add_node(
89        &mut self,
90        name: impl Into<String>,
91        op: impl Into<O>,
92        output_facts: TVec<F>,
93    ) -> TractResult<usize> {
94        let op = op.into();
95        let name = name.into();
96        let id = self.nodes.len();
97        let outputs =
98            output_facts.into_iter().map(|fact| Outlet { fact, successors: tvec!() }).collect();
99        let node = Node { id, name, op, inputs: vec![], outputs };
100        self.nodes.push(node);
101        Ok(id)
102    }
103
104    /// Connect a node outlet to a node inlet.
105    pub fn add_edge(&mut self, outlet: OutletId, inlet: InletId) -> TractResult<()> {
106        if let Some(previous) = self.nodes[inlet.node].inputs.get(inlet.slot).cloned() {
107            self.nodes[previous.node].outputs[previous.slot]
108                .successors
109                .retain(|&mut succ| succ != inlet);
110        }
111        {
112            let prec = &mut self.nodes[outlet.node];
113            prec.outputs[outlet.slot].successors.push(inlet);
114        }
115        let succ = &mut self.nodes[inlet.node];
116        #[allow(clippy::comparison_chain)]
117        if inlet.slot == succ.inputs.len() {
118            succ.inputs.push(outlet);
119        } else if inlet.slot < succ.inputs.len() {
120            succ.inputs[inlet.slot] = outlet;
121        } else {
122            bail!(
123                "Edges must be added in order and consecutive. Trying to connect input {:?} of node {:?} ",
124                inlet.slot,
125                succ
126            )
127        }
128        Ok(())
129    }
130
131    // Inputs
132
133    /// Get model inputs.
134    pub fn input_outlets(&self) -> TractResult<&[OutletId]> {
135        Ok(&self.inputs)
136    }
137
138    /// Change model inputs.
139    pub fn set_input_outlets(&mut self, inputs: &[OutletId]) -> TractResult<()> {
140        self.inputs = inputs.to_vec();
141        Ok(())
142    }
143
144    /// Change model inputs and return `self`.
145    pub fn with_input_outlets(mut self, inputs: &[OutletId]) -> TractResult<Self> {
146        self.set_input_outlets(inputs)?;
147        Ok(self)
148    }
149
150    /// Set model inputs by the node name.
151    pub fn set_input_names(
152        &mut self,
153        inputs: impl IntoIterator<Item = impl AsRef<str>>,
154    ) -> TractResult<()> {
155        let mut ids = vec![];
156        for i in inputs.into_iter() {
157            let node = self.node_by_name(&i)?;
158            for o in 0..node.outputs.len() {
159                ids.push(OutletId::new(node.id, o))
160            }
161        }
162        self.inputs = ids;
163        Ok(())
164    }
165
166    /// Set model inputs by the node name and return `self`.
167    pub fn with_input_names(
168        mut self,
169        inputs: impl IntoIterator<Item = impl AsRef<str>>,
170    ) -> TractResult<Self> {
171        self.set_input_names(inputs)?;
172        Ok(self)
173    }
174
175    /// Set model inputs by node name — mirror of [`Self::select_outputs_by_name`].
176    /// Removed inputs become dangling Source nodes; declutter prunes them.
177    pub fn select_inputs_by_name(
178        &mut self,
179        inputs: impl IntoIterator<Item = impl AsRef<str>>,
180    ) -> TractResult<()> {
181        self.set_input_names(inputs)
182    }
183
184    /// Set model inputs by node name and return `self`.
185    pub fn with_inputs_by_name(
186        mut self,
187        inputs: impl IntoIterator<Item = impl AsRef<str>>,
188    ) -> TractResult<Self> {
189        self.select_inputs_by_name(inputs)?;
190        Ok(self)
191    }
192
193    /// Get the `ix`-th input tensor type information.
194    pub fn input_fact(&self, ix: usize) -> TractResult<&F> {
195        let input = self.input_outlets()?[ix];
196        self.outlet_fact(input)
197    }
198
199    /// Get the `ix`-th input tensor type information, mutably.
200    pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
201        let input = self.input_outlets()?[ix];
202        self.outlet_fact_mut(input)
203    }
204
205    /// Set the `ix`-th input tensor type information.
206    pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
207        let outlet = self.inputs[input];
208        self.set_outlet_fact(outlet, fact)
209    }
210
211    /// Set the `ix`-th input tensor type information and return `self`.
212    pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
213        self.set_input_fact(input, fact)?;
214        Ok(self)
215    }
216
217    // Outputs
218    /// Get model outputs.
219    pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
220        Ok(&self.outputs)
221    }
222
223    /// Guess outputs from the topology: node or nodes with no successors.
224    pub fn auto_outputs(&mut self) -> TractResult<()> {
225        let outputs = self
226            .nodes
227            .iter()
228            .flat_map(|n| {
229                let id = n.id;
230                n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
231                    (OutletId::new(id, ix), output_fact.successors.len())
232                })
233            })
234            .filter(|(_f, succs)| *succs == 0)
235            .map(|(f, _)| f)
236            .collect();
237        self.outputs = outputs;
238        Ok(())
239    }
240
241    /// Change model outputs.
242    pub fn select_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
243        self.outputs = outputs.to_vec();
244        Ok(())
245    }
246
247    /// Change model outputs and return `self`.
248    pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
249        self.select_output_outlets(outputs)?;
250        Ok(self)
251    }
252
253    /// Set model outputs by node names.
254    pub fn select_outputs_by_name(
255        &mut self,
256        outputs: impl IntoIterator<Item = impl AsRef<str>>,
257    ) -> TractResult<()> {
258        let mut labels: HashMap<StaticName, OutletId> =
259            self.outlet_labels.iter().map(|(o, s)| (Cow::Owned((*s).to_string()), *o)).collect();
260        for n in self.nodes() {
261            for ix in 0..n.outputs.len() {
262                labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
263            }
264        }
265        let ids: Vec<OutletId> = outputs
266            .into_iter()
267            .map(|s| {
268                let s = s.as_ref();
269                labels
270                    .get(s)
271                    .cloned()
272                    .or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
273                    .ok_or_else(|| format_err!("Node {} not found", s))
274            })
275            .collect::<TractResult<_>>()?;
276        self.outputs = ids;
277        Ok(())
278    }
279
280    /// Set model outputs by node names and return `self`.
281    pub fn with_outputs_by_name(
282        mut self,
283        outputs: impl IntoIterator<Item = impl AsRef<str>>,
284    ) -> TractResult<Self> {
285        self.select_outputs_by_name(outputs)?;
286        Ok(self)
287    }
288
289    /// Get the `ix`-th input tensor type information.
290    pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
291        let output = self.output_outlets()?[ix];
292        self.outlet_fact(output)
293    }
294
295    /// Get the `ix`-th input tensor type information, mutably.
296    pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
297        let output = self.output_outlets()?[ix];
298        self.outlet_fact_mut(output)
299    }
300
301    /// Set the `ix`-th output tensor type information.
302    pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
303        let outlet = self.outputs[output];
304        self.set_outlet_fact(outlet, fact)
305    }
306
307    /// Set the `ix`-th output tensor type information and return `self`.
308    pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
309        self.set_output_fact(output, fact)?;
310        Ok(self)
311    }
312
313    // nodes and their facts
314
315    /// Iterate over all node names.
316    pub fn node_names(&self) -> impl Iterator<Item = &str> {
317        self.nodes.iter().map(|s| &*s.name)
318    }
319
320    pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
321        self.nodes
322            .iter()
323            .find(|n| n.name == name)
324            .map(|n| n.id)
325            .with_context(|| format!("No node found for name: \"{name}\""))
326    }
327
328    /// Find a node by its name.
329    pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
330        let id: usize = self.node_id_by_name(name.as_ref())?;
331        Ok(&self.nodes[id])
332    }
333
334    /// Borrow mutably a node by its name.
335    pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
336        let id: usize = self.node_id_by_name(name.as_ref())?;
337        Ok(&mut self.nodes[id])
338    }
339
340    pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
341        self.node_mut(id).name = name.to_string();
342        Ok(())
343    }
344
345    /// Find a node by its id.
346    pub fn node(&self, id: usize) -> &Node<F, O> {
347        &self.nodes[id]
348    }
349
350    /// Find a node by its id.
351    pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
352        &mut self.nodes[id]
353    }
354
355    /// Access the nodes table.
356    pub fn nodes(&self) -> &[Node<F, O>] {
357        &self.nodes
358    }
359
360    /// Access the nodes table.
361    pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
362        &mut self.nodes
363    }
364
365    /// Get input and output tensor information for a node.
366    pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
367        Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
368    }
369
370    /// Get input tensor information for a node.
371    pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
372        self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
373    }
374
375    /// Get output tensor information for a node.
376    pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
377        Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
378    }
379
380    // outlets
381
382    /// Get tensor information for a single outlet.
383    pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
384        ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
385        let outlets = &self.nodes[outlet.node].outputs;
386        outlets
387            .get(outlet.slot)
388            .map(|o| &o.fact)
389            .with_context(|| format!("Invalid outlet reference: {outlet:?}"))
390    }
391
392    /// Get tensor information for a single outlet.
393    pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
394        let outlets = &mut self.nodes[outlet.node].outputs;
395        outlets
396            .get_mut(outlet.slot)
397            .map(|o| &mut o.fact)
398            .with_context(|| format!("Invalid outlet reference: {outlet:?}"))
399    }
400
401    /// Get multiple mutable tensor information for outlets.
402    pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
403        assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
404        unsafe {
405            outlets
406                .iter()
407                .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
408                .collect()
409        }
410    }
411
412    /// Set tensor information for a single outlet.
413    pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
414        let outlets = &mut self.nodes[outlet.node].outputs;
415        if outlets.len() <= outlet.slot {
416            bail!("Invalid outlet refererence: {:?}", outlet)
417        }
418        outlets[outlet.slot].fact = fact;
419        Ok(())
420    }
421
422    /// Set tensor information for a single outlet and return `self`.
423    pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
424        self.set_outlet_fact(outlet, fact)?;
425        Ok(self)
426    }
427
428    // outlet labels
429
430    /// Get label for an outlet.
431    pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
432        self.outlet_labels.get(&outlet).map(|s| &**s)
433    }
434
435    /// Set label for an outlet.
436    pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
437        self.outlet_labels.insert(outlet, label);
438        Ok(())
439    }
440
441    /// Set label for an outlet and return `self`.
442    pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
443        self.set_outlet_label(outlet, label)?;
444        Ok(self)
445    }
446
447    /// Find outlet by label.
448    pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
449        self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
450    }
451
452    // misc
453
454    /// Computes an evalutation order for the graph inputs and outputs
455    pub fn eval_order(&self) -> TractResult<Vec<usize>> {
456        super::order::eval_order(self)
457    }
458
459    /// Computes an evalutation order for the graph inputs and outputs. This order will minimize
460    /// temporary buffers.
461    pub fn eval_order_opt_ram(&self) -> TractResult<Vec<usize>> {
462        super::order::eval_order_opt_ram(self)
463    }
464
465    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
466    #[inline]
467    pub fn check_edges(&self) -> TractResult<()> {
468        Ok(())
469    }
470
471    /// Performs a sanity check on network connections.
472    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
473    pub fn check_edges(&self) -> TractResult<()> {
474        for node_id in self.eval_order()? {
475            let node = &self.nodes[node_id];
476            for (ix, input) in node.inputs.iter().enumerate() {
477                let prec = &self.nodes[input.node];
478                if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
479                    bail!(
480                        "Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
481                        node.id,
482                        ix,
483                        prec
484                    )
485                }
486            }
487            for (ix, output) in node.outputs.iter().enumerate() {
488                for succ in &output.successors {
489                    if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
490                        bail!(
491                            "Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
492                            node.id,
493                            ix,
494                            succ
495                        )
496                    }
497                }
498            }
499        }
500        Ok(())
501    }
502
503    /// Evaluate temporary memory usage with its related node at each step of the given order.
504    pub fn eval_tmp_memory_usage<Flushable>(
505        &self,
506        order: &[usize],
507        flushable: Flushable,
508    ) -> TractResult<TVec<(usize, TDim)>>
509    where
510        Flushable: Fn(&Node<F, O>) -> bool,
511    {
512        super::memory::eval_tmp_memory_usage(self, order, flushable)
513    }
514
515    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
516    #[inline]
517    pub fn check_names(&self) -> TractResult<()> {
518        Ok(())
519    }
520
521    /// Performs a sanity check on network connections.
522    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
523    pub fn check_names(&self) -> TractResult<()> {
524        let dups =
525            self.eval_order()?.iter().map(|n| &self.nodes[*n].name).duplicates().collect_vec();
526        ensure!(dups.len() == 0, "Duplicate node name(s) : {:?}\n{}", dups, &self);
527        Ok(())
528    }
529
530    // Converts the model into a `RunnableModel` to actually process user data.
531    // pub fn into_runnable(self) -> TractResult<Arc<RunnableModel<F, O>>> {
532    //     crate::plan::SimplePlan::new_with_options(self, &PlanOptions::default())
533    // }
534
535    /// Converts the model into a `RunnableModel` to actually process user data. This variant
536    /// accepts options.
537    pub fn into_runnable_with_options(
538        self,
539        options: &RunOptions,
540    ) -> TractResult<Arc<RunnableModel<F, O>>> {
541        crate::plan::SimplePlan::new_with_options(self, options)
542    }
543
544    pub fn linear_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
545        let node = &self.nodes()[id];
546        rule_if!(node.inputs.len() == 1);
547        let prec = &self.nodes()[node.inputs[0].node];
548        rule_if!(prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() == 1);
549        Ok(Some(prec))
550    }
551
552    pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
553        let node = &self.nodes()[id];
554        rule_if!(node.inputs.len() == 1);
555        let prec = &self.nodes()[node.inputs[0].node];
556        Ok(Some(prec))
557    }
558
559    pub fn all_prec(&self, id: usize) -> TractResult<Option<TVec<&Node<F, O>>>> {
560        let node = &self.nodes()[id];
561        rule_if!(node.inputs.len() > 0);
562        Ok(Some(node.inputs.iter().map(|n| &self.nodes()[n.node]).collect()))
563    }
564
565    /// linear_succ is only intended for optimisation of simple operators
566    /// with 1 output, and only 1 output successors (successor with only 1 input)
567    pub fn linear_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
568        let node = &self.nodes()[id];
569
570        rule_if!(node.outputs.len() == 1);
571        rule_if!(node.outputs[0].successors.len() == 1);
572        let succ = node.outputs[0].successors[0];
573        let succ = &self.nodes()[succ.node];
574        rule_if!(succ.inputs.len() == 1);
575        Ok(Some(succ))
576    }
577
578    pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
579        let node = &self.nodes()[id];
580
581        rule_if!(node.outputs.len() == 1);
582        rule_if!(node.outputs[0].successors.len() == 1);
583        let succ = node.outputs[0].successors[0];
584        Ok(Some(&self.nodes()[succ.node]))
585    }
586
587    pub fn all_succ(&self, id: usize) -> TractResult<Option<TVec<&Node<F, O>>>> {
588        let node = &self.nodes()[id];
589        rule_if!(!node.outputs.is_empty());
590
591        Ok(Some(
592            node.outputs
593                .iter()
594                .flat_map(|o| {
595                    o.successors.iter().map(|succ| &self.nodes()[succ.node]).collect::<Vec<_>>()
596                })
597                .collect(),
598        ))
599    }
600
601    pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
602        &self.nodes[outlet.node].outputs[outlet.slot].successors
603    }
604
605    /// retrieve of create a symbol
606    pub fn sym(&self, s: &str) -> Symbol {
607        self.symbols.sym(s)
608    }
609
610    /// create a new symbol with the prefix
611    pub fn new_sym_with_prefix(&self, prefix: &str) -> Symbol {
612        self.symbols.new_with_prefix(prefix)
613    }
614
615    /// generates a name for a new node in the model that will not conflict (by suffixing with a
616    /// dot and number)
617    pub fn unique_name<'n>(&self, prefix: impl Into<Cow<'n, str>>) -> Cow<'n, str> {
618        let prefix = prefix.into();
619        if self.nodes.iter().all(|n| n.name != *prefix) {
620            return prefix;
621        }
622        for i in 1.. {
623            let s = format!("{prefix}.{i}");
624            if self.nodes.iter().all(|n| n.name != s) {
625                return Cow::Owned(s);
626            }
627        }
628        unreachable!();
629    }
630}
631
632impl<F, O> fmt::Display for Graph<F, O>
633where
634    F: Fact + Clone + 'static,
635    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
636{
637    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
638        for i in 0..self.nodes.len() {
639            let input_1 =
640                self.nodes[i].inputs.first().map(|o| format!("{o:?}")).unwrap_or_default();
641            let input_2 = self.nodes[i].inputs.get(1).map(|o| format!("{o:?}")).unwrap_or_default();
642            let successors = self.nodes[i]
643                .outputs
644                .first()
645                .iter()
646                .flat_map(|o| o.successors.iter())
647                .collect_vec();
648            let output_1 = successors.first().map(|o| format!("{o:?}")).unwrap_or_default();
649            let output_2 = successors.get(1).map(|o| format!("{o:?}")).unwrap_or_default();
650            writeln!(
651                fmt,
652                "{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {} => {}",
653                i,
654                input_1,
655                input_2,
656                output_1,
657                output_2,
658                self.nodes[i].op().name(),
659                self.nodes[i].name,
660                self.node_input_facts(i).unwrap().iter().map(|f| format!("{f:?}")).join(" ; "),
661                self.node_output_facts(i).unwrap().iter().map(|f| format!("{f:?}")).join(" ; "),
662            )?;
663            if self.nodes[i].inputs.len() > 2 {
664                writeln!(
665                    fmt,
666                    "                                               |   * inputs: {}",
667                    self.nodes[i].inputs.iter().map(|s| format!("{s:?}")).join(", ")
668                )?;
669            }
670            if self.nodes[i].outputs.len() > 1
671                || successors.len() > 2
672                || (self.outlet_label(i.into()).is_some()
673                    && self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
674            {
675                for o in 0..self.nodes[i].outputs.len() {
676                    if self.outlet_successors((i, o).into()).len() > 0 {
677                        writeln!(
678                            fmt,
679                            "                                               |   * output #{}: {} {}",
680                            o,
681                            self.outlet_label((i, o).into()).unwrap_or(""),
682                            self.outlet_successors((i, o).into())
683                                .iter()
684                                .map(|s| format!("{s:?}"))
685                                .join(", "),
686                        )?;
687                    }
688                }
689            }
690        }
691        writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{o:?}")).join(", "))?;
692        Ok(())
693    }
694}
695
696impl<F, O> Graph<F, O>
697where
698    F: Fact + Clone + 'static + for<'a> std::convert::From<&'a F>,
699    O: std::fmt::Display
700        + std::fmt::Debug
701        + Clone
702        + AsRef<dyn Op>
703        + AsMut<dyn Op>
704        + Clone
705        + 'static
706        + for<'a> std::convert::From<&'a O>,
707    Graph<F, O>: SpecialOps<F, O>,
708{
709    #[cfg(debug_assertions)]
710    pub fn check_compact(&self) -> TractResult<()> {
711        let order = self.eval_order()?;
712        let useless_sources = self
713            .input_outlets()?
714            .iter()
715            .filter(|io| {
716                self.outlet_successors(**io).len() == 0
717                    && !self.output_outlets().unwrap().contains(io)
718            })
719            .count();
720        if order.len() + useless_sources != self.nodes.len() {
721            bail!(
722                "Eval order is {} long, nodes are {}, including {} unused sources",
723                order.len(),
724                self.nodes.len(),
725                useless_sources
726            );
727        }
728        if (0..order.len()).any(|ix| order[ix] != ix) {
729            bail!("eval order is not trivial");
730        }
731        let mut seen = std::collections::HashSet::new();
732        for (ix, n) in self.nodes.iter().enumerate() {
733            if ix != n.id {
734                bail!("Invalid node id: position is {}, node is {}", ix, n);
735            }
736            if seen.contains(&n.name) {
737                bail!("duplicate name for node {n}");
738            }
739            seen.insert(&n.name);
740        }
741        Ok(())
742    }
743
744    pub fn compact(&mut self) -> TractResult<()> {
745        let mut order = self.eval_order()?;
746        if order.len() == self.nodes.len() && order.iter().enumerate().all(|(a, b)| a == *b) {
747            return Ok(());
748        }
749        for i in &self.inputs {
750            if !order.contains(&i.node) {
751                order.push(i.node);
752            }
753        }
754        let mut old_to_new = vec![usize::MAX; self.nodes.len()];
755        let mut new_nodes = vec![
756            Node {
757                id: self.nodes.len(),
758                name: "".to_string(),
759                inputs: vec![],
760                op: self.create_dummy(),
761                outputs: tvec!(),
762            };
763            order.len()
764        ];
765        for (ix, id) in order.iter().enumerate() {
766            old_to_new[*id] = ix;
767            std::mem::swap(&mut new_nodes[ix], &mut self.nodes[*id]);
768        }
769        for node in &mut new_nodes {
770            if self.inputs.iter().any(|n| n.node == node.id) && !Self::is_source(&node.op) {
771                node.inputs.clear();
772                node.op = self.create_source(node.outputs[0].fact.clone());
773            }
774            node.id = old_to_new[node.id];
775            for input in &mut node.inputs {
776                assert!(old_to_new[input.node] < order.len());
777                input.node = old_to_new[input.node];
778            }
779            for output in &mut node.outputs {
780                for succ in &mut output.successors {
781                    succ.node = old_to_new[succ.node];
782                }
783                output.successors.retain(|s| s.node < order.len());
784                output.successors.sort();
785            }
786        }
787        self.nodes = new_nodes;
788        for input in &mut self.inputs {
789            assert!(old_to_new[input.node] < order.len());
790            input.node = old_to_new[input.node];
791        }
792        for output in &mut self.outputs {
793            assert!(old_to_new[output.node] < order.len());
794            output.node = old_to_new[output.node];
795        }
796        self.outlet_labels = std::mem::take(&mut self.outlet_labels)
797            .into_iter()
798            .map(|(k, v)| (OutletId::new(old_to_new[k.node], k.slot), v))
799            .filter(|(k, _)| k.node < order.len())
800            .collect();
801        ensure!(self.nodes.iter().enumerate().all(|(ix, n)| n.id == ix));
802        #[cfg(debug_assertions)]
803        {
804            self.check_compact().context("after graph compaction")?;
805        }
806        Ok(())
807    }
808
809    pub fn into_compact(mut self) -> TractResult<Self> {
810        self.compact()?;
811        Ok(self)
812    }
813}
814
815pub trait IntoRunnable<F, O>
816where
817    F: Fact + Clone + 'static,
818    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
819{
820    fn into_runnable(self) -> TractResult<Arc<RunnableModel<F, O>>>;
821}
822
823impl<G, F, O> IntoRunnable<F, O> for G
824where
825    F: Fact + Clone + 'static,
826    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
827    G: Into<Arc<Graph<F, O>>>,
828{
829    fn into_runnable(self) -> TractResult<Arc<RunnableModel<F, O>>> {
830        SimplePlan::new(self)
831    }
832}