Skip to main content

tensorlogic_scirs_backend/
pooling.rs

1//! Pooling operations for neural network tensor processing.
2//!
3//! Provides max pooling, average pooling, Lp pooling, global pooling,
4//! adaptive pooling, and unpooling operations over N-dimensional spatial data.
5
6use scirs2_core::ndarray::{ArrayD, IxDyn};
7
8/// Errors that can occur during pooling operations.
9#[derive(Debug, Clone)]
10pub enum PoolingError {
11    /// Kernel size must be > 0.
12    InvalidKernelSize { size: usize },
13    /// Stride must be > 0.
14    InvalidStride { stride: usize },
15    /// Padding must be less than kernel_size.
16    InvalidPadding { padding: usize, kernel_size: usize },
17    /// Input tensor does not have enough dimensions.
18    InsufficientDimensions { ndim: usize, required: usize },
19    /// Input tensor is empty.
20    EmptyInput,
21    /// Shape mismatch between tensors.
22    ShapeMismatch(String),
23}
24
25impl std::fmt::Display for PoolingError {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        match self {
28            Self::InvalidKernelSize { size } => {
29                write!(f, "Invalid kernel size: {size} (must be > 0)")
30            }
31            Self::InvalidStride { stride } => {
32                write!(f, "Invalid stride: {stride} (must be > 0)")
33            }
34            Self::InvalidPadding {
35                padding,
36                kernel_size,
37            } => write!(
38                f,
39                "Invalid padding: {padding} (must be < kernel_size {kernel_size})"
40            ),
41            Self::InsufficientDimensions { ndim, required } => {
42                write!(
43                    f,
44                    "Insufficient dimensions: got {ndim}, need at least {required}"
45                )
46            }
47            Self::EmptyInput => write!(f, "Empty input tensor"),
48            Self::ShapeMismatch(msg) => write!(f, "Shape mismatch: {msg}"),
49        }
50    }
51}
52
53impl std::error::Error for PoolingError {}
54
55/// Pooling configuration specifying kernel size, stride, padding, and rounding mode.
56#[derive(Debug, Clone)]
57pub struct PoolConfig {
58    /// Kernel (window) size for each spatial dimension.
59    pub kernel_size: Vec<usize>,
60    /// Stride for each spatial dimension. If empty, defaults to kernel_size.
61    pub stride: Vec<usize>,
62    /// Zero-padding on each side for each spatial dimension.
63    pub padding: Vec<usize>,
64    /// Use ceil instead of floor for output size computation.
65    pub ceil_mode: bool,
66}
67
68impl PoolConfig {
69    /// Create a new config with the given kernel size, stride equal to kernel size,
70    /// zero padding, and floor mode.
71    pub fn new(kernel_size: Vec<usize>) -> Self {
72        Self {
73            stride: kernel_size.clone(),
74            padding: vec![0; kernel_size.len()],
75            kernel_size,
76            ceil_mode: false,
77        }
78    }
79
80    /// Set the stride (builder pattern).
81    pub fn with_stride(mut self, stride: Vec<usize>) -> Self {
82        self.stride = stride;
83        self
84    }
85
86    /// Set the padding (builder pattern).
87    pub fn with_padding(mut self, padding: Vec<usize>) -> Self {
88        self.padding = padding;
89        self
90    }
91
92    /// Set ceil mode (builder pattern).
93    pub fn with_ceil_mode(mut self, ceil: bool) -> Self {
94        self.ceil_mode = ceil;
95        self
96    }
97
98    /// Compute the output size for one spatial dimension.
99    ///
100    /// Formula: `floor((input + 2*padding - kernel) / stride) + 1`
101    /// (or ceil if `ceil_mode` is true).
102    pub fn output_size(&self, input_size: usize, dim: usize) -> usize {
103        let k = self.kernel_size.get(dim).copied().unwrap_or(1);
104        let s = self.effective_stride(dim);
105        let p = self.padding.get(dim).copied().unwrap_or(0);
106        let numerator = input_size + 2 * p;
107        if numerator < k {
108            return 0;
109        }
110        let diff = numerator - k;
111        if self.ceil_mode {
112            diff.div_ceil(s) + 1
113        } else {
114            diff / s + 1
115        }
116    }
117
118    /// Validate the config, returning an error if any field is invalid.
119    pub fn validate(&self) -> Result<(), PoolingError> {
120        for &k in &self.kernel_size {
121            if k == 0 {
122                return Err(PoolingError::InvalidKernelSize { size: k });
123            }
124        }
125        for &s in &self.stride {
126            if s == 0 {
127                return Err(PoolingError::InvalidStride { stride: s });
128            }
129        }
130        for (i, &p) in self.padding.iter().enumerate() {
131            let k = self.kernel_size.get(i).copied().unwrap_or(1);
132            if p >= k {
133                return Err(PoolingError::InvalidPadding {
134                    padding: p,
135                    kernel_size: k,
136                });
137            }
138        }
139        Ok(())
140    }
141
142    /// Number of spatial dimensions this config covers.
143    pub fn num_spatial_dims(&self) -> usize {
144        self.kernel_size.len()
145    }
146
147    /// Effective stride for a given dimension (defaults to kernel_size if stride vec is short).
148    fn effective_stride(&self, dim: usize) -> usize {
149        self.stride
150            .get(dim)
151            .copied()
152            .unwrap_or_else(|| self.kernel_size.get(dim).copied().unwrap_or(1))
153    }
154
155    /// Effective padding for a given dimension.
156    fn effective_padding(&self, dim: usize) -> usize {
157        self.padding.get(dim).copied().unwrap_or(0)
158    }
159}
160
161/// Validate that input has at least `batch + channel + spatial` dimensions.
162fn validate_input(input: &ArrayD<f64>, num_spatial: usize) -> Result<(), PoolingError> {
163    if input.is_empty() {
164        return Err(PoolingError::EmptyInput);
165    }
166    let required = num_spatial + 2;
167    if input.ndim() < required {
168        return Err(PoolingError::InsufficientDimensions {
169            ndim: input.ndim(),
170            required,
171        });
172    }
173    Ok(())
174}
175
176/// Compute the output shape given input shape and pool config.
177/// Returns full shape: [batch, channels, ...spatial_out...]
178fn compute_output_shape(
179    input_shape: &[usize],
180    config: &PoolConfig,
181) -> Result<Vec<usize>, PoolingError> {
182    let num_spatial = config.num_spatial_dims();
183    let mut out_shape = Vec::with_capacity(input_shape.len());
184    // Copy batch + channel dims
185    for &d in &input_shape[..input_shape.len() - num_spatial] {
186        out_shape.push(d);
187    }
188    // Compute spatial dims
189    for i in 0..num_spatial {
190        let spatial_idx = input_shape.len() - num_spatial + i;
191        let out = config.output_size(input_shape[spatial_idx], i);
192        out_shape.push(out);
193    }
194    Ok(out_shape)
195}
196
197/// Iterate over all positions in the non-spatial (batch + channel) dimensions.
198/// Returns the total number of "slices" and a function to map flat index → multi-dim index.
199fn num_outer_slices(shape: &[usize], num_spatial: usize) -> usize {
200    shape[..shape.len() - num_spatial].iter().product()
201}
202
203/// Convert a flat outer index to multi-dimensional indices for the leading dims.
204fn flat_to_outer_indices(mut flat: usize, shape: &[usize], num_spatial: usize) -> Vec<usize> {
205    let outer_dims = shape.len() - num_spatial;
206    let mut indices = vec![0usize; outer_dims];
207    for d in (0..outer_dims).rev() {
208        indices[d] = flat % shape[d];
209        flat /= shape[d];
210    }
211    indices
212}
213
214/// Extract a spatial slice from the input given outer indices.
215/// Returns a view into the spatial portion.
216fn get_spatial_value(
217    input: &ArrayD<f64>,
218    outer_indices: &[usize],
219    spatial_indices: &[usize],
220    num_spatial: usize,
221) -> f64 {
222    let ndim = input.ndim();
223    let mut idx = vec![0usize; ndim];
224    for (i, &oi) in outer_indices.iter().enumerate() {
225        idx[i] = oi;
226    }
227    let offset = ndim - num_spatial;
228    for (i, &si) in spatial_indices.iter().enumerate() {
229        idx[offset + i] = si;
230    }
231    input[IxDyn(&idx)]
232}
233
234/// Iterate over all windows for the spatial dimensions, calling the callback with
235/// the output spatial indices and the collected window values (with their flat spatial positions).
236fn for_each_window<F>(
237    input_spatial_shape: &[usize],
238    config: &PoolConfig,
239    output_spatial_shape: &[usize],
240    mut callback: F,
241) where
242    F: FnMut(&[usize], Vec<(f64, Vec<usize>)>),
243{
244    let num_spatial = config.num_spatial_dims();
245    let mut out_pos = vec![0usize; num_spatial];
246
247    loop {
248        // Collect values in the window at out_pos
249        let mut window_values: Vec<(f64, Vec<usize>)> = Vec::new();
250        collect_window_values(
251            input_spatial_shape,
252            config,
253            &out_pos,
254            num_spatial,
255            0,
256            &mut vec![0usize; num_spatial],
257            &mut window_values,
258        );
259
260        callback(&out_pos, window_values);
261
262        // Advance out_pos
263        if !advance_indices(&mut out_pos, output_spatial_shape) {
264            break;
265        }
266    }
267}
268
269/// Recursively collect all values within a pooling window.
270fn collect_window_values(
271    input_spatial_shape: &[usize],
272    config: &PoolConfig,
273    out_pos: &[usize],
274    num_spatial: usize,
275    dim: usize,
276    current_input_pos: &mut Vec<usize>,
277    results: &mut Vec<(f64, Vec<usize>)>,
278) {
279    if dim == num_spatial {
280        // Check bounds (accounting for padding)
281        let mut valid = true;
282        let mut actual_pos = Vec::with_capacity(num_spatial);
283        for d in 0..num_spatial {
284            let p = config.effective_padding(d);
285            let pos_with_pad = current_input_pos[d];
286            if pos_with_pad < p || pos_with_pad >= input_spatial_shape[d] + p {
287                valid = false;
288                break;
289            }
290            actual_pos.push(pos_with_pad - p);
291        }
292        if valid {
293            // We push a placeholder value; the caller will look it up
294            results.push((0.0, actual_pos));
295        }
296        return;
297    }
298
299    let stride = config.effective_stride(dim);
300    let k = config.kernel_size.get(dim).copied().unwrap_or(1);
301    let start = out_pos[dim] * stride;
302
303    for ki in 0..k {
304        current_input_pos[dim] = start + ki;
305        collect_window_values(
306            input_spatial_shape,
307            config,
308            out_pos,
309            num_spatial,
310            dim + 1,
311            current_input_pos,
312            results,
313        );
314    }
315}
316
317/// Advance a multi-dimensional index. Returns false if we've wrapped around (done).
318fn advance_indices(indices: &mut [usize], shape: &[usize]) -> bool {
319    for d in (0..indices.len()).rev() {
320        indices[d] += 1;
321        if indices[d] < shape[d] {
322            return true;
323        }
324        indices[d] = 0;
325    }
326    false
327}
328
329/// Compute the flat spatial index from multi-dim spatial indices.
330fn spatial_flat_index(spatial_indices: &[usize], spatial_shape: &[usize]) -> i64 {
331    let mut flat: i64 = 0;
332    let mut stride: i64 = 1;
333    for d in (0..spatial_indices.len()).rev() {
334        flat += spatial_indices[d] as i64 * stride;
335        stride *= spatial_shape[d] as i64;
336    }
337    flat
338}
339
340/// Max pooling over spatial dimensions.
341///
342/// Input shape: `[batch, channels, ...spatial_dims...]`
343/// Output: max over each kernel window.
344pub fn max_pool(input: &ArrayD<f64>, config: &PoolConfig) -> Result<ArrayD<f64>, PoolingError> {
345    config.validate()?;
346    let num_spatial = config.num_spatial_dims();
347    validate_input(input, num_spatial)?;
348
349    let input_shape = input.shape();
350    let out_shape = compute_output_shape(input_shape, config)?;
351    let spatial_offset = input_shape.len() - num_spatial;
352    let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
353    let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
354
355    let mut output = ArrayD::zeros(IxDyn(&out_shape));
356    let n_outer = num_outer_slices(input_shape, num_spatial);
357
358    for outer_flat in 0..n_outer {
359        let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
360
361        for_each_window(
362            &input_spatial,
363            config,
364            &output_spatial,
365            |out_pos, positions| {
366                let mut max_val = f64::NEG_INFINITY;
367                for (_, actual_pos) in &positions {
368                    let val = get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
369                    if val > max_val {
370                        max_val = val;
371                    }
372                }
373                // If no valid positions (all padding), use 0
374                if max_val == f64::NEG_INFINITY {
375                    max_val = 0.0;
376                }
377                let mut full_idx: Vec<usize> = outer_idx.clone();
378                full_idx.extend_from_slice(out_pos);
379                output[IxDyn(&full_idx)] = max_val;
380            },
381        );
382    }
383
384    Ok(output)
385}
386
387/// Max pooling with indices: returns `(pooled_output, indices_of_max)`.
388///
389/// The indices are flat indices into the spatial dimensions of the input.
390pub fn max_pool_with_indices(
391    input: &ArrayD<f64>,
392    config: &PoolConfig,
393) -> Result<(ArrayD<f64>, ArrayD<i64>), PoolingError> {
394    config.validate()?;
395    let num_spatial = config.num_spatial_dims();
396    validate_input(input, num_spatial)?;
397
398    let input_shape = input.shape();
399    let out_shape = compute_output_shape(input_shape, config)?;
400    let spatial_offset = input_shape.len() - num_spatial;
401    let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
402    let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
403
404    let mut output = ArrayD::zeros(IxDyn(&out_shape));
405    let mut indices = ArrayD::zeros(IxDyn(&out_shape));
406    let n_outer = num_outer_slices(input_shape, num_spatial);
407
408    for outer_flat in 0..n_outer {
409        let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
410
411        for_each_window(
412            &input_spatial,
413            config,
414            &output_spatial,
415            |out_pos, positions| {
416                let mut max_val = f64::NEG_INFINITY;
417                let mut max_idx: i64 = -1;
418                for (_, actual_pos) in &positions {
419                    let val = get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
420                    if val > max_val {
421                        max_val = val;
422                        max_idx = spatial_flat_index(actual_pos, &input_spatial);
423                    }
424                }
425                if max_val == f64::NEG_INFINITY {
426                    max_val = 0.0;
427                    max_idx = 0;
428                }
429                let mut full_idx: Vec<usize> = outer_idx.clone();
430                full_idx.extend_from_slice(out_pos);
431                output[IxDyn(&full_idx)] = max_val;
432                indices[IxDyn(&full_idx)] = max_idx;
433            },
434        );
435    }
436
437    Ok((output, indices))
438}
439
440/// Average pooling over spatial dimensions.
441///
442/// Input shape: `[batch, channels, ...spatial_dims...]`
443pub fn avg_pool(input: &ArrayD<f64>, config: &PoolConfig) -> Result<ArrayD<f64>, PoolingError> {
444    config.validate()?;
445    let num_spatial = config.num_spatial_dims();
446    validate_input(input, num_spatial)?;
447
448    let input_shape = input.shape();
449    let out_shape = compute_output_shape(input_shape, config)?;
450    let spatial_offset = input_shape.len() - num_spatial;
451    let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
452    let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
453
454    let mut output = ArrayD::zeros(IxDyn(&out_shape));
455    let n_outer = num_outer_slices(input_shape, num_spatial);
456
457    for outer_flat in 0..n_outer {
458        let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
459
460        for_each_window(
461            &input_spatial,
462            config,
463            &output_spatial,
464            |out_pos, positions| {
465                let mut sum = 0.0;
466                let count = positions.len();
467                for (_, actual_pos) in &positions {
468                    sum += get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
469                }
470                let avg = if count > 0 { sum / count as f64 } else { 0.0 };
471                let mut full_idx: Vec<usize> = outer_idx.clone();
472                full_idx.extend_from_slice(out_pos);
473                output[IxDyn(&full_idx)] = avg;
474            },
475        );
476    }
477
478    Ok(output)
479}
480
481/// Lp pooling (generalized): `(sum(|x|^p) / count)^(1/p)`.
482pub fn lp_pool(
483    input: &ArrayD<f64>,
484    config: &PoolConfig,
485    p: f64,
486) -> Result<ArrayD<f64>, PoolingError> {
487    config.validate()?;
488    let num_spatial = config.num_spatial_dims();
489    validate_input(input, num_spatial)?;
490
491    let input_shape = input.shape();
492    let out_shape = compute_output_shape(input_shape, config)?;
493    let spatial_offset = input_shape.len() - num_spatial;
494    let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
495    let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
496
497    let mut output = ArrayD::zeros(IxDyn(&out_shape));
498    let n_outer = num_outer_slices(input_shape, num_spatial);
499
500    for outer_flat in 0..n_outer {
501        let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
502
503        for_each_window(
504            &input_spatial,
505            config,
506            &output_spatial,
507            |out_pos, positions| {
508                let count = positions.len();
509                let mut sum_pow = 0.0;
510                for (_, actual_pos) in &positions {
511                    let val = get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
512                    sum_pow += val.abs().powf(p);
513                }
514                let result = if count > 0 {
515                    (sum_pow / count as f64).powf(1.0 / p)
516                } else {
517                    0.0
518                };
519                let mut full_idx: Vec<usize> = outer_idx.clone();
520                full_idx.extend_from_slice(out_pos);
521                output[IxDyn(&full_idx)] = result;
522            },
523        );
524    }
525
526    Ok(output)
527}
528
529/// Global max pooling: reduce all spatial dims to a single value per (batch, channel).
530///
531/// Input: `[batch, channels, ...spatial...]` → Output: `[batch, channels]`
532pub fn global_max_pool(input: &ArrayD<f64>) -> Result<ArrayD<f64>, PoolingError> {
533    if input.is_empty() {
534        return Err(PoolingError::EmptyInput);
535    }
536    if input.ndim() < 3 {
537        return Err(PoolingError::InsufficientDimensions {
538            ndim: input.ndim(),
539            required: 3,
540        });
541    }
542
543    let shape = input.shape();
544    let batch = shape[0];
545    let channels = shape[1];
546    let num_spatial = input.ndim() - 2;
547    let spatial_size: usize = shape[2..].iter().product();
548
549    let mut output = ArrayD::zeros(IxDyn(&[batch, channels]));
550
551    for b in 0..batch {
552        for c in 0..channels {
553            let mut max_val = f64::NEG_INFINITY;
554            // Iterate over all spatial positions
555            for s in 0..spatial_size {
556                let spatial_idx = flat_to_spatial_indices(s, &shape[2..]);
557                let mut full_idx = vec![b, c];
558                full_idx.extend_from_slice(&spatial_idx);
559                let val = input[IxDyn(&full_idx)];
560                if val > max_val {
561                    max_val = val;
562                }
563            }
564            if max_val == f64::NEG_INFINITY {
565                max_val = 0.0;
566            }
567            output[IxDyn(&[b, c])] = max_val;
568        }
569    }
570    // Suppress unused warning for num_spatial
571    let _ = num_spatial;
572
573    Ok(output)
574}
575
576/// Global average pooling: reduce spatial dims to mean.
577///
578/// Input: `[batch, channels, ...spatial...]` → Output: `[batch, channels]`
579pub fn global_avg_pool(input: &ArrayD<f64>) -> Result<ArrayD<f64>, PoolingError> {
580    if input.is_empty() {
581        return Err(PoolingError::EmptyInput);
582    }
583    if input.ndim() < 3 {
584        return Err(PoolingError::InsufficientDimensions {
585            ndim: input.ndim(),
586            required: 3,
587        });
588    }
589
590    let shape = input.shape();
591    let batch = shape[0];
592    let channels = shape[1];
593    let spatial_size: usize = shape[2..].iter().product();
594
595    let mut output = ArrayD::zeros(IxDyn(&[batch, channels]));
596
597    for b in 0..batch {
598        for c in 0..channels {
599            let mut sum = 0.0;
600            for s in 0..spatial_size {
601                let spatial_idx = flat_to_spatial_indices(s, &shape[2..]);
602                let mut full_idx = vec![b, c];
603                full_idx.extend_from_slice(&spatial_idx);
604                sum += input[IxDyn(&full_idx)];
605            }
606            output[IxDyn(&[b, c])] = sum / spatial_size as f64;
607        }
608    }
609
610    Ok(output)
611}
612
613/// Convert a flat index to multi-dimensional spatial indices.
614fn flat_to_spatial_indices(mut flat: usize, spatial_shape: &[usize]) -> Vec<usize> {
615    let mut indices = vec![0usize; spatial_shape.len()];
616    for d in (0..spatial_shape.len()).rev() {
617        indices[d] = flat % spatial_shape[d];
618        flat /= spatial_shape[d];
619    }
620    indices
621}
622
623/// Adaptive average pooling: automatically compute kernel/stride to achieve target output size.
624///
625/// Input: `[batch, channels, ...spatial...]`, `output_size` for each spatial dim.
626pub fn adaptive_avg_pool(
627    input: &ArrayD<f64>,
628    output_size: &[usize],
629) -> Result<ArrayD<f64>, PoolingError> {
630    if input.is_empty() {
631        return Err(PoolingError::EmptyInput);
632    }
633    let num_spatial = output_size.len();
634    if input.ndim() < num_spatial + 2 {
635        return Err(PoolingError::InsufficientDimensions {
636            ndim: input.ndim(),
637            required: num_spatial + 2,
638        });
639    }
640
641    let shape = input.shape();
642    let spatial_offset = shape.len() - num_spatial;
643    let input_spatial: Vec<usize> = shape[spatial_offset..].to_vec();
644
645    // Build output shape
646    let mut out_shape: Vec<usize> = shape[..spatial_offset].to_vec();
647    out_shape.extend_from_slice(output_size);
648
649    let mut output = ArrayD::zeros(IxDyn(&out_shape));
650    let n_outer = num_outer_slices(shape, num_spatial);
651
652    for outer_flat in 0..n_outer {
653        let outer_idx = flat_to_outer_indices(outer_flat, shape, num_spatial);
654
655        // Iterate over all output spatial positions
656        let mut out_pos = vec![0usize; num_spatial];
657        loop {
658            // For each spatial dim, compute the input range using the adaptive formula
659            let mut ranges: Vec<(usize, usize)> = Vec::with_capacity(num_spatial);
660            for d in 0..num_spatial {
661                let in_size = input_spatial[d];
662                let out_sz = output_size[d];
663                let start = (out_pos[d] * in_size) / out_sz;
664                let end = ((out_pos[d] + 1) * in_size) / out_sz;
665                ranges.push((start, end));
666            }
667
668            // Average over the adaptive window
669            let mut sum = 0.0;
670            let mut count = 0usize;
671            let mut win_pos = vec![0usize; num_spatial];
672            // Initialize win_pos to range starts
673            for d in 0..num_spatial {
674                win_pos[d] = ranges[d].0;
675            }
676            loop {
677                let val = get_spatial_value(input, &outer_idx, &win_pos, num_spatial);
678                sum += val;
679                count += 1;
680
681                // Advance win_pos within ranges
682                if !advance_within_ranges(&mut win_pos, &ranges) {
683                    break;
684                }
685            }
686
687            let avg = if count > 0 { sum / count as f64 } else { 0.0 };
688            let mut full_idx: Vec<usize> = outer_idx.clone();
689            full_idx.extend_from_slice(&out_pos);
690            output[IxDyn(&full_idx)] = avg;
691
692            if !advance_indices(&mut out_pos, output_size) {
693                break;
694            }
695        }
696    }
697
698    Ok(output)
699}
700
701/// Advance indices within specified ranges (inclusive start, exclusive end).
702fn advance_within_ranges(indices: &mut [usize], ranges: &[(usize, usize)]) -> bool {
703    for d in (0..indices.len()).rev() {
704        indices[d] += 1;
705        if indices[d] < ranges[d].1 {
706            return true;
707        }
708        indices[d] = ranges[d].0;
709    }
710    false
711}
712
713/// Unpool (inverse of max_pool): scatter pooled values back using stored indices.
714///
715/// Creates a zero tensor of `output_size` and places pooled values at the positions
716/// indicated by `indices`.
717pub fn max_unpool(
718    pooled: &ArrayD<f64>,
719    indices: &ArrayD<i64>,
720    output_size: &[usize],
721) -> Result<ArrayD<f64>, PoolingError> {
722    if pooled.shape() != indices.shape() {
723        return Err(PoolingError::ShapeMismatch(format!(
724            "pooled shape {:?} != indices shape {:?}",
725            pooled.shape(),
726            indices.shape()
727        )));
728    }
729    if pooled.is_empty() {
730        return Err(PoolingError::EmptyInput);
731    }
732
733    let pooled_shape = pooled.shape();
734    // output_size should be the full shape including batch+channel dims
735    if output_size.len() != pooled_shape.len() {
736        return Err(PoolingError::ShapeMismatch(format!(
737            "output_size len {} != pooled ndim {}",
738            output_size.len(),
739            pooled_shape.len()
740        )));
741    }
742
743    // Determine num_spatial by finding how many trailing dims differ
744    // We assume at least 2 leading dims (batch, channel) match
745    let num_spatial = pooled_shape.len().saturating_sub(2);
746    let spatial_offset = pooled_shape.len() - num_spatial;
747    let output_spatial: Vec<usize> = output_size[spatial_offset..].to_vec();
748
749    let mut output = ArrayD::zeros(IxDyn(output_size));
750    let n_outer = num_outer_slices(pooled_shape, num_spatial);
751
752    // Total spatial size of output for flat index mapping
753    let output_spatial_total: usize = output_spatial.iter().product();
754
755    for outer_flat in 0..n_outer {
756        let outer_idx = flat_to_outer_indices(outer_flat, pooled_shape, num_spatial);
757
758        // Iterate over all pooled spatial positions
759        let pooled_spatial: Vec<usize> = pooled_shape[spatial_offset..].to_vec();
760        let mut pos = vec![0usize; num_spatial];
761        loop {
762            let mut pooled_full: Vec<usize> = outer_idx.clone();
763            pooled_full.extend_from_slice(&pos);
764            let val = pooled[IxDyn(&pooled_full)];
765            let idx = indices[IxDyn(&pooled_full)];
766
767            if idx >= 0 && (idx as usize) < output_spatial_total {
768                let spatial_pos = flat_to_spatial_indices(idx as usize, &output_spatial);
769                let mut out_full: Vec<usize> = outer_idx.clone();
770                out_full.extend_from_slice(&spatial_pos);
771                output[IxDyn(&out_full)] = val;
772            }
773
774            if !advance_indices(&mut pos, &pooled_spatial) {
775                break;
776            }
777        }
778    }
779
780    Ok(output)
781}
782
783/// Statistics from a pooling operation.
784#[derive(Debug, Clone)]
785pub struct PoolingStats {
786    /// Shape of the input tensor.
787    pub input_shape: Vec<usize>,
788    /// Shape of the output tensor.
789    pub output_shape: Vec<usize>,
790    /// Kernel size for each spatial dimension.
791    pub kernel_size: Vec<usize>,
792    /// Stride for each spatial dimension.
793    pub stride: Vec<usize>,
794    /// Total number of elements in one kernel window (product of kernel dims).
795    pub receptive_field_size: usize,
796    /// Ratio of input spatial elements to output spatial elements.
797    pub compression_ratio: f64,
798    /// Overlap ratio: how much windows overlap (0 = no overlap).
799    pub overlap_ratio: f64,
800}
801
802impl PoolingStats {
803    /// Compute pooling statistics from input shape and config.
804    pub fn compute(input_shape: &[usize], config: &PoolConfig) -> Result<Self, PoolingError> {
805        config.validate()?;
806        let num_spatial = config.num_spatial_dims();
807        if input_shape.len() < num_spatial + 2 {
808            return Err(PoolingError::InsufficientDimensions {
809                ndim: input_shape.len(),
810                required: num_spatial + 2,
811            });
812        }
813
814        let output_shape = compute_output_shape(input_shape, config)?;
815        let spatial_offset = input_shape.len() - num_spatial;
816
817        let input_spatial_size: usize = input_shape[spatial_offset..].iter().product();
818        let output_spatial_size: usize = output_shape[spatial_offset..].iter().product();
819
820        let receptive_field_size: usize = config.kernel_size.iter().product();
821
822        let compression_ratio = if output_spatial_size > 0 {
823            input_spatial_size as f64 / output_spatial_size as f64
824        } else {
825            f64::INFINITY
826        };
827
828        // Overlap ratio: for each dim, overlap = (kernel - stride) / kernel
829        // Average across dims, clamped to [0, 1]
830        let mut overlap_sum = 0.0;
831        for d in 0..num_spatial {
832            let k = config.kernel_size.get(d).copied().unwrap_or(1) as f64;
833            let s = config.effective_stride(d) as f64;
834            let overlap = ((k - s) / k).max(0.0);
835            overlap_sum += overlap;
836        }
837        let overlap_ratio = if num_spatial > 0 {
838            overlap_sum / num_spatial as f64
839        } else {
840            0.0
841        };
842
843        let effective_stride: Vec<usize> = (0..num_spatial)
844            .map(|d| config.effective_stride(d))
845            .collect();
846
847        Ok(Self {
848            input_shape: input_shape.to_vec(),
849            output_shape,
850            kernel_size: config.kernel_size.clone(),
851            stride: effective_stride,
852            receptive_field_size,
853            compression_ratio,
854            overlap_ratio,
855        })
856    }
857
858    /// Return a human-readable summary string.
859    pub fn summary(&self) -> String {
860        format!(
861            "Pooling: {:?} -> {:?}, kernel={:?}, stride={:?}, \
862             receptive_field={}, compression={:.2}x, overlap={:.2}",
863            self.input_shape,
864            self.output_shape,
865            self.kernel_size,
866            self.stride,
867            self.receptive_field_size,
868            self.compression_ratio,
869            self.overlap_ratio,
870        )
871    }
872}
873
874#[cfg(test)]
875mod tests {
876    use super::*;
877    use scirs2_core::ndarray::ArrayD;
878
879    fn make_4d(data: Vec<f64>, h: usize, w: usize) -> ArrayD<f64> {
880        ArrayD::from_shape_vec(IxDyn(&[1, 1, h, w]), data)
881            .expect("test tensor creation should succeed")
882    }
883
884    #[test]
885    fn test_pool_config_output_size() {
886        let config = PoolConfig::new(vec![2, 2]);
887        assert_eq!(config.output_size(4, 0), 2);
888        assert_eq!(config.output_size(4, 1), 2);
889    }
890
891    #[test]
892    fn test_pool_config_output_size_with_padding() {
893        let config = PoolConfig::new(vec![2, 2]).with_padding(vec![1, 1]);
894        // (4 + 2*1 - 2) / 2 + 1 = 4/2 + 1 = 3
895        assert_eq!(config.output_size(4, 0), 3);
896    }
897
898    #[test]
899    fn test_pool_config_validate_valid() {
900        let config = PoolConfig::new(vec![2, 2]);
901        assert!(config.validate().is_ok());
902    }
903
904    #[test]
905    fn test_pool_config_validate_zero_kernel() {
906        let config = PoolConfig::new(vec![0, 2]);
907        let err = config.validate();
908        assert!(err.is_err());
909        match err {
910            Err(PoolingError::InvalidKernelSize { size: 0 }) => {}
911            other => panic!("Expected InvalidKernelSize, got {:?}", other),
912        }
913    }
914
915    #[test]
916    fn test_max_pool_basic() {
917        // 4x4 input with known values
918        #[rustfmt::skip]
919        let data = vec![
920            1.0, 2.0, 3.0, 4.0,
921            5.0, 6.0, 7.0, 8.0,
922            9.0, 10.0, 11.0, 12.0,
923            13.0, 14.0, 15.0, 16.0,
924        ];
925        let input = make_4d(data, 4, 4);
926        let config = PoolConfig::new(vec![2, 2]);
927        let output = max_pool(&input, &config).expect("max_pool should succeed");
928
929        assert_eq!(output.shape(), &[1, 1, 2, 2]);
930        assert_eq!(output[IxDyn(&[0, 0, 0, 0])], 6.0);
931        assert_eq!(output[IxDyn(&[0, 0, 0, 1])], 8.0);
932        assert_eq!(output[IxDyn(&[0, 0, 1, 0])], 14.0);
933        assert_eq!(output[IxDyn(&[0, 0, 1, 1])], 16.0);
934    }
935
936    #[test]
937    fn test_max_pool_with_indices_correct() {
938        #[rustfmt::skip]
939        let data = vec![
940            1.0, 2.0, 3.0, 4.0,
941            5.0, 6.0, 7.0, 8.0,
942            9.0, 10.0, 11.0, 12.0,
943            13.0, 14.0, 15.0, 16.0,
944        ];
945        let input = make_4d(data, 4, 4);
946        let config = PoolConfig::new(vec![2, 2]);
947        let (output, indices) =
948            max_pool_with_indices(&input, &config).expect("max_pool_with_indices should succeed");
949
950        assert_eq!(output.shape(), &[1, 1, 2, 2]);
951        // Max of top-left 2x2 is 6.0 at position (1,1) -> flat index 5
952        assert_eq!(output[IxDyn(&[0, 0, 0, 0])], 6.0);
953        assert_eq!(indices[IxDyn(&[0, 0, 0, 0])], 5);
954        // Max of top-right 2x2 is 8.0 at position (1,3) -> flat index 7
955        assert_eq!(output[IxDyn(&[0, 0, 0, 1])], 8.0);
956        assert_eq!(indices[IxDyn(&[0, 0, 0, 1])], 7);
957        // Max of bottom-left 2x2 is 14.0 at position (3,1) -> flat index 13
958        assert_eq!(output[IxDyn(&[0, 0, 1, 0])], 14.0);
959        assert_eq!(indices[IxDyn(&[0, 0, 1, 0])], 13);
960        // Max of bottom-right 2x2 is 16.0 at position (3,3) -> flat index 15
961        assert_eq!(output[IxDyn(&[0, 0, 1, 1])], 16.0);
962        assert_eq!(indices[IxDyn(&[0, 0, 1, 1])], 15);
963    }
964
965    #[test]
966    fn test_avg_pool_basic() {
967        #[rustfmt::skip]
968        let data = vec![
969            1.0, 2.0, 3.0, 4.0,
970            5.0, 6.0, 7.0, 8.0,
971            9.0, 10.0, 11.0, 12.0,
972            13.0, 14.0, 15.0, 16.0,
973        ];
974        let input = make_4d(data, 4, 4);
975        let config = PoolConfig::new(vec![2, 2]);
976        let output = avg_pool(&input, &config).expect("avg_pool should succeed");
977
978        assert_eq!(output.shape(), &[1, 1, 2, 2]);
979        // avg of [1,2,5,6] = 3.5
980        assert!((output[IxDyn(&[0, 0, 0, 0])] - 3.5).abs() < 1e-10);
981        // avg of [3,4,7,8] = 5.5
982        assert!((output[IxDyn(&[0, 0, 0, 1])] - 5.5).abs() < 1e-10);
983        // avg of [9,10,13,14] = 11.5
984        assert!((output[IxDyn(&[0, 0, 1, 0])] - 11.5).abs() < 1e-10);
985        // avg of [11,12,15,16] = 13.5
986        assert!((output[IxDyn(&[0, 0, 1, 1])] - 13.5).abs() < 1e-10);
987    }
988
989    #[test]
990    fn test_avg_pool_padding() {
991        // With padding=1, kernel=2, stride=2, input=4 → output = (4+2-2)/2 + 1 = 3
992        let data = vec![1.0; 16];
993        let input = make_4d(data, 4, 4);
994        let config = PoolConfig::new(vec![2, 2]).with_padding(vec![1, 1]);
995        let output = avg_pool(&input, &config).expect("avg_pool with padding should succeed");
996
997        assert_eq!(output.shape(), &[1, 1, 3, 3]);
998    }
999
1000    #[test]
1001    fn test_lp_pool_p2() {
1002        // L2 pool: sqrt(mean of squares)
1003        #[rustfmt::skip]
1004        let data = vec![
1005            1.0, 2.0,
1006            3.0, 4.0,
1007        ];
1008        let input = make_4d(data, 2, 2);
1009        let config = PoolConfig::new(vec![2, 2]);
1010        let output = lp_pool(&input, &config, 2.0).expect("lp_pool p=2 should succeed");
1011
1012        assert_eq!(output.shape(), &[1, 1, 1, 1]);
1013        // sqrt((1+4+9+16)/4) = sqrt(30/4) = sqrt(7.5)
1014        let expected = (7.5_f64).sqrt();
1015        assert!((output[IxDyn(&[0, 0, 0, 0])] - expected).abs() < 1e-10);
1016    }
1017
1018    #[test]
1019    fn test_lp_pool_p1() {
1020        // L1 pool: (mean of |x|^1)^(1/1) = mean of |x|
1021        #[rustfmt::skip]
1022        let data = vec![
1023            1.0, -2.0,
1024            3.0, -4.0,
1025        ];
1026        let input = make_4d(data, 2, 2);
1027        let config = PoolConfig::new(vec![2, 2]);
1028        let output = lp_pool(&input, &config, 1.0).expect("lp_pool p=1 should succeed");
1029
1030        assert_eq!(output.shape(), &[1, 1, 1, 1]);
1031        // mean of [1, 2, 3, 4] = 2.5
1032        assert!((output[IxDyn(&[0, 0, 0, 0])] - 2.5).abs() < 1e-10);
1033    }
1034
1035    #[test]
1036    fn test_global_max_pool_shape() {
1037        let input = ArrayD::zeros(IxDyn(&[1, 3, 4, 4]));
1038        let output = global_max_pool(&input).expect("global_max_pool should succeed");
1039        assert_eq!(output.shape(), &[1, 3]);
1040    }
1041
1042    #[test]
1043    fn test_global_max_pool_values() {
1044        let mut input = ArrayD::zeros(IxDyn(&[1, 3, 4, 4]));
1045        // Set a known max in each channel
1046        input[IxDyn(&[0, 0, 2, 3])] = 42.0;
1047        input[IxDyn(&[0, 1, 0, 0])] = 99.0;
1048        input[IxDyn(&[0, 2, 3, 3])] = -1.0; // all zeros except this, but 0 > -1
1049                                            // Channel 2: all zeros, so max = 0
1050
1051        let output = global_max_pool(&input).expect("global_max_pool should succeed");
1052        assert_eq!(output[IxDyn(&[0, 0])], 42.0);
1053        assert_eq!(output[IxDyn(&[0, 1])], 99.0);
1054        assert_eq!(output[IxDyn(&[0, 2])], 0.0); // max of zeros and -1 is 0
1055    }
1056
1057    #[test]
1058    fn test_global_avg_pool_shape() {
1059        let input = ArrayD::zeros(IxDyn(&[1, 3, 4, 4]));
1060        let output = global_avg_pool(&input).expect("global_avg_pool should succeed");
1061        assert_eq!(output.shape(), &[1, 3]);
1062    }
1063
1064    #[test]
1065    fn test_global_avg_pool_values() {
1066        let mut input = ArrayD::ones(IxDyn(&[1, 2, 2, 2]));
1067        // Channel 0: all ones → mean = 1.0
1068        // Channel 1: set all to 2.0
1069        input[IxDyn(&[0, 1, 0, 0])] = 2.0;
1070        input[IxDyn(&[0, 1, 0, 1])] = 2.0;
1071        input[IxDyn(&[0, 1, 1, 0])] = 2.0;
1072        input[IxDyn(&[0, 1, 1, 1])] = 2.0;
1073
1074        let output = global_avg_pool(&input).expect("global_avg_pool should succeed");
1075        assert!((output[IxDyn(&[0, 0])] - 1.0).abs() < 1e-10);
1076        assert!((output[IxDyn(&[0, 1])] - 2.0).abs() < 1e-10);
1077    }
1078
1079    #[test]
1080    fn test_adaptive_avg_pool_output_size() {
1081        let input = ArrayD::ones(IxDyn(&[1, 1, 4, 4]));
1082        let output = adaptive_avg_pool(&input, &[2, 2]).expect("adaptive_avg_pool should succeed");
1083        assert_eq!(output.shape(), &[1, 1, 2, 2]);
1084    }
1085
1086    #[test]
1087    fn test_adaptive_avg_pool_identity() {
1088        // Target same as input → should preserve values
1089        #[rustfmt::skip]
1090        let data = vec![
1091            1.0, 2.0, 3.0, 4.0,
1092            5.0, 6.0, 7.0, 8.0,
1093            9.0, 10.0, 11.0, 12.0,
1094            13.0, 14.0, 15.0, 16.0,
1095        ];
1096        let input = make_4d(data.clone(), 4, 4);
1097        let output =
1098            adaptive_avg_pool(&input, &[4, 4]).expect("adaptive_avg_pool identity should succeed");
1099        assert_eq!(output.shape(), &[1, 1, 4, 4]);
1100        for (i, &v) in data.iter().enumerate() {
1101            let h = i / 4;
1102            let w = i % 4;
1103            assert!(
1104                (output[IxDyn(&[0, 0, h, w])] - v).abs() < 1e-10,
1105                "mismatch at ({}, {})",
1106                h,
1107                w
1108            );
1109        }
1110    }
1111
1112    #[test]
1113    fn test_max_unpool_basic() {
1114        #[rustfmt::skip]
1115        let data = vec![
1116            1.0, 2.0, 3.0, 4.0,
1117            5.0, 6.0, 7.0, 8.0,
1118            9.0, 10.0, 11.0, 12.0,
1119            13.0, 14.0, 15.0, 16.0,
1120        ];
1121        let input = make_4d(data, 4, 4);
1122        let config = PoolConfig::new(vec![2, 2]);
1123
1124        let (pooled, indices) =
1125            max_pool_with_indices(&input, &config).expect("max_pool_with_indices should succeed");
1126
1127        let unpooled =
1128            max_unpool(&pooled, &indices, &[1, 1, 4, 4]).expect("max_unpool should succeed");
1129
1130        assert_eq!(unpooled.shape(), &[1, 1, 4, 4]);
1131        // Values at max positions should be restored
1132        assert_eq!(unpooled[IxDyn(&[0, 0, 1, 1])], 6.0); // index 5 → (1,1)
1133        assert_eq!(unpooled[IxDyn(&[0, 0, 1, 3])], 8.0); // index 7 → (1,3)
1134        assert_eq!(unpooled[IxDyn(&[0, 0, 3, 1])], 14.0); // index 13 → (3,1)
1135        assert_eq!(unpooled[IxDyn(&[0, 0, 3, 3])], 16.0); // index 15 → (3,3)
1136                                                          // Non-max positions should be zero
1137        assert_eq!(unpooled[IxDyn(&[0, 0, 0, 0])], 0.0);
1138        assert_eq!(unpooled[IxDyn(&[0, 0, 2, 2])], 0.0);
1139    }
1140
1141    #[test]
1142    fn test_pooling_stats_compression() {
1143        let config = PoolConfig::new(vec![2, 2]);
1144        let stats =
1145            PoolingStats::compute(&[1, 1, 4, 4], &config).expect("stats compute should succeed");
1146        assert_eq!(stats.output_shape, vec![1, 1, 2, 2]);
1147        // 4*4 / 2*2 = 16/4 = 4.0
1148        assert!((stats.compression_ratio - 4.0).abs() < 1e-10);
1149        assert_eq!(stats.receptive_field_size, 4);
1150        // stride == kernel → no overlap
1151        assert!((stats.overlap_ratio - 0.0).abs() < 1e-10);
1152    }
1153
1154    #[test]
1155    fn test_pooling_stats_summary() {
1156        let config = PoolConfig::new(vec![2, 2]);
1157        let stats =
1158            PoolingStats::compute(&[1, 1, 4, 4], &config).expect("stats compute should succeed");
1159        let summary = stats.summary();
1160        assert!(!summary.is_empty());
1161        assert!(summary.contains("Pooling"));
1162    }
1163
1164    #[test]
1165    fn test_pooling_error_display() {
1166        let errors = vec![
1167            PoolingError::InvalidKernelSize { size: 0 },
1168            PoolingError::InvalidStride { stride: 0 },
1169            PoolingError::InvalidPadding {
1170                padding: 3,
1171                kernel_size: 2,
1172            },
1173            PoolingError::InsufficientDimensions {
1174                ndim: 2,
1175                required: 4,
1176            },
1177            PoolingError::EmptyInput,
1178            PoolingError::ShapeMismatch("test".to_string()),
1179        ];
1180        for err in &errors {
1181            let msg = format!("{err}");
1182            assert!(!msg.is_empty(), "Error display should not be empty");
1183        }
1184    }
1185}