tract_core/ops/array/
pad.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4pub enum PadMode {
5    Constant(Arc<Tensor>),
6    Reflect,
7    Edge,
8}
9
10impl Default for PadMode {
11    fn default() -> PadMode {
12        PadMode::Constant(Arc::new(0.0f32.into()))
13    }
14}
15
16#[derive(Debug, Clone, new, Default, Hash)]
17pub struct Pad {
18    pub pads: Vec<(usize, usize)>,
19    pub mode: PadMode,
20}
21
22impl Pad {
23    fn eval_t<T>(&self, input_tensor: TValue) -> TractResult<TValue>
24    where
25        T: Copy + Datum,
26    {
27        use tract_ndarray::*;
28        let input = input_tensor.to_array_view::<T>()?;
29        let output_shape: Vec<usize> =
30            input.shape().iter().zip(self.pads.iter()).map(|(&d, &(a, b))| d + a + b).collect();
31        let element = match &self.mode {
32            PadMode::Constant(f) => f.cast_to_scalar::<T>()?,
33            _ => T::default(),
34        };
35        let mut output = ArrayD::<T>::from_elem(output_shape, element);
36        let slice_spec: Vec<SliceInfoElem> = self
37            .pads
38            .iter()
39            .map(|&(a, b)| SliceInfoElem::Slice {
40                start: a as isize,
41                end: if b != 0 { Some(-(b as isize)) } else { None },
42                step: 1,
43            })
44            .collect();
45        let slice_info = SliceInfo::<_, IxDyn, IxDyn>::try_from(slice_spec).unwrap();
46        output.slice_mut(slice_info.as_ref()).assign(&input);
47        if self.mode == PadMode::Reflect || self.mode == PadMode::Edge {
48            for (ax, &(bef, aft)) in self.pads.iter().enumerate() {
49                let axis = Axis(ax);
50                let dim = output.shape()[ax];
51                {
52                    let (mut pad, data) = output.view_mut().split_at(axis, bef);
53                    for i in 0..bef {
54                        let mut target = pad.slice_axis_mut(axis, Slice::from(i..i + 1));
55                        let source_slice = match self.mode {
56                            PadMode::Edge => 0,
57                            PadMode::Reflect => bef - i,
58                            _ => panic!(),
59                        };
60                        let source =
61                            data.slice_axis(axis, Slice::from(source_slice..source_slice + 1));
62                        target.assign(&source);
63                    }
64                }
65                {
66                    let (data, mut pad) = output.view_mut().split_at(axis, dim - aft);
67                    for i in 0..aft {
68                        let mut target = pad.slice_axis_mut(axis, Slice::from(i..i + 1));
69                        let source_slice = match self.mode {
70                            PadMode::Edge => dim - aft - 1,
71                            PadMode::Reflect => dim - aft - 2 - i,
72                            _ => panic!(),
73                        };
74                        let source =
75                            data.slice_axis(axis, Slice::from(source_slice..source_slice + 1));
76                        target.assign(&source);
77                    }
78                }
79            }
80        }
81        let mut output = output.into_tensor();
82        unsafe { output.set_datum_type(input_tensor.datum_type()) }
83        Ok(output.into_tvalue())
84    }
85}
86
87impl Op for Pad {
88    fn name(&self) -> Cow<str> {
89        "Pad".into()
90    }
91
92    fn info(&self) -> TractResult<Vec<String>> {
93        Ok(vec![format!("Mode: {:?}, pads: {:?})", self.mode, self.pads,)])
94    }
95
96    op_as_typed_op!();
97}
98
99impl EvalOp for Pad {
100    fn is_stateless(&self) -> bool {
101        true
102    }
103
104    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
105        let input = args_1!(inputs);
106        Ok(tvec!(dispatch_numbers!(Self::eval_t(input.datum_type())(self, input))?))
107    }
108}
109
110impl TypedOp for Pad {
111    as_op!();
112
113    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
114        let mut fact = inputs[0].without_value();
115        if self.pads.len() != fact.rank() {
116            bail!("Inconsistent pad: input of rank {}, pads are: {:?}", fact.rank(), self.pads);
117        }
118        for (ix, (b, e)) in self.pads.iter().enumerate() {
119            fact.shape.set(ix, fact.shape[ix].clone() + *b + *e);
120        }
121        Ok(tvec!(fact))
122    }
123
124    fn axes_mapping(
125        &self,
126        inputs: &[&TypedFact],
127        outputs: &[&TypedFact],
128    ) -> TractResult<AxesMapping> {
129        let mut result = AxesMapping::disconnected(inputs, outputs)?;
130        for (ix, pads) in self.pads.iter().enumerate() {
131            if pads == &(0, 0) {
132                result = result.linking((InOut::In(0), ix), (InOut::Out(0), ix))?;
133            }
134        }
135        Ok(result)
136    }
137
138    fn change_axes(
139        &self,
140        model: &TypedModel,
141        node: &TypedNode,
142        io: InOut,
143        change: &AxisOp,
144    ) -> TractResult<Option<AxisChangeConsequence>> {
145        let mut new_op = self.clone();
146        if let (InOut::In(0), AxisOp::Rm(ix)) = (io, change) {
147            if new_op.pads.remove(*ix) == (0, 0) {
148                return Ok(Some(AxisChangeConsequence::new(
149                    model,
150                    node,
151                    Some(Box::new(new_op)),
152                    change,
153                )));
154            }
155        }
156        if let (InOut::In(0), AxisOp::Add(ix)) = (io, change) {
157            new_op.pads.insert(*ix, (0, 0));
158            return Ok(Some(AxisChangeConsequence::new(
159                model,
160                node,
161                Some(Box::new(new_op)),
162                change,
163            )));
164        }
165        Ok(None)
166    }
167
168    fn declutter(
169        &self,
170        model: &TypedModel,
171        node: &TypedNode,
172    ) -> TractResult<Option<TypedModelPatch>> {
173        if self.pads.iter().all(|p| p.0 == 0 && p.1 == 0) {
174            TypedModelPatch::shunt_one_op(model, node)
175        } else {
176            Ok(None)
177        }
178    }
179}