1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
use crate::internal::*; use tract_core::ops::change_axes::AxisOp; register_all!(AxisOp: pulsify); fn pulsify( op: &AxisOp, _source: &TypedModel, node: &TypedNode, target: &mut PulsedModel, mapping: &HashMap<OutletId, OutletId>, _pulse: usize, ) -> TractResult<TVec<OutletId>> { let input = mapping[&node.inputs[0]]; target.wire_node(&*node.name, op.clone(), &[input]) } impl PulsedOp for AxisOp { fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> { let mut fact = inputs[0].clone(); fact.shape = inputs[0].shape.clone(); self.change_shape_array(&mut fact.shape)?; fact.axis = self .transform_axis(fact.axis) .ok_or_else(|| format_err!("Invalid axis for pulsification"))?; Ok(tvec!(fact)) } as_op!(); pulsed_op_to_typed_op!(); }