Skip to main content

svod_tensor/nn/
pad.rs

1//! Padding helpers: flat-to-pair conversion, auto-pad, pool pad resolution.
2
3use bon::bon;
4use svod_ir::{ConstValue, SInt, UOp};
5
6use super::{AutoPad, PadMode};
7use crate::Tensor;
8
9type Result<T> = crate::Result<T>;
10
11/// Convert flat pads `[begin0, begin1, ..., end0, end1, ...]` to `[(begin0, end0), ...]`.
12///
13/// ONNX stores padding as a flat array where the first half contains begin-pads
14/// and the second half contains end-pads. This function zips them into pairs.
15///
16/// # Examples
17///
18/// ```
19/// # use svod_tensor::nn::pad::flat_pads_to_pairs;
20/// let pairs = flat_pads_to_pairs(&[1, 2, 3, 4]);
21/// assert_eq!(pairs, vec![(1, 3), (2, 4)]);
22/// ```
23///
24/// ```
25/// # use svod_tensor::nn::pad::flat_pads_to_pairs;
26/// let pairs = flat_pads_to_pairs(&[0, 0, 1, 1, 1, 1, 0, 0]);
27/// assert_eq!(pairs, vec![(0, 1), (0, 1), (1, 0), (1, 0)]);
28/// ```
29pub fn flat_pads_to_pairs(pads: &[i64]) -> Vec<(isize, isize)> {
30    let n = pads.len() / 2;
31    (0..n).map(|i| (pads[i] as isize, pads[i + n] as isize)).collect()
32}
33
34/// Split total padding per dimension into `[begin0, begin1, ..., end0, end1, ...]`
35/// based on auto-pad mode (`SAME_UPPER`: more padding at end; `SAME_LOWER`: more at begin).
36///
37/// # Examples
38///
39/// ```
40/// # use svod_tensor::nn::AutoPad;
41/// # use svod_tensor::nn::pad::auto_pad_split;
42/// // Total pad of 3: SAME_UPPER puts floor at begin, ceil at end
43/// let flat = auto_pad_split(&[3], AutoPad::SameUpper);
44/// assert_eq!(flat, vec![1, 2]); // begin=1, end=2
45///
46/// let flat = auto_pad_split(&[3], AutoPad::SameLower);
47/// assert_eq!(flat, vec![2, 1]); // begin=2, end=1
48/// ```
49pub fn auto_pad_split(total_pads: &[isize], auto_pad: AutoPad) -> Vec<isize> {
50    let first: Vec<isize> = if auto_pad == AutoPad::SameUpper {
51        total_pads.iter().map(|&p| p.div_euclid(2)).collect()
52    } else {
53        total_pads.iter().map(|&p| p - p.div_euclid(2)).collect()
54    };
55    let mut result = first.clone();
56    result.extend(total_pads.iter().zip(&first).map(|(p, f)| p - f));
57    result
58}
59
60/// Resolve auto-pad mode and flat pads into `[(begin, end), ...]` pairs.
61///
62/// Handles all ONNX auto-pad modes: `VALID` (no padding), `NOTSET` (use explicit pads),
63/// `SAME_UPPER` and `SAME_LOWER` (compute padding to preserve spatial size).
64///
65/// # Examples
66///
67/// ```
68/// # use svod_tensor::nn::AutoPad;
69/// # use svod_tensor::nn::pad::resolve_pool_pads;
70/// # use svod_ir::SInt;
71/// // VALID mode: no padding regardless of explicit pads
72/// let pads = resolve_pool_pads(&[SInt::from(5), SInt::from(5)], &[], &[3, 3], &[1, 1], &[1, 1], AutoPad::Valid);
73/// assert_eq!(pads, vec![(0, 0), (0, 0)]);
74/// ```
75///
76/// ```
77/// # use svod_tensor::nn::AutoPad;
78/// # use svod_tensor::nn::pad::resolve_pool_pads;
79/// # use svod_ir::SInt;
80/// // SAME_UPPER: compute pads to keep output size = ceil(input/stride)
81/// let pads = resolve_pool_pads(&[SInt::from(5), SInt::from(5)], &[], &[3, 3], &[1, 1], &[1, 1], AutoPad::SameUpper);
82/// assert_eq!(pads, vec![(1, 1), (1, 1)]);
83/// ```
84pub fn resolve_pool_pads(
85    input_spatial: &[SInt],
86    pads: &[i64],
87    kernel: &[usize],
88    dilations: &[usize],
89    strides: &[usize],
90    auto_pad: AutoPad,
91) -> Vec<(isize, isize)> {
92    let n = kernel.len();
93    match auto_pad {
94        AutoPad::Valid => vec![(0, 0); n],
95        AutoPad::NotSet => {
96            if pads.is_empty() {
97                vec![(0, 0); n]
98            } else {
99                flat_pads_to_pairs(pads)
100            }
101        }
102        AutoPad::SameUpper | AutoPad::SameLower => {
103            // SameUpper/SameLower require concrete spatial dims to compute padding.
104            let total_pads: Vec<isize> = (0..n)
105                .map(|i| {
106                    let is = input_spatial[i]
107                        .as_const()
108                        .expect("SameUpper/SameLower auto_pad requires concrete spatial dims");
109                    let out_size = usize::div_ceil(is, strides[i]);
110                    let eff_kernel = dilations[i] * (kernel[i] - 1) + 1;
111                    let needed = (out_size - 1) * strides[i] + eff_kernel;
112                    needed.saturating_sub(is) as isize
113                })
114                .collect();
115            let flat = auto_pad_split(&total_pads, auto_pad);
116            let half = flat.len() / 2;
117            (0..half).map(|i| (flat[i], flat[i + half])).collect()
118        }
119    }
120}
121
122#[bon]
123impl Tensor {
124    /// Pad with a custom fill value. Delegates to `try_pad` when `value == 0.0`.
125    ///
126    /// Each element of `padding` is `(before, after)` for the corresponding dimension.
127    /// Non-zero fill is implemented via an additive mask to avoid nested WHERE conditions.
128    ///
129    /// # Examples
130    ///
131    /// Zero padding (delegates to `try_pad`):
132    ///
133    /// ```
134    /// # use svod_tensor::Tensor;
135    /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
136    /// let mut y = x.try_pad_value(&[(1, 1)], 0.0).unwrap();
137    /// y.realize().unwrap();
138    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0, 0.0]);
139    /// ```
140    ///
141    /// Negative-infinity padding (useful for max pooling):
142    ///
143    /// ```
144    /// # use svod_tensor::Tensor;
145    /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
146    /// let mut y = x.try_pad_value(&[(1, 0)], f64::NEG_INFINITY).unwrap();
147    /// y.realize().unwrap();
148    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![f32::NEG_INFINITY, 1.0, 2.0, 3.0]);
149    /// ```
150    pub fn try_pad_value(&self, padding: &[(isize, isize)], value: f64) -> Result<Tensor> {
151        if value == 0.0 {
152            return self.try_pad(padding);
153        }
154        // Tinygrad approach: x.pad(0) + ones_pad.where(0, fill_value)
155        // ADD-based avoids fragile nested WHERE conditions that can evaluate to -inf.
156        let dtype = self.uop().dtype();
157        let sdtype = dtype.scalar().expect("pad_value requires scalar dtype");
158        let padded = self.try_pad(padding)?;
159        let ones = Tensor::new(UOp::const_(dtype.clone(), ConstValue::one(sdtype)));
160        let ones = ones.broadcast_to(&self.shape()?)?;
161        let ones_padded = ones.try_pad(padding)?;
162        let zero_cmp = Tensor::new(UOp::const_(dtype.clone(), ConstValue::zero(sdtype)));
163        let mask = ones_padded.try_ne(&zero_cmp)?;
164        let zero_val = Tensor::new(UOp::const_(dtype.clone(), ConstValue::zero(sdtype)));
165        let fill_val = Tensor::new(UOp::const_(dtype, ConstValue::Float(value)));
166        // mask ? zero : fill_value  →  data region gets 0, pad region gets fill_value
167        let fill_term = zero_val.where_(&mask, &fill_val)?;
168        padded.try_add(&fill_term)
169    }
170
171    /// Pad with configurable mode and fill value.
172    ///
173    /// Supports four padding modes via [`PadMode`]:
174    /// - `Constant` (default): fill with `value` (default 0.0)
175    /// - `Replicate`: repeat boundary values
176    /// - `Reflect`: mirror without repeating boundary
177    /// - `Circular`: wrap around
178    ///
179    /// # Examples
180    ///
181    /// Constant padding (default mode):
182    ///
183    /// ```
184    /// # use svod_tensor::Tensor;
185    /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
186    /// let mut y = x.pad_with().padding(&[(1, 1)]).call().unwrap();
187    /// y.realize().unwrap();
188    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0, 0.0]);
189    /// ```
190    ///
191    /// Constant padding with a custom fill value:
192    ///
193    /// ```
194    /// # use svod_tensor::Tensor;
195    /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
196    /// let mut y = x.pad_with().padding(&[(1, 1)]).value(-f64::INFINITY).call().unwrap();
197    /// y.realize().unwrap();
198    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![f32::NEG_INFINITY, 1.0, 2.0, 3.0, f32::NEG_INFINITY]);
199    /// ```
200    ///
201    /// Replicate (edge) padding:
202    ///
203    /// ```
204    /// # use svod_tensor::Tensor;
205    /// # use svod_tensor::nn::PadMode;
206    /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
207    /// let mut y = x.pad_with().padding(&[(2, 2)]).mode(PadMode::Replicate).call().unwrap();
208    /// y.realize().unwrap();
209    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0, 1.0, 1.0, 2.0, 3.0, 3.0, 3.0]);
210    /// ```
211    ///
212    /// Reflect padding:
213    ///
214    /// ```
215    /// # use svod_tensor::Tensor;
216    /// # use svod_tensor::nn::PadMode;
217    /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
218    /// let mut y = x.pad_with().padding(&[(2, 2)]).mode(PadMode::Reflect).call().unwrap();
219    /// y.realize().unwrap();
220    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0]);
221    /// ```
222    ///
223    /// Circular (wrap) padding:
224    ///
225    /// ```
226    /// # use svod_tensor::Tensor;
227    /// # use svod_tensor::nn::PadMode;
228    /// let x = Tensor::from_slice([1.0f32, 2.0, 3.0]);
229    /// let mut y = x.pad_with().padding(&[(2, 2)]).mode(PadMode::Circular).call().unwrap();
230    /// y.realize().unwrap();
231    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0]);
232    /// ```
233    #[builder]
234    pub fn pad_with(
235        &self,
236        padding: &[(isize, isize)],
237        #[builder(default)] mode: PadMode,
238        #[builder(default)] value: f64,
239    ) -> Result<Tensor> {
240        match mode {
241            PadMode::Constant => self.try_pad_value(padding, value),
242            PadMode::Replicate => pad_replicate(self, padding),
243            PadMode::Reflect => pad_reflect(self, padding),
244            PadMode::Circular => pad_circular(self, padding),
245        }
246    }
247}
248
249/// Replicate (edge) padding: repeats boundary values.
250///
251/// For each padded dimension, extracts edge slices via shrink, replicates
252/// via expand, then concatenates. Mirrors Tinygrad's `pad(mode="replicate")`.
253fn pad_replicate(data: &Tensor, padding: &[(isize, isize)]) -> Result<Tensor> {
254    let mut result = data.clone();
255    for (d, &(pad_before, pad_after)) in padding.iter().enumerate() {
256        if pad_before == 0 && pad_after == 0 {
257            continue;
258        }
259        let shape = result.shape()?;
260        let dim_size = shape[d].as_const().expect("replicate pad requires concrete dims") as isize;
261        let mut parts: Vec<Tensor> = Vec::new();
262
263        if pad_before > 0 {
264            let mut shrink_ranges: Vec<(isize, isize)> =
265                shape.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
266            shrink_ranges[d] = (0, 1);
267            let edge = result.try_shrink(&shrink_ranges)?;
268            let mut expand_shape: Vec<isize> = shape.iter().map(|s| s.as_const().unwrap() as isize).collect();
269            expand_shape[d] = pad_before;
270            parts.push(edge.try_expand(&expand_shape)?);
271        }
272
273        parts.push(result.clone());
274
275        if pad_after > 0 {
276            let mut shrink_ranges: Vec<(isize, isize)> =
277                shape.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
278            shrink_ranges[d] = (dim_size - 1, dim_size);
279            let edge = result.try_shrink(&shrink_ranges)?;
280            let mut expand_shape: Vec<isize> = shape.iter().map(|s| s.as_const().unwrap() as isize).collect();
281            expand_shape[d] = pad_after;
282            parts.push(edge.try_expand(&expand_shape)?);
283        }
284
285        let refs: Vec<&Tensor> = parts.iter().collect();
286        result = Tensor::cat(&refs, d as isize)?;
287    }
288    Ok(result)
289}
290
291/// Reflect padding: mirrors values without repeating the boundary.
292///
293/// For each padded dimension, extracts interior slices via shrink, flips them,
294/// then concatenates. E.g. `[1,2,3]` pad(2,2) → `[3,2,1,2,3,2,1]`.
295fn pad_reflect(data: &Tensor, padding: &[(isize, isize)]) -> Result<Tensor> {
296    let mut result = data.clone();
297    for (d, &(pad_before, pad_after)) in padding.iter().enumerate() {
298        if pad_before == 0 && pad_after == 0 {
299            continue;
300        }
301        let shape = result.shape()?;
302        let dim_size = shape[d].as_const().expect("reflect pad requires concrete dims") as isize;
303        let mut parts: Vec<Tensor> = Vec::new();
304
305        if pad_before > 0 {
306            let mut shrink_ranges: Vec<(isize, isize)> =
307                shape.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
308            shrink_ranges[d] = (1, 1 + pad_before);
309            let slice = result.try_shrink(&shrink_ranges)?;
310            parts.push(slice.flip(&[d as isize])?);
311        }
312
313        parts.push(result.clone());
314
315        if pad_after > 0 {
316            let mut shrink_ranges: Vec<(isize, isize)> =
317                shape.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
318            shrink_ranges[d] = (dim_size - 1 - pad_after, dim_size - 1);
319            let slice = result.try_shrink(&shrink_ranges)?;
320            parts.push(slice.flip(&[d as isize])?);
321        }
322
323        let refs: Vec<&Tensor> = parts.iter().collect();
324        result = Tensor::cat(&refs, d as isize)?;
325    }
326    Ok(result)
327}
328
329/// Circular (wrap) padding: wraps values from the opposite end.
330///
331/// Uses repeat + shrink: tile the tensor up to 3x per padded dimension,
332/// then shrink to extract the wrapped window. Mirrors Tinygrad's `pad(mode="circular")`.
333fn pad_circular(data: &Tensor, padding: &[(isize, isize)]) -> Result<Tensor> {
334    let shape = data.shape()?;
335    let ndim = shape.len();
336    let repeats: Vec<SInt> =
337        padding.iter().map(|&(pb, pa)| SInt::from(1 + usize::from(pb > 0) + usize::from(pa > 0))).collect();
338    let repeated = data.repeat(&repeats)?;
339    let rep_shape = repeated.shape()?;
340
341    let shrink_ranges: Vec<(isize, isize)> = (0..ndim)
342        .map(|d| {
343            let (pb, _pa) = padding[d];
344            let orig = shape[d].as_const().expect("circular pad requires concrete dims") as isize;
345            let rep_dim = rep_shape[d].as_const().unwrap() as isize;
346            let start = if pb == 0 { 0 } else { orig - pb };
347            let end = if padding[d].1 == 0 { rep_dim } else { rep_dim - orig + padding[d].1 };
348            (start, end)
349        })
350        .collect();
351    repeated.try_shrink(&shrink_ranges)
352}
353
354/// Adjust padding for ceil_mode output sizes.
355/// Per arXiv:1603.07285 section 5.1, relationship 15.
356pub(super) fn apply_ceil_mode(
357    padding: &[(isize, isize)],
358    input_spatial: &[SInt],
359    kernel: &[usize],
360    stride: &[usize],
361    dilation: &[usize],
362) -> Vec<(isize, isize)> {
363    let n = kernel.len();
364    let grouped: Vec<(isize, isize)> = padding.to_vec();
365    let mut ceil_pads = grouped.clone();
366    for i in 0..n {
367        let is = input_spatial[i].as_const().expect("ceil_mode requires concrete spatial dims");
368        let padded = is as isize + grouped[i].0 + grouped[i].1;
369        let eff_k = (dilation[i] * (kernel[i] - 1) + 1) as isize;
370        let s = stride[i] as isize;
371        let o_ceil = (padded - eff_k + s - 1) / s + 1;
372        let o_floor = (padded - eff_k) / s + 1;
373        if o_ceil > o_floor {
374            let last_start = s * (o_ceil - 1);
375            let extra = last_start + eff_k - padded;
376            let correction = (last_start - (grouped[i].0 + is as isize - 1)).max(0);
377            ceil_pads[i].1 += extra - correction;
378        }
379    }
380    ceil_pads
381}