1use scirs2_core::ndarray::{ArrayD, IxDyn};
7
8#[derive(Debug, Clone)]
10pub enum ConvError {
11 InvalidKernelSize(String),
13 InvalidStride(String),
15 InvalidPadding(String),
17 InvalidDilation(String),
19 ShapeMismatch {
21 expected: Vec<usize>,
22 got: Vec<usize>,
23 },
24 InsufficientDimensions { ndim: usize, required: usize },
26 InvalidGroups {
28 groups: usize,
29 in_channels: usize,
30 out_channels: usize,
31 },
32 EmptyInput,
34}
35
36impl std::fmt::Display for ConvError {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 Self::InvalidKernelSize(msg) => write!(f, "Invalid kernel size: {msg}"),
40 Self::InvalidStride(msg) => write!(f, "Invalid stride: {msg}"),
41 Self::InvalidPadding(msg) => write!(f, "Invalid padding: {msg}"),
42 Self::InvalidDilation(msg) => write!(f, "Invalid dilation: {msg}"),
43 Self::ShapeMismatch { expected, got } => {
44 write!(f, "Shape mismatch: expected {expected:?}, got {got:?}")
45 }
46 Self::InsufficientDimensions { ndim, required } => {
47 write!(
48 f,
49 "Insufficient dimensions: got {ndim}, need at least {required}"
50 )
51 }
52 Self::InvalidGroups {
53 groups,
54 in_channels,
55 out_channels,
56 } => write!(
57 f,
58 "Invalid groups={groups}: in_channels={in_channels} and \
59 out_channels={out_channels} must both be divisible by groups"
60 ),
61 Self::EmptyInput => write!(f, "Empty input tensor"),
62 }
63 }
64}
65
66impl std::error::Error for ConvError {}
67
68#[derive(Debug, Clone)]
70pub struct ConvConfig {
71 pub kernel_size: Vec<usize>,
73 pub stride: Vec<usize>,
75 pub padding: Vec<usize>,
77 pub dilation: Vec<usize>,
79 pub groups: usize,
81}
82
83impl ConvConfig {
84 pub fn new(kernel_size: Vec<usize>) -> Self {
87 let ndim = kernel_size.len();
88 Self {
89 kernel_size,
90 stride: vec![1; ndim],
91 padding: vec![0; ndim],
92 dilation: vec![1; ndim],
93 groups: 1,
94 }
95 }
96
97 pub fn with_stride(mut self, stride: Vec<usize>) -> Self {
99 self.stride = stride;
100 self
101 }
102
103 pub fn with_padding(mut self, padding: Vec<usize>) -> Self {
105 self.padding = padding;
106 self
107 }
108
109 pub fn with_dilation(mut self, dilation: Vec<usize>) -> Self {
111 self.dilation = dilation;
112 self
113 }
114
115 pub fn with_groups(mut self, groups: usize) -> Self {
117 self.groups = groups;
118 self
119 }
120
121 pub fn output_size(&self, input_size: usize, dim: usize) -> usize {
125 let k = self.kernel_size[dim];
126 let s = self.stride[dim];
127 let p = self.padding[dim];
128 let d = self.dilation[dim];
129 let effective_k = d * (k - 1) + 1;
130 (input_size + 2 * p - effective_k) / s + 1
131 }
132
133 pub fn validate(&self) -> Result<(), ConvError> {
135 let ndim = self.kernel_size.len();
136
137 if self.stride.len() != ndim {
139 return Err(ConvError::InvalidStride(format!(
140 "stride length {} != kernel_size length {ndim}",
141 self.stride.len()
142 )));
143 }
144 if self.padding.len() != ndim {
145 return Err(ConvError::InvalidPadding(format!(
146 "padding length {} != kernel_size length {ndim}",
147 self.padding.len()
148 )));
149 }
150 if self.dilation.len() != ndim {
151 return Err(ConvError::InvalidDilation(format!(
152 "dilation length {} != kernel_size length {ndim}",
153 self.dilation.len()
154 )));
155 }
156
157 for i in 0..ndim {
158 if self.kernel_size[i] == 0 {
159 return Err(ConvError::InvalidKernelSize(format!(
160 "kernel_size[{i}] must be > 0"
161 )));
162 }
163 if self.stride[i] == 0 {
164 return Err(ConvError::InvalidStride(format!("stride[{i}] must be > 0")));
165 }
166 if self.dilation[i] == 0 {
167 return Err(ConvError::InvalidDilation(format!(
168 "dilation[{i}] must be > 0"
169 )));
170 }
171 }
172
173 if self.groups == 0 {
174 return Err(ConvError::InvalidGroups {
175 groups: 0,
176 in_channels: 0,
177 out_channels: 0,
178 });
179 }
180
181 Ok(())
182 }
183
184 pub fn num_spatial_dims(&self) -> usize {
186 self.kernel_size.len()
187 }
188}
189
190pub fn conv1d(
196 input: &ArrayD<f64>,
197 weight: &ArrayD<f64>,
198 bias: Option<&ArrayD<f64>>,
199 config: &ConvConfig,
200) -> Result<ArrayD<f64>, ConvError> {
201 config.validate()?;
202
203 let in_shape = input.shape();
204 if in_shape.is_empty() || input.is_empty() {
205 return Err(ConvError::EmptyInput);
206 }
207 if in_shape.len() != 3 {
208 return Err(ConvError::InsufficientDimensions {
209 ndim: in_shape.len(),
210 required: 3,
211 });
212 }
213
214 let w_shape = weight.shape();
215 if w_shape.len() != 3 {
216 return Err(ConvError::InsufficientDimensions {
217 ndim: w_shape.len(),
218 required: 3,
219 });
220 }
221
222 let batch = in_shape[0];
223 let in_channels = in_shape[1];
224 let in_len = in_shape[2];
225 let out_channels = w_shape[0];
226 let kernel_len = config.kernel_size[0];
227 let groups = config.groups;
228
229 if !in_channels.is_multiple_of(groups) || !out_channels.is_multiple_of(groups) {
231 return Err(ConvError::InvalidGroups {
232 groups,
233 in_channels,
234 out_channels,
235 });
236 }
237
238 let out_len = config.output_size(in_len, 0);
239 let in_channels_per_group = in_channels / groups;
240 let out_channels_per_group = out_channels / groups;
241
242 let mut output = ArrayD::zeros(IxDyn(&[batch, out_channels, out_len]));
243
244 let stride = config.stride[0];
245 let padding = config.padding[0];
246 let dilation = config.dilation[0];
247
248 for b in 0..batch {
249 for g in 0..groups {
250 let oc_start = g * out_channels_per_group;
251 let ic_start = g * in_channels_per_group;
252
253 for oc in 0..out_channels_per_group {
254 for ol in 0..out_len {
255 let mut sum = 0.0_f64;
256 for ic in 0..in_channels_per_group {
257 for kl in 0..kernel_len {
258 let il_raw = ol as isize * stride as isize
259 + kl as isize * dilation as isize
260 - padding as isize;
261 if il_raw >= 0 && (il_raw as usize) < in_len {
262 let il = il_raw as usize;
263 sum += input[[b, ic_start + ic, il].as_ref()]
264 * weight[[oc_start + oc, ic, kl].as_ref()];
265 }
266 }
267 }
268 output[[b, oc_start + oc, ol].as_ref()] = sum;
269 }
270 }
271 }
272 }
273
274 if let Some(bias_arr) = bias {
276 for b in 0..batch {
277 for oc in 0..out_channels {
278 let bias_val = bias_arr[IxDyn(&[oc])];
279 for ol in 0..out_len {
280 output[[b, oc, ol].as_ref()] += bias_val;
281 }
282 }
283 }
284 }
285
286 Ok(output)
287}
288
289pub fn conv2d(
295 input: &ArrayD<f64>,
296 weight: &ArrayD<f64>,
297 bias: Option<&ArrayD<f64>>,
298 config: &ConvConfig,
299) -> Result<ArrayD<f64>, ConvError> {
300 config.validate()?;
301
302 let in_shape = input.shape();
303 if in_shape.is_empty() || input.is_empty() {
304 return Err(ConvError::EmptyInput);
305 }
306 if in_shape.len() != 4 {
307 return Err(ConvError::InsufficientDimensions {
308 ndim: in_shape.len(),
309 required: 4,
310 });
311 }
312
313 let w_shape = weight.shape();
314 if w_shape.len() != 4 {
315 return Err(ConvError::InsufficientDimensions {
316 ndim: w_shape.len(),
317 required: 4,
318 });
319 }
320
321 let batch = in_shape[0];
322 let in_channels = in_shape[1];
323 let in_h = in_shape[2];
324 let in_w = in_shape[3];
325 let out_channels = w_shape[0];
326 let groups = config.groups;
327
328 if !in_channels.is_multiple_of(groups) || !out_channels.is_multiple_of(groups) {
329 return Err(ConvError::InvalidGroups {
330 groups,
331 in_channels,
332 out_channels,
333 });
334 }
335
336 let out_h = config.output_size(in_h, 0);
337 let out_w = config.output_size(in_w, 1);
338 let in_channels_per_group = in_channels / groups;
339 let out_channels_per_group = out_channels / groups;
340
341 let k_h = config.kernel_size[0];
342 let k_w = config.kernel_size[1];
343 let stride_h = config.stride[0];
344 let stride_w = config.stride[1];
345 let pad_h = config.padding[0];
346 let pad_w = config.padding[1];
347 let dil_h = config.dilation[0];
348 let dil_w = config.dilation[1];
349
350 let mut output = ArrayD::zeros(IxDyn(&[batch, out_channels, out_h, out_w]));
351
352 for b in 0..batch {
353 for g in 0..groups {
354 let oc_start = g * out_channels_per_group;
355 let ic_start = g * in_channels_per_group;
356
357 for oc in 0..out_channels_per_group {
358 for oh in 0..out_h {
359 for ow in 0..out_w {
360 let mut sum = 0.0_f64;
361 for ic in 0..in_channels_per_group {
362 for kh in 0..k_h {
363 for kw in 0..k_w {
364 let ih_raw = oh as isize * stride_h as isize
365 + kh as isize * dil_h as isize
366 - pad_h as isize;
367 let iw_raw = ow as isize * stride_w as isize
368 + kw as isize * dil_w as isize
369 - pad_w as isize;
370 if ih_raw >= 0
371 && (ih_raw as usize) < in_h
372 && iw_raw >= 0
373 && (iw_raw as usize) < in_w
374 {
375 let ih = ih_raw as usize;
376 let iw = iw_raw as usize;
377 sum += input[IxDyn(&[b, ic_start + ic, ih, iw])]
378 * weight[IxDyn(&[oc_start + oc, ic, kh, kw])];
379 }
380 }
381 }
382 }
383 output[IxDyn(&[b, oc_start + oc, oh, ow])] = sum;
384 }
385 }
386 }
387 }
388 }
389
390 if let Some(bias_arr) = bias {
392 for b in 0..batch {
393 for oc in 0..out_channels {
394 let bias_val = bias_arr[IxDyn(&[oc])];
395 for oh in 0..out_h {
396 for ow in 0..out_w {
397 output[IxDyn(&[b, oc, oh, ow])] += bias_val;
398 }
399 }
400 }
401 }
402 }
403
404 Ok(output)
405}
406
407pub fn conv_transpose2d(
416 input: &ArrayD<f64>,
417 weight: &ArrayD<f64>,
418 bias: Option<&ArrayD<f64>>,
419 config: &ConvConfig,
420 output_padding: &[usize],
421) -> Result<ArrayD<f64>, ConvError> {
422 config.validate()?;
423
424 let in_shape = input.shape();
425 if in_shape.is_empty() || input.is_empty() {
426 return Err(ConvError::EmptyInput);
427 }
428 if in_shape.len() != 4 {
429 return Err(ConvError::InsufficientDimensions {
430 ndim: in_shape.len(),
431 required: 4,
432 });
433 }
434
435 let w_shape = weight.shape();
436 if w_shape.len() != 4 {
437 return Err(ConvError::InsufficientDimensions {
438 ndim: w_shape.len(),
439 required: 4,
440 });
441 }
442
443 let batch = in_shape[0];
444 let in_channels = in_shape[1];
445 let in_h = in_shape[2];
446 let in_w = in_shape[3];
447 let groups = config.groups;
448
449 let out_channels_per_group = w_shape[1];
451 let out_channels = out_channels_per_group * groups;
452
453 if !in_channels.is_multiple_of(groups) {
454 return Err(ConvError::InvalidGroups {
455 groups,
456 in_channels,
457 out_channels,
458 });
459 }
460
461 let in_channels_per_group = in_channels / groups;
462 let k_h = config.kernel_size[0];
463 let k_w = config.kernel_size[1];
464 let stride_h = config.stride[0];
465 let stride_w = config.stride[1];
466 let pad_h = config.padding[0];
467 let pad_w = config.padding[1];
468 let dil_h = config.dilation[0];
469 let dil_w = config.dilation[1];
470
471 let out_pad_h = if output_padding.is_empty() {
472 0
473 } else {
474 output_padding[0]
475 };
476 let out_pad_w = if output_padding.len() < 2 {
477 0
478 } else {
479 output_padding[1]
480 };
481
482 let out_h = (in_h - 1) * stride_h + dil_h * (k_h - 1) + 1 + out_pad_h - 2 * pad_h;
483 let out_w = (in_w - 1) * stride_w + dil_w * (k_w - 1) + 1 + out_pad_w - 2 * pad_w;
484
485 let mut output = ArrayD::zeros(IxDyn(&[batch, out_channels, out_h, out_w]));
486
487 for b in 0..batch {
489 for g in 0..groups {
490 let ic_start = g * in_channels_per_group;
491 let oc_start = g * out_channels_per_group;
492
493 for ic in 0..in_channels_per_group {
494 for ih in 0..in_h {
495 for iw in 0..in_w {
496 let input_val = input[IxDyn(&[b, ic_start + ic, ih, iw])];
497 for oc in 0..out_channels_per_group {
498 for kh in 0..k_h {
499 for kw in 0..k_w {
500 let oh_raw = ih as isize * stride_h as isize
501 + kh as isize * dil_h as isize
502 - pad_h as isize;
503 let ow_raw = iw as isize * stride_w as isize
504 + kw as isize * dil_w as isize
505 - pad_w as isize;
506 if oh_raw >= 0
507 && (oh_raw as usize) < out_h
508 && ow_raw >= 0
509 && (ow_raw as usize) < out_w
510 {
511 let oh = oh_raw as usize;
512 let ow = ow_raw as usize;
513 output[IxDyn(&[b, oc_start + oc, oh, ow])] +=
514 input_val * weight[IxDyn(&[ic_start + ic, oc, kh, kw])];
515 }
516 }
517 }
518 }
519 }
520 }
521 }
522 }
523 }
524
525 if let Some(bias_arr) = bias {
527 for b in 0..batch {
528 for oc in 0..out_channels {
529 let bias_val = bias_arr[IxDyn(&[oc])];
530 for oh in 0..out_h {
531 for ow in 0..out_w {
532 output[IxDyn(&[b, oc, oh, ow])] += bias_val;
533 }
534 }
535 }
536 }
537 }
538
539 Ok(output)
540}
541
542pub fn depthwise_conv2d(
547 input: &ArrayD<f64>,
548 weight: &ArrayD<f64>,
549 bias: Option<&ArrayD<f64>>,
550 config: &ConvConfig,
551) -> Result<ArrayD<f64>, ConvError> {
552 let in_shape = input.shape();
553 if in_shape.len() < 4 {
554 return Err(ConvError::InsufficientDimensions {
555 ndim: in_shape.len(),
556 required: 4,
557 });
558 }
559
560 let in_channels = in_shape[1];
561 let mut dw_config = config.clone();
562 dw_config.groups = in_channels;
563
564 conv2d(input, weight, bias, &dw_config)
565}
566
567pub fn im2col(
572 input: &ArrayD<f64>,
573 kernel_size: &[usize],
574 stride: &[usize],
575 padding: &[usize],
576 dilation: &[usize],
577) -> Result<ArrayD<f64>, ConvError> {
578 let in_shape = input.shape();
579 if in_shape.is_empty() || input.is_empty() {
580 return Err(ConvError::EmptyInput);
581 }
582 if in_shape.len() != 4 {
583 return Err(ConvError::InsufficientDimensions {
584 ndim: in_shape.len(),
585 required: 4,
586 });
587 }
588 if kernel_size.len() != 2 || stride.len() != 2 || padding.len() != 2 || dilation.len() != 2 {
589 return Err(ConvError::InvalidKernelSize(
590 "im2col requires exactly 2 spatial dimensions".to_string(),
591 ));
592 }
593
594 let batch = in_shape[0];
595 let channels = in_shape[1];
596 let in_h = in_shape[2];
597 let in_w = in_shape[3];
598 let k_h = kernel_size[0];
599 let k_w = kernel_size[1];
600 let s_h = stride[0];
601 let s_w = stride[1];
602 let p_h = padding[0];
603 let p_w = padding[1];
604 let d_h = dilation[0];
605 let d_w = dilation[1];
606
607 let eff_k_h = d_h * (k_h - 1) + 1;
608 let eff_k_w = d_w * (k_w - 1) + 1;
609 let out_h = (in_h + 2 * p_h - eff_k_h) / s_h + 1;
610 let out_w = (in_w + 2 * p_w - eff_k_w) / s_w + 1;
611
612 let col_rows = channels * k_h * k_w;
613 let col_cols = out_h * out_w;
614 let mut cols = ArrayD::zeros(IxDyn(&[batch, col_rows, col_cols]));
615
616 for b in 0..batch {
617 let mut col_idx = 0;
618 for c in 0..channels {
619 for kh in 0..k_h {
620 for kw in 0..k_w {
621 let mut spatial_idx = 0;
622 for oh in 0..out_h {
623 for ow in 0..out_w {
624 let ih_raw = oh as isize * s_h as isize + kh as isize * d_h as isize
625 - p_h as isize;
626 let iw_raw = ow as isize * s_w as isize + kw as isize * d_w as isize
627 - p_w as isize;
628 let val = if ih_raw >= 0
629 && (ih_raw as usize) < in_h
630 && iw_raw >= 0
631 && (iw_raw as usize) < in_w
632 {
633 input[IxDyn(&[b, c, ih_raw as usize, iw_raw as usize])]
634 } else {
635 0.0
636 };
637 cols[IxDyn(&[b, col_idx, spatial_idx])] = val;
638 spatial_idx += 1;
639 }
640 }
641 col_idx += 1;
642 }
643 }
644 }
645 }
646
647 Ok(cols)
648}
649
650pub fn col2im(
657 cols: &ArrayD<f64>,
658 output_size: &[usize],
659 kernel_size: &[usize],
660 stride: &[usize],
661 padding: &[usize],
662 dilation: &[usize],
663) -> Result<ArrayD<f64>, ConvError> {
664 let col_shape = cols.shape();
665 if col_shape.is_empty() || cols.is_empty() {
666 return Err(ConvError::EmptyInput);
667 }
668 if col_shape.len() != 3 {
669 return Err(ConvError::InsufficientDimensions {
670 ndim: col_shape.len(),
671 required: 3,
672 });
673 }
674 if output_size.len() != 4 {
675 return Err(ConvError::InvalidKernelSize(
676 "output_size must have 4 elements [batch, channels, H, W]".to_string(),
677 ));
678 }
679
680 let batch = output_size[0];
681 let channels = output_size[1];
682 let out_h_img = output_size[2];
683 let out_w_img = output_size[3];
684
685 let k_h = kernel_size[0];
686 let k_w = kernel_size[1];
687 let s_h = stride[0];
688 let s_w = stride[1];
689 let p_h = padding[0];
690 let p_w = padding[1];
691 let d_h = dilation[0];
692 let d_w = dilation[1];
693
694 let eff_k_h = d_h * (k_h - 1) + 1;
695 let eff_k_w = d_w * (k_w - 1) + 1;
696 let col_out_h = (out_h_img + 2 * p_h - eff_k_h) / s_h + 1;
697 let col_out_w = (out_w_img + 2 * p_w - eff_k_w) / s_w + 1;
698
699 let mut output = ArrayD::zeros(IxDyn(&[batch, channels, out_h_img, out_w_img]));
700
701 for b in 0..batch {
702 let mut col_idx = 0;
703 for c in 0..channels {
704 for kh in 0..k_h {
705 for kw in 0..k_w {
706 let mut spatial_idx = 0;
707 for oh in 0..col_out_h {
708 for ow in 0..col_out_w {
709 let ih_raw = oh as isize * s_h as isize + kh as isize * d_h as isize
710 - p_h as isize;
711 let iw_raw = ow as isize * s_w as isize + kw as isize * d_w as isize
712 - p_w as isize;
713 if ih_raw >= 0
714 && (ih_raw as usize) < out_h_img
715 && iw_raw >= 0
716 && (iw_raw as usize) < out_w_img
717 {
718 output[IxDyn(&[b, c, ih_raw as usize, iw_raw as usize])] +=
719 cols[IxDyn(&[b, col_idx, spatial_idx])];
720 }
721 spatial_idx += 1;
722 }
723 }
724 col_idx += 1;
725 }
726 }
727 }
728 }
729
730 Ok(output)
731}
732
733#[derive(Debug, Clone)]
735pub struct ConvStats {
736 pub input_shape: Vec<usize>,
738 pub output_shape: Vec<usize>,
740 pub kernel_shape: Vec<usize>,
742 pub num_parameters: usize,
744 pub flops: u64,
746 pub receptive_field: Vec<usize>,
748}
749
750impl ConvStats {
751 pub fn compute(
756 input_shape: &[usize],
757 weight_shape: &[usize],
758 config: &ConvConfig,
759 ) -> Result<Self, ConvError> {
760 config.validate()?;
761
762 if input_shape.len() < 3 {
763 return Err(ConvError::InsufficientDimensions {
764 ndim: input_shape.len(),
765 required: 3,
766 });
767 }
768 if weight_shape.len() < 3 {
769 return Err(ConvError::InsufficientDimensions {
770 ndim: weight_shape.len(),
771 required: 3,
772 });
773 }
774
775 let batch = input_shape[0];
776 let out_channels = weight_shape[0];
777 let ndim = config.num_spatial_dims();
778
779 let mut output_spatial = Vec::with_capacity(ndim);
781 for d in 0..ndim {
782 let in_size = input_shape[2 + d];
783 output_spatial.push(config.output_size(in_size, d));
784 }
785
786 let mut output_shape = vec![batch, out_channels];
787 output_shape.extend_from_slice(&output_spatial);
788
789 let weight_params: usize = weight_shape.iter().product();
791 let num_parameters = weight_params + out_channels; let kernel_volume: usize = config.kernel_size.iter().product();
796 let in_channels_per_group = if config.groups > 0 {
797 weight_shape[1]
798 } else {
799 return Err(ConvError::InvalidGroups {
800 groups: 0,
801 in_channels: 0,
802 out_channels: 0,
803 });
804 };
805 let output_elements: u64 = output_shape.iter().map(|&s| s as u64).product();
806 let macs_per_element = (kernel_volume * in_channels_per_group) as u64;
807 let flops = output_elements * macs_per_element * 2;
808
809 let receptive_field: Vec<usize> = (0..ndim)
811 .map(|d| config.dilation[d] * (config.kernel_size[d] - 1) + 1)
812 .collect();
813
814 Ok(Self {
815 input_shape: input_shape.to_vec(),
816 output_shape,
817 kernel_shape: weight_shape.to_vec(),
818 num_parameters,
819 flops,
820 receptive_field,
821 })
822 }
823
824 pub fn summary(&self) -> String {
826 format!(
827 "ConvStats {{ input: {:?}, output: {:?}, kernel: {:?}, \
828 params: {}, flops: {}, receptive_field: {:?} }}",
829 self.input_shape,
830 self.output_shape,
831 self.kernel_shape,
832 self.num_parameters,
833 self.flops,
834 self.receptive_field,
835 )
836 }
837}
838
839#[cfg(test)]
840mod tests {
841 use super::*;
842 use scirs2_core::ndarray::{ArrayD, IxDyn};
843
844 #[test]
845 fn test_conv_config_output_size() {
846 let cfg = ConvConfig::new(vec![3, 3]).with_padding(vec![1, 1]);
848 assert_eq!(cfg.output_size(8, 0), 8);
849 assert_eq!(cfg.output_size(8, 1), 8);
850 }
851
852 #[test]
853 fn test_conv_config_validate_valid() {
854 let cfg = ConvConfig::new(vec![3, 3])
855 .with_stride(vec![1, 1])
856 .with_padding(vec![1, 1])
857 .with_dilation(vec![1, 1])
858 .with_groups(1);
859 assert!(cfg.validate().is_ok());
860 }
861
862 #[test]
863 fn test_conv_config_validate_zero_kernel() {
864 let cfg = ConvConfig::new(vec![0, 3]);
865 let err = cfg.validate();
866 assert!(err.is_err());
867 let msg = format!("{}", err.expect_err("expected error"));
868 assert!(msg.contains("kernel_size"));
869 }
870
871 #[test]
872 fn test_conv1d_basic() {
873 let input = ArrayD::from_shape_vec(IxDyn(&[1, 1, 5]), vec![1.0, 2.0, 3.0, 4.0, 5.0])
876 .expect("input shape");
877 let weight =
878 ArrayD::from_shape_vec(IxDyn(&[1, 1, 3]), vec![1.0, 1.0, 1.0]).expect("weight shape");
879 let cfg = ConvConfig::new(vec![3]);
880
881 let out = conv1d(&input, &weight, None, &cfg).expect("conv1d");
882 assert_eq!(out.shape(), &[1, 1, 3]);
883 assert!((out[IxDyn(&[0, 0, 0])] - 6.0).abs() < 1e-10);
885 assert!((out[IxDyn(&[0, 0, 1])] - 9.0).abs() < 1e-10);
886 assert!((out[IxDyn(&[0, 0, 2])] - 12.0).abs() < 1e-10);
887 }
888
889 #[test]
890 fn test_conv1d_with_bias() {
891 let input =
892 ArrayD::from_shape_vec(IxDyn(&[1, 1, 3]), vec![1.0, 2.0, 3.0]).expect("input shape");
893 let weight = ArrayD::from_shape_vec(IxDyn(&[2, 1, 3]), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0])
894 .expect("weight shape");
895 let bias = ArrayD::from_shape_vec(IxDyn(&[2]), vec![10.0, 20.0]).expect("bias shape");
896 let cfg = ConvConfig::new(vec![3]);
897
898 let out = conv1d(&input, &weight, Some(&bias), &cfg).expect("conv1d");
899 assert_eq!(out.shape(), &[1, 2, 1]);
900 assert!((out[IxDyn(&[0, 0, 0])] - 11.0).abs() < 1e-10);
902 assert!((out[IxDyn(&[0, 1, 0])] - 23.0).abs() < 1e-10);
904 }
905
906 #[test]
907 fn test_conv2d_identity_kernel() {
908 let input = ArrayD::from_shape_vec(
910 IxDyn(&[1, 2, 2, 2]),
911 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
912 )
913 .expect("input shape");
914 let weight =
916 ArrayD::from_shape_vec(IxDyn(&[1, 2, 1, 1]), vec![1.0, 1.0]).expect("weight shape");
917 let cfg = ConvConfig::new(vec![1, 1]);
918
919 let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
920 assert_eq!(out.shape(), &[1, 1, 2, 2]);
921 assert!((out[IxDyn(&[0, 0, 0, 0])] - 6.0).abs() < 1e-10);
923 assert!((out[IxDyn(&[0, 0, 0, 1])] - 8.0).abs() < 1e-10);
924 assert!((out[IxDyn(&[0, 0, 1, 0])] - 10.0).abs() < 1e-10);
925 assert!((out[IxDyn(&[0, 0, 1, 1])] - 12.0).abs() < 1e-10);
926 }
927
928 #[test]
929 fn test_conv2d_basic() {
930 let input =
932 ArrayD::from_shape_vec(IxDyn(&[1, 1, 4, 4]), (1..=16).map(|x| x as f64).collect())
933 .expect("input shape");
934 let weight = ArrayD::ones(IxDyn(&[1, 1, 3, 3]));
935 let cfg = ConvConfig::new(vec![3, 3]);
936
937 let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
938 assert_eq!(out.shape(), &[1, 1, 2, 2]);
939
940 assert!((out[IxDyn(&[0, 0, 0, 0])] - 54.0).abs() < 1e-10);
942 }
943
944 #[test]
945 fn test_conv2d_with_padding() {
946 let input = ArrayD::ones(IxDyn(&[1, 1, 4, 4]));
948 let weight = ArrayD::ones(IxDyn(&[1, 1, 3, 3]));
949 let cfg = ConvConfig::new(vec![3, 3]).with_padding(vec![1, 1]);
950
951 let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
952 assert_eq!(out.shape(), &[1, 1, 4, 4]);
953
954 assert!((out[IxDyn(&[0, 0, 1, 1])] - 9.0).abs() < 1e-10);
956 assert!((out[IxDyn(&[0, 0, 0, 0])] - 4.0).abs() < 1e-10);
958 }
959
960 #[test]
961 fn test_conv2d_stride2() {
962 let input = ArrayD::ones(IxDyn(&[1, 1, 4, 4]));
964 let weight = ArrayD::ones(IxDyn(&[1, 1, 3, 3]));
965 let cfg = ConvConfig::new(vec![3, 3])
966 .with_stride(vec![2, 2])
967 .with_padding(vec![1, 1]);
968
969 let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
970 assert_eq!(out.shape(), &[1, 1, 2, 2]);
972 }
973
974 #[test]
975 fn test_conv2d_groups() {
976 let input = ArrayD::from_shape_vec(
978 IxDyn(&[1, 2, 3, 3]),
979 vec![
980 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
983 ],
984 )
985 .expect("input shape");
986 let weight = ArrayD::ones(IxDyn(&[2, 1, 3, 3]));
988 let cfg = ConvConfig::new(vec![3, 3]).with_groups(2);
989
990 let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
991 assert_eq!(out.shape(), &[1, 2, 1, 1]);
992 assert!((out[IxDyn(&[0, 0, 0, 0])] - 9.0).abs() < 1e-10);
994 assert!((out[IxDyn(&[0, 1, 0, 0])] - 18.0).abs() < 1e-10);
996 }
997
998 #[test]
999 fn test_conv2d_dilation() {
1000 let input = ArrayD::ones(IxDyn(&[1, 1, 7, 7]));
1002 let weight = ArrayD::ones(IxDyn(&[1, 1, 3, 3]));
1003 let cfg = ConvConfig::new(vec![3, 3]).with_dilation(vec![2, 2]);
1004
1005 let out = conv2d(&input, &weight, None, &cfg).expect("conv2d");
1006 assert_eq!(out.shape(), &[1, 1, 3, 3]);
1008 assert!((out[IxDyn(&[0, 0, 1, 1])] - 9.0).abs() < 1e-10);
1010 }
1011
1012 #[test]
1013 fn test_conv_transpose2d_basic() {
1014 let input = ArrayD::ones(IxDyn(&[1, 1, 2, 2]));
1017 let weight = ArrayD::ones(IxDyn(&[1, 1, 3, 3]));
1018 let cfg = ConvConfig::new(vec![3, 3]).with_stride(vec![2, 2]);
1019
1020 let out = conv_transpose2d(&input, &weight, None, &cfg, &[]).expect("conv_transpose2d");
1021 assert_eq!(out.shape(), &[1, 1, 5, 5]);
1022 assert!((out[IxDyn(&[0, 0, 2, 2])] - 4.0).abs() < 1e-10);
1024 assert!((out[IxDyn(&[0, 0, 0, 0])] - 1.0).abs() < 1e-10);
1026 }
1027
1028 #[test]
1029 fn test_depthwise_conv2d() {
1030 let input = ArrayD::from_shape_vec(
1032 IxDyn(&[1, 2, 3, 3]),
1033 vec![
1034 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ],
1037 )
1038 .expect("input shape");
1039 let weight = ArrayD::ones(IxDyn(&[2, 1, 3, 3]));
1041 let cfg = ConvConfig::new(vec![3, 3]);
1042
1043 let out = depthwise_conv2d(&input, &weight, None, &cfg).expect("depthwise");
1044 assert_eq!(out.shape(), &[1, 2, 1, 1]);
1045 assert!((out[IxDyn(&[0, 0, 0, 0])] - 9.0).abs() < 1e-10);
1046 assert!((out[IxDyn(&[0, 1, 0, 0])] - 18.0).abs() < 1e-10);
1047 }
1048
1049 #[test]
1050 fn test_im2col_shape() {
1051 let input = ArrayD::ones(IxDyn(&[1, 2, 4, 4]));
1053 let cols = im2col(&input, &[3, 3], &[1, 1], &[0, 0], &[1, 1]).expect("im2col");
1054 assert_eq!(cols.shape(), &[1, 18, 4]);
1055 }
1056
1057 #[test]
1058 fn test_im2col_values() {
1059 let input = ArrayD::from_shape_vec(
1061 IxDyn(&[1, 1, 3, 3]),
1062 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1063 )
1064 .expect("input shape");
1065
1066 let cols = im2col(&input, &[2, 2], &[1, 1], &[0, 0], &[1, 1]).expect("im2col");
1067 assert_eq!(cols.shape(), &[1, 4, 4]);
1069
1070 assert!((cols[IxDyn(&[0, 0, 0])] - 1.0).abs() < 1e-10);
1072 assert!((cols[IxDyn(&[0, 1, 0])] - 2.0).abs() < 1e-10);
1073 assert!((cols[IxDyn(&[0, 2, 0])] - 4.0).abs() < 1e-10);
1074 assert!((cols[IxDyn(&[0, 3, 0])] - 5.0).abs() < 1e-10);
1075 }
1076
1077 #[test]
1078 fn test_col2im_roundtrip_no_overlap() {
1079 let input =
1081 ArrayD::from_shape_vec(IxDyn(&[1, 1, 4, 4]), (1..=16).map(|x| x as f64).collect())
1082 .expect("input shape");
1083
1084 let kernel = [2, 2];
1085 let stride = [2, 2];
1086 let padding = [0, 0];
1087 let dilation = [1, 1];
1088
1089 let cols = im2col(&input, &kernel, &stride, &padding, &dilation).expect("im2col");
1090 let reconstructed =
1091 col2im(&cols, &[1, 1, 4, 4], &kernel, &stride, &padding, &dilation).expect("col2im");
1092
1093 assert_eq!(reconstructed.shape(), input.shape());
1094 for (a, b) in input.iter().zip(reconstructed.iter()) {
1095 assert!((a - b).abs() < 1e-10, "mismatch: {a} vs {b}");
1096 }
1097 }
1098
1099 #[test]
1100 fn test_conv_stats_flops() {
1101 let cfg = ConvConfig::new(vec![3, 3]);
1102 let stats = ConvStats::compute(&[1, 3, 32, 32], &[16, 3, 3, 3], &cfg).expect("conv stats");
1103 assert!(stats.flops > 0);
1104 }
1105
1106 #[test]
1107 fn test_conv_stats_parameters() {
1108 let cfg = ConvConfig::new(vec![3, 3]);
1110 let stats = ConvStats::compute(&[1, 3, 32, 32], &[16, 3, 3, 3], &cfg).expect("conv stats");
1111 assert_eq!(stats.num_parameters, 432 + 16);
1112 }
1113
1114 #[test]
1115 fn test_conv_stats_summary_nonempty() {
1116 let cfg = ConvConfig::new(vec![3, 3]);
1117 let stats = ConvStats::compute(&[1, 3, 32, 32], &[16, 3, 3, 3], &cfg).expect("conv stats");
1118 let s = stats.summary();
1119 assert!(!s.is_empty());
1120 assert!(s.contains("ConvStats"));
1121 }
1122
1123 #[test]
1124 fn test_conv_error_display() {
1125 let errors: Vec<ConvError> = vec![
1126 ConvError::InvalidKernelSize("zero".to_string()),
1127 ConvError::InvalidStride("zero".to_string()),
1128 ConvError::InvalidPadding("negative".to_string()),
1129 ConvError::InvalidDilation("zero".to_string()),
1130 ConvError::ShapeMismatch {
1131 expected: vec![1, 2],
1132 got: vec![3, 4],
1133 },
1134 ConvError::InsufficientDimensions {
1135 ndim: 2,
1136 required: 4,
1137 },
1138 ConvError::InvalidGroups {
1139 groups: 3,
1140 in_channels: 4,
1141 out_channels: 6,
1142 },
1143 ConvError::EmptyInput,
1144 ];
1145 for err in &errors {
1146 let msg = format!("{err}");
1147 assert!(!msg.is_empty(), "error display should be non-empty");
1148 }
1149 }
1150}