tract_pulse/
model.rs

1use std::sync::RwLock;
2
3use crate::fact::StreamInfo;
4use crate::{internal::*, ops::sync_inputs};
5use tract_core::model::translator::Translate;
6use tract_pulse_opl::tract_core::ops::konst::Const;
7use tract_pulse_opl::tract_core::ops::source::TypedSource;
8
9pub type PulsedModel = Graph<PulsedFact, Box<dyn PulsedOp>>;
10pub type PulsedNode = Node<PulsedFact, Box<dyn PulsedOp>>;
11
12#[allow(clippy::new_ret_no_self)]
13pub trait PulsedModelExt {
14    fn new(source: &TypedModel, symbol: Symbol, pulse: &TDim) -> TractResult<PulsedModel>;
15
16    fn new_with_mapping(
17        source: &TypedModel,
18        symbol: Symbol,
19        pulse: &TDim,
20    ) -> TractResult<(PulsedModel, HashMap<OutletId, OutletId>)>;
21
22    fn into_typed(self) -> TractResult<TypedModel>;
23}
24
25impl PulsedModelExt for PulsedModel {
26    fn new(source: &TypedModel, symbol: Symbol, pulse: &TDim) -> TractResult<PulsedModel> {
27        Ok(PulsedModel::new_with_mapping(source, symbol, pulse)?.0)
28    }
29
30    fn new_with_mapping(
31        source: &TypedModel,
32        symbol: Symbol,
33        pulse: &TDim,
34    ) -> TractResult<(PulsedModel, HashMap<OutletId, OutletId>)> {
35        let pulsifiers = crate::ops::OpPulsifier::inventory();
36        Pulsifier(symbol, pulse.to_owned(), pulsifiers).translate_model_with_mappings(source)
37    }
38
39    fn into_typed(self) -> TractResult<TypedModel> {
40        let mut typed = tract_core::model::translator::IntoTranslator.translate_model(&self)?;
41        ensure!(self.input_outlets()?.iter().all(|o| self
42            .outlet_fact(*o)
43            .unwrap()
44            .stream
45            .is_some()));
46        ensure!(self.output_outlets()?.iter().all(|o| self
47            .outlet_fact(*o)
48            .unwrap()
49            .stream
50            .is_some()));
51        let delays = tensor1(
52            &self
53                .output_outlets()?
54                .iter()
55                .map(|oo| Ok(self.outlet_fact(*oo)?.stream.as_ref().unwrap().delay as _))
56                .collect::<TractResult<TVec<i64>>>()?,
57        );
58        typed.properties.insert("pulse.delay".to_string(), delays.into_arc_tensor());
59        let input_axes = tensor1(
60            &self
61                .input_outlets()?
62                .iter()
63                .map(|oo| Ok(self.outlet_fact(*oo)?.stream.as_ref().unwrap().axis as _))
64                .collect::<TractResult<TVec<i64>>>()?,
65        );
66        typed.properties.insert("pulse.input_axes".to_string(), input_axes.into_arc_tensor());
67        let output_axes = tensor1(
68            &self
69                .output_outlets()?
70                .iter()
71                .map(|oo| Ok(self.outlet_fact(*oo)?.stream.as_ref().unwrap().axis as _))
72                .collect::<TractResult<TVec<i64>>>()?,
73        );
74        typed.properties.insert("pulse.output_axes".to_string(), output_axes.into_arc_tensor());
75        Ok(typed)
76    }
77}
78
79impl SpecialOps<PulsedFact, Box<dyn PulsedOp>> for PulsedModel {
80    fn is_source(op: &Box<dyn PulsedOp>) -> bool {
81        op.as_op().downcast_ref::<crate::ops::source::PulsedSource>().is_some()
82    }
83
84    fn create_source(&self, fact: PulsedFact) -> Box<dyn PulsedOp> {
85        Box::new(crate::ops::source::PulsedSource(fact))
86    }
87
88    fn create_dummy(&self) -> Box<dyn PulsedOp> {
89        Box::new(tract_core::ops::dummy::Dummy::new())
90    }
91
92    fn wire_node(
93        &mut self,
94        name: impl Into<String>,
95        op: impl Into<Box<dyn PulsedOp>>,
96        inputs: &[OutletId],
97    ) -> TractResult<TVec<OutletId>> {
98        let op = op.into();
99        let output_facts = {
100            let input_facts =
101                inputs.iter().map(|o| self.outlet_fact(*o)).collect::<TractResult<TVec<_>>>()?;
102            op.pulsed_output_facts(&input_facts)?
103        };
104        let id = self.add_node(name, op, output_facts)?;
105        inputs
106            .iter()
107            .enumerate()
108            .try_for_each(|(ix, i)| self.add_edge(*i, InletId::new(id, ix)))?;
109        Ok(self.node(id).outputs.iter().enumerate().map(|(ix, _)| OutletId::new(id, ix)).collect())
110    }
111
112    fn add_const(
113        &mut self,
114        name: impl Into<String>,
115        v: impl IntoArcTensor,
116    ) -> TractResult<OutletId> {
117        let v = v.into_arc_tensor();
118        for node in &self.nodes {
119            if let Some(op) = node.op_as::<Const>() {
120                if op.val() == &v {
121                    return Ok(node.id.into());
122                }
123            }
124        }
125        let op = NonPulsingWrappingOp(Box::new(Const::new(v)?));
126        Ok(self.wire_node(name, op, &[])?[0])
127    }
128}
129
130struct Pulsifier(
131    Symbol,
132    TDim,
133    #[allow(dead_code)] Arc<RwLock<HashMap<TypeId, crate::ops::OpPulsifier>>>,
134);
135
136impl std::fmt::Debug for Pulsifier {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write!(f, "Pulsifier({})", self.0)
139    }
140}
141
142impl
143    tract_core::model::translator::Translate<
144        TypedFact,
145        Box<dyn TypedOp>,
146        PulsedFact,
147        Box<dyn PulsedOp>,
148    > for Pulsifier
149{
150    fn translate_node(
151        &self,
152        source: &TypedModel,
153        node: &TypedNode,
154        target: &mut PulsedModel,
155        mapping: &HashMap<OutletId, OutletId>,
156    ) -> TractResult<TVec<OutletId>> {
157        if let Some(op) = node.op_as::<TypedSource>() {
158            return Ok(crate::ops::source::pulsify(
159                op, source, node, target, mapping, &self.0, &self.1,
160            )?
161            .unwrap());
162        }
163        log::debug!("Pulsifying node {node}");
164
165        if !source
166            .node_input_facts(node.id)?
167            .iter()
168            .any(|f| f.shape.iter().any(|d| d.symbols().contains(&self.0)))
169            && !node
170                .outputs
171                .iter()
172                .any(|o| o.fact.shape.iter().any(|d| d.symbols().contains(&self.0)))
173        {
174            let pulse_op = NonPulsingWrappingOp(node.op.clone());
175            let inputs: TVec<OutletId> = node.inputs.iter().map(|i| mapping[i]).collect();
176            log::debug!("Pulsified node {node} with NonPulsingWrappingOp");
177            return target.wire_node(&node.name, pulse_op, &inputs);
178        }
179
180        if let Some(pulsified) =
181            OpPulsifier::pulsify(source, node, target, mapping, &self.0, &self.1)?
182        {
183            log::debug!("Pulsified node {node} with adhoc pulsifier");
184            return Ok(pulsified);
185        }
186
187        let pulse_facts: TVec<PulsedFact> =
188            node.inputs.iter().map(|i| target.outlet_fact(mapping[i]).unwrap().clone()).collect();
189        if pulse_facts.iter().all(|pf| pf.stream.is_none()) {
190            let pulse_op = NonPulsingWrappingOp(node.op.clone());
191            let inputs: TVec<OutletId> = node.inputs.iter().map(|i| mapping[i]).collect();
192            log::debug!("Pulsified node {node} with NonPulsingWrappingOp");
193            return target.wire_node(&node.name, pulse_op, &inputs);
194        }
195
196        let (stream_input_ix, pulse_fact) =
197            pulse_facts.iter().enumerate().find(|(_ix, pf)| pf.stream.is_some()).unwrap();
198        let (input_facts, output_facts) = source.node_facts(node.id)?;
199        let axes_mapping = node.op.axes_mapping(&input_facts, &output_facts)?;
200        let axis_info = axes_mapping
201            .axis((InOut::In(stream_input_ix), pulse_fact.stream.as_ref().unwrap().axis))?;
202        if axis_info.outputs[0].len() == 1 {
203            let pulse_op = PulseWrappingOp(node.op.clone());
204            let inputs = sync_inputs(node, target, mapping)?;
205            log::debug!("Pulsified node {node} with PulsingWrappingOp");
206            return target.wire_node(&node.name, pulse_op, &inputs);
207        }
208
209        bail!("No specific pulse transformation for {}, and could not track pulsing axis.", node)
210    }
211}
212
213#[derive(Debug, Clone)]
214pub(crate) struct PulseWrappingOp(pub Box<dyn TypedOp>);
215
216impl Op for PulseWrappingOp {
217    fn name(&self) -> Cow<str> {
218        format!("PulseWrapping({}", self.0.name()).into()
219    }
220
221    fn as_typed(&self) -> Option<&dyn TypedOp> {
222        Some(self.0.as_ref())
223    }
224}
225
226impl EvalOp for PulseWrappingOp {
227    fn is_stateless(&self) -> bool {
228        self.0.is_stateless()
229    }
230
231    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
232        self.0.eval(inputs)
233    }
234
235    fn state(
236        &self,
237        session: &mut SessionState,
238        node_id: usize,
239    ) -> TractResult<Option<Box<dyn OpState>>> {
240        self.0.state(session, node_id)
241    }
242}
243
244impl PulsedOp for PulseWrappingOp {
245    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
246        let (pulsing_input, stream) = if let Some((ix, fact)) =
247            &inputs.iter().enumerate().find(|(_ix, f)| f.stream.is_some())
248        {
249            (*ix, fact.stream.as_ref().unwrap())
250        } else {
251            bail!("PulseWrappingOp used on non streaming input")
252        };
253        let input_facts =
254            inputs.iter().map(|pf| pf.to_typed_fact()).collect::<TractResult<TVec<_>>>()?;
255        let input_facts_ref = input_facts.iter().map(|f| f.as_ref()).collect::<TVec<_>>();
256        let output_facts = self.0.output_facts(&input_facts_ref)?;
257        let output_facts_ref = output_facts.iter().collect::<TVec<_>>();
258        let axes_mapping = self.0.axes_mapping(&input_facts_ref, &output_facts_ref)?;
259        let axis_info = axes_mapping.axis((InOut::In(pulsing_input), stream.axis))?;
260        std::mem::drop(output_facts_ref);
261        output_facts
262            .into_iter()
263            .enumerate()
264            .map(|(ix, tf)| {
265                if let &[axis] = &*axis_info.outputs[ix] {
266                    Ok(PulsedFact {
267                        shape: tf.shape,
268                        datum_type: tf.datum_type,
269                        stream: Some(StreamInfo {
270                            delay: stream.delay,
271                            axis,
272                            dim: stream.dim.clone(),
273                        }),
274                    })
275                } else {
276                    bail!("Disappearing pulsing axis")
277                }
278            })
279            .collect()
280    }
281
282    as_op!();
283
284    fn to_typed(&self) -> Box<dyn TypedOp> {
285        self.0.clone()
286    }
287}
288
289#[derive(Debug, Clone)]
290pub(crate) struct NonPulsingWrappingOp(pub Box<dyn TypedOp>);
291
292impl Op for NonPulsingWrappingOp {
293    fn name(&self) -> Cow<str> {
294        format!("NonePulsingWrapping({}", self.0.name()).into()
295    }
296
297    fn as_typed(&self) -> Option<&dyn TypedOp> {
298        Some(self.0.as_ref())
299    }
300}
301
302impl EvalOp for NonPulsingWrappingOp {
303    fn is_stateless(&self) -> bool {
304        self.0.is_stateless()
305    }
306
307    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
308        self.0.eval(inputs)
309    }
310
311    fn state(
312        &self,
313        session: &mut SessionState,
314        node_id: usize,
315    ) -> TractResult<Option<Box<dyn OpState>>> {
316        self.0.state(session, node_id)
317    }
318}
319
320impl PulsedOp for NonPulsingWrappingOp {
321    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
322        let input_facts =
323            inputs.iter().map(|pf| pf.to_typed_fact()).collect::<TractResult<TVec<_>>>()?;
324        let input_facts_ref = input_facts.iter().map(|f| f.as_ref()).collect::<TVec<_>>();
325        let output_facts = self.0.output_facts(&input_facts_ref)?;
326        let output_facts_ref = output_facts.iter().collect::<TVec<_>>();
327        std::mem::drop(output_facts_ref);
328        output_facts
329            .into_iter()
330            .map(|tf| Ok(PulsedFact { shape: tf.shape, datum_type: tf.datum_type, stream: None }))
331            .collect()
332    }
333
334    as_op!();
335
336    fn to_typed(&self) -> Box<dyn TypedOp> {
337        self.0.clone()
338    }
339}