1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
use crate::internal::*; use tract_core::ops::array::Slice; register_all!(Slice: pulsify); fn pulsify( op: &Slice, _source: &TypedModel, node: &TypedNode, target: &mut PulsedModel, mapping: &HashMap<OutletId, OutletId>, _pulse: usize, ) -> TractResult<TVec<OutletId>> { let input = mapping[&node.inputs[0]]; let fact = target.outlet_fact(input)?.clone(); let op: Box<dyn PulsedOp> = if op.axis == fact.axis { let skip = op.start.to_usize()?; let take = (op.end.clone() - &op.start).to_dim(); PulsedAxisSlice { axis: op.axis, skip, take }.into() } else { tract_core::dyn_clone::clone_box(op) }; target.wire_node(&*node.name, op, &[input]) } impl PulsedOp for Slice { fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> { let mut fact = inputs[0].clone(); let len = (self.end.clone() - &self.start).to_dim(); if self.axis == fact.axis { fact.delay += self.start.to_usize()?; fact.dim = len } else { fact.shape[self.axis] = len; } Ok(tvec!(fact)) } as_op!(); pulsed_op_to_typed_op!(); } #[derive(Debug, Clone, Default, Hash)] pub struct PulsedAxisSlice { pub axis: usize, pub skip: usize, pub take: TDim, } impl_dyn_hash!(PulsedAxisSlice); impl Op for PulsedAxisSlice { fn name(&self) -> Cow<str> { "PulsedAxisSlice".into() } fn info(&self) -> TractResult<Vec<String>> { Ok(vec![format!("axis:{}, skip:{} take:{}", self.axis, self.skip, self.take)]) } op_pulse!(); not_a_typed_op!(); } impl EvalOp for PulsedAxisSlice { fn is_stateless(&self) -> bool { true } fn eval(&self, inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> { Ok(inputs) } } impl PulsedOp for PulsedAxisSlice { fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> { let mut fact = inputs[0].clone(); fact.delay += self.skip; fact.dim = self.take.clone(); Ok(tvec!(fact)) } fn to_typed(&self) -> Box<dyn TypedOp> { Box::new(tract_core::ops::identity::Identity::default()) } as_op!(); }