1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4pub enum PadMode {
5 Constant(Arc<Tensor>),
6 Reflect,
7 Edge,
8}
9
10impl Default for PadMode {
11 fn default() -> PadMode {
12 PadMode::Constant(Arc::new(0.0f32.into()))
13 }
14}
15
16#[derive(Debug, Clone, new, Default, Hash)]
17pub struct Pad {
18 pub pads: Vec<(usize, usize)>,
19 pub mode: PadMode,
20}
21
22impl Pad {
23 fn eval_t<T>(&self, input_tensor: TValue) -> TractResult<TValue>
24 where
25 T: Copy + Datum,
26 {
27 use tract_ndarray::*;
28 let input = input_tensor.to_array_view::<T>()?;
29 let output_shape: Vec<usize> =
30 input.shape().iter().zip(self.pads.iter()).map(|(&d, &(a, b))| d + a + b).collect();
31 let element = match &self.mode {
32 PadMode::Constant(f) => f.cast_to_scalar::<T>()?,
33 _ => T::default(),
34 };
35 let mut output = ArrayD::<T>::from_elem(output_shape, element);
36 let slice_spec: Vec<SliceInfoElem> = self
37 .pads
38 .iter()
39 .map(|&(a, b)| SliceInfoElem::Slice {
40 start: a as isize,
41 end: if b != 0 { Some(-(b as isize)) } else { None },
42 step: 1,
43 })
44 .collect();
45 let slice_info = SliceInfo::<_, IxDyn, IxDyn>::try_from(slice_spec).unwrap();
46 output.slice_mut(slice_info.as_ref()).assign(&input);
47 if self.mode == PadMode::Reflect || self.mode == PadMode::Edge {
48 for (ax, &(bef, aft)) in self.pads.iter().enumerate() {
49 let axis = Axis(ax);
50 let dim = output.shape()[ax];
51 {
52 let (mut pad, data) = output.view_mut().split_at(axis, bef);
53 for i in 0..bef {
54 let mut target = pad.slice_axis_mut(axis, Slice::from(i..i + 1));
55 let source_slice = match self.mode {
56 PadMode::Edge => 0,
57 PadMode::Reflect => bef - i,
58 _ => panic!(),
59 };
60 let source =
61 data.slice_axis(axis, Slice::from(source_slice..source_slice + 1));
62 target.assign(&source);
63 }
64 }
65 {
66 let (data, mut pad) = output.view_mut().split_at(axis, dim - aft);
67 for i in 0..aft {
68 let mut target = pad.slice_axis_mut(axis, Slice::from(i..i + 1));
69 let source_slice = match self.mode {
70 PadMode::Edge => dim - aft - 1,
71 PadMode::Reflect => dim - aft - 2 - i,
72 _ => panic!(),
73 };
74 let source =
75 data.slice_axis(axis, Slice::from(source_slice..source_slice + 1));
76 target.assign(&source);
77 }
78 }
79 }
80 }
81 let mut output = output.into_tensor();
82 unsafe { output.set_datum_type(input_tensor.datum_type()) }
83 Ok(output.into_tvalue())
84 }
85}
86
87impl Op for Pad {
88 fn name(&self) -> Cow<str> {
89 "Pad".into()
90 }
91
92 fn info(&self) -> TractResult<Vec<String>> {
93 Ok(vec![format!("Mode: {:?}, pads: {:?})", self.mode, self.pads,)])
94 }
95
96 op_as_typed_op!();
97}
98
99impl EvalOp for Pad {
100 fn is_stateless(&self) -> bool {
101 true
102 }
103
104 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
105 let input = args_1!(inputs);
106 Ok(tvec!(dispatch_numbers!(Self::eval_t(input.datum_type())(self, input))?))
107 }
108}
109
110impl TypedOp for Pad {
111 as_op!();
112
113 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
114 let mut fact = inputs[0].without_value();
115 if self.pads.len() != fact.rank() {
116 bail!("Inconsistent pad: input of rank {}, pads are: {:?}", fact.rank(), self.pads);
117 }
118 for (ix, (b, e)) in self.pads.iter().enumerate() {
119 fact.shape.set(ix, fact.shape[ix].clone() + *b + *e);
120 }
121 Ok(tvec!(fact))
122 }
123
124 fn axes_mapping(
125 &self,
126 inputs: &[&TypedFact],
127 outputs: &[&TypedFact],
128 ) -> TractResult<AxesMapping> {
129 let mut result = AxesMapping::disconnected(inputs, outputs)?;
130 for (ix, pads) in self.pads.iter().enumerate() {
131 if pads == &(0, 0) {
132 result = result.linking((InOut::In(0), ix), (InOut::Out(0), ix))?;
133 }
134 }
135 Ok(result)
136 }
137
138 fn change_axes(
139 &self,
140 model: &TypedModel,
141 node: &TypedNode,
142 io: InOut,
143 change: &AxisOp,
144 ) -> TractResult<Option<AxisChangeConsequence>> {
145 let mut new_op = self.clone();
146 if let (InOut::In(0), AxisOp::Rm(ix)) = (io, change) {
147 if new_op.pads.remove(*ix) == (0, 0) {
148 return Ok(Some(AxisChangeConsequence::new(
149 model,
150 node,
151 Some(Box::new(new_op)),
152 change,
153 )));
154 }
155 }
156 if let (InOut::In(0), AxisOp::Add(ix)) = (io, change) {
157 new_op.pads.insert(*ix, (0, 0));
158 return Ok(Some(AxisChangeConsequence::new(
159 model,
160 node,
161 Some(Box::new(new_op)),
162 change,
163 )));
164 }
165 Ok(None)
166 }
167
168 fn declutter(
169 &self,
170 model: &TypedModel,
171 node: &TypedNode,
172 ) -> TractResult<Option<TypedModelPatch>> {
173 if self.pads.iter().all(|p| p.0 == 0 && p.1 == 0) {
174 TypedModelPatch::shunt_one_op(model, node)
175 } else {
176 Ok(None)
177 }
178 }
179}