Skip to main content

rustorch_core/
ops.rs

1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::Tensor;
4use parking_lot::Mutex;
5use rayon::prelude::*;
6use std::collections::HashMap;
7use std::hint::black_box;
8use std::sync::{Arc, OnceLock};
9use std::time::Instant;
10use wide::f32x8;
11
12pub mod activations;
13pub mod conv;
14pub mod embedding;
15pub mod norm;
16pub mod pool;
17pub mod view;
18
19pub use activations::{sigmoid, softmax, tanh};
20pub use conv::conv2d;
21pub use embedding::embedding;
22pub use norm::{batch_norm2d, layer_norm};
23pub use pool::max_pool2d;
24pub use view::ReshapeBackward;
25
26#[derive(Debug)]
27pub struct MulBackward {
28    pub lhs: Tensor,
29    pub rhs: Tensor,
30}
31
32impl BackwardOp for MulBackward {
33    fn backward(&self, grad: &Tensor) {
34        let same_input = Arc::ptr_eq(&self.lhs.inner, &self.rhs.inner);
35        if same_input && self.lhs.requires_grad() {
36            let mut grad_lhs = crate::ops::mul(grad, &self.rhs);
37            let mut grad_rhs = crate::ops::mul(grad, &self.lhs);
38            if grad_lhs.shape() != self.lhs.shape() {
39                grad_lhs = sum_to(&grad_lhs, self.lhs.shape());
40            }
41            if grad_rhs.shape() != self.lhs.shape() {
42                grad_rhs = sum_to(&grad_rhs, self.lhs.shape());
43            }
44            let grad_total = add(&grad_lhs, &grad_rhs);
45            self.lhs.accumulate_grad(&grad_total);
46            self.lhs.backward_step();
47            return;
48        }
49
50        if self.lhs.requires_grad() {
51            let mut grad_lhs = crate::ops::mul(grad, &self.rhs);
52            if grad_lhs.shape() != self.lhs.shape() {
53                grad_lhs = sum_to(&grad_lhs, self.lhs.shape());
54            }
55            self.lhs.accumulate_grad(&grad_lhs);
56            self.lhs.backward_step();
57        }
58        if self.rhs.requires_grad() {
59            let mut grad_rhs = crate::ops::mul(grad, &self.lhs);
60            if grad_rhs.shape() != self.rhs.shape() {
61                grad_rhs = sum_to(&grad_rhs, self.rhs.shape());
62            }
63            self.rhs.accumulate_grad(&grad_rhs);
64            self.rhs.backward_step();
65        }
66    }
67}
68
69pub fn mul(lhs: &Tensor, rhs: &Tensor) -> Tensor {
70    #[cfg(feature = "wgpu_backend")]
71    {
72        if let (Some(lhs_buf), Some(rhs_buf)) =
73            (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer())
74        {
75            let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
76                .expect("Shapes not broadcastable");
77
78            use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
79            let output_buf = elementwise_wgpu_buffer(
80                lhs_buf,
81                lhs.shape(),
82                lhs.strides(),
83                Some((rhs_buf, rhs.shape(), rhs.strides())),
84                &target_shape,
85                ElementwiseOp::Mul,
86                None,
87            );
88
89            let size: usize = target_shape.iter().product();
90            let storage = Storage::new_wgpu(output_buf, size, 0);
91            let mut tensor = Tensor::new_with_storage(storage, &target_shape);
92
93            if lhs.requires_grad() || rhs.requires_grad() {
94                tensor.set_requires_grad_mut(true);
95                tensor.set_op(Arc::new(MulBackward {
96                    lhs: lhs.clone(),
97                    rhs: rhs.clone(),
98                }));
99            }
100            return tensor;
101        }
102    }
103
104    if lhs.shape() != rhs.shape() {
105        // CPU Broadcast (simplified)
106        let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
107            .expect("Shapes not broadcastable");
108        let lhs_expanded = lhs.expand(&target_shape);
109        let rhs_expanded = rhs.expand(&target_shape);
110        return mul(&lhs_expanded, &rhs_expanded);
111    }
112
113    let lhs_contig = if lhs.is_contiguous() {
114        lhs.clone()
115    } else {
116        lhs.contiguous()
117    };
118    let rhs_contig = if rhs.is_contiguous() {
119        rhs.clone()
120    } else {
121        rhs.contiguous()
122    };
123
124    let lhs_guard = lhs_contig.data();
125    let rhs_guard = rhs_contig.data();
126    let lhs_data = &*lhs_guard;
127    let rhs_data = &*rhs_guard;
128
129    let result_data = elemwise_auto(lhs_data, rhs_data, ElemwiseKind::Mul);
130
131    let storage = Storage::new(result_data);
132    let mut tensor = Tensor::new_with_storage(storage, lhs.shape());
133
134    if lhs.requires_grad() || rhs.requires_grad() {
135        tensor.set_requires_grad_mut(true);
136        tensor.set_op(Arc::new(MulBackward {
137            lhs: lhs.clone(),
138            rhs: rhs.clone(),
139        }));
140    }
141
142    tensor
143}
144
145pub fn div(lhs: &Tensor, rhs: &Tensor) -> Tensor {
146    // Basic Div
147    if lhs.shape() != rhs.shape() {
148        let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
149            .expect("Shapes not broadcastable");
150        let lhs_expanded = lhs.expand(&target_shape);
151        let rhs_expanded = rhs.expand(&target_shape);
152        return div(&lhs_expanded, &rhs_expanded);
153    }
154
155    let lhs_contig = if lhs.is_contiguous() {
156        lhs.clone()
157    } else {
158        lhs.contiguous()
159    };
160    let rhs_contig = if rhs.is_contiguous() {
161        rhs.clone()
162    } else {
163        rhs.contiguous()
164    };
165
166    let lhs_guard = lhs_contig.data();
167    let rhs_guard = rhs_contig.data();
168    let result_data: Vec<f32> = lhs_guard
169        .par_iter()
170        .zip(rhs_guard.par_iter())
171        .map(|(a, b)| a / b)
172        .collect();
173    let storage = Storage::new(result_data);
174    let tensor = Tensor::new_with_storage(storage, lhs.shape());
175
176    // DivBackward... (omitted for brevity unless needed)
177    tensor
178}
179
180// --- Add ---
181#[derive(Debug)]
182pub struct AddBackward {
183    pub lhs: Tensor,
184    pub rhs: Tensor,
185}
186
187impl BackwardOp for AddBackward {
188    fn backward(&self, grad: &Tensor) {
189        let same_input = Arc::ptr_eq(&self.lhs.inner, &self.rhs.inner);
190        if same_input && self.lhs.requires_grad() {
191            let grad_lhs = if grad.shape() != self.lhs.shape() {
192                sum_to(grad, self.lhs.shape())
193            } else {
194                grad.clone()
195            };
196            let grad_total = add(&grad_lhs, &grad_lhs);
197            self.lhs.accumulate_grad(&grad_total);
198            self.lhs.backward_step();
199            return;
200        }
201
202        if self.lhs.requires_grad() {
203            let grad_lhs = if grad.shape() != self.lhs.shape() {
204                sum_to(grad, self.lhs.shape())
205            } else {
206                grad.clone()
207            };
208            self.lhs.accumulate_grad(&grad_lhs);
209            self.lhs.backward_step();
210        }
211        if self.rhs.requires_grad() {
212            let grad_rhs = if grad.shape() != self.rhs.shape() {
213                sum_to(grad, self.rhs.shape())
214            } else {
215                grad.clone()
216            };
217            self.rhs.accumulate_grad(&grad_rhs);
218            self.rhs.backward_step();
219        }
220    }
221}
222
223pub fn add(lhs: &Tensor, rhs: &Tensor) -> Tensor {
224    #[cfg(feature = "wgpu_backend")]
225    {
226        if let (Some(lhs_buf), Some(rhs_buf)) =
227            (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer())
228        {
229            let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
230                .expect("Shapes not broadcastable");
231
232            use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
233            let output_buf = elementwise_wgpu_buffer(
234                lhs_buf,
235                lhs.shape(),
236                lhs.strides(),
237                Some((rhs_buf, rhs.shape(), rhs.strides())),
238                &target_shape,
239                ElementwiseOp::Add,
240                None,
241            );
242
243            let size: usize = target_shape.iter().product();
244            let storage = Storage::new_wgpu(output_buf, size, 0);
245            let mut tensor = Tensor::new_with_storage(storage, &target_shape);
246
247            if lhs.requires_grad() || rhs.requires_grad() {
248                tensor.set_requires_grad_mut(true);
249                tensor.set_op(Arc::new(AddBackward {
250                    lhs: lhs.clone(),
251                    rhs: rhs.clone(),
252                }));
253            }
254            return tensor;
255        }
256    }
257
258    if lhs.shape() != rhs.shape() {
259        let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
260            .expect("Shapes not broadcastable");
261        let lhs_expanded = lhs.expand(&target_shape);
262        let rhs_expanded = rhs.expand(&target_shape);
263        return add(&lhs_expanded, &rhs_expanded);
264    }
265
266    let lhs_contig = if lhs.is_contiguous() {
267        lhs.clone()
268    } else {
269        lhs.contiguous()
270    };
271    let rhs_contig = if rhs.is_contiguous() {
272        rhs.clone()
273    } else {
274        rhs.contiguous()
275    };
276
277    let lhs_guard = lhs_contig.data();
278    let rhs_guard = rhs_contig.data();
279    let result_data = elemwise_auto(&lhs_guard, &rhs_guard, ElemwiseKind::Add);
280    let storage = Storage::new(result_data);
281    let mut tensor = Tensor::new_with_storage(storage, lhs.shape());
282    if lhs.requires_grad() || rhs.requires_grad() {
283        tensor.set_requires_grad_mut(true);
284        tensor.set_op(Arc::new(AddBackward {
285            lhs: lhs.clone(),
286            rhs: rhs.clone(),
287        }));
288    }
289    tensor
290}
291
292#[derive(Debug)]
293pub struct SubBackward {
294    pub lhs: Tensor,
295    pub rhs: Tensor,
296}
297
298impl BackwardOp for SubBackward {
299    fn backward(&self, grad: &Tensor) {
300        let same_input = Arc::ptr_eq(&self.lhs.inner, &self.rhs.inner);
301        if same_input && self.lhs.requires_grad() {
302            return;
303        }
304
305        if self.lhs.requires_grad() {
306            let mut grad_lhs = grad.clone();
307            if grad_lhs.shape() != self.lhs.shape() {
308                grad_lhs = sum_to(&grad_lhs, self.lhs.shape());
309            }
310            self.lhs.accumulate_grad(&grad_lhs);
311            self.lhs.backward_step();
312        }
313        if self.rhs.requires_grad() {
314            let mut grad_rhs = neg(grad);
315            if grad_rhs.shape() != self.rhs.shape() {
316                grad_rhs = sum_to(&grad_rhs, self.rhs.shape());
317            }
318            self.rhs.accumulate_grad(&grad_rhs);
319            self.rhs.backward_step();
320        }
321    }
322}
323
324pub fn sub(lhs: &Tensor, rhs: &Tensor) -> Tensor {
325    #[cfg(feature = "wgpu_backend")]
326    {
327        if let (Some(lhs_buf), Some(rhs_buf)) =
328            (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer())
329        {
330            let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
331                .expect("Shapes not broadcastable");
332
333            use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
334            let output_buf = elementwise_wgpu_buffer(
335                lhs_buf,
336                lhs.shape(),
337                lhs.strides(),
338                Some((rhs_buf, rhs.shape(), rhs.strides())),
339                &target_shape,
340                ElementwiseOp::Sub,
341                None,
342            );
343
344            let size: usize = target_shape.iter().product();
345            let storage = Storage::new_wgpu(output_buf, size, 0);
346            let mut tensor = Tensor::new_with_storage(storage, &target_shape);
347
348            if lhs.requires_grad() || rhs.requires_grad() {
349                tensor.set_requires_grad_mut(true);
350                tensor.set_op(Arc::new(SubBackward {
351                    lhs: lhs.clone(),
352                    rhs: rhs.clone(),
353                }));
354            }
355            return tensor;
356        }
357    }
358
359    if lhs.shape() != rhs.shape() {
360        let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
361            .expect("Shapes not broadcastable");
362        let lhs_expanded = lhs.expand(&target_shape);
363        let rhs_expanded = rhs.expand(&target_shape);
364        return sub(&lhs_expanded, &rhs_expanded);
365    }
366
367    let lhs_contig = if lhs.is_contiguous() {
368        lhs.clone()
369    } else {
370        lhs.contiguous()
371    };
372    let rhs_contig = if rhs.is_contiguous() {
373        rhs.clone()
374    } else {
375        rhs.contiguous()
376    };
377    let lhs_guard = lhs_contig.data();
378    let rhs_guard = rhs_contig.data();
379    let result_data = elemwise_auto(&lhs_guard, &rhs_guard, ElemwiseKind::Sub);
380    let storage = Storage::new(result_data);
381    let mut tensor = Tensor::new_with_storage(storage, lhs.shape());
382
383    if lhs.requires_grad() || rhs.requires_grad() {
384        tensor.set_requires_grad_mut(true);
385        tensor.set_op(Arc::new(SubBackward {
386            lhs: lhs.clone(),
387            rhs: rhs.clone(),
388        }));
389    }
390    tensor
391}
392
393pub fn neg(input: &Tensor) -> Tensor {
394    let input_guard = input.data();
395    let result_data: Vec<f32> = input_guard.par_iter().map(|x| -x).collect();
396    let storage = Storage::new(result_data);
397    Tensor::new_with_storage(storage, input.shape())
398}
399
400#[derive(Debug)]
401pub struct ReluBackward {
402    pub input: Tensor,
403    pub output: Tensor,
404}
405
406impl BackwardOp for ReluBackward {
407    fn backward(&self, grad: &Tensor) {
408        if self.input.requires_grad() {
409            #[cfg(feature = "wgpu_backend")]
410            {
411                if let (Some(out_buf), Some(grad_buf)) = (
412                    self.output.storage().wgpu_buffer(),
413                    grad.storage().wgpu_buffer(),
414                ) {
415                    use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
416                    let grad_input_buf = elementwise_wgpu_buffer(
417                        out_buf,
418                        self.output.shape(),
419                        self.output.strides(),
420                        Some((grad_buf, grad.shape(), grad.strides())),
421                        grad.shape(),
422                        ElementwiseOp::ReLUBackward,
423                        None,
424                    );
425                    let size: usize = grad.shape().iter().product();
426                    let storage = Storage::new_wgpu(grad_input_buf, size, 0);
427                    let grad_input = Tensor::new_with_storage(storage, grad.shape());
428                    self.input.accumulate_grad(&grad_input);
429                    self.input.backward_step();
430                    return;
431                }
432            }
433
434            let input_guard = self.input.data();
435            let grad_guard = grad.data();
436            let grad_input: Vec<f32> = input_guard
437                .par_iter()
438                .zip(grad_guard.par_iter())
439                .map(|(x, g)| if *x > 0.0 { *g } else { 0.0 })
440                .collect();
441            let storage = Storage::new(grad_input);
442            let grad_input_tensor = Tensor::new_with_storage(storage, grad.shape());
443            self.input.accumulate_grad(&grad_input_tensor);
444            self.input.backward_step();
445        }
446    }
447}
448
449pub fn relu(input: &Tensor) -> Tensor {
450    #[cfg(feature = "wgpu_backend")]
451    {
452        if let Some(buf) = input.storage().wgpu_buffer() {
453            use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
454            let output_buf = elementwise_wgpu_buffer(
455                buf,
456                input.shape(),
457                input.strides(),
458                None,
459                input.shape(),
460                ElementwiseOp::ReLU,
461                None,
462            );
463            let size: usize = input.shape().iter().product();
464            let storage = Storage::new_wgpu(output_buf, size, 0);
465            let mut tensor = Tensor::new_with_storage(storage, input.shape());
466
467            if input.requires_grad() {
468                tensor.set_requires_grad_mut(true);
469                tensor.set_op(Arc::new(ReluBackward {
470                    input: input.clone(),
471                    output: tensor.detach(),
472                }));
473            }
474            return tensor;
475        }
476    }
477
478    let input_guard = input.data();
479    let result_data: Vec<f32> = input_guard.par_iter().map(|x| x.max(0.0)).collect();
480    let storage = Storage::new(result_data);
481    let mut tensor = Tensor::new_with_storage(storage, input.shape());
482
483    if input.requires_grad() {
484        tensor.set_requires_grad_mut(true);
485        tensor.set_op(Arc::new(ReluBackward {
486            input: input.clone(),
487            output: tensor.detach(),
488        }));
489    }
490    tensor
491}
492
493pub fn sgd_step(param: &Tensor, grad: &Tensor, lr: f32) -> Tensor {
494    // param - lr * grad
495    #[cfg(feature = "wgpu_backend")]
496    {
497        if let (Some(p_buf), Some(g_buf)) =
498            (param.storage().wgpu_buffer(), grad.storage().wgpu_buffer())
499        {
500            use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
501            let output_buf = elementwise_wgpu_buffer(
502                p_buf,
503                param.shape(),
504                param.strides(),
505                Some((g_buf, grad.shape(), grad.strides())),
506                param.shape(),
507                ElementwiseOp::SGDStep,
508                Some(lr),
509            );
510            let size: usize = param.shape().iter().product();
511            let storage = Storage::new_wgpu(output_buf, size, 0);
512            return Tensor::new_with_storage(storage, param.shape());
513        }
514    }
515
516    // CPU
517    let p_data = param.data();
518    let g_data = grad.data();
519    let res_data: Vec<f32> = p_data
520        .par_iter()
521        .zip(g_data.par_iter())
522        .map(|(p, g)| p - lr * g)
523        .collect();
524    let storage = Storage::new(res_data);
525    Tensor::new_with_storage(storage, param.shape())
526}
527
528#[derive(Debug)]
529pub struct MatmulBackward {
530    pub lhs: Tensor,
531    pub rhs: Tensor,
532}
533
534impl BackwardOp for MatmulBackward {
535    fn backward(&self, grad: &Tensor) {
536        #[cfg(feature = "wgpu_backend")]
537        {
538            let grad_is_wgpu = grad.storage().wgpu_buffer().is_some();
539            let lhs_is_wgpu = self.lhs.storage().wgpu_buffer().is_some();
540            let rhs_is_wgpu = self.rhs.storage().wgpu_buffer().is_some();
541
542            if grad_is_wgpu && lhs_is_wgpu && rhs_is_wgpu {
543                if self.lhs.requires_grad() {
544                    let rhs_t = self.rhs.t();
545                    let grad_lhs = matmul(grad, &rhs_t).detach();
546                    self.lhs.accumulate_grad(&grad_lhs);
547                    self.lhs.backward_step();
548                }
549                if self.rhs.requires_grad() {
550                    let grad_rhs = matmul(&self.lhs.t(), grad).detach();
551                    self.rhs.accumulate_grad(&grad_rhs);
552                    self.rhs.backward_step();
553                }
554                return;
555            }
556        }
557
558        if self.lhs.requires_grad() {
559            let rhs_t = self.rhs.t();
560            let grad_lhs = matmul(grad, &rhs_t);
561            self.lhs.accumulate_grad(&grad_lhs);
562            self.lhs.backward_step();
563        }
564        if self.rhs.requires_grad() {
565            let grad_rhs = matmul(&self.lhs.t(), grad);
566            self.rhs.accumulate_grad(&grad_rhs);
567            self.rhs.backward_step();
568        }
569    }
570}
571
572#[cfg(feature = "wgpu_backend")]
573#[allow(dead_code)]
574fn matmul_gpu_aware_no_grad(lhs: &Tensor, rhs: &Tensor) -> Tensor {
575    let lhs_shape = lhs.shape();
576    let rhs_shape = rhs.shape();
577
578    if lhs_shape.len() != 2 || rhs_shape.len() != 2 {
579        panic!("Matmul only supports 2D");
580    }
581
582    let m = lhs_shape[0];
583    let k = lhs_shape[1];
584    let k2 = rhs_shape[0];
585    let n = rhs_shape[1];
586
587    if k != k2 {
588        panic!("Matmul dimension mismatch");
589    }
590
591    if let (Some(_), Some(_)) = (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer()) {
592        let lhs_contig = if lhs.is_contiguous() {
593            lhs.clone()
594        } else {
595            lhs.contiguous()
596        };
597        let rhs_contig = if rhs.is_contiguous() {
598            rhs.clone()
599        } else {
600            rhs.contiguous()
601        };
602
603        if lhs_contig.storage().wgpu_buffer().is_none()
604            || rhs_contig.storage().wgpu_buffer().is_none()
605        {
606            return matmul(&lhs_contig, &rhs_contig);
607        }
608
609        use crate::backend::wgpu::{matmul_wgpu_buffer, Activation};
610        let output_buf = matmul_wgpu_buffer(
611            lhs_contig.storage().wgpu_buffer().unwrap(),
612            lhs_contig.shape(),
613            rhs_contig.storage().wgpu_buffer().unwrap(),
614            rhs_contig.shape(),
615            Activation::None,
616        );
617
618        let storage = Storage::new_wgpu(output_buf, m * n, 0);
619        return Tensor::new_with_storage(storage, &[m, n]);
620    }
621
622    matmul(lhs, rhs)
623}
624
625pub fn sum_to(tensor: &Tensor, shape: &[usize]) -> Tensor {
626    if tensor.shape() == shape {
627        return tensor.clone();
628    }
629    view::sum_to(tensor, shape)
630}
631
632#[derive(Debug)]
633pub struct SumBackward {
634    pub input: Tensor,
635}
636
637impl BackwardOp for SumBackward {
638    fn backward(&self, grad: &Tensor) {
639        if self.input.requires_grad() {
640            #[cfg(feature = "wgpu_backend")]
641            let grad_cpu = if grad.storage().device().is_wgpu() {
642                grad.to_cpu()
643            } else {
644                grad.clone()
645            };
646            #[cfg(not(feature = "wgpu_backend"))]
647            let grad_cpu = grad.clone();
648
649            let grad_val = grad_cpu.data()[0];
650            let mut grad_input = Tensor::full(self.input.shape(), grad_val);
651            #[cfg(feature = "wgpu_backend")]
652            if self.input.storage().device().is_wgpu() {
653                grad_input = grad_input.to_wgpu();
654            }
655            self.input.accumulate_grad(&grad_input);
656            self.input.backward_step();
657        }
658    }
659}
660
661pub fn sum(tensor: &Tensor) -> Tensor {
662    let total_size: usize = tensor.shape().iter().product();
663
664    let mut output = {
665        #[cfg(feature = "wgpu_backend")]
666        {
667            if let Some(input_buf) = tensor.storage().wgpu_buffer() {
668                let output_buf = crate::backend::wgpu::reduce_sum_all_wgpu(input_buf, total_size);
669                let storage = Storage::new_wgpu(output_buf, 1, 0);
670                Tensor::new_with_storage(storage, &[])
671            } else {
672                let data = tensor.data();
673                let sum_val = sum_auto(&data);
674                Tensor::new_with_storage(Storage::new(vec![sum_val]), &[])
675            }
676        }
677        #[cfg(not(feature = "wgpu_backend"))]
678        {
679            let data = tensor.data();
680            let sum_val = sum_auto(&data);
681            Tensor::new_with_storage(Storage::new(vec![sum_val]), &[])
682        }
683    };
684
685    if tensor.requires_grad() {
686        output.set_requires_grad_mut(true);
687        output.set_op(Arc::new(SumBackward {
688            input: tensor.clone(),
689        }));
690    }
691
692    output
693}
694
695#[derive(Debug)]
696pub struct MeanBackward {
697    pub input: Tensor,
698}
699
700impl BackwardOp for MeanBackward {
701    fn backward(&self, grad: &Tensor) {
702        if self.input.requires_grad() {
703            #[cfg(feature = "wgpu_backend")]
704            let grad_cpu = if grad.storage().device().is_wgpu() {
705                grad.to_cpu()
706            } else {
707                grad.clone()
708            };
709            #[cfg(not(feature = "wgpu_backend"))]
710            let grad_cpu = grad.clone();
711
712            let grad_val = grad_cpu.data()[0];
713            let numel = self.input.shape().iter().product::<usize>() as f32;
714            let mut grad_input = Tensor::full(self.input.shape(), grad_val / numel);
715            #[cfg(feature = "wgpu_backend")]
716            if self.input.storage().device().is_wgpu() {
717                grad_input = grad_input.to_wgpu();
718            }
719            self.input.accumulate_grad(&grad_input);
720            self.input.backward_step();
721        }
722    }
723}
724
725pub fn mean(tensor: &Tensor) -> Tensor {
726    let t_cpu = {
727        #[cfg(feature = "wgpu_backend")]
728        {
729            if tensor.storage().device().is_wgpu() {
730                tensor.to_cpu()
731            } else {
732                tensor.clone()
733            }
734        }
735        #[cfg(not(feature = "wgpu_backend"))]
736        {
737            tensor.clone()
738        }
739    };
740    let data = t_cpu.data();
741    let numel = data.len() as f32;
742    let mean_val = sum_auto(&data) / numel;
743    let mut out = Tensor::new_with_storage(Storage::new(vec![mean_val]), &[]);
744    if tensor.requires_grad() {
745        out.set_requires_grad_mut(true);
746        out.set_op(Arc::new(MeanBackward {
747            input: tensor.clone(),
748        }));
749    }
750    out
751}
752
753#[derive(Debug)]
754pub struct VarBackward {
755    pub input: Tensor,
756    pub mean: f32,
757}
758
759impl BackwardOp for VarBackward {
760    fn backward(&self, grad: &Tensor) {
761        if self.input.requires_grad() {
762            #[cfg(feature = "wgpu_backend")]
763            let grad_cpu = if grad.storage().device().is_wgpu() {
764                grad.to_cpu()
765            } else {
766                grad.clone()
767            };
768            #[cfg(not(feature = "wgpu_backend"))]
769            let grad_cpu = grad.clone();
770            let grad_val = grad_cpu.data()[0];
771            let input_cpu = {
772                #[cfg(feature = "wgpu_backend")]
773                {
774                    if self.input.storage().device().is_wgpu() {
775                        self.input.to_cpu()
776                    } else {
777                        self.input.clone()
778                    }
779                }
780                #[cfg(not(feature = "wgpu_backend"))]
781                {
782                    self.input.clone()
783                }
784            };
785            let input_data = input_cpu.data();
786            let numel = input_data.len() as f32;
787            let scale = grad_val * 2.0 / numel;
788            let grad_data: Vec<f32> = input_data.iter().map(|x| (x - self.mean) * scale).collect();
789            let mut grad_input = Tensor::new(&grad_data, self.input.shape());
790            #[cfg(feature = "wgpu_backend")]
791            if self.input.storage().device().is_wgpu() {
792                grad_input = grad_input.to_wgpu();
793            }
794            self.input.accumulate_grad(&grad_input);
795            self.input.backward_step();
796        }
797    }
798}
799
800pub fn var(tensor: &Tensor) -> Tensor {
801    let t_cpu = {
802        #[cfg(feature = "wgpu_backend")]
803        {
804            if tensor.storage().device().is_wgpu() {
805                tensor.to_cpu()
806            } else {
807                tensor.clone()
808            }
809        }
810        #[cfg(not(feature = "wgpu_backend"))]
811        {
812            tensor.clone()
813        }
814    };
815    let data = t_cpu.data();
816    let numel = data.len() as f32;
817    let m = sum_auto(&data) / numel;
818    let sq: Vec<f32> = data.iter().map(|x| (x - m) * (x - m)).collect();
819    let v = sum_auto(&sq) / numel;
820    let mut out = Tensor::new_with_storage(Storage::new(vec![v]), &[]);
821    if tensor.requires_grad() {
822        out.set_requires_grad_mut(true);
823        out.set_op(Arc::new(VarBackward {
824            input: tensor.clone(),
825            mean: m,
826        }));
827    }
828    out
829}
830
831pub fn linear_mse_grads(input: &Tensor, output: &Tensor, target: &Tensor) -> (f32, Tensor, Tensor) {
832    let x = if input.is_contiguous() {
833        input.clone()
834    } else {
835        input.contiguous()
836    };
837    let y = if output.is_contiguous() {
838        output.clone()
839    } else {
840        output.contiguous()
841    };
842    let t = if target.is_contiguous() {
843        target.clone()
844    } else {
845        target.contiguous()
846    };
847
848    let x_shape = x.shape();
849    let y_shape = y.shape();
850    let t_shape = t.shape();
851    if x_shape.len() != 2 || y_shape.len() != 2 || t_shape.len() != 2 {
852        panic!("linear_mse_grads expects 2D tensors");
853    }
854    if y_shape != t_shape {
855        panic!("linear_mse_grads output and target shape mismatch");
856    }
857    if x_shape[0] != y_shape[0] {
858        panic!("linear_mse_grads batch mismatch");
859    }
860
861    let batch = x_shape[0];
862    let in_dim = x_shape[1];
863    let out_dim = y_shape[1];
864    let numel = (batch * out_dim) as f32;
865    let grad_scale = 2.0 / numel;
866
867    let x_data = x.data();
868    let y_data = y.data();
869    let t_data = t.data();
870
871    let mut grad_w = vec![0.0f32; out_dim * in_dim];
872    let mut grad_b = vec![0.0f32; out_dim];
873    let mut loss = 0.0f32;
874
875    for b in 0..batch {
876        for (o, gb) in grad_b.iter_mut().enumerate().take(out_dim) {
877            let idx = b * out_dim + o;
878            let d = y_data[idx] - t_data[idx];
879            loss += d * d;
880            let go = d * grad_scale;
881            *gb += go;
882            let w_row_offset = o * in_dim;
883            let x_row_offset = b * in_dim;
884            for i in 0..in_dim {
885                grad_w[w_row_offset + i] += go * x_data[x_row_offset + i];
886            }
887        }
888    }
889
890    let loss = loss / numel;
891    let grad_w_t = Tensor::new(&grad_w, &[out_dim, in_dim]);
892    let grad_b_t = Tensor::new(&grad_b, &[out_dim]);
893    (loss, grad_w_t, grad_b_t)
894}
895
896#[derive(Debug)]
897pub struct FusedMatmulBackward {
898    pub lhs: Tensor,
899    pub rhs: Tensor,
900    pub bias: Option<Tensor>,
901    pub output: Tensor,
902    pub activation: crate::backend::Activation,
903}
904
905impl BackwardOp for FusedMatmulBackward {
906    fn backward(&self, grad_output: &Tensor) {
907        let grad_pre_act = match self.activation {
908            crate::backend::Activation::ReLU => {
909                #[cfg(feature = "wgpu_backend")]
910                {
911                    if let (Some(out_buf), Some(grad_buf)) = (
912                        self.output.storage().wgpu_buffer(),
913                        grad_output.storage().wgpu_buffer(),
914                    ) {
915                        // ... (existing code)
916                        use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
917                        let target_shape = grad_output.shape();
918                        let out_buf = elementwise_wgpu_buffer(
919                            out_buf,
920                            self.output.shape(),
921                            self.output.strides(),
922                            Some((grad_buf, grad_output.shape(), grad_output.strides())),
923                            target_shape,
924                            ElementwiseOp::ReLUBackward,
925                            None,
926                        );
927                        let storage = Storage::new_wgpu(out_buf, target_shape.iter().product(), 0);
928                        Tensor::new_with_storage(storage, target_shape)
929                    } else {
930                        grad_output.clone()
931                    }
932                }
933                #[cfg(not(feature = "wgpu_backend"))]
934                grad_output.clone()
935            }
936            crate::backend::Activation::Sigmoid => {
937                #[cfg(feature = "wgpu_backend")]
938                {
939                    if let (Some(out_buf), Some(grad_buf)) = (
940                        self.output.storage().wgpu_buffer(),
941                        grad_output.storage().wgpu_buffer(),
942                    ) {
943                        use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
944                        let target_shape = grad_output.shape();
945                        let out_buf = elementwise_wgpu_buffer(
946                            out_buf,
947                            self.output.shape(),
948                            self.output.strides(),
949                            Some((grad_buf, grad_output.shape(), grad_output.strides())),
950                            target_shape,
951                            ElementwiseOp::SigmoidBackward,
952                            None,
953                        );
954                        let storage = Storage::new_wgpu(out_buf, target_shape.iter().product(), 0);
955                        Tensor::new_with_storage(storage, target_shape)
956                    } else {
957                        grad_output.clone()
958                    }
959                }
960                #[cfg(not(feature = "wgpu_backend"))]
961                grad_output.clone()
962            }
963            crate::backend::Activation::Tanh => {
964                #[cfg(feature = "wgpu_backend")]
965                {
966                    if let (Some(out_buf), Some(grad_buf)) = (
967                        self.output.storage().wgpu_buffer(),
968                        grad_output.storage().wgpu_buffer(),
969                    ) {
970                        use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
971                        let target_shape = grad_output.shape();
972                        let out_buf = elementwise_wgpu_buffer(
973                            out_buf,
974                            self.output.shape(),
975                            self.output.strides(),
976                            Some((grad_buf, grad_output.shape(), grad_output.strides())),
977                            target_shape,
978                            ElementwiseOp::TanhBackward,
979                            None,
980                        );
981                        let storage = Storage::new_wgpu(out_buf, target_shape.iter().product(), 0);
982                        Tensor::new_with_storage(storage, target_shape)
983                    } else {
984                        grad_output.clone()
985                    }
986                }
987                #[cfg(not(feature = "wgpu_backend"))]
988                grad_output.clone()
989            }
990            crate::backend::Activation::None => grad_output.clone(),
991        };
992
993        if let Some(bias) = &self.bias {
994            if bias.requires_grad() {
995                let grad_bias = sum_to(&grad_pre_act, bias.shape());
996                bias.accumulate_grad(&grad_bias);
997                bias.backward_step();
998            }
999        }
1000
1001        if self.lhs.requires_grad() {
1002            let rhs_t = self.rhs.t();
1003            let grad_lhs = matmul(&grad_pre_act, &rhs_t);
1004            self.lhs.accumulate_grad(&grad_lhs);
1005            self.lhs.backward_step();
1006        }
1007
1008        if self.rhs.requires_grad() {
1009            let grad_rhs = matmul(&self.lhs.t(), &grad_pre_act);
1010            self.rhs.accumulate_grad(&grad_rhs);
1011            self.rhs.backward_step();
1012        }
1013    }
1014}
1015
1016#[inline]
1017fn parse_usize_env(key: &str, default: usize) -> usize {
1018    std::env::var(key)
1019        .ok()
1020        .and_then(|s| s.parse::<usize>().ok())
1021        .unwrap_or(default)
1022}
1023
1024#[derive(Clone, Copy, PartialEq, Eq)]
1025enum CpuMatmulStrategy {
1026    Auto,
1027    Profile,
1028    Sgemm,
1029    Parallel,
1030}
1031
1032#[derive(Clone, Copy)]
1033struct CpuMatmulConfig {
1034    strategy: CpuMatmulStrategy,
1035    min_m: usize,
1036    min_k: usize,
1037    max_n: usize,
1038    profile_iters: usize,
1039}
1040
1041#[derive(Clone, Copy, PartialEq, Eq)]
1042enum CpuKernelChoice {
1043    Sgemm,
1044    Parallel,
1045}
1046
1047type MatmulPerfKey = (usize, usize, usize, bool);
1048type MatmulPerfCache = HashMap<MatmulPerfKey, CpuKernelChoice>;
1049
1050fn cpu_matmul_config() -> CpuMatmulConfig {
1051    static CFG: OnceLock<CpuMatmulConfig> = OnceLock::new();
1052    *CFG.get_or_init(|| {
1053        let strategy = match std::env::var("RUSTORCH_CPU_MATMUL_STRATEGY")
1054            .unwrap_or_else(|_| "auto".to_string())
1055            .to_ascii_lowercase()
1056            .as_str()
1057        {
1058            "parallel" => CpuMatmulStrategy::Parallel,
1059            "sgemm" => CpuMatmulStrategy::Sgemm,
1060            "profile" => CpuMatmulStrategy::Profile,
1061            _ => CpuMatmulStrategy::Auto,
1062        };
1063
1064        CpuMatmulConfig {
1065            strategy,
1066            min_m: parse_usize_env("RUSTORCH_CPU_MATMUL_MIN_M", 128),
1067            min_k: parse_usize_env("RUSTORCH_CPU_MATMUL_MIN_K", 256),
1068            max_n: parse_usize_env("RUSTORCH_CPU_MATMUL_MAX_N", 128),
1069            profile_iters: parse_usize_env("RUSTORCH_CPU_MATMUL_PROFILE_ITERS", 2),
1070        }
1071    })
1072}
1073
1074#[inline]
1075fn should_use_parallel_auto(m: usize, k: usize, n: usize) -> bool {
1076    let cfg = cpu_matmul_config();
1077    m >= cfg.min_m && k >= cfg.min_k && n <= cfg.max_n
1078}
1079
1080fn matmul_profile_cache() -> &'static Mutex<MatmulPerfCache> {
1081    static CACHE: OnceLock<Mutex<MatmulPerfCache>> = OnceLock::new();
1082    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
1083}
1084
1085fn matmul_cpu_sgemm_core(
1086    lhs_data: &[f32],
1087    rhs_data: &[f32],
1088    m: usize,
1089    k: usize,
1090    n: usize,
1091    lhs_stride0: isize,
1092    lhs_stride1: isize,
1093    rhs_stride0: isize,
1094    rhs_stride1: isize,
1095    bias: Option<&[f32]>,
1096) -> Vec<f32> {
1097    let mut out = vec![0.0f32; m * n];
1098    unsafe {
1099        matrixmultiply::sgemm(
1100            m,
1101            k,
1102            n,
1103            1.0,
1104            lhs_data.as_ptr(),
1105            lhs_stride0,
1106            lhs_stride1,
1107            rhs_data.as_ptr(),
1108            rhs_stride0,
1109            rhs_stride1,
1110            0.0,
1111            out.as_mut_ptr(),
1112            n as isize,
1113            1,
1114        );
1115    }
1116    if let Some(bias_data) = bias {
1117        out.par_chunks_mut(n).for_each(|row| {
1118            row.iter_mut()
1119                .zip(bias_data.iter())
1120                .for_each(|(v, b)| *v += *b);
1121        });
1122    }
1123    out
1124}
1125
1126fn bench_kernel<F: Fn() -> Vec<f32>>(f: F, iters: usize) -> u128 {
1127    let mut total_ns = 0u128;
1128    let mut acc = 0.0f32;
1129    for _ in 0..iters {
1130        let t0 = Instant::now();
1131        let out = f();
1132        total_ns += t0.elapsed().as_nanos();
1133        if let Some(v) = out.first() {
1134            acc += *v;
1135        }
1136        black_box(acc);
1137    }
1138    total_ns
1139}
1140
1141fn choose_cpu_kernel(
1142    m: usize,
1143    k: usize,
1144    n: usize,
1145    has_bias: bool,
1146    lhs_data: &[f32],
1147    rhs_data: &[f32],
1148    lhs_stride0: isize,
1149    lhs_stride1: isize,
1150    rhs_stride0: isize,
1151    rhs_stride1: isize,
1152    bias: Option<&[f32]>,
1153) -> CpuKernelChoice {
1154    let cfg = cpu_matmul_config();
1155    match cfg.strategy {
1156        CpuMatmulStrategy::Parallel => CpuKernelChoice::Parallel,
1157        CpuMatmulStrategy::Sgemm => CpuKernelChoice::Sgemm,
1158        CpuMatmulStrategy::Auto => {
1159            if should_use_parallel_auto(m, k, n) {
1160                CpuKernelChoice::Parallel
1161            } else {
1162                CpuKernelChoice::Sgemm
1163            }
1164        }
1165        CpuMatmulStrategy::Profile => {
1166            let key = (m, k, n, has_bias);
1167            if let Some(cached) = matmul_profile_cache().lock().get(&key).copied() {
1168                return cached;
1169            }
1170
1171            let iters = cfg.profile_iters.max(1);
1172            let sgemm_ns = bench_kernel(
1173                || {
1174                    matmul_cpu_sgemm_core(
1175                        lhs_data,
1176                        rhs_data,
1177                        m,
1178                        k,
1179                        n,
1180                        lhs_stride0,
1181                        lhs_stride1,
1182                        rhs_stride0,
1183                        rhs_stride1,
1184                        bias,
1185                    )
1186                },
1187                iters,
1188            );
1189            let parallel_ns = bench_kernel(
1190                || matmul_cpu_parallel_core(lhs_data, rhs_data, m, k, n, bias),
1191                iters,
1192            );
1193            let choice = if parallel_ns < sgemm_ns {
1194                CpuKernelChoice::Parallel
1195            } else {
1196                CpuKernelChoice::Sgemm
1197            };
1198            matmul_profile_cache().lock().insert(key, choice);
1199            choice
1200        }
1201    }
1202}
1203
1204fn matmul_cpu_parallel_core(
1205    lhs_data: &[f32],
1206    rhs_data: &[f32],
1207    m: usize,
1208    k: usize,
1209    n: usize,
1210    bias: Option<&[f32]>,
1211) -> Vec<f32> {
1212    let mut result = vec![0.0f32; m * n];
1213    result.par_chunks_mut(n).enumerate().for_each(|(i, row)| {
1214        let lhs_row = &lhs_data[i * k..(i + 1) * k];
1215        for j in 0..n {
1216            let mut sum = bias.map_or(0.0, |b| b[j]);
1217            let mut p = 0usize;
1218            while p + 8 <= k {
1219                sum += lhs_row[p] * rhs_data[p * n + j];
1220                sum += lhs_row[p + 1] * rhs_data[(p + 1) * n + j];
1221                sum += lhs_row[p + 2] * rhs_data[(p + 2) * n + j];
1222                sum += lhs_row[p + 3] * rhs_data[(p + 3) * n + j];
1223                sum += lhs_row[p + 4] * rhs_data[(p + 4) * n + j];
1224                sum += lhs_row[p + 5] * rhs_data[(p + 5) * n + j];
1225                sum += lhs_row[p + 6] * rhs_data[(p + 6) * n + j];
1226                sum += lhs_row[p + 7] * rhs_data[(p + 7) * n + j];
1227                p += 8;
1228            }
1229            while p < k {
1230                sum += lhs_row[p] * rhs_data[p * n + j];
1231                p += 1;
1232            }
1233            row[j] = sum;
1234        }
1235    });
1236    result
1237}
1238
1239#[derive(Clone, Copy, PartialEq, Eq, Hash)]
1240enum ElemwiseKind {
1241    Add,
1242    Sub,
1243    Mul,
1244}
1245
1246#[derive(Clone, Copy, PartialEq, Eq)]
1247enum CpuElemwiseStrategy {
1248    Auto,
1249    Profile,
1250    Scalar,
1251    Simd,
1252}
1253
1254#[derive(Clone, Copy, PartialEq, Eq)]
1255enum ElemwiseKernelChoice {
1256    Scalar,
1257    Simd,
1258}
1259
1260#[derive(Clone, Copy)]
1261struct CpuElemwiseConfig {
1262    strategy: CpuElemwiseStrategy,
1263    min_len: usize,
1264    profile_iters: usize,
1265}
1266
1267type ElemwisePerfKey = (usize, ElemwiseKind);
1268type ElemwisePerfCache = HashMap<ElemwisePerfKey, ElemwiseKernelChoice>;
1269
1270fn cpu_elemwise_config() -> CpuElemwiseConfig {
1271    static CFG: OnceLock<CpuElemwiseConfig> = OnceLock::new();
1272    *CFG.get_or_init(|| {
1273        let strategy = match std::env::var("RUSTORCH_CPU_ELEMWISE_STRATEGY")
1274            .unwrap_or_else(|_| "auto".to_string())
1275            .to_ascii_lowercase()
1276            .as_str()
1277        {
1278            "simd" => CpuElemwiseStrategy::Simd,
1279            "scalar" => CpuElemwiseStrategy::Scalar,
1280            "profile" => CpuElemwiseStrategy::Profile,
1281            _ => CpuElemwiseStrategy::Auto,
1282        };
1283        CpuElemwiseConfig {
1284            strategy,
1285            min_len: parse_usize_env("RUSTORCH_CPU_ELEMWISE_MIN_LEN", 2048),
1286            profile_iters: parse_usize_env("RUSTORCH_CPU_ELEMWISE_PROFILE_ITERS", 2),
1287        }
1288    })
1289}
1290
1291fn elemwise_profile_cache() -> &'static Mutex<ElemwisePerfCache> {
1292    static CACHE: OnceLock<Mutex<ElemwisePerfCache>> = OnceLock::new();
1293    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
1294}
1295
1296#[inline]
1297fn apply_elemwise_scalar(a: f32, b: f32, kind: ElemwiseKind) -> f32 {
1298    match kind {
1299        ElemwiseKind::Add => a + b,
1300        ElemwiseKind::Sub => a - b,
1301        ElemwiseKind::Mul => a * b,
1302    }
1303}
1304
1305#[inline]
1306fn apply_elemwise_simd(a: f32x8, b: f32x8, kind: ElemwiseKind) -> f32x8 {
1307    match kind {
1308        ElemwiseKind::Add => a + b,
1309        ElemwiseKind::Sub => a - b,
1310        ElemwiseKind::Mul => a * b,
1311    }
1312}
1313
1314fn elemwise_scalar(lhs: &[f32], rhs: &[f32], kind: ElemwiseKind) -> Vec<f32> {
1315    lhs.par_iter()
1316        .zip(rhs.par_iter())
1317        .map(|(a, b)| apply_elemwise_scalar(*a, *b, kind))
1318        .collect()
1319}
1320
1321fn elemwise_simd(lhs: &[f32], rhs: &[f32], kind: ElemwiseKind) -> Vec<f32> {
1322    let len = lhs.len();
1323    let mut out = vec![0.0f32; len];
1324    let vec_len = len / 8 * 8;
1325    let lanes = 8usize;
1326
1327    out[..vec_len]
1328        .par_chunks_mut(1024)
1329        .enumerate()
1330        .for_each(|(chunk_idx, out_chunk)| {
1331            let chunk_start = chunk_idx * 1024;
1332            let lhs_chunk = &lhs[chunk_start..chunk_start + out_chunk.len()];
1333            let rhs_chunk = &rhs[chunk_start..chunk_start + out_chunk.len()];
1334            let mut i = 0usize;
1335            while i + lanes <= out_chunk.len() {
1336                let mut la = [0.0f32; 8];
1337                let mut lb = [0.0f32; 8];
1338                la.copy_from_slice(&lhs_chunk[i..i + lanes]);
1339                lb.copy_from_slice(&rhs_chunk[i..i + lanes]);
1340                let va = f32x8::from(la);
1341                let vb = f32x8::from(lb);
1342                let vc = apply_elemwise_simd(va, vb, kind);
1343                let oc: [f32; 8] = vc.into();
1344                out_chunk[i..i + lanes].copy_from_slice(&oc);
1345                i += lanes;
1346            }
1347            while i < out_chunk.len() {
1348                out_chunk[i] = apply_elemwise_scalar(lhs_chunk[i], rhs_chunk[i], kind);
1349                i += 1;
1350            }
1351        });
1352
1353    for i in vec_len..len {
1354        out[i] = apply_elemwise_scalar(lhs[i], rhs[i], kind);
1355    }
1356    out
1357}
1358
1359fn choose_elemwise_kernel(
1360    len: usize,
1361    kind: ElemwiseKind,
1362    lhs: &[f32],
1363    rhs: &[f32],
1364) -> ElemwiseKernelChoice {
1365    let cfg = cpu_elemwise_config();
1366    match cfg.strategy {
1367        CpuElemwiseStrategy::Simd => ElemwiseKernelChoice::Simd,
1368        CpuElemwiseStrategy::Scalar => ElemwiseKernelChoice::Scalar,
1369        CpuElemwiseStrategy::Auto => {
1370            if len >= cfg.min_len {
1371                ElemwiseKernelChoice::Simd
1372            } else {
1373                ElemwiseKernelChoice::Scalar
1374            }
1375        }
1376        CpuElemwiseStrategy::Profile => {
1377            let key = (len, kind);
1378            if let Some(cached) = elemwise_profile_cache().lock().get(&key).copied() {
1379                return cached;
1380            }
1381            let iters = cfg.profile_iters.max(1);
1382            let scalar_ns = {
1383                let mut total = 0u128;
1384                for _ in 0..iters {
1385                    let t0 = Instant::now();
1386                    let out = elemwise_scalar(lhs, rhs, kind);
1387                    black_box(out.len());
1388                    total += t0.elapsed().as_nanos();
1389                }
1390                total
1391            };
1392            let simd_ns = {
1393                let mut total = 0u128;
1394                for _ in 0..iters {
1395                    let t0 = Instant::now();
1396                    let out = elemwise_simd(lhs, rhs, kind);
1397                    black_box(out.len());
1398                    total += t0.elapsed().as_nanos();
1399                }
1400                total
1401            };
1402            let choice = if simd_ns < scalar_ns {
1403                ElemwiseKernelChoice::Simd
1404            } else {
1405                ElemwiseKernelChoice::Scalar
1406            };
1407            elemwise_profile_cache().lock().insert(key, choice);
1408            choice
1409        }
1410    }
1411}
1412
1413fn elemwise_auto(lhs: &[f32], rhs: &[f32], kind: ElemwiseKind) -> Vec<f32> {
1414    match choose_elemwise_kernel(lhs.len(), kind, lhs, rhs) {
1415        ElemwiseKernelChoice::Simd => elemwise_simd(lhs, rhs, kind),
1416        ElemwiseKernelChoice::Scalar => elemwise_scalar(lhs, rhs, kind),
1417    }
1418}
1419
1420#[derive(Clone, Copy, PartialEq, Eq)]
1421enum CpuReductionStrategy {
1422    Auto,
1423    Profile,
1424    Scalar,
1425    Simd,
1426}
1427
1428#[derive(Clone, Copy, PartialEq, Eq)]
1429enum ReductionKernelChoice {
1430    Scalar,
1431    Simd,
1432}
1433
1434#[derive(Clone, Copy)]
1435struct CpuReductionConfig {
1436    strategy: CpuReductionStrategy,
1437    min_len: usize,
1438    profile_iters: usize,
1439}
1440
1441fn cpu_reduction_config() -> CpuReductionConfig {
1442    static CFG: OnceLock<CpuReductionConfig> = OnceLock::new();
1443    *CFG.get_or_init(|| {
1444        let strategy = match std::env::var("RUSTORCH_CPU_REDUCTION_STRATEGY")
1445            .unwrap_or_else(|_| "auto".to_string())
1446            .to_ascii_lowercase()
1447            .as_str()
1448        {
1449            "simd" => CpuReductionStrategy::Simd,
1450            "scalar" => CpuReductionStrategy::Scalar,
1451            "profile" => CpuReductionStrategy::Profile,
1452            _ => CpuReductionStrategy::Auto,
1453        };
1454        CpuReductionConfig {
1455            strategy,
1456            min_len: parse_usize_env("RUSTORCH_CPU_REDUCTION_MIN_LEN", 4096),
1457            profile_iters: parse_usize_env("RUSTORCH_CPU_REDUCTION_PROFILE_ITERS", 2),
1458        }
1459    })
1460}
1461
1462fn reduction_profile_cache() -> &'static Mutex<HashMap<usize, ReductionKernelChoice>> {
1463    static CACHE: OnceLock<Mutex<HashMap<usize, ReductionKernelChoice>>> = OnceLock::new();
1464    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
1465}
1466
1467fn sum_scalar(data: &[f32]) -> f32 {
1468    data.par_iter().copied().sum()
1469}
1470
1471fn sum_simd_chunk(chunk: &[f32]) -> f32 {
1472    let lanes = 8usize;
1473    let vec_len = chunk.len() / lanes * lanes;
1474    let mut acc = f32x8::from([0.0; 8]);
1475    let mut i = 0usize;
1476    while i < vec_len {
1477        let mut v = [0.0f32; 8];
1478        v.copy_from_slice(&chunk[i..i + lanes]);
1479        acc += f32x8::from(v);
1480        i += lanes;
1481    }
1482    let a: [f32; 8] = acc.into();
1483    let mut s = a.iter().sum::<f32>();
1484    while i < chunk.len() {
1485        s += chunk[i];
1486        i += 1;
1487    }
1488    s
1489}
1490
1491fn sum_simd(data: &[f32]) -> f32 {
1492    data.par_chunks(4096).map(sum_simd_chunk).sum()
1493}
1494
1495fn choose_reduction_kernel(len: usize, data: &[f32]) -> ReductionKernelChoice {
1496    let cfg = cpu_reduction_config();
1497    match cfg.strategy {
1498        CpuReductionStrategy::Simd => ReductionKernelChoice::Simd,
1499        CpuReductionStrategy::Scalar => ReductionKernelChoice::Scalar,
1500        CpuReductionStrategy::Auto => {
1501            if len >= cfg.min_len {
1502                ReductionKernelChoice::Simd
1503            } else {
1504                ReductionKernelChoice::Scalar
1505            }
1506        }
1507        CpuReductionStrategy::Profile => {
1508            if let Some(cached) = reduction_profile_cache().lock().get(&len).copied() {
1509                return cached;
1510            }
1511            let iters = cfg.profile_iters.max(1);
1512            let mut scalar_ns = 0u128;
1513            let mut simd_ns = 0u128;
1514            for _ in 0..iters {
1515                let t0 = Instant::now();
1516                let s = sum_scalar(data);
1517                scalar_ns += t0.elapsed().as_nanos();
1518                black_box(s);
1519
1520                let t1 = Instant::now();
1521                let v = sum_simd(data);
1522                simd_ns += t1.elapsed().as_nanos();
1523                black_box(v);
1524            }
1525            let choice = if simd_ns < scalar_ns {
1526                ReductionKernelChoice::Simd
1527            } else {
1528                ReductionKernelChoice::Scalar
1529            };
1530            reduction_profile_cache().lock().insert(len, choice);
1531            choice
1532        }
1533    }
1534}
1535
1536pub(crate) fn sum_auto(data: &[f32]) -> f32 {
1537    match choose_reduction_kernel(data.len(), data) {
1538        ReductionKernelChoice::Simd => sum_simd(data),
1539        ReductionKernelChoice::Scalar => sum_scalar(data),
1540    }
1541}
1542
1543pub fn matmul_fused(
1544    lhs: &Tensor,
1545    rhs: &Tensor,
1546    bias: Option<&Tensor>,
1547    activation: crate::backend::Activation,
1548) -> Tensor {
1549    #[cfg(feature = "wgpu_backend")]
1550    {
1551        if let (Some(lhs_buf), Some(rhs_buf)) =
1552            (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer())
1553        {
1554            let m = lhs.shape()[0];
1555            let _k = lhs.shape()[1];
1556            let n = rhs.shape()[1];
1557
1558            let bias_data =
1559                bias.and_then(|b| b.storage().wgpu_buffer().map(|buf| (buf, b.shape())));
1560
1561            use crate::backend::wgpu::matmul_fused_wgpu_buffer;
1562            let output_buf = matmul_fused_wgpu_buffer(
1563                lhs_buf,
1564                lhs.shape(),
1565                rhs_buf,
1566                rhs.shape(),
1567                bias_data,
1568                activation,
1569            );
1570
1571            let storage = Storage::new_wgpu(output_buf, m * n, 0);
1572            let mut tensor = Tensor::new_with_storage(storage, &[m, n]);
1573
1574            if lhs.requires_grad()
1575                || rhs.requires_grad()
1576                || bias.map_or(false, |b| b.requires_grad())
1577            {
1578                tensor.set_requires_grad_mut(true);
1579                tensor.set_op(Arc::new(FusedMatmulBackward {
1580                    lhs: lhs.clone(),
1581                    rhs: rhs.clone(),
1582                    bias: bias.cloned(),
1583                    output: tensor.detach(),
1584                    activation,
1585                }));
1586            }
1587            return tensor;
1588        }
1589    }
1590
1591    if matches!(activation, crate::backend::Activation::None) {
1592        let lhs_shape = lhs.shape();
1593        let rhs_shape = rhs.shape();
1594        if lhs_shape.len() == 2 && rhs_shape.len() == 2 && lhs_shape[1] == rhs_shape[0] {
1595            let m = lhs_shape[0];
1596            let k = lhs_shape[1];
1597            let n = rhs_shape[1];
1598
1599            let lhs_contig = if lhs.is_contiguous() {
1600                lhs.clone()
1601            } else {
1602                lhs.contiguous()
1603            };
1604            let rhs_contig = if rhs.is_contiguous() {
1605                rhs.clone()
1606            } else {
1607                rhs.contiguous()
1608            };
1609
1610            #[cfg(feature = "wgpu_backend")]
1611            let (lhs_contig, rhs_contig) = {
1612                let l = if lhs_contig.storage().device().is_wgpu() {
1613                    lhs_contig.to_cpu()
1614                } else {
1615                    lhs_contig
1616                };
1617                let r = if rhs_contig.storage().device().is_wgpu() {
1618                    rhs_contig.to_cpu()
1619                } else {
1620                    rhs_contig
1621                };
1622                (l, r)
1623            };
1624
1625            let lhs_guard = lhs_contig.data();
1626            let rhs_guard = rhs_contig.data();
1627            let lhs_data = &*lhs_guard;
1628            let rhs_data = &*rhs_guard;
1629
1630            let bias_vec = bias.and_then(|b| {
1631                if b.shape().len() == 1 && b.shape()[0] == n {
1632                    let b_cpu = {
1633                        #[cfg(feature = "wgpu_backend")]
1634                        {
1635                            if b.storage().device().is_wgpu() {
1636                                b.to_cpu()
1637                            } else {
1638                                b.clone()
1639                            }
1640                        }
1641                        #[cfg(not(feature = "wgpu_backend"))]
1642                        {
1643                            b.clone()
1644                        }
1645                    };
1646                    let bg = b_cpu.data();
1647                    Some(bg.to_vec())
1648                } else {
1649                    None
1650                }
1651            });
1652            let bias_slice = bias_vec.as_deref();
1653
1654            let lhs_s0 = lhs_contig.strides()[0] as isize;
1655            let lhs_s1 = lhs_contig.strides()[1] as isize;
1656            let rhs_s0 = rhs_contig.strides()[0] as isize;
1657            let rhs_s1 = rhs_contig.strides()[1] as isize;
1658            let kernel = choose_cpu_kernel(
1659                m,
1660                k,
1661                n,
1662                bias_slice.is_some(),
1663                lhs_data,
1664                rhs_data,
1665                lhs_s0,
1666                lhs_s1,
1667                rhs_s0,
1668                rhs_s1,
1669                bias_slice,
1670            );
1671            let result_data = match kernel {
1672                CpuKernelChoice::Parallel => {
1673                    matmul_cpu_parallel_core(lhs_data, rhs_data, m, k, n, bias_slice)
1674                }
1675                CpuKernelChoice::Sgemm => matmul_cpu_sgemm_core(
1676                    lhs_data, rhs_data, m, k, n, lhs_s0, lhs_s1, rhs_s0, rhs_s1, bias_slice,
1677                ),
1678            };
1679
1680            let storage = Storage::new(result_data);
1681            let mut tensor = Tensor::new_with_storage(storage, &[m, n]);
1682            if lhs.requires_grad()
1683                || rhs.requires_grad()
1684                || bias.map_or(false, |b| b.requires_grad())
1685            {
1686                tensor.set_requires_grad_mut(true);
1687                tensor.set_op(Arc::new(FusedMatmulBackward {
1688                    lhs: lhs.clone(),
1689                    rhs: rhs.clone(),
1690                    bias: bias.cloned(),
1691                    output: tensor.detach(),
1692                    activation,
1693                }));
1694            }
1695            return tensor;
1696        }
1697    }
1698
1699    bump_pipeline_stat("staged");
1700    let mut out = lhs.matmul(rhs);
1701    if let Some(b) = bias {
1702        out = out.add(b);
1703    }
1704    match activation {
1705        crate::backend::Activation::ReLU => out.relu(),
1706        crate::backend::Activation::Sigmoid => crate::ops::activations::sigmoid(&out),
1707        crate::backend::Activation::Tanh => crate::ops::activations::tanh(&out),
1708        crate::backend::Activation::None => out,
1709    }
1710}
1711
1712#[derive(Clone, Copy, PartialEq, Eq)]
1713enum FusedPipelineStrategy {
1714    Auto,
1715    Profile,
1716    Staged,
1717    Fused,
1718}
1719
1720#[derive(Clone, Copy, PartialEq, Eq)]
1721enum FusedPipelineChoice {
1722    Staged,
1723    Fused,
1724}
1725
1726#[derive(Clone, Copy)]
1727struct FusedPipelineConfig {
1728    strategy: FusedPipelineStrategy,
1729    profile_iters: usize,
1730}
1731
1732type FusedPipelineKey = (usize, usize, usize, bool, bool, i32);
1733
1734fn fused_pipeline_config() -> FusedPipelineConfig {
1735    static CFG: OnceLock<FusedPipelineConfig> = OnceLock::new();
1736    *CFG.get_or_init(|| {
1737        let strategy = match std::env::var("RUSTORCH_FUSED_PIPELINE_STRATEGY")
1738            .unwrap_or_else(|_| "auto".to_string())
1739            .to_ascii_lowercase()
1740            .as_str()
1741        {
1742            "fused" => FusedPipelineStrategy::Fused,
1743            "staged" => FusedPipelineStrategy::Staged,
1744            "profile" => FusedPipelineStrategy::Profile,
1745            _ => FusedPipelineStrategy::Auto,
1746        };
1747        FusedPipelineConfig {
1748            strategy,
1749            profile_iters: parse_usize_env("RUSTORCH_FUSED_PIPELINE_PROFILE_ITERS", 1),
1750        }
1751    })
1752}
1753
1754fn fused_pipeline_cache() -> &'static Mutex<HashMap<FusedPipelineKey, FusedPipelineChoice>> {
1755    static CACHE: OnceLock<Mutex<HashMap<FusedPipelineKey, FusedPipelineChoice>>> = OnceLock::new();
1756    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
1757}
1758
1759fn fused_pipeline_stats() -> &'static Mutex<HashMap<String, u64>> {
1760    static STATS: OnceLock<Mutex<HashMap<String, u64>>> = OnceLock::new();
1761    STATS.get_or_init(|| Mutex::new(HashMap::new()))
1762}
1763
1764fn bump_pipeline_stat(key: &str) {
1765    let mut s = fused_pipeline_stats().lock();
1766    *s.entry(key.to_string()).or_insert(0) += 1;
1767}
1768
1769pub fn get_fused_pipeline_stats() -> HashMap<String, u64> {
1770    fused_pipeline_stats().lock().clone()
1771}
1772
1773fn apply_activation(t: Tensor, activation: crate::backend::Activation) -> Tensor {
1774    match activation {
1775        crate::backend::Activation::ReLU => t.relu(),
1776        crate::backend::Activation::Sigmoid => crate::ops::activations::sigmoid(&t),
1777        crate::backend::Activation::Tanh => crate::ops::activations::tanh(&t),
1778        crate::backend::Activation::None => t,
1779    }
1780}
1781
1782fn pipeline_staged(
1783    lhs: &Tensor,
1784    rhs: &Tensor,
1785    bias: Option<&Tensor>,
1786    norm_weight: Option<&Tensor>,
1787    norm_bias: Option<&Tensor>,
1788    eps: f32,
1789    activation: crate::backend::Activation,
1790) -> Tensor {
1791    let mut out = matmul(lhs, rhs);
1792    if let Some(b) = bias {
1793        out = add(&out, b);
1794    }
1795    let norm_shape = [rhs.shape()[1]];
1796    out = layer_norm(&out, &norm_shape, norm_weight, norm_bias, eps);
1797    apply_activation(out, activation)
1798}
1799
1800fn pipeline_fused(
1801    lhs: &Tensor,
1802    rhs: &Tensor,
1803    bias: Option<&Tensor>,
1804    norm_weight: Option<&Tensor>,
1805    norm_bias: Option<&Tensor>,
1806    eps: f32,
1807    activation: crate::backend::Activation,
1808) -> Tensor {
1809    let out = matmul_fused(lhs, rhs, bias, crate::backend::Activation::None);
1810    let norm_shape = [rhs.shape()[1]];
1811    let out = layer_norm(&out, &norm_shape, norm_weight, norm_bias, eps);
1812    apply_activation(out, activation)
1813}
1814
1815pub fn matmul_bias_norm_activation(
1816    lhs: &Tensor,
1817    rhs: &Tensor,
1818    bias: Option<&Tensor>,
1819    norm_weight: Option<&Tensor>,
1820    norm_bias: Option<&Tensor>,
1821    eps: f32,
1822    activation: crate::backend::Activation,
1823) -> Tensor {
1824    let m = lhs.shape()[0];
1825    let k = lhs.shape()[1];
1826    let n = rhs.shape()[1];
1827    let key: FusedPipelineKey = (
1828        m,
1829        k,
1830        n,
1831        norm_weight.is_some(),
1832        norm_bias.is_some(),
1833        activation as i32,
1834    );
1835    let cfg = fused_pipeline_config();
1836    let choice = match cfg.strategy {
1837        FusedPipelineStrategy::Fused => FusedPipelineChoice::Fused,
1838        FusedPipelineStrategy::Staged => FusedPipelineChoice::Staged,
1839        FusedPipelineStrategy::Auto => {
1840            if m >= 128 && k >= 128 && n >= 32 {
1841                FusedPipelineChoice::Fused
1842            } else {
1843                FusedPipelineChoice::Staged
1844            }
1845        }
1846        FusedPipelineStrategy::Profile => {
1847            if let Some(cached) = fused_pipeline_cache().lock().get(&key).copied() {
1848                cached
1849            } else {
1850                let iters = cfg.profile_iters.max(1);
1851                let mut staged_ns = 0u128;
1852                let mut fused_ns = 0u128;
1853                for _ in 0..iters {
1854                    let t0 = Instant::now();
1855                    let s =
1856                        pipeline_staged(lhs, rhs, bias, norm_weight, norm_bias, eps, activation);
1857                    staged_ns += t0.elapsed().as_nanos();
1858                    black_box(s.shape()[0]);
1859
1860                    let t1 = Instant::now();
1861                    let f = pipeline_fused(lhs, rhs, bias, norm_weight, norm_bias, eps, activation);
1862                    fused_ns += t1.elapsed().as_nanos();
1863                    black_box(f.shape()[0]);
1864                }
1865                let c = if fused_ns < staged_ns {
1866                    FusedPipelineChoice::Fused
1867                } else {
1868                    FusedPipelineChoice::Staged
1869                };
1870                fused_pipeline_cache().lock().insert(key, c);
1871                c
1872            }
1873        }
1874    };
1875
1876    match choice {
1877        FusedPipelineChoice::Fused => {
1878            bump_pipeline_stat("fused");
1879            pipeline_fused(lhs, rhs, bias, norm_weight, norm_bias, eps, activation)
1880        }
1881        FusedPipelineChoice::Staged => {
1882            bump_pipeline_stat("staged");
1883            pipeline_staged(lhs, rhs, bias, norm_weight, norm_bias, eps, activation)
1884        }
1885    }
1886}
1887
1888pub fn matmul(lhs: &Tensor, rhs: &Tensor) -> Tensor {
1889    let lhs_shape = lhs.shape();
1890    let rhs_shape = rhs.shape();
1891
1892    if lhs_shape.len() != 2 || rhs_shape.len() != 2 {
1893        panic!("Matmul only supports 2D");
1894    }
1895
1896    let m = lhs_shape[0];
1897    let k = lhs_shape[1];
1898    let k2 = rhs_shape[0];
1899    let n = rhs_shape[1];
1900
1901    if k != k2 {
1902        panic!("Matmul dimension mismatch");
1903    }
1904
1905    #[cfg(feature = "wgpu_backend")]
1906    {
1907        // Ensure both tensors are on the same device
1908        let lhs_is_wgpu = lhs.storage().device().is_wgpu();
1909        let rhs_is_wgpu = rhs.storage().device().is_wgpu();
1910
1911        let (lhs, rhs) = if lhs_is_wgpu && !rhs_is_wgpu {
1912            // LHS is GPU, RHS is CPU - move RHS to GPU
1913            (lhs.clone(), rhs.to_wgpu())
1914        } else if !lhs_is_wgpu && rhs_is_wgpu {
1915            // LHS is CPU, RHS is GPU - move LHS to GPU
1916            (lhs.to_wgpu(), rhs.clone())
1917        } else {
1918            (lhs.clone(), rhs.clone())
1919        };
1920
1921        if let (Some(lhs_buf), Some(rhs_buf)) =
1922            (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer())
1923        {
1924            let lhs_strides = lhs.strides();
1925            let rhs_strides = rhs.strides();
1926
1927            let lhs_is_contig = lhs.is_contiguous();
1928            let rhs_is_contig = rhs.is_contiguous();
1929
1930            // Check if tensor is transposed: strides = [1, rows] means transposed from [cols, rows]
1931            // Original: shape [rows, cols], strides [cols, 1]
1932            // Transposed: shape [cols, rows], strides [1, cols]
1933            let lhs_is_transposed = !lhs_is_contig && lhs_strides[0] == 1;
1934            let rhs_is_transposed = !rhs_is_contig && rhs_strides[0] == 1;
1935
1936            if lhs_is_contig && rhs_is_contig {
1937                use crate::backend::wgpu::{matmul_wgpu_buffer, Activation};
1938                let output_buf =
1939                    matmul_wgpu_buffer(lhs_buf, lhs_shape, rhs_buf, rhs_shape, Activation::None);
1940
1941                let storage = Storage::new_wgpu(output_buf, m * n, 0);
1942                let mut tensor = Tensor::new_with_storage(storage, &[m, n]);
1943
1944                if lhs.requires_grad() || rhs.requires_grad() {
1945                    tensor.set_requires_grad_mut(true);
1946                    tensor.set_op(Arc::new(MatmulBackward {
1947                        lhs: lhs.clone(),
1948                        rhs: rhs.clone(),
1949                    }));
1950                }
1951                return tensor;
1952            }
1953
1954            if lhs_is_transposed || rhs_is_transposed {
1955                // Ensure previous GPU commands are executed before reading transposed data
1956                crate::backend::wgpu::flush_queue();
1957
1958                // For transposed tensor, the data is stored in transposed layout
1959                // We need to make it contiguous OR adjust the matmul logic
1960                // Let's make it contiguous for now
1961                let lhs_contig = if lhs_is_transposed {
1962                    lhs.contiguous()
1963                } else {
1964                    lhs.clone()
1965                };
1966                let rhs_contig = if rhs_is_transposed {
1967                    rhs.contiguous()
1968                } else {
1969                    rhs.clone()
1970                };
1971
1972                // Ensure contiguous commands are executed
1973                crate::backend::wgpu::flush_queue();
1974
1975                if lhs_contig.storage().wgpu_buffer().is_some()
1976                    && rhs_contig.storage().wgpu_buffer().is_some()
1977                {
1978                    let lhs_buf = lhs_contig.storage().wgpu_buffer().unwrap();
1979                    let rhs_buf = rhs_contig.storage().wgpu_buffer().unwrap();
1980
1981                    use crate::backend::wgpu::{matmul_wgpu_buffer, Activation};
1982                    let output_buf = matmul_wgpu_buffer(
1983                        lhs_buf,
1984                        lhs_contig.shape(),
1985                        rhs_buf,
1986                        rhs_contig.shape(),
1987                        Activation::None,
1988                    );
1989
1990                    let storage = Storage::new_wgpu(output_buf, m * n, 0);
1991                    let mut tensor = Tensor::new_with_storage(storage, &[m, n]);
1992
1993                    if lhs.requires_grad() || rhs.requires_grad() {
1994                        tensor.set_requires_grad_mut(true);
1995                        tensor.set_op(Arc::new(MatmulBackward {
1996                            lhs: lhs.clone(),
1997                            rhs: rhs.clone(),
1998                        }));
1999                    }
2000                    return tensor;
2001                }
2002            }
2003
2004            return matmul(&lhs.contiguous(), &rhs.contiguous());
2005        }
2006    }
2007
2008    // CPU MatrixMultiply
2009    let lhs_contig = if lhs.is_contiguous() {
2010        lhs.clone()
2011    } else {
2012        lhs.contiguous()
2013    };
2014    let rhs_contig = if rhs.is_contiguous() {
2015        rhs.clone()
2016    } else {
2017        rhs.contiguous()
2018    };
2019
2020    #[cfg(feature = "wgpu_backend")]
2021    let (lhs_contig, rhs_contig) = {
2022        let l = if lhs_contig.storage().device().is_wgpu() {
2023            lhs_contig.to_cpu()
2024        } else {
2025            lhs_contig
2026        };
2027        let r = if rhs_contig.storage().device().is_wgpu() {
2028            rhs_contig.to_cpu()
2029        } else {
2030            rhs_contig
2031        };
2032        (l, r)
2033    };
2034
2035    let lhs_guard = lhs_contig.data();
2036    let rhs_guard = rhs_contig.data();
2037    let lhs_data = &*lhs_guard;
2038    let rhs_data = &*rhs_guard;
2039
2040    let lhs_s0 = lhs_contig.strides()[0] as isize;
2041    let lhs_s1 = lhs_contig.strides()[1] as isize;
2042    let rhs_s0 = rhs_contig.strides()[0] as isize;
2043    let rhs_s1 = rhs_contig.strides()[1] as isize;
2044    let kernel = choose_cpu_kernel(
2045        m, k, n, false, lhs_data, rhs_data, lhs_s0, lhs_s1, rhs_s0, rhs_s1, None,
2046    );
2047    let result_data = match kernel {
2048        CpuKernelChoice::Parallel => matmul_cpu_parallel_core(lhs_data, rhs_data, m, k, n, None),
2049        CpuKernelChoice::Sgemm => matmul_cpu_sgemm_core(
2050            lhs_data, rhs_data, m, k, n, lhs_s0, lhs_s1, rhs_s0, rhs_s1, None,
2051        ),
2052    };
2053
2054    let storage = Storage::new(result_data);
2055    let mut tensor = Tensor::new_with_storage(storage, &[m, n]);
2056
2057    if lhs.requires_grad() || rhs.requires_grad() {
2058        tensor.set_requires_grad_mut(true);
2059        tensor.set_op(Arc::new(MatmulBackward {
2060            lhs: lhs.clone(),
2061            rhs: rhs.clone(),
2062        }));
2063    }
2064
2065    tensor
2066}