Skip to main content

svod_tensor/nn/
mod.rs

1//! Neural network operations: convolution, pooling, normalization.
2
3mod conv;
4mod conv1d;
5mod grid_sample;
6mod linear;
7mod lstm_cell;
8mod norm;
9pub mod pad;
10mod pool;
11mod quantize;
12mod resize;
13mod rnn;
14
15pub use conv1d::Conv1d;
16pub use linear::Linear;
17pub use lstm_cell::LSTMCell;
18pub use rnn::{GruOutput, LstmOutput, RnnOutput};
19
20/// A neural network layer.
21pub trait Layer {
22    fn forward(&self, x: &Tensor) -> Result<Tensor>;
23}
24
25/// ReLU activation layer: `max(0, x)`.
26pub struct Relu;
27
28impl Layer for Relu {
29    fn forward(&self, x: &Tensor) -> Result<Tensor> {
30        x.relu()
31    }
32}
33
34pub use pad::{auto_pad_split, flat_pads_to_pairs, resolve_pool_pads};
35
36use bon::bon;
37use snafu::ResultExt;
38use svod_dtype::DType;
39use svod_ir::SInt;
40
41use crate::Tensor;
42use crate::error::{DivisibilitySnafu, NdimExactSnafu, NdimMinimumSnafu, ParamRangeSnafu, UOpSnafu};
43use crate::reduce::AxisSpec;
44
45type Result<T> = crate::Result<T>;
46
47// =========================================================================
48// Type-safe enums for string parameters
49// =========================================================================
50
51use strum::{Display, EnumString};
52
53/// Auto-padding mode for convolution and pooling.
54#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
55pub enum AutoPad {
56    #[default]
57    #[strum(serialize = "NOTSET", serialize = "")]
58    NotSet,
59    #[strum(serialize = "VALID")]
60    Valid,
61    #[strum(serialize = "SAME_UPPER")]
62    SameUpper,
63    #[strum(serialize = "SAME_LOWER")]
64    SameLower,
65}
66
67/// Reduction mode for loss functions.
68#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
69pub enum Reduction {
70    #[strum(serialize = "none")]
71    None,
72    #[default]
73    #[strum(serialize = "mean")]
74    Mean,
75    #[strum(serialize = "sum")]
76    Sum,
77}
78
79/// Resize interpolation mode.
80#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
81pub enum ResizeMode {
82    #[default]
83    #[strum(serialize = "nearest")]
84    Nearest,
85    #[strum(serialize = "linear")]
86    Linear,
87    #[strum(serialize = "cubic")]
88    Cubic,
89}
90
91/// Coordinate transformation mode for resize.
92#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
93pub enum CoordinateTransformMode {
94    #[default]
95    #[strum(serialize = "half_pixel")]
96    HalfPixel,
97    #[strum(serialize = "align_corners")]
98    AlignCorners,
99    #[strum(serialize = "asymmetric")]
100    Asymmetric,
101    #[strum(serialize = "pytorch_half_pixel")]
102    PytorchHalfPixel,
103    #[strum(serialize = "half_pixel_symmetric")]
104    HalfPixelSymmetric,
105    #[strum(serialize = "tf_crop_and_resize")]
106    TfCropAndResize,
107}
108
109/// Nearest-neighbor rounding mode for resize.
110#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
111pub enum NearestMode {
112    #[default]
113    #[strum(serialize = "round_prefer_floor")]
114    RoundPreferFloor,
115    #[strum(serialize = "round_prefer_ceil")]
116    RoundPreferCeil,
117    #[strum(serialize = "floor")]
118    Floor,
119    #[strum(serialize = "ceil")]
120    Ceil,
121}
122
123/// Depth-to-space rearrangement mode.
124#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
125pub enum DepthToSpaceMode {
126    /// DCR: depth-column-row (default, ONNX standard).
127    #[default]
128    #[strum(serialize = "DCR")]
129    Dcr,
130    /// CRD: column-row-depth (PyTorch pixel_shuffle order).
131    #[strum(serialize = "CRD")]
132    Crd,
133}
134
135/// Padding fill mode.
136///
137/// Determines how values outside the original tensor are filled when padding.
138/// ONNX uses "edge"/"reflect"/"wrap"; Tinygrad uses "replicate"/"reflect"/"circular".
139#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
140pub enum PadMode {
141    /// Fill with a constant value (default: 0).
142    #[default]
143    #[strum(serialize = "constant")]
144    Constant,
145    /// Replicate boundary values. `[1,2,3]` pad(2,2) → `[1,1,1,2,3,3,3]`.
146    #[strum(serialize = "edge", serialize = "replicate")]
147    Replicate,
148    /// Mirror without repeating boundary. `[1,2,3]` pad(2,2) → `[3,2,1,2,3,2,1]`.
149    #[strum(serialize = "reflect")]
150    Reflect,
151    /// Wrap around (circular). `[1,2,3]` pad(2,2) → `[2,3,1,2,3,1,2]`.
152    #[strum(serialize = "wrap", serialize = "circular")]
153    Circular,
154}
155
156/// GridSample interpolation mode.
157#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
158pub enum GridSampleMode {
159    #[default]
160    #[strum(serialize = "linear", serialize = "bilinear")]
161    Linear,
162    #[strum(serialize = "nearest")]
163    Nearest,
164    #[strum(serialize = "cubic", serialize = "bicubic")]
165    Cubic,
166}
167
168/// GridSample padding mode.
169#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
170pub enum GridSamplePaddingMode {
171    #[default]
172    #[strum(serialize = "zeros")]
173    Zeros,
174    #[strum(serialize = "border")]
175    Border,
176    #[strum(serialize = "reflection")]
177    Reflection,
178}
179
180/// Aspect ratio policy for resize.
181#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
182pub enum AspectRatioPolicy {
183    #[default]
184    #[strum(serialize = "stretch")]
185    Stretch,
186    #[strum(serialize = "not_larger")]
187    NotLarger,
188    #[strum(serialize = "not_smaller")]
189    NotSmaller,
190}
191
192// =========================================================================
193// Higher-level building blocks (ONNX-style wrappers)
194// =========================================================================
195
196#[bon]
197impl Tensor {
198    /// Negative log-likelihood loss.
199    ///
200    /// `self` is `[N, C, ...]` log-probabilities, `target` is `[N, ...]` class indices
201    /// (dtype `i64`). Gathers the log-prob at the target class and negates it.
202    ///
203    /// Supports optional per-class `weight`, `ignore_index` to mask out a class,
204    /// and `reduction` (default `Mean`).
205    ///
206    /// # Examples
207    ///
208    /// ```
209    /// # use svod_tensor::Tensor;
210    /// # use ndarray::array;
211    /// let logprobs = Tensor::from_ndarray(&array![[-0.5f32, -1.0, -2.0]]);
212    /// let target = Tensor::from_slice([0i64]);
213    /// let mut loss = logprobs.nll_loss().target(&target).call().unwrap();
214    /// loss.realize().unwrap();
215    /// let val = loss.as_vec::<f32>().unwrap();
216    /// // -(-0.5) = 0.5
217    /// assert!((val[0] - 0.5).abs() < 1e-5);
218    /// ```
219    ///
220    /// With sum reduction:
221    ///
222    /// ```
223    /// # use svod_tensor::Tensor;
224    /// # use svod_tensor::nn::Reduction;
225    /// # use ndarray::array;
226    /// let logprobs = Tensor::from_ndarray(&array![[-0.5f32, -1.0], [-2.0, -0.3]]);
227    /// let target = Tensor::from_slice([0i64, 1]);
228    /// let mut loss = logprobs.nll_loss().target(&target).reduction(Reduction::Sum).call().unwrap();
229    /// loss.realize().unwrap();
230    /// let val = loss.as_vec::<f32>().unwrap();
231    /// // sum of 0.5 + 0.3 = 0.8
232    /// assert!((val[0] - 0.8).abs() < 1e-5);
233    /// ```
234    #[builder]
235    pub fn nll_loss(
236        &self,
237        target: &Tensor,
238        weight: Option<&Tensor>,
239        ignore_index: Option<i64>,
240        #[builder(default)] reduction: Reduction,
241    ) -> Result<Tensor> {
242        let ndim = self.ndim()?;
243        snafu::ensure!(ndim >= 2, NdimMinimumSnafu { op: "nll_loss", min: 2_usize, actual: ndim });
244        // Gather log-probs at target class, negate
245        let nll = self.gather(1, &target.try_unsqueeze(1)?)?.try_squeeze(Some(1))?.try_neg()?;
246
247        // Per-sample weight: weight[target] or ones
248        let sample_weight = match weight {
249            Some(w) => {
250                let flat = target.try_reshape([-1])?;
251                let sel = w.gather(0, &flat)?;
252                let target_shape = svod_ir::shape::to_vec_isize(&target.shape()?).context(UOpSnafu)?;
253                sel.try_reshape(&target_shape)?
254            }
255            None => {
256                let shape = svod_ir::shape::to_vec_usize(&target.shape()?).context(UOpSnafu)?;
257                Tensor::full(&shape, 1.0, self.uop().dtype())?
258            }
259        };
260
261        // Mask out ignore_index
262        let masked_weight = match ignore_index {
263            Some(idx) => {
264                let mask = target.try_ne(&Tensor::const_(idx as f64, target.uop().dtype()))?;
265                sample_weight.try_mul(&mask.cast(sample_weight.uop().dtype())?)?
266            }
267            None => sample_weight,
268        };
269
270        let weighted_loss = nll.try_mul(&masked_weight)?;
271        match reduction {
272            Reduction::Mean => weighted_loss.sum(AxisSpec::All)?.try_div(&masked_weight.sum(AxisSpec::All)?),
273            Reduction::Sum => weighted_loss.sum(AxisSpec::All),
274            Reduction::None => Ok(weighted_loss),
275        }
276    }
277
278    /// Dropout: randomly zeros elements during training, passes through in inference.
279    ///
280    /// Returns `(output, mask)` where mask is a boolean tensor (`true` = kept).
281    /// In inference mode (`training=false`, the default), the output is identical
282    /// to the input and the mask is all-true.
283    ///
284    /// **Note:** Training mode is not yet implemented (requires RNG); currently
285    /// returns identity regardless of `training`.
286    ///
287    /// # Examples
288    ///
289    /// ```
290    /// # use svod_tensor::Tensor;
291    /// # use ndarray::array;
292    /// let x = Tensor::from_ndarray(&array![1.0f32, 2.0, 3.0]);
293    /// let (mut out, mut mask) = x.dropout().p(0.5).call().unwrap();
294    /// out.realize().unwrap();
295    /// mask.realize().unwrap();
296    /// // Default is inference mode: output == input
297    /// assert_eq!(out.as_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0]);
298    /// assert_eq!(mask.as_vec::<bool>().unwrap(), vec![true, true, true]);
299    /// ```
300    #[builder]
301    pub fn dropout(&self, p: f64, #[builder(default = false)] training: bool) -> Result<(Tensor, Tensor)> {
302        snafu::ensure!(
303            (0.0..=1.0).contains(&p),
304            ParamRangeSnafu { op: "dropout", param: "p", value: p.to_string(), constraint: "0.0 <= p <= 1.0" }
305        );
306        let _ = p;
307        let shape = svod_ir::shape::to_vec_usize(&self.shape()?).context(UOpSnafu)?;
308        if !training {
309            let mask = Tensor::full(&shape, true, DType::Bool)?;
310            return Ok((self.clone(), mask));
311        }
312        // Training mode deferred (needs RNG: rand_like / Threefry)
313        let mask = Tensor::full(&shape, true, DType::Bool)?;
314        Ok((self.clone(), mask))
315    }
316    /// Convolution with ONNX-style parameters.
317    ///
318    /// Wraps the lower-level [`conv2d`](Tensor::conv2d) after resolving ONNX padding conventions
319    /// (`auto_pad`, flat `pads`). Input shape is `[N, C, H, W, ...]` and weight
320    /// shape is `[out_channels, in_channels/group, kH, kW, ...]`.
321    ///
322    /// # Examples
323    ///
324    /// Basic convolution with no padding:
325    ///
326    /// ```
327    /// # use svod_tensor::Tensor;
328    /// # use ndarray::Array4;
329    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 5, 5), 1.0f32));
330    /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
331    /// let mut y = x.conv().weight(&w).call().unwrap();
332    /// y.realize().unwrap();
333    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
334    /// assert_eq!(shape, [1, 1, 3, 3]);
335    /// // Each output element sums a 3x3 window of ones = 9.0
336    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![9.0; 9]);
337    /// ```
338    ///
339    /// With explicit padding and strides:
340    ///
341    /// ```
342    /// # use svod_tensor::Tensor;
343    /// # use ndarray::Array4;
344    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 5, 5), 1.0f32));
345    /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
346    /// let mut y = x.conv().weight(&w).pads(&[1, 1, 1, 1]).strides(&[2, 2]).call().unwrap();
347    /// y.realize().unwrap();
348    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
349    /// assert_eq!(shape, [1, 1, 3, 3]);
350    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![4.0, 6.0, 4.0, 6.0, 9.0, 6.0, 4.0, 6.0, 4.0]);
351    /// ```
352    #[builder]
353    pub fn conv(
354        &self,
355        weight: &Tensor,
356        bias: Option<&Tensor>,
357        #[builder(default)] auto_pad: AutoPad,
358        #[builder(default = 1)] group: usize,
359        kernel_shape: Option<&[usize]>,
360        pads: Option<&[i64]>,
361        strides: Option<&[i64]>,
362        dilations: Option<&[i64]>,
363    ) -> Result<Tensor> {
364        let w_shape = weight.shape()?;
365        let kernel: Vec<usize> = kernel_shape
366            .map(|ks| ks.to_vec())
367            .unwrap_or_else(|| w_shape[2..].iter().map(|s| s.as_const().unwrap()).collect());
368        let n = kernel.len();
369        let strides_u: Vec<usize> =
370            strides.map(|s| s.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
371        let dilations_u: Vec<usize> =
372            dilations.map(|d| d.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
373        let x_shape = self.shape()?;
374        let input_spatial: Vec<SInt> = x_shape[2..].to_vec();
375        let empty_pads: Vec<i64> = vec![];
376        let padding =
377            resolve_pool_pads(&input_spatial, pads.unwrap_or(&empty_pads), &kernel, &dilations_u, &strides_u, auto_pad);
378        self.conv2d()
379            .weight(weight)
380            .maybe_bias(bias)
381            .groups(group)
382            .stride(&strides_u)
383            .dilation(&dilations_u)
384            .padding(&padding)
385            .call()
386    }
387
388    /// Transposed convolution with ONNX-style parameters.
389    ///
390    /// Wraps [`conv_transpose2d`](Tensor::conv_transpose2d) after resolving ONNX padding conventions.
391    /// Supports `output_shape` and `output_padding` for precise output size control.
392    ///
393    /// # Examples
394    ///
395    /// Basic transposed convolution (upsampling):
396    ///
397    /// ```
398    /// # use svod_tensor::Tensor;
399    /// # use ndarray::Array4;
400    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
401    /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
402    /// let mut y = x.conv_transpose().weight(&w).call().unwrap();
403    /// y.realize().unwrap();
404    /// let vals = y.as_vec::<f32>().unwrap();
405    /// assert_eq!(vals.len(), 16); // 4x4 output
406    /// assert_eq!(vals[5], 4.0); // center sees full overlap
407    /// ```
408    ///
409    /// With stride (larger upsampling factor):
410    ///
411    /// ```
412    /// # use svod_tensor::Tensor;
413    /// # use ndarray::Array4;
414    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
415    /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
416    /// let mut y = x.conv_transpose().weight(&w).strides(&[2, 2]).call().unwrap();
417    /// y.realize().unwrap();
418    /// let vals = y.as_vec::<f32>().unwrap();
419    /// assert_eq!(vals.len(), 25); // 5x5 output
420    /// ```
421    #[builder]
422    pub fn conv_transpose(
423        &self,
424        weight: &Tensor,
425        bias: Option<&Tensor>,
426        #[builder(default)] auto_pad: AutoPad,
427        #[builder(default = 1)] group: usize,
428        kernel_shape: Option<&[usize]>,
429        pads: Option<&[i64]>,
430        output_shape: Option<&[i64]>,
431        output_padding: Option<&[usize]>,
432        strides: Option<&[i64]>,
433        dilations: Option<&[i64]>,
434    ) -> Result<Tensor> {
435        let w_shape = weight.shape()?;
436        let kernel: Vec<usize> = kernel_shape
437            .map(|ks| ks.to_vec())
438            .unwrap_or_else(|| w_shape[2..].iter().map(|s| s.as_const().unwrap()).collect());
439        let n = kernel.len();
440        let x_shape = self.shape()?;
441        let input_spatial: Vec<SInt> = x_shape[2..].to_vec();
442        let strides_u: Vec<usize> =
443            strides.map(|s| s.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
444        let dilations_u: Vec<usize> =
445            dilations.map(|d| d.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
446        let output_padding_u: Vec<usize> = output_padding.map(|op| op.to_vec()).unwrap_or_else(|| vec![0; n]);
447
448        // 3-path padding resolution (matches Tinygrad's ConvTranspose)
449        let mut pads_resolved: Option<Vec<isize>> = None;
450
451        // ConvTranspose padding resolution requires concrete spatial dims.
452        let input_spatial_c: Vec<usize> = input_spatial
453            .iter()
454            .map(|s| s.as_const().expect("conv_transpose requires concrete spatial dims"))
455            .collect();
456
457        // Path 1: output_shape provided → derive total pads, apply auto_pad
458        if let Some(os) = output_shape {
459            let total_pads: Vec<isize> = (0..n)
460                .map(|i| {
461                    (strides_u[i] * (input_spatial_c[i] - 1)
462                        + output_padding_u[i]
463                        + (kernel[i] - 1) * dilations_u[i]
464                        + 1) as isize
465                        - os[i] as isize
466                })
467                .collect();
468            pads_resolved = Some(auto_pad_split(&total_pads, auto_pad));
469        }
470
471        // Path 2: no explicit pads → derive from default output_shape
472        if pads_resolved.is_none() && pads.is_none_or(|p| p.is_empty()) {
473            let default_out: Vec<usize> = (0..n).map(|i| input_spatial_c[i] * strides_u[i]).collect();
474            let total_pads: Vec<isize> = (0..n)
475                .map(|i| {
476                    (strides_u[i] * (input_spatial_c[i] - 1)
477                        + output_padding_u[i]
478                        + (kernel[i] - 1) * dilations_u[i]
479                        + 1) as isize
480                        - default_out[i] as isize
481                })
482                .collect();
483            pads_resolved =
484                Some(if auto_pad != AutoPad::NotSet { auto_pad_split(&total_pads, auto_pad) } else { vec![0; n * 2] });
485        }
486
487        // Path 3: explicit pads provided
488        let padding: Vec<(isize, isize)> = if let Some(flat) = pads_resolved {
489            let half = flat.len() / 2;
490            (0..half).map(|i| (flat[i], flat[i + half])).collect()
491        } else {
492            flat_pads_to_pairs(pads.unwrap())
493        };
494
495        self.conv_transpose2d()
496            .weight(weight)
497            .maybe_bias(bias)
498            .groups(group)
499            .stride(&strides_u)
500            .dilation(&dilations_u)
501            .padding(&padding)
502            .output_padding(&output_padding_u)
503            .call()
504    }
505
506    /// Average pooling with ONNX-style parameters.
507    ///
508    /// Wraps [`avg_pool2d`](Tensor::avg_pool2d) after resolving ONNX padding and stride conventions.
509    /// Stride defaults to 1 (unlike [`avg_pool2d`](Tensor::avg_pool2d) which defaults to `kernel_size`).
510    /// Input shape is `[N, C, H, W]`.
511    ///
512    /// # Examples
513    ///
514    /// ```
515    /// # use svod_tensor::Tensor;
516    /// # use ndarray::Array4;
517    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
518    /// let mut y = x.avg_pool().kernel_shape(&[2, 2]).call().unwrap();
519    /// y.realize().unwrap();
520    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
521    /// assert_eq!(shape, [1, 1, 3, 3]);
522    /// // Average of all-ones windows is 1.0
523    /// assert!(y.as_vec::<f32>().unwrap().iter().all(|&v| (v - 1.0).abs() < 1e-6));
524    /// ```
525    ///
526    /// With strides:
527    ///
528    /// ```
529    /// # use svod_tensor::Tensor;
530    /// # use ndarray::Array4;
531    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
532    /// let mut y = x.avg_pool().kernel_shape(&[2, 2]).strides(&[2, 2]).call().unwrap();
533    /// y.realize().unwrap();
534    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
535    /// assert_eq!(shape, [1, 1, 2, 2]);
536    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 4]);
537    /// ```
538    #[builder]
539    pub fn avg_pool(
540        &self,
541        kernel_shape: &[usize],
542        #[builder(default)] auto_pad: AutoPad,
543        #[builder(default = false)] ceil_mode: bool,
544        #[builder(default = false)] count_include_pad: bool,
545        pads: Option<&[i64]>,
546        strides: Option<&[i64]>,
547        dilations: Option<&[i64]>,
548    ) -> Result<Tensor> {
549        let n = kernel_shape.len();
550        let strides_u: Vec<usize> =
551            strides.map(|s| s.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
552        let dilations_u: Vec<usize> =
553            dilations.map(|d| d.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
554        let x_shape = self.shape()?;
555        let input_spatial: Vec<SInt> = x_shape[2..].to_vec();
556        let empty_pads: Vec<i64> = vec![];
557        let padding = resolve_pool_pads(
558            &input_spatial,
559            pads.unwrap_or(&empty_pads),
560            kernel_shape,
561            &dilations_u,
562            &strides_u,
563            auto_pad,
564        );
565        self.avg_pool2d()
566            .kernel_size(kernel_shape)
567            .stride(&strides_u)
568            .dilation(&dilations_u)
569            .padding(&padding)
570            .ceil_mode(ceil_mode)
571            .count_include_pad(count_include_pad)
572            .call()
573    }
574
575    /// Lp norm pooling with ONNX-style parameters.
576    ///
577    /// Computes `(sum(|x|^p))^(1/p)` over each pooling window. Defaults to
578    /// `p=2` (L2 pooling). Input shape is `[N, C, H, W]`.
579    ///
580    /// # Examples
581    ///
582    /// ```
583    /// # use svod_tensor::Tensor;
584    /// # use ndarray::Array4;
585    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
586    /// let mut y = x.lp_pool().kernel_shape(&[2, 2]).call().unwrap();
587    /// y.realize().unwrap();
588    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
589    /// assert_eq!(shape, [1, 1, 3, 3]);
590    /// // L2 pool of 2x2 window of ones = sqrt(4) = 2.0
591    /// assert!((y.as_vec::<f32>().unwrap()[0] - 2.0).abs() < 1e-5);
592    /// ```
593    #[builder]
594    pub fn lp_pool(
595        &self,
596        kernel_shape: &[usize],
597        #[builder(default = 2)] p: usize,
598        #[builder(default)] auto_pad: AutoPad,
599        #[builder(default = false)] ceil_mode: bool,
600        pads: Option<&[i64]>,
601        strides: Option<&[i64]>,
602        dilations: Option<&[i64]>,
603    ) -> Result<Tensor> {
604        snafu::ensure!(p >= 1, ParamRangeSnafu { op: "lp_pool", param: "p", value: p.to_string(), constraint: ">= 1" });
605        let n_spatial = kernel_shape.len();
606        let strides_u: Vec<usize> =
607            strides.map(|s| s.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n_spatial]);
608        let dilations_u: Vec<usize> =
609            dilations.map(|d| d.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n_spatial]);
610        let x_shape = self.shape()?;
611        let input_spatial: Vec<SInt> = x_shape[2..].to_vec();
612        let empty_pads: Vec<i64> = vec![];
613        let padding = resolve_pool_pads(
614            &input_spatial,
615            pads.unwrap_or(&empty_pads),
616            kernel_shape,
617            &dilations_u,
618            &strides_u,
619            auto_pad,
620        );
621
622        let p_f = p as f64;
623        let dtype = self.uop().dtype();
624        let p_tensor = Tensor::const_(p_f, dtype.clone());
625        let inv_p = Tensor::const_(1.0 / p_f, dtype);
626        let x_abs_p = self.try_abs()?.try_pow(&p_tensor)?;
627
628        // Pad, pool (create windows), then sum over kernel axes.
629        // This computes sum(|x|^p) directly — correct for all padding/ceil modes
630        // because padded zeros contribute 0 to the sum.
631        let reg_pads = padding;
632        let ceil_pads = if ceil_mode {
633            pad::apply_ceil_mode(&reg_pads, &input_spatial, kernel_shape, &strides_u, &dilations_u)
634        } else {
635            reg_pads.clone()
636        };
637        let pads_to_use = if ceil_mode { &ceil_pads } else { &reg_pads };
638        let mut padded = x_abs_p;
639        if pads_to_use.iter().any(|&(b, e)| b != 0 || e != 0) {
640            let n_batch = x_shape.len() - n_spatial;
641            let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); n_batch];
642            full_pad.extend_from_slice(pads_to_use);
643            padded = padded.try_pad(&full_pad)?;
644        }
645        let pooled = padded.pool(kernel_shape, &strides_u, &dilations_u)?;
646        let reduce_axes: Vec<isize> = (0..n_spatial).map(|j| -(1 + j as isize)).collect();
647        let sum_p = pooled.sum(crate::reduce::AxisSpec::Multiple(reduce_axes))?;
648        sum_p.try_pow(&inv_p)
649    }
650
651    /// Rearrange depth data into spatial blocks (inverse of [`space_to_depth`](Tensor::space_to_depth)).
652    ///
653    /// Equivalent to PyTorch's `F.pixel_shuffle`. Reshapes a `[N, C, H, W]`
654    /// tensor to `[N, C/(b*b), H*b, W*b]` where `b` is the blocksize.
655    ///
656    /// # Examples
657    ///
658    /// ```
659    /// # use svod_tensor::Tensor;
660    /// # use ndarray::Array4;
661    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 4, 1, 1), 1.0f32));
662    /// let mut y = x.depth_to_space().blocksize(2).call().unwrap();
663    /// y.realize().unwrap();
664    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
665    /// assert_eq!(shape, [1, 1, 2, 2]);
666    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 4]);
667    /// ```
668    ///
669    /// Using CRD mode (PyTorch pixel_shuffle order):
670    ///
671    /// ```
672    /// # use svod_tensor::Tensor;
673    /// # use svod_tensor::nn::DepthToSpaceMode;
674    /// # use ndarray::Array4;
675    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 4, 1, 1), 1.0f32));
676    /// let mut y = x.depth_to_space().blocksize(2).mode(DepthToSpaceMode::Crd).call().unwrap();
677    /// y.realize().unwrap();
678    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 4]);
679    /// ```
680    #[builder]
681    pub fn depth_to_space(&self, blocksize: usize, #[builder(default)] mode: DepthToSpaceMode) -> Result<Tensor> {
682        let ndim = self.ndim()?;
683        snafu::ensure!(ndim == 4, NdimExactSnafu { op: "depth_to_space", expected: 4_usize, actual: ndim });
684        snafu::ensure!(
685            blocksize > 0,
686            ParamRangeSnafu {
687                op: "depth_to_space",
688                param: "blocksize",
689                value: blocksize.to_string(),
690                constraint: "> 0"
691            }
692        );
693        let shape = self.shape()?;
694        let (b, c, h, w) = (
695            shape[0].as_const().unwrap(),
696            shape[1].as_const().unwrap(),
697            shape[2].as_const().unwrap(),
698            shape[3].as_const().unwrap(),
699        );
700        let bs_sq = blocksize * blocksize;
701        snafu::ensure!(
702            c.is_multiple_of(bs_sq),
703            DivisibilitySnafu {
704                op: "depth_to_space",
705                lhs_name: "channels",
706                lhs: c,
707                rhs_name: "blocksize^2",
708                rhs: bs_sq
709            }
710        );
711        let c_out = c / bs_sq;
712        let result = if mode == DepthToSpaceMode::Crd {
713            self.try_reshape([
714                b as isize,
715                c_out as isize,
716                blocksize as isize,
717                blocksize as isize,
718                h as isize,
719                w as isize,
720            ])?
721            .try_permute(&[0, 1, 4, 2, 5, 3])?
722        } else {
723            // DCR (default)
724            self.try_reshape([
725                b as isize,
726                blocksize as isize,
727                blocksize as isize,
728                c_out as isize,
729                h as isize,
730                w as isize,
731            ])?
732            .try_permute(&[0, 3, 4, 1, 5, 2])?
733        };
734        result.try_reshape([b as isize, c_out as isize, (h * blocksize) as isize, (w * blocksize) as isize])
735    }
736
737    /// Rearrange spatial data into depth (inverse of [`depth_to_space`](Tensor::depth_to_space)).
738    ///
739    /// Reshapes a `[N, C, H, W]` tensor to `[N, C*b*b, H/b, W/b]` where `b`
740    /// is the blocksize. Both `H` and `W` must be divisible by `blocksize`.
741    ///
742    /// # Examples
743    ///
744    /// ```
745    /// # use svod_tensor::Tensor;
746    /// # use ndarray::Array4;
747    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
748    /// let mut y = x.space_to_depth(2).unwrap();
749    /// y.realize().unwrap();
750    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
751    /// assert_eq!(shape, [1, 4, 2, 2]);
752    /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 16]);
753    /// ```
754    pub fn space_to_depth(&self, blocksize: usize) -> Result<Tensor> {
755        let ndim = self.ndim()?;
756        snafu::ensure!(ndim == 4, NdimExactSnafu { op: "space_to_depth", expected: 4_usize, actual: ndim });
757        snafu::ensure!(
758            blocksize > 0,
759            ParamRangeSnafu {
760                op: "space_to_depth",
761                param: "blocksize",
762                value: blocksize.to_string(),
763                constraint: "> 0"
764            }
765        );
766        let shape = self.shape()?;
767        let (b, c, h, w) = (
768            shape[0].as_const().unwrap(),
769            shape[1].as_const().unwrap(),
770            shape[2].as_const().unwrap(),
771            shape[3].as_const().unwrap(),
772        );
773        snafu::ensure!(
774            h.is_multiple_of(blocksize),
775            DivisibilitySnafu {
776                op: "space_to_depth",
777                lhs_name: "height",
778                lhs: h,
779                rhs_name: "blocksize",
780                rhs: blocksize
781            }
782        );
783        snafu::ensure!(
784            w.is_multiple_of(blocksize),
785            DivisibilitySnafu {
786                op: "space_to_depth",
787                lhs_name: "width",
788                lhs: w,
789                rhs_name: "blocksize",
790                rhs: blocksize
791            }
792        );
793        self.try_reshape([
794            b as isize,
795            c as isize,
796            (h / blocksize) as isize,
797            blocksize as isize,
798            (w / blocksize) as isize,
799            blocksize as isize,
800        ])?
801        .try_permute(&[0, 3, 5, 1, 2, 4])?
802        .try_reshape([
803            b as isize,
804            (c * blocksize * blocksize) as isize,
805            (h / blocksize) as isize,
806            (w / blocksize) as isize,
807        ])
808    }
809
810    /// Max pooling with ONNX-style parameters.
811    ///
812    /// Always returns `(values, indices)` where indices are flattened positions
813    /// (dtype `i64`). Wraps [`max_pool2d_with_indices`](Tensor::max_pool2d_with_indices) after resolving ONNX
814    /// padding conventions.
815    ///
816    /// # Examples
817    ///
818    /// ```
819    /// # use svod_tensor::Tensor;
820    /// # use ndarray::Array4;
821    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
822    /// let (vals, indices) = x.max_pool().kernel_shape(&[2, 2]).call().unwrap();
823    /// let shape: Vec<_> = vals.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
824    /// assert_eq!(shape, [1, 1, 3, 3]);
825    /// ```
826    ///
827    /// With strides:
828    ///
829    /// ```
830    /// # use svod_tensor::Tensor;
831    /// # use ndarray::Array4;
832    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
833    /// let (vals, _) = x.max_pool().kernel_shape(&[2, 2]).strides(&[2, 2]).call().unwrap();
834    /// let shape: Vec<_> = vals.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
835    /// assert_eq!(shape, [1, 1, 2, 2]);
836    /// ```
837    #[builder]
838    pub fn max_pool(
839        &self,
840        kernel_shape: &[usize],
841        #[builder(default)] auto_pad: AutoPad,
842        #[builder(default = false)] ceil_mode: bool,
843        #[builder(default = 0)] storage_order: usize,
844        pads: Option<&[i64]>,
845        strides: Option<&[i64]>,
846        dilations: Option<&[i64]>,
847    ) -> Result<(Tensor, Tensor)> {
848        let n = kernel_shape.len();
849        let strides_u: Vec<usize> =
850            strides.map(|s| s.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
851        let dilations_u: Vec<usize> =
852            dilations.map(|d| d.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
853        let x_shape = self.shape()?;
854        let input_spatial: Vec<SInt> = x_shape[2..].to_vec();
855        let empty_pads: Vec<i64> = vec![];
856        let padding = resolve_pool_pads(
857            &input_spatial,
858            pads.unwrap_or(&empty_pads),
859            kernel_shape,
860            &dilations_u,
861            &strides_u,
862            auto_pad,
863        );
864        let (values, indices) = self
865            .max_pool2d_with_indices()
866            .kernel_size(kernel_shape)
867            .stride(&strides_u)
868            .dilation(&dilations_u)
869            .padding(&padding)
870            .ceil_mode(ceil_mode)
871            .call()?;
872        let indices = if storage_order == 1 {
873            indices.try_transpose(-2, -1)?.cast(DType::Int64)?
874        } else {
875            indices.cast(DType::Int64)?
876        };
877        Ok((values, indices))
878    }
879
880    /// Local Response Normalization (LRN).
881    ///
882    /// Normalizes each element by dividing by a scaled sum of squares over a
883    /// local neighborhood of `size` channels:
884    /// `y = x / (bias + alpha * avg_pool(x^2, size))^beta`.
885    ///
886    /// Input must be 4-D `[N, C, H, W]`.
887    ///
888    /// # Examples
889    ///
890    /// ```
891    /// # use svod_tensor::Tensor;
892    /// # use ndarray::Array4;
893    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 3, 2, 2), 1.0f32));
894    /// let y = x.lrn().size(3).call().unwrap();
895    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
896    /// assert_eq!(shape, [1, 3, 2, 2]);
897    /// ```
898    ///
899    /// Custom alpha, beta, and bias:
900    ///
901    /// ```
902    /// # use svod_tensor::Tensor;
903    /// # use ndarray::Array4;
904    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 3, 2, 2), 1.0f32));
905    /// let y = x.lrn().size(3).alpha(0.001).beta(0.5).bias(2.0).call().unwrap();
906    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
907    /// assert_eq!(shape, [1, 3, 2, 2]);
908    /// ```
909    #[builder]
910    pub fn lrn(
911        &self,
912        size: usize,
913        #[builder(default = 0.0001)] alpha: f64,
914        #[builder(default = 0.75)] beta: f64,
915        #[builder(default = 1.0)] bias: f64,
916    ) -> Result<Tensor> {
917        let ndim = self.ndim()?;
918        snafu::ensure!(ndim == 4, NdimExactSnafu { op: "lrn", expected: 4_usize, actual: ndim });
919        snafu::ensure!(
920            size > 0,
921            ParamRangeSnafu { op: "lrn", param: "size", value: size.to_string(), constraint: "> 0" }
922        );
923        let shape = self.shape()?;
924        let (b, c, h, w) = (
925            shape[0].as_const().unwrap(),
926            shape[1].as_const().unwrap(),
927            shape[2].as_const().unwrap(),
928            shape[3].as_const().unwrap(),
929        );
930        let x_sq = self.square()?;
931        let x_sq = x_sq.try_reshape([b as isize, 1, c as isize, (h * w) as isize])?;
932        let pad_before = ((size - 1) / 2) as isize;
933        let pad_after = (size / 2) as isize;
934        let x_sq = x_sq.try_pad(&[(0, 0), (0, 0), (pad_before, pad_after), (0, 0)])?;
935        let pooled = x_sq.avg_pool2d().kernel_size(&[size, 1]).stride(&[1, 1]).call()?;
936        let pooled = pooled.try_reshape([b as isize, c as isize, h as isize, w as isize])?;
937        let dtype = self.uop().dtype();
938        let scale = pooled
939            .try_mul(&Tensor::const_(alpha, dtype.clone()))?
940            .try_add(&Tensor::const_(bias, dtype.clone()))?
941            .try_pow(&Tensor::const_(beta, dtype))?;
942        self.try_div(&scale)
943    }
944}
945
946impl Tensor {
947    /// Apply a sequence of layers to this tensor.
948    pub fn sequential(&self, layers: &[&dyn Layer]) -> Result<Tensor> {
949        let mut x = self.clone();
950        for layer in layers {
951            x = layer.forward(&x)?;
952        }
953        Ok(x)
954    }
955}