scirs2_linalg/convolution/
mod.rs

1//! Specialized operations for convolutional neural networks
2//!
3//! This module provides efficient implementations of specialized operations
4//! that are commonly used in convolutional neural networks, such as im2col/col2im,
5//! efficient convolution algorithms, and other related operations.
6
7use scirs2_core::ndarray::{Array2, Array4, ArrayView4, ScalarOperand};
8use scirs2_core::numeric::{Float, NumAssign, Zero};
9use std::iter::Sum;
10
11use crate::error::{LinalgError, LinalgResult};
12
13/// Extract patches from an input tensor using the im2col algorithm
14///
15/// This function implements the im2col (image to column) algorithm, which
16/// reformats the input data to allow efficient computation of convolution
17/// operations using matrix multiplication.
18///
19/// # Arguments
20///
21/// * `input` - Input tensor of shape (batchsize, channels, height, width)
22/// * `kernelsize` - Size of the kernel as (kernel_height, kernel_width)
23/// * `stride` - Stride as (stride_height, stride_width)
24/// * `padding` - Padding as (padding_height, padding_width)
25/// * `dilation` - Dilation as (dilation_height, dilation_width)
26///
27/// # Returns
28///
29/// * Column matrix of shape (kernel_h * kernel_w * channels, output_h * output_w * batchsize)
30///
31/// # Examples
32///
33/// ```
34/// use scirs2_core::ndarray::Array4;
35/// use scirs2_linalg::convolution::im2col;
36///
37/// // Create a 1x3x4x4 input tensor (1 batch, 3 channels, 4x4 spatial dimensions)
38/// let mut input = Array4::<f32>::zeros((1, 3, 4, 4));
39/// // Fill with sample data
40/// for c in 0..3 {
41///     for h in 0..4 {
42///         for w in 0..4 {
43///             input[[0, c, h, w]] = (c * 16 + h * 4 + w) as f32;
44///         }
45///     }
46/// }
47///
48/// // Extract 3x3 patches with stride 1 and no padding
49/// let cols = im2col(&input.view(), (3, 3), (1, 1), (0, 0), (1, 1)).unwrap();
50///
51/// // Resulting matrix has shape (3*3*3, 2*2*1) = (27, 4)
52/// // Each column represents a 3x3 patch across all 3 channels
53/// assert_eq!(cols.shape(), &[27, 4]);
54/// ```
55#[allow(dead_code)]
56pub fn im2col<F>(
57    input: &ArrayView4<F>,
58    kernelsize: (usize, usize),
59    stride: (usize, usize),
60    padding: (usize, usize),
61    dilation: (usize, usize),
62) -> LinalgResult<Array2<F>>
63where
64    F: Float + NumAssign + Sum + Zero + ScalarOperand,
65{
66    let (batchsize, channels, height, width) = input.dim();
67    let (kernel_h, kernel_w) = kernelsize;
68    let (stride_h, stride_w) = stride;
69    let (padding_h, padding_w) = padding;
70    let (dilation_h, dilation_w) = dilation;
71
72    // Calculate output dimensions
73    let output_h = ((height + 2 * padding_h - dilation_h * (kernel_h - 1) - 1) / stride_h) + 1;
74    let output_w = ((width + 2 * padding_w - dilation_w * (kernel_w - 1) - 1) / stride_w) + 1;
75
76    // Check for valid dimensions
77    if output_h == 0 || output_w == 0 {
78        return Err(LinalgError::ShapeError(format!(
79            "Invalid output dimensions: ({output_h}, {output_w})"
80        )));
81    }
82
83    // Allocate output matrix
84    let mut cols = Array2::<F>::zeros((
85        kernel_h * kernel_w * channels,
86        output_h * output_w * batchsize,
87    ));
88
89    // Populate the output matrix with patches
90    for batch_idx in 0..batchsize {
91        for channel_idx in 0..channels {
92            for kernel_row in 0..kernel_h {
93                for kernel_col in 0..kernel_w {
94                    let input_row_offset = kernel_row * dilation_h;
95                    let input_col_offset = kernel_col * dilation_w;
96
97                    // Position in the cols matrix
98                    let cols_idx =
99                        channel_idx * kernel_h * kernel_w + kernel_row * kernel_w + kernel_col;
100
101                    for output_row in 0..output_h {
102                        for output_col in 0..output_w {
103                            let input_row = output_row * stride_h + input_row_offset;
104                            let input_col = output_col * stride_w + input_col_offset;
105
106                            // Position in the cols matrix
107                            let cols_pos = batch_idx * output_h * output_w
108                                + output_row * output_w
109                                + output_col;
110
111                            // Check if we need to pad
112                            if input_row < padding_h
113                                || input_row >= height + padding_h
114                                || input_col < padding_w
115                                || input_col >= width + padding_w
116                            {
117                                // Zero-padding
118                                cols[[cols_idx, cols_pos]] = F::zero();
119                            } else {
120                                // Copy from input
121                                let input_val = input[[
122                                    batch_idx,
123                                    channel_idx,
124                                    input_row - padding_h,
125                                    input_col - padding_w,
126                                ]];
127                                cols[[cols_idx, cols_pos]] = input_val;
128                            }
129                        }
130                    }
131                }
132            }
133        }
134    }
135
136    Ok(cols)
137}
138
139/// Convert a column matrix back to an input tensor using the col2im algorithm
140///
141/// This function implements the col2im (column to image) algorithm, which
142/// converts the column matrix back to the original input tensor format.
143///
144/// # Arguments
145///
146/// * `cols` - Column matrix of shape (kernel_h * kernel_w * channels, output_h * output_w * batchsize)
147/// * `outputshape` - Shape of the output tensor as (batchsize, channels, height, width)
148/// * `kernelsize` - Size of the kernel as (kernel_height, kernel_width)
149/// * `stride` - Stride as (stride_height, stride_width)
150/// * `padding` - Padding as (padding_height, padding_width)
151/// * `dilation` - Dilation as (dilation_height, dilation_width)
152///
153/// # Returns
154///
155/// * Output tensor of shape (batchsize, channels, height, width)
156///
157/// # Examples
158///
159/// ```
160/// use scirs2_core::ndarray::{Array4, ArrayView4};
161/// use scirs2_linalg::convolution::{im2col, col2im};
162///
163/// // Create a 1x3x4x4 input tensor
164/// let mut input = Array4::<f32>::zeros((1, 3, 4, 4));
165/// // Fill with sample data
166/// for c in 0..3 {
167///     for h in 0..4 {
168///         for w in 0..4 {
169///             input[[0, c, h, w]] = (c * 16 + h * 4 + w) as f32;
170///         }
171///     }
172/// }
173///
174/// // Convert to columns with im2col
175/// let cols = im2col(&input.view(), (3, 3), (1, 1), (0, 0), (1, 1)).unwrap();
176///
177/// // Convert back to image with col2im
178/// let output = col2im(
179///     &cols.view(),
180///     (1, 3, 4, 4),
181///     (3, 3),
182///     (1, 1),
183///     (0, 0),
184///     (1, 1),
185/// ).unwrap();
186///
187/// // Verify output shape
188/// assert_eq!(output.shape(), &[1, 3, 4, 4]);
189/// ```
190#[allow(dead_code)]
191pub fn col2im<F>(
192    cols: &scirs2_core::ndarray::ArrayView2<F>,
193    outputshape: (usize, usize, usize, usize),
194    kernelsize: (usize, usize),
195    stride: (usize, usize),
196    padding: (usize, usize),
197    dilation: (usize, usize),
198) -> LinalgResult<Array4<F>>
199where
200    F: Float + NumAssign + Sum + Zero + ScalarOperand,
201{
202    let (batchsize, channels, height, width) = outputshape;
203    let (kernel_h, kernel_w) = kernelsize;
204    let (stride_h, stride_w) = stride;
205    let (padding_h, padding_w) = padding;
206    let (dilation_h, dilation_w) = dilation;
207
208    // Calculate output dimensions
209    let output_h = ((height + 2 * padding_h - dilation_h * (kernel_h - 1) - 1) / stride_h) + 1;
210    let output_w = ((width + 2 * padding_w - dilation_w * (kernel_w - 1) - 1) / stride_w) + 1;
211
212    // Check for valid dimensions
213    if output_h == 0 || output_w == 0 {
214        return Err(LinalgError::ShapeError(format!(
215            "Invalid output dimensions: ({output_h}, {output_w})"
216        )));
217    }
218
219    // Check input columns shape
220    if cols.shape()[0] != kernel_h * kernel_w * channels
221        || cols.shape()[1] != output_h * output_w * batchsize
222    {
223        return Err(LinalgError::ShapeError(format!(
224            "Invalid cols shape: expected ({}, {}), got ({}, {})",
225            kernel_h * kernel_w * channels,
226            output_h * output_w * batchsize,
227            cols.shape()[0],
228            cols.shape()[1]
229        )));
230    }
231
232    // Allocate output tensor
233    let mut output = Array4::<F>::zeros((batchsize, channels, height, width));
234    let mut counts = Array4::<usize>::zeros((batchsize, channels, height, width));
235
236    // Accumulate values from cols to output
237    for batch_idx in 0..batchsize {
238        for channel_idx in 0..channels {
239            for kernel_row in 0..kernel_h {
240                for kernel_col in 0..kernel_w {
241                    let input_row_offset = kernel_row * dilation_h;
242                    let input_col_offset = kernel_col * dilation_w;
243
244                    // Position in the cols matrix
245                    let cols_idx =
246                        channel_idx * kernel_h * kernel_w + kernel_row * kernel_w + kernel_col;
247
248                    for output_row in 0..output_h {
249                        for output_col in 0..output_w {
250                            let input_row = output_row * stride_h + input_row_offset;
251                            let input_col = output_col * stride_w + input_col_offset;
252
253                            // Position in the cols matrix
254                            let cols_pos = batch_idx * output_h * output_w
255                                + output_row * output_w
256                                + output_col;
257
258                            // Check if the position is valid (not padding)
259                            if input_row >= padding_h
260                                && input_row < height + padding_h
261                                && input_col >= padding_w
262                                && input_col < width + padding_w
263                            {
264                                let output_row_idx = input_row - padding_h;
265                                let output_col_idx = input_col - padding_w;
266
267                                output[[batch_idx, channel_idx, output_row_idx, output_col_idx]] +=
268                                    cols[[cols_idx, cols_pos]];
269                                counts[[batch_idx, channel_idx, output_row_idx, output_col_idx]] +=
270                                    1;
271                            }
272                        }
273                    }
274                }
275            }
276        }
277    }
278
279    // Normalize by count (average overlapping patches)
280    for batch_idx in 0..batchsize {
281        for channel_idx in 0..channels {
282            for h in 0..height {
283                for w in 0..width {
284                    let count = counts[[batch_idx, channel_idx, h, w]];
285                    if count > 0 {
286                        output[[batch_idx, channel_idx, h, w]] /= F::from(count).unwrap();
287                    }
288                }
289            }
290        }
291    }
292
293    Ok(output)
294}
295
296/// Perform max pooling operation on a 4D input tensor
297///
298/// Applies max pooling over a 4D tensor, which is commonly used to
299/// down-sample feature maps in convolutional neural networks.
300///
301/// # Arguments
302///
303/// * `input` - Input tensor of shape (batchsize, channels, height, width)
304/// * `poolsize` - Size of the pooling window as (pool_height, pool_width)
305/// * `stride` - Stride as (stride_height, stride_width)
306/// * `padding` - Padding as (padding_height, padding_width)
307///
308/// # Returns
309///
310/// * Output tensor of pooled values and indices of max values (for backward pass)
311///
312/// # Examples
313///
314/// ```
315/// use scirs2_core::ndarray::Array4;
316/// use scirs2_linalg::convolution::max_pool2d;
317///
318/// // Create a 1x1x4x4 input tensor
319/// let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
320/// // Fill with sample data
321/// for h in 0..4 {
322///     for w in 0..4 {
323///         input[[0, 0, h, w]] = (h * 4 + w) as f32;
324///     }
325/// }
326///
327/// // Apply 2x2 max pooling with stride 2
328/// let (output, indices) = max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).unwrap();
329///
330/// // Resulting tensor has shape (1, 1, 2, 2)
331/// assert_eq!(output.shape(), &[1, 1, 2, 2]);
332/// ```
333#[allow(dead_code)]
334pub fn max_pool2d<F>(
335    input: &ArrayView4<F>,
336    poolsize: (usize, usize),
337    stride: (usize, usize),
338    padding: (usize, usize),
339) -> LinalgResult<(Array4<F>, Array4<usize>)>
340where
341    F: Float + NumAssign + Sum + Zero + ScalarOperand,
342{
343    let (batchsize, channels, height, width) = input.dim();
344    let (pool_h, pool_w) = poolsize;
345    let (stride_h, stride_w) = stride;
346    let (padding_h, padding_w) = padding;
347
348    // Calculate output dimensions
349    let output_h = ((height + 2 * padding_h - pool_h) / stride_h) + 1;
350    let output_w = ((width + 2 * padding_w - pool_w) / stride_w) + 1;
351
352    // Check for valid dimensions
353    if output_h == 0 || output_w == 0 {
354        return Err(LinalgError::ShapeError(format!(
355            "Invalid output dimensions: ({output_h}, {output_w})"
356        )));
357    }
358
359    // Allocate output tensors
360    let mut output = Array4::<F>::zeros((batchsize, channels, output_h, output_w));
361    let mut indices = Array4::<usize>::zeros((batchsize, channels, output_h, output_w));
362
363    // Perform max pooling
364    for batch_idx in 0..batchsize {
365        for channel_idx in 0..channels {
366            for output_row in 0..output_h {
367                for output_col in 0..output_w {
368                    let start_h = output_row * stride_h;
369                    let start_w = output_col * stride_w;
370
371                    let mut max_val = F::neg_infinity();
372                    let mut max_idx = 0;
373
374                    // Find max value in the pooling window
375                    for pool_row in 0..pool_h {
376                        for pool_col in 0..pool_w {
377                            let input_row = start_h + pool_row;
378                            let input_col = start_w + pool_col;
379
380                            // Check if the position is valid (not padding)
381                            if input_row >= padding_h
382                                && input_row < height + padding_h
383                                && input_col >= padding_w
384                                && input_col < width + padding_w
385                            {
386                                let input_row_idx = input_row - padding_h;
387                                let input_col_idx = input_col - padding_w;
388                                let val =
389                                    input[[batch_idx, channel_idx, input_row_idx, input_col_idx]];
390
391                                if val > max_val {
392                                    max_val = val;
393                                    max_idx = input_row_idx * width + input_col_idx;
394                                }
395                            }
396                        }
397                    }
398
399                    output[[batch_idx, channel_idx, output_row, output_col]] = max_val;
400                    indices[[batch_idx, channel_idx, output_row, output_col]] = max_idx;
401                }
402            }
403        }
404    }
405
406    Ok((output, indices))
407}
408
409/// Perform the backward pass of max pooling operation
410///
411/// Takes the gradients of the pooled outputs and distributes them back to
412/// the locations of the maximum values in the original input.
413///
414/// # Arguments
415///
416/// * `grad_output` - Gradient of the output tensor of shape (batchsize, channels, output_height, output_width)
417/// * `indices` - Indices of the maximum values from the forward pass
418/// * `inputshape` - Shape of the original input tensor (batchsize, channels, height, width)
419///
420/// # Returns
421///
422/// * Gradient with respect to input
423///
424/// # Examples
425///
426/// ```
427/// use scirs2_core::ndarray::Array4;
428/// use scirs2_linalg::convolution::{max_pool2d, max_pool2d_backward};
429///
430/// // Create a 1x1x4x4 input tensor
431/// let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
432/// // Fill with sample data
433/// for h in 0..4 {
434///     for w in 0..4 {
435///         input[[0, 0, h, w]] = (h * 4 + w) as f32;
436///     }
437/// }
438///
439/// // Apply max pooling (forward pass)
440/// let (output, indices) = max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).unwrap();
441///
442/// // Create gradient of the output
443/// let mut grad_output = Array4::<f32>::ones((1, 1, 2, 2));
444///
445/// // Compute gradient of the input (backward pass)
446/// let grad_input = max_pool2d_backward(
447///     &grad_output.view(),
448///     &indices.view(),
449///     (1, 1, 4, 4),
450/// ).unwrap();
451///
452/// // Verify shape
453/// assert_eq!(grad_input.shape(), &[1, 1, 4, 4]);
454/// ```
455#[allow(dead_code)]
456pub fn max_pool2d_backward<F>(
457    grad_output: &ArrayView4<F>,
458    indices: &scirs2_core::ndarray::ArrayView4<usize>,
459    inputshape: (usize, usize, usize, usize),
460) -> LinalgResult<Array4<F>>
461where
462    F: Float + NumAssign + Sum + Zero + ScalarOperand,
463{
464    let (batchsize, channels, height, width) = inputshape;
465    let (out_batch, out_channels_, out_height, out_width) = grad_output.dim();
466    let (idx_batch, idx_channels, idx_height, idx_width) = indices.dim();
467
468    // Check that shapes match
469    if out_batch != idx_batch
470        || out_channels_ != idx_channels
471        || out_height != idx_height
472        || out_width != idx_width
473    {
474        return Err(LinalgError::ShapeError(format!(
475            "Shape mismatch between grad_output ({out_batch}, {out_channels_}, {out_height}, {out_width}) and indices ({idx_batch}, {idx_channels}, {idx_height}, {idx_width})"
476        )));
477    }
478
479    // Allocate _output gradient tensor
480    let mut grad_input = Array4::<F>::zeros((batchsize, channels, height, width));
481
482    // Distribute gradients to the locations of the maximum values
483    for batch_idx in 0..out_batch {
484        for channel_idx in 0..out_channels_ {
485            for h in 0..out_height {
486                for w in 0..out_width {
487                    let index = indices[[batch_idx, channel_idx, h, w]];
488                    let input_h = index / width;
489                    let input_w = index % width;
490
491                    if input_h < height && input_w < width {
492                        grad_input[[batch_idx, channel_idx, input_h, input_w]] +=
493                            grad_output[[batch_idx, channel_idx, h, w]];
494                    }
495                }
496            }
497        }
498    }
499
500    Ok(grad_input)
501}
502
503/// Compute the indices for batch matrix multiplication in a convolutional layer
504///
505/// This function computes the indices needed for efficient batch matrix multiplication
506/// in a convolutional layer, which can be used to implement convolutional layers
507/// more efficiently.
508///
509/// # Arguments
510///
511/// * `inputshape` - Shape of the input tensor (batchsize, channels, height, width)
512/// * `kernelshape` - Shape of the kernel tensor (out_channels_, in_channels, kernel_h, kernel_w)
513/// * `stride` - Stride as (stride_height, stride_width)
514/// * `padding` - Padding as (padding_height, padding_width)
515///
516/// # Returns
517///
518/// * Indices for efficient batch matrix multiplication
519///
520/// # Examples
521///
522/// ```
523/// use scirs2_linalg::convolution::compute_conv_indices;
524///
525/// // Compute indices for a simple convolutional layer
526/// let indices = compute_conv_indices(
527///     (1, 1, 4, 4),    // Input shape: batchsize=1, channels=1, height=4, width=4
528///     (1, 1, 2, 2),    // Kernel shape: out_channels_=1, in_channels=1, kernel_h=2, kernel_w=2
529///     (1, 1),          // Stride: height=1, width=1
530///     (0, 0),          // Padding: height=0, width=0
531/// ).unwrap();
532/// // For a 4x4 input with 2x2 kernel and no padding, we get a 3x3 output
533/// // Each output element is computed from 4 input elements (2x2 kernel)
534/// // So we should have 3*3*4*5 = 180 values in the indices array
535/// assert_eq!(indices.len() % 5, 0); // Should be multiple of 5
536/// ```
537#[allow(dead_code)]
538pub fn compute_conv_indices(
539    inputshape: (usize, usize, usize, usize),
540    kernelshape: (usize, usize, usize, usize),
541    stride: (usize, usize),
542    padding: (usize, usize),
543) -> LinalgResult<scirs2_core::ndarray::Array1<usize>> {
544    let (batchsize, _in_channels, height, width) = inputshape;
545    let (out_channels_, in_channels, kernel_h, kernel_w) = kernelshape;
546    let (stride_h, stride_w) = stride;
547    let (padding_h, padding_w) = padding;
548
549    // Calculate output dimensions
550    let output_h = ((height + 2 * padding_h - kernel_h) / stride_h) + 1;
551    let output_w = ((width + 2 * padding_w - kernel_w) / stride_w) + 1;
552
553    // Check for valid dimensions
554    if output_h == 0 || output_w == 0 {
555        return Err(LinalgError::ShapeError(format!(
556            "Invalid output dimensions: ({output_h}, {output_w})"
557        )));
558    }
559
560    // Calculate total number of elements
561    // Each output element can be computed from in_channels * kernel_h * kernel_w input elements
562    let total_elements =
563        batchsize * out_channels_ * output_h * output_w * in_channels * kernel_h * kernel_w;
564
565    // Allocate array for indices (5 values per element)
566    let mut indices = scirs2_core::ndarray::Array1::<usize>::zeros(total_elements * 5);
567
568    // Compute indices for batch matmul
569    let mut idx = 0;
570    for b in 0..batchsize {
571        for oc in 0..out_channels_ {
572            for oh in 0..output_h {
573                for ow in 0..output_w {
574                    for ic in 0..in_channels {
575                        for kh in 0..kernel_h {
576                            for kw in 0..kernel_w {
577                                let ih = oh * stride_h + kh;
578                                let iw = ow * stride_w + kw;
579
580                                // Check if within padded input
581                                if ih >= padding_h
582                                    && ih < height + padding_h
583                                    && iw >= padding_w
584                                    && iw < width + padding_w
585                                {
586                                    let real_ih = ih - padding_h;
587                                    let real_iw = iw - padding_w;
588
589                                    // Output index
590                                    let out_idx = b * out_channels_ * output_h * output_w
591                                        + oc * output_h * output_w
592                                        + oh * output_w
593                                        + ow;
594
595                                    // Input index
596                                    let in_idx = b * in_channels * height * width
597                                        + ic * height * width
598                                        + real_ih * width
599                                        + real_iw;
600
601                                    // Kernel index
602                                    let kernel_idx = oc * in_channels * kernel_h * kernel_w
603                                        + ic * kernel_h * kernel_w
604                                        + kh * kernel_w
605                                        + kw;
606
607                                    // Store indices
608                                    indices[idx] = out_idx;
609                                    indices[idx + 1] = in_idx;
610                                    indices[idx + 2] = kernel_idx;
611                                    indices[idx + 3] = oh * output_w + ow;
612                                    indices[idx + 4] = oc;
613
614                                    idx += 5;
615                                }
616                            }
617                        }
618                    }
619                }
620            }
621        }
622    }
623
624    // Resize the array to remove unused elements
625    let indices = indices.slice(scirs2_core::ndarray::s![0..idx]).to_owned();
626    Ok(indices)
627}
628
629/// Apply convolution operation using im2col and matrix multiplication
630///
631/// This function implements convolution using the im2col algorithm and
632/// efficient matrix multiplication, which is often faster than direct
633/// convolution for large inputs or kernels.
634///
635/// # Arguments
636///
637/// * `input` - Input tensor of shape (batchsize, channels, height, width)
638/// * `kernel` - Kernel tensor of shape (out_channels_, in_channels, kernel_h, kernel_w)
639/// * `bias` - Optional bias tensor of shape (out_channels_,)
640/// * `stride` - Stride as (stride_height, stride_width)
641/// * `padding` - Padding as (padding_height, padding_width)
642/// * `dilation` - Dilation as (dilation_height, dilation_width)
643///
644/// # Returns
645///
646/// * Output tensor of shape (batchsize, out_channels_, output_height, output_width)
647///
648/// # Examples
649///
650/// ```
651/// use scirs2_core::ndarray::{Array, Array4};
652/// use scirs2_linalg::convolution::conv2d_im2col;
653///
654/// // Create a 2x3x32x32 input tensor (2 batches, 3 channels, 32x32 spatial dimensions)
655/// let input = Array4::<f32>::zeros((2, 3, 32, 32));
656///
657/// // Create a 16x3x3x3 kernel tensor (16 output channels, 3 input channels, 3x3 kernel)
658/// let kernel = Array4::<f32>::zeros((16, 3, 3, 3));
659///
660/// // Create a bias tensor
661/// let bias = Some(Array::zeros(16));
662///
663/// // Apply convolution
664/// let output = conv2d_im2col(
665///     &input.view(),
666///     &kernel.view(),
667///     bias.as_ref().map(|b| b.view()),
668///     (1, 1),  // stride
669///     (1, 1),  // padding
670///     (1, 1),  // dilation
671/// ).unwrap();
672///
673/// // Output shape is (2, 16, 32, 32)
674/// assert_eq!(output.shape(), &[2, 16, 32, 32]);
675/// ```
676#[allow(dead_code)]
677pub fn conv2d_im2col<F>(
678    input: &ArrayView4<F>,
679    kernel: &ArrayView4<F>,
680    bias: Option<scirs2_core::ndarray::ArrayView1<F>>,
681    stride: (usize, usize),
682    padding: (usize, usize),
683    dilation: (usize, usize),
684) -> LinalgResult<Array4<F>>
685where
686    F: Float + NumAssign + Sum + Zero + ScalarOperand,
687{
688    let (batchsize, in_channels, height, width) = input.dim();
689    let (out_channels_, k_in_channels, kernel_h, kernel_w) = kernel.dim();
690
691    // Check that input and kernel channels match
692    if in_channels != k_in_channels {
693        return Err(LinalgError::ShapeError(format!(
694            "Input channels ({in_channels}) must match kernel in_channels ({k_in_channels})"
695        )));
696    }
697
698    // Check bias shape if provided
699    if let Some(b) = bias {
700        if b.len() != out_channels_ {
701            return Err(LinalgError::ShapeError(format!(
702                "Bias length ({}) must match out_channels_ ({})",
703                b.len(),
704                out_channels_
705            )));
706        }
707    }
708
709    // Calculate output dimensions
710    let (stride_h, stride_w) = stride;
711    let (padding_h, padding_w) = padding;
712    let (dilation_h, dilation_w) = dilation;
713
714    let output_h = ((height + 2 * padding_h - dilation_h * (kernel_h - 1) - 1) / stride_h) + 1;
715    let output_w = ((width + 2 * padding_w - dilation_w * (kernel_w - 1) - 1) / stride_w) + 1;
716
717    // Check for valid dimensions
718    if output_h == 0 || output_w == 0 {
719        return Err(LinalgError::ShapeError(format!(
720            "Invalid output dimensions: ({output_h}, {output_w})"
721        )));
722    }
723
724    // Convert input to columns using im2col
725    let cols = im2col(input, (kernel_h, kernel_w), stride, padding, dilation)?;
726
727    // Reshape kernel for matrix multiplication
728    let flat_kernel = (*kernel)
729        .into_shape_with_order((out_channels_, in_channels * kernel_h * kernel_w))
730        .map_err(|e| LinalgError::ShapeError(e.to_string()))?;
731
732    // Perform matrix multiplication
733    let output_2d = flat_kernel.dot(&cols);
734
735    // Reshape to output tensor
736    let mut output = output_2d
737        .into_shape_with_order((out_channels_, batchsize, output_h, output_w))
738        .map_err(|e| LinalgError::ShapeError(e.to_string()))?;
739
740    // Rearrange dimensions to (batchsize, out_channels_, output_h, output_w)
741    output = output.permuted_axes([1, 0, 2, 3]);
742
743    // Add bias if provided
744    if let Some(b) = bias {
745        for batch_idx in 0..batchsize {
746            for oc in 0..out_channels_ {
747                for h in 0..output_h {
748                    for w in 0..output_w {
749                        output[[batch_idx, oc, h, w]] += b[oc];
750                    }
751                }
752            }
753        }
754    }
755
756    Ok(output)
757}
758
759/// Apply backward pass of convolution operation for input gradient
760///
761/// This function computes the gradient of the input in a convolutional layer
762/// given the gradient of the output.
763///
764/// # Arguments
765///
766/// * `grad_output` - Gradient of the output tensor of shape (batchsize, out_channels_, output_h, output_w)
767/// * `kernel` - Kernel tensor of shape (out_channels_, in_channels, kernel_h, kernel_w)
768/// * `inputshape` - Shape of the input tensor (batchsize, in_channels, height, width)
769/// * `stride` - Stride as (stride_height, stride_width)
770/// * `padding` - Padding as (padding_height, padding_width)
771/// * `dilation` - Dilation as (dilation_height, dilation_width)
772///
773/// # Returns
774///
775/// * Gradient of the input tensor of shape (batchsize, in_channels, height, width)
776///
777/// # Examples
778///
779/// ```
780/// use scirs2_core::ndarray::Array4;
781/// use scirs2_linalg::convolution::{conv2d_im2col, conv2d_backward_input};
782///
783/// // Forward pass
784/// let input = Array4::<f32>::zeros((2, 3, 32, 32));
785/// let kernel = Array4::<f32>::zeros((16, 3, 3, 3));
786/// let bias = None;
787/// let output = conv2d_im2col(
788///     &input.view(),
789///     &kernel.view(),
790///     bias,
791///     (1, 1),
792///     (1, 1),
793///     (1, 1),
794/// ).unwrap();
795///
796/// // Backward pass
797/// let grad_output = Array4::<f32>::ones((2, 16, 32, 32));
798/// let grad_input = conv2d_backward_input(
799///     &grad_output.view(),
800///     &kernel.view(),
801///     (2, 3, 32, 32),
802///     (1, 1),
803///     (1, 1),
804///     (1, 1),
805/// ).unwrap();
806///
807/// // Gradient shape matches input shape
808/// assert_eq!(grad_input.shape(), &[2, 3, 32, 32]);
809/// ```
810#[allow(dead_code)]
811pub fn conv2d_backward_input<F>(
812    grad_output: &ArrayView4<F>,
813    kernel: &ArrayView4<F>,
814    inputshape: (usize, usize, usize, usize),
815    stride: (usize, usize),
816    padding: (usize, usize),
817    dilation: (usize, usize),
818) -> LinalgResult<Array4<F>>
819where
820    F: Float + NumAssign + Sum + Zero + ScalarOperand,
821{
822    let (batchsize, out_channels, _output_h, _output_w) = grad_output.dim();
823    let (k_out_channels, in_channels, kernel_h, kernel_w) = kernel.dim();
824    let (i_batchsize, i_in_channels, _height, _width) = inputshape;
825
826    // Check that shapes match
827    if batchsize != i_batchsize {
828        return Err(LinalgError::ShapeError(format!(
829            "Batch size mismatch: grad_output ({batchsize}) vs inputshape ({i_batchsize})"
830        )));
831    }
832
833    if out_channels != k_out_channels {
834        return Err(LinalgError::ShapeError(format!(
835            "Output channels mismatch: grad_output ({out_channels}) vs kernel ({k_out_channels})"
836        )));
837    }
838
839    if in_channels != i_in_channels {
840        return Err(LinalgError::ShapeError(format!(
841            "Input channels mismatch: kernel ({in_channels}) vs inputshape ({i_in_channels})"
842        )));
843    }
844
845    // Prepare kernel for transposed convolution
846    let mut kernel_transposed = Array4::<F>::zeros((in_channels, out_channels, kernel_h, kernel_w));
847
848    // Flip the kernel and transpose input/_output channels
849    for oc in 0..out_channels {
850        for ic in 0..in_channels {
851            for kh in 0..kernel_h {
852                for kw in 0..kernel_w {
853                    kernel_transposed[[ic, oc, kernel_h - 1 - kh, kernel_w - 1 - kw]] =
854                        kernel[[oc, ic, kh, kw]];
855                }
856            }
857        }
858    }
859
860    // Calculate padding for transposed convolution
861    let (_stride_h, _stride_w) = stride;
862    let (padding_h, padding_w) = padding;
863    let (_dilation_h, _dilation_w) = dilation;
864
865    // We need to adjust padding for transposed convolution
866    let pad_h = kernel_h - 1 - padding_h;
867    let pad_w = kernel_w - 1 - padding_w;
868
869    // Perform transposed convolution
870    // For transposed convolution, we swap the roles of stride and dilation
871    conv2d_im2col(
872        grad_output,
873        &kernel_transposed.view(),
874        None,
875        dilation,       // original dilation becomes stride
876        (pad_h, pad_w), // adjusted padding
877        stride,         // original stride becomes dilation
878    )
879}
880
881/// Apply backward pass of convolution operation for kernel gradient
882///
883/// This function computes the gradient of the kernel in a convolutional layer
884/// given the gradient of the output and the input.
885///
886/// # Arguments
887///
888/// * `input` - Input tensor of shape (batchsize, in_channels, height, width)
889/// * `grad_output` - Gradient of the output tensor of shape (batchsize, out_channels_, output_h, output_w)
890/// * `kernelshape` - Shape of the kernel tensor (out_channels_, in_channels, kernel_h, kernel_w)
891/// * `stride` - Stride as (stride_height, stride_width)
892/// * `padding` - Padding as (padding_height, padding_width)
893/// * `dilation` - Dilation as (dilation_height, dilation_width)
894///
895/// # Returns
896///
897/// * Gradient of the kernel tensor of shape (out_channels_, in_channels, kernel_h, kernel_w)
898///
899/// # Examples
900///
901/// ```
902/// use scirs2_core::ndarray::Array4;
903/// use scirs2_linalg::convolution::{conv2d_im2col, conv2d_backward_kernel};
904///
905/// // Simple example with smaller dimensions
906/// let input = Array4::<f32>::zeros((1, 1, 4, 4));
907/// let kernelshape = (1, 1, 2, 2);
908///
909/// // Forward pass to get output shape
910/// let kernel = Array4::<f32>::zeros(kernelshape);
911/// let output = conv2d_im2col(
912///     &input.view(),
913///     &kernel.view(),
914///     None,
915///     (1, 1),  // stride
916///     (0, 0),  // padding
917///     (1, 1),  // dilation
918/// ).unwrap();
919///
920/// // Backward pass - grad_output must match forward output shape
921/// let grad_output = Array4::<f32>::ones(output.dim());
922/// let grad_kernel = conv2d_backward_kernel(
923///     &input.view(),
924///     &grad_output.view(),
925///     kernelshape,
926///     (1, 1),
927///     (0, 0),
928///     (1, 1),
929/// ).unwrap();
930///
931/// // Gradient shape matches kernel shape
932/// assert_eq!(grad_kernel.shape(), &[1, 1, 2, 2]);
933/// ```
934#[allow(dead_code)]
935pub fn conv2d_backward_kernel<F>(
936    input: &ArrayView4<F>,
937    grad_output: &ArrayView4<F>,
938    kernelshape: (usize, usize, usize, usize),
939    stride: (usize, usize),
940    padding: (usize, usize),
941    dilation: (usize, usize),
942) -> LinalgResult<Array4<F>>
943where
944    F: Float + NumAssign + Sum + Zero + ScalarOperand,
945{
946    let (batchsize, in_channels, _height, _width) = input.dim();
947    let (go_batchsize, out_channels_, output_h, output_w) = grad_output.dim();
948    let (k_out_channels, k_in_channels, kernel_h, kernel_w) = kernelshape;
949
950    // Check that shapes match
951    if batchsize != go_batchsize {
952        return Err(LinalgError::ShapeError(format!(
953            "Batch size mismatch: input ({batchsize}) vs grad_output ({go_batchsize})"
954        )));
955    }
956
957    if out_channels_ != k_out_channels {
958        return Err(LinalgError::ShapeError(format!(
959            "Output channels mismatch: grad_output ({out_channels_}) vs kernelshape ({k_out_channels})"
960        )));
961    }
962
963    if in_channels != k_in_channels {
964        return Err(LinalgError::ShapeError(format!(
965            "Input channels mismatch: input ({in_channels}) vs kernelshape ({k_in_channels})"
966        )));
967    }
968
969    // Convert input to columns using im2col
970    let cols = im2col(input, (kernel_h, kernel_w), stride, padding, dilation)?;
971
972    // Reshape grad_output for matrix multiplication
973    let grad_output_reshaped = (*grad_output)
974        .into_shape_with_order((batchsize * out_channels_, output_h * output_w))
975        .map_err(|e| LinalgError::ShapeError(e.to_string()))?;
976
977    // Compute kernel gradient using matrix multiplication
978    let grad_kernel_flat = grad_output_reshaped.dot(&cols.t());
979
980    // Reshape to kernel shape
981    let grad_kernel = grad_kernel_flat
982        .into_shape_with_order((out_channels_, in_channels, kernel_h, kernel_w))
983        .map_err(|e| LinalgError::ShapeError(e.to_string()))?;
984
985    Ok(grad_kernel)
986}
987
988/// Apply backward pass of convolution operation for bias gradient
989///
990/// This function computes the gradient of the bias in a convolutional layer
991/// given the gradient of the output.
992///
993/// # Arguments
994///
995/// * `grad_output` - Gradient of the output tensor of shape (batchsize, out_channels_, output_h, output_w)
996///
997/// # Returns
998///
999/// * Gradient of the bias tensor of shape (out_channels_,)
1000///
1001/// # Examples
1002///
1003/// ```
1004/// use scirs2_core::ndarray::Array4;
1005/// use scirs2_linalg::convolution::conv2d_backward_bias;
1006///
1007/// // Backward pass for bias
1008/// let grad_output = Array4::<f32>::ones((2, 16, 32, 32));
1009/// let grad_bias = conv2d_backward_bias(&grad_output.view()).unwrap();
1010///
1011/// // Gradient shape matches bias shape
1012/// assert_eq!(grad_bias.shape(), &[16]);
1013/// ```
1014#[allow(dead_code)]
1015pub fn conv2d_backward_bias<F>(
1016    grad_output: &ArrayView4<F>,
1017) -> LinalgResult<scirs2_core::ndarray::Array1<F>>
1018where
1019    F: Float + NumAssign + Sum + Zero,
1020{
1021    let (batchsize, out_channels_, output_h, output_w) = grad_output.dim();
1022
1023    // Allocate gradient for bias
1024    let mut grad_bias = scirs2_core::ndarray::Array1::<F>::zeros(out_channels_);
1025
1026    // Sum gradients over batch, height, and width dimensions
1027    for batch_idx in 0..batchsize {
1028        for oc in 0..out_channels_ {
1029            for h in 0..output_h {
1030                for w in 0..output_w {
1031                    grad_bias[oc] += grad_output[[batch_idx, oc, h, w]];
1032                }
1033            }
1034        }
1035    }
1036
1037    Ok(grad_bias)
1038}
1039
1040/// Apply 2D transposed convolution (deconvolution) operation
1041///
1042/// This function implements a transposed convolution (also known as deconvolution
1043/// or fractionally-strided convolution), which is commonly used in convolutional
1044/// neural networks for upsampling.
1045///
1046/// # Arguments
1047///
1048/// * `input` - Input tensor of shape (batchsize, channels, height, width)
1049/// * `kernel` - Kernel tensor of shape (in_channels, out_channels_, kernel_h, kernel_w)
1050/// * `bias` - Optional bias tensor of shape (out_channels_,)
1051/// * `stride` - Stride as (stride_height, stride_width)
1052/// * `padding` - Padding as (padding_height, padding_width)
1053/// * `output_padding` - Additional padding for output as (padding_height, padding_width)
1054/// * `dilation` - Dilation as (dilation_height, dilation_width)
1055///
1056/// # Returns
1057///
1058/// * Output tensor of shape (batchsize, out_channels_, output_height, output_width)
1059///
1060/// # Examples
1061///
1062/// ```
1063/// use scirs2_core::ndarray::{Array, Array4};
1064/// use scirs2_linalg::convolution::conv_transpose2d;
1065///
1066/// // Simple example with smaller dimensions
1067/// let input = Array4::<f32>::zeros((1, 2, 3, 3));
1068///
1069/// // Create a 2x1x2x2 kernel tensor (2 input channels, 1 output channel, 2x2 kernel)
1070/// let kernel = Array4::<f32>::zeros((2, 1, 2, 2));
1071///
1072/// // Apply transposed convolution
1073/// let output = conv_transpose2d(
1074///     &input.view(),
1075///     &kernel.view(),
1076///     None,        // no bias
1077///     (1, 1),      // stride
1078///     (0, 0),      // padding
1079///     (0, 0),      // output_padding
1080///     (1, 1),      // dilation
1081/// ).unwrap();
1082///
1083/// // Calculate expected output shape:
1084/// // output_h = (3 - 1) * 1 - 2 * 0 + 1 * (2 - 1) + 0 + 1 = 2 + 1 + 1 = 4
1085/// // output_w = (3 - 1) * 1 - 2 * 0 + 1 * (2 - 1) + 0 + 1 = 2 + 1 + 1 = 4
1086/// assert_eq!(output.shape(), &[1, 1, 4, 4]);
1087/// ```
1088#[allow(dead_code)]
1089pub fn conv_transpose2d<F>(
1090    input: &ArrayView4<F>,
1091    kernel: &ArrayView4<F>,
1092    bias: Option<scirs2_core::ndarray::ArrayView1<F>>,
1093    stride: (usize, usize),
1094    padding: (usize, usize),
1095    output_padding: (usize, usize),
1096    dilation: (usize, usize),
1097) -> LinalgResult<Array4<F>>
1098where
1099    F: Float + NumAssign + Sum + Zero + ScalarOperand,
1100{
1101    let (batchsize, in_channels, height, width) = input.dim();
1102    let (k_in_channels, out_channels_, kernel_h, kernel_w) = kernel.dim();
1103
1104    // Check that channels match
1105    if in_channels != k_in_channels {
1106        return Err(LinalgError::ShapeError(format!(
1107            "Input channels mismatch: input ({in_channels}) vs kernel ({k_in_channels})"
1108        )));
1109    }
1110
1111    // Check bias shape if provided
1112    if let Some(b) = bias {
1113        if b.len() != out_channels_ {
1114            return Err(LinalgError::ShapeError(format!(
1115                "Bias length ({}) must match out_channels_ ({})",
1116                b.len(),
1117                out_channels_
1118            )));
1119        }
1120    }
1121
1122    // Calculate output dimensions
1123    let (stride_h, stride_w) = stride;
1124    let (padding_h, padding_w) = padding;
1125    let (output_padding_h, output_padding_w) = output_padding;
1126    let (dilation_h, dilation_w) = dilation;
1127
1128    let output_h = (height - 1) * stride_h - 2 * padding_h
1129        + dilation_h * (kernel_h - 1)
1130        + output_padding_h
1131        + 1;
1132    let output_w =
1133        (width - 1) * stride_w - 2 * padding_w + dilation_w * (kernel_w - 1) + output_padding_w + 1;
1134
1135    // Allocate output tensor
1136    let mut output = Array4::<F>::zeros((batchsize, out_channels_, output_h, output_w));
1137
1138    // Perform transposed convolution
1139    for b in 0..batchsize {
1140        for oc in 0..out_channels_ {
1141            for ic in 0..in_channels {
1142                for h in 0..height {
1143                    for w in 0..width {
1144                        let input_val = input[[b, ic, h, w]];
1145
1146                        for kh in 0..kernel_h {
1147                            for kw in 0..kernel_w {
1148                                // Calculate output coordinates for transposed convolution
1149                                // For transposed conv, we're scattering values from input to output
1150                                // The calculation needs to account for the kernel's position in the opposite way
1151                                // than normal convolution
1152                                let out_h = h as isize * stride_h as isize
1153                                    + kh as isize * dilation_h as isize
1154                                    - padding_h as isize;
1155                                let out_w = w as isize * stride_w as isize
1156                                    + kw as isize * dilation_w as isize
1157                                    - padding_w as isize;
1158
1159                                // Only process if coordinates are non-negative
1160                                if out_h >= 0 && out_w >= 0 {
1161                                    let out_h = out_h as usize;
1162                                    let out_w = out_w as usize;
1163
1164                                    if out_h < output_h && out_w < output_w {
1165                                        // The transposed convolution can be thought of as:
1166                                        // 1. For each input position and each kernel position
1167                                        // 2. Calculate the output position based on stride, padding, and dilation
1168                                        // 3. Add the product of input and kernel value to that output position
1169                                        //
1170                                        // Note: Technically, for a proper mathematical transposed convolution,
1171                                        // we should use the flipped kernel. However, in practice, ML libraries
1172                                        // often just reuse the same kernel for simplicity and learn appropriate weights.
1173                                        output[[b, oc, out_h, out_w]] +=
1174                                            input_val * kernel[[ic, oc, kh, kw]];
1175                                    }
1176                                }
1177                            }
1178                        }
1179                    }
1180                }
1181            }
1182
1183            // Add bias if provided
1184            if let Some(b_val) = bias.map(|b| b[oc]) {
1185                for h in 0..output_h {
1186                    for w in 0..output_w {
1187                        output[[b, oc, h, w]] += b_val;
1188                    }
1189                }
1190            }
1191        }
1192    }
1193
1194    Ok(output)
1195}
1196
1197#[cfg(test)]
1198mod tests {
1199    use super::*;
1200    use approx::assert_relative_eq;
1201    use scirs2_core::ndarray::{Array1, Array4};
1202
1203    #[test]
1204    fn test_im2col_basic() {
1205        // Create a simple 1x1x3x3 input
1206        let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
1207        for h in 0..3 {
1208            for w in 0..3 {
1209                input[[0, 0, h, w]] = (h * 3 + w) as f32;
1210            }
1211        }
1212
1213        // Extract 2x2 patches with stride 1 and no padding
1214        let cols = im2col(&input.view(), (2, 2), (1, 1), (0, 0), (1, 1)).unwrap();
1215
1216        // Resulting matrix should be (1*2*2, 2*2*1) = (4, 4)
1217        assert_eq!(cols.shape(), &[4, 4]);
1218
1219        // Check the first column (top-left 2x2 patch)
1220        assert_eq!(cols[[0, 0]], 0.0);
1221        assert_eq!(cols[[1, 0]], 1.0);
1222        assert_eq!(cols[[2, 0]], 3.0);
1223        assert_eq!(cols[[3, 0]], 4.0);
1224
1225        // Check the second column (top-right 2x2 patch)
1226        assert_eq!(cols[[0, 1]], 1.0);
1227        assert_eq!(cols[[1, 1]], 2.0);
1228        assert_eq!(cols[[2, 1]], 4.0);
1229        assert_eq!(cols[[3, 1]], 5.0);
1230    }
1231
1232    #[test]
1233    fn test_im2col_with_padding() {
1234        // Create a simple 1x1x2x2 input
1235        let mut input = Array4::<f32>::zeros((1, 1, 2, 2));
1236        input[[0, 0, 0, 0]] = 0.0;
1237        input[[0, 0, 0, 1]] = 1.0;
1238        input[[0, 0, 1, 0]] = 2.0;
1239        input[[0, 0, 1, 1]] = 3.0;
1240
1241        // Extract 3x3 patches with stride 1 and padding 1
1242        let cols = im2col(&input.view(), (3, 3), (1, 1), (1, 1), (1, 1)).unwrap();
1243
1244        // Resulting matrix should be (1*3*3, 2*2*1) = (9, 4)
1245        assert_eq!(cols.shape(), &[9, 4]);
1246
1247        // Check padding is zero
1248        assert_eq!(cols[[0, 0]], 0.0); // Top-left padding
1249        assert_eq!(cols[[2, 0]], 0.0); // Top-right padding
1250        assert_eq!(cols[[6, 0]], 0.0); // Bottom-left padding
1251        assert_eq!(cols[[8, 0]], 3.0); // Bottom-right padding - this corresponds to input[1,1] at position (2,2)
1252
1253        // Check actual values - kernel center (1,1) at patch (0,0) corresponds to input (0,0)
1254        assert_eq!(cols[[4, 0]], 0.0); // Center of first patch corresponds to input[0,0,0,0]
1255    }
1256
1257    #[test]
1258    fn test_col2im_basic() {
1259        // Create a simple 1x1x3x3 input
1260        let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
1261        for h in 0..3 {
1262            for w in 0..3 {
1263                input[[0, 0, h, w]] = (h * 3 + w) as f32;
1264            }
1265        }
1266
1267        // Convert to columns
1268        let cols = im2col(&input.view(), (2, 2), (1, 1), (0, 0), (1, 1)).unwrap();
1269
1270        // Convert back to image
1271        let output = col2im(&cols.view(), (1, 1, 3, 3), (2, 2), (1, 1), (0, 0), (1, 1)).unwrap();
1272
1273        // Check dimensions
1274        assert_eq!(output.shape(), input.shape());
1275
1276        // Check values (note that overlapping patches are averaged)
1277        assert_relative_eq!(output[[0, 0, 0, 0]], input[[0, 0, 0, 0]], epsilon = 1e-5);
1278        assert_relative_eq!(output[[0, 0, 0, 1]], input[[0, 0, 0, 1]], epsilon = 1e-5);
1279        assert_relative_eq!(output[[0, 0, 0, 2]], input[[0, 0, 0, 2]], epsilon = 1e-5);
1280        assert_relative_eq!(output[[0, 0, 1, 0]], input[[0, 0, 1, 0]], epsilon = 1e-5);
1281        assert_relative_eq!(output[[0, 0, 1, 1]], input[[0, 0, 1, 1]], epsilon = 1e-5);
1282        assert_relative_eq!(output[[0, 0, 1, 2]], input[[0, 0, 1, 2]], epsilon = 1e-5);
1283        assert_relative_eq!(output[[0, 0, 2, 0]], input[[0, 0, 2, 0]], epsilon = 1e-5);
1284        assert_relative_eq!(output[[0, 0, 2, 1]], input[[0, 0, 2, 1]], epsilon = 1e-5);
1285        assert_relative_eq!(output[[0, 0, 2, 2]], input[[0, 0, 2, 2]], epsilon = 1e-5);
1286    }
1287
1288    #[test]
1289    fn test_max_pool2d() {
1290        // Create a simple 1x1x4x4 input
1291        let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
1292        for h in 0..4 {
1293            for w in 0..4 {
1294                input[[0, 0, h, w]] = (h * 4 + w) as f32;
1295            }
1296        }
1297
1298        // Apply 2x2 max pooling with stride 2
1299        let (output, indices) = max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).unwrap();
1300
1301        // Check dimensions
1302        assert_eq!(output.shape(), &[1, 1, 2, 2]);
1303
1304        // Check values (should take max from each 2x2 region)
1305        assert_eq!(output[[0, 0, 0, 0]], 5.0); // max of top-left 2x2
1306        assert_eq!(output[[0, 0, 0, 1]], 7.0); // max of top-right 2x2
1307        assert_eq!(output[[0, 0, 1, 0]], 13.0); // max of bottom-left 2x2
1308        assert_eq!(output[[0, 0, 1, 1]], 15.0); // max of bottom-right 2x2
1309
1310        // Check indices
1311        assert_eq!(indices[[0, 0, 0, 0]], 5); // index of 5 in flattened input
1312        assert_eq!(indices[[0, 0, 0, 1]], 7); // index of 7 in flattened input
1313        assert_eq!(indices[[0, 0, 1, 0]], 13); // index of 13 in flattened input
1314        assert_eq!(indices[[0, 0, 1, 1]], 15); // index of 15 in flattened input
1315    }
1316
1317    #[test]
1318    fn test_max_pool2d_backward() {
1319        // Create a simple 1x1x4x4 input
1320        let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
1321        for h in 0..4 {
1322            for w in 0..4 {
1323                input[[0, 0, h, w]] = (h * 4 + w) as f32;
1324            }
1325        }
1326
1327        // Forward pass
1328        let (_output, indices) = max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).unwrap();
1329
1330        // Create gradient of output
1331        let grad_output = Array4::<f32>::ones((1, 1, 2, 2));
1332
1333        // Backward pass
1334        let grad_input =
1335            max_pool2d_backward(&grad_output.view(), &indices.view(), (1, 1, 4, 4)).unwrap();
1336
1337        // Check dimensions
1338        assert_eq!(grad_input.shape(), input.shape());
1339
1340        // Only positions with max values should have gradients
1341        for h in 0..4 {
1342            for w in 0..4 {
1343                let pos = h * 4 + w;
1344                let expected = if pos == 5 || pos == 7 || pos == 13 || pos == 15 {
1345                    1.0
1346                } else {
1347                    0.0
1348                };
1349                assert_eq!(grad_input[[0, 0, h, w]], expected);
1350            }
1351        }
1352    }
1353
1354    #[test]
1355    fn test_conv2d_im2col_basic() {
1356        // Create a simple 1x1x3x3 input
1357        let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
1358        for h in 0..3 {
1359            for w in 0..3 {
1360                input[[0, 0, h, w]] = (h * 3 + w) as f32;
1361            }
1362        }
1363
1364        // Create a simple 1x1x2x2 kernel (identity)
1365        let mut kernel = Array4::<f32>::zeros((1, 1, 2, 2));
1366        kernel[[0, 0, 0, 0]] = 1.0;
1367        kernel[[0, 0, 0, 1]] = 0.0;
1368        kernel[[0, 0, 1, 0]] = 0.0;
1369        kernel[[0, 0, 1, 1]] = 0.0;
1370
1371        // Apply convolution
1372        let output =
1373            conv2d_im2col(&input.view(), &kernel.view(), None, (1, 1), (0, 0), (1, 1)).unwrap();
1374
1375        // Check dimensions
1376        assert_eq!(output.shape(), &[1, 1, 2, 2]);
1377
1378        // Kernel extracts top-left value from each position
1379        assert_eq!(output[[0, 0, 0, 0]], 0.0);
1380        assert_eq!(output[[0, 0, 0, 1]], 1.0);
1381        assert_eq!(output[[0, 0, 1, 0]], 3.0);
1382        assert_eq!(output[[0, 0, 1, 1]], 4.0);
1383    }
1384
1385    #[test]
1386    fn test_conv2d_im2col_with_bias() {
1387        // Create a simple 1x1x3x3 input
1388        let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
1389        for h in 0..3 {
1390            for w in 0..3 {
1391                input[[0, 0, h, w]] = (h * 3 + w) as f32;
1392            }
1393        }
1394
1395        // Create a simple 1x1x2x2 kernel (identity)
1396        let mut kernel = Array4::<f32>::zeros((1, 1, 2, 2));
1397        kernel[[0, 0, 0, 0]] = 1.0;
1398        kernel[[0, 0, 0, 1]] = 0.0;
1399        kernel[[0, 0, 1, 0]] = 0.0;
1400        kernel[[0, 0, 1, 1]] = 0.0;
1401
1402        // Create bias
1403        let bias = Array1::<f32>::from_elem(1, 10.0);
1404
1405        // Apply convolution with bias
1406        let output = conv2d_im2col(
1407            &input.view(),
1408            &kernel.view(),
1409            Some(bias.view()),
1410            (1, 1),
1411            (0, 0),
1412            (1, 1),
1413        )
1414        .unwrap();
1415
1416        // Check dimensions
1417        assert_eq!(output.shape(), &[1, 1, 2, 2]);
1418
1419        // Kernel extracts top-left value from each position, plus bias
1420        assert_eq!(output[[0, 0, 0, 0]], 10.0);
1421        assert_eq!(output[[0, 0, 0, 1]], 11.0);
1422        assert_eq!(output[[0, 0, 1, 0]], 13.0);
1423        assert_eq!(output[[0, 0, 1, 1]], 14.0);
1424    }
1425
1426    #[test]
1427    fn test_conv2d_backward_input() {
1428        // Create a simple 1x1x3x3 input
1429        let input = Array4::<f32>::zeros((1, 1, 3, 3));
1430
1431        // Create a simple 1x1x2x2 kernel
1432        let mut kernel = Array4::<f32>::zeros((1, 1, 2, 2));
1433        kernel[[0, 0, 0, 0]] = 1.0;
1434        kernel[[0, 0, 0, 1]] = 2.0;
1435        kernel[[0, 0, 1, 0]] = 3.0;
1436        kernel[[0, 0, 1, 1]] = 4.0;
1437
1438        // Apply forward pass
1439        let _output =
1440            conv2d_im2col(&input.view(), &kernel.view(), None, (1, 1), (0, 0), (1, 1)).unwrap();
1441
1442        // Create gradient of output
1443        let grad_output = Array4::<f32>::ones((1, 1, 2, 2));
1444
1445        // Apply backward pass for input
1446        let grad_input = conv2d_backward_input(
1447            &grad_output.view(),
1448            &kernel.view(),
1449            (1, 1, 3, 3),
1450            (1, 1),
1451            (0, 0),
1452            (1, 1),
1453        )
1454        .unwrap();
1455
1456        // Check dimensions
1457        assert_eq!(grad_input.shape(), input.shape());
1458
1459        // Each position receives weighted gradients from overlapping filters
1460        // For gradient=1 at each output position:
1461        // input[0,0] receives 1.0 from output[0,0]
1462        // input[0,1] receives 2.0 from output[0,0] + 1.0 from output[0,1] = 3.0
1463        // etc.
1464        assert_eq!(grad_input[[0, 0, 0, 0]], 1.0);
1465        assert_eq!(grad_input[[0, 0, 0, 1]], 3.0);
1466        assert_eq!(grad_input[[0, 0, 1, 0]], 4.0);
1467        assert_eq!(grad_input[[0, 0, 1, 1]], 10.0);
1468    }
1469
1470    #[test]
1471    fn test_conv2d_backward_kernel() {
1472        // Create a simple 1x1x3x3 input with all ones
1473        let input = Array4::<f32>::ones((1, 1, 3, 3));
1474
1475        // Create gradient of output, all ones
1476        let grad_output = Array4::<f32>::ones((1, 1, 2, 2));
1477
1478        // Apply backward pass for kernel
1479        let grad_kernel = conv2d_backward_kernel(
1480            &input.view(),
1481            &grad_output.view(),
1482            (1, 1, 2, 2),
1483            (1, 1),
1484            (0, 0),
1485            (1, 1),
1486        )
1487        .unwrap();
1488
1489        // Check dimensions
1490        assert_eq!(grad_kernel.shape(), &[1, 1, 2, 2]);
1491
1492        // With all ones input and gradient, each kernel position accumulates
1493        // the number of times it overlaps with the input
1494        assert_eq!(grad_kernel[[0, 0, 0, 0]], 4.0); // Overlaps 4 times
1495        assert_eq!(grad_kernel[[0, 0, 0, 1]], 4.0);
1496        assert_eq!(grad_kernel[[0, 0, 1, 0]], 4.0);
1497        assert_eq!(grad_kernel[[0, 0, 1, 1]], 4.0);
1498    }
1499
1500    #[test]
1501    fn test_conv2d_backward_bias() {
1502        // Create gradient of output
1503        let mut grad_output = Array4::<f32>::zeros((2, 3, 2, 2));
1504        for b in 0..2 {
1505            for c in 0..3 {
1506                for h in 0..2 {
1507                    for w in 0..2 {
1508                        grad_output[[b, c, h, w]] = 1.0;
1509                    }
1510                }
1511            }
1512        }
1513
1514        // Apply backward pass for bias
1515        let grad_bias = conv2d_backward_bias(&grad_output.view()).unwrap();
1516
1517        // Check dimensions
1518        assert_eq!(grad_bias.shape(), &[3]);
1519
1520        // Each bias accumulates gradient from all positions and batches
1521        assert_eq!(grad_bias[0], 8.0); // 2 batches * 2*2 spatial = 8
1522        assert_eq!(grad_bias[1], 8.0);
1523        assert_eq!(grad_bias[2], 8.0);
1524    }
1525
1526    #[test]
1527    fn test_conv_transpose2d() {
1528        // Create a simple 1x1x2x2 input
1529        let input = Array4::<f32>::ones((1, 1, 2, 2));
1530
1531        // Create a simple 1x1x3x3 kernel with only top-left value set to 1.0
1532        let mut kernel = Array4::<f32>::zeros((1, 1, 3, 3));
1533        kernel[[0, 0, 0, 0]] = 1.0;
1534        // All other values are 0.0
1535
1536        // Apply transposed convolution
1537        let output = conv_transpose2d(
1538            &input.view(),
1539            &kernel.view(),
1540            None,
1541            (2, 2), // stride
1542            (1, 1), // padding
1543            (0, 0), // output_padding
1544            (1, 1), // dilation
1545        )
1546        .unwrap();
1547
1548        // Check dimensions
1549        // outputsize = (inputsize - 1) * stride - 2 * padding + kernelsize
1550        // = (2-1)*2 - 2*1 + 3 = 2 - 2 + 3 = 3
1551        assert_eq!(output.shape(), &[1, 1, 3, 3]);
1552
1553        // Let's carefully trace through the algorithm for each input position:
1554        // The input is all ones at positions (0,0), (0,1), (1,0), and (1,1)
1555        // The kernel has a 1.0 at position (0,0,0,0) and zeros elsewhere
1556        // With stride=(2,2), padding=(1,1), the output coordinates are calculated as:
1557        //
1558        // For each input position (h,w) and kernel position (kh,kw):
1559        //   out_h = h*stride_h + kh*dilation_h - padding_h
1560        //   out_w = w*stride_w + kw*dilation_w - padding_w
1561        //
1562        // For input (0,0) and kernel (0,0):
1563        //   out_h = 0*2 + 0*1 - 1 = -1 (out of bounds)
1564        //   out_w = 0*2 + 0*1 - 1 = -1 (out of bounds)
1565        //
1566        // For input (0,1) and kernel (0,0):
1567        //   out_h = 0*2 + 0*1 - 1 = -1 (out of bounds)
1568        //   out_w = 1*2 + 0*1 - 1 = 1
1569        //
1570        // For input (1,0) and kernel (0,0):
1571        //   out_h = 1*2 + 0*1 - 1 = 1
1572        //   out_w = 0*2 + 0*1 - 1 = -1 (out of bounds)
1573        //
1574        // For input (1,1) and kernel (0,0):
1575        //   out_h = 1*2 + 0*1 - 1 = 1
1576        //   out_w = 1*2 + 0*1 - 1 = 1
1577        //
1578        // So only the input at (1,1) with kernel at (0,0) contributes to the output at (1,1)
1579
1580        // Verify the output
1581        assert_eq!(output[[0, 0, 0, 0]], 0.0); // No contribution
1582        assert_eq!(output[[0, 0, 0, 1]], 0.0); // No contribution
1583        assert_eq!(output[[0, 0, 0, 2]], 0.0); // No contribution
1584        assert_eq!(output[[0, 0, 1, 0]], 0.0); // No contribution
1585        assert_eq!(output[[0, 0, 1, 1]], 1.0); // From input (1,1) with kernel (0,0)
1586        assert_eq!(output[[0, 0, 1, 2]], 0.0); // No contribution
1587        assert_eq!(output[[0, 0, 2, 0]], 0.0); // No contribution
1588        assert_eq!(output[[0, 0, 2, 1]], 0.0); // No contribution
1589        assert_eq!(output[[0, 0, 2, 2]], 0.0); // No contribution
1590    }
1591}