tract_hir/ops/array/
crop.rs

1use crate::infer::*;
2use crate::internal::*;
3
4#[derive(Debug, Clone, new, Default, Hash)]
5pub struct Crop {
6    pub axis: usize,
7    pub start: usize,
8    pub end: usize,
9}
10
11
12
13impl Expansion for Crop {
14    fn name(&self) -> StaticName {
15        "Crop".into()
16    }
17
18
19    fn rules<'r, 'p: 'r, 's: 'r>(
20        &'s self,
21        s: &mut Solver<'r>,
22        inputs: &'p [TensorProxy],
23        outputs: &'p [TensorProxy],
24    ) -> InferenceResult {
25        check_input_arity(inputs, 1)?;
26        check_output_arity(outputs, 1)?;
27        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
28        s.equals(&inputs[0].rank, &outputs[0].rank)?;
29        s.given(&inputs[0].rank, move |s, rank| {
30            (0..rank as usize).try_for_each(|ax| {
31                if self.axis == ax {
32                    s.equals(
33                        &inputs[0].shape[ax],
34                        outputs[0].shape[ax].bex() + self.start.to_dim() + self.end.to_dim(),
35                    )
36                } else {
37                    s.equals(&inputs[0].shape[ax], &outputs[0].shape[ax])
38                }
39            })
40        })?;
41        Ok(())
42    }
43
44    fn wire(
45        &self,
46        prefix: &str,
47        target: &mut TypedModel,
48        inputs: &[OutletId],
49    ) -> TractResult<TVec<OutletId>> {
50        let len = target.outlet_fact(inputs[0])?.shape[self.axis].clone();
51        target.wire_node(
52            prefix,
53            crate::ops::array::Slice::new(
54                self.axis,
55                self.start.to_dim(),
56                len - self.end.to_dim(),
57            ),
58            inputs,
59        )
60    }
61}