Skip to main content

tensorlogic_scirs_backend/
convolution.rs

1//! Convolution operations for neural network tensor processing.
2//!
3//! Provides 1D convolution, 2D convolution, transposed convolution, depthwise convolution,
4//! and im2col/col2im utilities for efficient convolution via matrix multiplication.
5
6use scirs2_core::ndarray::{ArrayD, IxDyn};
7
8/// Errors that can occur during convolution operations.
9#[derive(Debug, Clone)]
10pub enum ConvError {
11    /// Kernel size contains a zero or is otherwise invalid.
12    InvalidKernelSize(String),
13    /// Stride contains a zero or is otherwise invalid.
14    InvalidStride(String),
15    /// Padding value is invalid.
16    InvalidPadding(String),
17    /// Dilation value is invalid.
18    InvalidDilation(String),
19    /// Shape mismatch between expected and actual tensors.
20    ShapeMismatch {
21        expected: Vec<usize>,
22        got: Vec<usize>,
23    },
24    /// Input tensor does not have enough dimensions.
25    InsufficientDimensions { ndim: usize, required: usize },
26    /// Groups parameter is invalid for the given channel counts.
27    InvalidGroups {
28        groups: usize,
29        in_channels: usize,
30        out_channels: usize,
31    },
32    /// Input tensor is empty.
33    EmptyInput,
34}
35
36impl std::fmt::Display for ConvError {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            Self::InvalidKernelSize(msg) => write!(f, "Invalid kernel size: {msg}"),
40            Self::InvalidStride(msg) => write!(f, "Invalid stride: {msg}"),
41            Self::InvalidPadding(msg) => write!(f, "Invalid padding: {msg}"),
42            Self::InvalidDilation(msg) => write!(f, "Invalid dilation: {msg}"),
43            Self::ShapeMismatch { expected, got } => {
44                write!(f, "Shape mismatch: expected {expected:?}, got {got:?}")
45            }
46            Self::InsufficientDimensions { ndim, required } => {
47                write!(
48                    f,
49                    "Insufficient dimensions: got {ndim}, need at least {required}"
50                )
51            }
52            Self::InvalidGroups {
53                groups,
54                in_channels,
55                out_channels,
56            } => write!(
57                f,
58                "Invalid groups={groups}: in_channels={in_channels} and \
59                 out_channels={out_channels} must both be divisible by groups"
60            ),
61            Self::EmptyInput => write!(f, "Empty input tensor"),
62        }
63    }
64}
65
66impl std::error::Error for ConvError {}
67
68/// Convolution configuration specifying kernel size, stride, padding, dilation, and groups.
69#[derive(Debug, Clone)]
70pub struct ConvConfig {
71    /// Kernel size for each spatial dimension (e.g. `[3, 3]`).
72    pub kernel_size: Vec<usize>,
73    /// Stride for each spatial dimension (e.g. `[1, 1]`).
74    pub stride: Vec<usize>,
75    /// Zero-padding on each side for each spatial dimension (e.g. `[1, 1]`).
76    pub padding: Vec<usize>,
77    /// Dilation factor for each spatial dimension (e.g. `[1, 1]`).
78    pub dilation: Vec<usize>,
79    /// Number of groups for grouped convolution (1 = standard convolution).
80    pub groups: usize,
81}
82
83impl ConvConfig {
84    /// Create a new convolution config with the given kernel size.
85    /// Defaults: stride=1, padding=0, dilation=1, groups=1.
86    pub fn new(kernel_size: Vec<usize>) -> Self {
87        let ndim = kernel_size.len();
88        Self {
89            kernel_size,
90            stride: vec![1; ndim],
91            padding: vec![0; ndim],
92            dilation: vec![1; ndim],
93            groups: 1,
94        }
95    }
96
97    /// Set the stride (builder pattern).
98    pub fn with_stride(mut self, stride: Vec<usize>) -> Self {
99        self.stride = stride;
100        self
101    }
102
103    /// Set the padding (builder pattern).
104    pub fn with_padding(mut self, padding: Vec<usize>) -> Self {
105        self.padding = padding;
106        self
107    }
108
109    /// Set the dilation (builder pattern).
110    pub fn with_dilation(mut self, dilation: Vec<usize>) -> Self {
111        self.dilation = dilation;
112        self
113    }
114
115    /// Set the number of groups (builder pattern).
116    pub fn with_groups(mut self, groups: usize) -> Self {
117        self.groups = groups;
118        self
119    }
120
121    /// Compute the output size for one spatial dimension.
122    ///
123    /// Formula: `(input_size + 2*padding - dilation*(kernel-1) - 1) / stride + 1`
124    pub fn output_size(&self, input_size: usize, dim: usize) -> usize {
125        let k = self.kernel_size[dim];
126        let s = self.stride[dim];
127        let p = self.padding[dim];
128        let d = self.dilation[dim];
129        let effective_k = d * (k - 1) + 1;
130        (input_size + 2 * p - effective_k) / s + 1
131    }
132
133    /// Validate the configuration, returning an error if any parameter is invalid.
134    pub fn validate(&self) -> Result<(), ConvError> {
135        let ndim = self.kernel_size.len();
136
137        // All spatial parameter vectors must have the same length
138        if self.stride.len() != ndim {
139            return Err(ConvError::InvalidStride(format!(
140                "stride length {} != kernel_size length {ndim}",
141                self.stride.len()
142            )));
143        }
144        if self.padding.len() != ndim {
145            return Err(ConvError::InvalidPadding(format!(
146                "padding length {} != kernel_size length {ndim}",
147                self.padding.len()
148            )));
149        }
150        if self.dilation.len() != ndim {
151            return Err(ConvError::InvalidDilation(format!(
152                "dilation length {} != kernel_size length {ndim}",
153                self.dilation.len()
154            )));
155        }
156
157        for i in 0..ndim {
158            if self.kernel_size[i] == 0 {
159                return Err(ConvError::InvalidKernelSize(format!(
160                    "kernel_size[{i}] must be > 0"
161                )));
162            }
163            if self.stride[i] == 0 {
164                return Err(ConvError::InvalidStride(format!("stride[{i}] must be > 0")));
165            }
166            if self.dilation[i] == 0 {
167                return Err(ConvError::InvalidDilation(format!(
168                    "dilation[{i}] must be > 0"
169                )));
170            }
171        }
172
173        if self.groups == 0 {
174            return Err(ConvError::InvalidGroups {
175                groups: 0,
176                in_channels: 0,
177                out_channels: 0,
178            });
179        }
180
181        Ok(())
182    }
183
184    /// Number of spatial dimensions (length of kernel_size).
185    pub fn num_spatial_dims(&self) -> usize {
186        self.kernel_size.len()
187    }
188}
189
190/// 1D convolution.
191///
192/// - Input shape: `[batch, in_channels, length]`
193/// - Weight shape: `[out_channels, in_channels/groups, kernel_length]`
194/// - Output shape: `[batch, out_channels, output_length]`
195pub fn conv1d(
196    input: &ArrayD<f64>,
197    weight: &ArrayD<f64>,
198    bias: Option<&ArrayD<f64>>,
199    config: &ConvConfig,
200) -> Result<ArrayD<f64>, ConvError> {
201    config.validate()?;
202
203    let in_shape = input.shape();
204    if in_shape.is_empty() || input.is_empty() {
205        return Err(ConvError::EmptyInput);
206    }
207    if in_shape.len() != 3 {
208        return Err(ConvError::InsufficientDimensions {
209            ndim: in_shape.len(),
210            required: 3,
211        });
212    }
213
214    let w_shape = weight.shape();
215    if w_shape.len() != 3 {
216        return Err(ConvError::InsufficientDimensions {
217            ndim: w_shape.len(),
218            required: 3,
219        });
220    }
221
222    let batch = in_shape[0];
223    let in_channels = in_shape[1];
224    let in_len = in_shape[2];
225    let out_channels = w_shape[0];
226    let kernel_len = config.kernel_size[0];
227    let groups = config.groups;
228
229    // Validate groups
230    if !in_channels.is_multiple_of(groups) || !out_channels.is_multiple_of(groups) {
231        return Err(ConvError::InvalidGroups {
232            groups,
233            in_channels,
234            out_channels,
235        });
236    }
237
238    let out_len = config.output_size(in_len, 0);
239    let in_channels_per_group = in_channels / groups;
240    let out_channels_per_group = out_channels / groups;
241
242    let mut output = ArrayD::zeros(IxDyn(&[batch, out_channels, out_len]));
243
244    let stride = config.stride[0];
245    let padding = config.padding[0];
246    let dilation = config.dilation[0];
247
248    for b in 0..batch {
249        for g in 0..groups {
250            let oc_start = g * out_channels_per_group;
251            let ic_start = g * in_channels_per_group;
252
253            for oc in 0..out_channels_per_group {
254                for ol in 0..out_len {
255                    let mut sum = 0.0_f64;
256                    for ic in 0..in_channels_per_group {
257                        for kl in 0..kernel_len {
258                            let il_raw = ol as isize * stride as isize
259                                + kl as isize * dilation as isize
260                                - padding as isize;
261                            if il_raw >= 0 && (il_raw as usize) < in_len {
262                                let il = il_raw as usize;
263                                sum += input[[b, ic_start + ic, il].as_ref()]
264                                    * weight[[oc_start + oc, ic, kl].as_ref()];
265                            }
266                        }
267                    }
268                    output[[b, oc_start + oc, ol].as_ref()] = sum;
269                }
270            }
271        }
272    }
273
274    // Apply bias
275    if let Some(bias_arr) = bias {
276        for b in 0..batch {
277            for oc in 0..out_channels {
278                let bias_val = bias_arr[IxDyn(&[oc])];
279                for ol in 0..out_len {
280                    output[[b, oc, ol].as_ref()] += bias_val;
281                }
282            }
283        }
284    }
285
286    Ok(output)
287}
288
289/// 2D convolution.
290///
291/// - Input shape: `[batch, in_channels, height, width]`
292/// - Weight shape: `[out_channels, in_channels/groups, kH, kW]`
293/// - Output shape: `[batch, out_channels, outH, outW]`
294pub fn conv2d(
295    input: &ArrayD<f64>,
296    weight: &ArrayD<f64>,
297    bias: Option<&ArrayD<f64>>,
298    config: &ConvConfig,
299) -> Result<ArrayD<f64>, ConvError> {
300    config.validate()?;
301
302    let in_shape = input.shape();
303    if in_shape.is_empty() || input.is_empty() {
304        return Err(ConvError::EmptyInput);
305    }
306    if in_shape.len() != 4 {
307        return Err(ConvError::InsufficientDimensions {
308            ndim: in_shape.len(),
309            required: 4,
310        });
311    }
312
313    let w_shape = weight.shape();
314    if w_shape.len() != 4 {
315        return Err(ConvError::InsufficientDimensions {
316            ndim: w_shape.len(),
317            required: 4,
318        });
319    }
320
321    let batch = in_shape[0];
322    let in_channels = in_shape[1];
323    let in_h = in_shape[2];
324    let in_w = in_shape[3];
325    let out_channels = w_shape[0];
326    let groups = config.groups;
327
328    if !in_channels.is_multiple_of(groups) || !out_channels.is_multiple_of(groups) {
329        return Err(ConvError::InvalidGroups {
330            groups,
331            in_channels,
332            out_channels,
333        });
334    }
335
336    let out_h = config.output_size(in_h, 0);
337    let out_w = config.output_size(in_w, 1);
338    let in_channels_per_group = in_channels / groups;
339    let out_channels_per_group = out_channels / groups;
340
341    let k_h = config.kernel_size[0];
342    let k_w = config.kernel_size[1];
343    let stride_h = config.stride[0];
344    let stride_w = config.stride[1];
345    let pad_h = config.padding[0];
346    let pad_w = config.padding[1];
347    let dil_h = config.dilation[0];
348    let dil_w = config.dilation[1];
349
350    let mut output = ArrayD::zeros(IxDyn(&[batch, out_channels, out_h, out_w]));
351
352    for b in 0..batch {
353        for g in 0..groups {
354            let oc_start = g * out_channels_per_group;
355            let ic_start = g * in_channels_per_group;
356
357            for oc in 0..out_channels_per_group {
358                for oh in 0..out_h {
359                    for ow in 0..out_w {
360                        let mut sum = 0.0_f64;
361                        for ic in 0..in_channels_per_group {
362                            for kh in 0..k_h {
363                                for kw in 0..k_w {
364                                    let ih_raw = oh as isize * stride_h as isize
365                                        + kh as isize * dil_h as isize
366                                        - pad_h as isize;
367                                    let iw_raw = ow as isize * stride_w as isize
368                                        + kw as isize * dil_w as isize
369                                        - pad_w as isize;
370                                    if ih_raw >= 0
371                                        && (ih_raw as usize) < in_h
372                                        && iw_raw >= 0
373                                        && (iw_raw as usize) < in_w
374                                    {
375                                        let ih = ih_raw as usize;
376                                        let iw = iw_raw as usize;
377                                        sum += input[IxDyn(&[b, ic_start + ic, ih, iw])]
378                                            * weight[IxDyn(&[oc_start + oc, ic, kh, kw])];
379                                    }
380                                }
381                            }
382                        }
383                        output[IxDyn(&[b, oc_start + oc, oh, ow])] = sum;
384                    }
385                }
386            }
387        }
388    }
389
390    // Apply bias
391    if let Some(bias_arr) = bias {
392        for b in 0..batch {
393            for oc in 0..out_channels {
394                let bias_val = bias_arr[IxDyn(&[oc])];
395                for oh in 0..out_h {
396                    for ow in 0..out_w {
397                        output[IxDyn(&[b, oc, oh, ow])] += bias_val;
398                    }
399                }
400            }
401        }
402    }
403
404    Ok(output)
405}
406
407/// Transposed 2D convolution (deconvolution / fractionally-strided convolution).
408///
409/// - Input shape: `[batch, in_channels, height, width]`
410/// - Weight shape: `[in_channels, out_channels/groups, kH, kW]`
411/// - Output shape: `[batch, out_channels, outH, outW]`
412///
413/// Output size formula per dimension:
414/// `(input - 1) * stride - 2*padding + dilation*(kernel - 1) + output_padding + 1`
415pub fn conv_transpose2d(
416    input: &ArrayD<f64>,
417    weight: &ArrayD<f64>,
418    bias: Option<&ArrayD<f64>>,
419    config: &ConvConfig,
420    output_padding: &[usize],
421) -> Result<ArrayD<f64>, ConvError> {
422    config.validate()?;
423
424    let in_shape = input.shape();
425    if in_shape.is_empty() || input.is_empty() {
426        return Err(ConvError::EmptyInput);
427    }
428    if in_shape.len() != 4 {
429        return Err(ConvError::InsufficientDimensions {
430            ndim: in_shape.len(),
431            required: 4,
432        });
433    }
434
435    let w_shape = weight.shape();
436    if w_shape.len() != 4 {
437        return Err(ConvError::InsufficientDimensions {
438            ndim: w_shape.len(),
439            required: 4,
440        });
441    }
442
443    let batch = in_shape[0];
444    let in_channels = in_shape[1];
445    let in_h = in_shape[2];
446    let in_w = in_shape[3];
447    let groups = config.groups;
448
449    // For transposed conv, weight is [in_channels, out_channels/groups, kH, kW]
450    let out_channels_per_group = w_shape[1];
451    let out_channels = out_channels_per_group * groups;
452
453    if !in_channels.is_multiple_of(groups) {
454        return Err(ConvError::InvalidGroups {
455            groups,
456            in_channels,
457            out_channels,
458        });
459    }
460
461    let in_channels_per_group = in_channels / groups;
462    let k_h = config.kernel_size[0];
463    let k_w = config.kernel_size[1];
464    let stride_h = config.stride[0];
465    let stride_w = config.stride[1];
466    let pad_h = config.padding[0];
467    let pad_w = config.padding[1];
468    let dil_h = config.dilation[0];
469    let dil_w = config.dilation[1];
470
471    let out_pad_h = if output_padding.is_empty() {
472        0
473    } else {
474        output_padding[0]
475    };
476    let out_pad_w = if output_padding.len() < 2 {
477        0
478    } else {
479        output_padding[1]
480    };
481
482    let out_h = (in_h - 1) * stride_h + dil_h * (k_h - 1) + 1 + out_pad_h - 2 * pad_h;
483    let out_w = (in_w - 1) * stride_w + dil_w * (k_w - 1) + 1 + out_pad_w - 2 * pad_w;
484
485    let mut output = ArrayD::zeros(IxDyn(&[batch, out_channels, out_h, out_w]));
486
487    // Transposed convolution: for each input position, scatter-add weighted kernel to output
488    for b in 0..batch {
489        for g in 0..groups {
490            let ic_start = g * in_channels_per_group;
491            let oc_start = g * out_channels_per_group;
492
493            for ic in 0..in_channels_per_group {
494                for ih in 0..in_h {
495                    for iw in 0..in_w {
496                        let input_val = input[IxDyn(&[b, ic_start + ic, ih, iw])];
497                        for oc in 0..out_channels_per_group {
498                            for kh in 0..k_h {
499                                for kw in 0..k_w {
500                                    let oh_raw = ih as isize * stride_h as isize
501                                        + kh as isize * dil_h as isize
502                                        - pad_h as isize;
503                                    let ow_raw = iw as isize * stride_w as isize
504                                        + kw as isize * dil_w as isize
505                                        - pad_w as isize;
506                                    if oh_raw >= 0
507                                        && (oh_raw as usize) < out_h
508                                        && ow_raw >= 0
509                                        && (ow_raw as usize) < out_w
510                                    {
511                                        let oh = oh_raw as usize;
512                                        let ow = ow_raw as usize;
513                                        output[IxDyn(&[b, oc_start + oc, oh, ow])] +=
514                                            input_val * weight[IxDyn(&[ic_start + ic, oc, kh, kw])];
515                                    }
516                                }
517                            }
518                        }
519                    }
520                }
521            }
522        }
523    }
524
525    // Apply bias
526    if let Some(bias_arr) = bias {
527        for b in 0..batch {
528            for oc in 0..out_channels {
529                let bias_val = bias_arr[IxDyn(&[oc])];
530                for oh in 0..out_h {
531                    for ow in 0..out_w {
532                        output[IxDyn(&[b, oc, oh, ow])] += bias_val;
533                    }
534                }
535            }
536        }
537    }
538
539    Ok(output)
540}
541
542/// Depthwise 2D convolution: groups == in_channels == out_channels.
543///
544/// Convenience wrapper around [`conv2d`] that sets `groups = in_channels`.
545/// Weight shape: `[in_channels, 1, kH, kW]` (one filter per channel).
546pub fn depthwise_conv2d(
547    input: &ArrayD<f64>,
548    weight: &ArrayD<f64>,
549    bias: Option<&ArrayD<f64>>,
550    config: &ConvConfig,
551) -> Result<ArrayD<f64>, ConvError> {
552    let in_shape = input.shape();
553    if in_shape.len() < 4 {
554        return Err(ConvError::InsufficientDimensions {
555            ndim: in_shape.len(),
556            required: 4,
557        });
558    }
559
560    let in_channels = in_shape[1];
561    let mut dw_config = config.clone();
562    dw_config.groups = in_channels;
563
564    conv2d(input, weight, bias, &dw_config)
565}
566
567/// im2col: unfold input patches into columns for efficient convolution via GEMM.
568///
569/// - Input shape: `[batch, channels, H, W]`
570/// - Output shape: `[batch, channels * kH * kW, outH * outW]`
571pub fn im2col(
572    input: &ArrayD<f64>,
573    kernel_size: &[usize],
574    stride: &[usize],
575    padding: &[usize],
576    dilation: &[usize],
577) -> Result<ArrayD<f64>, ConvError> {
578    let in_shape = input.shape();
579    if in_shape.is_empty() || input.is_empty() {
580        return Err(ConvError::EmptyInput);
581    }
582    if in_shape.len() != 4 {
583        return Err(ConvError::InsufficientDimensions {
584            ndim: in_shape.len(),
585            required: 4,
586        });
587    }
588    if kernel_size.len() != 2 || stride.len() != 2 || padding.len() != 2 || dilation.len() != 2 {
589        return Err(ConvError::InvalidKernelSize(
590            "im2col requires exactly 2 spatial dimensions".to_string(),
591        ));
592    }
593
594    let batch = in_shape[0];
595    let channels = in_shape[1];
596    let in_h = in_shape[2];
597    let in_w = in_shape[3];
598    let k_h = kernel_size[0];
599    let k_w = kernel_size[1];
600    let s_h = stride[0];
601    let s_w = stride[1];
602    let p_h = padding[0];
603    let p_w = padding[1];
604    let d_h = dilation[0];
605    let d_w = dilation[1];
606
607    let eff_k_h = d_h * (k_h - 1) + 1;
608    let eff_k_w = d_w * (k_w - 1) + 1;
609    let out_h = (in_h + 2 * p_h - eff_k_h) / s_h + 1;
610    let out_w = (in_w + 2 * p_w - eff_k_w) / s_w + 1;
611
612    let col_rows = channels * k_h * k_w;
613    let col_cols = out_h * out_w;
614    let mut cols = ArrayD::zeros(IxDyn(&[batch, col_rows, col_cols]));
615
616    for b in 0..batch {
617        let mut col_idx = 0;
618        for c in 0..channels {
619            for kh in 0..k_h {
620                for kw in 0..k_w {
621                    let mut spatial_idx = 0;
622                    for oh in 0..out_h {
623                        for ow in 0..out_w {
624                            let ih_raw = oh as isize * s_h as isize + kh as isize * d_h as isize
625                                - p_h as isize;
626                            let iw_raw = ow as isize * s_w as isize + kw as isize * d_w as isize
627                                - p_w as isize;
628                            let val = if ih_raw >= 0
629                                && (ih_raw as usize) < in_h
630                                && iw_raw >= 0
631                                && (iw_raw as usize) < in_w
632                            {
633                                input[IxDyn(&[b, c, ih_raw as usize, iw_raw as usize])]
634                            } else {
635                                0.0
636                            };
637                            cols[IxDyn(&[b, col_idx, spatial_idx])] = val;
638                            spatial_idx += 1;
639                        }
640                    }
641                    col_idx += 1;
642                }
643            }
644        }
645    }
646
647    Ok(cols)
648}
649
650/// col2im: fold columns back into image form (inverse of im2col).
651///
652/// - Cols shape: `[batch, channels * kH * kW, outH * outW]`
653/// - Output shape: `[batch, channels, H, W]` (specified via `output_size`)
654///
655/// Where overlapping patches are summed (accumulated).
656pub fn col2im(
657    cols: &ArrayD<f64>,
658    output_size: &[usize],
659    kernel_size: &[usize],
660    stride: &[usize],
661    padding: &[usize],
662    dilation: &[usize],
663) -> Result<ArrayD<f64>, ConvError> {
664    let col_shape = cols.shape();
665    if col_shape.is_empty() || cols.is_empty() {
666        return Err(ConvError::EmptyInput);
667    }
668    if col_shape.len() != 3 {
669        return Err(ConvError::InsufficientDimensions {
670            ndim: col_shape.len(),
671            required: 3,
672        });
673    }
674    if output_size.len() != 4 {
675        return Err(ConvError::InvalidKernelSize(
676            "output_size must have 4 elements [batch, channels, H, W]".to_string(),
677        ));
678    }
679
680    let batch = output_size[0];
681    let channels = output_size[1];
682    let out_h_img = output_size[2];
683    let out_w_img = output_size[3];
684
685    let k_h = kernel_size[0];
686    let k_w = kernel_size[1];
687    let s_h = stride[0];
688    let s_w = stride[1];
689    let p_h = padding[0];
690    let p_w = padding[1];
691    let d_h = dilation[0];
692    let d_w = dilation[1];
693
694    let eff_k_h = d_h * (k_h - 1) + 1;
695    let eff_k_w = d_w * (k_w - 1) + 1;
696    let col_out_h = (out_h_img + 2 * p_h - eff_k_h) / s_h + 1;
697    let col_out_w = (out_w_img + 2 * p_w - eff_k_w) / s_w + 1;
698
699    let mut output = ArrayD::zeros(IxDyn(&[batch, channels, out_h_img, out_w_img]));
700
701    for b in 0..batch {
702        let mut col_idx = 0;
703        for c in 0..channels {
704            for kh in 0..k_h {
705                for kw in 0..k_w {
706                    let mut spatial_idx = 0;
707                    for oh in 0..col_out_h {
708                        for ow in 0..col_out_w {
709                            let ih_raw = oh as isize * s_h as isize + kh as isize * d_h as isize
710                                - p_h as isize;
711                            let iw_raw = ow as isize * s_w as isize + kw as isize * d_w as isize
712                                - p_w as isize;
713                            if ih_raw >= 0
714                                && (ih_raw as usize) < out_h_img
715                                && iw_raw >= 0
716                                && (iw_raw as usize) < out_w_img
717                            {
718                                output[IxDyn(&[b, c, ih_raw as usize, iw_raw as usize])] +=
719                                    cols[IxDyn(&[b, col_idx, spatial_idx])];
720                            }
721                            spatial_idx += 1;
722                        }
723                    }
724                    col_idx += 1;
725                }
726            }
727        }
728    }
729
730    Ok(output)
731}
732
733/// Statistics about a convolution operation (parameter count, FLOPs, receptive field).
734#[derive(Debug, Clone)]
735pub struct ConvStats {
736    /// Shape of the input tensor.
737    pub input_shape: Vec<usize>,
738    /// Shape of the output tensor.
739    pub output_shape: Vec<usize>,
740    /// Shape of the kernel/weight tensor.
741    pub kernel_shape: Vec<usize>,
742    /// Total number of learnable parameters (weights + bias if present).
743    pub num_parameters: usize,
744    /// Estimated floating-point operations (multiply-accumulate counted as 2).
745    pub flops: u64,
746    /// Receptive field size in each spatial dimension.
747    pub receptive_field: Vec<usize>,
748}
749
750impl ConvStats {
751    /// Compute convolution statistics from input/weight shapes and config.
752    ///
753    /// Input shape: `[batch, in_channels, spatial...]`
754    /// Weight shape: `[out_channels, in_channels/groups, kernel_spatial...]`
755    pub fn compute(
756        input_shape: &[usize],
757        weight_shape: &[usize],
758        config: &ConvConfig,
759    ) -> Result<Self, ConvError> {
760        config.validate()?;
761
762        if input_shape.len() < 3 {
763            return Err(ConvError::InsufficientDimensions {
764                ndim: input_shape.len(),
765                required: 3,
766            });
767        }
768        if weight_shape.len() < 3 {
769            return Err(ConvError::InsufficientDimensions {
770                ndim: weight_shape.len(),
771                required: 3,
772            });
773        }
774
775        let batch = input_shape[0];
776        let out_channels = weight_shape[0];
777        let ndim = config.num_spatial_dims();
778
779        // Compute output spatial dimensions
780        let mut output_spatial = Vec::with_capacity(ndim);
781        for d in 0..ndim {
782            let in_size = input_shape[2 + d];
783            output_spatial.push(config.output_size(in_size, d));
784        }
785
786        let mut output_shape = vec![batch, out_channels];
787        output_shape.extend_from_slice(&output_spatial);
788
789        // Number of parameters: weight elements + bias (out_channels)
790        let weight_params: usize = weight_shape.iter().product();
791        let num_parameters = weight_params + out_channels; // assume bias present
792
793        // FLOPs: for each output element, we do kernel_volume * in_channels_per_group
794        // multiply-accumulates. Each MAC = 2 ops (mul + add).
795        let kernel_volume: usize = config.kernel_size.iter().product();
796        let in_channels_per_group = if config.groups > 0 {
797            weight_shape[1]
798        } else {
799            return Err(ConvError::InvalidGroups {
800                groups: 0,
801                in_channels: 0,
802                out_channels: 0,
803            });
804        };
805        let output_elements: u64 = output_shape.iter().map(|&s| s as u64).product();
806        let macs_per_element = (kernel_volume * in_channels_per_group) as u64;
807        let flops = output_elements * macs_per_element * 2;
808
809        // Receptive field per spatial dim: dilation * (kernel - 1) + 1
810        let receptive_field: Vec<usize> = (0..ndim)
811            .map(|d| config.dilation[d] * (config.kernel_size[d] - 1) + 1)
812            .collect();
813
814        Ok(Self {
815            input_shape: input_shape.to_vec(),
816            output_shape,
817            kernel_shape: weight_shape.to_vec(),
818            num_parameters,
819            flops,
820            receptive_field,
821        })
822    }
823
824    /// Human-readable summary string.
825    pub fn summary(&self) -> String {
826        format!(
827            "ConvStats {{ input: {:?}, output: {:?}, kernel: {:?}, \
828             params: {}, flops: {}, receptive_field: {:?} }}",
829            self.input_shape,
830            self.output_shape,
831            self.kernel_shape,
832            self.num_parameters,
833            self.flops,
834            self.receptive_field,
835        )
836    }
837}
838
839#[cfg(test)]
840mod tests {
841    use super::*;
842    use scirs2_core::ndarray::{ArrayD, IxDyn};
843
844    #[test]
845    fn test_conv_config_output_size() {
846        // kernel=3, stride=1, pad=1, dilation=1: same size
847        let cfg = ConvConfig::new(vec![3, 3]).with_padding(vec![1, 1]);
848        assert_eq!(cfg.output_size(8, 0), 8);
849        assert_eq!(cfg.output_size(8, 1), 8);
850    }
851
852    #[test]
853    fn test_conv_config_validate_valid() {
854        let cfg = ConvConfig::new(vec![3, 3])
855            .with_stride(vec![1, 1])
856            .with_padding(vec![1, 1])
857            .with_dilation(vec![1, 1])
858            .with_groups(1);
859        assert!(cfg.validate().is_ok());
860    }
861
862    #[test]
863    fn test_conv_config_validate_zero_kernel() {
864        let cfg = ConvConfig::new(vec![0, 3]);
865        let err = cfg.validate();
866        assert!(err.is_err());
867        let msg = format!("{}", err.expect_err("expected error"));
868        assert!(msg.contains("kernel_size"));
869    }
870
871    #[test]
872    fn test_conv1d_basic() {
873        // Input: [1, 1, 5], Kernel: [1, 1, 3], no padding, stride=1
874        // Output length = (5 - 3) / 1 + 1 = 3
875        let input = ArrayD::from_shape_vec(IxDyn(&[1, 1, 5]), vec![1.0, 2.0, 3.0, 4.0, 5.0])
876            .expect("input shape");
877        let weight =
878            ArrayD::from_shape_vec(IxDyn(&[1, 1, 3]), vec![1.0, 1.0, 1.0]).expect("weight shape");
879        let cfg = ConvConfig::new(vec![3]);
880
881        let out = conv1d(&input, &weight, None, &cfg).expect("conv1d");
882        assert_eq!(out.shape(), &[1, 1, 3]);
883        // [1+2+3, 2+3+4, 3+4+5] = [6, 9, 12]
884        assert!((out[IxDyn(&[0, 0, 0])] - 6.0).abs() < 1e-10);
885        assert!((out[IxDyn(&[0, 0, 1])] - 9.0).abs() < 1e-10);
886        assert!((out[IxDyn(&[0, 0, 2])] - 12.0).abs() < 1e-10);
887    }
888
889    #[test]
890    fn test_conv1d_with_bias() {
891        let input =
892            ArrayD::from_shape_vec(IxDyn(&[1, 1, 3]), vec![1.0, 2.0, 3.0]).expect("input shape");
893        let weight = ArrayD::from_shape_vec(IxDyn(&[2, 1, 3]), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0])
894            .expect("weight shape");
895        let bias = ArrayD::from_shape_vec(IxDyn(&[2]), vec![10.0, 20.0]).expect("bias shape");
896        let cfg = ConvConfig::new(vec![3]);
897
898        let out = conv1d(&input, &weight, Some(&bias), &cfg).expect("conv1d");
899        assert_eq!(out.shape(), &[1, 2, 1]);
900        // channel 0: 1*1 + 0*2 + 0*3 + 10 = 11
901        assert!((out[IxDyn(&[0, 0, 0])] - 11.0).abs() < 1e-10);
902        // channel 1: 0*1 + 0*2 + 1*3 + 20 = 23
903        assert!((out[IxDyn(&[0, 1, 0])] - 23.0).abs() < 1e-10);
904    }
905
906    #[test]
907    fn test_conv2d_identity_kernel() {
908        // 1x1 kernel = channel mixing only, spatial preserved
909        let input = ArrayD::from_shape_vec(
910            IxDyn(&[1, 2, 2, 2]),
911            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
912        )
913        .expect("input shape");
914        // Weight [1, 2, 1, 1]: output channel 0 = 1*ch0 + 1*ch1
915        let weight =
916            ArrayD::from_shape_vec(IxDyn(&[1, 2, 1, 1]), vec![1.0, 1.0]).expect("weight shape");
917        let cfg = ConvConfig::new(vec![1, 1]);
918
919        let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
920        assert_eq!(out.shape(), &[1, 1, 2, 2]);
921        // (0,0): 1+5=6, (0,1): 2+6=8, (1,0): 3+7=10, (1,1): 4+8=12
922        assert!((out[IxDyn(&[0, 0, 0, 0])] - 6.0).abs() < 1e-10);
923        assert!((out[IxDyn(&[0, 0, 0, 1])] - 8.0).abs() < 1e-10);
924        assert!((out[IxDyn(&[0, 0, 1, 0])] - 10.0).abs() < 1e-10);
925        assert!((out[IxDyn(&[0, 0, 1, 1])] - 12.0).abs() < 1e-10);
926    }
927
928    #[test]
929    fn test_conv2d_basic() {
930        // [1,1,4,4] input, [1,1,3,3] kernel, no padding → [1,1,2,2]
931        let input =
932            ArrayD::from_shape_vec(IxDyn(&[1, 1, 4, 4]), (1..=16).map(|x| x as f64).collect())
933                .expect("input shape");
934        let weight = ArrayD::ones(IxDyn(&[1, 1, 3, 3]));
935        let cfg = ConvConfig::new(vec![3, 3]);
936
937        let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
938        assert_eq!(out.shape(), &[1, 1, 2, 2]);
939
940        // Top-left 3x3: 1+2+3+5+6+7+9+10+11 = 54
941        assert!((out[IxDyn(&[0, 0, 0, 0])] - 54.0).abs() < 1e-10);
942    }
943
944    #[test]
945    fn test_conv2d_with_padding() {
946        // 3x3 kernel, padding=1 → same spatial size
947        let input = ArrayD::ones(IxDyn(&[1, 1, 4, 4]));
948        let weight = ArrayD::ones(IxDyn(&[1, 1, 3, 3]));
949        let cfg = ConvConfig::new(vec![3, 3]).with_padding(vec![1, 1]);
950
951        let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
952        assert_eq!(out.shape(), &[1, 1, 4, 4]);
953
954        // Center pixel: all 9 neighbors present → 9.0
955        assert!((out[IxDyn(&[0, 0, 1, 1])] - 9.0).abs() < 1e-10);
956        // Corner: 4 neighbors present → 4.0
957        assert!((out[IxDyn(&[0, 0, 0, 0])] - 4.0).abs() < 1e-10);
958    }
959
960    #[test]
961    fn test_conv2d_stride2() {
962        // stride=2 → output halved
963        let input = ArrayD::ones(IxDyn(&[1, 1, 4, 4]));
964        let weight = ArrayD::ones(IxDyn(&[1, 1, 3, 3]));
965        let cfg = ConvConfig::new(vec![3, 3])
966            .with_stride(vec![2, 2])
967            .with_padding(vec![1, 1]);
968
969        let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
970        // output_size = (4 + 2 - 3) / 2 + 1 = 2
971        assert_eq!(out.shape(), &[1, 1, 2, 2]);
972    }
973
974    #[test]
975    fn test_conv2d_groups() {
976        // 2 input channels, 2 output channels, groups=2 → each group has 1 in/out channel
977        let input = ArrayD::from_shape_vec(
978            IxDyn(&[1, 2, 3, 3]),
979            vec![
980                // ch0: all 1s
981                1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // ch1: all 2s
982                2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
983            ],
984        )
985        .expect("input shape");
986        // Weight: [2, 1, 3, 3] — 2 output channels, 1 input channel per group
987        let weight = ArrayD::ones(IxDyn(&[2, 1, 3, 3]));
988        let cfg = ConvConfig::new(vec![3, 3]).with_groups(2);
989
990        let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
991        assert_eq!(out.shape(), &[1, 2, 1, 1]);
992        // Group 0: sum of 1s in 3x3 = 9
993        assert!((out[IxDyn(&[0, 0, 0, 0])] - 9.0).abs() < 1e-10);
994        // Group 1: sum of 2s in 3x3 = 18
995        assert!((out[IxDyn(&[0, 1, 0, 0])] - 18.0).abs() < 1e-10);
996    }
997
998    #[test]
999    fn test_conv2d_dilation() {
1000        // dilation=2 with 3x3 kernel: effective kernel = 5x5
1001        let input = ArrayD::ones(IxDyn(&[1, 1, 7, 7]));
1002        let weight = ArrayD::ones(IxDyn(&[1, 1, 3, 3]));
1003        let cfg = ConvConfig::new(vec![3, 3]).with_dilation(vec![2, 2]);
1004
1005        let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
1006        // output_size = (7 - 2*(3-1) - 1) / 1 + 1 = (7 - 5) / 1 + 1 = 3
1007        assert_eq!(out.shape(), &[1, 1, 3, 3]);
1008        // All 9 sampled positions are within bounds and input=1 → sum=9
1009        assert!((out[IxDyn(&[0, 0, 1, 1])] - 9.0).abs() < 1e-10);
1010    }
1011
1012    #[test]
1013    fn test_conv_transpose2d_basic() {
1014        // Input [1,1,2,2], weight [1,1,3,3], stride=2 → upsamples
1015        // output_size = (2-1)*2 + 3 + 0 - 0 = 5 per dim
1016        let input = ArrayD::ones(IxDyn(&[1, 1, 2, 2]));
1017        let weight = ArrayD::ones(IxDyn(&[1, 1, 3, 3]));
1018        let cfg = ConvConfig::new(vec![3, 3]).with_stride(vec![2, 2]);
1019
1020        let out = conv_transpose2d(&input, &weight, None, &cfg, &[]).expect("conv_transpose2d");
1021        assert_eq!(out.shape(), &[1, 1, 5, 5]);
1022        // Center (2,2): overlapped by all 4 input positions → 4
1023        assert!((out[IxDyn(&[0, 0, 2, 2])] - 4.0).abs() < 1e-10);
1024        // Corner (0,0): only 1 input contributes → 1
1025        assert!((out[IxDyn(&[0, 0, 0, 0])] - 1.0).abs() < 1e-10);
1026    }
1027
1028    #[test]
1029    fn test_depthwise_conv2d() {
1030        // 2 channels, depthwise: each channel convolved independently
1031        let input = ArrayD::from_shape_vec(
1032            IxDyn(&[1, 2, 3, 3]),
1033            vec![
1034                1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // ch0
1035                2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, // ch1
1036            ],
1037        )
1038        .expect("input shape");
1039        // Weight: [2, 1, 3, 3] — each channel has its own filter
1040        let weight = ArrayD::ones(IxDyn(&[2, 1, 3, 3]));
1041        let cfg = ConvConfig::new(vec![3, 3]);
1042
1043        let out = depthwise_conv2d(&input, &weight, None, &cfg).expect("depthwise");
1044        assert_eq!(out.shape(), &[1, 2, 1, 1]);
1045        assert!((out[IxDyn(&[0, 0, 0, 0])] - 9.0).abs() < 1e-10);
1046        assert!((out[IxDyn(&[0, 1, 0, 0])] - 18.0).abs() < 1e-10);
1047    }
1048
1049    #[test]
1050    fn test_im2col_shape() {
1051        // [1, 2, 4, 4], kernel 3x3, stride 1, pad 0 → cols: [1, 2*3*3=18, 2*2=4]
1052        let input = ArrayD::ones(IxDyn(&[1, 2, 4, 4]));
1053        let cols = im2col(&input, &[3, 3], &[1, 1], &[0, 0], &[1, 1]).expect("im2col");
1054        assert_eq!(cols.shape(), &[1, 18, 4]);
1055    }
1056
1057    #[test]
1058    fn test_im2col_values() {
1059        // [1, 1, 3, 3] input with values 1..9, kernel=2x2
1060        let input = ArrayD::from_shape_vec(
1061            IxDyn(&[1, 1, 3, 3]),
1062            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1063        )
1064        .expect("input shape");
1065
1066        let cols = im2col(&input, &[2, 2], &[1, 1], &[0, 0], &[1, 1]).expect("im2col");
1067        // output spatial: (3-2)/1+1 = 2 per dim → 4 columns
1068        assert_eq!(cols.shape(), &[1, 4, 4]);
1069
1070        // First column (oh=0,ow=0): patch at (0,0)→(1,1) = [1,2,4,5]
1071        assert!((cols[IxDyn(&[0, 0, 0])] - 1.0).abs() < 1e-10);
1072        assert!((cols[IxDyn(&[0, 1, 0])] - 2.0).abs() < 1e-10);
1073        assert!((cols[IxDyn(&[0, 2, 0])] - 4.0).abs() < 1e-10);
1074        assert!((cols[IxDyn(&[0, 3, 0])] - 5.0).abs() < 1e-10);
1075    }
1076
1077    #[test]
1078    fn test_col2im_roundtrip_no_overlap() {
1079        // With stride >= kernel, patches don't overlap → roundtrip is exact
1080        let input =
1081            ArrayD::from_shape_vec(IxDyn(&[1, 1, 4, 4]), (1..=16).map(|x| x as f64).collect())
1082                .expect("input shape");
1083
1084        let kernel = [2, 2];
1085        let stride = [2, 2];
1086        let padding = [0, 0];
1087        let dilation = [1, 1];
1088
1089        let cols = im2col(&input, &kernel, &stride, &padding, &dilation).expect("im2col");
1090        let reconstructed =
1091            col2im(&cols, &[1, 1, 4, 4], &kernel, &stride, &padding, &dilation).expect("col2im");
1092
1093        assert_eq!(reconstructed.shape(), input.shape());
1094        for (a, b) in input.iter().zip(reconstructed.iter()) {
1095            assert!((a - b).abs() < 1e-10, "mismatch: {a} vs {b}");
1096        }
1097    }
1098
1099    #[test]
1100    fn test_conv_stats_flops() {
1101        let cfg = ConvConfig::new(vec![3, 3]);
1102        let stats = ConvStats::compute(&[1, 3, 32, 32], &[16, 3, 3, 3], &cfg).expect("conv stats");
1103        assert!(stats.flops > 0);
1104    }
1105
1106    #[test]
1107    fn test_conv_stats_parameters() {
1108        // Weight [16, 3, 3, 3] = 432 + 16 bias = 448
1109        let cfg = ConvConfig::new(vec![3, 3]);
1110        let stats = ConvStats::compute(&[1, 3, 32, 32], &[16, 3, 3, 3], &cfg).expect("conv stats");
1111        assert_eq!(stats.num_parameters, 432 + 16);
1112    }
1113
1114    #[test]
1115    fn test_conv_stats_summary_nonempty() {
1116        let cfg = ConvConfig::new(vec![3, 3]);
1117        let stats = ConvStats::compute(&[1, 3, 32, 32], &[16, 3, 3, 3], &cfg).expect("conv stats");
1118        let s = stats.summary();
1119        assert!(!s.is_empty());
1120        assert!(s.contains("ConvStats"));
1121    }
1122
1123    #[test]
1124    fn test_conv_error_display() {
1125        let errors: Vec<ConvError> = vec![
1126            ConvError::InvalidKernelSize("zero".to_string()),
1127            ConvError::InvalidStride("zero".to_string()),
1128            ConvError::InvalidPadding("negative".to_string()),
1129            ConvError::InvalidDilation("zero".to_string()),
1130            ConvError::ShapeMismatch {
1131                expected: vec![1, 2],
1132                got: vec![3, 4],
1133            },
1134            ConvError::InsufficientDimensions {
1135                ndim: 2,
1136                required: 4,
1137            },
1138            ConvError::InvalidGroups {
1139                groups: 3,
1140                in_channels: 4,
1141                out_channels: 6,
1142            },
1143            ConvError::EmptyInput,
1144        ];
1145        for err in &errors {
1146            let msg = format!("{err}");
1147            assert!(!msg.is_empty(), "error display should be non-empty");
1148        }
1149    }
1150}