Skip to main content

torsh_backend/cpu/
convolution.rs

1//! CPU convolution implementation with optimized algorithms
2
3use crate::convolution::{
4    algorithms, ConvolutionAlgorithm, ConvolutionOps, ConvolutionPerformanceHints, ConvolutionType,
5    PaddingMode,
6};
7
8// Re-export for benchmarks
9pub use crate::convolution::ConvolutionConfig;
10use crate::cpu::buffer::BufferCpuExt;
11use crate::{BackendResult, Buffer, Device};
12
13#[cfg(not(feature = "std"))]
14use alloc::{boxed::Box, vec::Vec};
15
16/// CPU convolution operations implementation
17#[derive(Clone, Debug)]
18pub struct CpuConvolutionOps {
19    /// Performance hints for algorithm selection
20    performance_hints: ConvolutionPerformanceHints,
21    /// Number of threads for parallel processing
22    #[allow(dead_code)]
23    num_threads: usize,
24}
25
26impl CpuConvolutionOps {
27    /// Create a new CPU convolution operations instance
28    pub fn new(num_threads: Option<usize>) -> Self {
29        let num_threads = num_threads.unwrap_or_else(|| rayon::current_num_threads());
30
31        Self {
32            performance_hints: ConvolutionPerformanceHints {
33                small_kernel_algorithm: ConvolutionAlgorithm::Direct,
34                large_kernel_algorithm: ConvolutionAlgorithm::Im2col,
35                fft_threshold: 7,
36                winograd_threshold: 6,
37                tile_size: (16, 16),
38                memory_bandwidth: 100.0, // CPU memory bandwidth
39                compute_throughput: num_threads as f32 * 50.0, // Estimated GOPS
40            },
41            num_threads,
42        }
43    }
44
45    /// Copy buffer data safely for CPU
46    #[allow(dead_code)]
47    fn copy_buffer_data(&self, src: &Buffer, dst: &Buffer, size: usize) -> BackendResult<()> {
48        if !src.is_cpu() || !dst.is_cpu() {
49            return Err(torsh_core::error::TorshError::BackendError(
50                "Both buffers must be CPU buffers".to_string(),
51            ));
52        }
53
54        let src_ptr = src.as_cpu_ptr().ok_or_else(|| {
55            torsh_core::error::TorshError::BackendError(
56                "Failed to get source buffer pointer".to_string(),
57            )
58        })?;
59
60        let dst_ptr = dst.as_cpu_ptr().ok_or_else(|| {
61            torsh_core::error::TorshError::BackendError(
62                "Failed to get destination buffer pointer".to_string(),
63            )
64        })?;
65
66        if size > src.size.min(dst.size) {
67            return Err(torsh_core::error::TorshError::BackendError(format!(
68                "Copy size {} exceeds buffer capacity",
69                size
70            )));
71        }
72
73        unsafe {
74            std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, size);
75        }
76
77        Ok(())
78    }
79
80    /// Execute direct convolution on CPU
81    fn direct_convolution(
82        &self,
83        input: &Buffer,
84        kernel: &Buffer,
85        bias: Option<&Buffer>,
86        output: &Buffer,
87        config: &ConvolutionConfig,
88    ) -> BackendResult<()> {
89        // Get buffer pointers
90        let input_ptr = input.as_cpu_ptr().ok_or_else(|| {
91            torsh_core::error::TorshError::BackendError(
92                "Failed to get input buffer pointer".to_string(),
93            )
94        })?;
95
96        let kernel_ptr = kernel.as_cpu_ptr().ok_or_else(|| {
97            torsh_core::error::TorshError::BackendError(
98                "Failed to get kernel buffer pointer".to_string(),
99            )
100        })?;
101
102        let output_ptr = output.as_cpu_ptr().ok_or_else(|| {
103            torsh_core::error::TorshError::BackendError(
104                "Failed to get output buffer pointer".to_string(),
105            )
106        })?;
107
108        unsafe {
109            let input_data = std::slice::from_raw_parts(input_ptr as *const f32, input.size / 4);
110            let kernel_data = std::slice::from_raw_parts(kernel_ptr as *const f32, kernel.size / 4);
111            let output_data =
112                std::slice::from_raw_parts_mut(output_ptr as *mut f32, output.size / 4);
113
114            match config.conv_type {
115                ConvolutionType::Conv2D => {
116                    algorithms::DirectConvolution::conv2d_direct(
117                        input_data,
118                        kernel_data,
119                        output_data,
120                        &config.input_dims,
121                        &config.kernel_dims,
122                        &config.output_dims,
123                        (config.strides[0], config.strides[1]),
124                        (config.padding[0], config.padding[1]),
125                    )?;
126                }
127                ConvolutionType::DepthwiseConv2D => {
128                    // Simplified depthwise convolution - would need specialized implementation
129                    self.depthwise_direct_implementation(
130                        input_data,
131                        kernel_data,
132                        output_data,
133                        config,
134                    )?;
135                }
136                _ => {
137                    return Err(torsh_core::error::TorshError::BackendError(format!(
138                        "Convolution type {:?} not implemented yet",
139                        config.conv_type
140                    )));
141                }
142            }
143
144            // Add bias if provided
145            if let Some(bias_buffer) = bias {
146                let bias_ptr = bias_buffer.as_cpu_ptr().ok_or_else(|| {
147                    torsh_core::error::TorshError::BackendError(
148                        "Failed to get bias buffer pointer".to_string(),
149                    )
150                })?;
151                let bias_data =
152                    std::slice::from_raw_parts(bias_ptr as *const f32, bias_buffer.size / 4);
153
154                self.add_bias(output_data, bias_data, &config.output_dims)?;
155            }
156        }
157
158        Ok(())
159    }
160
161    /// Add bias to output
162    fn add_bias(
163        &self,
164        output: &mut [f32],
165        bias: &[f32],
166        output_dims: &[usize],
167    ) -> BackendResult<()> {
168        if output_dims.len() < 4 {
169            return Ok(());
170        }
171
172        let (batch, channels, height, width) = (
173            output_dims[0],
174            output_dims[1],
175            output_dims[2],
176            output_dims[3],
177        );
178
179        for b in 0..batch {
180            for c in 0..channels {
181                let bias_value = bias.get(c).copied().unwrap_or(0.0);
182                for h in 0..height {
183                    for w in 0..width {
184                        let idx =
185                            b * channels * height * width + c * height * width + h * width + w;
186                        if idx < output.len() {
187                            output[idx] += bias_value;
188                        }
189                    }
190                }
191            }
192        }
193
194        Ok(())
195    }
196
197    /// Simplified depthwise convolution implementation
198    fn depthwise_direct_implementation(
199        &self,
200        input: &[f32],
201        kernel: &[f32],
202        output: &mut [f32],
203        config: &ConvolutionConfig,
204    ) -> BackendResult<()> {
205        let (batch, channels, in_h, in_w) = (
206            config.input_dims[0],
207            config.input_dims[1],
208            config.input_dims[2],
209            config.input_dims[3],
210        );
211        let (_, _, k_h, k_w) = (
212            config.kernel_dims[0],
213            config.kernel_dims[1],
214            config.kernel_dims[2],
215            config.kernel_dims[3],
216        );
217        let (_, _, out_h, out_w) = (
218            config.output_dims[0],
219            config.output_dims[1],
220            config.output_dims[2],
221            config.output_dims[3],
222        );
223        let (s_h, s_w) = (config.strides[0], config.strides[1]);
224        let (p_h, p_w) = (config.padding[0], config.padding[1]);
225
226        for b in 0..batch {
227            for c in 0..channels {
228                for oh in 0..out_h {
229                    for ow in 0..out_w {
230                        let mut sum = 0.0;
231
232                        for kh in 0..k_h {
233                            for kw in 0..k_w {
234                                let ih = oh * s_h + kh;
235                                let iw = ow * s_w + kw;
236
237                                if ih >= p_h && iw >= p_w && ih < in_h + p_h && iw < in_w + p_w {
238                                    let input_h = ih - p_h;
239                                    let input_w = iw - p_w;
240
241                                    if input_h < in_h && input_w < in_w {
242                                        let input_idx = b * channels * in_h * in_w
243                                            + c * in_h * in_w
244                                            + input_h * in_w
245                                            + input_w;
246                                        let kernel_idx = c * k_h * k_w + kh * k_w + kw;
247
248                                        if input_idx < input.len() && kernel_idx < kernel.len() {
249                                            sum += input[input_idx] * kernel[kernel_idx];
250                                        }
251                                    }
252                                }
253                            }
254                        }
255
256                        let output_idx =
257                            b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
258
259                        if output_idx < output.len() {
260                            output[output_idx] = sum;
261                        }
262                    }
263                }
264            }
265        }
266
267        Ok(())
268    }
269}
270
271#[async_trait::async_trait]
272impl ConvolutionOps for CpuConvolutionOps {
273    async fn convolution(
274        &self,
275        _device: &Device,
276        input: &Buffer,
277        kernel: &Buffer,
278        bias: Option<&Buffer>,
279        output: &Buffer,
280        config: &ConvolutionConfig,
281    ) -> BackendResult<()> {
282        if !config.is_valid() {
283            return Err(torsh_core::error::TorshError::BackendError(
284                "Invalid convolution configuration".to_string(),
285            ));
286        }
287
288        let algorithm = self.select_algorithm(config);
289
290        match algorithm {
291            ConvolutionAlgorithm::Direct => {
292                self.direct_convolution(input, kernel, bias, output, config)
293            }
294            ConvolutionAlgorithm::Im2col => {
295                // For now, fall back to direct convolution
296                // A full im2col implementation would require GEMM operations
297                self.direct_convolution(input, kernel, bias, output, config)
298            }
299            ConvolutionAlgorithm::Winograd => {
300                // For now, fall back to direct convolution
301                // A full Winograd implementation would require specialized transforms
302                self.direct_convolution(input, kernel, bias, output, config)
303            }
304            ConvolutionAlgorithm::FftBased => {
305                // For now, fall back to direct convolution
306                // FFT-based convolution would use our FFT operations module
307                self.direct_convolution(input, kernel, bias, output, config)
308            }
309            _ => self.direct_convolution(input, kernel, bias, output, config),
310        }
311    }
312
313    async fn conv2d(
314        &self,
315        device: &Device,
316        input: &Buffer,
317        kernel: &Buffer,
318        bias: Option<&Buffer>,
319        output: &Buffer,
320        stride: (usize, usize),
321        padding: (usize, usize),
322        dilation: (usize, usize),
323    ) -> BackendResult<()> {
324        // Create a basic configuration from parameters
325        // For a full implementation, we'd need to infer dimensions from buffer sizes
326        let config = ConvolutionConfig {
327            conv_type: ConvolutionType::Conv2D,
328            input_dims: vec![1, 1, 32, 32],  // Placeholder dimensions
329            output_dims: vec![1, 1, 32, 32], // Placeholder dimensions
330            kernel_dims: vec![1, 1, 3, 3],   // Placeholder dimensions
331            strides: vec![stride.0, stride.1],
332            padding: vec![padding.0, padding.1],
333            dilation: vec![dilation.0, dilation.1],
334            groups: 1,
335            padding_mode: PaddingMode::Custom,
336            dtype: torsh_core::dtype::DType::F32,
337            algorithm: ConvolutionAlgorithm::Auto,
338        };
339
340        self.convolution(device, input, kernel, bias, output, &config)
341            .await
342    }
343
344    async fn depthwise_conv2d(
345        &self,
346        device: &Device,
347        input: &Buffer,
348        kernel: &Buffer,
349        bias: Option<&Buffer>,
350        output: &Buffer,
351        stride: (usize, usize),
352        padding: (usize, usize),
353    ) -> BackendResult<()> {
354        // Create depthwise configuration
355        let config = ConvolutionConfig {
356            conv_type: ConvolutionType::DepthwiseConv2D,
357            input_dims: vec![1, 16, 32, 32], // Placeholder dimensions
358            output_dims: vec![1, 16, 32, 32], // Placeholder dimensions
359            kernel_dims: vec![16, 1, 3, 3],  // Placeholder dimensions
360            strides: vec![stride.0, stride.1],
361            padding: vec![padding.0, padding.1],
362            dilation: vec![1, 1],
363            groups: 16, // Depthwise means groups = input channels
364            padding_mode: PaddingMode::Custom,
365            dtype: torsh_core::dtype::DType::F32,
366            algorithm: ConvolutionAlgorithm::Direct,
367        };
368
369        self.convolution(device, input, kernel, bias, output, &config)
370            .await
371    }
372
373    async fn conv_transpose2d(
374        &self,
375        _device: &Device,
376        _input: &Buffer,
377        _kernel: &Buffer,
378        _bias: Option<&Buffer>,
379        _output: &Buffer,
380        _stride: (usize, usize),
381        _padding: (usize, usize),
382        _output_padding: (usize, usize),
383    ) -> BackendResult<()> {
384        Err(torsh_core::error::TorshError::BackendError(
385            "Transposed convolution not implemented for CPU backend yet".to_string(),
386        ))
387    }
388
389    async fn grouped_conv2d(
390        &self,
391        device: &Device,
392        input: &Buffer,
393        kernel: &Buffer,
394        bias: Option<&Buffer>,
395        output: &Buffer,
396        groups: usize,
397        stride: (usize, usize),
398        padding: (usize, usize),
399    ) -> BackendResult<()> {
400        // Create grouped configuration
401        let config = ConvolutionConfig {
402            conv_type: ConvolutionType::GroupedConv2D,
403            input_dims: vec![1, 16, 32, 32], // Placeholder dimensions
404            output_dims: vec![1, 16, 32, 32], // Placeholder dimensions
405            kernel_dims: vec![16, 16 / groups, 3, 3], // Placeholder dimensions
406            strides: vec![stride.0, stride.1],
407            padding: vec![padding.0, padding.1],
408            dilation: vec![1, 1],
409            groups,
410            padding_mode: PaddingMode::Custom,
411            dtype: torsh_core::dtype::DType::F32,
412            algorithm: ConvolutionAlgorithm::Direct,
413        };
414
415        self.convolution(device, input, kernel, bias, output, &config)
416            .await
417    }
418
419    fn select_algorithm(&self, config: &ConvolutionConfig) -> ConvolutionAlgorithm {
420        if config.algorithm != ConvolutionAlgorithm::Auto {
421            return config.algorithm;
422        }
423
424        // Auto-select based on configuration and performance hints
425        match config.conv_type {
426            ConvolutionType::Conv2D => {
427                if config.kernel_dims.len() >= 4 {
428                    let kernel_h = config.kernel_dims[2];
429                    let kernel_w = config.kernel_dims[3];
430                    let kernel_size = kernel_h.max(kernel_w);
431
432                    if kernel_size <= 3 {
433                        // Small kernels work well with direct convolution on CPU
434                        ConvolutionAlgorithm::Direct
435                    } else if kernel_size <= self.performance_hints.winograd_threshold {
436                        ConvolutionAlgorithm::Winograd
437                    } else if kernel_size >= self.performance_hints.fft_threshold {
438                        ConvolutionAlgorithm::FftBased
439                    } else {
440                        ConvolutionAlgorithm::Im2col
441                    }
442                } else {
443                    ConvolutionAlgorithm::Direct
444                }
445            }
446            ConvolutionType::DepthwiseConv2D => ConvolutionAlgorithm::Direct,
447            ConvolutionType::SeparableConv2D => ConvolutionAlgorithm::Direct,
448            ConvolutionType::GroupedConv2D => ConvolutionAlgorithm::Direct,
449            _ => ConvolutionAlgorithm::Im2col,
450        }
451    }
452
453    fn supports_convolution(&self) -> bool {
454        true
455    }
456
457    fn supported_conv_types(&self) -> Vec<ConvolutionType> {
458        vec![
459            ConvolutionType::Conv1D,
460            ConvolutionType::Conv2D,
461            ConvolutionType::Conv3D,
462            ConvolutionType::DepthwiseConv2D,
463            ConvolutionType::SeparableConv2D,
464            ConvolutionType::GroupedConv2D,
465            // ConvolutionType::ConvTranspose2D, // Not implemented yet
466            ConvolutionType::DilatedConv2D,
467        ]
468    }
469
470    fn supported_algorithms(&self) -> Vec<ConvolutionAlgorithm> {
471        vec![
472            ConvolutionAlgorithm::Auto,
473            ConvolutionAlgorithm::Direct,
474            ConvolutionAlgorithm::Im2col,
475            ConvolutionAlgorithm::Winograd,
476            ConvolutionAlgorithm::FftBased,
477        ]
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use crate::convolution::ConvolutionConfig;
485
486    #[test]
487    fn test_cpu_convolution_ops_creation() {
488        let conv_ops = CpuConvolutionOps::new(Some(2));
489        assert!(conv_ops.supports_convolution());
490        assert!(!conv_ops.supported_conv_types().is_empty());
491        assert!(!conv_ops.supported_algorithms().is_empty());
492    }
493
494    #[test]
495    fn test_algorithm_selection() {
496        let conv_ops = CpuConvolutionOps::new(Some(1));
497
498        // Small kernel should use direct convolution on CPU
499        let small_config = ConvolutionConfig::conv2d(1, 3, 16, (32, 32), (3, 3), (1, 1), (1, 1));
500        assert_eq!(
501            conv_ops.select_algorithm(&small_config),
502            ConvolutionAlgorithm::Direct
503        );
504
505        // Large kernel should use FFT-based convolution
506        let large_config = ConvolutionConfig::conv2d(1, 3, 16, (32, 32), (9, 9), (1, 1), (4, 4));
507        assert_eq!(
508            conv_ops.select_algorithm(&large_config),
509            ConvolutionAlgorithm::FftBased
510        );
511
512        // Depthwise should always use direct
513        let depthwise_config =
514            ConvolutionConfig::depthwise_conv2d(1, 16, (32, 32), (3, 3), (1, 1), (1, 1));
515        assert_eq!(
516            conv_ops.select_algorithm(&depthwise_config),
517            ConvolutionAlgorithm::Direct
518        );
519    }
520}