tract_pulse/
model.rs

1#![allow(clippy::collapsible_if)]
2use std::sync::RwLock;
3
4use crate::fact::StreamInfo;
5use crate::{internal::*, ops::sync_inputs};
6use tract_core::model::translator::Translate;
7use tract_pulse_opl::tract_core::ops::konst::Const;
8use tract_pulse_opl::tract_core::ops::source::TypedSource;
9
10pub type PulsedModel = Graph<PulsedFact, Box<dyn PulsedOp>>;
11pub type PulsedNode = Node<PulsedFact, Box<dyn PulsedOp>>;
12
13#[allow(clippy::new_ret_no_self)]
14pub trait PulsedModelExt {
15    fn new(source: &TypedModel, symbol: Symbol, pulse: &TDim) -> TractResult<PulsedModel>;
16
17    fn new_with_mapping(
18        source: &TypedModel,
19        symbol: Symbol,
20        pulse: &TDim,
21    ) -> TractResult<(PulsedModel, HashMap<OutletId, OutletId>)>;
22
23    fn into_typed(self) -> TractResult<TypedModel>;
24}
25
26impl PulsedModelExt for PulsedModel {
27    fn new(source: &TypedModel, symbol: Symbol, pulse: &TDim) -> TractResult<PulsedModel> {
28        Ok(PulsedModel::new_with_mapping(source, symbol, pulse)?.0)
29    }
30
31    fn new_with_mapping(
32        source: &TypedModel,
33        symbol: Symbol,
34        pulse: &TDim,
35    ) -> TractResult<(PulsedModel, HashMap<OutletId, OutletId>)> {
36        let pulsifiers = crate::ops::OpPulsifier::inventory();
37        Pulsifier(symbol, pulse.to_owned(), pulsifiers).translate_model_with_mappings(source)
38    }
39
40    fn into_typed(self) -> TractResult<TypedModel> {
41        let mut typed = tract_core::model::translator::IntoTranslator.translate_model(&self)?;
42        ensure!(self.input_outlets()?.iter().all(|o| self
43            .outlet_fact(*o)
44            .unwrap()
45            .stream
46            .is_some()));
47        ensure!(self.output_outlets()?.iter().all(|o| self
48            .outlet_fact(*o)
49            .unwrap()
50            .stream
51            .is_some()));
52        let delays = tensor1(
53            &self
54                .output_outlets()?
55                .iter()
56                .map(|oo| Ok(self.outlet_fact(*oo)?.stream.as_ref().unwrap().delay as _))
57                .collect::<TractResult<TVec<i64>>>()?,
58        );
59        typed.properties.insert("pulse.delay".to_string(), delays.into_arc_tensor());
60        let input_axes = tensor1(
61            &self
62                .input_outlets()?
63                .iter()
64                .map(|oo| Ok(self.outlet_fact(*oo)?.stream.as_ref().unwrap().axis as _))
65                .collect::<TractResult<TVec<i64>>>()?,
66        );
67        typed.properties.insert("pulse.input_axes".to_string(), input_axes.into_arc_tensor());
68        let output_axes = tensor1(
69            &self
70                .output_outlets()?
71                .iter()
72                .map(|oo| Ok(self.outlet_fact(*oo)?.stream.as_ref().unwrap().axis as _))
73                .collect::<TractResult<TVec<i64>>>()?,
74        );
75        typed.properties.insert("pulse.output_axes".to_string(), output_axes.into_arc_tensor());
76        Ok(typed)
77    }
78}
79
80impl SpecialOps<PulsedFact, Box<dyn PulsedOp>> for PulsedModel {
81    fn is_source(op: &Box<dyn PulsedOp>) -> bool {
82        op.as_op().downcast_ref::<crate::ops::source::PulsedSource>().is_some()
83    }
84
85    fn create_source(&self, fact: PulsedFact) -> Box<dyn PulsedOp> {
86        Box::new(crate::ops::source::PulsedSource(fact))
87    }
88
89    fn create_dummy(&self) -> Box<dyn PulsedOp> {
90        Box::new(tract_core::ops::dummy::Dummy::new())
91    }
92
93    fn wire_node(
94        &mut self,
95        name: impl Into<String>,
96        op: impl Into<Box<dyn PulsedOp>>,
97        inputs: &[OutletId],
98    ) -> TractResult<TVec<OutletId>> {
99        let op = op.into();
100        let output_facts = {
101            let input_facts =
102                inputs.iter().map(|o| self.outlet_fact(*o)).collect::<TractResult<TVec<_>>>()?;
103            op.pulsed_output_facts(&input_facts)?
104        };
105        let id = self.add_node(name, op, output_facts)?;
106        inputs
107            .iter()
108            .enumerate()
109            .try_for_each(|(ix, i)| self.add_edge(*i, InletId::new(id, ix)))?;
110        Ok(self.node(id).outputs.iter().enumerate().map(|(ix, _)| OutletId::new(id, ix)).collect())
111    }
112
113    fn add_const(
114        &mut self,
115        name: impl Into<String>,
116        v: impl IntoArcTensor,
117    ) -> TractResult<OutletId> {
118        let v = v.into_arc_tensor();
119        for node in &self.nodes {
120            if let Some(op) = node.op_as::<Const>() {
121                if op.val() == &v {
122                    return Ok(node.id.into());
123                }
124            }
125        }
126        let op = NonPulsingWrappingOp(Box::new(Const::new(v)?));
127        Ok(self.wire_node(name, op, &[])?[0])
128    }
129}
130
131struct Pulsifier(
132    Symbol,
133    TDim,
134    #[allow(dead_code)] Arc<RwLock<HashMap<TypeId, crate::ops::OpPulsifier>>>,
135);
136
137impl std::fmt::Debug for Pulsifier {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        write!(f, "Pulsifier({})", self.0)
140    }
141}
142
143impl
144    tract_core::model::translator::Translate<
145        TypedFact,
146        Box<dyn TypedOp>,
147        PulsedFact,
148        Box<dyn PulsedOp>,
149    > for Pulsifier
150{
151    fn translate_node(
152        &self,
153        source: &TypedModel,
154        node: &TypedNode,
155        target: &mut PulsedModel,
156        mapping: &HashMap<OutletId, OutletId>,
157    ) -> TractResult<TVec<OutletId>> {
158        if let Some(op) = node.op_as::<TypedSource>() {
159            return Ok(crate::ops::source::pulsify(
160                op, source, node, target, mapping, &self.0, &self.1,
161            )?
162            .unwrap());
163        }
164        log::debug!("Pulsifying node {node}");
165
166        if !source
167            .node_input_facts(node.id)?
168            .iter()
169            .any(|f| f.shape.iter().any(|d| d.symbols().contains(&self.0)))
170            && !node
171                .outputs
172                .iter()
173                .any(|o| o.fact.shape.iter().any(|d| d.symbols().contains(&self.0)))
174        {
175            let pulse_op = NonPulsingWrappingOp(node.op.clone());
176            let inputs: TVec<OutletId> = node.inputs.iter().map(|i| mapping[i]).collect();
177            log::debug!("Pulsified node {node} with NonPulsingWrappingOp");
178            return target.wire_node(&node.name, pulse_op, &inputs);
179        }
180
181        if let Some(pulsified) =
182            OpPulsifier::pulsify(source, node, target, mapping, &self.0, &self.1)?
183        {
184            log::debug!("Pulsified node {node} with adhoc pulsifier");
185            return Ok(pulsified);
186        }
187
188        let pulse_facts: TVec<PulsedFact> =
189            node.inputs.iter().map(|i| target.outlet_fact(mapping[i]).unwrap().clone()).collect();
190        if pulse_facts.iter().all(|pf| pf.stream.is_none()) {
191            let pulse_op = NonPulsingWrappingOp(node.op.clone());
192            let inputs: TVec<OutletId> = node.inputs.iter().map(|i| mapping[i]).collect();
193            log::debug!("Pulsified node {node} with NonPulsingWrappingOp");
194            return target.wire_node(&node.name, pulse_op, &inputs);
195        }
196
197        let (stream_input_ix, pulse_fact) =
198            pulse_facts.iter().enumerate().find(|(_ix, pf)| pf.stream.is_some()).unwrap();
199        let (input_facts, output_facts) = source.node_facts(node.id)?;
200        let axes_mapping = node.op.axes_mapping(&input_facts, &output_facts)?;
201        let axis_info = axes_mapping
202            .axis((InOut::In(stream_input_ix), pulse_fact.stream.as_ref().unwrap().axis))?;
203        if axis_info.outputs[0].len() == 1 {
204            let pulse_op = PulseWrappingOp(node.op.clone());
205            let inputs = sync_inputs(node, target, mapping)?;
206            log::debug!("Pulsified node {node} with PulsingWrappingOp");
207            return target.wire_node(&node.name, pulse_op, &inputs);
208        }
209
210        bail!("No specific pulse transformation for {}, and could not track pulsing axis.", node)
211    }
212}
213
214#[derive(Debug, Clone)]
215pub(crate) struct PulseWrappingOp(pub Box<dyn TypedOp>);
216
217impl Op for PulseWrappingOp {
218    fn name(&self) -> StaticName {
219        format!("PulseWrapping({}", self.0.name()).into()
220    }
221
222    fn as_typed(&self) -> Option<&dyn TypedOp> {
223        Some(self.0.as_ref())
224    }
225}
226
227impl EvalOp for PulseWrappingOp {
228    fn is_stateless(&self) -> bool {
229        self.0.is_stateless()
230    }
231
232    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
233        self.0.eval(inputs)
234    }
235
236    fn state(
237        &self,
238        session: &mut SessionState,
239        node_id: usize,
240    ) -> TractResult<Option<Box<dyn OpState>>> {
241        self.0.state(session, node_id)
242    }
243}
244
245impl PulsedOp for PulseWrappingOp {
246    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
247        let (pulsing_input, stream) = if let Some((ix, fact)) =
248            &inputs.iter().enumerate().find(|(_ix, f)| f.stream.is_some())
249        {
250            (*ix, fact.stream.as_ref().unwrap())
251        } else {
252            bail!("PulseWrappingOp used on non streaming input")
253        };
254        let input_facts =
255            inputs.iter().map(|pf| pf.to_typed_fact()).collect::<TractResult<TVec<_>>>()?;
256        let input_facts_ref = input_facts.iter().map(|f| f.as_ref()).collect::<TVec<_>>();
257        let output_facts = self.0.output_facts(&input_facts_ref)?;
258        let output_facts_ref = output_facts.iter().collect::<TVec<_>>();
259        let axes_mapping = self.0.axes_mapping(&input_facts_ref, &output_facts_ref)?;
260        let axis_info = axes_mapping.axis((InOut::In(pulsing_input), stream.axis))?;
261        std::mem::drop(output_facts_ref);
262        output_facts
263            .into_iter()
264            .enumerate()
265            .map(|(ix, tf)| {
266                if let &[axis] = &*axis_info.outputs[ix] {
267                    Ok(PulsedFact {
268                        shape: tf.shape,
269                        datum_type: tf.datum_type,
270                        stream: Some(StreamInfo {
271                            delay: stream.delay,
272                            axis,
273                            dim: stream.dim.clone(),
274                        }),
275                    })
276                } else {
277                    bail!("Disappearing pulsing axis")
278                }
279            })
280            .collect()
281    }
282
283    as_op!();
284
285    fn to_typed(&self) -> Box<dyn TypedOp> {
286        self.0.clone()
287    }
288}
289
290#[derive(Debug, Clone)]
291pub(crate) struct NonPulsingWrappingOp(pub Box<dyn TypedOp>);
292
293impl Op for NonPulsingWrappingOp {
294    fn name(&self) -> StaticName {
295        format!("NonePulsingWrapping({}", self.0.name()).into()
296    }
297
298    fn as_typed(&self) -> Option<&dyn TypedOp> {
299        Some(self.0.as_ref())
300    }
301}
302
303impl EvalOp for NonPulsingWrappingOp {
304    fn is_stateless(&self) -> bool {
305        self.0.is_stateless()
306    }
307
308    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
309        self.0.eval(inputs)
310    }
311
312    fn state(
313        &self,
314        session: &mut SessionState,
315        node_id: usize,
316    ) -> TractResult<Option<Box<dyn OpState>>> {
317        self.0.state(session, node_id)
318    }
319}
320
321impl PulsedOp for NonPulsingWrappingOp {
322    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
323        let input_facts =
324            inputs.iter().map(|pf| pf.to_typed_fact()).collect::<TractResult<TVec<_>>>()?;
325        let input_facts_ref = input_facts.iter().map(|f| f.as_ref()).collect::<TVec<_>>();
326        let output_facts = self.0.output_facts(&input_facts_ref)?;
327        let output_facts_ref = output_facts.iter().collect::<TVec<_>>();
328        std::mem::drop(output_facts_ref);
329        output_facts
330            .into_iter()
331            .map(|tf| Ok(PulsedFact { shape: tf.shape, datum_type: tf.datum_type, stream: None }))
332            .collect()
333    }
334
335    as_op!();
336
337    fn to_typed(&self) -> Box<dyn TypedOp> {
338        self.0.clone()
339    }
340}