tract_pulse/ops/
mod.rs

1#![allow(clippy::collapsible_if)]
2use std::any::Any;
3use std::sync::RwLock;
4
5use crate::internal::*;
6use lazy_static::lazy_static;
7use tract_pulse_opl::ops::Delay;
8
9pub mod array;
10pub mod cnn;
11pub mod delay;
12pub mod downsample;
13pub mod dummy;
14pub mod mask;
15pub mod scan;
16pub mod slice;
17pub mod source;
18
19pub(crate) fn sync_inputs(
20    node: &TypedNode,
21    target: &mut PulsedModel,
22    mapping: &HashMap<OutletId, OutletId>,
23) -> TractResult<TVec<OutletId>> {
24    let mut max_delay = 0;
25    for input in &node.inputs {
26        let fact = target.outlet_fact(mapping[input])?;
27        if let Some(stream) = &fact.stream {
28            max_delay = max_delay.max(stream.delay);
29        }
30    }
31    let mut inputs = tvec!();
32    for input in &node.inputs {
33        let mut input = mapping[input];
34        let fact = target.outlet_fact(input)?.clone();
35        if let Some(stream) = &fact.stream {
36            if stream.delay < max_delay {
37                let add_delay = max_delay - stream.delay;
38                let delay_axis = stream.axis;
39                input = target.wire_node(
40                    format!("{}.Delay", &*node.name),
41                    Delay::new_typed(&fact.into(), delay_axis, add_delay, 0),
42                    &[input],
43                )?[0];
44            }
45        }
46        inputs.push(input);
47    }
48    Ok(inputs)
49}
50
51register_all_mod!(array, cnn, downsample, scan, source);
52
53type PulsifierFn = fn(
54    &TypedModel,
55    &TypedNode,
56    &mut PulsedModel,
57    &HashMap<OutletId, OutletId>,
58    &Symbol,
59    &TDim,
60) -> TractResult<Option<TVec<OutletId>>>;
61
62pub struct OpPulsifier {
63    pub type_id: std::any::TypeId,
64    pub name: &'static str,
65    pub func: PulsifierFn,
66}
67
68impl OpPulsifier {
69    pub fn inventory() -> Arc<RwLock<HashMap<TypeId, OpPulsifier>>> {
70        lazy_static! {
71            static ref INVENTORY: Arc<RwLock<HashMap<TypeId, OpPulsifier>>> = {
72                let mut it = HashMap::default();
73                register_all(&mut it);
74                Arc::new(RwLock::new(it))
75            };
76        };
77        (*INVENTORY).clone()
78    }
79
80    pub fn register<T: Any>(func: PulsifierFn) -> TractResult<()> {
81        let inv = Self::inventory();
82        let mut inv = inv.write().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
83        inv.insert(
84            std::any::TypeId::of::<T>(),
85            OpPulsifier {
86                type_id: std::any::TypeId::of::<T>(),
87                name: std::any::type_name::<T>(),
88                func,
89            },
90        );
91        Ok(())
92    }
93
94    pub fn pulsify(
95        source: &TypedModel,
96        node: &TypedNode,
97        target: &mut PulsedModel,
98        mapping: &HashMap<OutletId, OutletId>,
99        symbol: &Symbol,
100        pulse: &TDim,
101    ) -> TractResult<Option<TVec<OutletId>>> {
102        let inv = Self::inventory();
103        let inv = inv.read().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
104        if let Some(pulsifier) = inv.get(&(*node.op).type_id()) {
105            if let Some(pulsified) = (pulsifier.func)(source, node, target, mapping, symbol, pulse)?
106            {
107                return Ok(Some(pulsified));
108            }
109        }
110        Ok(None)
111    }
112}
113
114pub trait PulsedOp:
115    Op + fmt::Debug + tract_core::dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
116{
117    /// Reinterpret the PulsedOp as an Op.
118    fn as_op(&self) -> &dyn Op;
119
120    /// Reinterpret the PulsedOp as an Op, mutably.
121    fn as_op_mut(&mut self) -> &mut dyn Op;
122
123    /// Reinterpret the PulsedOp as an TypedOp.
124    fn to_typed(&self) -> Box<dyn TypedOp>;
125
126    /// Deduce output facts from input facts.
127    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>>;
128}
129
130tract_core::dyn_clone::clone_trait_object!(PulsedOp);
131
132impl<O: PulsedOp> From<O> for Box<dyn PulsedOp> {
133    fn from(it: O) -> Box<dyn PulsedOp> {
134        Box::new(it)
135    }
136}
137
138impl AsMut<dyn Op> for Box<dyn PulsedOp> {
139    fn as_mut(&mut self) -> &mut dyn Op {
140        self.as_op_mut()
141    }
142}
143
144impl AsRef<dyn Op> for dyn PulsedOp {
145    fn as_ref(&self) -> &dyn Op {
146        self.as_op()
147    }
148}
149
150impl AsRef<dyn Op> for Box<dyn PulsedOp> {
151    fn as_ref(&self) -> &dyn Op {
152        self.as_op()
153    }
154}
155
156impl AsMut<dyn Op> for dyn PulsedOp {
157    fn as_mut(&mut self) -> &mut dyn Op {
158        self.as_op_mut()
159    }
160}
161
162impl std::fmt::Display for Box<dyn PulsedOp> {
163    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
164        write!(fmt, "{}", self.name())
165    }
166}
167
168impl<'a> From<&'a Box<dyn PulsedOp>> for Box<dyn TypedOp> {
169    fn from(op: &'a Box<dyn PulsedOp>) -> Box<dyn TypedOp> {
170        op.to_typed()
171    }
172}