1use scirs2_core::ndarray::{ArrayD, IxDyn};
7
8#[derive(Debug, Clone)]
10pub enum PoolingError {
11 InvalidKernelSize { size: usize },
13 InvalidStride { stride: usize },
15 InvalidPadding { padding: usize, kernel_size: usize },
17 InsufficientDimensions { ndim: usize, required: usize },
19 EmptyInput,
21 ShapeMismatch(String),
23}
24
25impl std::fmt::Display for PoolingError {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 match self {
28 Self::InvalidKernelSize { size } => {
29 write!(f, "Invalid kernel size: {size} (must be > 0)")
30 }
31 Self::InvalidStride { stride } => {
32 write!(f, "Invalid stride: {stride} (must be > 0)")
33 }
34 Self::InvalidPadding {
35 padding,
36 kernel_size,
37 } => write!(
38 f,
39 "Invalid padding: {padding} (must be < kernel_size {kernel_size})"
40 ),
41 Self::InsufficientDimensions { ndim, required } => {
42 write!(
43 f,
44 "Insufficient dimensions: got {ndim}, need at least {required}"
45 )
46 }
47 Self::EmptyInput => write!(f, "Empty input tensor"),
48 Self::ShapeMismatch(msg) => write!(f, "Shape mismatch: {msg}"),
49 }
50 }
51}
52
53impl std::error::Error for PoolingError {}
54
55#[derive(Debug, Clone)]
57pub struct PoolConfig {
58 pub kernel_size: Vec<usize>,
60 pub stride: Vec<usize>,
62 pub padding: Vec<usize>,
64 pub ceil_mode: bool,
66}
67
68impl PoolConfig {
69 pub fn new(kernel_size: Vec<usize>) -> Self {
72 Self {
73 stride: kernel_size.clone(),
74 padding: vec![0; kernel_size.len()],
75 kernel_size,
76 ceil_mode: false,
77 }
78 }
79
80 pub fn with_stride(mut self, stride: Vec<usize>) -> Self {
82 self.stride = stride;
83 self
84 }
85
86 pub fn with_padding(mut self, padding: Vec<usize>) -> Self {
88 self.padding = padding;
89 self
90 }
91
92 pub fn with_ceil_mode(mut self, ceil: bool) -> Self {
94 self.ceil_mode = ceil;
95 self
96 }
97
98 pub fn output_size(&self, input_size: usize, dim: usize) -> usize {
103 let k = self.kernel_size.get(dim).copied().unwrap_or(1);
104 let s = self.effective_stride(dim);
105 let p = self.padding.get(dim).copied().unwrap_or(0);
106 let numerator = input_size + 2 * p;
107 if numerator < k {
108 return 0;
109 }
110 let diff = numerator - k;
111 if self.ceil_mode {
112 diff.div_ceil(s) + 1
113 } else {
114 diff / s + 1
115 }
116 }
117
118 pub fn validate(&self) -> Result<(), PoolingError> {
120 for &k in &self.kernel_size {
121 if k == 0 {
122 return Err(PoolingError::InvalidKernelSize { size: k });
123 }
124 }
125 for &s in &self.stride {
126 if s == 0 {
127 return Err(PoolingError::InvalidStride { stride: s });
128 }
129 }
130 for (i, &p) in self.padding.iter().enumerate() {
131 let k = self.kernel_size.get(i).copied().unwrap_or(1);
132 if p >= k {
133 return Err(PoolingError::InvalidPadding {
134 padding: p,
135 kernel_size: k,
136 });
137 }
138 }
139 Ok(())
140 }
141
142 pub fn num_spatial_dims(&self) -> usize {
144 self.kernel_size.len()
145 }
146
147 fn effective_stride(&self, dim: usize) -> usize {
149 self.stride
150 .get(dim)
151 .copied()
152 .unwrap_or_else(|| self.kernel_size.get(dim).copied().unwrap_or(1))
153 }
154
155 fn effective_padding(&self, dim: usize) -> usize {
157 self.padding.get(dim).copied().unwrap_or(0)
158 }
159}
160
161fn validate_input(input: &ArrayD<f64>, num_spatial: usize) -> Result<(), PoolingError> {
163 if input.is_empty() {
164 return Err(PoolingError::EmptyInput);
165 }
166 let required = num_spatial + 2;
167 if input.ndim() < required {
168 return Err(PoolingError::InsufficientDimensions {
169 ndim: input.ndim(),
170 required,
171 });
172 }
173 Ok(())
174}
175
176fn compute_output_shape(
179 input_shape: &[usize],
180 config: &PoolConfig,
181) -> Result<Vec<usize>, PoolingError> {
182 let num_spatial = config.num_spatial_dims();
183 let mut out_shape = Vec::with_capacity(input_shape.len());
184 for &d in &input_shape[..input_shape.len() - num_spatial] {
186 out_shape.push(d);
187 }
188 for i in 0..num_spatial {
190 let spatial_idx = input_shape.len() - num_spatial + i;
191 let out = config.output_size(input_shape[spatial_idx], i);
192 out_shape.push(out);
193 }
194 Ok(out_shape)
195}
196
197fn num_outer_slices(shape: &[usize], num_spatial: usize) -> usize {
200 shape[..shape.len() - num_spatial].iter().product()
201}
202
203fn flat_to_outer_indices(mut flat: usize, shape: &[usize], num_spatial: usize) -> Vec<usize> {
205 let outer_dims = shape.len() - num_spatial;
206 let mut indices = vec![0usize; outer_dims];
207 for d in (0..outer_dims).rev() {
208 indices[d] = flat % shape[d];
209 flat /= shape[d];
210 }
211 indices
212}
213
214fn get_spatial_value(
217 input: &ArrayD<f64>,
218 outer_indices: &[usize],
219 spatial_indices: &[usize],
220 num_spatial: usize,
221) -> f64 {
222 let ndim = input.ndim();
223 let mut idx = vec![0usize; ndim];
224 for (i, &oi) in outer_indices.iter().enumerate() {
225 idx[i] = oi;
226 }
227 let offset = ndim - num_spatial;
228 for (i, &si) in spatial_indices.iter().enumerate() {
229 idx[offset + i] = si;
230 }
231 input[IxDyn(&idx)]
232}
233
234fn for_each_window<F>(
237 input_spatial_shape: &[usize],
238 config: &PoolConfig,
239 output_spatial_shape: &[usize],
240 mut callback: F,
241) where
242 F: FnMut(&[usize], Vec<(f64, Vec<usize>)>),
243{
244 let num_spatial = config.num_spatial_dims();
245 let mut out_pos = vec![0usize; num_spatial];
246
247 loop {
248 let mut window_values: Vec<(f64, Vec<usize>)> = Vec::new();
250 collect_window_values(
251 input_spatial_shape,
252 config,
253 &out_pos,
254 num_spatial,
255 0,
256 &mut vec![0usize; num_spatial],
257 &mut window_values,
258 );
259
260 callback(&out_pos, window_values);
261
262 if !advance_indices(&mut out_pos, output_spatial_shape) {
264 break;
265 }
266 }
267}
268
269fn collect_window_values(
271 input_spatial_shape: &[usize],
272 config: &PoolConfig,
273 out_pos: &[usize],
274 num_spatial: usize,
275 dim: usize,
276 current_input_pos: &mut Vec<usize>,
277 results: &mut Vec<(f64, Vec<usize>)>,
278) {
279 if dim == num_spatial {
280 let mut valid = true;
282 let mut actual_pos = Vec::with_capacity(num_spatial);
283 for d in 0..num_spatial {
284 let p = config.effective_padding(d);
285 let pos_with_pad = current_input_pos[d];
286 if pos_with_pad < p || pos_with_pad >= input_spatial_shape[d] + p {
287 valid = false;
288 break;
289 }
290 actual_pos.push(pos_with_pad - p);
291 }
292 if valid {
293 results.push((0.0, actual_pos));
295 }
296 return;
297 }
298
299 let stride = config.effective_stride(dim);
300 let k = config.kernel_size.get(dim).copied().unwrap_or(1);
301 let start = out_pos[dim] * stride;
302
303 for ki in 0..k {
304 current_input_pos[dim] = start + ki;
305 collect_window_values(
306 input_spatial_shape,
307 config,
308 out_pos,
309 num_spatial,
310 dim + 1,
311 current_input_pos,
312 results,
313 );
314 }
315}
316
317fn advance_indices(indices: &mut [usize], shape: &[usize]) -> bool {
319 for d in (0..indices.len()).rev() {
320 indices[d] += 1;
321 if indices[d] < shape[d] {
322 return true;
323 }
324 indices[d] = 0;
325 }
326 false
327}
328
329fn spatial_flat_index(spatial_indices: &[usize], spatial_shape: &[usize]) -> i64 {
331 let mut flat: i64 = 0;
332 let mut stride: i64 = 1;
333 for d in (0..spatial_indices.len()).rev() {
334 flat += spatial_indices[d] as i64 * stride;
335 stride *= spatial_shape[d] as i64;
336 }
337 flat
338}
339
340pub fn max_pool(input: &ArrayD<f64>, config: &PoolConfig) -> Result<ArrayD<f64>, PoolingError> {
345 config.validate()?;
346 let num_spatial = config.num_spatial_dims();
347 validate_input(input, num_spatial)?;
348
349 let input_shape = input.shape();
350 let out_shape = compute_output_shape(input_shape, config)?;
351 let spatial_offset = input_shape.len() - num_spatial;
352 let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
353 let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
354
355 let mut output = ArrayD::zeros(IxDyn(&out_shape));
356 let n_outer = num_outer_slices(input_shape, num_spatial);
357
358 for outer_flat in 0..n_outer {
359 let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
360
361 for_each_window(
362 &input_spatial,
363 config,
364 &output_spatial,
365 |out_pos, positions| {
366 let mut max_val = f64::NEG_INFINITY;
367 for (_, actual_pos) in &positions {
368 let val = get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
369 if val > max_val {
370 max_val = val;
371 }
372 }
373 if max_val == f64::NEG_INFINITY {
375 max_val = 0.0;
376 }
377 let mut full_idx: Vec<usize> = outer_idx.clone();
378 full_idx.extend_from_slice(out_pos);
379 output[IxDyn(&full_idx)] = max_val;
380 },
381 );
382 }
383
384 Ok(output)
385}
386
387pub fn max_pool_with_indices(
391 input: &ArrayD<f64>,
392 config: &PoolConfig,
393) -> Result<(ArrayD<f64>, ArrayD<i64>), PoolingError> {
394 config.validate()?;
395 let num_spatial = config.num_spatial_dims();
396 validate_input(input, num_spatial)?;
397
398 let input_shape = input.shape();
399 let out_shape = compute_output_shape(input_shape, config)?;
400 let spatial_offset = input_shape.len() - num_spatial;
401 let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
402 let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
403
404 let mut output = ArrayD::zeros(IxDyn(&out_shape));
405 let mut indices = ArrayD::zeros(IxDyn(&out_shape));
406 let n_outer = num_outer_slices(input_shape, num_spatial);
407
408 for outer_flat in 0..n_outer {
409 let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
410
411 for_each_window(
412 &input_spatial,
413 config,
414 &output_spatial,
415 |out_pos, positions| {
416 let mut max_val = f64::NEG_INFINITY;
417 let mut max_idx: i64 = -1;
418 for (_, actual_pos) in &positions {
419 let val = get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
420 if val > max_val {
421 max_val = val;
422 max_idx = spatial_flat_index(actual_pos, &input_spatial);
423 }
424 }
425 if max_val == f64::NEG_INFINITY {
426 max_val = 0.0;
427 max_idx = 0;
428 }
429 let mut full_idx: Vec<usize> = outer_idx.clone();
430 full_idx.extend_from_slice(out_pos);
431 output[IxDyn(&full_idx)] = max_val;
432 indices[IxDyn(&full_idx)] = max_idx;
433 },
434 );
435 }
436
437 Ok((output, indices))
438}
439
440pub fn avg_pool(input: &ArrayD<f64>, config: &PoolConfig) -> Result<ArrayD<f64>, PoolingError> {
444 config.validate()?;
445 let num_spatial = config.num_spatial_dims();
446 validate_input(input, num_spatial)?;
447
448 let input_shape = input.shape();
449 let out_shape = compute_output_shape(input_shape, config)?;
450 let spatial_offset = input_shape.len() - num_spatial;
451 let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
452 let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
453
454 let mut output = ArrayD::zeros(IxDyn(&out_shape));
455 let n_outer = num_outer_slices(input_shape, num_spatial);
456
457 for outer_flat in 0..n_outer {
458 let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
459
460 for_each_window(
461 &input_spatial,
462 config,
463 &output_spatial,
464 |out_pos, positions| {
465 let mut sum = 0.0;
466 let count = positions.len();
467 for (_, actual_pos) in &positions {
468 sum += get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
469 }
470 let avg = if count > 0 { sum / count as f64 } else { 0.0 };
471 let mut full_idx: Vec<usize> = outer_idx.clone();
472 full_idx.extend_from_slice(out_pos);
473 output[IxDyn(&full_idx)] = avg;
474 },
475 );
476 }
477
478 Ok(output)
479}
480
481pub fn lp_pool(
483 input: &ArrayD<f64>,
484 config: &PoolConfig,
485 p: f64,
486) -> Result<ArrayD<f64>, PoolingError> {
487 config.validate()?;
488 let num_spatial = config.num_spatial_dims();
489 validate_input(input, num_spatial)?;
490
491 let input_shape = input.shape();
492 let out_shape = compute_output_shape(input_shape, config)?;
493 let spatial_offset = input_shape.len() - num_spatial;
494 let input_spatial: Vec<usize> = input_shape[spatial_offset..].to_vec();
495 let output_spatial: Vec<usize> = out_shape[spatial_offset..].to_vec();
496
497 let mut output = ArrayD::zeros(IxDyn(&out_shape));
498 let n_outer = num_outer_slices(input_shape, num_spatial);
499
500 for outer_flat in 0..n_outer {
501 let outer_idx = flat_to_outer_indices(outer_flat, input_shape, num_spatial);
502
503 for_each_window(
504 &input_spatial,
505 config,
506 &output_spatial,
507 |out_pos, positions| {
508 let count = positions.len();
509 let mut sum_pow = 0.0;
510 for (_, actual_pos) in &positions {
511 let val = get_spatial_value(input, &outer_idx, actual_pos, num_spatial);
512 sum_pow += val.abs().powf(p);
513 }
514 let result = if count > 0 {
515 (sum_pow / count as f64).powf(1.0 / p)
516 } else {
517 0.0
518 };
519 let mut full_idx: Vec<usize> = outer_idx.clone();
520 full_idx.extend_from_slice(out_pos);
521 output[IxDyn(&full_idx)] = result;
522 },
523 );
524 }
525
526 Ok(output)
527}
528
529pub fn global_max_pool(input: &ArrayD<f64>) -> Result<ArrayD<f64>, PoolingError> {
533 if input.is_empty() {
534 return Err(PoolingError::EmptyInput);
535 }
536 if input.ndim() < 3 {
537 return Err(PoolingError::InsufficientDimensions {
538 ndim: input.ndim(),
539 required: 3,
540 });
541 }
542
543 let shape = input.shape();
544 let batch = shape[0];
545 let channels = shape[1];
546 let num_spatial = input.ndim() - 2;
547 let spatial_size: usize = shape[2..].iter().product();
548
549 let mut output = ArrayD::zeros(IxDyn(&[batch, channels]));
550
551 for b in 0..batch {
552 for c in 0..channels {
553 let mut max_val = f64::NEG_INFINITY;
554 for s in 0..spatial_size {
556 let spatial_idx = flat_to_spatial_indices(s, &shape[2..]);
557 let mut full_idx = vec![b, c];
558 full_idx.extend_from_slice(&spatial_idx);
559 let val = input[IxDyn(&full_idx)];
560 if val > max_val {
561 max_val = val;
562 }
563 }
564 if max_val == f64::NEG_INFINITY {
565 max_val = 0.0;
566 }
567 output[IxDyn(&[b, c])] = max_val;
568 }
569 }
570 let _ = num_spatial;
572
573 Ok(output)
574}
575
576pub fn global_avg_pool(input: &ArrayD<f64>) -> Result<ArrayD<f64>, PoolingError> {
580 if input.is_empty() {
581 return Err(PoolingError::EmptyInput);
582 }
583 if input.ndim() < 3 {
584 return Err(PoolingError::InsufficientDimensions {
585 ndim: input.ndim(),
586 required: 3,
587 });
588 }
589
590 let shape = input.shape();
591 let batch = shape[0];
592 let channels = shape[1];
593 let spatial_size: usize = shape[2..].iter().product();
594
595 let mut output = ArrayD::zeros(IxDyn(&[batch, channels]));
596
597 for b in 0..batch {
598 for c in 0..channels {
599 let mut sum = 0.0;
600 for s in 0..spatial_size {
601 let spatial_idx = flat_to_spatial_indices(s, &shape[2..]);
602 let mut full_idx = vec![b, c];
603 full_idx.extend_from_slice(&spatial_idx);
604 sum += input[IxDyn(&full_idx)];
605 }
606 output[IxDyn(&[b, c])] = sum / spatial_size as f64;
607 }
608 }
609
610 Ok(output)
611}
612
613fn flat_to_spatial_indices(mut flat: usize, spatial_shape: &[usize]) -> Vec<usize> {
615 let mut indices = vec![0usize; spatial_shape.len()];
616 for d in (0..spatial_shape.len()).rev() {
617 indices[d] = flat % spatial_shape[d];
618 flat /= spatial_shape[d];
619 }
620 indices
621}
622
623pub fn adaptive_avg_pool(
627 input: &ArrayD<f64>,
628 output_size: &[usize],
629) -> Result<ArrayD<f64>, PoolingError> {
630 if input.is_empty() {
631 return Err(PoolingError::EmptyInput);
632 }
633 let num_spatial = output_size.len();
634 if input.ndim() < num_spatial + 2 {
635 return Err(PoolingError::InsufficientDimensions {
636 ndim: input.ndim(),
637 required: num_spatial + 2,
638 });
639 }
640
641 let shape = input.shape();
642 let spatial_offset = shape.len() - num_spatial;
643 let input_spatial: Vec<usize> = shape[spatial_offset..].to_vec();
644
645 let mut out_shape: Vec<usize> = shape[..spatial_offset].to_vec();
647 out_shape.extend_from_slice(output_size);
648
649 let mut output = ArrayD::zeros(IxDyn(&out_shape));
650 let n_outer = num_outer_slices(shape, num_spatial);
651
652 for outer_flat in 0..n_outer {
653 let outer_idx = flat_to_outer_indices(outer_flat, shape, num_spatial);
654
655 let mut out_pos = vec![0usize; num_spatial];
657 loop {
658 let mut ranges: Vec<(usize, usize)> = Vec::with_capacity(num_spatial);
660 for d in 0..num_spatial {
661 let in_size = input_spatial[d];
662 let out_sz = output_size[d];
663 let start = (out_pos[d] * in_size) / out_sz;
664 let end = ((out_pos[d] + 1) * in_size) / out_sz;
665 ranges.push((start, end));
666 }
667
668 let mut sum = 0.0;
670 let mut count = 0usize;
671 let mut win_pos = vec![0usize; num_spatial];
672 for d in 0..num_spatial {
674 win_pos[d] = ranges[d].0;
675 }
676 loop {
677 let val = get_spatial_value(input, &outer_idx, &win_pos, num_spatial);
678 sum += val;
679 count += 1;
680
681 if !advance_within_ranges(&mut win_pos, &ranges) {
683 break;
684 }
685 }
686
687 let avg = if count > 0 { sum / count as f64 } else { 0.0 };
688 let mut full_idx: Vec<usize> = outer_idx.clone();
689 full_idx.extend_from_slice(&out_pos);
690 output[IxDyn(&full_idx)] = avg;
691
692 if !advance_indices(&mut out_pos, output_size) {
693 break;
694 }
695 }
696 }
697
698 Ok(output)
699}
700
701fn advance_within_ranges(indices: &mut [usize], ranges: &[(usize, usize)]) -> bool {
703 for d in (0..indices.len()).rev() {
704 indices[d] += 1;
705 if indices[d] < ranges[d].1 {
706 return true;
707 }
708 indices[d] = ranges[d].0;
709 }
710 false
711}
712
713pub fn max_unpool(
718 pooled: &ArrayD<f64>,
719 indices: &ArrayD<i64>,
720 output_size: &[usize],
721) -> Result<ArrayD<f64>, PoolingError> {
722 if pooled.shape() != indices.shape() {
723 return Err(PoolingError::ShapeMismatch(format!(
724 "pooled shape {:?} != indices shape {:?}",
725 pooled.shape(),
726 indices.shape()
727 )));
728 }
729 if pooled.is_empty() {
730 return Err(PoolingError::EmptyInput);
731 }
732
733 let pooled_shape = pooled.shape();
734 if output_size.len() != pooled_shape.len() {
736 return Err(PoolingError::ShapeMismatch(format!(
737 "output_size len {} != pooled ndim {}",
738 output_size.len(),
739 pooled_shape.len()
740 )));
741 }
742
743 let num_spatial = pooled_shape.len().saturating_sub(2);
746 let spatial_offset = pooled_shape.len() - num_spatial;
747 let output_spatial: Vec<usize> = output_size[spatial_offset..].to_vec();
748
749 let mut output = ArrayD::zeros(IxDyn(output_size));
750 let n_outer = num_outer_slices(pooled_shape, num_spatial);
751
752 let output_spatial_total: usize = output_spatial.iter().product();
754
755 for outer_flat in 0..n_outer {
756 let outer_idx = flat_to_outer_indices(outer_flat, pooled_shape, num_spatial);
757
758 let pooled_spatial: Vec<usize> = pooled_shape[spatial_offset..].to_vec();
760 let mut pos = vec![0usize; num_spatial];
761 loop {
762 let mut pooled_full: Vec<usize> = outer_idx.clone();
763 pooled_full.extend_from_slice(&pos);
764 let val = pooled[IxDyn(&pooled_full)];
765 let idx = indices[IxDyn(&pooled_full)];
766
767 if idx >= 0 && (idx as usize) < output_spatial_total {
768 let spatial_pos = flat_to_spatial_indices(idx as usize, &output_spatial);
769 let mut out_full: Vec<usize> = outer_idx.clone();
770 out_full.extend_from_slice(&spatial_pos);
771 output[IxDyn(&out_full)] = val;
772 }
773
774 if !advance_indices(&mut pos, &pooled_spatial) {
775 break;
776 }
777 }
778 }
779
780 Ok(output)
781}
782
783#[derive(Debug, Clone)]
785pub struct PoolingStats {
786 pub input_shape: Vec<usize>,
788 pub output_shape: Vec<usize>,
790 pub kernel_size: Vec<usize>,
792 pub stride: Vec<usize>,
794 pub receptive_field_size: usize,
796 pub compression_ratio: f64,
798 pub overlap_ratio: f64,
800}
801
802impl PoolingStats {
803 pub fn compute(input_shape: &[usize], config: &PoolConfig) -> Result<Self, PoolingError> {
805 config.validate()?;
806 let num_spatial = config.num_spatial_dims();
807 if input_shape.len() < num_spatial + 2 {
808 return Err(PoolingError::InsufficientDimensions {
809 ndim: input_shape.len(),
810 required: num_spatial + 2,
811 });
812 }
813
814 let output_shape = compute_output_shape(input_shape, config)?;
815 let spatial_offset = input_shape.len() - num_spatial;
816
817 let input_spatial_size: usize = input_shape[spatial_offset..].iter().product();
818 let output_spatial_size: usize = output_shape[spatial_offset..].iter().product();
819
820 let receptive_field_size: usize = config.kernel_size.iter().product();
821
822 let compression_ratio = if output_spatial_size > 0 {
823 input_spatial_size as f64 / output_spatial_size as f64
824 } else {
825 f64::INFINITY
826 };
827
828 let mut overlap_sum = 0.0;
831 for d in 0..num_spatial {
832 let k = config.kernel_size.get(d).copied().unwrap_or(1) as f64;
833 let s = config.effective_stride(d) as f64;
834 let overlap = ((k - s) / k).max(0.0);
835 overlap_sum += overlap;
836 }
837 let overlap_ratio = if num_spatial > 0 {
838 overlap_sum / num_spatial as f64
839 } else {
840 0.0
841 };
842
843 let effective_stride: Vec<usize> = (0..num_spatial)
844 .map(|d| config.effective_stride(d))
845 .collect();
846
847 Ok(Self {
848 input_shape: input_shape.to_vec(),
849 output_shape,
850 kernel_size: config.kernel_size.clone(),
851 stride: effective_stride,
852 receptive_field_size,
853 compression_ratio,
854 overlap_ratio,
855 })
856 }
857
858 pub fn summary(&self) -> String {
860 format!(
861 "Pooling: {:?} -> {:?}, kernel={:?}, stride={:?}, \
862 receptive_field={}, compression={:.2}x, overlap={:.2}",
863 self.input_shape,
864 self.output_shape,
865 self.kernel_size,
866 self.stride,
867 self.receptive_field_size,
868 self.compression_ratio,
869 self.overlap_ratio,
870 )
871 }
872}
873
874#[cfg(test)]
875mod tests {
876 use super::*;
877 use scirs2_core::ndarray::ArrayD;
878
879 fn make_4d(data: Vec<f64>, h: usize, w: usize) -> ArrayD<f64> {
880 ArrayD::from_shape_vec(IxDyn(&[1, 1, h, w]), data)
881 .expect("test tensor creation should succeed")
882 }
883
884 #[test]
885 fn test_pool_config_output_size() {
886 let config = PoolConfig::new(vec![2, 2]);
887 assert_eq!(config.output_size(4, 0), 2);
888 assert_eq!(config.output_size(4, 1), 2);
889 }
890
891 #[test]
892 fn test_pool_config_output_size_with_padding() {
893 let config = PoolConfig::new(vec![2, 2]).with_padding(vec![1, 1]);
894 assert_eq!(config.output_size(4, 0), 3);
896 }
897
898 #[test]
899 fn test_pool_config_validate_valid() {
900 let config = PoolConfig::new(vec![2, 2]);
901 assert!(config.validate().is_ok());
902 }
903
904 #[test]
905 fn test_pool_config_validate_zero_kernel() {
906 let config = PoolConfig::new(vec![0, 2]);
907 let err = config.validate();
908 assert!(err.is_err());
909 match err {
910 Err(PoolingError::InvalidKernelSize { size: 0 }) => {}
911 other => panic!("Expected InvalidKernelSize, got {:?}", other),
912 }
913 }
914
915 #[test]
916 fn test_max_pool_basic() {
917 #[rustfmt::skip]
919 let data = vec![
920 1.0, 2.0, 3.0, 4.0,
921 5.0, 6.0, 7.0, 8.0,
922 9.0, 10.0, 11.0, 12.0,
923 13.0, 14.0, 15.0, 16.0,
924 ];
925 let input = make_4d(data, 4, 4);
926 let config = PoolConfig::new(vec![2, 2]);
927 let output = max_pool(&input, &config).expect("max_pool should succeed");
928
929 assert_eq!(output.shape(), &[1, 1, 2, 2]);
930 assert_eq!(output[IxDyn(&[0, 0, 0, 0])], 6.0);
931 assert_eq!(output[IxDyn(&[0, 0, 0, 1])], 8.0);
932 assert_eq!(output[IxDyn(&[0, 0, 1, 0])], 14.0);
933 assert_eq!(output[IxDyn(&[0, 0, 1, 1])], 16.0);
934 }
935
936 #[test]
937 fn test_max_pool_with_indices_correct() {
938 #[rustfmt::skip]
939 let data = vec![
940 1.0, 2.0, 3.0, 4.0,
941 5.0, 6.0, 7.0, 8.0,
942 9.0, 10.0, 11.0, 12.0,
943 13.0, 14.0, 15.0, 16.0,
944 ];
945 let input = make_4d(data, 4, 4);
946 let config = PoolConfig::new(vec![2, 2]);
947 let (output, indices) =
948 max_pool_with_indices(&input, &config).expect("max_pool_with_indices should succeed");
949
950 assert_eq!(output.shape(), &[1, 1, 2, 2]);
951 assert_eq!(output[IxDyn(&[0, 0, 0, 0])], 6.0);
953 assert_eq!(indices[IxDyn(&[0, 0, 0, 0])], 5);
954 assert_eq!(output[IxDyn(&[0, 0, 0, 1])], 8.0);
956 assert_eq!(indices[IxDyn(&[0, 0, 0, 1])], 7);
957 assert_eq!(output[IxDyn(&[0, 0, 1, 0])], 14.0);
959 assert_eq!(indices[IxDyn(&[0, 0, 1, 0])], 13);
960 assert_eq!(output[IxDyn(&[0, 0, 1, 1])], 16.0);
962 assert_eq!(indices[IxDyn(&[0, 0, 1, 1])], 15);
963 }
964
965 #[test]
966 fn test_avg_pool_basic() {
967 #[rustfmt::skip]
968 let data = vec![
969 1.0, 2.0, 3.0, 4.0,
970 5.0, 6.0, 7.0, 8.0,
971 9.0, 10.0, 11.0, 12.0,
972 13.0, 14.0, 15.0, 16.0,
973 ];
974 let input = make_4d(data, 4, 4);
975 let config = PoolConfig::new(vec![2, 2]);
976 let output = avg_pool(&input, &config).expect("avg_pool should succeed");
977
978 assert_eq!(output.shape(), &[1, 1, 2, 2]);
979 assert!((output[IxDyn(&[0, 0, 0, 0])] - 3.5).abs() < 1e-10);
981 assert!((output[IxDyn(&[0, 0, 0, 1])] - 5.5).abs() < 1e-10);
983 assert!((output[IxDyn(&[0, 0, 1, 0])] - 11.5).abs() < 1e-10);
985 assert!((output[IxDyn(&[0, 0, 1, 1])] - 13.5).abs() < 1e-10);
987 }
988
989 #[test]
990 fn test_avg_pool_padding() {
991 let data = vec![1.0; 16];
993 let input = make_4d(data, 4, 4);
994 let config = PoolConfig::new(vec![2, 2]).with_padding(vec![1, 1]);
995 let output = avg_pool(&input, &config).expect("avg_pool with padding should succeed");
996
997 assert_eq!(output.shape(), &[1, 1, 3, 3]);
998 }
999
1000 #[test]
1001 fn test_lp_pool_p2() {
1002 #[rustfmt::skip]
1004 let data = vec![
1005 1.0, 2.0,
1006 3.0, 4.0,
1007 ];
1008 let input = make_4d(data, 2, 2);
1009 let config = PoolConfig::new(vec![2, 2]);
1010 let output = lp_pool(&input, &config, 2.0).expect("lp_pool p=2 should succeed");
1011
1012 assert_eq!(output.shape(), &[1, 1, 1, 1]);
1013 let expected = (7.5_f64).sqrt();
1015 assert!((output[IxDyn(&[0, 0, 0, 0])] - expected).abs() < 1e-10);
1016 }
1017
1018 #[test]
1019 fn test_lp_pool_p1() {
1020 #[rustfmt::skip]
1022 let data = vec![
1023 1.0, -2.0,
1024 3.0, -4.0,
1025 ];
1026 let input = make_4d(data, 2, 2);
1027 let config = PoolConfig::new(vec![2, 2]);
1028 let output = lp_pool(&input, &config, 1.0).expect("lp_pool p=1 should succeed");
1029
1030 assert_eq!(output.shape(), &[1, 1, 1, 1]);
1031 assert!((output[IxDyn(&[0, 0, 0, 0])] - 2.5).abs() < 1e-10);
1033 }
1034
1035 #[test]
1036 fn test_global_max_pool_shape() {
1037 let input = ArrayD::zeros(IxDyn(&[1, 3, 4, 4]));
1038 let output = global_max_pool(&input).expect("global_max_pool should succeed");
1039 assert_eq!(output.shape(), &[1, 3]);
1040 }
1041
1042 #[test]
1043 fn test_global_max_pool_values() {
1044 let mut input = ArrayD::zeros(IxDyn(&[1, 3, 4, 4]));
1045 input[IxDyn(&[0, 0, 2, 3])] = 42.0;
1047 input[IxDyn(&[0, 1, 0, 0])] = 99.0;
1048 input[IxDyn(&[0, 2, 3, 3])] = -1.0; let output = global_max_pool(&input).expect("global_max_pool should succeed");
1052 assert_eq!(output[IxDyn(&[0, 0])], 42.0);
1053 assert_eq!(output[IxDyn(&[0, 1])], 99.0);
1054 assert_eq!(output[IxDyn(&[0, 2])], 0.0); }
1056
1057 #[test]
1058 fn test_global_avg_pool_shape() {
1059 let input = ArrayD::zeros(IxDyn(&[1, 3, 4, 4]));
1060 let output = global_avg_pool(&input).expect("global_avg_pool should succeed");
1061 assert_eq!(output.shape(), &[1, 3]);
1062 }
1063
1064 #[test]
1065 fn test_global_avg_pool_values() {
1066 let mut input = ArrayD::ones(IxDyn(&[1, 2, 2, 2]));
1067 input[IxDyn(&[0, 1, 0, 0])] = 2.0;
1070 input[IxDyn(&[0, 1, 0, 1])] = 2.0;
1071 input[IxDyn(&[0, 1, 1, 0])] = 2.0;
1072 input[IxDyn(&[0, 1, 1, 1])] = 2.0;
1073
1074 let output = global_avg_pool(&input).expect("global_avg_pool should succeed");
1075 assert!((output[IxDyn(&[0, 0])] - 1.0).abs() < 1e-10);
1076 assert!((output[IxDyn(&[0, 1])] - 2.0).abs() < 1e-10);
1077 }
1078
1079 #[test]
1080 fn test_adaptive_avg_pool_output_size() {
1081 let input = ArrayD::ones(IxDyn(&[1, 1, 4, 4]));
1082 let output = adaptive_avg_pool(&input, &[2, 2]).expect("adaptive_avg_pool should succeed");
1083 assert_eq!(output.shape(), &[1, 1, 2, 2]);
1084 }
1085
1086 #[test]
1087 fn test_adaptive_avg_pool_identity() {
1088 #[rustfmt::skip]
1090 let data = vec![
1091 1.0, 2.0, 3.0, 4.0,
1092 5.0, 6.0, 7.0, 8.0,
1093 9.0, 10.0, 11.0, 12.0,
1094 13.0, 14.0, 15.0, 16.0,
1095 ];
1096 let input = make_4d(data.clone(), 4, 4);
1097 let output =
1098 adaptive_avg_pool(&input, &[4, 4]).expect("adaptive_avg_pool identity should succeed");
1099 assert_eq!(output.shape(), &[1, 1, 4, 4]);
1100 for (i, &v) in data.iter().enumerate() {
1101 let h = i / 4;
1102 let w = i % 4;
1103 assert!(
1104 (output[IxDyn(&[0, 0, h, w])] - v).abs() < 1e-10,
1105 "mismatch at ({}, {})",
1106 h,
1107 w
1108 );
1109 }
1110 }
1111
1112 #[test]
1113 fn test_max_unpool_basic() {
1114 #[rustfmt::skip]
1115 let data = vec![
1116 1.0, 2.0, 3.0, 4.0,
1117 5.0, 6.0, 7.0, 8.0,
1118 9.0, 10.0, 11.0, 12.0,
1119 13.0, 14.0, 15.0, 16.0,
1120 ];
1121 let input = make_4d(data, 4, 4);
1122 let config = PoolConfig::new(vec![2, 2]);
1123
1124 let (pooled, indices) =
1125 max_pool_with_indices(&input, &config).expect("max_pool_with_indices should succeed");
1126
1127 let unpooled =
1128 max_unpool(&pooled, &indices, &[1, 1, 4, 4]).expect("max_unpool should succeed");
1129
1130 assert_eq!(unpooled.shape(), &[1, 1, 4, 4]);
1131 assert_eq!(unpooled[IxDyn(&[0, 0, 1, 1])], 6.0); assert_eq!(unpooled[IxDyn(&[0, 0, 1, 3])], 8.0); assert_eq!(unpooled[IxDyn(&[0, 0, 3, 1])], 14.0); assert_eq!(unpooled[IxDyn(&[0, 0, 3, 3])], 16.0); assert_eq!(unpooled[IxDyn(&[0, 0, 0, 0])], 0.0);
1138 assert_eq!(unpooled[IxDyn(&[0, 0, 2, 2])], 0.0);
1139 }
1140
1141 #[test]
1142 fn test_pooling_stats_compression() {
1143 let config = PoolConfig::new(vec![2, 2]);
1144 let stats =
1145 PoolingStats::compute(&[1, 1, 4, 4], &config).expect("stats compute should succeed");
1146 assert_eq!(stats.output_shape, vec![1, 1, 2, 2]);
1147 assert!((stats.compression_ratio - 4.0).abs() < 1e-10);
1149 assert_eq!(stats.receptive_field_size, 4);
1150 assert!((stats.overlap_ratio - 0.0).abs() < 1e-10);
1152 }
1153
1154 #[test]
1155 fn test_pooling_stats_summary() {
1156 let config = PoolConfig::new(vec![2, 2]);
1157 let stats =
1158 PoolingStats::compute(&[1, 1, 4, 4], &config).expect("stats compute should succeed");
1159 let summary = stats.summary();
1160 assert!(!summary.is_empty());
1161 assert!(summary.contains("Pooling"));
1162 }
1163
1164 #[test]
1165 fn test_pooling_error_display() {
1166 let errors = vec![
1167 PoolingError::InvalidKernelSize { size: 0 },
1168 PoolingError::InvalidStride { stride: 0 },
1169 PoolingError::InvalidPadding {
1170 padding: 3,
1171 kernel_size: 2,
1172 },
1173 PoolingError::InsufficientDimensions {
1174 ndim: 2,
1175 required: 4,
1176 },
1177 PoolingError::EmptyInput,
1178 PoolingError::ShapeMismatch("test".to_string()),
1179 ];
1180 for err in &errors {
1181 let msg = format!("{err}");
1182 assert!(!msg.is_empty(), "Error display should not be empty");
1183 }
1184 }
1185}