svod_tensor/nn/mod.rs
1//! Neural network operations: convolution, pooling, normalization.
2
3mod conv;
4mod conv1d;
5mod grid_sample;
6mod linear;
7mod lstm_cell;
8mod norm;
9pub mod pad;
10mod pool;
11mod quantize;
12mod resize;
13mod rnn;
14
15pub use conv1d::Conv1d;
16pub use linear::Linear;
17pub use lstm_cell::LSTMCell;
18pub use rnn::{GruOutput, LstmOutput, RnnOutput};
19
20/// A neural network layer.
21pub trait Layer {
22 fn forward(&self, x: &Tensor) -> Result<Tensor>;
23}
24
25/// ReLU activation layer: `max(0, x)`.
26pub struct Relu;
27
28impl Layer for Relu {
29 fn forward(&self, x: &Tensor) -> Result<Tensor> {
30 x.relu()
31 }
32}
33
34pub use pad::{auto_pad_split, flat_pads_to_pairs, resolve_pool_pads};
35
36use bon::bon;
37use snafu::ResultExt;
38use svod_dtype::DType;
39use svod_ir::SInt;
40
41use crate::Tensor;
42use crate::error::{DivisibilitySnafu, NdimExactSnafu, NdimMinimumSnafu, ParamRangeSnafu, UOpSnafu};
43use crate::reduce::AxisSpec;
44
45type Result<T> = crate::Result<T>;
46
47// =========================================================================
48// Type-safe enums for string parameters
49// =========================================================================
50
51use strum::{Display, EnumString};
52
53/// Auto-padding mode for convolution and pooling.
54#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
55pub enum AutoPad {
56 #[default]
57 #[strum(serialize = "NOTSET", serialize = "")]
58 NotSet,
59 #[strum(serialize = "VALID")]
60 Valid,
61 #[strum(serialize = "SAME_UPPER")]
62 SameUpper,
63 #[strum(serialize = "SAME_LOWER")]
64 SameLower,
65}
66
67/// Reduction mode for loss functions.
68#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
69pub enum Reduction {
70 #[strum(serialize = "none")]
71 None,
72 #[default]
73 #[strum(serialize = "mean")]
74 Mean,
75 #[strum(serialize = "sum")]
76 Sum,
77}
78
79/// Resize interpolation mode.
80#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
81pub enum ResizeMode {
82 #[default]
83 #[strum(serialize = "nearest")]
84 Nearest,
85 #[strum(serialize = "linear")]
86 Linear,
87 #[strum(serialize = "cubic")]
88 Cubic,
89}
90
91/// Coordinate transformation mode for resize.
92#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
93pub enum CoordinateTransformMode {
94 #[default]
95 #[strum(serialize = "half_pixel")]
96 HalfPixel,
97 #[strum(serialize = "align_corners")]
98 AlignCorners,
99 #[strum(serialize = "asymmetric")]
100 Asymmetric,
101 #[strum(serialize = "pytorch_half_pixel")]
102 PytorchHalfPixel,
103 #[strum(serialize = "half_pixel_symmetric")]
104 HalfPixelSymmetric,
105 #[strum(serialize = "tf_crop_and_resize")]
106 TfCropAndResize,
107}
108
109/// Nearest-neighbor rounding mode for resize.
110#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
111pub enum NearestMode {
112 #[default]
113 #[strum(serialize = "round_prefer_floor")]
114 RoundPreferFloor,
115 #[strum(serialize = "round_prefer_ceil")]
116 RoundPreferCeil,
117 #[strum(serialize = "floor")]
118 Floor,
119 #[strum(serialize = "ceil")]
120 Ceil,
121}
122
123/// Depth-to-space rearrangement mode.
124#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
125pub enum DepthToSpaceMode {
126 /// DCR: depth-column-row (default, ONNX standard).
127 #[default]
128 #[strum(serialize = "DCR")]
129 Dcr,
130 /// CRD: column-row-depth (PyTorch pixel_shuffle order).
131 #[strum(serialize = "CRD")]
132 Crd,
133}
134
135/// Padding fill mode.
136///
137/// Determines how values outside the original tensor are filled when padding.
138/// ONNX uses "edge"/"reflect"/"wrap"; Tinygrad uses "replicate"/"reflect"/"circular".
139#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
140pub enum PadMode {
141 /// Fill with a constant value (default: 0).
142 #[default]
143 #[strum(serialize = "constant")]
144 Constant,
145 /// Replicate boundary values. `[1,2,3]` pad(2,2) → `[1,1,1,2,3,3,3]`.
146 #[strum(serialize = "edge", serialize = "replicate")]
147 Replicate,
148 /// Mirror without repeating boundary. `[1,2,3]` pad(2,2) → `[3,2,1,2,3,2,1]`.
149 #[strum(serialize = "reflect")]
150 Reflect,
151 /// Wrap around (circular). `[1,2,3]` pad(2,2) → `[2,3,1,2,3,1,2]`.
152 #[strum(serialize = "wrap", serialize = "circular")]
153 Circular,
154}
155
156/// GridSample interpolation mode.
157#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
158pub enum GridSampleMode {
159 #[default]
160 #[strum(serialize = "linear", serialize = "bilinear")]
161 Linear,
162 #[strum(serialize = "nearest")]
163 Nearest,
164 #[strum(serialize = "cubic", serialize = "bicubic")]
165 Cubic,
166}
167
168/// GridSample padding mode.
169#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
170pub enum GridSamplePaddingMode {
171 #[default]
172 #[strum(serialize = "zeros")]
173 Zeros,
174 #[strum(serialize = "border")]
175 Border,
176 #[strum(serialize = "reflection")]
177 Reflection,
178}
179
180/// Aspect ratio policy for resize.
181#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
182pub enum AspectRatioPolicy {
183 #[default]
184 #[strum(serialize = "stretch")]
185 Stretch,
186 #[strum(serialize = "not_larger")]
187 NotLarger,
188 #[strum(serialize = "not_smaller")]
189 NotSmaller,
190}
191
192// =========================================================================
193// Higher-level building blocks (ONNX-style wrappers)
194// =========================================================================
195
196#[bon]
197impl Tensor {
198 /// Negative log-likelihood loss.
199 ///
200 /// `self` is `[N, C, ...]` log-probabilities, `target` is `[N, ...]` class indices
201 /// (dtype `i64`). Gathers the log-prob at the target class and negates it.
202 ///
203 /// Supports optional per-class `weight`, `ignore_index` to mask out a class,
204 /// and `reduction` (default `Mean`).
205 ///
206 /// # Examples
207 ///
208 /// ```
209 /// # use svod_tensor::Tensor;
210 /// # use ndarray::array;
211 /// let logprobs = Tensor::from_ndarray(&array![[-0.5f32, -1.0, -2.0]]);
212 /// let target = Tensor::from_slice([0i64]);
213 /// let mut loss = logprobs.nll_loss().target(&target).call().unwrap();
214 /// loss.realize().unwrap();
215 /// let val = loss.as_vec::<f32>().unwrap();
216 /// // -(-0.5) = 0.5
217 /// assert!((val[0] - 0.5).abs() < 1e-5);
218 /// ```
219 ///
220 /// With sum reduction:
221 ///
222 /// ```
223 /// # use svod_tensor::Tensor;
224 /// # use svod_tensor::nn::Reduction;
225 /// # use ndarray::array;
226 /// let logprobs = Tensor::from_ndarray(&array![[-0.5f32, -1.0], [-2.0, -0.3]]);
227 /// let target = Tensor::from_slice([0i64, 1]);
228 /// let mut loss = logprobs.nll_loss().target(&target).reduction(Reduction::Sum).call().unwrap();
229 /// loss.realize().unwrap();
230 /// let val = loss.as_vec::<f32>().unwrap();
231 /// // sum of 0.5 + 0.3 = 0.8
232 /// assert!((val[0] - 0.8).abs() < 1e-5);
233 /// ```
234 #[builder]
235 pub fn nll_loss(
236 &self,
237 target: &Tensor,
238 weight: Option<&Tensor>,
239 ignore_index: Option<i64>,
240 #[builder(default)] reduction: Reduction,
241 ) -> Result<Tensor> {
242 let ndim = self.ndim()?;
243 snafu::ensure!(ndim >= 2, NdimMinimumSnafu { op: "nll_loss", min: 2_usize, actual: ndim });
244 // Gather log-probs at target class, negate
245 let nll = self.gather(1, &target.try_unsqueeze(1)?)?.try_squeeze(Some(1))?.try_neg()?;
246
247 // Per-sample weight: weight[target] or ones
248 let sample_weight = match weight {
249 Some(w) => {
250 let flat = target.try_reshape([-1])?;
251 let sel = w.gather(0, &flat)?;
252 let target_shape = svod_ir::shape::to_vec_isize(&target.shape()?).context(UOpSnafu)?;
253 sel.try_reshape(&target_shape)?
254 }
255 None => {
256 let shape = svod_ir::shape::to_vec_usize(&target.shape()?).context(UOpSnafu)?;
257 Tensor::full(&shape, 1.0, self.uop().dtype())?
258 }
259 };
260
261 // Mask out ignore_index
262 let masked_weight = match ignore_index {
263 Some(idx) => {
264 let mask = target.try_ne(&Tensor::const_(idx as f64, target.uop().dtype()))?;
265 sample_weight.try_mul(&mask.cast(sample_weight.uop().dtype())?)?
266 }
267 None => sample_weight,
268 };
269
270 let weighted_loss = nll.try_mul(&masked_weight)?;
271 match reduction {
272 Reduction::Mean => weighted_loss.sum(AxisSpec::All)?.try_div(&masked_weight.sum(AxisSpec::All)?),
273 Reduction::Sum => weighted_loss.sum(AxisSpec::All),
274 Reduction::None => Ok(weighted_loss),
275 }
276 }
277
278 /// Dropout: randomly zeros elements during training, passes through in inference.
279 ///
280 /// Returns `(output, mask)` where mask is a boolean tensor (`true` = kept).
281 /// In inference mode (`training=false`, the default), the output is identical
282 /// to the input and the mask is all-true.
283 ///
284 /// **Note:** Training mode is not yet implemented (requires RNG); currently
285 /// returns identity regardless of `training`.
286 ///
287 /// # Examples
288 ///
289 /// ```
290 /// # use svod_tensor::Tensor;
291 /// # use ndarray::array;
292 /// let x = Tensor::from_ndarray(&array![1.0f32, 2.0, 3.0]);
293 /// let (mut out, mut mask) = x.dropout().p(0.5).call().unwrap();
294 /// out.realize().unwrap();
295 /// mask.realize().unwrap();
296 /// // Default is inference mode: output == input
297 /// assert_eq!(out.as_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0]);
298 /// assert_eq!(mask.as_vec::<bool>().unwrap(), vec![true, true, true]);
299 /// ```
300 #[builder]
301 pub fn dropout(&self, p: f64, #[builder(default = false)] training: bool) -> Result<(Tensor, Tensor)> {
302 snafu::ensure!(
303 (0.0..=1.0).contains(&p),
304 ParamRangeSnafu { op: "dropout", param: "p", value: p.to_string(), constraint: "0.0 <= p <= 1.0" }
305 );
306 let _ = p;
307 let shape = svod_ir::shape::to_vec_usize(&self.shape()?).context(UOpSnafu)?;
308 if !training {
309 let mask = Tensor::full(&shape, true, DType::Bool)?;
310 return Ok((self.clone(), mask));
311 }
312 // Training mode deferred (needs RNG: rand_like / Threefry)
313 let mask = Tensor::full(&shape, true, DType::Bool)?;
314 Ok((self.clone(), mask))
315 }
316 /// Convolution with ONNX-style parameters.
317 ///
318 /// Wraps the lower-level [`conv2d`](Tensor::conv2d) after resolving ONNX padding conventions
319 /// (`auto_pad`, flat `pads`). Input shape is `[N, C, H, W, ...]` and weight
320 /// shape is `[out_channels, in_channels/group, kH, kW, ...]`.
321 ///
322 /// # Examples
323 ///
324 /// Basic convolution with no padding:
325 ///
326 /// ```
327 /// # use svod_tensor::Tensor;
328 /// # use ndarray::Array4;
329 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 5, 5), 1.0f32));
330 /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
331 /// let mut y = x.conv().weight(&w).call().unwrap();
332 /// y.realize().unwrap();
333 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
334 /// assert_eq!(shape, [1, 1, 3, 3]);
335 /// // Each output element sums a 3x3 window of ones = 9.0
336 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![9.0; 9]);
337 /// ```
338 ///
339 /// With explicit padding and strides:
340 ///
341 /// ```
342 /// # use svod_tensor::Tensor;
343 /// # use ndarray::Array4;
344 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 5, 5), 1.0f32));
345 /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
346 /// let mut y = x.conv().weight(&w).pads(&[1, 1, 1, 1]).strides(&[2, 2]).call().unwrap();
347 /// y.realize().unwrap();
348 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
349 /// assert_eq!(shape, [1, 1, 3, 3]);
350 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![4.0, 6.0, 4.0, 6.0, 9.0, 6.0, 4.0, 6.0, 4.0]);
351 /// ```
352 #[builder]
353 pub fn conv(
354 &self,
355 weight: &Tensor,
356 bias: Option<&Tensor>,
357 #[builder(default)] auto_pad: AutoPad,
358 #[builder(default = 1)] group: usize,
359 kernel_shape: Option<&[usize]>,
360 pads: Option<&[i64]>,
361 strides: Option<&[i64]>,
362 dilations: Option<&[i64]>,
363 ) -> Result<Tensor> {
364 let w_shape = weight.shape()?;
365 let kernel: Vec<usize> = kernel_shape
366 .map(|ks| ks.to_vec())
367 .unwrap_or_else(|| w_shape[2..].iter().map(|s| s.as_const().unwrap()).collect());
368 let n = kernel.len();
369 let strides_u: Vec<usize> =
370 strides.map(|s| s.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
371 let dilations_u: Vec<usize> =
372 dilations.map(|d| d.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
373 let x_shape = self.shape()?;
374 let input_spatial: Vec<SInt> = x_shape[2..].to_vec();
375 let empty_pads: Vec<i64> = vec![];
376 let padding =
377 resolve_pool_pads(&input_spatial, pads.unwrap_or(&empty_pads), &kernel, &dilations_u, &strides_u, auto_pad);
378 self.conv2d()
379 .weight(weight)
380 .maybe_bias(bias)
381 .groups(group)
382 .stride(&strides_u)
383 .dilation(&dilations_u)
384 .padding(&padding)
385 .call()
386 }
387
388 /// Transposed convolution with ONNX-style parameters.
389 ///
390 /// Wraps [`conv_transpose2d`](Tensor::conv_transpose2d) after resolving ONNX padding conventions.
391 /// Supports `output_shape` and `output_padding` for precise output size control.
392 ///
393 /// # Examples
394 ///
395 /// Basic transposed convolution (upsampling):
396 ///
397 /// ```
398 /// # use svod_tensor::Tensor;
399 /// # use ndarray::Array4;
400 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
401 /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
402 /// let mut y = x.conv_transpose().weight(&w).call().unwrap();
403 /// y.realize().unwrap();
404 /// let vals = y.as_vec::<f32>().unwrap();
405 /// assert_eq!(vals.len(), 16); // 4x4 output
406 /// assert_eq!(vals[5], 4.0); // center sees full overlap
407 /// ```
408 ///
409 /// With stride (larger upsampling factor):
410 ///
411 /// ```
412 /// # use svod_tensor::Tensor;
413 /// # use ndarray::Array4;
414 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
415 /// let w = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 1.0f32));
416 /// let mut y = x.conv_transpose().weight(&w).strides(&[2, 2]).call().unwrap();
417 /// y.realize().unwrap();
418 /// let vals = y.as_vec::<f32>().unwrap();
419 /// assert_eq!(vals.len(), 25); // 5x5 output
420 /// ```
421 #[builder]
422 pub fn conv_transpose(
423 &self,
424 weight: &Tensor,
425 bias: Option<&Tensor>,
426 #[builder(default)] auto_pad: AutoPad,
427 #[builder(default = 1)] group: usize,
428 kernel_shape: Option<&[usize]>,
429 pads: Option<&[i64]>,
430 output_shape: Option<&[i64]>,
431 output_padding: Option<&[usize]>,
432 strides: Option<&[i64]>,
433 dilations: Option<&[i64]>,
434 ) -> Result<Tensor> {
435 let w_shape = weight.shape()?;
436 let kernel: Vec<usize> = kernel_shape
437 .map(|ks| ks.to_vec())
438 .unwrap_or_else(|| w_shape[2..].iter().map(|s| s.as_const().unwrap()).collect());
439 let n = kernel.len();
440 let x_shape = self.shape()?;
441 let input_spatial: Vec<SInt> = x_shape[2..].to_vec();
442 let strides_u: Vec<usize> =
443 strides.map(|s| s.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
444 let dilations_u: Vec<usize> =
445 dilations.map(|d| d.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
446 let output_padding_u: Vec<usize> = output_padding.map(|op| op.to_vec()).unwrap_or_else(|| vec![0; n]);
447
448 // 3-path padding resolution (matches Tinygrad's ConvTranspose)
449 let mut pads_resolved: Option<Vec<isize>> = None;
450
451 // ConvTranspose padding resolution requires concrete spatial dims.
452 let input_spatial_c: Vec<usize> = input_spatial
453 .iter()
454 .map(|s| s.as_const().expect("conv_transpose requires concrete spatial dims"))
455 .collect();
456
457 // Path 1: output_shape provided → derive total pads, apply auto_pad
458 if let Some(os) = output_shape {
459 let total_pads: Vec<isize> = (0..n)
460 .map(|i| {
461 (strides_u[i] * (input_spatial_c[i] - 1)
462 + output_padding_u[i]
463 + (kernel[i] - 1) * dilations_u[i]
464 + 1) as isize
465 - os[i] as isize
466 })
467 .collect();
468 pads_resolved = Some(auto_pad_split(&total_pads, auto_pad));
469 }
470
471 // Path 2: no explicit pads → derive from default output_shape
472 if pads_resolved.is_none() && pads.is_none_or(|p| p.is_empty()) {
473 let default_out: Vec<usize> = (0..n).map(|i| input_spatial_c[i] * strides_u[i]).collect();
474 let total_pads: Vec<isize> = (0..n)
475 .map(|i| {
476 (strides_u[i] * (input_spatial_c[i] - 1)
477 + output_padding_u[i]
478 + (kernel[i] - 1) * dilations_u[i]
479 + 1) as isize
480 - default_out[i] as isize
481 })
482 .collect();
483 pads_resolved =
484 Some(if auto_pad != AutoPad::NotSet { auto_pad_split(&total_pads, auto_pad) } else { vec![0; n * 2] });
485 }
486
487 // Path 3: explicit pads provided
488 let padding: Vec<(isize, isize)> = if let Some(flat) = pads_resolved {
489 let half = flat.len() / 2;
490 (0..half).map(|i| (flat[i], flat[i + half])).collect()
491 } else {
492 flat_pads_to_pairs(pads.unwrap())
493 };
494
495 self.conv_transpose2d()
496 .weight(weight)
497 .maybe_bias(bias)
498 .groups(group)
499 .stride(&strides_u)
500 .dilation(&dilations_u)
501 .padding(&padding)
502 .output_padding(&output_padding_u)
503 .call()
504 }
505
506 /// Average pooling with ONNX-style parameters.
507 ///
508 /// Wraps [`avg_pool2d`](Tensor::avg_pool2d) after resolving ONNX padding and stride conventions.
509 /// Stride defaults to 1 (unlike [`avg_pool2d`](Tensor::avg_pool2d) which defaults to `kernel_size`).
510 /// Input shape is `[N, C, H, W]`.
511 ///
512 /// # Examples
513 ///
514 /// ```
515 /// # use svod_tensor::Tensor;
516 /// # use ndarray::Array4;
517 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
518 /// let mut y = x.avg_pool().kernel_shape(&[2, 2]).call().unwrap();
519 /// y.realize().unwrap();
520 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
521 /// assert_eq!(shape, [1, 1, 3, 3]);
522 /// // Average of all-ones windows is 1.0
523 /// assert!(y.as_vec::<f32>().unwrap().iter().all(|&v| (v - 1.0).abs() < 1e-6));
524 /// ```
525 ///
526 /// With strides:
527 ///
528 /// ```
529 /// # use svod_tensor::Tensor;
530 /// # use ndarray::Array4;
531 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
532 /// let mut y = x.avg_pool().kernel_shape(&[2, 2]).strides(&[2, 2]).call().unwrap();
533 /// y.realize().unwrap();
534 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
535 /// assert_eq!(shape, [1, 1, 2, 2]);
536 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 4]);
537 /// ```
538 #[builder]
539 pub fn avg_pool(
540 &self,
541 kernel_shape: &[usize],
542 #[builder(default)] auto_pad: AutoPad,
543 #[builder(default = false)] ceil_mode: bool,
544 #[builder(default = false)] count_include_pad: bool,
545 pads: Option<&[i64]>,
546 strides: Option<&[i64]>,
547 dilations: Option<&[i64]>,
548 ) -> Result<Tensor> {
549 let n = kernel_shape.len();
550 let strides_u: Vec<usize> =
551 strides.map(|s| s.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
552 let dilations_u: Vec<usize> =
553 dilations.map(|d| d.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
554 let x_shape = self.shape()?;
555 let input_spatial: Vec<SInt> = x_shape[2..].to_vec();
556 let empty_pads: Vec<i64> = vec![];
557 let padding = resolve_pool_pads(
558 &input_spatial,
559 pads.unwrap_or(&empty_pads),
560 kernel_shape,
561 &dilations_u,
562 &strides_u,
563 auto_pad,
564 );
565 self.avg_pool2d()
566 .kernel_size(kernel_shape)
567 .stride(&strides_u)
568 .dilation(&dilations_u)
569 .padding(&padding)
570 .ceil_mode(ceil_mode)
571 .count_include_pad(count_include_pad)
572 .call()
573 }
574
575 /// Lp norm pooling with ONNX-style parameters.
576 ///
577 /// Computes `(sum(|x|^p))^(1/p)` over each pooling window. Defaults to
578 /// `p=2` (L2 pooling). Input shape is `[N, C, H, W]`.
579 ///
580 /// # Examples
581 ///
582 /// ```
583 /// # use svod_tensor::Tensor;
584 /// # use ndarray::Array4;
585 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
586 /// let mut y = x.lp_pool().kernel_shape(&[2, 2]).call().unwrap();
587 /// y.realize().unwrap();
588 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
589 /// assert_eq!(shape, [1, 1, 3, 3]);
590 /// // L2 pool of 2x2 window of ones = sqrt(4) = 2.0
591 /// assert!((y.as_vec::<f32>().unwrap()[0] - 2.0).abs() < 1e-5);
592 /// ```
593 #[builder]
594 pub fn lp_pool(
595 &self,
596 kernel_shape: &[usize],
597 #[builder(default = 2)] p: usize,
598 #[builder(default)] auto_pad: AutoPad,
599 #[builder(default = false)] ceil_mode: bool,
600 pads: Option<&[i64]>,
601 strides: Option<&[i64]>,
602 dilations: Option<&[i64]>,
603 ) -> Result<Tensor> {
604 snafu::ensure!(p >= 1, ParamRangeSnafu { op: "lp_pool", param: "p", value: p.to_string(), constraint: ">= 1" });
605 let n_spatial = kernel_shape.len();
606 let strides_u: Vec<usize> =
607 strides.map(|s| s.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n_spatial]);
608 let dilations_u: Vec<usize> =
609 dilations.map(|d| d.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n_spatial]);
610 let x_shape = self.shape()?;
611 let input_spatial: Vec<SInt> = x_shape[2..].to_vec();
612 let empty_pads: Vec<i64> = vec![];
613 let padding = resolve_pool_pads(
614 &input_spatial,
615 pads.unwrap_or(&empty_pads),
616 kernel_shape,
617 &dilations_u,
618 &strides_u,
619 auto_pad,
620 );
621
622 let p_f = p as f64;
623 let dtype = self.uop().dtype();
624 let p_tensor = Tensor::const_(p_f, dtype.clone());
625 let inv_p = Tensor::const_(1.0 / p_f, dtype);
626 let x_abs_p = self.try_abs()?.try_pow(&p_tensor)?;
627
628 // Pad, pool (create windows), then sum over kernel axes.
629 // This computes sum(|x|^p) directly — correct for all padding/ceil modes
630 // because padded zeros contribute 0 to the sum.
631 let reg_pads = padding;
632 let ceil_pads = if ceil_mode {
633 pad::apply_ceil_mode(®_pads, &input_spatial, kernel_shape, &strides_u, &dilations_u)
634 } else {
635 reg_pads.clone()
636 };
637 let pads_to_use = if ceil_mode { &ceil_pads } else { ®_pads };
638 let mut padded = x_abs_p;
639 if pads_to_use.iter().any(|&(b, e)| b != 0 || e != 0) {
640 let n_batch = x_shape.len() - n_spatial;
641 let mut full_pad: Vec<(isize, isize)> = vec![(0, 0); n_batch];
642 full_pad.extend_from_slice(pads_to_use);
643 padded = padded.try_pad(&full_pad)?;
644 }
645 let pooled = padded.pool(kernel_shape, &strides_u, &dilations_u)?;
646 let reduce_axes: Vec<isize> = (0..n_spatial).map(|j| -(1 + j as isize)).collect();
647 let sum_p = pooled.sum(crate::reduce::AxisSpec::Multiple(reduce_axes))?;
648 sum_p.try_pow(&inv_p)
649 }
650
651 /// Rearrange depth data into spatial blocks (inverse of [`space_to_depth`](Tensor::space_to_depth)).
652 ///
653 /// Equivalent to PyTorch's `F.pixel_shuffle`. Reshapes a `[N, C, H, W]`
654 /// tensor to `[N, C/(b*b), H*b, W*b]` where `b` is the blocksize.
655 ///
656 /// # Examples
657 ///
658 /// ```
659 /// # use svod_tensor::Tensor;
660 /// # use ndarray::Array4;
661 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 4, 1, 1), 1.0f32));
662 /// let mut y = x.depth_to_space().blocksize(2).call().unwrap();
663 /// y.realize().unwrap();
664 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
665 /// assert_eq!(shape, [1, 1, 2, 2]);
666 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 4]);
667 /// ```
668 ///
669 /// Using CRD mode (PyTorch pixel_shuffle order):
670 ///
671 /// ```
672 /// # use svod_tensor::Tensor;
673 /// # use svod_tensor::nn::DepthToSpaceMode;
674 /// # use ndarray::Array4;
675 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 4, 1, 1), 1.0f32));
676 /// let mut y = x.depth_to_space().blocksize(2).mode(DepthToSpaceMode::Crd).call().unwrap();
677 /// y.realize().unwrap();
678 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 4]);
679 /// ```
680 #[builder]
681 pub fn depth_to_space(&self, blocksize: usize, #[builder(default)] mode: DepthToSpaceMode) -> Result<Tensor> {
682 let ndim = self.ndim()?;
683 snafu::ensure!(ndim == 4, NdimExactSnafu { op: "depth_to_space", expected: 4_usize, actual: ndim });
684 snafu::ensure!(
685 blocksize > 0,
686 ParamRangeSnafu {
687 op: "depth_to_space",
688 param: "blocksize",
689 value: blocksize.to_string(),
690 constraint: "> 0"
691 }
692 );
693 let shape = self.shape()?;
694 let (b, c, h, w) = (
695 shape[0].as_const().unwrap(),
696 shape[1].as_const().unwrap(),
697 shape[2].as_const().unwrap(),
698 shape[3].as_const().unwrap(),
699 );
700 let bs_sq = blocksize * blocksize;
701 snafu::ensure!(
702 c.is_multiple_of(bs_sq),
703 DivisibilitySnafu {
704 op: "depth_to_space",
705 lhs_name: "channels",
706 lhs: c,
707 rhs_name: "blocksize^2",
708 rhs: bs_sq
709 }
710 );
711 let c_out = c / bs_sq;
712 let result = if mode == DepthToSpaceMode::Crd {
713 self.try_reshape([
714 b as isize,
715 c_out as isize,
716 blocksize as isize,
717 blocksize as isize,
718 h as isize,
719 w as isize,
720 ])?
721 .try_permute(&[0, 1, 4, 2, 5, 3])?
722 } else {
723 // DCR (default)
724 self.try_reshape([
725 b as isize,
726 blocksize as isize,
727 blocksize as isize,
728 c_out as isize,
729 h as isize,
730 w as isize,
731 ])?
732 .try_permute(&[0, 3, 4, 1, 5, 2])?
733 };
734 result.try_reshape([b as isize, c_out as isize, (h * blocksize) as isize, (w * blocksize) as isize])
735 }
736
737 /// Rearrange spatial data into depth (inverse of [`depth_to_space`](Tensor::depth_to_space)).
738 ///
739 /// Reshapes a `[N, C, H, W]` tensor to `[N, C*b*b, H/b, W/b]` where `b`
740 /// is the blocksize. Both `H` and `W` must be divisible by `blocksize`.
741 ///
742 /// # Examples
743 ///
744 /// ```
745 /// # use svod_tensor::Tensor;
746 /// # use ndarray::Array4;
747 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
748 /// let mut y = x.space_to_depth(2).unwrap();
749 /// y.realize().unwrap();
750 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
751 /// assert_eq!(shape, [1, 4, 2, 2]);
752 /// assert_eq!(y.as_vec::<f32>().unwrap(), vec![1.0; 16]);
753 /// ```
754 pub fn space_to_depth(&self, blocksize: usize) -> Result<Tensor> {
755 let ndim = self.ndim()?;
756 snafu::ensure!(ndim == 4, NdimExactSnafu { op: "space_to_depth", expected: 4_usize, actual: ndim });
757 snafu::ensure!(
758 blocksize > 0,
759 ParamRangeSnafu {
760 op: "space_to_depth",
761 param: "blocksize",
762 value: blocksize.to_string(),
763 constraint: "> 0"
764 }
765 );
766 let shape = self.shape()?;
767 let (b, c, h, w) = (
768 shape[0].as_const().unwrap(),
769 shape[1].as_const().unwrap(),
770 shape[2].as_const().unwrap(),
771 shape[3].as_const().unwrap(),
772 );
773 snafu::ensure!(
774 h.is_multiple_of(blocksize),
775 DivisibilitySnafu {
776 op: "space_to_depth",
777 lhs_name: "height",
778 lhs: h,
779 rhs_name: "blocksize",
780 rhs: blocksize
781 }
782 );
783 snafu::ensure!(
784 w.is_multiple_of(blocksize),
785 DivisibilitySnafu {
786 op: "space_to_depth",
787 lhs_name: "width",
788 lhs: w,
789 rhs_name: "blocksize",
790 rhs: blocksize
791 }
792 );
793 self.try_reshape([
794 b as isize,
795 c as isize,
796 (h / blocksize) as isize,
797 blocksize as isize,
798 (w / blocksize) as isize,
799 blocksize as isize,
800 ])?
801 .try_permute(&[0, 3, 5, 1, 2, 4])?
802 .try_reshape([
803 b as isize,
804 (c * blocksize * blocksize) as isize,
805 (h / blocksize) as isize,
806 (w / blocksize) as isize,
807 ])
808 }
809
810 /// Max pooling with ONNX-style parameters.
811 ///
812 /// Always returns `(values, indices)` where indices are flattened positions
813 /// (dtype `i64`). Wraps [`max_pool2d_with_indices`](Tensor::max_pool2d_with_indices) after resolving ONNX
814 /// padding conventions.
815 ///
816 /// # Examples
817 ///
818 /// ```
819 /// # use svod_tensor::Tensor;
820 /// # use ndarray::Array4;
821 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
822 /// let (vals, indices) = x.max_pool().kernel_shape(&[2, 2]).call().unwrap();
823 /// let shape: Vec<_> = vals.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
824 /// assert_eq!(shape, [1, 1, 3, 3]);
825 /// ```
826 ///
827 /// With strides:
828 ///
829 /// ```
830 /// # use svod_tensor::Tensor;
831 /// # use ndarray::Array4;
832 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
833 /// let (vals, _) = x.max_pool().kernel_shape(&[2, 2]).strides(&[2, 2]).call().unwrap();
834 /// let shape: Vec<_> = vals.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
835 /// assert_eq!(shape, [1, 1, 2, 2]);
836 /// ```
837 #[builder]
838 pub fn max_pool(
839 &self,
840 kernel_shape: &[usize],
841 #[builder(default)] auto_pad: AutoPad,
842 #[builder(default = false)] ceil_mode: bool,
843 #[builder(default = 0)] storage_order: usize,
844 pads: Option<&[i64]>,
845 strides: Option<&[i64]>,
846 dilations: Option<&[i64]>,
847 ) -> Result<(Tensor, Tensor)> {
848 let n = kernel_shape.len();
849 let strides_u: Vec<usize> =
850 strides.map(|s| s.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
851 let dilations_u: Vec<usize> =
852 dilations.map(|d| d.iter().map(|&v| v as usize).collect()).unwrap_or_else(|| vec![1; n]);
853 let x_shape = self.shape()?;
854 let input_spatial: Vec<SInt> = x_shape[2..].to_vec();
855 let empty_pads: Vec<i64> = vec![];
856 let padding = resolve_pool_pads(
857 &input_spatial,
858 pads.unwrap_or(&empty_pads),
859 kernel_shape,
860 &dilations_u,
861 &strides_u,
862 auto_pad,
863 );
864 let (values, indices) = self
865 .max_pool2d_with_indices()
866 .kernel_size(kernel_shape)
867 .stride(&strides_u)
868 .dilation(&dilations_u)
869 .padding(&padding)
870 .ceil_mode(ceil_mode)
871 .call()?;
872 let indices = if storage_order == 1 {
873 indices.try_transpose(-2, -1)?.cast(DType::Int64)?
874 } else {
875 indices.cast(DType::Int64)?
876 };
877 Ok((values, indices))
878 }
879
880 /// Local Response Normalization (LRN).
881 ///
882 /// Normalizes each element by dividing by a scaled sum of squares over a
883 /// local neighborhood of `size` channels:
884 /// `y = x / (bias + alpha * avg_pool(x^2, size))^beta`.
885 ///
886 /// Input must be 4-D `[N, C, H, W]`.
887 ///
888 /// # Examples
889 ///
890 /// ```
891 /// # use svod_tensor::Tensor;
892 /// # use ndarray::Array4;
893 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 3, 2, 2), 1.0f32));
894 /// let y = x.lrn().size(3).call().unwrap();
895 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
896 /// assert_eq!(shape, [1, 3, 2, 2]);
897 /// ```
898 ///
899 /// Custom alpha, beta, and bias:
900 ///
901 /// ```
902 /// # use svod_tensor::Tensor;
903 /// # use ndarray::Array4;
904 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 3, 2, 2), 1.0f32));
905 /// let y = x.lrn().size(3).alpha(0.001).beta(0.5).bias(2.0).call().unwrap();
906 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
907 /// assert_eq!(shape, [1, 3, 2, 2]);
908 /// ```
909 #[builder]
910 pub fn lrn(
911 &self,
912 size: usize,
913 #[builder(default = 0.0001)] alpha: f64,
914 #[builder(default = 0.75)] beta: f64,
915 #[builder(default = 1.0)] bias: f64,
916 ) -> Result<Tensor> {
917 let ndim = self.ndim()?;
918 snafu::ensure!(ndim == 4, NdimExactSnafu { op: "lrn", expected: 4_usize, actual: ndim });
919 snafu::ensure!(
920 size > 0,
921 ParamRangeSnafu { op: "lrn", param: "size", value: size.to_string(), constraint: "> 0" }
922 );
923 let shape = self.shape()?;
924 let (b, c, h, w) = (
925 shape[0].as_const().unwrap(),
926 shape[1].as_const().unwrap(),
927 shape[2].as_const().unwrap(),
928 shape[3].as_const().unwrap(),
929 );
930 let x_sq = self.square()?;
931 let x_sq = x_sq.try_reshape([b as isize, 1, c as isize, (h * w) as isize])?;
932 let pad_before = ((size - 1) / 2) as isize;
933 let pad_after = (size / 2) as isize;
934 let x_sq = x_sq.try_pad(&[(0, 0), (0, 0), (pad_before, pad_after), (0, 0)])?;
935 let pooled = x_sq.avg_pool2d().kernel_size(&[size, 1]).stride(&[1, 1]).call()?;
936 let pooled = pooled.try_reshape([b as isize, c as isize, h as isize, w as isize])?;
937 let dtype = self.uop().dtype();
938 let scale = pooled
939 .try_mul(&Tensor::const_(alpha, dtype.clone()))?
940 .try_add(&Tensor::const_(bias, dtype.clone()))?
941 .try_pow(&Tensor::const_(beta, dtype))?;
942 self.try_div(&scale)
943 }
944}
945
946impl Tensor {
947 /// Apply a sequence of layers to this tensor.
948 pub fn sequential(&self, layers: &[&dyn Layer]) -> Result<Tensor> {
949 let mut x = self.clone();
950 for layer in layers {
951 x = layer.forward(&x)?;
952 }
953 Ok(x)
954 }
955}