1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
use super::*;
use crate::ops::Op;
use itertools::Itertools;
use std::fmt;
use std::fmt::{Debug, Display};

/// A Smallvec instantiation with 4 embeddable values.
///
/// Used about everywhere in tract, for node inputs and outputs, or
/// tensor dimensions.
pub type TVec<T> = ::smallvec::SmallVec<[T; 4]>;

/// A Node in an Model.
///
/// Parameterized by a Fact implementation matching the one used in the
/// model.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
pub struct BaseNode<F: Fact, O> {
    /// node id in the model
    ///
    /// Caution: this id will not be persistent during networks transformation
    pub id: usize,
    /// name of the node
    ///
    /// This will usually come from the importing framework. `tract`
    /// transformation try to maintain the names accross transformations.
    pub name: String,
    /// A list of incoming tensors, identified by the node outlet that creates
    /// them.
    pub inputs: Vec<OutletId>,
    /// The actual operation the node performs.
    #[cfg_attr(feature = "serialize", serde(skip))]
    pub op: O,
    /// List of ouputs, with their descendant and tensor type information.
    pub outputs: TVec<OutletFact<F>>,
}

impl<F: Fact, O: std::fmt::Display> fmt::Display for BaseNode<F, O> {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(fmt, "#{} \"{}\" {}", self.id, self.name, self.op)
    }
}

pub type Node<F> = BaseNode<F, Box<dyn Op>>;

impl<F: Fact, NodeOp: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + AsMut<dyn Op>>
    BaseNode<F, NodeOp>
{
    /// Access the op of the node
    pub fn op(&self) -> &dyn Op {
        self.op.as_ref()
    }

    /// Try to downcast the node operation to O.
    pub fn op_as<O: Op>(&self) -> Option<&O> {
        self.op().downcast_ref::<O>()
    }

    /// Try to downcast the node operation to O.
    pub fn op_as_mut<O: Op>(&mut self) -> Option<&mut O> {
        self.op.as_mut().downcast_mut::<O>()
    }

    /// Check if the node operation is of type O.
    pub fn op_is<O: Op>(&self) -> bool {
        self.op_as::<O>().is_some()
    }

    /// Check that this node produce the same outputs as `other`.
    pub fn same_as(&self, other: &BaseNode<F, NodeOp>) -> bool {
        self.inputs == other.inputs && self.op().same_as(other.op())
    }
}

/// Information for each outlet of a node
#[derive(Clone, Default)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
pub struct OutletFact<F: Fact> {
    /// the tensor type information
    pub fact: F,
    /// where this outlet is used.
    pub successors: TVec<InletId>,
}

impl<F: Fact> fmt::Debug for OutletFact<F> {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(
            fmt,
            "{:?} {}",
            self.fact,
            self.successors.iter().map(|o| format!("{:?}", o)).join(" ")
        )
    }
}

/// Identifier for a node output in the graph.
///
/// This happens to be a unique identifier of any variable tensor in the graph
/// (as the graph typically connect one single node output to one or several
/// inputs slots)
#[derive(Clone, Copy, PartialEq, Eq, Hash, new)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
pub struct OutletId {
    /// node identifier in the graph
    pub node: usize,
    /// rank of the input in the node
    pub slot: usize,
}

impl fmt::Debug for OutletId {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(fmt, "{}/{}>", self.node, self.slot)
    }
}

impl From<usize> for OutletId {
    fn from(node: usize) -> OutletId {
        OutletId::new(node, 0)
    }
}

impl From<(usize, usize)> for OutletId {
    fn from(pair: (usize, usize)) -> OutletId {
        OutletId::new(pair.0, pair.1)
    }
}

/// Identifier for a node input in the graph.
#[derive(Clone, Copy, PartialEq, Eq, Hash, new, Ord, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
pub struct InletId {
    /// node identifier in the graph
    pub node: usize,
    /// rank of the input in the node
    pub slot: usize,
}

impl fmt::Debug for InletId {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(fmt, ">{}/{}", self.node, self.slot)
    }
}