Skip to main content

rustorch_core/ops/
conv.rs

1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::Tensor;
4use parking_lot::Mutex;
5use rayon::prelude::*;
6use std::collections::HashMap;
7use std::hint::black_box;
8use std::sync::{Arc, OnceLock};
9use std::time::Instant;
10
11// --- Conv2d ---
12// Input: (N, C_in, H, W)
13// Weight: (C_out, C_in, kH, kW)
14// Output: (N, C_out, H_out, W_out)
15
16#[derive(Debug)]
17pub struct Conv2dBackward {
18    pub input: Tensor,
19    pub weight: Tensor,
20    pub stride: (usize, usize),
21    pub padding: (usize, usize),
22}
23
24#[derive(Clone, Copy, PartialEq, Eq)]
25enum CpuConv2dStrategy {
26    Auto,
27    Profile,
28    Direct,
29    Im2col,
30}
31
32#[derive(Clone, Copy, PartialEq, Eq)]
33enum Conv2dKernelChoice {
34    Direct,
35    Im2col,
36}
37
38#[derive(Clone, Copy)]
39struct CpuConv2dConfig {
40    strategy: CpuConv2dStrategy,
41    min_work: usize,
42    profile_iters: usize,
43}
44
45#[derive(Clone, Copy, PartialEq, Eq)]
46enum CpuConv2dBwdStrategy {
47    Auto,
48    Profile,
49    Direct,
50    Im2col,
51}
52
53#[derive(Clone, Copy)]
54struct CpuConv2dBwdConfig {
55    strategy: CpuConv2dBwdStrategy,
56    min_work_grad_input: usize,
57    min_work_grad_weight: usize,
58    profile_iters_grad_input: usize,
59    profile_iters_grad_weight: usize,
60}
61
62#[derive(Clone, Copy, PartialEq, Eq, Hash)]
63enum Conv2dBwdTarget {
64    GradInput,
65    GradWeight,
66}
67
68type Conv2dPerfKey = (
69    usize,
70    usize,
71    usize,
72    usize,
73    usize,
74    usize,
75    usize,
76    usize,
77    usize,
78    usize,
79    usize,
80);
81type Conv2dBwdPerfKey = (Conv2dBwdTarget, Conv2dPerfKey);
82
83fn parse_usize_env(key: &str, default: usize) -> usize {
84    std::env::var(key)
85        .ok()
86        .and_then(|s| s.parse::<usize>().ok())
87        .unwrap_or(default)
88}
89
90fn cpu_conv2d_config() -> CpuConv2dConfig {
91    static CFG: OnceLock<CpuConv2dConfig> = OnceLock::new();
92    *CFG.get_or_init(|| {
93        let strategy = match std::env::var("RUSTORCH_CPU_CONV2D_STRATEGY")
94            .unwrap_or_else(|_| "auto".to_string())
95            .to_ascii_lowercase()
96            .as_str()
97        {
98            "im2col" => CpuConv2dStrategy::Im2col,
99            "direct" => CpuConv2dStrategy::Direct,
100            "profile" => CpuConv2dStrategy::Profile,
101            _ => CpuConv2dStrategy::Auto,
102        };
103        CpuConv2dConfig {
104            strategy,
105            min_work: parse_usize_env("RUSTORCH_CPU_CONV2D_MIN_WORK", 65536),
106            profile_iters: parse_usize_env("RUSTORCH_CPU_CONV2D_PROFILE_ITERS", 1),
107        }
108    })
109}
110
111fn conv2d_profile_cache() -> &'static Mutex<HashMap<Conv2dPerfKey, Conv2dKernelChoice>> {
112    static CACHE: OnceLock<Mutex<HashMap<Conv2dPerfKey, Conv2dKernelChoice>>> = OnceLock::new();
113    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
114}
115
116fn cpu_conv2d_bwd_config() -> CpuConv2dBwdConfig {
117    static CFG: OnceLock<CpuConv2dBwdConfig> = OnceLock::new();
118    *CFG.get_or_init(|| {
119        let strategy = match std::env::var("RUSTORCH_CPU_CONV2D_BWD_STRATEGY")
120            .unwrap_or_else(|_| "auto".to_string())
121            .to_ascii_lowercase()
122            .as_str()
123        {
124            "im2col" => CpuConv2dBwdStrategy::Im2col,
125            "direct" => CpuConv2dBwdStrategy::Direct,
126            "profile" => CpuConv2dBwdStrategy::Profile,
127            _ => CpuConv2dBwdStrategy::Auto,
128        };
129        CpuConv2dBwdConfig {
130            strategy,
131            min_work_grad_input: parse_usize_env(
132                "RUSTORCH_CPU_CONV2D_BWD_MIN_WORK_GRAD_INPUT",
133                parse_usize_env("RUSTORCH_CPU_CONV2D_BWD_MIN_WORK", 65536),
134            ),
135            min_work_grad_weight: parse_usize_env(
136                "RUSTORCH_CPU_CONV2D_BWD_MIN_WORK_GRAD_WEIGHT",
137                parse_usize_env("RUSTORCH_CPU_CONV2D_BWD_MIN_WORK", 65536),
138            ),
139            profile_iters_grad_input: parse_usize_env(
140                "RUSTORCH_CPU_CONV2D_BWD_PROFILE_ITERS_GRAD_INPUT",
141                parse_usize_env("RUSTORCH_CPU_CONV2D_BWD_PROFILE_ITERS", 1),
142            ),
143            profile_iters_grad_weight: parse_usize_env(
144                "RUSTORCH_CPU_CONV2D_BWD_PROFILE_ITERS_GRAD_WEIGHT",
145                parse_usize_env("RUSTORCH_CPU_CONV2D_BWD_PROFILE_ITERS", 1),
146            ),
147        }
148    })
149}
150
151fn conv2d_bwd_profile_cache() -> &'static Mutex<HashMap<Conv2dBwdPerfKey, Conv2dKernelChoice>> {
152    static CACHE: OnceLock<Mutex<HashMap<Conv2dBwdPerfKey, Conv2dKernelChoice>>> = OnceLock::new();
153    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
154}
155
156fn conv2d_direct_core(
157    input_data: &[f32],
158    weight_data: &[f32],
159    n: usize,
160    c_in: usize,
161    h_in: usize,
162    w_in: usize,
163    c_out: usize,
164    k_h: usize,
165    k_w: usize,
166    h_out: usize,
167    w_out: usize,
168    stride_h: usize,
169    stride_w: usize,
170    pad_h: usize,
171    pad_w: usize,
172) -> Vec<f32> {
173    let total_elements = n * c_out * h_out * w_out;
174    (0..total_elements)
175        .into_par_iter()
176        .map(|idx| {
177            let wo = idx % w_out;
178            let ho = (idx / w_out) % h_out;
179            let co = (idx / (w_out * h_out)) % c_out;
180            let b = idx / (w_out * h_out * c_out);
181
182            let mut sum: f64 = 0.0;
183            for ci in 0..c_in {
184                for kh in 0..k_h {
185                    for kw in 0..k_w {
186                        let h_in_idx = ho * stride_h + kh;
187                        let w_in_idx = wo * stride_w + kw;
188                        if h_in_idx >= pad_h && w_in_idx >= pad_w {
189                            let hi = h_in_idx - pad_h;
190                            let wi = w_in_idx - pad_w;
191                            if hi < h_in && wi < w_in {
192                                let val_in =
193                                    input_data[((b * c_in + ci) * h_in + hi) * w_in + wi] as f64;
194                                let val_w =
195                                    weight_data[((co * c_in + ci) * k_h + kh) * k_w + kw] as f64;
196                                sum += val_in * val_w;
197                            }
198                        }
199                    }
200                }
201            }
202            sum as f32
203        })
204        .collect()
205}
206
207fn conv2d_im2col_core(
208    input_data: &[f32],
209    weight_data: &[f32],
210    n: usize,
211    c_in: usize,
212    h_in: usize,
213    w_in: usize,
214    c_out: usize,
215    k_h: usize,
216    k_w: usize,
217    h_out: usize,
218    w_out: usize,
219    stride_h: usize,
220    stride_w: usize,
221    pad_h: usize,
222    pad_w: usize,
223) -> Vec<f32> {
224    let k_size = c_in * k_h * k_w;
225    let patches = n * h_out * w_out;
226    let mut col = vec![0.0f32; patches * k_size];
227    for b in 0..n {
228        for ho in 0..h_out {
229            for wo in 0..w_out {
230                let row = (b * h_out + ho) * w_out + wo;
231                let mut col_idx = 0usize;
232                for ci in 0..c_in {
233                    for kh in 0..k_h {
234                        for kw in 0..k_w {
235                            let h_in_idx = ho * stride_h + kh;
236                            let w_in_idx = wo * stride_w + kw;
237                            let val = if h_in_idx >= pad_h && w_in_idx >= pad_w {
238                                let hi = h_in_idx - pad_h;
239                                let wi = w_in_idx - pad_w;
240                                if hi < h_in && wi < w_in {
241                                    input_data[((b * c_in + ci) * h_in + hi) * w_in + wi]
242                                } else {
243                                    0.0
244                                }
245                            } else {
246                                0.0
247                            };
248                            col[row * k_size + col_idx] = val;
249                            col_idx += 1;
250                        }
251                    }
252                }
253            }
254        }
255    }
256
257    let col_t = Tensor::new(&col, &[patches, k_size]);
258    let weight_t = Tensor::new(weight_data, &[c_out, k_size]).t();
259    let out2d = crate::ops::matmul(&col_t, &weight_t);
260    let out_data = out2d.data();
261    let mut result = vec![0.0f32; n * c_out * h_out * w_out];
262    for b in 0..n {
263        for ho in 0..h_out {
264            for wo in 0..w_out {
265                let row = (b * h_out + ho) * w_out + wo;
266                for co in 0..c_out {
267                    result[((b * c_out + co) * h_out + ho) * w_out + wo] =
268                        out_data[row * c_out + co];
269                }
270            }
271        }
272    }
273    result
274}
275
276fn choose_conv2d_bwd_kernel<Fd, Fi>(
277    target: Conv2dBwdTarget,
278    key: Conv2dPerfKey,
279    work: usize,
280    direct_fn: Fd,
281    im2col_fn: Fi,
282) -> Conv2dKernelChoice
283where
284    Fd: Fn() -> Vec<f32>,
285    Fi: Fn() -> Vec<f32>,
286{
287    let cfg = cpu_conv2d_bwd_config();
288    let min_work = match target {
289        Conv2dBwdTarget::GradInput => cfg.min_work_grad_input,
290        Conv2dBwdTarget::GradWeight => cfg.min_work_grad_weight,
291    };
292    let profile_iters = match target {
293        Conv2dBwdTarget::GradInput => cfg.profile_iters_grad_input,
294        Conv2dBwdTarget::GradWeight => cfg.profile_iters_grad_weight,
295    };
296    match cfg.strategy {
297        CpuConv2dBwdStrategy::Direct => Conv2dKernelChoice::Direct,
298        CpuConv2dBwdStrategy::Im2col => Conv2dKernelChoice::Im2col,
299        CpuConv2dBwdStrategy::Auto => {
300            if work >= min_work {
301                Conv2dKernelChoice::Im2col
302            } else {
303                Conv2dKernelChoice::Direct
304            }
305        }
306        CpuConv2dBwdStrategy::Profile => {
307            let cache_key = (target, key);
308            if let Some(cached) = conv2d_bwd_profile_cache().lock().get(&cache_key).copied() {
309                return cached;
310            }
311            let iters = profile_iters.max(1);
312            let mut direct_ns = 0u128;
313            let mut im2col_ns = 0u128;
314            for _ in 0..iters {
315                let t0 = Instant::now();
316                let d = direct_fn();
317                direct_ns += t0.elapsed().as_nanos();
318                black_box(d.len());
319
320                let t1 = Instant::now();
321                let c = im2col_fn();
322                im2col_ns += t1.elapsed().as_nanos();
323                black_box(c.len());
324            }
325            let choice = if im2col_ns < direct_ns {
326                Conv2dKernelChoice::Im2col
327            } else {
328                Conv2dKernelChoice::Direct
329            };
330            conv2d_bwd_profile_cache().lock().insert(cache_key, choice);
331            choice
332        }
333    }
334}
335
336fn conv2d_grad_input_direct_core(
337    grad_data: &[f32],
338    weight_data: &[f32],
339    n: usize,
340    c_in: usize,
341    h_in: usize,
342    w_in: usize,
343    c_out: usize,
344    k_h: usize,
345    k_w: usize,
346    h_out: usize,
347    w_out: usize,
348    stride_h: usize,
349    stride_w: usize,
350    pad_h: usize,
351    pad_w: usize,
352) -> Vec<f32> {
353    let total_elements = n * c_in * h_in * w_in;
354    (0..total_elements)
355        .into_par_iter()
356        .map(|idx| {
357            let w = idx % w_in;
358            let h = (idx / w_in) % h_in;
359            let ci = (idx / (w_in * h_in)) % c_in;
360            let b = idx / (w_in * h_in * c_in);
361            let mut sum: f64 = 0.0;
362
363            let h_out_start = if h + pad_h >= k_h {
364                (h + pad_h - k_h + 1).div_ceil(stride_h)
365            } else {
366                0
367            };
368            let h_out_end = std::cmp::min(h_out, (h + pad_h) / stride_h + 1);
369
370            for ho in h_out_start..h_out_end {
371                let kh = h + pad_h - ho * stride_h;
372                let w_out_start = if w + pad_w >= k_w {
373                    (w + pad_w - k_w + 1).div_ceil(stride_w)
374                } else {
375                    0
376                };
377                let w_out_end = std::cmp::min(w_out, (w + pad_w) / stride_w + 1);
378
379                for wo in w_out_start..w_out_end {
380                    let kw = w + pad_w - wo * stride_w;
381                    for co in 0..c_out {
382                        let g_val = grad_data[((b * c_out + co) * h_out + ho) * w_out + wo] as f64;
383                        let w_val = weight_data[((co * c_in + ci) * k_h + kh) * k_w + kw] as f64;
384                        sum += g_val * w_val;
385                    }
386                }
387            }
388            sum as f32
389        })
390        .collect()
391}
392
393fn conv2d_grad_input_im2col_core(
394    grad_data: &[f32],
395    weight_data: &[f32],
396    n: usize,
397    c_in: usize,
398    h_in: usize,
399    w_in: usize,
400    c_out: usize,
401    k_h: usize,
402    k_w: usize,
403    h_out: usize,
404    w_out: usize,
405    stride_h: usize,
406    stride_w: usize,
407    pad_h: usize,
408    pad_w: usize,
409) -> Vec<f32> {
410    let patches = h_out * w_out;
411    let k_size = c_in * k_h * k_w;
412    let weight_flat = Tensor::new(weight_data, &[c_out, k_size]);
413    let mut grad_input_data = vec![0.0f32; n * c_in * h_in * w_in];
414
415    for b in 0..n {
416        let mut gy = vec![0.0f32; patches * c_out];
417        for ho in 0..h_out {
418            for wo in 0..w_out {
419                let row = ho * w_out + wo;
420                for co in 0..c_out {
421                    gy[row * c_out + co] = grad_data[((b * c_out + co) * h_out + ho) * w_out + wo];
422                }
423            }
424        }
425        let gy_t = Tensor::new(&gy, &[patches, c_out]);
426        let dcol_t = crate::ops::matmul(&gy_t, &weight_flat);
427        let dcol = dcol_t.data();
428
429        for ho in 0..h_out {
430            for wo in 0..w_out {
431                let row = ho * w_out + wo;
432                let mut col_idx = 0usize;
433                for ci in 0..c_in {
434                    for kh in 0..k_h {
435                        for kw in 0..k_w {
436                            let h_in_idx = ho * stride_h + kh;
437                            let w_in_idx = wo * stride_w + kw;
438                            if h_in_idx >= pad_h && w_in_idx >= pad_w {
439                                let hi = h_in_idx - pad_h;
440                                let wi = w_in_idx - pad_w;
441                                if hi < h_in && wi < w_in {
442                                    grad_input_data[((b * c_in + ci) * h_in + hi) * w_in + wi] +=
443                                        dcol[row * k_size + col_idx];
444                                }
445                            }
446                            col_idx += 1;
447                        }
448                    }
449                }
450            }
451        }
452    }
453
454    grad_input_data
455}
456
457fn conv2d_grad_weight_direct_core(
458    input_data: &[f32],
459    grad_data: &[f32],
460    n: usize,
461    c_in: usize,
462    h_in: usize,
463    w_in: usize,
464    c_out: usize,
465    k_h: usize,
466    k_w: usize,
467    h_out: usize,
468    w_out: usize,
469    stride_h: usize,
470    stride_w: usize,
471    pad_h: usize,
472    pad_w: usize,
473) -> Vec<f32> {
474    let total_elements = c_out * c_in * k_h * k_w;
475    (0..total_elements)
476        .into_par_iter()
477        .map(|idx| {
478            let kw = idx % k_w;
479            let kh = (idx / k_w) % k_h;
480            let ci = (idx / (k_w * k_h)) % c_in;
481            let co = idx / (k_w * k_h * c_in);
482            let mut sum: f64 = 0.0;
483            for b in 0..n {
484                for ho in 0..h_out {
485                    for wo in 0..w_out {
486                        let h_in_idx = ho * stride_h + kh;
487                        let w_in_idx = wo * stride_w + kw;
488                        if h_in_idx >= pad_h && w_in_idx >= pad_w {
489                            let hi = h_in_idx - pad_h;
490                            let wi = w_in_idx - pad_w;
491                            if hi < h_in && wi < w_in {
492                                let val_in =
493                                    input_data[((b * c_in + ci) * h_in + hi) * w_in + wi] as f64;
494                                let val_g =
495                                    grad_data[((b * c_out + co) * h_out + ho) * w_out + wo] as f64;
496                                sum += val_in * val_g;
497                            }
498                        }
499                    }
500                }
501            }
502            sum as f32
503        })
504        .collect()
505}
506
507fn conv2d_grad_weight_im2col_core(
508    input_data: &[f32],
509    grad_data: &[f32],
510    n: usize,
511    c_in: usize,
512    h_in: usize,
513    w_in: usize,
514    c_out: usize,
515    k_h: usize,
516    k_w: usize,
517    h_out: usize,
518    w_out: usize,
519    stride_h: usize,
520    stride_w: usize,
521    pad_h: usize,
522    pad_w: usize,
523) -> Vec<f32> {
524    let patches = h_out * w_out;
525    let k_size = c_in * k_h * k_w;
526    let mut grad_weight = vec![0.0f32; c_out * k_size];
527
528    for b in 0..n {
529        let mut col = vec![0.0f32; patches * k_size];
530        let mut gy = vec![0.0f32; patches * c_out];
531
532        for ho in 0..h_out {
533            for wo in 0..w_out {
534                let row = ho * w_out + wo;
535                let mut col_idx = 0usize;
536                for ci in 0..c_in {
537                    for kh in 0..k_h {
538                        for kw in 0..k_w {
539                            let h_in_idx = ho * stride_h + kh;
540                            let w_in_idx = wo * stride_w + kw;
541                            let val = if h_in_idx >= pad_h && w_in_idx >= pad_w {
542                                let hi = h_in_idx - pad_h;
543                                let wi = w_in_idx - pad_w;
544                                if hi < h_in && wi < w_in {
545                                    input_data[((b * c_in + ci) * h_in + hi) * w_in + wi]
546                                } else {
547                                    0.0
548                                }
549                            } else {
550                                0.0
551                            };
552                            col[row * k_size + col_idx] = val;
553                            col_idx += 1;
554                        }
555                    }
556                }
557                for co in 0..c_out {
558                    gy[row * c_out + co] = grad_data[((b * c_out + co) * h_out + ho) * w_out + wo];
559                }
560            }
561        }
562
563        let gy_t = Tensor::new(&gy, &[patches, c_out]).t();
564        let col_t = Tensor::new(&col, &[patches, k_size]);
565        let gw_batch = crate::ops::matmul(&gy_t, &col_t);
566        let gw_data = gw_batch.data();
567        for i in 0..grad_weight.len() {
568            grad_weight[i] += gw_data[i];
569        }
570    }
571
572    grad_weight
573}
574
575impl BackwardOp for Conv2dBackward {
576    fn backward(&self, grad: &Tensor) {
577        let (stride_h, stride_w) = self.stride;
578        let (pad_h, pad_w) = self.padding;
579
580        let input_shape = self.input.shape();
581        let weight_shape = self.weight.shape();
582        let grad_shape = grad.shape();
583
584        let n = input_shape[0];
585        let c_in = input_shape[1];
586        let h_in = input_shape[2];
587        let w_in = input_shape[3];
588
589        let c_out = weight_shape[0];
590        let k_h = weight_shape[2];
591        let k_w = weight_shape[3];
592
593        let h_out = grad_shape[2];
594        let w_out = grad_shape[3];
595
596        // Compute grad_input
597        if self.input.requires_grad() {
598            let grad_guard = grad.data();
599            let weight_guard = self.weight.data();
600            let grad_data = &*grad_guard;
601            let weight_data = &*weight_guard;
602            let work = n * c_in * h_in * w_in * c_out * k_h * k_w;
603            let key: Conv2dPerfKey = (
604                n, c_in, h_in, w_in, c_out, k_h, k_w, stride_h, stride_w, pad_h, pad_w,
605            );
606            let kernel = choose_conv2d_bwd_kernel(
607                Conv2dBwdTarget::GradInput,
608                key,
609                work,
610                || {
611                    conv2d_grad_input_direct_core(
612                        grad_data,
613                        weight_data,
614                        n,
615                        c_in,
616                        h_in,
617                        w_in,
618                        c_out,
619                        k_h,
620                        k_w,
621                        h_out,
622                        w_out,
623                        stride_h,
624                        stride_w,
625                        pad_h,
626                        pad_w,
627                    )
628                },
629                || {
630                    conv2d_grad_input_im2col_core(
631                        grad_data,
632                        weight_data,
633                        n,
634                        c_in,
635                        h_in,
636                        w_in,
637                        c_out,
638                        k_h,
639                        k_w,
640                        h_out,
641                        w_out,
642                        stride_h,
643                        stride_w,
644                        pad_h,
645                        pad_w,
646                    )
647                },
648            );
649            let grad_input_data = match kernel {
650                Conv2dKernelChoice::Direct => conv2d_grad_input_direct_core(
651                    grad_data,
652                    weight_data,
653                    n,
654                    c_in,
655                    h_in,
656                    w_in,
657                    c_out,
658                    k_h,
659                    k_w,
660                    h_out,
661                    w_out,
662                    stride_h,
663                    stride_w,
664                    pad_h,
665                    pad_w,
666                ),
667                Conv2dKernelChoice::Im2col => conv2d_grad_input_im2col_core(
668                    grad_data,
669                    weight_data,
670                    n,
671                    c_in,
672                    h_in,
673                    w_in,
674                    c_out,
675                    k_h,
676                    k_w,
677                    h_out,
678                    w_out,
679                    stride_h,
680                    stride_w,
681                    pad_h,
682                    pad_w,
683                ),
684            };
685
686            let grad_input_tensor =
687                Tensor::new_with_storage(Storage::new(grad_input_data), self.input.shape());
688            self.input.accumulate_grad(&grad_input_tensor);
689            self.input.backward_step();
690        }
691
692        // Compute grad_weight
693        if self.weight.requires_grad() {
694            let input_guard = self.input.data();
695            let grad_guard = grad.data();
696            let input_data = &*input_guard;
697            let grad_data = &*grad_guard;
698            let work = n * c_in * h_in * w_in * c_out * k_h * k_w;
699            let key: Conv2dPerfKey = (
700                n, c_in, h_in, w_in, c_out, k_h, k_w, stride_h, stride_w, pad_h, pad_w,
701            );
702            let kernel = choose_conv2d_bwd_kernel(
703                Conv2dBwdTarget::GradWeight,
704                key,
705                work,
706                || {
707                    conv2d_grad_weight_direct_core(
708                        input_data, grad_data, n, c_in, h_in, w_in, c_out, k_h, k_w, h_out, w_out,
709                        stride_h, stride_w, pad_h, pad_w,
710                    )
711                },
712                || {
713                    conv2d_grad_weight_im2col_core(
714                        input_data, grad_data, n, c_in, h_in, w_in, c_out, k_h, k_w, h_out, w_out,
715                        stride_h, stride_w, pad_h, pad_w,
716                    )
717                },
718            );
719            let grad_weight_data = match kernel {
720                Conv2dKernelChoice::Direct => conv2d_grad_weight_direct_core(
721                    input_data, grad_data, n, c_in, h_in, w_in, c_out, k_h, k_w, h_out, w_out,
722                    stride_h, stride_w, pad_h, pad_w,
723                ),
724                Conv2dKernelChoice::Im2col => conv2d_grad_weight_im2col_core(
725                    input_data, grad_data, n, c_in, h_in, w_in, c_out, k_h, k_w, h_out, w_out,
726                    stride_h, stride_w, pad_h, pad_w,
727                ),
728            };
729
730            let grad_weight_tensor =
731                Tensor::new_with_storage(Storage::new(grad_weight_data), self.weight.shape());
732            self.weight.accumulate_grad(&grad_weight_tensor);
733            self.weight.backward_step();
734        }
735    }
736}
737
738pub fn conv2d(
739    input: &Tensor,
740    weight: &Tensor,
741    stride: (usize, usize),
742    padding: (usize, usize),
743) -> Tensor {
744    let input_shape = input.shape();
745    let weight_shape = weight.shape();
746
747    if input_shape.len() != 4 || weight_shape.len() != 4 {
748        panic!("Conv2d requires 4D tensors");
749    }
750
751    let n = input_shape[0];
752    let c_in = input_shape[1];
753    let h_in = input_shape[2];
754    let w_in = input_shape[3];
755
756    let c_out = weight_shape[0];
757    let k_h = weight_shape[2];
758    let k_w = weight_shape[3];
759
760    if weight_shape[1] != c_in {
761        panic!(
762            "Weight input channels {} must match input channels {}",
763            weight_shape[1], c_in
764        );
765    }
766
767    let (stride_h, stride_w) = stride;
768    let (pad_h, pad_w) = padding;
769
770    let h_out = (h_in + 2 * pad_h - k_h) / stride_h + 1;
771    let w_out = (w_in + 2 * pad_w - k_w) / stride_w + 1;
772
773    let input_contig = if input.is_contiguous() {
774        input.clone()
775    } else {
776        input.contiguous()
777    };
778    let weight_contig = if weight.is_contiguous() {
779        weight.clone()
780    } else {
781        weight.contiguous()
782    };
783
784    let input_guard = input_contig.data();
785    let weight_guard = weight_contig.data();
786    let input_data = &*input_guard;
787    let weight_data = &*weight_guard;
788
789    let cfg = cpu_conv2d_config();
790    let work = n * c_out * h_out * w_out * c_in * k_h * k_w;
791    let key: Conv2dPerfKey = (
792        n, c_in, h_in, w_in, c_out, k_h, k_w, stride_h, stride_w, pad_h, pad_w,
793    );
794    let kernel = match cfg.strategy {
795        CpuConv2dStrategy::Direct => Conv2dKernelChoice::Direct,
796        CpuConv2dStrategy::Im2col => Conv2dKernelChoice::Im2col,
797        CpuConv2dStrategy::Auto => {
798            if work >= cfg.min_work {
799                Conv2dKernelChoice::Im2col
800            } else {
801                Conv2dKernelChoice::Direct
802            }
803        }
804        CpuConv2dStrategy::Profile => {
805            if let Some(cached) = conv2d_profile_cache().lock().get(&key).copied() {
806                cached
807            } else {
808                let iters = cfg.profile_iters.max(1);
809                let mut direct_ns = 0u128;
810                let mut im2col_ns = 0u128;
811                for _ in 0..iters {
812                    let t0 = Instant::now();
813                    let d = conv2d_direct_core(
814                        input_data,
815                        weight_data,
816                        n,
817                        c_in,
818                        h_in,
819                        w_in,
820                        c_out,
821                        k_h,
822                        k_w,
823                        h_out,
824                        w_out,
825                        stride_h,
826                        stride_w,
827                        pad_h,
828                        pad_w,
829                    );
830                    direct_ns += t0.elapsed().as_nanos();
831                    black_box(d.len());
832                    let t1 = Instant::now();
833                    let c = conv2d_im2col_core(
834                        input_data,
835                        weight_data,
836                        n,
837                        c_in,
838                        h_in,
839                        w_in,
840                        c_out,
841                        k_h,
842                        k_w,
843                        h_out,
844                        w_out,
845                        stride_h,
846                        stride_w,
847                        pad_h,
848                        pad_w,
849                    );
850                    im2col_ns += t1.elapsed().as_nanos();
851                    black_box(c.len());
852                }
853                let choice = if im2col_ns < direct_ns {
854                    Conv2dKernelChoice::Im2col
855                } else {
856                    Conv2dKernelChoice::Direct
857                };
858                conv2d_profile_cache().lock().insert(key, choice);
859                choice
860            }
861        }
862    };
863
864    let result_data = match kernel {
865        Conv2dKernelChoice::Direct => conv2d_direct_core(
866            input_data,
867            weight_data,
868            n,
869            c_in,
870            h_in,
871            w_in,
872            c_out,
873            k_h,
874            k_w,
875            h_out,
876            w_out,
877            stride_h,
878            stride_w,
879            pad_h,
880            pad_w,
881        ),
882        Conv2dKernelChoice::Im2col => conv2d_im2col_core(
883            input_data,
884            weight_data,
885            n,
886            c_in,
887            h_in,
888            w_in,
889            c_out,
890            k_h,
891            k_w,
892            h_out,
893            w_out,
894            stride_h,
895            stride_w,
896            pad_h,
897            pad_w,
898        ),
899    };
900
901    let storage = Storage::new(result_data);
902    let mut tensor = Tensor::new_with_storage(storage, &[n, c_out, h_out, w_out]);
903
904    if input.requires_grad() || weight.requires_grad() {
905        tensor.set_requires_grad_mut(true);
906        tensor.set_op(Arc::new(Conv2dBackward {
907            input: input.clone(),
908            weight: weight.clone(),
909            stride,
910            padding,
911        }));
912    }
913
914    // if crate::graph::is_tracing() {
915    //     crate::graph::record_op(crate::graph::NodeOp::Conv2d { stride, padding }, &[input, weight], &tensor);
916    // }
917
918    tensor
919}