1use crate::{BackendResult, Buffer, Device};
8use torsh_core::dtype::DType;
9
10#[cfg(not(feature = "std"))]
11use alloc::{boxed::Box, vec::Vec};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ConvolutionType {
16 Conv1D,
18 Conv2D,
20 Conv3D,
22 DepthwiseConv2D,
24 SeparableConv2D,
26 ConvTranspose2D,
28 DilatedConv2D,
30 GroupedConv2D,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ConvolutionAlgorithm {
37 Auto,
39 Direct,
41 Im2col,
43 Winograd,
45 FftBased,
47 Optimized,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum PaddingMode {
54 Valid,
56 Same,
58 Custom,
60}
61
62#[derive(Debug, Clone)]
64pub struct ConvolutionConfig {
65 pub conv_type: ConvolutionType,
67 pub input_dims: Vec<usize>,
69 pub output_dims: Vec<usize>,
71 pub kernel_dims: Vec<usize>,
73 pub strides: Vec<usize>,
75 pub padding: Vec<usize>,
77 pub dilation: Vec<usize>,
79 pub groups: usize,
81 pub padding_mode: PaddingMode,
83 pub dtype: DType,
85 pub algorithm: ConvolutionAlgorithm,
87}
88
89impl ConvolutionConfig {
90 pub fn conv2d(
92 batch_size: usize,
93 in_channels: usize,
94 out_channels: usize,
95 input_size: (usize, usize),
96 kernel_size: (usize, usize),
97 stride: (usize, usize),
98 padding: (usize, usize),
99 ) -> Self {
100 let (in_h, in_w) = input_size;
101 let (k_h, k_w) = kernel_size;
102 let (s_h, s_w) = stride;
103 let (p_h, p_w) = padding;
104
105 let out_h = (in_h + 2 * p_h - k_h) / s_h + 1;
107 let out_w = (in_w + 2 * p_w - k_w) / s_w + 1;
108
109 Self {
110 conv_type: ConvolutionType::Conv2D,
111 input_dims: vec![batch_size, in_channels, in_h, in_w],
112 output_dims: vec![batch_size, out_channels, out_h, out_w],
113 kernel_dims: vec![out_channels, in_channels, k_h, k_w],
114 strides: vec![s_h, s_w],
115 padding: vec![p_h, p_w],
116 dilation: vec![1, 1],
117 groups: 1,
118 padding_mode: PaddingMode::Custom,
119 dtype: DType::F32,
120 algorithm: ConvolutionAlgorithm::Auto,
121 }
122 }
123
124 pub fn depthwise_conv2d(
126 batch_size: usize,
127 channels: usize,
128 input_size: (usize, usize),
129 kernel_size: (usize, usize),
130 stride: (usize, usize),
131 padding: (usize, usize),
132 ) -> Self {
133 let mut config = Self::conv2d(
134 batch_size,
135 channels,
136 channels,
137 input_size,
138 kernel_size,
139 stride,
140 padding,
141 );
142 config.conv_type = ConvolutionType::DepthwiseConv2D;
143 config.groups = channels;
144 config.kernel_dims = vec![channels, 1, kernel_size.0, kernel_size.1];
145 config
146 }
147
148 pub fn with_algorithm(mut self, algorithm: ConvolutionAlgorithm) -> Self {
150 self.algorithm = algorithm;
151 self
152 }
153
154 pub fn with_dtype(mut self, dtype: DType) -> Self {
156 self.dtype = dtype;
157 self
158 }
159
160 pub fn with_dilation(mut self, dilation: Vec<usize>) -> Self {
162 self.dilation = dilation;
163 self
164 }
165
166 pub fn input_elements(&self) -> usize {
168 self.input_dims.iter().product()
169 }
170
171 pub fn output_elements(&self) -> usize {
173 self.output_dims.iter().product()
174 }
175
176 pub fn kernel_elements(&self) -> usize {
178 self.kernel_dims.iter().product()
179 }
180
181 pub fn input_buffer_size(&self) -> usize {
183 let element_size = match self.dtype {
184 DType::F32 => 4,
185 DType::F64 => 8,
186 DType::F16 => 2,
187 _ => 4,
188 };
189 self.input_elements() * element_size
190 }
191
192 pub fn output_buffer_size(&self) -> usize {
194 let element_size = match self.dtype {
195 DType::F32 => 4,
196 DType::F64 => 8,
197 DType::F16 => 2,
198 _ => 4,
199 };
200 self.output_elements() * element_size
201 }
202
203 pub fn kernel_buffer_size(&self) -> usize {
205 let element_size = match self.dtype {
206 DType::F32 => 4,
207 DType::F64 => 8,
208 DType::F16 => 2,
209 _ => 4,
210 };
211 self.kernel_elements() * element_size
212 }
213
214 pub fn is_valid(&self) -> bool {
216 !self.input_dims.is_empty()
217 && !self.output_dims.is_empty()
218 && !self.kernel_dims.is_empty()
219 && self.input_dims.iter().all(|&d| d > 0)
220 && self.output_dims.iter().all(|&d| d > 0)
221 && self.kernel_dims.iter().all(|&d| d > 0)
222 && self.groups > 0
223 }
224}
225
226#[async_trait::async_trait]
228pub trait ConvolutionOps: Send + Sync {
229 async fn convolution(
231 &self,
232 device: &Device,
233 input: &Buffer,
234 kernel: &Buffer,
235 bias: Option<&Buffer>,
236 output: &Buffer,
237 config: &ConvolutionConfig,
238 ) -> BackendResult<()>;
239
240 async fn conv2d(
242 &self,
243 device: &Device,
244 input: &Buffer,
245 kernel: &Buffer,
246 bias: Option<&Buffer>,
247 output: &Buffer,
248 stride: (usize, usize),
249 padding: (usize, usize),
250 dilation: (usize, usize),
251 ) -> BackendResult<()>;
252
253 async fn depthwise_conv2d(
255 &self,
256 device: &Device,
257 input: &Buffer,
258 kernel: &Buffer,
259 bias: Option<&Buffer>,
260 output: &Buffer,
261 stride: (usize, usize),
262 padding: (usize, usize),
263 ) -> BackendResult<()>;
264
265 async fn conv_transpose2d(
267 &self,
268 device: &Device,
269 input: &Buffer,
270 kernel: &Buffer,
271 bias: Option<&Buffer>,
272 output: &Buffer,
273 stride: (usize, usize),
274 padding: (usize, usize),
275 output_padding: (usize, usize),
276 ) -> BackendResult<()>;
277
278 async fn grouped_conv2d(
280 &self,
281 device: &Device,
282 input: &Buffer,
283 kernel: &Buffer,
284 bias: Option<&Buffer>,
285 output: &Buffer,
286 groups: usize,
287 stride: (usize, usize),
288 padding: (usize, usize),
289 ) -> BackendResult<()>;
290
291 fn select_algorithm(&self, config: &ConvolutionConfig) -> ConvolutionAlgorithm;
293
294 fn supports_convolution(&self) -> bool;
296
297 fn supported_conv_types(&self) -> Vec<ConvolutionType>;
299
300 fn supported_algorithms(&self) -> Vec<ConvolutionAlgorithm>;
302}
303
304#[derive(Debug, Clone)]
306pub struct ConvolutionPerformanceHints {
307 pub small_kernel_algorithm: ConvolutionAlgorithm,
309 pub large_kernel_algorithm: ConvolutionAlgorithm,
311 pub fft_threshold: usize,
313 pub winograd_threshold: usize,
315 pub tile_size: (usize, usize),
317 pub memory_bandwidth: f32,
319 pub compute_throughput: f32,
321}
322
323impl Default for ConvolutionPerformanceHints {
324 fn default() -> Self {
325 Self {
326 small_kernel_algorithm: ConvolutionAlgorithm::Winograd,
327 large_kernel_algorithm: ConvolutionAlgorithm::FftBased,
328 fft_threshold: 7,
329 winograd_threshold: 6,
330 tile_size: (16, 16),
331 memory_bandwidth: 50.0,
332 compute_throughput: 100.0,
333 }
334 }
335}
336
337pub struct DefaultConvolutionOps {
339 performance_hints: ConvolutionPerformanceHints,
340}
341
342impl DefaultConvolutionOps {
343 pub fn new() -> Self {
344 Self {
345 performance_hints: ConvolutionPerformanceHints::default(),
346 }
347 }
348
349 pub fn with_performance_hints(mut self, hints: ConvolutionPerformanceHints) -> Self {
350 self.performance_hints = hints;
351 self
352 }
353}
354
355#[async_trait::async_trait]
356impl ConvolutionOps for DefaultConvolutionOps {
357 async fn convolution(
358 &self,
359 _device: &Device,
360 _input: &Buffer,
361 _kernel: &Buffer,
362 _bias: Option<&Buffer>,
363 _output: &Buffer,
364 _config: &ConvolutionConfig,
365 ) -> BackendResult<()> {
366 Err(torsh_core::error::TorshError::BackendError(
367 "Convolution operations not implemented for this backend".to_string(),
368 ))
369 }
370
371 async fn conv2d(
372 &self,
373 _device: &Device,
374 _input: &Buffer,
375 _kernel: &Buffer,
376 _bias: Option<&Buffer>,
377 _output: &Buffer,
378 _stride: (usize, usize),
379 _padding: (usize, usize),
380 _dilation: (usize, usize),
381 ) -> BackendResult<()> {
382 Err(torsh_core::error::TorshError::BackendError(
383 "Conv2D operations not implemented for this backend".to_string(),
384 ))
385 }
386
387 async fn depthwise_conv2d(
388 &self,
389 _device: &Device,
390 _input: &Buffer,
391 _kernel: &Buffer,
392 _bias: Option<&Buffer>,
393 _output: &Buffer,
394 _stride: (usize, usize),
395 _padding: (usize, usize),
396 ) -> BackendResult<()> {
397 Err(torsh_core::error::TorshError::BackendError(
398 "Depthwise convolution not implemented for this backend".to_string(),
399 ))
400 }
401
402 async fn conv_transpose2d(
403 &self,
404 _device: &Device,
405 _input: &Buffer,
406 _kernel: &Buffer,
407 _bias: Option<&Buffer>,
408 _output: &Buffer,
409 _stride: (usize, usize),
410 _padding: (usize, usize),
411 _output_padding: (usize, usize),
412 ) -> BackendResult<()> {
413 Err(torsh_core::error::TorshError::BackendError(
414 "Transposed convolution not implemented for this backend".to_string(),
415 ))
416 }
417
418 async fn grouped_conv2d(
419 &self,
420 _device: &Device,
421 _input: &Buffer,
422 _kernel: &Buffer,
423 _bias: Option<&Buffer>,
424 _output: &Buffer,
425 _groups: usize,
426 _stride: (usize, usize),
427 _padding: (usize, usize),
428 ) -> BackendResult<()> {
429 Err(torsh_core::error::TorshError::BackendError(
430 "Grouped convolution not implemented for this backend".to_string(),
431 ))
432 }
433
434 fn select_algorithm(&self, config: &ConvolutionConfig) -> ConvolutionAlgorithm {
435 if config.algorithm != ConvolutionAlgorithm::Auto {
436 return config.algorithm;
437 }
438
439 match config.conv_type {
441 ConvolutionType::Conv2D => {
442 if config.kernel_dims.len() >= 4 {
443 let kernel_h = config.kernel_dims[2];
444 let kernel_w = config.kernel_dims[3];
445 let kernel_size = kernel_h.max(kernel_w);
446
447 if kernel_size <= self.performance_hints.winograd_threshold {
448 ConvolutionAlgorithm::Winograd
449 } else if kernel_size >= self.performance_hints.fft_threshold {
450 ConvolutionAlgorithm::FftBased
451 } else {
452 ConvolutionAlgorithm::Im2col
453 }
454 } else {
455 ConvolutionAlgorithm::Direct
456 }
457 }
458 ConvolutionType::DepthwiseConv2D => ConvolutionAlgorithm::Direct,
459 ConvolutionType::SeparableConv2D => ConvolutionAlgorithm::Direct,
460 _ => ConvolutionAlgorithm::Im2col,
461 }
462 }
463
464 fn supports_convolution(&self) -> bool {
465 false
466 }
467
468 fn supported_conv_types(&self) -> Vec<ConvolutionType> {
469 vec![]
470 }
471
472 fn supported_algorithms(&self) -> Vec<ConvolutionAlgorithm> {
473 vec![ConvolutionAlgorithm::Direct]
474 }
475}
476
477impl Default for DefaultConvolutionOps {
478 fn default() -> Self {
479 Self::new()
480 }
481}
482
483pub mod algorithms {
485 use super::*;
486
487 pub struct DirectConvolution;
489
490 impl DirectConvolution {
491 pub fn conv2d_direct(
493 input: &[f32],
494 kernel: &[f32],
495 output: &mut [f32],
496 input_dims: &[usize],
497 kernel_dims: &[usize],
498 output_dims: &[usize],
499 stride: (usize, usize),
500 padding: (usize, usize),
501 ) -> BackendResult<()> {
502 let (batch, in_channels, in_h, in_w) =
503 (input_dims[0], input_dims[1], input_dims[2], input_dims[3]);
504 let (out_channels, _, k_h, k_w) = (
505 kernel_dims[0],
506 kernel_dims[1],
507 kernel_dims[2],
508 kernel_dims[3],
509 );
510 let (_, _, out_h, out_w) = (
511 output_dims[0],
512 output_dims[1],
513 output_dims[2],
514 output_dims[3],
515 );
516 let (s_h, s_w) = stride;
517 let (p_h, p_w) = padding;
518
519 for b in 0..batch {
520 for oc in 0..out_channels {
521 for oh in 0..out_h {
522 for ow in 0..out_w {
523 let mut sum = 0.0;
524
525 for ic in 0..in_channels {
526 for kh in 0..k_h {
527 for kw in 0..k_w {
528 let ih = oh * s_h + kh;
529 let iw = ow * s_w + kw;
530
531 if ih >= p_h
532 && iw >= p_w
533 && ih < in_h + p_h
534 && iw < in_w + p_w
535 {
536 let input_h = ih - p_h;
537 let input_w = iw - p_w;
538
539 if input_h < in_h && input_w < in_w {
540 let input_idx = b * in_channels * in_h * in_w
541 + ic * in_h * in_w
542 + input_h * in_w
543 + input_w;
544 let kernel_idx = oc * in_channels * k_h * k_w
545 + ic * k_h * k_w
546 + kh * k_w
547 + kw;
548
549 sum += input[input_idx] * kernel[kernel_idx];
550 }
551 }
552 }
553 }
554 }
555
556 let output_idx = b * out_channels * out_h * out_w
557 + oc * out_h * out_w
558 + oh * out_w
559 + ow;
560 output[output_idx] = sum;
561 }
562 }
563 }
564 }
565
566 Ok(())
567 }
568 }
569
570 pub struct Im2colConvolution;
572
573 impl Im2colConvolution {
574 pub fn im2col(
576 input: &[f32],
577 output: &mut [f32],
578 input_dims: &[usize],
579 kernel_size: (usize, usize),
580 stride: (usize, usize),
581 padding: (usize, usize),
582 ) -> BackendResult<()> {
583 let (batch, channels, height, width) =
584 (input_dims[0], input_dims[1], input_dims[2], input_dims[3]);
585 let (k_h, k_w) = kernel_size;
586 let (s_h, s_w) = stride;
587 let (p_h, p_w) = padding;
588
589 let out_h = (height + 2 * p_h - k_h) / s_h + 1;
590 let out_w = (width + 2 * p_w - k_w) / s_w + 1;
591
592 for b in 0..batch {
593 for c in 0..channels {
594 for kh in 0..k_h {
595 for kw in 0..k_w {
596 for oh in 0..out_h {
597 for ow in 0..out_w {
598 let ih = oh * s_h + kh;
599 let iw = ow * s_w + kw;
600
601 let value = if ih >= p_h
602 && iw >= p_w
603 && ih < height + p_h
604 && iw < width + p_w
605 {
606 let input_h = ih - p_h;
607 let input_w = iw - p_w;
608
609 if input_h < height && input_w < width {
610 let input_idx = b * channels * height * width
611 + c * height * width
612 + input_h * width
613 + input_w;
614 input[input_idx]
615 } else {
616 0.0
617 }
618 } else {
619 0.0
620 };
621
622 let col_idx =
623 (b * channels * k_h * k_w + c * k_h * k_w + kh * k_w + kw)
624 * out_h
625 * out_w
626 + oh * out_w
627 + ow;
628
629 if col_idx < output.len() {
630 output[col_idx] = value;
631 }
632 }
633 }
634 }
635 }
636 }
637 }
638
639 Ok(())
640 }
641 }
642
643 pub struct WinogradConvolution;
645
646 impl WinogradConvolution {
647 pub fn can_apply(kernel_size: (usize, usize), stride: (usize, usize)) -> bool {
649 let (k_h, k_w) = kernel_size;
650 let (s_h, s_w) = stride;
651
652 k_h == 3 && k_w == 3 && s_h == 1 && s_w == 1
654 }
655
656 pub fn conv2d_winograd(
658 input: &[f32],
659 kernel: &[f32],
660 output: &mut [f32],
661 input_dims: &[usize],
662 kernel_dims: &[usize],
663 output_dims: &[usize],
664 ) -> BackendResult<()> {
665 DirectConvolution::conv2d_direct(
668 input,
669 kernel,
670 output,
671 input_dims,
672 kernel_dims,
673 output_dims,
674 (1, 1),
675 (1, 1),
676 )
677 }
678 }
679}
680
681#[cfg(test)]
682mod tests {
683 use super::*;
684
685 #[test]
686 fn test_convolution_config_creation() {
687 let config = ConvolutionConfig::conv2d(1, 3, 16, (32, 32), (3, 3), (1, 1), (1, 1));
688
689 assert_eq!(config.conv_type, ConvolutionType::Conv2D);
690 assert_eq!(config.input_dims, vec![1, 3, 32, 32]);
691 assert_eq!(config.output_dims, vec![1, 16, 32, 32]);
692 assert_eq!(config.kernel_dims, vec![16, 3, 3, 3]);
693 assert!(config.is_valid());
694 }
695
696 #[test]
697 fn test_depthwise_config_creation() {
698 let config = ConvolutionConfig::depthwise_conv2d(1, 16, (32, 32), (3, 3), (1, 1), (1, 1));
699
700 assert_eq!(config.conv_type, ConvolutionType::DepthwiseConv2D);
701 assert_eq!(config.groups, 16);
702 assert_eq!(config.kernel_dims, vec![16, 1, 3, 3]);
703 assert!(config.is_valid());
704 }
705
706 #[test]
707 fn test_algorithm_selection() {
708 let ops = DefaultConvolutionOps::new();
709
710 let small_kernel_config =
712 ConvolutionConfig::conv2d(1, 3, 16, (32, 32), (3, 3), (1, 1), (1, 1));
713 assert_eq!(
714 ops.select_algorithm(&small_kernel_config),
715 ConvolutionAlgorithm::Winograd
716 );
717
718 let large_kernel_config =
720 ConvolutionConfig::conv2d(1, 3, 16, (32, 32), (9, 9), (1, 1), (4, 4));
721 assert_eq!(
722 ops.select_algorithm(&large_kernel_config),
723 ConvolutionAlgorithm::FftBased
724 );
725 }
726
727 #[test]
728 fn test_buffer_size_calculations() {
729 let config = ConvolutionConfig::conv2d(2, 3, 16, (32, 32), (3, 3), (1, 1), (1, 1));
730
731 assert_eq!(config.input_elements(), 2 * 3 * 32 * 32);
732 assert_eq!(config.output_elements(), 2 * 16 * 32 * 32);
733 assert_eq!(config.kernel_elements(), 16 * 3 * 3 * 3);
734
735 assert_eq!(config.input_buffer_size(), 2 * 3 * 32 * 32 * 4); assert_eq!(config.output_buffer_size(), 2 * 16 * 32 * 32 * 4);
737 assert_eq!(config.kernel_buffer_size(), 16 * 3 * 3 * 3 * 4);
738 }
739
740 #[test]
741 fn test_winograd_applicability() {
742 assert!(algorithms::WinogradConvolution::can_apply((3, 3), (1, 1)));
743 assert!(!algorithms::WinogradConvolution::can_apply((5, 5), (1, 1)));
744 assert!(!algorithms::WinogradConvolution::can_apply((3, 3), (2, 2)));
745 }
746}