Skip to main content

torsh_backend/
convolution.rs

1//! Convolution operations for all backends
2//!
3//! This module provides a unified interface for convolution operations across all backends,
4//! with optimized implementations for each platform including direct convolution,
5//! Winograd algorithm, FFT-based convolution, and im2col-based approaches.
6
7use crate::{BackendResult, Buffer, Device};
8use torsh_core::dtype::DType;
9
10#[cfg(not(feature = "std"))]
11use alloc::{boxed::Box, vec::Vec};
12
13/// Convolution operation type
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ConvolutionType {
16    /// 1D convolution
17    Conv1D,
18    /// 2D convolution (most common)
19    Conv2D,
20    /// 3D convolution
21    Conv3D,
22    /// Depthwise convolution
23    DepthwiseConv2D,
24    /// Separable convolution
25    SeparableConv2D,
26    /// Transposed convolution (deconvolution)
27    ConvTranspose2D,
28    /// Dilated convolution
29    DilatedConv2D,
30    /// Grouped convolution
31    GroupedConv2D,
32}
33
34/// Convolution algorithm implementation
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ConvolutionAlgorithm {
37    /// Auto-select best algorithm
38    Auto,
39    /// Direct convolution implementation
40    Direct,
41    /// Im2col + GEMM approach
42    Im2col,
43    /// Winograd algorithm for small kernels
44    Winograd,
45    /// FFT-based convolution for large kernels
46    FftBased,
47    /// Optimized backend-specific implementation
48    Optimized,
49}
50
51/// Padding mode for convolution
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum PaddingMode {
54    /// No padding
55    Valid,
56    /// Zero padding to maintain output size
57    Same,
58    /// Custom padding
59    Custom,
60}
61
62/// Convolution configuration
63#[derive(Debug, Clone)]
64pub struct ConvolutionConfig {
65    /// Convolution type
66    pub conv_type: ConvolutionType,
67    /// Input dimensions [batch, channels, height, width] for 2D
68    pub input_dims: Vec<usize>,
69    /// Output dimensions [batch, channels, height, width] for 2D
70    pub output_dims: Vec<usize>,
71    /// Kernel dimensions [out_channels, in_channels, height, width] for 2D
72    pub kernel_dims: Vec<usize>,
73    /// Stride in each dimension
74    pub strides: Vec<usize>,
75    /// Padding in each dimension
76    pub padding: Vec<usize>,
77    /// Dilation in each dimension
78    pub dilation: Vec<usize>,
79    /// Number of groups for grouped convolution
80    pub groups: usize,
81    /// Padding mode
82    pub padding_mode: PaddingMode,
83    /// Data type
84    pub dtype: DType,
85    /// Preferred algorithm
86    pub algorithm: ConvolutionAlgorithm,
87}
88
89impl ConvolutionConfig {
90    /// Create a new 2D convolution configuration
91    pub fn conv2d(
92        batch_size: usize,
93        in_channels: usize,
94        out_channels: usize,
95        input_size: (usize, usize),
96        kernel_size: (usize, usize),
97        stride: (usize, usize),
98        padding: (usize, usize),
99    ) -> Self {
100        let (in_h, in_w) = input_size;
101        let (k_h, k_w) = kernel_size;
102        let (s_h, s_w) = stride;
103        let (p_h, p_w) = padding;
104
105        // Calculate output dimensions
106        let out_h = (in_h + 2 * p_h - k_h) / s_h + 1;
107        let out_w = (in_w + 2 * p_w - k_w) / s_w + 1;
108
109        Self {
110            conv_type: ConvolutionType::Conv2D,
111            input_dims: vec![batch_size, in_channels, in_h, in_w],
112            output_dims: vec![batch_size, out_channels, out_h, out_w],
113            kernel_dims: vec![out_channels, in_channels, k_h, k_w],
114            strides: vec![s_h, s_w],
115            padding: vec![p_h, p_w],
116            dilation: vec![1, 1],
117            groups: 1,
118            padding_mode: PaddingMode::Custom,
119            dtype: DType::F32,
120            algorithm: ConvolutionAlgorithm::Auto,
121        }
122    }
123
124    /// Create a depthwise convolution configuration
125    pub fn depthwise_conv2d(
126        batch_size: usize,
127        channels: usize,
128        input_size: (usize, usize),
129        kernel_size: (usize, usize),
130        stride: (usize, usize),
131        padding: (usize, usize),
132    ) -> Self {
133        let mut config = Self::conv2d(
134            batch_size,
135            channels,
136            channels,
137            input_size,
138            kernel_size,
139            stride,
140            padding,
141        );
142        config.conv_type = ConvolutionType::DepthwiseConv2D;
143        config.groups = channels;
144        config.kernel_dims = vec![channels, 1, kernel_size.0, kernel_size.1];
145        config
146    }
147
148    /// Set the algorithm preference
149    pub fn with_algorithm(mut self, algorithm: ConvolutionAlgorithm) -> Self {
150        self.algorithm = algorithm;
151        self
152    }
153
154    /// Set the data type
155    pub fn with_dtype(mut self, dtype: DType) -> Self {
156        self.dtype = dtype;
157        self
158    }
159
160    /// Set dilation
161    pub fn with_dilation(mut self, dilation: Vec<usize>) -> Self {
162        self.dilation = dilation;
163        self
164    }
165
166    /// Calculate total input elements
167    pub fn input_elements(&self) -> usize {
168        self.input_dims.iter().product()
169    }
170
171    /// Calculate total output elements
172    pub fn output_elements(&self) -> usize {
173        self.output_dims.iter().product()
174    }
175
176    /// Calculate total kernel elements
177    pub fn kernel_elements(&self) -> usize {
178        self.kernel_dims.iter().product()
179    }
180
181    /// Get input buffer size in bytes
182    pub fn input_buffer_size(&self) -> usize {
183        let element_size = match self.dtype {
184            DType::F32 => 4,
185            DType::F64 => 8,
186            DType::F16 => 2,
187            _ => 4,
188        };
189        self.input_elements() * element_size
190    }
191
192    /// Get output buffer size in bytes
193    pub fn output_buffer_size(&self) -> usize {
194        let element_size = match self.dtype {
195            DType::F32 => 4,
196            DType::F64 => 8,
197            DType::F16 => 2,
198            _ => 4,
199        };
200        self.output_elements() * element_size
201    }
202
203    /// Get kernel buffer size in bytes
204    pub fn kernel_buffer_size(&self) -> usize {
205        let element_size = match self.dtype {
206            DType::F32 => 4,
207            DType::F64 => 8,
208            DType::F16 => 2,
209            _ => 4,
210        };
211        self.kernel_elements() * element_size
212    }
213
214    /// Check if the configuration is valid
215    pub fn is_valid(&self) -> bool {
216        !self.input_dims.is_empty()
217            && !self.output_dims.is_empty()
218            && !self.kernel_dims.is_empty()
219            && self.input_dims.iter().all(|&d| d > 0)
220            && self.output_dims.iter().all(|&d| d > 0)
221            && self.kernel_dims.iter().all(|&d| d > 0)
222            && self.groups > 0
223    }
224}
225
226/// Convolution operations trait
227#[async_trait::async_trait]
228pub trait ConvolutionOps: Send + Sync {
229    /// Execute a convolution operation
230    async fn convolution(
231        &self,
232        device: &Device,
233        input: &Buffer,
234        kernel: &Buffer,
235        bias: Option<&Buffer>,
236        output: &Buffer,
237        config: &ConvolutionConfig,
238    ) -> BackendResult<()>;
239
240    /// Execute a 2D convolution
241    async fn conv2d(
242        &self,
243        device: &Device,
244        input: &Buffer,
245        kernel: &Buffer,
246        bias: Option<&Buffer>,
247        output: &Buffer,
248        stride: (usize, usize),
249        padding: (usize, usize),
250        dilation: (usize, usize),
251    ) -> BackendResult<()>;
252
253    /// Execute a depthwise convolution
254    async fn depthwise_conv2d(
255        &self,
256        device: &Device,
257        input: &Buffer,
258        kernel: &Buffer,
259        bias: Option<&Buffer>,
260        output: &Buffer,
261        stride: (usize, usize),
262        padding: (usize, usize),
263    ) -> BackendResult<()>;
264
265    /// Execute a transposed convolution
266    async fn conv_transpose2d(
267        &self,
268        device: &Device,
269        input: &Buffer,
270        kernel: &Buffer,
271        bias: Option<&Buffer>,
272        output: &Buffer,
273        stride: (usize, usize),
274        padding: (usize, usize),
275        output_padding: (usize, usize),
276    ) -> BackendResult<()>;
277
278    /// Execute a grouped convolution
279    async fn grouped_conv2d(
280        &self,
281        device: &Device,
282        input: &Buffer,
283        kernel: &Buffer,
284        bias: Option<&Buffer>,
285        output: &Buffer,
286        groups: usize,
287        stride: (usize, usize),
288        padding: (usize, usize),
289    ) -> BackendResult<()>;
290
291    /// Get the best algorithm for given configuration
292    fn select_algorithm(&self, config: &ConvolutionConfig) -> ConvolutionAlgorithm;
293
294    /// Check if convolution operations are supported
295    fn supports_convolution(&self) -> bool;
296
297    /// Get supported convolution types
298    fn supported_conv_types(&self) -> Vec<ConvolutionType>;
299
300    /// Get supported algorithms
301    fn supported_algorithms(&self) -> Vec<ConvolutionAlgorithm>;
302}
303
304/// Performance characteristics for algorithm selection
305#[derive(Debug, Clone)]
306pub struct ConvolutionPerformanceHints {
307    /// Optimal algorithm for small kernels (3x3, 5x5)
308    pub small_kernel_algorithm: ConvolutionAlgorithm,
309    /// Optimal algorithm for large kernels (7x7, 9x9+)
310    pub large_kernel_algorithm: ConvolutionAlgorithm,
311    /// Threshold for switching to FFT-based convolution
312    pub fft_threshold: usize,
313    /// Threshold for using Winograd algorithm
314    pub winograd_threshold: usize,
315    /// Preferred tile size for tiled algorithms
316    pub tile_size: (usize, usize),
317    /// Memory bandwidth in GB/s
318    pub memory_bandwidth: f32,
319    /// Compute throughput in GOPS
320    pub compute_throughput: f32,
321}
322
323impl Default for ConvolutionPerformanceHints {
324    fn default() -> Self {
325        Self {
326            small_kernel_algorithm: ConvolutionAlgorithm::Winograd,
327            large_kernel_algorithm: ConvolutionAlgorithm::FftBased,
328            fft_threshold: 7,
329            winograd_threshold: 6,
330            tile_size: (16, 16),
331            memory_bandwidth: 50.0,
332            compute_throughput: 100.0,
333        }
334    }
335}
336
337/// Default convolution operations implementation
338pub struct DefaultConvolutionOps {
339    performance_hints: ConvolutionPerformanceHints,
340}
341
342impl DefaultConvolutionOps {
343    pub fn new() -> Self {
344        Self {
345            performance_hints: ConvolutionPerformanceHints::default(),
346        }
347    }
348
349    pub fn with_performance_hints(mut self, hints: ConvolutionPerformanceHints) -> Self {
350        self.performance_hints = hints;
351        self
352    }
353}
354
355#[async_trait::async_trait]
356impl ConvolutionOps for DefaultConvolutionOps {
357    async fn convolution(
358        &self,
359        _device: &Device,
360        _input: &Buffer,
361        _kernel: &Buffer,
362        _bias: Option<&Buffer>,
363        _output: &Buffer,
364        _config: &ConvolutionConfig,
365    ) -> BackendResult<()> {
366        Err(torsh_core::error::TorshError::BackendError(
367            "Convolution operations not implemented for this backend".to_string(),
368        ))
369    }
370
371    async fn conv2d(
372        &self,
373        _device: &Device,
374        _input: &Buffer,
375        _kernel: &Buffer,
376        _bias: Option<&Buffer>,
377        _output: &Buffer,
378        _stride: (usize, usize),
379        _padding: (usize, usize),
380        _dilation: (usize, usize),
381    ) -> BackendResult<()> {
382        Err(torsh_core::error::TorshError::BackendError(
383            "Conv2D operations not implemented for this backend".to_string(),
384        ))
385    }
386
387    async fn depthwise_conv2d(
388        &self,
389        _device: &Device,
390        _input: &Buffer,
391        _kernel: &Buffer,
392        _bias: Option<&Buffer>,
393        _output: &Buffer,
394        _stride: (usize, usize),
395        _padding: (usize, usize),
396    ) -> BackendResult<()> {
397        Err(torsh_core::error::TorshError::BackendError(
398            "Depthwise convolution not implemented for this backend".to_string(),
399        ))
400    }
401
402    async fn conv_transpose2d(
403        &self,
404        _device: &Device,
405        _input: &Buffer,
406        _kernel: &Buffer,
407        _bias: Option<&Buffer>,
408        _output: &Buffer,
409        _stride: (usize, usize),
410        _padding: (usize, usize),
411        _output_padding: (usize, usize),
412    ) -> BackendResult<()> {
413        Err(torsh_core::error::TorshError::BackendError(
414            "Transposed convolution not implemented for this backend".to_string(),
415        ))
416    }
417
418    async fn grouped_conv2d(
419        &self,
420        _device: &Device,
421        _input: &Buffer,
422        _kernel: &Buffer,
423        _bias: Option<&Buffer>,
424        _output: &Buffer,
425        _groups: usize,
426        _stride: (usize, usize),
427        _padding: (usize, usize),
428    ) -> BackendResult<()> {
429        Err(torsh_core::error::TorshError::BackendError(
430            "Grouped convolution not implemented for this backend".to_string(),
431        ))
432    }
433
434    fn select_algorithm(&self, config: &ConvolutionConfig) -> ConvolutionAlgorithm {
435        if config.algorithm != ConvolutionAlgorithm::Auto {
436            return config.algorithm;
437        }
438
439        // Auto-select based on kernel size and configuration
440        match config.conv_type {
441            ConvolutionType::Conv2D => {
442                if config.kernel_dims.len() >= 4 {
443                    let kernel_h = config.kernel_dims[2];
444                    let kernel_w = config.kernel_dims[3];
445                    let kernel_size = kernel_h.max(kernel_w);
446
447                    if kernel_size <= self.performance_hints.winograd_threshold {
448                        ConvolutionAlgorithm::Winograd
449                    } else if kernel_size >= self.performance_hints.fft_threshold {
450                        ConvolutionAlgorithm::FftBased
451                    } else {
452                        ConvolutionAlgorithm::Im2col
453                    }
454                } else {
455                    ConvolutionAlgorithm::Direct
456                }
457            }
458            ConvolutionType::DepthwiseConv2D => ConvolutionAlgorithm::Direct,
459            ConvolutionType::SeparableConv2D => ConvolutionAlgorithm::Direct,
460            _ => ConvolutionAlgorithm::Im2col,
461        }
462    }
463
464    fn supports_convolution(&self) -> bool {
465        false
466    }
467
468    fn supported_conv_types(&self) -> Vec<ConvolutionType> {
469        vec![]
470    }
471
472    fn supported_algorithms(&self) -> Vec<ConvolutionAlgorithm> {
473        vec![ConvolutionAlgorithm::Direct]
474    }
475}
476
477impl Default for DefaultConvolutionOps {
478    fn default() -> Self {
479        Self::new()
480    }
481}
482
483/// Convolution algorithm implementations
484pub mod algorithms {
485    use super::*;
486
487    /// Direct convolution implementation
488    pub struct DirectConvolution;
489
490    impl DirectConvolution {
491        /// Perform 2D convolution using direct approach
492        pub fn conv2d_direct(
493            input: &[f32],
494            kernel: &[f32],
495            output: &mut [f32],
496            input_dims: &[usize],
497            kernel_dims: &[usize],
498            output_dims: &[usize],
499            stride: (usize, usize),
500            padding: (usize, usize),
501        ) -> BackendResult<()> {
502            let (batch, in_channels, in_h, in_w) =
503                (input_dims[0], input_dims[1], input_dims[2], input_dims[3]);
504            let (out_channels, _, k_h, k_w) = (
505                kernel_dims[0],
506                kernel_dims[1],
507                kernel_dims[2],
508                kernel_dims[3],
509            );
510            let (_, _, out_h, out_w) = (
511                output_dims[0],
512                output_dims[1],
513                output_dims[2],
514                output_dims[3],
515            );
516            let (s_h, s_w) = stride;
517            let (p_h, p_w) = padding;
518
519            for b in 0..batch {
520                for oc in 0..out_channels {
521                    for oh in 0..out_h {
522                        for ow in 0..out_w {
523                            let mut sum = 0.0;
524
525                            for ic in 0..in_channels {
526                                for kh in 0..k_h {
527                                    for kw in 0..k_w {
528                                        let ih = oh * s_h + kh;
529                                        let iw = ow * s_w + kw;
530
531                                        if ih >= p_h
532                                            && iw >= p_w
533                                            && ih < in_h + p_h
534                                            && iw < in_w + p_w
535                                        {
536                                            let input_h = ih - p_h;
537                                            let input_w = iw - p_w;
538
539                                            if input_h < in_h && input_w < in_w {
540                                                let input_idx = b * in_channels * in_h * in_w
541                                                    + ic * in_h * in_w
542                                                    + input_h * in_w
543                                                    + input_w;
544                                                let kernel_idx = oc * in_channels * k_h * k_w
545                                                    + ic * k_h * k_w
546                                                    + kh * k_w
547                                                    + kw;
548
549                                                sum += input[input_idx] * kernel[kernel_idx];
550                                            }
551                                        }
552                                    }
553                                }
554                            }
555
556                            let output_idx = b * out_channels * out_h * out_w
557                                + oc * out_h * out_w
558                                + oh * out_w
559                                + ow;
560                            output[output_idx] = sum;
561                        }
562                    }
563                }
564            }
565
566            Ok(())
567        }
568    }
569
570    /// Im2col convolution implementation
571    pub struct Im2colConvolution;
572
573    impl Im2colConvolution {
574        /// Convert input to column matrix for GEMM-based convolution
575        pub fn im2col(
576            input: &[f32],
577            output: &mut [f32],
578            input_dims: &[usize],
579            kernel_size: (usize, usize),
580            stride: (usize, usize),
581            padding: (usize, usize),
582        ) -> BackendResult<()> {
583            let (batch, channels, height, width) =
584                (input_dims[0], input_dims[1], input_dims[2], input_dims[3]);
585            let (k_h, k_w) = kernel_size;
586            let (s_h, s_w) = stride;
587            let (p_h, p_w) = padding;
588
589            let out_h = (height + 2 * p_h - k_h) / s_h + 1;
590            let out_w = (width + 2 * p_w - k_w) / s_w + 1;
591
592            for b in 0..batch {
593                for c in 0..channels {
594                    for kh in 0..k_h {
595                        for kw in 0..k_w {
596                            for oh in 0..out_h {
597                                for ow in 0..out_w {
598                                    let ih = oh * s_h + kh;
599                                    let iw = ow * s_w + kw;
600
601                                    let value = if ih >= p_h
602                                        && iw >= p_w
603                                        && ih < height + p_h
604                                        && iw < width + p_w
605                                    {
606                                        let input_h = ih - p_h;
607                                        let input_w = iw - p_w;
608
609                                        if input_h < height && input_w < width {
610                                            let input_idx = b * channels * height * width
611                                                + c * height * width
612                                                + input_h * width
613                                                + input_w;
614                                            input[input_idx]
615                                        } else {
616                                            0.0
617                                        }
618                                    } else {
619                                        0.0
620                                    };
621
622                                    let col_idx =
623                                        (b * channels * k_h * k_w + c * k_h * k_w + kh * k_w + kw)
624                                            * out_h
625                                            * out_w
626                                            + oh * out_w
627                                            + ow;
628
629                                    if col_idx < output.len() {
630                                        output[col_idx] = value;
631                                    }
632                                }
633                            }
634                        }
635                    }
636                }
637            }
638
639            Ok(())
640        }
641    }
642
643    /// Winograd convolution implementation
644    pub struct WinogradConvolution;
645
646    impl WinogradConvolution {
647        /// Check if Winograd can be applied
648        pub fn can_apply(kernel_size: (usize, usize), stride: (usize, usize)) -> bool {
649            let (k_h, k_w) = kernel_size;
650            let (s_h, s_w) = stride;
651
652            // Winograd is most effective for 3x3 kernels with stride 1
653            k_h == 3 && k_w == 3 && s_h == 1 && s_w == 1
654        }
655
656        /// Perform Winograd convolution (simplified F(2,3) implementation)
657        pub fn conv2d_winograd(
658            input: &[f32],
659            kernel: &[f32],
660            output: &mut [f32],
661            input_dims: &[usize],
662            kernel_dims: &[usize],
663            output_dims: &[usize],
664        ) -> BackendResult<()> {
665            // For now, fall back to direct convolution
666            // A full Winograd implementation would involve complex matrix transformations
667            DirectConvolution::conv2d_direct(
668                input,
669                kernel,
670                output,
671                input_dims,
672                kernel_dims,
673                output_dims,
674                (1, 1),
675                (1, 1),
676            )
677        }
678    }
679}
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684
685    #[test]
686    fn test_convolution_config_creation() {
687        let config = ConvolutionConfig::conv2d(1, 3, 16, (32, 32), (3, 3), (1, 1), (1, 1));
688
689        assert_eq!(config.conv_type, ConvolutionType::Conv2D);
690        assert_eq!(config.input_dims, vec![1, 3, 32, 32]);
691        assert_eq!(config.output_dims, vec![1, 16, 32, 32]);
692        assert_eq!(config.kernel_dims, vec![16, 3, 3, 3]);
693        assert!(config.is_valid());
694    }
695
696    #[test]
697    fn test_depthwise_config_creation() {
698        let config = ConvolutionConfig::depthwise_conv2d(1, 16, (32, 32), (3, 3), (1, 1), (1, 1));
699
700        assert_eq!(config.conv_type, ConvolutionType::DepthwiseConv2D);
701        assert_eq!(config.groups, 16);
702        assert_eq!(config.kernel_dims, vec![16, 1, 3, 3]);
703        assert!(config.is_valid());
704    }
705
706    #[test]
707    fn test_algorithm_selection() {
708        let ops = DefaultConvolutionOps::new();
709
710        // Small kernel should prefer Winograd
711        let small_kernel_config =
712            ConvolutionConfig::conv2d(1, 3, 16, (32, 32), (3, 3), (1, 1), (1, 1));
713        assert_eq!(
714            ops.select_algorithm(&small_kernel_config),
715            ConvolutionAlgorithm::Winograd
716        );
717
718        // Large kernel should prefer FFT
719        let large_kernel_config =
720            ConvolutionConfig::conv2d(1, 3, 16, (32, 32), (9, 9), (1, 1), (4, 4));
721        assert_eq!(
722            ops.select_algorithm(&large_kernel_config),
723            ConvolutionAlgorithm::FftBased
724        );
725    }
726
727    #[test]
728    fn test_buffer_size_calculations() {
729        let config = ConvolutionConfig::conv2d(2, 3, 16, (32, 32), (3, 3), (1, 1), (1, 1));
730
731        assert_eq!(config.input_elements(), 2 * 3 * 32 * 32);
732        assert_eq!(config.output_elements(), 2 * 16 * 32 * 32);
733        assert_eq!(config.kernel_elements(), 16 * 3 * 3 * 3);
734
735        assert_eq!(config.input_buffer_size(), 2 * 3 * 32 * 32 * 4); // F32 = 4 bytes
736        assert_eq!(config.output_buffer_size(), 2 * 16 * 32 * 32 * 4);
737        assert_eq!(config.kernel_buffer_size(), 16 * 3 * 3 * 3 * 4);
738    }
739
740    #[test]
741    fn test_winograd_applicability() {
742        assert!(algorithms::WinogradConvolution::can_apply((3, 3), (1, 1)));
743        assert!(!algorithms::WinogradConvolution::can_apply((5, 5), (1, 1)));
744        assert!(!algorithms::WinogradConvolution::can_apply((3, 3), (2, 2)));
745    }
746}