tract_hir/ops/array/
crop.rs1use crate::infer::*;
2use crate::internal::*;
3
4#[derive(Debug, Clone, new, Default, Hash)]
5pub struct Crop {
6 pub axis: usize,
7 pub start: usize,
8 pub end: usize,
9}
10
11
12
13impl Expansion for Crop {
14 fn name(&self) -> StaticName {
15 "Crop".into()
16 }
17
18
19 fn rules<'r, 'p: 'r, 's: 'r>(
20 &'s self,
21 s: &mut Solver<'r>,
22 inputs: &'p [TensorProxy],
23 outputs: &'p [TensorProxy],
24 ) -> InferenceResult {
25 check_input_arity(inputs, 1)?;
26 check_output_arity(outputs, 1)?;
27 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
28 s.equals(&inputs[0].rank, &outputs[0].rank)?;
29 s.given(&inputs[0].rank, move |s, rank| {
30 (0..rank as usize).try_for_each(|ax| {
31 if self.axis == ax {
32 s.equals(
33 &inputs[0].shape[ax],
34 outputs[0].shape[ax].bex() + self.start.to_dim() + self.end.to_dim(),
35 )
36 } else {
37 s.equals(&inputs[0].shape[ax], &outputs[0].shape[ax])
38 }
39 })
40 })?;
41 Ok(())
42 }
43
44 fn wire(
45 &self,
46 prefix: &str,
47 target: &mut TypedModel,
48 inputs: &[OutletId],
49 ) -> TractResult<TVec<OutletId>> {
50 let len = target.outlet_fact(inputs[0])?.shape[self.axis].clone();
51 target.wire_node(
52 prefix,
53 crate::ops::array::Slice::new(
54 self.axis,
55 self.start.to_dim(),
56 len - self.end.to_dim(),
57 ),
58 inputs,
59 )
60 }
61}