1use crate::convolution::{
4 algorithms, ConvolutionAlgorithm, ConvolutionOps, ConvolutionPerformanceHints, ConvolutionType,
5 PaddingMode,
6};
7
8pub use crate::convolution::ConvolutionConfig;
10use crate::cpu::buffer::BufferCpuExt;
11use crate::{BackendResult, Buffer, Device};
12
13#[cfg(not(feature = "std"))]
14use alloc::{boxed::Box, vec::Vec};
15
16#[derive(Clone, Debug)]
18pub struct CpuConvolutionOps {
19 performance_hints: ConvolutionPerformanceHints,
21 #[allow(dead_code)]
23 num_threads: usize,
24}
25
26impl CpuConvolutionOps {
27 pub fn new(num_threads: Option<usize>) -> Self {
29 let num_threads = num_threads.unwrap_or_else(|| rayon::current_num_threads());
30
31 Self {
32 performance_hints: ConvolutionPerformanceHints {
33 small_kernel_algorithm: ConvolutionAlgorithm::Direct,
34 large_kernel_algorithm: ConvolutionAlgorithm::Im2col,
35 fft_threshold: 7,
36 winograd_threshold: 6,
37 tile_size: (16, 16),
38 memory_bandwidth: 100.0, compute_throughput: num_threads as f32 * 50.0, },
41 num_threads,
42 }
43 }
44
45 #[allow(dead_code)]
47 fn copy_buffer_data(&self, src: &Buffer, dst: &Buffer, size: usize) -> BackendResult<()> {
48 if !src.is_cpu() || !dst.is_cpu() {
49 return Err(torsh_core::error::TorshError::BackendError(
50 "Both buffers must be CPU buffers".to_string(),
51 ));
52 }
53
54 let src_ptr = src.as_cpu_ptr().ok_or_else(|| {
55 torsh_core::error::TorshError::BackendError(
56 "Failed to get source buffer pointer".to_string(),
57 )
58 })?;
59
60 let dst_ptr = dst.as_cpu_ptr().ok_or_else(|| {
61 torsh_core::error::TorshError::BackendError(
62 "Failed to get destination buffer pointer".to_string(),
63 )
64 })?;
65
66 if size > src.size.min(dst.size) {
67 return Err(torsh_core::error::TorshError::BackendError(format!(
68 "Copy size {} exceeds buffer capacity",
69 size
70 )));
71 }
72
73 unsafe {
74 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, size);
75 }
76
77 Ok(())
78 }
79
80 fn direct_convolution(
82 &self,
83 input: &Buffer,
84 kernel: &Buffer,
85 bias: Option<&Buffer>,
86 output: &Buffer,
87 config: &ConvolutionConfig,
88 ) -> BackendResult<()> {
89 let input_ptr = input.as_cpu_ptr().ok_or_else(|| {
91 torsh_core::error::TorshError::BackendError(
92 "Failed to get input buffer pointer".to_string(),
93 )
94 })?;
95
96 let kernel_ptr = kernel.as_cpu_ptr().ok_or_else(|| {
97 torsh_core::error::TorshError::BackendError(
98 "Failed to get kernel buffer pointer".to_string(),
99 )
100 })?;
101
102 let output_ptr = output.as_cpu_ptr().ok_or_else(|| {
103 torsh_core::error::TorshError::BackendError(
104 "Failed to get output buffer pointer".to_string(),
105 )
106 })?;
107
108 unsafe {
109 let input_data = std::slice::from_raw_parts(input_ptr as *const f32, input.size / 4);
110 let kernel_data = std::slice::from_raw_parts(kernel_ptr as *const f32, kernel.size / 4);
111 let output_data =
112 std::slice::from_raw_parts_mut(output_ptr as *mut f32, output.size / 4);
113
114 match config.conv_type {
115 ConvolutionType::Conv2D => {
116 algorithms::DirectConvolution::conv2d_direct(
117 input_data,
118 kernel_data,
119 output_data,
120 &config.input_dims,
121 &config.kernel_dims,
122 &config.output_dims,
123 (config.strides[0], config.strides[1]),
124 (config.padding[0], config.padding[1]),
125 )?;
126 }
127 ConvolutionType::DepthwiseConv2D => {
128 self.depthwise_direct_implementation(
130 input_data,
131 kernel_data,
132 output_data,
133 config,
134 )?;
135 }
136 _ => {
137 return Err(torsh_core::error::TorshError::BackendError(format!(
138 "Convolution type {:?} not implemented yet",
139 config.conv_type
140 )));
141 }
142 }
143
144 if let Some(bias_buffer) = bias {
146 let bias_ptr = bias_buffer.as_cpu_ptr().ok_or_else(|| {
147 torsh_core::error::TorshError::BackendError(
148 "Failed to get bias buffer pointer".to_string(),
149 )
150 })?;
151 let bias_data =
152 std::slice::from_raw_parts(bias_ptr as *const f32, bias_buffer.size / 4);
153
154 self.add_bias(output_data, bias_data, &config.output_dims)?;
155 }
156 }
157
158 Ok(())
159 }
160
161 fn add_bias(
163 &self,
164 output: &mut [f32],
165 bias: &[f32],
166 output_dims: &[usize],
167 ) -> BackendResult<()> {
168 if output_dims.len() < 4 {
169 return Ok(());
170 }
171
172 let (batch, channels, height, width) = (
173 output_dims[0],
174 output_dims[1],
175 output_dims[2],
176 output_dims[3],
177 );
178
179 for b in 0..batch {
180 for c in 0..channels {
181 let bias_value = bias.get(c).copied().unwrap_or(0.0);
182 for h in 0..height {
183 for w in 0..width {
184 let idx =
185 b * channels * height * width + c * height * width + h * width + w;
186 if idx < output.len() {
187 output[idx] += bias_value;
188 }
189 }
190 }
191 }
192 }
193
194 Ok(())
195 }
196
197 fn depthwise_direct_implementation(
199 &self,
200 input: &[f32],
201 kernel: &[f32],
202 output: &mut [f32],
203 config: &ConvolutionConfig,
204 ) -> BackendResult<()> {
205 let (batch, channels, in_h, in_w) = (
206 config.input_dims[0],
207 config.input_dims[1],
208 config.input_dims[2],
209 config.input_dims[3],
210 );
211 let (_, _, k_h, k_w) = (
212 config.kernel_dims[0],
213 config.kernel_dims[1],
214 config.kernel_dims[2],
215 config.kernel_dims[3],
216 );
217 let (_, _, out_h, out_w) = (
218 config.output_dims[0],
219 config.output_dims[1],
220 config.output_dims[2],
221 config.output_dims[3],
222 );
223 let (s_h, s_w) = (config.strides[0], config.strides[1]);
224 let (p_h, p_w) = (config.padding[0], config.padding[1]);
225
226 for b in 0..batch {
227 for c in 0..channels {
228 for oh in 0..out_h {
229 for ow in 0..out_w {
230 let mut sum = 0.0;
231
232 for kh in 0..k_h {
233 for kw in 0..k_w {
234 let ih = oh * s_h + kh;
235 let iw = ow * s_w + kw;
236
237 if ih >= p_h && iw >= p_w && ih < in_h + p_h && iw < in_w + p_w {
238 let input_h = ih - p_h;
239 let input_w = iw - p_w;
240
241 if input_h < in_h && input_w < in_w {
242 let input_idx = b * channels * in_h * in_w
243 + c * in_h * in_w
244 + input_h * in_w
245 + input_w;
246 let kernel_idx = c * k_h * k_w + kh * k_w + kw;
247
248 if input_idx < input.len() && kernel_idx < kernel.len() {
249 sum += input[input_idx] * kernel[kernel_idx];
250 }
251 }
252 }
253 }
254 }
255
256 let output_idx =
257 b * channels * out_h * out_w + c * out_h * out_w + oh * out_w + ow;
258
259 if output_idx < output.len() {
260 output[output_idx] = sum;
261 }
262 }
263 }
264 }
265 }
266
267 Ok(())
268 }
269}
270
271#[async_trait::async_trait]
272impl ConvolutionOps for CpuConvolutionOps {
273 async fn convolution(
274 &self,
275 _device: &Device,
276 input: &Buffer,
277 kernel: &Buffer,
278 bias: Option<&Buffer>,
279 output: &Buffer,
280 config: &ConvolutionConfig,
281 ) -> BackendResult<()> {
282 if !config.is_valid() {
283 return Err(torsh_core::error::TorshError::BackendError(
284 "Invalid convolution configuration".to_string(),
285 ));
286 }
287
288 let algorithm = self.select_algorithm(config);
289
290 match algorithm {
291 ConvolutionAlgorithm::Direct => {
292 self.direct_convolution(input, kernel, bias, output, config)
293 }
294 ConvolutionAlgorithm::Im2col => {
295 self.direct_convolution(input, kernel, bias, output, config)
298 }
299 ConvolutionAlgorithm::Winograd => {
300 self.direct_convolution(input, kernel, bias, output, config)
303 }
304 ConvolutionAlgorithm::FftBased => {
305 self.direct_convolution(input, kernel, bias, output, config)
308 }
309 _ => self.direct_convolution(input, kernel, bias, output, config),
310 }
311 }
312
313 async fn conv2d(
314 &self,
315 device: &Device,
316 input: &Buffer,
317 kernel: &Buffer,
318 bias: Option<&Buffer>,
319 output: &Buffer,
320 stride: (usize, usize),
321 padding: (usize, usize),
322 dilation: (usize, usize),
323 ) -> BackendResult<()> {
324 let config = ConvolutionConfig {
327 conv_type: ConvolutionType::Conv2D,
328 input_dims: vec![1, 1, 32, 32], output_dims: vec![1, 1, 32, 32], kernel_dims: vec![1, 1, 3, 3], strides: vec![stride.0, stride.1],
332 padding: vec![padding.0, padding.1],
333 dilation: vec![dilation.0, dilation.1],
334 groups: 1,
335 padding_mode: PaddingMode::Custom,
336 dtype: torsh_core::dtype::DType::F32,
337 algorithm: ConvolutionAlgorithm::Auto,
338 };
339
340 self.convolution(device, input, kernel, bias, output, &config)
341 .await
342 }
343
344 async fn depthwise_conv2d(
345 &self,
346 device: &Device,
347 input: &Buffer,
348 kernel: &Buffer,
349 bias: Option<&Buffer>,
350 output: &Buffer,
351 stride: (usize, usize),
352 padding: (usize, usize),
353 ) -> BackendResult<()> {
354 let config = ConvolutionConfig {
356 conv_type: ConvolutionType::DepthwiseConv2D,
357 input_dims: vec![1, 16, 32, 32], output_dims: vec![1, 16, 32, 32], kernel_dims: vec![16, 1, 3, 3], strides: vec![stride.0, stride.1],
361 padding: vec![padding.0, padding.1],
362 dilation: vec![1, 1],
363 groups: 16, padding_mode: PaddingMode::Custom,
365 dtype: torsh_core::dtype::DType::F32,
366 algorithm: ConvolutionAlgorithm::Direct,
367 };
368
369 self.convolution(device, input, kernel, bias, output, &config)
370 .await
371 }
372
373 async fn conv_transpose2d(
374 &self,
375 _device: &Device,
376 _input: &Buffer,
377 _kernel: &Buffer,
378 _bias: Option<&Buffer>,
379 _output: &Buffer,
380 _stride: (usize, usize),
381 _padding: (usize, usize),
382 _output_padding: (usize, usize),
383 ) -> BackendResult<()> {
384 Err(torsh_core::error::TorshError::BackendError(
385 "Transposed convolution not implemented for CPU backend yet".to_string(),
386 ))
387 }
388
389 async fn grouped_conv2d(
390 &self,
391 device: &Device,
392 input: &Buffer,
393 kernel: &Buffer,
394 bias: Option<&Buffer>,
395 output: &Buffer,
396 groups: usize,
397 stride: (usize, usize),
398 padding: (usize, usize),
399 ) -> BackendResult<()> {
400 let config = ConvolutionConfig {
402 conv_type: ConvolutionType::GroupedConv2D,
403 input_dims: vec![1, 16, 32, 32], output_dims: vec![1, 16, 32, 32], kernel_dims: vec![16, 16 / groups, 3, 3], strides: vec![stride.0, stride.1],
407 padding: vec![padding.0, padding.1],
408 dilation: vec![1, 1],
409 groups,
410 padding_mode: PaddingMode::Custom,
411 dtype: torsh_core::dtype::DType::F32,
412 algorithm: ConvolutionAlgorithm::Direct,
413 };
414
415 self.convolution(device, input, kernel, bias, output, &config)
416 .await
417 }
418
419 fn select_algorithm(&self, config: &ConvolutionConfig) -> ConvolutionAlgorithm {
420 if config.algorithm != ConvolutionAlgorithm::Auto {
421 return config.algorithm;
422 }
423
424 match config.conv_type {
426 ConvolutionType::Conv2D => {
427 if config.kernel_dims.len() >= 4 {
428 let kernel_h = config.kernel_dims[2];
429 let kernel_w = config.kernel_dims[3];
430 let kernel_size = kernel_h.max(kernel_w);
431
432 if kernel_size <= 3 {
433 ConvolutionAlgorithm::Direct
435 } else if kernel_size <= self.performance_hints.winograd_threshold {
436 ConvolutionAlgorithm::Winograd
437 } else if kernel_size >= self.performance_hints.fft_threshold {
438 ConvolutionAlgorithm::FftBased
439 } else {
440 ConvolutionAlgorithm::Im2col
441 }
442 } else {
443 ConvolutionAlgorithm::Direct
444 }
445 }
446 ConvolutionType::DepthwiseConv2D => ConvolutionAlgorithm::Direct,
447 ConvolutionType::SeparableConv2D => ConvolutionAlgorithm::Direct,
448 ConvolutionType::GroupedConv2D => ConvolutionAlgorithm::Direct,
449 _ => ConvolutionAlgorithm::Im2col,
450 }
451 }
452
453 fn supports_convolution(&self) -> bool {
454 true
455 }
456
457 fn supported_conv_types(&self) -> Vec<ConvolutionType> {
458 vec![
459 ConvolutionType::Conv1D,
460 ConvolutionType::Conv2D,
461 ConvolutionType::Conv3D,
462 ConvolutionType::DepthwiseConv2D,
463 ConvolutionType::SeparableConv2D,
464 ConvolutionType::GroupedConv2D,
465 ConvolutionType::DilatedConv2D,
467 ]
468 }
469
470 fn supported_algorithms(&self) -> Vec<ConvolutionAlgorithm> {
471 vec![
472 ConvolutionAlgorithm::Auto,
473 ConvolutionAlgorithm::Direct,
474 ConvolutionAlgorithm::Im2col,
475 ConvolutionAlgorithm::Winograd,
476 ConvolutionAlgorithm::FftBased,
477 ]
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use crate::convolution::ConvolutionConfig;
485
486 #[test]
487 fn test_cpu_convolution_ops_creation() {
488 let conv_ops = CpuConvolutionOps::new(Some(2));
489 assert!(conv_ops.supports_convolution());
490 assert!(!conv_ops.supported_conv_types().is_empty());
491 assert!(!conv_ops.supported_algorithms().is_empty());
492 }
493
494 #[test]
495 fn test_algorithm_selection() {
496 let conv_ops = CpuConvolutionOps::new(Some(1));
497
498 let small_config = ConvolutionConfig::conv2d(1, 3, 16, (32, 32), (3, 3), (1, 1), (1, 1));
500 assert_eq!(
501 conv_ops.select_algorithm(&small_config),
502 ConvolutionAlgorithm::Direct
503 );
504
505 let large_config = ConvolutionConfig::conv2d(1, 3, 16, (32, 32), (9, 9), (1, 1), (4, 4));
507 assert_eq!(
508 conv_ops.select_algorithm(&large_config),
509 ConvolutionAlgorithm::FftBased
510 );
511
512 let depthwise_config =
514 ConvolutionConfig::depthwise_conv2d(1, 16, (32, 32), (3, 3), (1, 1), (1, 1));
515 assert_eq!(
516 conv_ops.select_algorithm(&depthwise_config),
517 ConvolutionAlgorithm::Direct
518 );
519 }
520}