scirs2_linalg/convolution/mod.rs
1//! Specialized operations for convolutional neural networks
2//!
3//! This module provides efficient implementations of specialized operations
4//! that are commonly used in convolutional neural networks, such as im2col/col2im,
5//! efficient convolution algorithms, and other related operations.
6
7use scirs2_core::ndarray::{Array2, Array4, ArrayView4, ScalarOperand};
8use scirs2_core::numeric::{Float, NumAssign, Zero};
9use std::iter::Sum;
10
11use crate::error::{LinalgError, LinalgResult};
12
13/// Extract patches from an input tensor using the im2col algorithm
14///
15/// This function implements the im2col (image to column) algorithm, which
16/// reformats the input data to allow efficient computation of convolution
17/// operations using matrix multiplication.
18///
19/// # Arguments
20///
21/// * `input` - Input tensor of shape (batchsize, channels, height, width)
22/// * `kernelsize` - Size of the kernel as (kernel_height, kernel_width)
23/// * `stride` - Stride as (stride_height, stride_width)
24/// * `padding` - Padding as (padding_height, padding_width)
25/// * `dilation` - Dilation as (dilation_height, dilation_width)
26///
27/// # Returns
28///
29/// * Column matrix of shape (kernel_h * kernel_w * channels, output_h * output_w * batchsize)
30///
31/// # Examples
32///
33/// ```
34/// use scirs2_core::ndarray::Array4;
35/// use scirs2_linalg::convolution::im2col;
36///
37/// // Create a 1x3x4x4 input tensor (1 batch, 3 channels, 4x4 spatial dimensions)
38/// let mut input = Array4::<f32>::zeros((1, 3, 4, 4));
39/// // Fill with sample data
40/// for c in 0..3 {
41/// for h in 0..4 {
42/// for w in 0..4 {
43/// input[[0, c, h, w]] = (c * 16 + h * 4 + w) as f32;
44/// }
45/// }
46/// }
47///
48/// // Extract 3x3 patches with stride 1 and no padding
49/// let cols = im2col(&input.view(), (3, 3), (1, 1), (0, 0), (1, 1)).unwrap();
50///
51/// // Resulting matrix has shape (3*3*3, 2*2*1) = (27, 4)
52/// // Each column represents a 3x3 patch across all 3 channels
53/// assert_eq!(cols.shape(), &[27, 4]);
54/// ```
55#[allow(dead_code)]
56pub fn im2col<F>(
57 input: &ArrayView4<F>,
58 kernelsize: (usize, usize),
59 stride: (usize, usize),
60 padding: (usize, usize),
61 dilation: (usize, usize),
62) -> LinalgResult<Array2<F>>
63where
64 F: Float + NumAssign + Sum + Zero + ScalarOperand,
65{
66 let (batchsize, channels, height, width) = input.dim();
67 let (kernel_h, kernel_w) = kernelsize;
68 let (stride_h, stride_w) = stride;
69 let (padding_h, padding_w) = padding;
70 let (dilation_h, dilation_w) = dilation;
71
72 // Calculate output dimensions
73 let output_h = ((height + 2 * padding_h - dilation_h * (kernel_h - 1) - 1) / stride_h) + 1;
74 let output_w = ((width + 2 * padding_w - dilation_w * (kernel_w - 1) - 1) / stride_w) + 1;
75
76 // Check for valid dimensions
77 if output_h == 0 || output_w == 0 {
78 return Err(LinalgError::ShapeError(format!(
79 "Invalid output dimensions: ({output_h}, {output_w})"
80 )));
81 }
82
83 // Allocate output matrix
84 let mut cols = Array2::<F>::zeros((
85 kernel_h * kernel_w * channels,
86 output_h * output_w * batchsize,
87 ));
88
89 // Populate the output matrix with patches
90 for batch_idx in 0..batchsize {
91 for channel_idx in 0..channels {
92 for kernel_row in 0..kernel_h {
93 for kernel_col in 0..kernel_w {
94 let input_row_offset = kernel_row * dilation_h;
95 let input_col_offset = kernel_col * dilation_w;
96
97 // Position in the cols matrix
98 let cols_idx =
99 channel_idx * kernel_h * kernel_w + kernel_row * kernel_w + kernel_col;
100
101 for output_row in 0..output_h {
102 for output_col in 0..output_w {
103 let input_row = output_row * stride_h + input_row_offset;
104 let input_col = output_col * stride_w + input_col_offset;
105
106 // Position in the cols matrix
107 let cols_pos = batch_idx * output_h * output_w
108 + output_row * output_w
109 + output_col;
110
111 // Check if we need to pad
112 if input_row < padding_h
113 || input_row >= height + padding_h
114 || input_col < padding_w
115 || input_col >= width + padding_w
116 {
117 // Zero-padding
118 cols[[cols_idx, cols_pos]] = F::zero();
119 } else {
120 // Copy from input
121 let input_val = input[[
122 batch_idx,
123 channel_idx,
124 input_row - padding_h,
125 input_col - padding_w,
126 ]];
127 cols[[cols_idx, cols_pos]] = input_val;
128 }
129 }
130 }
131 }
132 }
133 }
134 }
135
136 Ok(cols)
137}
138
139/// Convert a column matrix back to an input tensor using the col2im algorithm
140///
141/// This function implements the col2im (column to image) algorithm, which
142/// converts the column matrix back to the original input tensor format.
143///
144/// # Arguments
145///
146/// * `cols` - Column matrix of shape (kernel_h * kernel_w * channels, output_h * output_w * batchsize)
147/// * `outputshape` - Shape of the output tensor as (batchsize, channels, height, width)
148/// * `kernelsize` - Size of the kernel as (kernel_height, kernel_width)
149/// * `stride` - Stride as (stride_height, stride_width)
150/// * `padding` - Padding as (padding_height, padding_width)
151/// * `dilation` - Dilation as (dilation_height, dilation_width)
152///
153/// # Returns
154///
155/// * Output tensor of shape (batchsize, channels, height, width)
156///
157/// # Examples
158///
159/// ```
160/// use scirs2_core::ndarray::{Array4, ArrayView4};
161/// use scirs2_linalg::convolution::{im2col, col2im};
162///
163/// // Create a 1x3x4x4 input tensor
164/// let mut input = Array4::<f32>::zeros((1, 3, 4, 4));
165/// // Fill with sample data
166/// for c in 0..3 {
167/// for h in 0..4 {
168/// for w in 0..4 {
169/// input[[0, c, h, w]] = (c * 16 + h * 4 + w) as f32;
170/// }
171/// }
172/// }
173///
174/// // Convert to columns with im2col
175/// let cols = im2col(&input.view(), (3, 3), (1, 1), (0, 0), (1, 1)).unwrap();
176///
177/// // Convert back to image with col2im
178/// let output = col2im(
179/// &cols.view(),
180/// (1, 3, 4, 4),
181/// (3, 3),
182/// (1, 1),
183/// (0, 0),
184/// (1, 1),
185/// ).unwrap();
186///
187/// // Verify output shape
188/// assert_eq!(output.shape(), &[1, 3, 4, 4]);
189/// ```
190#[allow(dead_code)]
191pub fn col2im<F>(
192 cols: &scirs2_core::ndarray::ArrayView2<F>,
193 outputshape: (usize, usize, usize, usize),
194 kernelsize: (usize, usize),
195 stride: (usize, usize),
196 padding: (usize, usize),
197 dilation: (usize, usize),
198) -> LinalgResult<Array4<F>>
199where
200 F: Float + NumAssign + Sum + Zero + ScalarOperand,
201{
202 let (batchsize, channels, height, width) = outputshape;
203 let (kernel_h, kernel_w) = kernelsize;
204 let (stride_h, stride_w) = stride;
205 let (padding_h, padding_w) = padding;
206 let (dilation_h, dilation_w) = dilation;
207
208 // Calculate output dimensions
209 let output_h = ((height + 2 * padding_h - dilation_h * (kernel_h - 1) - 1) / stride_h) + 1;
210 let output_w = ((width + 2 * padding_w - dilation_w * (kernel_w - 1) - 1) / stride_w) + 1;
211
212 // Check for valid dimensions
213 if output_h == 0 || output_w == 0 {
214 return Err(LinalgError::ShapeError(format!(
215 "Invalid output dimensions: ({output_h}, {output_w})"
216 )));
217 }
218
219 // Check input columns shape
220 if cols.shape()[0] != kernel_h * kernel_w * channels
221 || cols.shape()[1] != output_h * output_w * batchsize
222 {
223 return Err(LinalgError::ShapeError(format!(
224 "Invalid cols shape: expected ({}, {}), got ({}, {})",
225 kernel_h * kernel_w * channels,
226 output_h * output_w * batchsize,
227 cols.shape()[0],
228 cols.shape()[1]
229 )));
230 }
231
232 // Allocate output tensor
233 let mut output = Array4::<F>::zeros((batchsize, channels, height, width));
234 let mut counts = Array4::<usize>::zeros((batchsize, channels, height, width));
235
236 // Accumulate values from cols to output
237 for batch_idx in 0..batchsize {
238 for channel_idx in 0..channels {
239 for kernel_row in 0..kernel_h {
240 for kernel_col in 0..kernel_w {
241 let input_row_offset = kernel_row * dilation_h;
242 let input_col_offset = kernel_col * dilation_w;
243
244 // Position in the cols matrix
245 let cols_idx =
246 channel_idx * kernel_h * kernel_w + kernel_row * kernel_w + kernel_col;
247
248 for output_row in 0..output_h {
249 for output_col in 0..output_w {
250 let input_row = output_row * stride_h + input_row_offset;
251 let input_col = output_col * stride_w + input_col_offset;
252
253 // Position in the cols matrix
254 let cols_pos = batch_idx * output_h * output_w
255 + output_row * output_w
256 + output_col;
257
258 // Check if the position is valid (not padding)
259 if input_row >= padding_h
260 && input_row < height + padding_h
261 && input_col >= padding_w
262 && input_col < width + padding_w
263 {
264 let output_row_idx = input_row - padding_h;
265 let output_col_idx = input_col - padding_w;
266
267 output[[batch_idx, channel_idx, output_row_idx, output_col_idx]] +=
268 cols[[cols_idx, cols_pos]];
269 counts[[batch_idx, channel_idx, output_row_idx, output_col_idx]] +=
270 1;
271 }
272 }
273 }
274 }
275 }
276 }
277 }
278
279 // Normalize by count (average overlapping patches)
280 for batch_idx in 0..batchsize {
281 for channel_idx in 0..channels {
282 for h in 0..height {
283 for w in 0..width {
284 let count = counts[[batch_idx, channel_idx, h, w]];
285 if count > 0 {
286 output[[batch_idx, channel_idx, h, w]] /= F::from(count).unwrap();
287 }
288 }
289 }
290 }
291 }
292
293 Ok(output)
294}
295
296/// Perform max pooling operation on a 4D input tensor
297///
298/// Applies max pooling over a 4D tensor, which is commonly used to
299/// down-sample feature maps in convolutional neural networks.
300///
301/// # Arguments
302///
303/// * `input` - Input tensor of shape (batchsize, channels, height, width)
304/// * `poolsize` - Size of the pooling window as (pool_height, pool_width)
305/// * `stride` - Stride as (stride_height, stride_width)
306/// * `padding` - Padding as (padding_height, padding_width)
307///
308/// # Returns
309///
310/// * Output tensor of pooled values and indices of max values (for backward pass)
311///
312/// # Examples
313///
314/// ```
315/// use scirs2_core::ndarray::Array4;
316/// use scirs2_linalg::convolution::max_pool2d;
317///
318/// // Create a 1x1x4x4 input tensor
319/// let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
320/// // Fill with sample data
321/// for h in 0..4 {
322/// for w in 0..4 {
323/// input[[0, 0, h, w]] = (h * 4 + w) as f32;
324/// }
325/// }
326///
327/// // Apply 2x2 max pooling with stride 2
328/// let (output, indices) = max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).unwrap();
329///
330/// // Resulting tensor has shape (1, 1, 2, 2)
331/// assert_eq!(output.shape(), &[1, 1, 2, 2]);
332/// ```
333#[allow(dead_code)]
334pub fn max_pool2d<F>(
335 input: &ArrayView4<F>,
336 poolsize: (usize, usize),
337 stride: (usize, usize),
338 padding: (usize, usize),
339) -> LinalgResult<(Array4<F>, Array4<usize>)>
340where
341 F: Float + NumAssign + Sum + Zero + ScalarOperand,
342{
343 let (batchsize, channels, height, width) = input.dim();
344 let (pool_h, pool_w) = poolsize;
345 let (stride_h, stride_w) = stride;
346 let (padding_h, padding_w) = padding;
347
348 // Calculate output dimensions
349 let output_h = ((height + 2 * padding_h - pool_h) / stride_h) + 1;
350 let output_w = ((width + 2 * padding_w - pool_w) / stride_w) + 1;
351
352 // Check for valid dimensions
353 if output_h == 0 || output_w == 0 {
354 return Err(LinalgError::ShapeError(format!(
355 "Invalid output dimensions: ({output_h}, {output_w})"
356 )));
357 }
358
359 // Allocate output tensors
360 let mut output = Array4::<F>::zeros((batchsize, channels, output_h, output_w));
361 let mut indices = Array4::<usize>::zeros((batchsize, channels, output_h, output_w));
362
363 // Perform max pooling
364 for batch_idx in 0..batchsize {
365 for channel_idx in 0..channels {
366 for output_row in 0..output_h {
367 for output_col in 0..output_w {
368 let start_h = output_row * stride_h;
369 let start_w = output_col * stride_w;
370
371 let mut max_val = F::neg_infinity();
372 let mut max_idx = 0;
373
374 // Find max value in the pooling window
375 for pool_row in 0..pool_h {
376 for pool_col in 0..pool_w {
377 let input_row = start_h + pool_row;
378 let input_col = start_w + pool_col;
379
380 // Check if the position is valid (not padding)
381 if input_row >= padding_h
382 && input_row < height + padding_h
383 && input_col >= padding_w
384 && input_col < width + padding_w
385 {
386 let input_row_idx = input_row - padding_h;
387 let input_col_idx = input_col - padding_w;
388 let val =
389 input[[batch_idx, channel_idx, input_row_idx, input_col_idx]];
390
391 if val > max_val {
392 max_val = val;
393 max_idx = input_row_idx * width + input_col_idx;
394 }
395 }
396 }
397 }
398
399 output[[batch_idx, channel_idx, output_row, output_col]] = max_val;
400 indices[[batch_idx, channel_idx, output_row, output_col]] = max_idx;
401 }
402 }
403 }
404 }
405
406 Ok((output, indices))
407}
408
409/// Perform the backward pass of max pooling operation
410///
411/// Takes the gradients of the pooled outputs and distributes them back to
412/// the locations of the maximum values in the original input.
413///
414/// # Arguments
415///
416/// * `grad_output` - Gradient of the output tensor of shape (batchsize, channels, output_height, output_width)
417/// * `indices` - Indices of the maximum values from the forward pass
418/// * `inputshape` - Shape of the original input tensor (batchsize, channels, height, width)
419///
420/// # Returns
421///
422/// * Gradient with respect to input
423///
424/// # Examples
425///
426/// ```
427/// use scirs2_core::ndarray::Array4;
428/// use scirs2_linalg::convolution::{max_pool2d, max_pool2d_backward};
429///
430/// // Create a 1x1x4x4 input tensor
431/// let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
432/// // Fill with sample data
433/// for h in 0..4 {
434/// for w in 0..4 {
435/// input[[0, 0, h, w]] = (h * 4 + w) as f32;
436/// }
437/// }
438///
439/// // Apply max pooling (forward pass)
440/// let (output, indices) = max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).unwrap();
441///
442/// // Create gradient of the output
443/// let mut grad_output = Array4::<f32>::ones((1, 1, 2, 2));
444///
445/// // Compute gradient of the input (backward pass)
446/// let grad_input = max_pool2d_backward(
447/// &grad_output.view(),
448/// &indices.view(),
449/// (1, 1, 4, 4),
450/// ).unwrap();
451///
452/// // Verify shape
453/// assert_eq!(grad_input.shape(), &[1, 1, 4, 4]);
454/// ```
455#[allow(dead_code)]
456pub fn max_pool2d_backward<F>(
457 grad_output: &ArrayView4<F>,
458 indices: &scirs2_core::ndarray::ArrayView4<usize>,
459 inputshape: (usize, usize, usize, usize),
460) -> LinalgResult<Array4<F>>
461where
462 F: Float + NumAssign + Sum + Zero + ScalarOperand,
463{
464 let (batchsize, channels, height, width) = inputshape;
465 let (out_batch, out_channels_, out_height, out_width) = grad_output.dim();
466 let (idx_batch, idx_channels, idx_height, idx_width) = indices.dim();
467
468 // Check that shapes match
469 if out_batch != idx_batch
470 || out_channels_ != idx_channels
471 || out_height != idx_height
472 || out_width != idx_width
473 {
474 return Err(LinalgError::ShapeError(format!(
475 "Shape mismatch between grad_output ({out_batch}, {out_channels_}, {out_height}, {out_width}) and indices ({idx_batch}, {idx_channels}, {idx_height}, {idx_width})"
476 )));
477 }
478
479 // Allocate _output gradient tensor
480 let mut grad_input = Array4::<F>::zeros((batchsize, channels, height, width));
481
482 // Distribute gradients to the locations of the maximum values
483 for batch_idx in 0..out_batch {
484 for channel_idx in 0..out_channels_ {
485 for h in 0..out_height {
486 for w in 0..out_width {
487 let index = indices[[batch_idx, channel_idx, h, w]];
488 let input_h = index / width;
489 let input_w = index % width;
490
491 if input_h < height && input_w < width {
492 grad_input[[batch_idx, channel_idx, input_h, input_w]] +=
493 grad_output[[batch_idx, channel_idx, h, w]];
494 }
495 }
496 }
497 }
498 }
499
500 Ok(grad_input)
501}
502
503/// Compute the indices for batch matrix multiplication in a convolutional layer
504///
505/// This function computes the indices needed for efficient batch matrix multiplication
506/// in a convolutional layer, which can be used to implement convolutional layers
507/// more efficiently.
508///
509/// # Arguments
510///
511/// * `inputshape` - Shape of the input tensor (batchsize, channels, height, width)
512/// * `kernelshape` - Shape of the kernel tensor (out_channels_, in_channels, kernel_h, kernel_w)
513/// * `stride` - Stride as (stride_height, stride_width)
514/// * `padding` - Padding as (padding_height, padding_width)
515///
516/// # Returns
517///
518/// * Indices for efficient batch matrix multiplication
519///
520/// # Examples
521///
522/// ```
523/// use scirs2_linalg::convolution::compute_conv_indices;
524///
525/// // Compute indices for a simple convolutional layer
526/// let indices = compute_conv_indices(
527/// (1, 1, 4, 4), // Input shape: batchsize=1, channels=1, height=4, width=4
528/// (1, 1, 2, 2), // Kernel shape: out_channels_=1, in_channels=1, kernel_h=2, kernel_w=2
529/// (1, 1), // Stride: height=1, width=1
530/// (0, 0), // Padding: height=0, width=0
531/// ).unwrap();
532/// // For a 4x4 input with 2x2 kernel and no padding, we get a 3x3 output
533/// // Each output element is computed from 4 input elements (2x2 kernel)
534/// // So we should have 3*3*4*5 = 180 values in the indices array
535/// assert_eq!(indices.len() % 5, 0); // Should be multiple of 5
536/// ```
537#[allow(dead_code)]
538pub fn compute_conv_indices(
539 inputshape: (usize, usize, usize, usize),
540 kernelshape: (usize, usize, usize, usize),
541 stride: (usize, usize),
542 padding: (usize, usize),
543) -> LinalgResult<scirs2_core::ndarray::Array1<usize>> {
544 let (batchsize, _in_channels, height, width) = inputshape;
545 let (out_channels_, in_channels, kernel_h, kernel_w) = kernelshape;
546 let (stride_h, stride_w) = stride;
547 let (padding_h, padding_w) = padding;
548
549 // Calculate output dimensions
550 let output_h = ((height + 2 * padding_h - kernel_h) / stride_h) + 1;
551 let output_w = ((width + 2 * padding_w - kernel_w) / stride_w) + 1;
552
553 // Check for valid dimensions
554 if output_h == 0 || output_w == 0 {
555 return Err(LinalgError::ShapeError(format!(
556 "Invalid output dimensions: ({output_h}, {output_w})"
557 )));
558 }
559
560 // Calculate total number of elements
561 // Each output element can be computed from in_channels * kernel_h * kernel_w input elements
562 let total_elements =
563 batchsize * out_channels_ * output_h * output_w * in_channels * kernel_h * kernel_w;
564
565 // Allocate array for indices (5 values per element)
566 let mut indices = scirs2_core::ndarray::Array1::<usize>::zeros(total_elements * 5);
567
568 // Compute indices for batch matmul
569 let mut idx = 0;
570 for b in 0..batchsize {
571 for oc in 0..out_channels_ {
572 for oh in 0..output_h {
573 for ow in 0..output_w {
574 for ic in 0..in_channels {
575 for kh in 0..kernel_h {
576 for kw in 0..kernel_w {
577 let ih = oh * stride_h + kh;
578 let iw = ow * stride_w + kw;
579
580 // Check if within padded input
581 if ih >= padding_h
582 && ih < height + padding_h
583 && iw >= padding_w
584 && iw < width + padding_w
585 {
586 let real_ih = ih - padding_h;
587 let real_iw = iw - padding_w;
588
589 // Output index
590 let out_idx = b * out_channels_ * output_h * output_w
591 + oc * output_h * output_w
592 + oh * output_w
593 + ow;
594
595 // Input index
596 let in_idx = b * in_channels * height * width
597 + ic * height * width
598 + real_ih * width
599 + real_iw;
600
601 // Kernel index
602 let kernel_idx = oc * in_channels * kernel_h * kernel_w
603 + ic * kernel_h * kernel_w
604 + kh * kernel_w
605 + kw;
606
607 // Store indices
608 indices[idx] = out_idx;
609 indices[idx + 1] = in_idx;
610 indices[idx + 2] = kernel_idx;
611 indices[idx + 3] = oh * output_w + ow;
612 indices[idx + 4] = oc;
613
614 idx += 5;
615 }
616 }
617 }
618 }
619 }
620 }
621 }
622 }
623
624 // Resize the array to remove unused elements
625 let indices = indices.slice(scirs2_core::ndarray::s![0..idx]).to_owned();
626 Ok(indices)
627}
628
629/// Apply convolution operation using im2col and matrix multiplication
630///
631/// This function implements convolution using the im2col algorithm and
632/// efficient matrix multiplication, which is often faster than direct
633/// convolution for large inputs or kernels.
634///
635/// # Arguments
636///
637/// * `input` - Input tensor of shape (batchsize, channels, height, width)
638/// * `kernel` - Kernel tensor of shape (out_channels_, in_channels, kernel_h, kernel_w)
639/// * `bias` - Optional bias tensor of shape (out_channels_,)
640/// * `stride` - Stride as (stride_height, stride_width)
641/// * `padding` - Padding as (padding_height, padding_width)
642/// * `dilation` - Dilation as (dilation_height, dilation_width)
643///
644/// # Returns
645///
646/// * Output tensor of shape (batchsize, out_channels_, output_height, output_width)
647///
648/// # Examples
649///
650/// ```
651/// use scirs2_core::ndarray::{Array, Array4};
652/// use scirs2_linalg::convolution::conv2d_im2col;
653///
654/// // Create a 2x3x32x32 input tensor (2 batches, 3 channels, 32x32 spatial dimensions)
655/// let input = Array4::<f32>::zeros((2, 3, 32, 32));
656///
657/// // Create a 16x3x3x3 kernel tensor (16 output channels, 3 input channels, 3x3 kernel)
658/// let kernel = Array4::<f32>::zeros((16, 3, 3, 3));
659///
660/// // Create a bias tensor
661/// let bias = Some(Array::zeros(16));
662///
663/// // Apply convolution
664/// let output = conv2d_im2col(
665/// &input.view(),
666/// &kernel.view(),
667/// bias.as_ref().map(|b| b.view()),
668/// (1, 1), // stride
669/// (1, 1), // padding
670/// (1, 1), // dilation
671/// ).unwrap();
672///
673/// // Output shape is (2, 16, 32, 32)
674/// assert_eq!(output.shape(), &[2, 16, 32, 32]);
675/// ```
676#[allow(dead_code)]
677pub fn conv2d_im2col<F>(
678 input: &ArrayView4<F>,
679 kernel: &ArrayView4<F>,
680 bias: Option<scirs2_core::ndarray::ArrayView1<F>>,
681 stride: (usize, usize),
682 padding: (usize, usize),
683 dilation: (usize, usize),
684) -> LinalgResult<Array4<F>>
685where
686 F: Float + NumAssign + Sum + Zero + ScalarOperand,
687{
688 let (batchsize, in_channels, height, width) = input.dim();
689 let (out_channels_, k_in_channels, kernel_h, kernel_w) = kernel.dim();
690
691 // Check that input and kernel channels match
692 if in_channels != k_in_channels {
693 return Err(LinalgError::ShapeError(format!(
694 "Input channels ({in_channels}) must match kernel in_channels ({k_in_channels})"
695 )));
696 }
697
698 // Check bias shape if provided
699 if let Some(b) = bias {
700 if b.len() != out_channels_ {
701 return Err(LinalgError::ShapeError(format!(
702 "Bias length ({}) must match out_channels_ ({})",
703 b.len(),
704 out_channels_
705 )));
706 }
707 }
708
709 // Calculate output dimensions
710 let (stride_h, stride_w) = stride;
711 let (padding_h, padding_w) = padding;
712 let (dilation_h, dilation_w) = dilation;
713
714 let output_h = ((height + 2 * padding_h - dilation_h * (kernel_h - 1) - 1) / stride_h) + 1;
715 let output_w = ((width + 2 * padding_w - dilation_w * (kernel_w - 1) - 1) / stride_w) + 1;
716
717 // Check for valid dimensions
718 if output_h == 0 || output_w == 0 {
719 return Err(LinalgError::ShapeError(format!(
720 "Invalid output dimensions: ({output_h}, {output_w})"
721 )));
722 }
723
724 // Convert input to columns using im2col
725 let cols = im2col(input, (kernel_h, kernel_w), stride, padding, dilation)?;
726
727 // Reshape kernel for matrix multiplication
728 let flat_kernel = (*kernel)
729 .into_shape_with_order((out_channels_, in_channels * kernel_h * kernel_w))
730 .map_err(|e| LinalgError::ShapeError(e.to_string()))?;
731
732 // Perform matrix multiplication
733 let output_2d = flat_kernel.dot(&cols);
734
735 // Reshape to output tensor
736 let mut output = output_2d
737 .into_shape_with_order((out_channels_, batchsize, output_h, output_w))
738 .map_err(|e| LinalgError::ShapeError(e.to_string()))?;
739
740 // Rearrange dimensions to (batchsize, out_channels_, output_h, output_w)
741 output = output.permuted_axes([1, 0, 2, 3]);
742
743 // Add bias if provided
744 if let Some(b) = bias {
745 for batch_idx in 0..batchsize {
746 for oc in 0..out_channels_ {
747 for h in 0..output_h {
748 for w in 0..output_w {
749 output[[batch_idx, oc, h, w]] += b[oc];
750 }
751 }
752 }
753 }
754 }
755
756 Ok(output)
757}
758
759/// Apply backward pass of convolution operation for input gradient
760///
761/// This function computes the gradient of the input in a convolutional layer
762/// given the gradient of the output.
763///
764/// # Arguments
765///
766/// * `grad_output` - Gradient of the output tensor of shape (batchsize, out_channels_, output_h, output_w)
767/// * `kernel` - Kernel tensor of shape (out_channels_, in_channels, kernel_h, kernel_w)
768/// * `inputshape` - Shape of the input tensor (batchsize, in_channels, height, width)
769/// * `stride` - Stride as (stride_height, stride_width)
770/// * `padding` - Padding as (padding_height, padding_width)
771/// * `dilation` - Dilation as (dilation_height, dilation_width)
772///
773/// # Returns
774///
775/// * Gradient of the input tensor of shape (batchsize, in_channels, height, width)
776///
777/// # Examples
778///
779/// ```
780/// use scirs2_core::ndarray::Array4;
781/// use scirs2_linalg::convolution::{conv2d_im2col, conv2d_backward_input};
782///
783/// // Forward pass
784/// let input = Array4::<f32>::zeros((2, 3, 32, 32));
785/// let kernel = Array4::<f32>::zeros((16, 3, 3, 3));
786/// let bias = None;
787/// let output = conv2d_im2col(
788/// &input.view(),
789/// &kernel.view(),
790/// bias,
791/// (1, 1),
792/// (1, 1),
793/// (1, 1),
794/// ).unwrap();
795///
796/// // Backward pass
797/// let grad_output = Array4::<f32>::ones((2, 16, 32, 32));
798/// let grad_input = conv2d_backward_input(
799/// &grad_output.view(),
800/// &kernel.view(),
801/// (2, 3, 32, 32),
802/// (1, 1),
803/// (1, 1),
804/// (1, 1),
805/// ).unwrap();
806///
807/// // Gradient shape matches input shape
808/// assert_eq!(grad_input.shape(), &[2, 3, 32, 32]);
809/// ```
810#[allow(dead_code)]
811pub fn conv2d_backward_input<F>(
812 grad_output: &ArrayView4<F>,
813 kernel: &ArrayView4<F>,
814 inputshape: (usize, usize, usize, usize),
815 stride: (usize, usize),
816 padding: (usize, usize),
817 dilation: (usize, usize),
818) -> LinalgResult<Array4<F>>
819where
820 F: Float + NumAssign + Sum + Zero + ScalarOperand,
821{
822 let (batchsize, out_channels, _output_h, _output_w) = grad_output.dim();
823 let (k_out_channels, in_channels, kernel_h, kernel_w) = kernel.dim();
824 let (i_batchsize, i_in_channels, _height, _width) = inputshape;
825
826 // Check that shapes match
827 if batchsize != i_batchsize {
828 return Err(LinalgError::ShapeError(format!(
829 "Batch size mismatch: grad_output ({batchsize}) vs inputshape ({i_batchsize})"
830 )));
831 }
832
833 if out_channels != k_out_channels {
834 return Err(LinalgError::ShapeError(format!(
835 "Output channels mismatch: grad_output ({out_channels}) vs kernel ({k_out_channels})"
836 )));
837 }
838
839 if in_channels != i_in_channels {
840 return Err(LinalgError::ShapeError(format!(
841 "Input channels mismatch: kernel ({in_channels}) vs inputshape ({i_in_channels})"
842 )));
843 }
844
845 // Prepare kernel for transposed convolution
846 let mut kernel_transposed = Array4::<F>::zeros((in_channels, out_channels, kernel_h, kernel_w));
847
848 // Flip the kernel and transpose input/_output channels
849 for oc in 0..out_channels {
850 for ic in 0..in_channels {
851 for kh in 0..kernel_h {
852 for kw in 0..kernel_w {
853 kernel_transposed[[ic, oc, kernel_h - 1 - kh, kernel_w - 1 - kw]] =
854 kernel[[oc, ic, kh, kw]];
855 }
856 }
857 }
858 }
859
860 // Calculate padding for transposed convolution
861 let (_stride_h, _stride_w) = stride;
862 let (padding_h, padding_w) = padding;
863 let (_dilation_h, _dilation_w) = dilation;
864
865 // We need to adjust padding for transposed convolution
866 let pad_h = kernel_h - 1 - padding_h;
867 let pad_w = kernel_w - 1 - padding_w;
868
869 // Perform transposed convolution
870 // For transposed convolution, we swap the roles of stride and dilation
871 conv2d_im2col(
872 grad_output,
873 &kernel_transposed.view(),
874 None,
875 dilation, // original dilation becomes stride
876 (pad_h, pad_w), // adjusted padding
877 stride, // original stride becomes dilation
878 )
879}
880
881/// Apply backward pass of convolution operation for kernel gradient
882///
883/// This function computes the gradient of the kernel in a convolutional layer
884/// given the gradient of the output and the input.
885///
886/// # Arguments
887///
888/// * `input` - Input tensor of shape (batchsize, in_channels, height, width)
889/// * `grad_output` - Gradient of the output tensor of shape (batchsize, out_channels_, output_h, output_w)
890/// * `kernelshape` - Shape of the kernel tensor (out_channels_, in_channels, kernel_h, kernel_w)
891/// * `stride` - Stride as (stride_height, stride_width)
892/// * `padding` - Padding as (padding_height, padding_width)
893/// * `dilation` - Dilation as (dilation_height, dilation_width)
894///
895/// # Returns
896///
897/// * Gradient of the kernel tensor of shape (out_channels_, in_channels, kernel_h, kernel_w)
898///
899/// # Examples
900///
901/// ```
902/// use scirs2_core::ndarray::Array4;
903/// use scirs2_linalg::convolution::{conv2d_im2col, conv2d_backward_kernel};
904///
905/// // Simple example with smaller dimensions
906/// let input = Array4::<f32>::zeros((1, 1, 4, 4));
907/// let kernelshape = (1, 1, 2, 2);
908///
909/// // Forward pass to get output shape
910/// let kernel = Array4::<f32>::zeros(kernelshape);
911/// let output = conv2d_im2col(
912/// &input.view(),
913/// &kernel.view(),
914/// None,
915/// (1, 1), // stride
916/// (0, 0), // padding
917/// (1, 1), // dilation
918/// ).unwrap();
919///
920/// // Backward pass - grad_output must match forward output shape
921/// let grad_output = Array4::<f32>::ones(output.dim());
922/// let grad_kernel = conv2d_backward_kernel(
923/// &input.view(),
924/// &grad_output.view(),
925/// kernelshape,
926/// (1, 1),
927/// (0, 0),
928/// (1, 1),
929/// ).unwrap();
930///
931/// // Gradient shape matches kernel shape
932/// assert_eq!(grad_kernel.shape(), &[1, 1, 2, 2]);
933/// ```
934#[allow(dead_code)]
935pub fn conv2d_backward_kernel<F>(
936 input: &ArrayView4<F>,
937 grad_output: &ArrayView4<F>,
938 kernelshape: (usize, usize, usize, usize),
939 stride: (usize, usize),
940 padding: (usize, usize),
941 dilation: (usize, usize),
942) -> LinalgResult<Array4<F>>
943where
944 F: Float + NumAssign + Sum + Zero + ScalarOperand,
945{
946 let (batchsize, in_channels, _height, _width) = input.dim();
947 let (go_batchsize, out_channels_, output_h, output_w) = grad_output.dim();
948 let (k_out_channels, k_in_channels, kernel_h, kernel_w) = kernelshape;
949
950 // Check that shapes match
951 if batchsize != go_batchsize {
952 return Err(LinalgError::ShapeError(format!(
953 "Batch size mismatch: input ({batchsize}) vs grad_output ({go_batchsize})"
954 )));
955 }
956
957 if out_channels_ != k_out_channels {
958 return Err(LinalgError::ShapeError(format!(
959 "Output channels mismatch: grad_output ({out_channels_}) vs kernelshape ({k_out_channels})"
960 )));
961 }
962
963 if in_channels != k_in_channels {
964 return Err(LinalgError::ShapeError(format!(
965 "Input channels mismatch: input ({in_channels}) vs kernelshape ({k_in_channels})"
966 )));
967 }
968
969 // Convert input to columns using im2col
970 let cols = im2col(input, (kernel_h, kernel_w), stride, padding, dilation)?;
971
972 // Reshape grad_output for matrix multiplication
973 let grad_output_reshaped = (*grad_output)
974 .into_shape_with_order((batchsize * out_channels_, output_h * output_w))
975 .map_err(|e| LinalgError::ShapeError(e.to_string()))?;
976
977 // Compute kernel gradient using matrix multiplication
978 let grad_kernel_flat = grad_output_reshaped.dot(&cols.t());
979
980 // Reshape to kernel shape
981 let grad_kernel = grad_kernel_flat
982 .into_shape_with_order((out_channels_, in_channels, kernel_h, kernel_w))
983 .map_err(|e| LinalgError::ShapeError(e.to_string()))?;
984
985 Ok(grad_kernel)
986}
987
988/// Apply backward pass of convolution operation for bias gradient
989///
990/// This function computes the gradient of the bias in a convolutional layer
991/// given the gradient of the output.
992///
993/// # Arguments
994///
995/// * `grad_output` - Gradient of the output tensor of shape (batchsize, out_channels_, output_h, output_w)
996///
997/// # Returns
998///
999/// * Gradient of the bias tensor of shape (out_channels_,)
1000///
1001/// # Examples
1002///
1003/// ```
1004/// use scirs2_core::ndarray::Array4;
1005/// use scirs2_linalg::convolution::conv2d_backward_bias;
1006///
1007/// // Backward pass for bias
1008/// let grad_output = Array4::<f32>::ones((2, 16, 32, 32));
1009/// let grad_bias = conv2d_backward_bias(&grad_output.view()).unwrap();
1010///
1011/// // Gradient shape matches bias shape
1012/// assert_eq!(grad_bias.shape(), &[16]);
1013/// ```
1014#[allow(dead_code)]
1015pub fn conv2d_backward_bias<F>(
1016 grad_output: &ArrayView4<F>,
1017) -> LinalgResult<scirs2_core::ndarray::Array1<F>>
1018where
1019 F: Float + NumAssign + Sum + Zero,
1020{
1021 let (batchsize, out_channels_, output_h, output_w) = grad_output.dim();
1022
1023 // Allocate gradient for bias
1024 let mut grad_bias = scirs2_core::ndarray::Array1::<F>::zeros(out_channels_);
1025
1026 // Sum gradients over batch, height, and width dimensions
1027 for batch_idx in 0..batchsize {
1028 for oc in 0..out_channels_ {
1029 for h in 0..output_h {
1030 for w in 0..output_w {
1031 grad_bias[oc] += grad_output[[batch_idx, oc, h, w]];
1032 }
1033 }
1034 }
1035 }
1036
1037 Ok(grad_bias)
1038}
1039
1040/// Apply 2D transposed convolution (deconvolution) operation
1041///
1042/// This function implements a transposed convolution (also known as deconvolution
1043/// or fractionally-strided convolution), which is commonly used in convolutional
1044/// neural networks for upsampling.
1045///
1046/// # Arguments
1047///
1048/// * `input` - Input tensor of shape (batchsize, channels, height, width)
1049/// * `kernel` - Kernel tensor of shape (in_channels, out_channels_, kernel_h, kernel_w)
1050/// * `bias` - Optional bias tensor of shape (out_channels_,)
1051/// * `stride` - Stride as (stride_height, stride_width)
1052/// * `padding` - Padding as (padding_height, padding_width)
1053/// * `output_padding` - Additional padding for output as (padding_height, padding_width)
1054/// * `dilation` - Dilation as (dilation_height, dilation_width)
1055///
1056/// # Returns
1057///
1058/// * Output tensor of shape (batchsize, out_channels_, output_height, output_width)
1059///
1060/// # Examples
1061///
1062/// ```
1063/// use scirs2_core::ndarray::{Array, Array4};
1064/// use scirs2_linalg::convolution::conv_transpose2d;
1065///
1066/// // Simple example with smaller dimensions
1067/// let input = Array4::<f32>::zeros((1, 2, 3, 3));
1068///
1069/// // Create a 2x1x2x2 kernel tensor (2 input channels, 1 output channel, 2x2 kernel)
1070/// let kernel = Array4::<f32>::zeros((2, 1, 2, 2));
1071///
1072/// // Apply transposed convolution
1073/// let output = conv_transpose2d(
1074/// &input.view(),
1075/// &kernel.view(),
1076/// None, // no bias
1077/// (1, 1), // stride
1078/// (0, 0), // padding
1079/// (0, 0), // output_padding
1080/// (1, 1), // dilation
1081/// ).unwrap();
1082///
1083/// // Calculate expected output shape:
1084/// // output_h = (3 - 1) * 1 - 2 * 0 + 1 * (2 - 1) + 0 + 1 = 2 + 1 + 1 = 4
1085/// // output_w = (3 - 1) * 1 - 2 * 0 + 1 * (2 - 1) + 0 + 1 = 2 + 1 + 1 = 4
1086/// assert_eq!(output.shape(), &[1, 1, 4, 4]);
1087/// ```
1088#[allow(dead_code)]
1089pub fn conv_transpose2d<F>(
1090 input: &ArrayView4<F>,
1091 kernel: &ArrayView4<F>,
1092 bias: Option<scirs2_core::ndarray::ArrayView1<F>>,
1093 stride: (usize, usize),
1094 padding: (usize, usize),
1095 output_padding: (usize, usize),
1096 dilation: (usize, usize),
1097) -> LinalgResult<Array4<F>>
1098where
1099 F: Float + NumAssign + Sum + Zero + ScalarOperand,
1100{
1101 let (batchsize, in_channels, height, width) = input.dim();
1102 let (k_in_channels, out_channels_, kernel_h, kernel_w) = kernel.dim();
1103
1104 // Check that channels match
1105 if in_channels != k_in_channels {
1106 return Err(LinalgError::ShapeError(format!(
1107 "Input channels mismatch: input ({in_channels}) vs kernel ({k_in_channels})"
1108 )));
1109 }
1110
1111 // Check bias shape if provided
1112 if let Some(b) = bias {
1113 if b.len() != out_channels_ {
1114 return Err(LinalgError::ShapeError(format!(
1115 "Bias length ({}) must match out_channels_ ({})",
1116 b.len(),
1117 out_channels_
1118 )));
1119 }
1120 }
1121
1122 // Calculate output dimensions
1123 let (stride_h, stride_w) = stride;
1124 let (padding_h, padding_w) = padding;
1125 let (output_padding_h, output_padding_w) = output_padding;
1126 let (dilation_h, dilation_w) = dilation;
1127
1128 let output_h = (height - 1) * stride_h - 2 * padding_h
1129 + dilation_h * (kernel_h - 1)
1130 + output_padding_h
1131 + 1;
1132 let output_w =
1133 (width - 1) * stride_w - 2 * padding_w + dilation_w * (kernel_w - 1) + output_padding_w + 1;
1134
1135 // Allocate output tensor
1136 let mut output = Array4::<F>::zeros((batchsize, out_channels_, output_h, output_w));
1137
1138 // Perform transposed convolution
1139 for b in 0..batchsize {
1140 for oc in 0..out_channels_ {
1141 for ic in 0..in_channels {
1142 for h in 0..height {
1143 for w in 0..width {
1144 let input_val = input[[b, ic, h, w]];
1145
1146 for kh in 0..kernel_h {
1147 for kw in 0..kernel_w {
1148 // Calculate output coordinates for transposed convolution
1149 // For transposed conv, we're scattering values from input to output
1150 // The calculation needs to account for the kernel's position in the opposite way
1151 // than normal convolution
1152 let out_h = h as isize * stride_h as isize
1153 + kh as isize * dilation_h as isize
1154 - padding_h as isize;
1155 let out_w = w as isize * stride_w as isize
1156 + kw as isize * dilation_w as isize
1157 - padding_w as isize;
1158
1159 // Only process if coordinates are non-negative
1160 if out_h >= 0 && out_w >= 0 {
1161 let out_h = out_h as usize;
1162 let out_w = out_w as usize;
1163
1164 if out_h < output_h && out_w < output_w {
1165 // The transposed convolution can be thought of as:
1166 // 1. For each input position and each kernel position
1167 // 2. Calculate the output position based on stride, padding, and dilation
1168 // 3. Add the product of input and kernel value to that output position
1169 //
1170 // Note: Technically, for a proper mathematical transposed convolution,
1171 // we should use the flipped kernel. However, in practice, ML libraries
1172 // often just reuse the same kernel for simplicity and learn appropriate weights.
1173 output[[b, oc, out_h, out_w]] +=
1174 input_val * kernel[[ic, oc, kh, kw]];
1175 }
1176 }
1177 }
1178 }
1179 }
1180 }
1181 }
1182
1183 // Add bias if provided
1184 if let Some(b_val) = bias.map(|b| b[oc]) {
1185 for h in 0..output_h {
1186 for w in 0..output_w {
1187 output[[b, oc, h, w]] += b_val;
1188 }
1189 }
1190 }
1191 }
1192 }
1193
1194 Ok(output)
1195}
1196
1197#[cfg(test)]
1198mod tests {
1199 use super::*;
1200 use approx::assert_relative_eq;
1201 use scirs2_core::ndarray::{Array1, Array4};
1202
1203 #[test]
1204 fn test_im2col_basic() {
1205 // Create a simple 1x1x3x3 input
1206 let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
1207 for h in 0..3 {
1208 for w in 0..3 {
1209 input[[0, 0, h, w]] = (h * 3 + w) as f32;
1210 }
1211 }
1212
1213 // Extract 2x2 patches with stride 1 and no padding
1214 let cols = im2col(&input.view(), (2, 2), (1, 1), (0, 0), (1, 1)).unwrap();
1215
1216 // Resulting matrix should be (1*2*2, 2*2*1) = (4, 4)
1217 assert_eq!(cols.shape(), &[4, 4]);
1218
1219 // Check the first column (top-left 2x2 patch)
1220 assert_eq!(cols[[0, 0]], 0.0);
1221 assert_eq!(cols[[1, 0]], 1.0);
1222 assert_eq!(cols[[2, 0]], 3.0);
1223 assert_eq!(cols[[3, 0]], 4.0);
1224
1225 // Check the second column (top-right 2x2 patch)
1226 assert_eq!(cols[[0, 1]], 1.0);
1227 assert_eq!(cols[[1, 1]], 2.0);
1228 assert_eq!(cols[[2, 1]], 4.0);
1229 assert_eq!(cols[[3, 1]], 5.0);
1230 }
1231
1232 #[test]
1233 fn test_im2col_with_padding() {
1234 // Create a simple 1x1x2x2 input
1235 let mut input = Array4::<f32>::zeros((1, 1, 2, 2));
1236 input[[0, 0, 0, 0]] = 0.0;
1237 input[[0, 0, 0, 1]] = 1.0;
1238 input[[0, 0, 1, 0]] = 2.0;
1239 input[[0, 0, 1, 1]] = 3.0;
1240
1241 // Extract 3x3 patches with stride 1 and padding 1
1242 let cols = im2col(&input.view(), (3, 3), (1, 1), (1, 1), (1, 1)).unwrap();
1243
1244 // Resulting matrix should be (1*3*3, 2*2*1) = (9, 4)
1245 assert_eq!(cols.shape(), &[9, 4]);
1246
1247 // Check padding is zero
1248 assert_eq!(cols[[0, 0]], 0.0); // Top-left padding
1249 assert_eq!(cols[[2, 0]], 0.0); // Top-right padding
1250 assert_eq!(cols[[6, 0]], 0.0); // Bottom-left padding
1251 assert_eq!(cols[[8, 0]], 3.0); // Bottom-right padding - this corresponds to input[1,1] at position (2,2)
1252
1253 // Check actual values - kernel center (1,1) at patch (0,0) corresponds to input (0,0)
1254 assert_eq!(cols[[4, 0]], 0.0); // Center of first patch corresponds to input[0,0,0,0]
1255 }
1256
1257 #[test]
1258 fn test_col2im_basic() {
1259 // Create a simple 1x1x3x3 input
1260 let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
1261 for h in 0..3 {
1262 for w in 0..3 {
1263 input[[0, 0, h, w]] = (h * 3 + w) as f32;
1264 }
1265 }
1266
1267 // Convert to columns
1268 let cols = im2col(&input.view(), (2, 2), (1, 1), (0, 0), (1, 1)).unwrap();
1269
1270 // Convert back to image
1271 let output = col2im(&cols.view(), (1, 1, 3, 3), (2, 2), (1, 1), (0, 0), (1, 1)).unwrap();
1272
1273 // Check dimensions
1274 assert_eq!(output.shape(), input.shape());
1275
1276 // Check values (note that overlapping patches are averaged)
1277 assert_relative_eq!(output[[0, 0, 0, 0]], input[[0, 0, 0, 0]], epsilon = 1e-5);
1278 assert_relative_eq!(output[[0, 0, 0, 1]], input[[0, 0, 0, 1]], epsilon = 1e-5);
1279 assert_relative_eq!(output[[0, 0, 0, 2]], input[[0, 0, 0, 2]], epsilon = 1e-5);
1280 assert_relative_eq!(output[[0, 0, 1, 0]], input[[0, 0, 1, 0]], epsilon = 1e-5);
1281 assert_relative_eq!(output[[0, 0, 1, 1]], input[[0, 0, 1, 1]], epsilon = 1e-5);
1282 assert_relative_eq!(output[[0, 0, 1, 2]], input[[0, 0, 1, 2]], epsilon = 1e-5);
1283 assert_relative_eq!(output[[0, 0, 2, 0]], input[[0, 0, 2, 0]], epsilon = 1e-5);
1284 assert_relative_eq!(output[[0, 0, 2, 1]], input[[0, 0, 2, 1]], epsilon = 1e-5);
1285 assert_relative_eq!(output[[0, 0, 2, 2]], input[[0, 0, 2, 2]], epsilon = 1e-5);
1286 }
1287
1288 #[test]
1289 fn test_max_pool2d() {
1290 // Create a simple 1x1x4x4 input
1291 let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
1292 for h in 0..4 {
1293 for w in 0..4 {
1294 input[[0, 0, h, w]] = (h * 4 + w) as f32;
1295 }
1296 }
1297
1298 // Apply 2x2 max pooling with stride 2
1299 let (output, indices) = max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).unwrap();
1300
1301 // Check dimensions
1302 assert_eq!(output.shape(), &[1, 1, 2, 2]);
1303
1304 // Check values (should take max from each 2x2 region)
1305 assert_eq!(output[[0, 0, 0, 0]], 5.0); // max of top-left 2x2
1306 assert_eq!(output[[0, 0, 0, 1]], 7.0); // max of top-right 2x2
1307 assert_eq!(output[[0, 0, 1, 0]], 13.0); // max of bottom-left 2x2
1308 assert_eq!(output[[0, 0, 1, 1]], 15.0); // max of bottom-right 2x2
1309
1310 // Check indices
1311 assert_eq!(indices[[0, 0, 0, 0]], 5); // index of 5 in flattened input
1312 assert_eq!(indices[[0, 0, 0, 1]], 7); // index of 7 in flattened input
1313 assert_eq!(indices[[0, 0, 1, 0]], 13); // index of 13 in flattened input
1314 assert_eq!(indices[[0, 0, 1, 1]], 15); // index of 15 in flattened input
1315 }
1316
1317 #[test]
1318 fn test_max_pool2d_backward() {
1319 // Create a simple 1x1x4x4 input
1320 let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
1321 for h in 0..4 {
1322 for w in 0..4 {
1323 input[[0, 0, h, w]] = (h * 4 + w) as f32;
1324 }
1325 }
1326
1327 // Forward pass
1328 let (_output, indices) = max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).unwrap();
1329
1330 // Create gradient of output
1331 let grad_output = Array4::<f32>::ones((1, 1, 2, 2));
1332
1333 // Backward pass
1334 let grad_input =
1335 max_pool2d_backward(&grad_output.view(), &indices.view(), (1, 1, 4, 4)).unwrap();
1336
1337 // Check dimensions
1338 assert_eq!(grad_input.shape(), input.shape());
1339
1340 // Only positions with max values should have gradients
1341 for h in 0..4 {
1342 for w in 0..4 {
1343 let pos = h * 4 + w;
1344 let expected = if pos == 5 || pos == 7 || pos == 13 || pos == 15 {
1345 1.0
1346 } else {
1347 0.0
1348 };
1349 assert_eq!(grad_input[[0, 0, h, w]], expected);
1350 }
1351 }
1352 }
1353
1354 #[test]
1355 fn test_conv2d_im2col_basic() {
1356 // Create a simple 1x1x3x3 input
1357 let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
1358 for h in 0..3 {
1359 for w in 0..3 {
1360 input[[0, 0, h, w]] = (h * 3 + w) as f32;
1361 }
1362 }
1363
1364 // Create a simple 1x1x2x2 kernel (identity)
1365 let mut kernel = Array4::<f32>::zeros((1, 1, 2, 2));
1366 kernel[[0, 0, 0, 0]] = 1.0;
1367 kernel[[0, 0, 0, 1]] = 0.0;
1368 kernel[[0, 0, 1, 0]] = 0.0;
1369 kernel[[0, 0, 1, 1]] = 0.0;
1370
1371 // Apply convolution
1372 let output =
1373 conv2d_im2col(&input.view(), &kernel.view(), None, (1, 1), (0, 0), (1, 1)).unwrap();
1374
1375 // Check dimensions
1376 assert_eq!(output.shape(), &[1, 1, 2, 2]);
1377
1378 // Kernel extracts top-left value from each position
1379 assert_eq!(output[[0, 0, 0, 0]], 0.0);
1380 assert_eq!(output[[0, 0, 0, 1]], 1.0);
1381 assert_eq!(output[[0, 0, 1, 0]], 3.0);
1382 assert_eq!(output[[0, 0, 1, 1]], 4.0);
1383 }
1384
1385 #[test]
1386 fn test_conv2d_im2col_with_bias() {
1387 // Create a simple 1x1x3x3 input
1388 let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
1389 for h in 0..3 {
1390 for w in 0..3 {
1391 input[[0, 0, h, w]] = (h * 3 + w) as f32;
1392 }
1393 }
1394
1395 // Create a simple 1x1x2x2 kernel (identity)
1396 let mut kernel = Array4::<f32>::zeros((1, 1, 2, 2));
1397 kernel[[0, 0, 0, 0]] = 1.0;
1398 kernel[[0, 0, 0, 1]] = 0.0;
1399 kernel[[0, 0, 1, 0]] = 0.0;
1400 kernel[[0, 0, 1, 1]] = 0.0;
1401
1402 // Create bias
1403 let bias = Array1::<f32>::from_elem(1, 10.0);
1404
1405 // Apply convolution with bias
1406 let output = conv2d_im2col(
1407 &input.view(),
1408 &kernel.view(),
1409 Some(bias.view()),
1410 (1, 1),
1411 (0, 0),
1412 (1, 1),
1413 )
1414 .unwrap();
1415
1416 // Check dimensions
1417 assert_eq!(output.shape(), &[1, 1, 2, 2]);
1418
1419 // Kernel extracts top-left value from each position, plus bias
1420 assert_eq!(output[[0, 0, 0, 0]], 10.0);
1421 assert_eq!(output[[0, 0, 0, 1]], 11.0);
1422 assert_eq!(output[[0, 0, 1, 0]], 13.0);
1423 assert_eq!(output[[0, 0, 1, 1]], 14.0);
1424 }
1425
1426 #[test]
1427 fn test_conv2d_backward_input() {
1428 // Create a simple 1x1x3x3 input
1429 let input = Array4::<f32>::zeros((1, 1, 3, 3));
1430
1431 // Create a simple 1x1x2x2 kernel
1432 let mut kernel = Array4::<f32>::zeros((1, 1, 2, 2));
1433 kernel[[0, 0, 0, 0]] = 1.0;
1434 kernel[[0, 0, 0, 1]] = 2.0;
1435 kernel[[0, 0, 1, 0]] = 3.0;
1436 kernel[[0, 0, 1, 1]] = 4.0;
1437
1438 // Apply forward pass
1439 let _output =
1440 conv2d_im2col(&input.view(), &kernel.view(), None, (1, 1), (0, 0), (1, 1)).unwrap();
1441
1442 // Create gradient of output
1443 let grad_output = Array4::<f32>::ones((1, 1, 2, 2));
1444
1445 // Apply backward pass for input
1446 let grad_input = conv2d_backward_input(
1447 &grad_output.view(),
1448 &kernel.view(),
1449 (1, 1, 3, 3),
1450 (1, 1),
1451 (0, 0),
1452 (1, 1),
1453 )
1454 .unwrap();
1455
1456 // Check dimensions
1457 assert_eq!(grad_input.shape(), input.shape());
1458
1459 // Each position receives weighted gradients from overlapping filters
1460 // For gradient=1 at each output position:
1461 // input[0,0] receives 1.0 from output[0,0]
1462 // input[0,1] receives 2.0 from output[0,0] + 1.0 from output[0,1] = 3.0
1463 // etc.
1464 assert_eq!(grad_input[[0, 0, 0, 0]], 1.0);
1465 assert_eq!(grad_input[[0, 0, 0, 1]], 3.0);
1466 assert_eq!(grad_input[[0, 0, 1, 0]], 4.0);
1467 assert_eq!(grad_input[[0, 0, 1, 1]], 10.0);
1468 }
1469
1470 #[test]
1471 fn test_conv2d_backward_kernel() {
1472 // Create a simple 1x1x3x3 input with all ones
1473 let input = Array4::<f32>::ones((1, 1, 3, 3));
1474
1475 // Create gradient of output, all ones
1476 let grad_output = Array4::<f32>::ones((1, 1, 2, 2));
1477
1478 // Apply backward pass for kernel
1479 let grad_kernel = conv2d_backward_kernel(
1480 &input.view(),
1481 &grad_output.view(),
1482 (1, 1, 2, 2),
1483 (1, 1),
1484 (0, 0),
1485 (1, 1),
1486 )
1487 .unwrap();
1488
1489 // Check dimensions
1490 assert_eq!(grad_kernel.shape(), &[1, 1, 2, 2]);
1491
1492 // With all ones input and gradient, each kernel position accumulates
1493 // the number of times it overlaps with the input
1494 assert_eq!(grad_kernel[[0, 0, 0, 0]], 4.0); // Overlaps 4 times
1495 assert_eq!(grad_kernel[[0, 0, 0, 1]], 4.0);
1496 assert_eq!(grad_kernel[[0, 0, 1, 0]], 4.0);
1497 assert_eq!(grad_kernel[[0, 0, 1, 1]], 4.0);
1498 }
1499
1500 #[test]
1501 fn test_conv2d_backward_bias() {
1502 // Create gradient of output
1503 let mut grad_output = Array4::<f32>::zeros((2, 3, 2, 2));
1504 for b in 0..2 {
1505 for c in 0..3 {
1506 for h in 0..2 {
1507 for w in 0..2 {
1508 grad_output[[b, c, h, w]] = 1.0;
1509 }
1510 }
1511 }
1512 }
1513
1514 // Apply backward pass for bias
1515 let grad_bias = conv2d_backward_bias(&grad_output.view()).unwrap();
1516
1517 // Check dimensions
1518 assert_eq!(grad_bias.shape(), &[3]);
1519
1520 // Each bias accumulates gradient from all positions and batches
1521 assert_eq!(grad_bias[0], 8.0); // 2 batches * 2*2 spatial = 8
1522 assert_eq!(grad_bias[1], 8.0);
1523 assert_eq!(grad_bias[2], 8.0);
1524 }
1525
1526 #[test]
1527 fn test_conv_transpose2d() {
1528 // Create a simple 1x1x2x2 input
1529 let input = Array4::<f32>::ones((1, 1, 2, 2));
1530
1531 // Create a simple 1x1x3x3 kernel with only top-left value set to 1.0
1532 let mut kernel = Array4::<f32>::zeros((1, 1, 3, 3));
1533 kernel[[0, 0, 0, 0]] = 1.0;
1534 // All other values are 0.0
1535
1536 // Apply transposed convolution
1537 let output = conv_transpose2d(
1538 &input.view(),
1539 &kernel.view(),
1540 None,
1541 (2, 2), // stride
1542 (1, 1), // padding
1543 (0, 0), // output_padding
1544 (1, 1), // dilation
1545 )
1546 .unwrap();
1547
1548 // Check dimensions
1549 // outputsize = (inputsize - 1) * stride - 2 * padding + kernelsize
1550 // = (2-1)*2 - 2*1 + 3 = 2 - 2 + 3 = 3
1551 assert_eq!(output.shape(), &[1, 1, 3, 3]);
1552
1553 // Let's carefully trace through the algorithm for each input position:
1554 // The input is all ones at positions (0,0), (0,1), (1,0), and (1,1)
1555 // The kernel has a 1.0 at position (0,0,0,0) and zeros elsewhere
1556 // With stride=(2,2), padding=(1,1), the output coordinates are calculated as:
1557 //
1558 // For each input position (h,w) and kernel position (kh,kw):
1559 // out_h = h*stride_h + kh*dilation_h - padding_h
1560 // out_w = w*stride_w + kw*dilation_w - padding_w
1561 //
1562 // For input (0,0) and kernel (0,0):
1563 // out_h = 0*2 + 0*1 - 1 = -1 (out of bounds)
1564 // out_w = 0*2 + 0*1 - 1 = -1 (out of bounds)
1565 //
1566 // For input (0,1) and kernel (0,0):
1567 // out_h = 0*2 + 0*1 - 1 = -1 (out of bounds)
1568 // out_w = 1*2 + 0*1 - 1 = 1
1569 //
1570 // For input (1,0) and kernel (0,0):
1571 // out_h = 1*2 + 0*1 - 1 = 1
1572 // out_w = 0*2 + 0*1 - 1 = -1 (out of bounds)
1573 //
1574 // For input (1,1) and kernel (0,0):
1575 // out_h = 1*2 + 0*1 - 1 = 1
1576 // out_w = 1*2 + 0*1 - 1 = 1
1577 //
1578 // So only the input at (1,1) with kernel at (0,0) contributes to the output at (1,1)
1579
1580 // Verify the output
1581 assert_eq!(output[[0, 0, 0, 0]], 0.0); // No contribution
1582 assert_eq!(output[[0, 0, 0, 1]], 0.0); // No contribution
1583 assert_eq!(output[[0, 0, 0, 2]], 0.0); // No contribution
1584 assert_eq!(output[[0, 0, 1, 0]], 0.0); // No contribution
1585 assert_eq!(output[[0, 0, 1, 1]], 1.0); // From input (1,1) with kernel (0,0)
1586 assert_eq!(output[[0, 0, 1, 2]], 0.0); // No contribution
1587 assert_eq!(output[[0, 0, 2, 0]], 0.0); // No contribution
1588 assert_eq!(output[[0, 0, 2, 1]], 0.0); // No contribution
1589 assert_eq!(output[[0, 0, 2, 2]], 0.0); // No contribution
1590 }
1591}