1use std::any::Any;
2use std::sync::RwLock;
3
4use crate::internal::*;
5use lazy_static::lazy_static;
6use tract_pulse_opl::ops::Delay;
7
8pub mod array;
9pub mod cnn;
10pub mod delay;
11pub mod downsample;
12pub mod dummy;
13pub mod mask;
14pub mod scan;
15pub mod slice;
16pub mod source;
17
18pub(crate) fn sync_inputs(
19 node: &TypedNode,
20 target: &mut PulsedModel,
21 mapping: &HashMap<OutletId, OutletId>,
22) -> TractResult<TVec<OutletId>> {
23 let mut max_delay = 0;
24 for input in &node.inputs {
25 let fact = target.outlet_fact(mapping[input])?;
26 if let Some(stream) = &fact.stream {
27 max_delay = max_delay.max(stream.delay);
28 }
29 }
30 let mut inputs = tvec!();
31 for input in &node.inputs {
32 let mut input = mapping[input];
33 let fact = target.outlet_fact(input)?.clone();
34 if let Some(stream) = &fact.stream {
35 if stream.delay < max_delay {
36 let add_delay = max_delay - stream.delay;
37 let delay_axis = stream.axis;
38 input = target.wire_node(
39 format!("{}.Delay", &*node.name),
40 Delay::new_typed(&fact.into(), delay_axis, add_delay, 0),
41 &[input],
42 )?[0];
43 }
44 }
45 inputs.push(input);
46 }
47 Ok(inputs)
48}
49
50register_all_mod!(array, cnn, downsample, scan, source);
51
52type PulsifierFn = fn(
53 &TypedModel,
54 &TypedNode,
55 &mut PulsedModel,
56 &HashMap<OutletId, OutletId>,
57 &Symbol,
58 &TDim,
59) -> TractResult<Option<TVec<OutletId>>>;
60
61pub struct OpPulsifier {
62 pub type_id: std::any::TypeId,
63 pub name: &'static str,
64 pub func: PulsifierFn,
65}
66
67impl OpPulsifier {
68 pub fn inventory() -> Arc<RwLock<HashMap<TypeId, OpPulsifier>>> {
69 lazy_static! {
70 static ref INVENTORY: Arc<RwLock<HashMap<TypeId, OpPulsifier>>> = {
71 let mut it = HashMap::default();
72 register_all(&mut it);
73 Arc::new(RwLock::new(it))
74 };
75 };
76 (*INVENTORY).clone()
77 }
78
79 pub fn register<T: Any>(func: PulsifierFn) -> TractResult<()> {
80 let inv = Self::inventory();
81 let mut inv = inv.write().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
82 inv.insert(
83 std::any::TypeId::of::<T>(),
84 OpPulsifier {
85 type_id: std::any::TypeId::of::<T>(),
86 name: std::any::type_name::<T>(),
87 func,
88 },
89 );
90 Ok(())
91 }
92
93 pub fn pulsify(
94 source: &TypedModel,
95 node: &TypedNode,
96 target: &mut PulsedModel,
97 mapping: &HashMap<OutletId, OutletId>,
98 symbol: &Symbol,
99 pulse: &TDim,
100 ) -> TractResult<Option<TVec<OutletId>>> {
101 let inv = Self::inventory();
102 let inv = inv.read().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
103 if let Some(pulsifier) = inv.get(&(*node.op).type_id()) {
104 if let Some(pulsified) = (pulsifier.func)(source, node, target, mapping, symbol, pulse)?
105 {
106 return Ok(Some(pulsified));
107 }
108 }
109 Ok(None)
110 }
111}
112
113pub trait PulsedOp:
114 Op + fmt::Debug + tract_core::dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
115{
116 fn as_op(&self) -> &dyn Op;
118
119 fn as_op_mut(&mut self) -> &mut dyn Op;
121
122 fn to_typed(&self) -> Box<dyn TypedOp>;
124
125 fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>>;
127}
128
129tract_core::dyn_clone::clone_trait_object!(PulsedOp);
130
131impl<O: PulsedOp> From<O> for Box<dyn PulsedOp> {
132 fn from(it: O) -> Box<dyn PulsedOp> {
133 Box::new(it)
134 }
135}
136
137impl AsMut<dyn Op> for Box<dyn PulsedOp> {
138 fn as_mut(&mut self) -> &mut dyn Op {
139 self.as_op_mut()
140 }
141}
142
143impl AsRef<dyn Op> for dyn PulsedOp {
144 fn as_ref(&self) -> &dyn Op {
145 self.as_op()
146 }
147}
148
149impl AsRef<dyn Op> for Box<dyn PulsedOp> {
150 fn as_ref(&self) -> &dyn Op {
151 self.as_op()
152 }
153}
154
155impl AsMut<dyn Op> for dyn PulsedOp {
156 fn as_mut(&mut self) -> &mut dyn Op {
157 self.as_op_mut()
158 }
159}
160
161impl std::fmt::Display for Box<dyn PulsedOp> {
162 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
163 write!(fmt, "{}", self.name())
164 }
165}
166
167impl<'a> From<&'a Box<dyn PulsedOp>> for Box<dyn TypedOp> {
168 fn from(op: &'a Box<dyn PulsedOp>) -> Box<dyn TypedOp> {
169 op.to_typed()
170 }
171}