tract_pulse/ops/
mod.rs

1use std::any::Any;
2use std::sync::RwLock;
3
4use crate::internal::*;
5use lazy_static::lazy_static;
6use tract_pulse_opl::ops::Delay;
7
8pub mod array;
9pub mod cnn;
10pub mod delay;
11pub mod downsample;
12pub mod dummy;
13pub mod mask;
14pub mod scan;
15pub mod slice;
16pub mod source;
17
18pub(crate) fn sync_inputs(
19    node: &TypedNode,
20    target: &mut PulsedModel,
21    mapping: &HashMap<OutletId, OutletId>,
22) -> TractResult<TVec<OutletId>> {
23    let mut max_delay = 0;
24    for input in &node.inputs {
25        let fact = target.outlet_fact(mapping[input])?;
26        if let Some(stream) = &fact.stream {
27            max_delay = max_delay.max(stream.delay);
28        }
29    }
30    let mut inputs = tvec!();
31    for input in &node.inputs {
32        let mut input = mapping[input];
33        let fact = target.outlet_fact(input)?.clone();
34        if let Some(stream) = &fact.stream {
35            if stream.delay < max_delay {
36                let add_delay = max_delay - stream.delay;
37                let delay_axis = stream.axis;
38                input = target.wire_node(
39                    format!("{}.Delay", &*node.name),
40                    Delay::new_typed(&fact.into(), delay_axis, add_delay, 0),
41                    &[input],
42                )?[0];
43            }
44        }
45        inputs.push(input);
46    }
47    Ok(inputs)
48}
49
50register_all_mod!(array, cnn, downsample, scan, source);
51
52type PulsifierFn = fn(
53    &TypedModel,
54    &TypedNode,
55    &mut PulsedModel,
56    &HashMap<OutletId, OutletId>,
57    &Symbol,
58    &TDim,
59) -> TractResult<Option<TVec<OutletId>>>;
60
61pub struct OpPulsifier {
62    pub type_id: std::any::TypeId,
63    pub name: &'static str,
64    pub func: PulsifierFn,
65}
66
67impl OpPulsifier {
68    pub fn inventory() -> Arc<RwLock<HashMap<TypeId, OpPulsifier>>> {
69        lazy_static! {
70            static ref INVENTORY: Arc<RwLock<HashMap<TypeId, OpPulsifier>>> = {
71                let mut it = HashMap::default();
72                register_all(&mut it);
73                Arc::new(RwLock::new(it))
74            };
75        };
76        (*INVENTORY).clone()
77    }
78
79    pub fn register<T: Any>(func: PulsifierFn) -> TractResult<()> {
80        let inv = Self::inventory();
81        let mut inv = inv.write().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
82        inv.insert(
83            std::any::TypeId::of::<T>(),
84            OpPulsifier {
85                type_id: std::any::TypeId::of::<T>(),
86                name: std::any::type_name::<T>(),
87                func,
88            },
89        );
90        Ok(())
91    }
92
93    pub fn pulsify(
94        source: &TypedModel,
95        node: &TypedNode,
96        target: &mut PulsedModel,
97        mapping: &HashMap<OutletId, OutletId>,
98        symbol: &Symbol,
99        pulse: &TDim,
100    ) -> TractResult<Option<TVec<OutletId>>> {
101        let inv = Self::inventory();
102        let inv = inv.read().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
103        if let Some(pulsifier) = inv.get(&(*node.op).type_id()) {
104            if let Some(pulsified) = (pulsifier.func)(source, node, target, mapping, symbol, pulse)?
105            {
106                return Ok(Some(pulsified));
107            }
108        }
109        Ok(None)
110    }
111}
112
113pub trait PulsedOp:
114    Op + fmt::Debug + tract_core::dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
115{
116    /// Reinterpret the PulsedOp as an Op.
117    fn as_op(&self) -> &dyn Op;
118
119    /// Reinterpret the PulsedOp as an Op, mutably.
120    fn as_op_mut(&mut self) -> &mut dyn Op;
121
122    /// Reinterpret the PulsedOp as an TypedOp.
123    fn to_typed(&self) -> Box<dyn TypedOp>;
124
125    /// Deduce output facts from input facts.
126    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>>;
127}
128
129tract_core::dyn_clone::clone_trait_object!(PulsedOp);
130
131impl<O: PulsedOp> From<O> for Box<dyn PulsedOp> {
132    fn from(it: O) -> Box<dyn PulsedOp> {
133        Box::new(it)
134    }
135}
136
137impl AsMut<dyn Op> for Box<dyn PulsedOp> {
138    fn as_mut(&mut self) -> &mut dyn Op {
139        self.as_op_mut()
140    }
141}
142
143impl AsRef<dyn Op> for dyn PulsedOp {
144    fn as_ref(&self) -> &dyn Op {
145        self.as_op()
146    }
147}
148
149impl AsRef<dyn Op> for Box<dyn PulsedOp> {
150    fn as_ref(&self) -> &dyn Op {
151        self.as_op()
152    }
153}
154
155impl AsMut<dyn Op> for dyn PulsedOp {
156    fn as_mut(&mut self) -> &mut dyn Op {
157        self.as_op_mut()
158    }
159}
160
161impl std::fmt::Display for Box<dyn PulsedOp> {
162    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
163        write!(fmt, "{}", self.name())
164    }
165}
166
167impl<'a> From<&'a Box<dyn PulsedOp>> for Box<dyn TypedOp> {
168    fn from(op: &'a Box<dyn PulsedOp>) -> Box<dyn TypedOp> {
169        op.to_typed()
170    }
171}