tract_core/model/
node.rs

1use super::*;
2use crate::internal::*;
3use crate::ops::Op;
4use tract_itertools::Itertools;
5use std::fmt;
6use std::fmt::{Debug, Display};
7
8/// A Node in an Model.
9///
10/// Parameterized by a Fact implementation matching the one used in the
11/// model.
12#[derive(Debug, Clone)]
13pub struct Node<F: Fact , O> {
14    /// node id in the model
15    ///
16    /// Caution: this id will not be persistent during networks transformation
17    pub id: usize,
18    /// name of the node
19    ///
20    /// This will usually come from the importing framework. `tract`
21    /// transformation try to maintain the names accross transformations.
22    pub name: String,
23    /// A list of incoming tensors, identified by the node outlet that creates
24    /// them.
25    pub inputs: Vec<OutletId>,
26    /// The actual operation the node performs.
27    pub op: O,
28    /// List of ouputs, with their descendant and tensor type information.
29    pub outputs: TVec<Outlet<F>>,
30}
31
32impl<F: Fact , O: std::fmt::Display> fmt::Display for Node<F, O> {
33    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
34        write!(fmt, "#{} \"{}\" {}", self.id, self.name, self.op)
35    }
36}
37
38impl<F, NodeOp> Node<F, NodeOp>
39where
40    F: Fact ,
41    NodeOp: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + AsMut<dyn Op> ,
42{
43    /// Access the op of the node
44    pub fn op(&self) -> &dyn Op {
45        self.op.as_ref()
46    }
47
48    /// Try to downcast the node operation to O.
49    pub fn op_as<O: Op>(&self) -> Option<&O> {
50        self.op().downcast_ref::<O>()
51    }
52
53    /// Try to downcast the node operation to O.
54    pub fn op_as_mut<O: Op>(&mut self) -> Option<&mut O> {
55        self.op.as_mut().downcast_mut::<O>()
56    }
57
58    /// Check if the node operation is of type O.
59    pub fn op_is<O: Op>(&self) -> bool {
60        self.op_as::<O>().is_some()
61    }
62
63    /// Check that this node produce the same outputs as `other`.
64    pub fn same_as(&self, other: &Node<F, NodeOp>) -> bool {
65        self.inputs == other.inputs && self.op().same_as(other.op())
66    }
67}
68
69/// Information for each outlet of a node
70#[derive(Clone, Default)]
71pub struct Outlet<F: Fact > {
72    /// the tensor type information
73    pub fact: F,
74    /// where this outlet is used.
75    pub successors: TVec<InletId>,
76}
77
78impl<F: Fact > fmt::Debug for Outlet<F> {
79    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
80        write!(
81            fmt,
82            "{:?} {}",
83            self.fact,
84            self.successors.iter().map(|o| format!("{o:?}")).join(" ")
85        )
86    }
87}
88
89/// Identifier for a node output in the graph.
90///
91/// This happens to be a unique identifier of any variable tensor in the graph
92/// (as the graph typically connect one single node output to one or several
93/// inputs slots)
94#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, new)]
95pub struct OutletId {
96    /// node identifier in the graph
97    pub node: usize,
98    /// rank of the input in the node
99    pub slot: usize,
100}
101
102impl fmt::Debug for OutletId {
103    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
104        write!(fmt, "{}/{}>", self.node, self.slot)
105    }
106}
107
108impl From<usize> for OutletId {
109    fn from(node: usize) -> OutletId {
110        OutletId::new(node, 0)
111    }
112}
113
114impl From<(usize, usize)> for OutletId {
115    fn from(pair: (usize, usize)) -> OutletId {
116        OutletId::new(pair.0, pair.1)
117    }
118}
119
120/// Identifier for a node input in the graph.
121#[derive(Clone, Copy, PartialEq, Eq, Hash, new, Ord, PartialOrd)]
122pub struct InletId {
123    /// node identifier in the graph
124    pub node: usize,
125    /// rank of the input in the node
126    pub slot: usize,
127}
128
129impl fmt::Debug for InletId {
130    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
131        write!(fmt, ">{}/{}", self.node, self.slot)
132    }
133}