1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
use super::*;
use crate::internal::*;
use crate::ops::Op;
use tract_itertools::Itertools;
use std::fmt;
use std::fmt::{Debug, Display};
/// A Node in an Model.
///
/// Parameterized by a Fact implementation matching the one used in the
/// model.
#[derive(Debug, Clone)]
pub struct Node<F: Fact , O> {
/// node id in the model
///
/// Caution: this id will not be persistent during networks transformation
pub id: usize,
/// name of the node
///
/// This will usually come from the importing framework. `tract`
/// transformation try to maintain the names accross transformations.
pub name: String,
/// A list of incoming tensors, identified by the node outlet that creates
/// them.
pub inputs: Vec<OutletId>,
/// The actual operation the node performs.
#[cfg_attr(feature = "serialize", serde(skip))]
pub op: O,
/// List of ouputs, with their descendant and tensor type information.
pub outputs: TVec<Outlet<F>>,
}
impl<F: Fact , O: std::fmt::Display> fmt::Display for Node<F, O> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "#{} \"{}\" {}", self.id, self.name, self.op)
}
}
impl<F, NodeOp> Node<F, NodeOp>
where
F: Fact ,
NodeOp: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + AsMut<dyn Op> ,
{
/// Access the op of the node
pub fn op(&self) -> &dyn Op {
self.op.as_ref()
}
/// Try to downcast the node operation to O.
pub fn op_as<O: Op>(&self) -> Option<&O> {
self.op().downcast_ref::<O>()
}
/// Try to downcast the node operation to O.
pub fn op_as_mut<O: Op>(&mut self) -> Option<&mut O> {
self.op.as_mut().downcast_mut::<O>()
}
/// Check if the node operation is of type O.
pub fn op_is<O: Op>(&self) -> bool {
self.op_as::<O>().is_some()
}
/// Check that this node produce the same outputs as `other`.
pub fn same_as(&self, other: &Node<F, NodeOp>) -> bool {
self.inputs == other.inputs && self.op().same_as(other.op())
}
}
/// Information for each outlet of a node
#[derive(Clone, Default)]
pub struct Outlet<F: Fact > {
/// the tensor type information
pub fact: F,
/// where this outlet is used.
pub successors: TVec<InletId>,
}
impl<F: Fact > fmt::Debug for Outlet<F> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(
fmt,
"{:?} {}",
self.fact,
self.successors.iter().map(|o| format!("{o:?}")).join(" ")
)
}
}
/// Identifier for a node output in the graph.
///
/// This happens to be a unique identifier of any variable tensor in the graph
/// (as the graph typically connect one single node output to one or several
/// inputs slots)
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, new)]
pub struct OutletId {
/// node identifier in the graph
pub node: usize,
/// rank of the input in the node
pub slot: usize,
}
impl fmt::Debug for OutletId {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}/{}>", self.node, self.slot)
}
}
impl From<usize> for OutletId {
fn from(node: usize) -> OutletId {
OutletId::new(node, 0)
}
}
impl From<(usize, usize)> for OutletId {
fn from(pair: (usize, usize)) -> OutletId {
OutletId::new(pair.0, pair.1)
}
}
/// Identifier for a node input in the graph.
#[derive(Clone, Copy, PartialEq, Eq, Hash, new, Ord, PartialOrd)]
pub struct InletId {
/// node identifier in the graph
pub node: usize,
/// rank of the input in the node
pub slot: usize,
}
impl fmt::Debug for InletId {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, ">{}/{}", self.node, self.slot)
}
}