1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4pub enum PadMode {
5 Constant(Arc<Tensor>),
6 Reflect,
7 Edge,
8}
9
10impl Default for PadMode {
11 fn default() -> PadMode {
12 PadMode::Constant(Arc::new(0.0f32.into()))
13 }
14}
15
16#[derive(Debug, Clone, new, Default, Hash, PartialEq, Eq)]
17pub struct Pad {
18 pub pads: Vec<(usize, usize)>,
19 pub mode: PadMode,
20}
21
22impl Pad {
23 fn eval_t<T>(&self, input_tensor: TValue) -> TractResult<TValue>
24 where
25 T: Copy + Datum,
26 {
27 use tract_ndarray::*;
28 let input = input_tensor.to_plain_array_view::<T>()?;
29 let output_shape: Vec<usize> =
30 input.shape().iter().zip(self.pads.iter()).map(|(&d, &(a, b))| d + a + b).collect();
31 let element = match &self.mode {
32 PadMode::Constant(f) => f.cast_to_scalar::<T>()?,
33 _ => T::default(),
34 };
35 let mut output = ArrayD::<T>::from_elem(output_shape, element);
36 let slice_spec: Vec<SliceInfoElem> = self
37 .pads
38 .iter()
39 .map(|&(a, b)| SliceInfoElem::Slice {
40 start: a as isize,
41 end: if b != 0 { Some(-(b as isize)) } else { None },
42 step: 1,
43 })
44 .collect();
45 let slice_info = SliceInfo::<_, IxDyn, IxDyn>::try_from(slice_spec).unwrap();
46 output.slice_mut(slice_info.as_ref()).assign(&input);
47 if self.mode == PadMode::Reflect || self.mode == PadMode::Edge {
48 for (ax, &(bef, aft)) in self.pads.iter().enumerate() {
49 let axis = Axis(ax);
50 let dim = output.shape()[ax];
51 {
52 let (mut pad, data) = output.view_mut().split_at(axis, bef);
53 for i in 0..bef {
54 let mut target = pad.slice_axis_mut(axis, Slice::from(i..i + 1));
55 let source_slice = match self.mode {
56 PadMode::Edge => 0,
57 PadMode::Reflect => bef - i,
58 _ => panic!(),
59 };
60 let source =
61 data.slice_axis(axis, Slice::from(source_slice..source_slice + 1));
62 target.assign(&source);
63 }
64 }
65 {
66 let (data, mut pad) = output.view_mut().split_at(axis, dim - aft);
67 for i in 0..aft {
68 let mut target = pad.slice_axis_mut(axis, Slice::from(i..i + 1));
69 let source_slice = match self.mode {
70 PadMode::Edge => dim - aft - 1,
71 PadMode::Reflect => dim - aft - 2 - i,
72 _ => panic!(),
73 };
74 let source =
75 data.slice_axis(axis, Slice::from(source_slice..source_slice + 1));
76 target.assign(&source);
77 }
78 }
79 }
80 }
81 let mut output = output.into_tensor();
82 unsafe { output.set_datum_type(input_tensor.datum_type()) }
83 Ok(output.into_tvalue())
84 }
85}
86
87impl Op for Pad {
88 fn name(&self) -> StaticName {
89 "Pad".into()
90 }
91
92 fn info(&self) -> TractResult<Vec<String>> {
93 Ok(vec![format!("Mode: {:?}, pads: {:?})", self.mode, self.pads,)])
94 }
95
96 op_as_typed_op!();
97}
98
99impl EvalOp for Pad {
100 fn is_stateless(&self) -> bool {
101 true
102 }
103
104 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
105 let input = args_1!(inputs);
106 Ok(tvec!(dispatch_numbers!(Self::eval_t(input.datum_type())(self, input))?))
107 }
108}
109
110impl TypedOp for Pad {
111 as_op!();
112
113 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
114 let mut fact = inputs[0].without_value();
115 if self.pads.len() != fact.rank() {
116 bail!("Inconsistent pad: input of rank {}, pads are: {:?}", fact.rank(), self.pads);
117 }
118 for (ix, (b, e)) in self.pads.iter().enumerate() {
119 fact.shape.set(ix, fact.shape[ix].clone() + *b + *e);
120 }
121 Ok(tvec!(fact))
122 }
123
124 fn input_roi(
125 &self,
126 model: &TypedModel,
127 node: &TypedNode,
128 ) -> TractResult<Option<TVec<Option<TDim>>>> {
129 let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
130 let Some(roi) = &output_fact.region_of_interest else { return Ok(None) };
131 let mut input_roi = roi.clone();
133 for (axis, &(before, _)) in self.pads.iter().enumerate() {
134 if before == 0 {
135 continue;
136 }
137 if let Some(sym) = input_roi
138 .symbols()
139 .into_iter()
140 .find(|s| crate::ops::logic::sym_to_coord_axis(s) == Some(axis))
141 {
142 let shifted = TDim::Sym(sym.clone()) - TDim::Val(before as i64);
143 input_roi = input_roi.substitute(&sym, &shifted).unwrap_or(input_roi);
144 }
145 }
146 Ok(Some(tvec![Some(input_roi)]))
147 }
148
149 fn axes_mapping(
150 &self,
151 inputs: &[&TypedFact],
152 outputs: &[&TypedFact],
153 ) -> TractResult<AxesMapping> {
154 let mut result = AxesMapping::disconnected(inputs, outputs)?;
155 for (ix, pads) in self.pads.iter().enumerate() {
156 if pads == &(0, 0) {
157 result = result.linking((InOut::In(0), ix), (InOut::Out(0), ix))?;
158 }
159 }
160 Ok(result)
161 }
162
163 fn change_axes(
164 &self,
165 model: &TypedModel,
166 node: &TypedNode,
167 io: InOut,
168 change: &AxisOp,
169 ) -> TractResult<Option<AxisChangeConsequence>> {
170 let mut new_op = self.clone();
171 if let (InOut::In(0), AxisOp::Rm(ix)) = (io, change)
172 && new_op.pads.remove(*ix) == (0, 0)
173 {
174 return Ok(Some(AxisChangeConsequence::new(
175 model,
176 node,
177 Some(Box::new(new_op)),
178 change,
179 )));
180 }
181 if let (InOut::In(0), AxisOp::Add(ix)) = (io, change) {
182 new_op.pads.insert(*ix, (0, 0));
183 return Ok(Some(AxisChangeConsequence::new(
184 model,
185 node,
186 Some(Box::new(new_op)),
187 change,
188 )));
189 }
190 Ok(None)
191 }
192
193 fn declutter(
194 &self,
195 model: &TypedModel,
196 node: &TypedNode,
197 ) -> TractResult<Option<TypedModelPatch>> {
198 if self.pads.iter().all(|p| p.0 == 0 && p.1 == 0) {
199 TypedModelPatch::shunt_one_op(model, node)
200 } else {
201 Ok(None)
202 }
203 }
204}