svod_tensor/nn/pad.rs
1//! Padding helpers: flat-to-pair conversion, auto-pad, pool pad resolution.
2
3use bon::bon;
4use svod_ir::{ConstValue, SInt, UOp};
5
6use super::{AutoPad, PadMode};
7use crate::Tensor;
8
9type Result<T> = crate::Result<T>;
10
11/// Convert flat pads `[begin0, begin1, ..., end0, end1, ...]` to `[(begin0, end0), ...]`.
12///
13/// ONNX stores padding as a flat array where the first half contains begin-pads
14/// and the second half contains end-pads. This function zips them into pairs.
15///
16/// # Examples
17///
18/// ```
19/// # use svod_tensor::nn::pad::flat_pads_to_pairs;
20/// let pairs = flat_pads_to_pairs(&[1, 2, 3, 4]);
21/// assert_eq!(pairs, vec![(1, 3), (2, 4)]);
22/// ```
23///
24/// ```
25/// # use svod_tensor::nn::pad::flat_pads_to_pairs;
26/// let pairs = flat_pads_to_pairs(&[0, 0, 1, 1, 1, 1, 0, 0]);
27/// assert_eq!(pairs, vec![(0, 1), (0, 1), (1, 0), (1, 0)]);
28/// ```
29pub fn flat_pads_to_pairs(pads: &[i64]) -> Vec<(isize, isize)> {
30 let n = pads.len() / 2;
31 (0..n).map(|i| (pads[i] as isize, pads[i + n] as isize)).collect()
32}
33
34/// Split total padding per dimension into `[begin0, begin1, ..., end0, end1, ...]`
35/// based on auto-pad mode (`SAME_UPPER`: more padding at end; `SAME_LOWER`: more at begin).
36///
37/// # Examples
38///
39/// ```
40/// # use svod_tensor::nn::AutoPad;
41/// # use svod_tensor::nn::pad::auto_pad_split;
42/// // Total pad of 3: SAME_UPPER puts floor at begin, ceil at end
43/// let flat = auto_pad_split(&[3], AutoPad::SameUpper);
44/// assert_eq!(flat, vec![1, 2]); // begin=1, end=2
45///
46/// let flat = auto_pad_split(&[3], AutoPad::SameLower);
47/// assert_eq!(flat, vec![2, 1]); // begin=2, end=1
48/// ```
49pub fn auto_pad_split(total_pads: &[isize], auto_pad: AutoPad) -> Vec<isize> {
50 let first: Vec<isize> = if auto_pad == AutoPad::SameUpper {
51 total_pads.iter().map(|&p| p.div_euclid(2)).collect()
52 } else {
53 total_pads.iter().map(|&p| p - p.div_euclid(2)).collect()
54 };
55 let mut result = first.clone();
56 result.extend(total_pads.iter().zip(&first).map(|(p, f)| p - f));
57 result
58}
59
60/// Resolve auto-pad mode and flat pads into `[(begin, end), ...]` pairs.
61///
62/// Handles all ONNX auto-pad modes: `VALID` (no padding), `NOTSET` (use explicit pads),
63/// `SAME_UPPER` and `SAME_LOWER` (compute padding to preserve spatial size).
64///
65/// # Examples
66///
67/// ```
68/// # use svod_tensor::nn::AutoPad;
69/// # use svod_tensor::nn::pad::resolve_pool_pads;
70/// # use svod_ir::SInt;
71/// // VALID mode: no padding regardless of explicit pads
72/// let pads = resolve_pool_pads(&[SInt::from(5), SInt::from(5)], &[], &[3, 3], &[1, 1], &[1, 1], AutoPad::Valid);
73/// assert_eq!(pads, vec![(0, 0), (0, 0)]);
74/// ```
75///
76/// ```
77/// # use svod_tensor::nn::AutoPad;
78/// # use svod_tensor::nn::pad::resolve_pool_pads;
79/// # use svod_ir::SInt;
80/// // SAME_UPPER: compute pads to keep output size = ceil(input/stride)
81/// let pads = resolve_pool_pads(&[SInt::from(5), SInt::from(5)], &[], &[3, 3], &[1, 1], &[1, 1], AutoPad::SameUpper);
82/// assert_eq!(pads, vec![(1, 1), (1, 1)]);
83/// ```
84pub fn resolve_pool_pads(
85 input_spatial: &[SInt],
86 pads: &[i64],
87 kernel: &[usize],
88 dilations: &[usize],
89 strides: &[usize],
90 auto_pad: AutoPad,
91) -> Vec<(isize, isize)> {
92 let n = kernel.len();
93 match auto_pad {
94 AutoPad::Valid => vec![(0, 0); n],
95 AutoPad::NotSet => {
96 if pads.is_empty() {
97 vec![(0, 0); n]
98 } else {
99 flat_pads_to_pairs(pads)
100 }
101 }
102 AutoPad::SameUpper | AutoPad::SameLower => {
103 // SameUpper/SameLower require concrete spatial dims to compute padding.
104 let total_pads: Vec<isize> = (0..n)
105 .map(|i| {
106 let is = input_spatial[i]
107 .as_const()
108 .expect("SameUpper/SameLower auto_pad requires concrete spatial dims");
109 let out_size = usize::div_ceil(is, strides[i]);
110 let eff_kernel = dilations[i] * (kernel[i] - 1) + 1;
111 let needed = (out_size - 1) * strides[i] + eff_kernel;
112 needed.saturating_sub(is) as isize
113 })
114 .collect();
115 let flat = auto_pad_split(&total_pads, auto_pad);
116 let half = flat.len() / 2;
117 (0..half).map(|i| (flat[i], flat[i + half])).collect()
118 }
119 }
120}
121
122#[bon]
123impl Tensor {
124 /// Pad with a custom fill value. Delegates to `try_pad` when `value == 0.0`.
125 ///
126 /// Each element of `padding` is `(before, after)` for the corresponding dimension.
127 /// Non-zero fill is implemented via an additive mask to avoid nested WHERE conditions.
128 ///
129 /// # Examples
130 ///
131 /// Zero padding (delegates to `try_pad`):
132 ///
133 /// ```
134 /// # use svod_tensor::Tensor;
135 /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
136 /// let mut y = x.try_pad_value(&[(1, 1)], 0.0).unwrap();
137 /// y.realize().unwrap();
138 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0, 0.0]);
139 /// ```
140 ///
141 /// Negative-infinity padding (useful for max pooling):
142 ///
143 /// ```
144 /// # use svod_tensor::Tensor;
145 /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
146 /// let mut y = x.try_pad_value(&[(1, 0)], f64::NEG_INFINITY).unwrap();
147 /// y.realize().unwrap();
148 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![f32::NEG_INFINITY, 1.0, 2.0, 3.0]);
149 /// ```
150 pub fn try_pad_value(&self, padding: &[(isize, isize)], value: f64) -> Result<Tensor> {
151 if value == 0.0 {
152 return self.try_pad(padding);
153 }
154 // Tinygrad approach: x.pad(0) + ones_pad.where(0, fill_value)
155 // ADD-based avoids fragile nested WHERE conditions that can evaluate to -inf.
156 let dtype = self.uop().dtype();
157 let sdtype = dtype.scalar().expect("pad_value requires scalar dtype");
158 let padded = self.try_pad(padding)?;
159 let ones = Tensor::new(UOp::const_(dtype.clone(), ConstValue::one(sdtype)));
160 let ones = ones.broadcast_to(&self.shape()?)?;
161 let ones_padded = ones.try_pad(padding)?;
162 let zero_cmp = Tensor::new(UOp::const_(dtype.clone(), ConstValue::zero(sdtype)));
163 let mask = ones_padded.try_ne(&zero_cmp)?;
164 let zero_val = Tensor::new(UOp::const_(dtype.clone(), ConstValue::zero(sdtype)));
165 let fill_val = Tensor::new(UOp::const_(dtype, ConstValue::Float(value)));
166 // mask ? zero : fill_value → data region gets 0, pad region gets fill_value
167 let fill_term = zero_val.where_(&mask, &fill_val)?;
168 padded.try_add(&fill_term)
169 }
170
171 /// Pad with configurable mode and fill value.
172 ///
173 /// Supports four padding modes via [`PadMode`]:
174 /// - `Constant` (default): fill with `value` (default 0.0)
175 /// - `Replicate`: repeat boundary values
176 /// - `Reflect`: mirror without repeating boundary
177 /// - `Circular`: wrap around
178 ///
179 /// # Examples
180 ///
181 /// Constant padding (default mode):
182 ///
183 /// ```
184 /// # use svod_tensor::Tensor;
185 /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
186 /// let mut y = x.pad_with().padding(&[(1, 1)]).call().unwrap();
187 /// y.realize().unwrap();
188 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0, 0.0]);
189 /// ```
190 ///
191 /// Constant padding with a custom fill value:
192 ///
193 /// ```
194 /// # use svod_tensor::Tensor;
195 /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
196 /// let mut y = x.pad_with().padding(&[(1, 1)]).value(-f64::INFINITY).call().unwrap();
197 /// y.realize().unwrap();
198 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![f32::NEG_INFINITY, 1.0, 2.0, 3.0, f32::NEG_INFINITY]);
199 /// ```
200 ///
201 /// Replicate (edge) padding:
202 ///
203 /// ```
204 /// # use svod_tensor::Tensor;
205 /// # use svod_tensor::nn::PadMode;
206 /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
207 /// let mut y = x.pad_with().padding(&[(2, 2)]).mode(PadMode::Replicate).call().unwrap();
208 /// y.realize().unwrap();
209 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0, 1.0, 1.0, 2.0, 3.0, 3.0, 3.0]);
210 /// ```
211 ///
212 /// Reflect padding:
213 ///
214 /// ```
215 /// # use svod_tensor::Tensor;
216 /// # use svod_tensor::nn::PadMode;
217 /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
218 /// let mut y = x.pad_with().padding(&[(2, 2)]).mode(PadMode::Reflect).call().unwrap();
219 /// y.realize().unwrap();
220 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0]);
221 /// ```
222 ///
223 /// Circular (wrap) padding:
224 ///
225 /// ```
226 /// # use svod_tensor::Tensor;
227 /// # use svod_tensor::nn::PadMode;
228 /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
229 /// let mut y = x.pad_with().padding(&[(2, 2)]).mode(PadMode::Circular).call().unwrap();
230 /// y.realize().unwrap();
231 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0]);
232 /// ```
233 #[builder]
234 pub fn pad_with(
235 &self,
236 padding: &[(isize, isize)],
237 #[builder(default)] mode: PadMode,
238 #[builder(default)] value: f64,
239 ) -> Result<Tensor> {
240 match mode {
241 PadMode::Constant => self.try_pad_value(padding, value),
242 PadMode::Replicate => pad_replicate(self, padding),
243 PadMode::Reflect => pad_reflect(self, padding),
244 PadMode::Circular => pad_circular(self, padding),
245 }
246 }
247}
248
249/// Replicate (edge) padding: repeats boundary values.
250///
251/// For each padded dimension, extracts edge slices via shrink, replicates
252/// via expand, then concatenates. Mirrors Tinygrad's `pad(mode="replicate")`.
253fn pad_replicate(data: &Tensor, padding: &[(isize, isize)]) -> Result<Tensor> {
254 let mut result = data.clone();
255 for (d, &(pad_before, pad_after)) in padding.iter().enumerate() {
256 if pad_before == 0 && pad_after == 0 {
257 continue;
258 }
259 let shape = result.shape()?;
260 let dim_size = shape[d].as_const().expect("replicate pad requires concrete dims") as isize;
261 let mut parts: Vec<Tensor> = Vec::new();
262
263 if pad_before > 0 {
264 let mut shrink_ranges: Vec<(isize, isize)> =
265 shape.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
266 shrink_ranges[d] = (0, 1);
267 let edge = result.try_shrink(&shrink_ranges)?;
268 let mut expand_shape: Vec<isize> = shape.iter().map(|s| s.as_const().unwrap() as isize).collect();
269 expand_shape[d] = pad_before;
270 parts.push(edge.try_expand(&expand_shape)?);
271 }
272
273 parts.push(result.clone());
274
275 if pad_after > 0 {
276 let mut shrink_ranges: Vec<(isize, isize)> =
277 shape.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
278 shrink_ranges[d] = (dim_size - 1, dim_size);
279 let edge = result.try_shrink(&shrink_ranges)?;
280 let mut expand_shape: Vec<isize> = shape.iter().map(|s| s.as_const().unwrap() as isize).collect();
281 expand_shape[d] = pad_after;
282 parts.push(edge.try_expand(&expand_shape)?);
283 }
284
285 let refs: Vec<&Tensor> = parts.iter().collect();
286 result = Tensor::cat(&refs, d as isize)?;
287 }
288 Ok(result)
289}
290
291/// Reflect padding: mirrors values without repeating the boundary.
292///
293/// For each padded dimension, extracts interior slices via shrink, flips them,
294/// then concatenates. E.g. `[1,2,3]` pad(2,2) → `[3,2,1,2,3,2,1]`.
295fn pad_reflect(data: &Tensor, padding: &[(isize, isize)]) -> Result<Tensor> {
296 let mut result = data.clone();
297 for (d, &(pad_before, pad_after)) in padding.iter().enumerate() {
298 if pad_before == 0 && pad_after == 0 {
299 continue;
300 }
301 let shape = result.shape()?;
302 let dim_size = shape[d].as_const().expect("reflect pad requires concrete dims") as isize;
303 let mut parts: Vec<Tensor> = Vec::new();
304
305 if pad_before > 0 {
306 let mut shrink_ranges: Vec<(isize, isize)> =
307 shape.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
308 shrink_ranges[d] = (1, 1 + pad_before);
309 let slice = result.try_shrink(&shrink_ranges)?;
310 parts.push(slice.flip(&[d as isize])?);
311 }
312
313 parts.push(result.clone());
314
315 if pad_after > 0 {
316 let mut shrink_ranges: Vec<(isize, isize)> =
317 shape.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
318 shrink_ranges[d] = (dim_size - 1 - pad_after, dim_size - 1);
319 let slice = result.try_shrink(&shrink_ranges)?;
320 parts.push(slice.flip(&[d as isize])?);
321 }
322
323 let refs: Vec<&Tensor> = parts.iter().collect();
324 result = Tensor::cat(&refs, d as isize)?;
325 }
326 Ok(result)
327}
328
329/// Circular (wrap) padding: wraps values from the opposite end.
330///
331/// Uses repeat + shrink: tile the tensor up to 3x per padded dimension,
332/// then shrink to extract the wrapped window. Mirrors Tinygrad's `pad(mode="circular")`.
333fn pad_circular(data: &Tensor, padding: &[(isize, isize)]) -> Result<Tensor> {
334 let shape = data.shape()?;
335 let ndim = shape.len();
336 let repeats: Vec<SInt> =
337 padding.iter().map(|&(pb, pa)| SInt::from(1 + usize::from(pb > 0) + usize::from(pa > 0))).collect();
338 let repeated = data.repeat(&repeats)?;
339 let rep_shape = repeated.shape()?;
340
341 let shrink_ranges: Vec<(isize, isize)> = (0..ndim)
342 .map(|d| {
343 let (pb, _pa) = padding[d];
344 let orig = shape[d].as_const().expect("circular pad requires concrete dims") as isize;
345 let rep_dim = rep_shape[d].as_const().unwrap() as isize;
346 let start = if pb == 0 { 0 } else { orig - pb };
347 let end = if padding[d].1 == 0 { rep_dim } else { rep_dim - orig + padding[d].1 };
348 (start, end)
349 })
350 .collect();
351 repeated.try_shrink(&shrink_ranges)
352}
353
354/// Adjust padding for ceil_mode output sizes.
355/// Per arXiv:1603.07285 section 5.1, relationship 15.
356pub(super) fn apply_ceil_mode(
357 padding: &[(isize, isize)],
358 input_spatial: &[SInt],
359 kernel: &[usize],
360 stride: &[usize],
361 dilation: &[usize],
362) -> Vec<(isize, isize)> {
363 let n = kernel.len();
364 let grouped: Vec<(isize, isize)> = padding.to_vec();
365 let mut ceil_pads = grouped.clone();
366 for i in 0..n {
367 let is = input_spatial[i].as_const().expect("ceil_mode requires concrete spatial dims");
368 let padded = is as isize + grouped[i].0 + grouped[i].1;
369 let eff_k = (dilation[i] * (kernel[i] - 1) + 1) as isize;
370 let s = stride[i] as isize;
371 let o_ceil = (padded - eff_k + s - 1) / s + 1;
372 let o_floor = (padded - eff_k) / s + 1;
373 if o_ceil > o_floor {
374 let last_start = s * (o_ceil - 1);
375 let extra = last_start + eff_k - padded;
376 let correction = (last_start - (grouped[i].0 + is as isize - 1)).max(0);
377 ceil_pads[i].1 += extra - correction;
378 }
379 }
380 ceil_pads
381}