1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use crate::internal::*;

pub mod array;
pub mod binary;
pub mod change_axes;
pub mod cnn;
pub mod delay;
pub mod downsample;
pub mod dummy;
pub mod element_wise;
pub mod matmul;
pub mod nn;
pub mod quant;
pub mod scan;
pub mod source;

register_all_mod!(
    array,
    binary,
    change_axes,
    cnn,
    downsample,
    element_wise,
    matmul,
    nn,
    quant,
    scan,
    source
);

pub struct OpPulsifier {
    pub type_id: std::any::TypeId,
    pub name: &'static str,
    pub func: fn(
        &TypedModel,
        &TypedNode,
        &mut PulsedModel,
        &HashMap<OutletId, OutletId>,
        usize,
    ) -> TractResult<TVec<OutletId>>,
}

impl OpPulsifier {
    pub fn inventory() -> HashMap<TypeId, OpPulsifier> {
        let mut inventory = HashMap::default();
        register_all(&mut inventory);
        inventory
    }
}

pub trait PulsedOp:
    Op
    + fmt::Debug
    + tract_core::dyn_clone::DynClone
    + Send
    + Sync
    + 'static
    + Downcast
    + EvalOp
    + DynHash
{
    /// Reinterpret the PulsedOp as an Op.
    fn as_op(&self) -> &dyn Op;

    /// Reinterpret the PulsedOp as an Op, mutably.
    fn as_op_mut(&mut self) -> &mut dyn Op;

    /// Reinterpret the PulsedOp as an TypedOp.
    fn to_typed(&self) -> Box<dyn TypedOp>;

    /// Deduce output facts from input facts.
    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>>;
}

tract_core::dyn_clone::clone_trait_object!(PulsedOp);

impl Hash for Box<dyn PulsedOp> {
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
        std::hash::Hash::hash(&self.type_id(), state);
        self.dyn_hash(state)
    }
}

impl<O: PulsedOp> From<O> for Box<dyn PulsedOp> {
    fn from(it: O) -> Box<dyn PulsedOp> {
        Box::new(it)
    }
}

impl AsMut<dyn Op> for Box<dyn PulsedOp> {
    fn as_mut(&mut self) -> &mut dyn Op {
        self.as_op_mut()
    }
}

impl AsRef<dyn Op> for dyn PulsedOp {
    fn as_ref(&self) -> &dyn Op {
        self.as_op()
    }
}

impl AsRef<dyn Op> for Box<dyn PulsedOp> {
    fn as_ref(&self) -> &dyn Op {
        self.as_op()
    }
}

impl AsMut<dyn Op> for dyn PulsedOp {
    fn as_mut(&mut self) -> &mut dyn Op {
        self.as_op_mut()
    }
}

impl std::fmt::Display for Box<dyn PulsedOp> {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(fmt, "{}", self.name())
    }
}

impl<'a> From<&'a Box<dyn PulsedOp>> for Box<dyn TypedOp> {
    fn from(op: &'a Box<dyn PulsedOp>) -> Box<dyn TypedOp> {
        op.to_typed()
    }
}