1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::Tensor;
4use parking_lot::Mutex;
5use rayon::prelude::*;
6use std::collections::HashMap;
7use std::hint::black_box;
8use std::sync::{Arc, OnceLock};
9use std::time::Instant;
10
11#[derive(Debug)]
17pub struct Conv2dBackward {
18 pub input: Tensor,
19 pub weight: Tensor,
20 pub stride: (usize, usize),
21 pub padding: (usize, usize),
22}
23
24#[derive(Clone, Copy, PartialEq, Eq)]
25enum CpuConv2dStrategy {
26 Auto,
27 Profile,
28 Direct,
29 Im2col,
30}
31
32#[derive(Clone, Copy, PartialEq, Eq)]
33enum Conv2dKernelChoice {
34 Direct,
35 Im2col,
36}
37
38#[derive(Clone, Copy)]
39struct CpuConv2dConfig {
40 strategy: CpuConv2dStrategy,
41 min_work: usize,
42 profile_iters: usize,
43}
44
45#[derive(Clone, Copy, PartialEq, Eq)]
46enum CpuConv2dBwdStrategy {
47 Auto,
48 Profile,
49 Direct,
50 Im2col,
51}
52
53#[derive(Clone, Copy)]
54struct CpuConv2dBwdConfig {
55 strategy: CpuConv2dBwdStrategy,
56 min_work_grad_input: usize,
57 min_work_grad_weight: usize,
58 profile_iters_grad_input: usize,
59 profile_iters_grad_weight: usize,
60}
61
62#[derive(Clone, Copy, PartialEq, Eq, Hash)]
63enum Conv2dBwdTarget {
64 GradInput,
65 GradWeight,
66}
67
68type Conv2dPerfKey = (
69 usize,
70 usize,
71 usize,
72 usize,
73 usize,
74 usize,
75 usize,
76 usize,
77 usize,
78 usize,
79 usize,
80);
81type Conv2dBwdPerfKey = (Conv2dBwdTarget, Conv2dPerfKey);
82
83fn parse_usize_env(key: &str, default: usize) -> usize {
84 std::env::var(key)
85 .ok()
86 .and_then(|s| s.parse::<usize>().ok())
87 .unwrap_or(default)
88}
89
90fn cpu_conv2d_config() -> CpuConv2dConfig {
91 static CFG: OnceLock<CpuConv2dConfig> = OnceLock::new();
92 *CFG.get_or_init(|| {
93 let strategy = match std::env::var("RUSTORCH_CPU_CONV2D_STRATEGY")
94 .unwrap_or_else(|_| "auto".to_string())
95 .to_ascii_lowercase()
96 .as_str()
97 {
98 "im2col" => CpuConv2dStrategy::Im2col,
99 "direct" => CpuConv2dStrategy::Direct,
100 "profile" => CpuConv2dStrategy::Profile,
101 _ => CpuConv2dStrategy::Auto,
102 };
103 CpuConv2dConfig {
104 strategy,
105 min_work: parse_usize_env("RUSTORCH_CPU_CONV2D_MIN_WORK", 65536),
106 profile_iters: parse_usize_env("RUSTORCH_CPU_CONV2D_PROFILE_ITERS", 1),
107 }
108 })
109}
110
111fn conv2d_profile_cache() -> &'static Mutex<HashMap<Conv2dPerfKey, Conv2dKernelChoice>> {
112 static CACHE: OnceLock<Mutex<HashMap<Conv2dPerfKey, Conv2dKernelChoice>>> = OnceLock::new();
113 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
114}
115
116fn cpu_conv2d_bwd_config() -> CpuConv2dBwdConfig {
117 static CFG: OnceLock<CpuConv2dBwdConfig> = OnceLock::new();
118 *CFG.get_or_init(|| {
119 let strategy = match std::env::var("RUSTORCH_CPU_CONV2D_BWD_STRATEGY")
120 .unwrap_or_else(|_| "auto".to_string())
121 .to_ascii_lowercase()
122 .as_str()
123 {
124 "im2col" => CpuConv2dBwdStrategy::Im2col,
125 "direct" => CpuConv2dBwdStrategy::Direct,
126 "profile" => CpuConv2dBwdStrategy::Profile,
127 _ => CpuConv2dBwdStrategy::Auto,
128 };
129 CpuConv2dBwdConfig {
130 strategy,
131 min_work_grad_input: parse_usize_env(
132 "RUSTORCH_CPU_CONV2D_BWD_MIN_WORK_GRAD_INPUT",
133 parse_usize_env("RUSTORCH_CPU_CONV2D_BWD_MIN_WORK", 65536),
134 ),
135 min_work_grad_weight: parse_usize_env(
136 "RUSTORCH_CPU_CONV2D_BWD_MIN_WORK_GRAD_WEIGHT",
137 parse_usize_env("RUSTORCH_CPU_CONV2D_BWD_MIN_WORK", 65536),
138 ),
139 profile_iters_grad_input: parse_usize_env(
140 "RUSTORCH_CPU_CONV2D_BWD_PROFILE_ITERS_GRAD_INPUT",
141 parse_usize_env("RUSTORCH_CPU_CONV2D_BWD_PROFILE_ITERS", 1),
142 ),
143 profile_iters_grad_weight: parse_usize_env(
144 "RUSTORCH_CPU_CONV2D_BWD_PROFILE_ITERS_GRAD_WEIGHT",
145 parse_usize_env("RUSTORCH_CPU_CONV2D_BWD_PROFILE_ITERS", 1),
146 ),
147 }
148 })
149}
150
151fn conv2d_bwd_profile_cache() -> &'static Mutex<HashMap<Conv2dBwdPerfKey, Conv2dKernelChoice>> {
152 static CACHE: OnceLock<Mutex<HashMap<Conv2dBwdPerfKey, Conv2dKernelChoice>>> = OnceLock::new();
153 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
154}
155
156fn conv2d_direct_core(
157 input_data: &[f32],
158 weight_data: &[f32],
159 n: usize,
160 c_in: usize,
161 h_in: usize,
162 w_in: usize,
163 c_out: usize,
164 k_h: usize,
165 k_w: usize,
166 h_out: usize,
167 w_out: usize,
168 stride_h: usize,
169 stride_w: usize,
170 pad_h: usize,
171 pad_w: usize,
172) -> Vec<f32> {
173 let total_elements = n * c_out * h_out * w_out;
174 (0..total_elements)
175 .into_par_iter()
176 .map(|idx| {
177 let wo = idx % w_out;
178 let ho = (idx / w_out) % h_out;
179 let co = (idx / (w_out * h_out)) % c_out;
180 let b = idx / (w_out * h_out * c_out);
181
182 let mut sum: f64 = 0.0;
183 for ci in 0..c_in {
184 for kh in 0..k_h {
185 for kw in 0..k_w {
186 let h_in_idx = ho * stride_h + kh;
187 let w_in_idx = wo * stride_w + kw;
188 if h_in_idx >= pad_h && w_in_idx >= pad_w {
189 let hi = h_in_idx - pad_h;
190 let wi = w_in_idx - pad_w;
191 if hi < h_in && wi < w_in {
192 let val_in =
193 input_data[((b * c_in + ci) * h_in + hi) * w_in + wi] as f64;
194 let val_w =
195 weight_data[((co * c_in + ci) * k_h + kh) * k_w + kw] as f64;
196 sum += val_in * val_w;
197 }
198 }
199 }
200 }
201 }
202 sum as f32
203 })
204 .collect()
205}
206
207fn conv2d_im2col_core(
208 input_data: &[f32],
209 weight_data: &[f32],
210 n: usize,
211 c_in: usize,
212 h_in: usize,
213 w_in: usize,
214 c_out: usize,
215 k_h: usize,
216 k_w: usize,
217 h_out: usize,
218 w_out: usize,
219 stride_h: usize,
220 stride_w: usize,
221 pad_h: usize,
222 pad_w: usize,
223) -> Vec<f32> {
224 let k_size = c_in * k_h * k_w;
225 let patches = n * h_out * w_out;
226 let mut col = vec![0.0f32; patches * k_size];
227 for b in 0..n {
228 for ho in 0..h_out {
229 for wo in 0..w_out {
230 let row = (b * h_out + ho) * w_out + wo;
231 let mut col_idx = 0usize;
232 for ci in 0..c_in {
233 for kh in 0..k_h {
234 for kw in 0..k_w {
235 let h_in_idx = ho * stride_h + kh;
236 let w_in_idx = wo * stride_w + kw;
237 let val = if h_in_idx >= pad_h && w_in_idx >= pad_w {
238 let hi = h_in_idx - pad_h;
239 let wi = w_in_idx - pad_w;
240 if hi < h_in && wi < w_in {
241 input_data[((b * c_in + ci) * h_in + hi) * w_in + wi]
242 } else {
243 0.0
244 }
245 } else {
246 0.0
247 };
248 col[row * k_size + col_idx] = val;
249 col_idx += 1;
250 }
251 }
252 }
253 }
254 }
255 }
256
257 let col_t = Tensor::new(&col, &[patches, k_size]);
258 let weight_t = Tensor::new(weight_data, &[c_out, k_size]).t();
259 let out2d = crate::ops::matmul(&col_t, &weight_t);
260 let out_data = out2d.data();
261 let mut result = vec![0.0f32; n * c_out * h_out * w_out];
262 for b in 0..n {
263 for ho in 0..h_out {
264 for wo in 0..w_out {
265 let row = (b * h_out + ho) * w_out + wo;
266 for co in 0..c_out {
267 result[((b * c_out + co) * h_out + ho) * w_out + wo] =
268 out_data[row * c_out + co];
269 }
270 }
271 }
272 }
273 result
274}
275
276fn choose_conv2d_bwd_kernel<Fd, Fi>(
277 target: Conv2dBwdTarget,
278 key: Conv2dPerfKey,
279 work: usize,
280 direct_fn: Fd,
281 im2col_fn: Fi,
282) -> Conv2dKernelChoice
283where
284 Fd: Fn() -> Vec<f32>,
285 Fi: Fn() -> Vec<f32>,
286{
287 let cfg = cpu_conv2d_bwd_config();
288 let min_work = match target {
289 Conv2dBwdTarget::GradInput => cfg.min_work_grad_input,
290 Conv2dBwdTarget::GradWeight => cfg.min_work_grad_weight,
291 };
292 let profile_iters = match target {
293 Conv2dBwdTarget::GradInput => cfg.profile_iters_grad_input,
294 Conv2dBwdTarget::GradWeight => cfg.profile_iters_grad_weight,
295 };
296 match cfg.strategy {
297 CpuConv2dBwdStrategy::Direct => Conv2dKernelChoice::Direct,
298 CpuConv2dBwdStrategy::Im2col => Conv2dKernelChoice::Im2col,
299 CpuConv2dBwdStrategy::Auto => {
300 if work >= min_work {
301 Conv2dKernelChoice::Im2col
302 } else {
303 Conv2dKernelChoice::Direct
304 }
305 }
306 CpuConv2dBwdStrategy::Profile => {
307 let cache_key = (target, key);
308 if let Some(cached) = conv2d_bwd_profile_cache().lock().get(&cache_key).copied() {
309 return cached;
310 }
311 let iters = profile_iters.max(1);
312 let mut direct_ns = 0u128;
313 let mut im2col_ns = 0u128;
314 for _ in 0..iters {
315 let t0 = Instant::now();
316 let d = direct_fn();
317 direct_ns += t0.elapsed().as_nanos();
318 black_box(d.len());
319
320 let t1 = Instant::now();
321 let c = im2col_fn();
322 im2col_ns += t1.elapsed().as_nanos();
323 black_box(c.len());
324 }
325 let choice = if im2col_ns < direct_ns {
326 Conv2dKernelChoice::Im2col
327 } else {
328 Conv2dKernelChoice::Direct
329 };
330 conv2d_bwd_profile_cache().lock().insert(cache_key, choice);
331 choice
332 }
333 }
334}
335
336fn conv2d_grad_input_direct_core(
337 grad_data: &[f32],
338 weight_data: &[f32],
339 n: usize,
340 c_in: usize,
341 h_in: usize,
342 w_in: usize,
343 c_out: usize,
344 k_h: usize,
345 k_w: usize,
346 h_out: usize,
347 w_out: usize,
348 stride_h: usize,
349 stride_w: usize,
350 pad_h: usize,
351 pad_w: usize,
352) -> Vec<f32> {
353 let total_elements = n * c_in * h_in * w_in;
354 (0..total_elements)
355 .into_par_iter()
356 .map(|idx| {
357 let w = idx % w_in;
358 let h = (idx / w_in) % h_in;
359 let ci = (idx / (w_in * h_in)) % c_in;
360 let b = idx / (w_in * h_in * c_in);
361 let mut sum: f64 = 0.0;
362
363 let h_out_start = if h + pad_h >= k_h {
364 (h + pad_h - k_h + 1).div_ceil(stride_h)
365 } else {
366 0
367 };
368 let h_out_end = std::cmp::min(h_out, (h + pad_h) / stride_h + 1);
369
370 for ho in h_out_start..h_out_end {
371 let kh = h + pad_h - ho * stride_h;
372 let w_out_start = if w + pad_w >= k_w {
373 (w + pad_w - k_w + 1).div_ceil(stride_w)
374 } else {
375 0
376 };
377 let w_out_end = std::cmp::min(w_out, (w + pad_w) / stride_w + 1);
378
379 for wo in w_out_start..w_out_end {
380 let kw = w + pad_w - wo * stride_w;
381 for co in 0..c_out {
382 let g_val = grad_data[((b * c_out + co) * h_out + ho) * w_out + wo] as f64;
383 let w_val = weight_data[((co * c_in + ci) * k_h + kh) * k_w + kw] as f64;
384 sum += g_val * w_val;
385 }
386 }
387 }
388 sum as f32
389 })
390 .collect()
391}
392
393fn conv2d_grad_input_im2col_core(
394 grad_data: &[f32],
395 weight_data: &[f32],
396 n: usize,
397 c_in: usize,
398 h_in: usize,
399 w_in: usize,
400 c_out: usize,
401 k_h: usize,
402 k_w: usize,
403 h_out: usize,
404 w_out: usize,
405 stride_h: usize,
406 stride_w: usize,
407 pad_h: usize,
408 pad_w: usize,
409) -> Vec<f32> {
410 let patches = h_out * w_out;
411 let k_size = c_in * k_h * k_w;
412 let weight_flat = Tensor::new(weight_data, &[c_out, k_size]);
413 let mut grad_input_data = vec![0.0f32; n * c_in * h_in * w_in];
414
415 for b in 0..n {
416 let mut gy = vec![0.0f32; patches * c_out];
417 for ho in 0..h_out {
418 for wo in 0..w_out {
419 let row = ho * w_out + wo;
420 for co in 0..c_out {
421 gy[row * c_out + co] = grad_data[((b * c_out + co) * h_out + ho) * w_out + wo];
422 }
423 }
424 }
425 let gy_t = Tensor::new(&gy, &[patches, c_out]);
426 let dcol_t = crate::ops::matmul(&gy_t, &weight_flat);
427 let dcol = dcol_t.data();
428
429 for ho in 0..h_out {
430 for wo in 0..w_out {
431 let row = ho * w_out + wo;
432 let mut col_idx = 0usize;
433 for ci in 0..c_in {
434 for kh in 0..k_h {
435 for kw in 0..k_w {
436 let h_in_idx = ho * stride_h + kh;
437 let w_in_idx = wo * stride_w + kw;
438 if h_in_idx >= pad_h && w_in_idx >= pad_w {
439 let hi = h_in_idx - pad_h;
440 let wi = w_in_idx - pad_w;
441 if hi < h_in && wi < w_in {
442 grad_input_data[((b * c_in + ci) * h_in + hi) * w_in + wi] +=
443 dcol[row * k_size + col_idx];
444 }
445 }
446 col_idx += 1;
447 }
448 }
449 }
450 }
451 }
452 }
453
454 grad_input_data
455}
456
457fn conv2d_grad_weight_direct_core(
458 input_data: &[f32],
459 grad_data: &[f32],
460 n: usize,
461 c_in: usize,
462 h_in: usize,
463 w_in: usize,
464 c_out: usize,
465 k_h: usize,
466 k_w: usize,
467 h_out: usize,
468 w_out: usize,
469 stride_h: usize,
470 stride_w: usize,
471 pad_h: usize,
472 pad_w: usize,
473) -> Vec<f32> {
474 let total_elements = c_out * c_in * k_h * k_w;
475 (0..total_elements)
476 .into_par_iter()
477 .map(|idx| {
478 let kw = idx % k_w;
479 let kh = (idx / k_w) % k_h;
480 let ci = (idx / (k_w * k_h)) % c_in;
481 let co = idx / (k_w * k_h * c_in);
482 let mut sum: f64 = 0.0;
483 for b in 0..n {
484 for ho in 0..h_out {
485 for wo in 0..w_out {
486 let h_in_idx = ho * stride_h + kh;
487 let w_in_idx = wo * stride_w + kw;
488 if h_in_idx >= pad_h && w_in_idx >= pad_w {
489 let hi = h_in_idx - pad_h;
490 let wi = w_in_idx - pad_w;
491 if hi < h_in && wi < w_in {
492 let val_in =
493 input_data[((b * c_in + ci) * h_in + hi) * w_in + wi] as f64;
494 let val_g =
495 grad_data[((b * c_out + co) * h_out + ho) * w_out + wo] as f64;
496 sum += val_in * val_g;
497 }
498 }
499 }
500 }
501 }
502 sum as f32
503 })
504 .collect()
505}
506
507fn conv2d_grad_weight_im2col_core(
508 input_data: &[f32],
509 grad_data: &[f32],
510 n: usize,
511 c_in: usize,
512 h_in: usize,
513 w_in: usize,
514 c_out: usize,
515 k_h: usize,
516 k_w: usize,
517 h_out: usize,
518 w_out: usize,
519 stride_h: usize,
520 stride_w: usize,
521 pad_h: usize,
522 pad_w: usize,
523) -> Vec<f32> {
524 let patches = h_out * w_out;
525 let k_size = c_in * k_h * k_w;
526 let mut grad_weight = vec![0.0f32; c_out * k_size];
527
528 for b in 0..n {
529 let mut col = vec![0.0f32; patches * k_size];
530 let mut gy = vec![0.0f32; patches * c_out];
531
532 for ho in 0..h_out {
533 for wo in 0..w_out {
534 let row = ho * w_out + wo;
535 let mut col_idx = 0usize;
536 for ci in 0..c_in {
537 for kh in 0..k_h {
538 for kw in 0..k_w {
539 let h_in_idx = ho * stride_h + kh;
540 let w_in_idx = wo * stride_w + kw;
541 let val = if h_in_idx >= pad_h && w_in_idx >= pad_w {
542 let hi = h_in_idx - pad_h;
543 let wi = w_in_idx - pad_w;
544 if hi < h_in && wi < w_in {
545 input_data[((b * c_in + ci) * h_in + hi) * w_in + wi]
546 } else {
547 0.0
548 }
549 } else {
550 0.0
551 };
552 col[row * k_size + col_idx] = val;
553 col_idx += 1;
554 }
555 }
556 }
557 for co in 0..c_out {
558 gy[row * c_out + co] = grad_data[((b * c_out + co) * h_out + ho) * w_out + wo];
559 }
560 }
561 }
562
563 let gy_t = Tensor::new(&gy, &[patches, c_out]).t();
564 let col_t = Tensor::new(&col, &[patches, k_size]);
565 let gw_batch = crate::ops::matmul(&gy_t, &col_t);
566 let gw_data = gw_batch.data();
567 for i in 0..grad_weight.len() {
568 grad_weight[i] += gw_data[i];
569 }
570 }
571
572 grad_weight
573}
574
575impl BackwardOp for Conv2dBackward {
576 fn backward(&self, grad: &Tensor) {
577 let (stride_h, stride_w) = self.stride;
578 let (pad_h, pad_w) = self.padding;
579
580 let input_shape = self.input.shape();
581 let weight_shape = self.weight.shape();
582 let grad_shape = grad.shape();
583
584 let n = input_shape[0];
585 let c_in = input_shape[1];
586 let h_in = input_shape[2];
587 let w_in = input_shape[3];
588
589 let c_out = weight_shape[0];
590 let k_h = weight_shape[2];
591 let k_w = weight_shape[3];
592
593 let h_out = grad_shape[2];
594 let w_out = grad_shape[3];
595
596 if self.input.requires_grad() {
598 let grad_guard = grad.data();
599 let weight_guard = self.weight.data();
600 let grad_data = &*grad_guard;
601 let weight_data = &*weight_guard;
602 let work = n * c_in * h_in * w_in * c_out * k_h * k_w;
603 let key: Conv2dPerfKey = (
604 n, c_in, h_in, w_in, c_out, k_h, k_w, stride_h, stride_w, pad_h, pad_w,
605 );
606 let kernel = choose_conv2d_bwd_kernel(
607 Conv2dBwdTarget::GradInput,
608 key,
609 work,
610 || {
611 conv2d_grad_input_direct_core(
612 grad_data,
613 weight_data,
614 n,
615 c_in,
616 h_in,
617 w_in,
618 c_out,
619 k_h,
620 k_w,
621 h_out,
622 w_out,
623 stride_h,
624 stride_w,
625 pad_h,
626 pad_w,
627 )
628 },
629 || {
630 conv2d_grad_input_im2col_core(
631 grad_data,
632 weight_data,
633 n,
634 c_in,
635 h_in,
636 w_in,
637 c_out,
638 k_h,
639 k_w,
640 h_out,
641 w_out,
642 stride_h,
643 stride_w,
644 pad_h,
645 pad_w,
646 )
647 },
648 );
649 let grad_input_data = match kernel {
650 Conv2dKernelChoice::Direct => conv2d_grad_input_direct_core(
651 grad_data,
652 weight_data,
653 n,
654 c_in,
655 h_in,
656 w_in,
657 c_out,
658 k_h,
659 k_w,
660 h_out,
661 w_out,
662 stride_h,
663 stride_w,
664 pad_h,
665 pad_w,
666 ),
667 Conv2dKernelChoice::Im2col => conv2d_grad_input_im2col_core(
668 grad_data,
669 weight_data,
670 n,
671 c_in,
672 h_in,
673 w_in,
674 c_out,
675 k_h,
676 k_w,
677 h_out,
678 w_out,
679 stride_h,
680 stride_w,
681 pad_h,
682 pad_w,
683 ),
684 };
685
686 let grad_input_tensor =
687 Tensor::new_with_storage(Storage::new(grad_input_data), self.input.shape());
688 self.input.accumulate_grad(&grad_input_tensor);
689 self.input.backward_step();
690 }
691
692 if self.weight.requires_grad() {
694 let input_guard = self.input.data();
695 let grad_guard = grad.data();
696 let input_data = &*input_guard;
697 let grad_data = &*grad_guard;
698 let work = n * c_in * h_in * w_in * c_out * k_h * k_w;
699 let key: Conv2dPerfKey = (
700 n, c_in, h_in, w_in, c_out, k_h, k_w, stride_h, stride_w, pad_h, pad_w,
701 );
702 let kernel = choose_conv2d_bwd_kernel(
703 Conv2dBwdTarget::GradWeight,
704 key,
705 work,
706 || {
707 conv2d_grad_weight_direct_core(
708 input_data, grad_data, n, c_in, h_in, w_in, c_out, k_h, k_w, h_out, w_out,
709 stride_h, stride_w, pad_h, pad_w,
710 )
711 },
712 || {
713 conv2d_grad_weight_im2col_core(
714 input_data, grad_data, n, c_in, h_in, w_in, c_out, k_h, k_w, h_out, w_out,
715 stride_h, stride_w, pad_h, pad_w,
716 )
717 },
718 );
719 let grad_weight_data = match kernel {
720 Conv2dKernelChoice::Direct => conv2d_grad_weight_direct_core(
721 input_data, grad_data, n, c_in, h_in, w_in, c_out, k_h, k_w, h_out, w_out,
722 stride_h, stride_w, pad_h, pad_w,
723 ),
724 Conv2dKernelChoice::Im2col => conv2d_grad_weight_im2col_core(
725 input_data, grad_data, n, c_in, h_in, w_in, c_out, k_h, k_w, h_out, w_out,
726 stride_h, stride_w, pad_h, pad_w,
727 ),
728 };
729
730 let grad_weight_tensor =
731 Tensor::new_with_storage(Storage::new(grad_weight_data), self.weight.shape());
732 self.weight.accumulate_grad(&grad_weight_tensor);
733 self.weight.backward_step();
734 }
735 }
736}
737
738pub fn conv2d(
739 input: &Tensor,
740 weight: &Tensor,
741 stride: (usize, usize),
742 padding: (usize, usize),
743) -> Tensor {
744 let input_shape = input.shape();
745 let weight_shape = weight.shape();
746
747 if input_shape.len() != 4 || weight_shape.len() != 4 {
748 panic!("Conv2d requires 4D tensors");
749 }
750
751 let n = input_shape[0];
752 let c_in = input_shape[1];
753 let h_in = input_shape[2];
754 let w_in = input_shape[3];
755
756 let c_out = weight_shape[0];
757 let k_h = weight_shape[2];
758 let k_w = weight_shape[3];
759
760 if weight_shape[1] != c_in {
761 panic!(
762 "Weight input channels {} must match input channels {}",
763 weight_shape[1], c_in
764 );
765 }
766
767 let (stride_h, stride_w) = stride;
768 let (pad_h, pad_w) = padding;
769
770 let h_out = (h_in + 2 * pad_h - k_h) / stride_h + 1;
771 let w_out = (w_in + 2 * pad_w - k_w) / stride_w + 1;
772
773 let input_contig = if input.is_contiguous() {
774 input.clone()
775 } else {
776 input.contiguous()
777 };
778 let weight_contig = if weight.is_contiguous() {
779 weight.clone()
780 } else {
781 weight.contiguous()
782 };
783
784 let input_guard = input_contig.data();
785 let weight_guard = weight_contig.data();
786 let input_data = &*input_guard;
787 let weight_data = &*weight_guard;
788
789 let cfg = cpu_conv2d_config();
790 let work = n * c_out * h_out * w_out * c_in * k_h * k_w;
791 let key: Conv2dPerfKey = (
792 n, c_in, h_in, w_in, c_out, k_h, k_w, stride_h, stride_w, pad_h, pad_w,
793 );
794 let kernel = match cfg.strategy {
795 CpuConv2dStrategy::Direct => Conv2dKernelChoice::Direct,
796 CpuConv2dStrategy::Im2col => Conv2dKernelChoice::Im2col,
797 CpuConv2dStrategy::Auto => {
798 if work >= cfg.min_work {
799 Conv2dKernelChoice::Im2col
800 } else {
801 Conv2dKernelChoice::Direct
802 }
803 }
804 CpuConv2dStrategy::Profile => {
805 if let Some(cached) = conv2d_profile_cache().lock().get(&key).copied() {
806 cached
807 } else {
808 let iters = cfg.profile_iters.max(1);
809 let mut direct_ns = 0u128;
810 let mut im2col_ns = 0u128;
811 for _ in 0..iters {
812 let t0 = Instant::now();
813 let d = conv2d_direct_core(
814 input_data,
815 weight_data,
816 n,
817 c_in,
818 h_in,
819 w_in,
820 c_out,
821 k_h,
822 k_w,
823 h_out,
824 w_out,
825 stride_h,
826 stride_w,
827 pad_h,
828 pad_w,
829 );
830 direct_ns += t0.elapsed().as_nanos();
831 black_box(d.len());
832 let t1 = Instant::now();
833 let c = conv2d_im2col_core(
834 input_data,
835 weight_data,
836 n,
837 c_in,
838 h_in,
839 w_in,
840 c_out,
841 k_h,
842 k_w,
843 h_out,
844 w_out,
845 stride_h,
846 stride_w,
847 pad_h,
848 pad_w,
849 );
850 im2col_ns += t1.elapsed().as_nanos();
851 black_box(c.len());
852 }
853 let choice = if im2col_ns < direct_ns {
854 Conv2dKernelChoice::Im2col
855 } else {
856 Conv2dKernelChoice::Direct
857 };
858 conv2d_profile_cache().lock().insert(key, choice);
859 choice
860 }
861 }
862 };
863
864 let result_data = match kernel {
865 Conv2dKernelChoice::Direct => conv2d_direct_core(
866 input_data,
867 weight_data,
868 n,
869 c_in,
870 h_in,
871 w_in,
872 c_out,
873 k_h,
874 k_w,
875 h_out,
876 w_out,
877 stride_h,
878 stride_w,
879 pad_h,
880 pad_w,
881 ),
882 Conv2dKernelChoice::Im2col => conv2d_im2col_core(
883 input_data,
884 weight_data,
885 n,
886 c_in,
887 h_in,
888 w_in,
889 c_out,
890 k_h,
891 k_w,
892 h_out,
893 w_out,
894 stride_h,
895 stride_w,
896 pad_h,
897 pad_w,
898 ),
899 };
900
901 let storage = Storage::new(result_data);
902 let mut tensor = Tensor::new_with_storage(storage, &[n, c_out, h_out, w_out]);
903
904 if input.requires_grad() || weight.requires_grad() {
905 tensor.set_requires_grad_mut(true);
906 tensor.set_op(Arc::new(Conv2dBackward {
907 input: input.clone(),
908 weight: weight.clone(),
909 stride,
910 padding,
911 }));
912 }
913
914 tensor
919}