1#![allow(clippy::collapsible_if)]
2use std::any::Any;
3use std::sync::RwLock;
4
5use crate::internal::*;
6use lazy_static::lazy_static;
7use tract_pulse_opl::ops::Delay;
8
9pub mod array;
10pub mod cnn;
11pub mod delay;
12pub mod downsample;
13pub mod dummy;
14pub mod mask;
15pub mod scan;
16pub mod slice;
17pub mod source;
18
19pub(crate) fn sync_inputs(
20 node: &TypedNode,
21 target: &mut PulsedModel,
22 mapping: &HashMap<OutletId, OutletId>,
23) -> TractResult<TVec<OutletId>> {
24 let mut max_delay = 0;
25 for input in &node.inputs {
26 let fact = target.outlet_fact(mapping[input])?;
27 if let Some(stream) = &fact.stream {
28 max_delay = max_delay.max(stream.delay);
29 }
30 }
31 let mut inputs = tvec!();
32 for input in &node.inputs {
33 let mut input = mapping[input];
34 let fact = target.outlet_fact(input)?.clone();
35 if let Some(stream) = &fact.stream {
36 if stream.delay < max_delay {
37 let add_delay = max_delay - stream.delay;
38 let delay_axis = stream.axis;
39 input = target.wire_node(
40 format!("{}.Delay", &*node.name),
41 Delay::new_typed(&fact.into(), delay_axis, add_delay, 0),
42 &[input],
43 )?[0];
44 }
45 }
46 inputs.push(input);
47 }
48 Ok(inputs)
49}
50
51register_all_mod!(array, cnn, downsample, scan, source);
52
53type PulsifierFn = fn(
54 &TypedModel,
55 &TypedNode,
56 &mut PulsedModel,
57 &HashMap<OutletId, OutletId>,
58 &Symbol,
59 &TDim,
60) -> TractResult<Option<TVec<OutletId>>>;
61
62pub struct OpPulsifier {
63 pub type_id: std::any::TypeId,
64 pub name: &'static str,
65 pub func: PulsifierFn,
66}
67
68impl OpPulsifier {
69 pub fn inventory() -> Arc<RwLock<HashMap<TypeId, OpPulsifier>>> {
70 lazy_static! {
71 static ref INVENTORY: Arc<RwLock<HashMap<TypeId, OpPulsifier>>> = {
72 let mut it = HashMap::default();
73 register_all(&mut it);
74 Arc::new(RwLock::new(it))
75 };
76 };
77 (*INVENTORY).clone()
78 }
79
80 pub fn register<T: Any>(func: PulsifierFn) -> TractResult<()> {
81 let inv = Self::inventory();
82 let mut inv = inv.write().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
83 inv.insert(
84 std::any::TypeId::of::<T>(),
85 OpPulsifier {
86 type_id: std::any::TypeId::of::<T>(),
87 name: std::any::type_name::<T>(),
88 func,
89 },
90 );
91 Ok(())
92 }
93
94 pub fn pulsify(
95 source: &TypedModel,
96 node: &TypedNode,
97 target: &mut PulsedModel,
98 mapping: &HashMap<OutletId, OutletId>,
99 symbol: &Symbol,
100 pulse: &TDim,
101 ) -> TractResult<Option<TVec<OutletId>>> {
102 let inv = Self::inventory();
103 let inv = inv.read().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
104 if let Some(pulsifier) = inv.get(&(*node.op).type_id()) {
105 if let Some(pulsified) = (pulsifier.func)(source, node, target, mapping, symbol, pulse)?
106 {
107 return Ok(Some(pulsified));
108 }
109 }
110 Ok(None)
111 }
112}
113
114pub trait PulsedOp:
115 Op + fmt::Debug + tract_core::dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
116{
117 fn as_op(&self) -> &dyn Op;
119
120 fn as_op_mut(&mut self) -> &mut dyn Op;
122
123 fn to_typed(&self) -> Box<dyn TypedOp>;
125
126 fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>>;
128}
129
130tract_core::dyn_clone::clone_trait_object!(PulsedOp);
131
132impl<O: PulsedOp> From<O> for Box<dyn PulsedOp> {
133 fn from(it: O) -> Box<dyn PulsedOp> {
134 Box::new(it)
135 }
136}
137
138impl AsMut<dyn Op> for Box<dyn PulsedOp> {
139 fn as_mut(&mut self) -> &mut dyn Op {
140 self.as_op_mut()
141 }
142}
143
144impl AsRef<dyn Op> for dyn PulsedOp {
145 fn as_ref(&self) -> &dyn Op {
146 self.as_op()
147 }
148}
149
150impl AsRef<dyn Op> for Box<dyn PulsedOp> {
151 fn as_ref(&self) -> &dyn Op {
152 self.as_op()
153 }
154}
155
156impl AsMut<dyn Op> for dyn PulsedOp {
157 fn as_mut(&mut self) -> &mut dyn Op {
158 self.as_op_mut()
159 }
160}
161
162impl std::fmt::Display for Box<dyn PulsedOp> {
163 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
164 write!(fmt, "{}", self.name())
165 }
166}
167
168impl<'a> From<&'a Box<dyn PulsedOp>> for Box<dyn TypedOp> {
169 fn from(op: &'a Box<dyn PulsedOp>) -> Box<dyn TypedOp> {
170 op.to_typed()
171 }
172}