1use crate::autograd::BackwardOp;
2use crate::storage::Storage;
3use crate::Tensor;
4use parking_lot::Mutex;
5use rayon::prelude::*;
6use std::collections::HashMap;
7use std::hint::black_box;
8use std::sync::{Arc, OnceLock};
9use std::time::Instant;
10use wide::f32x8;
11
12pub mod activations;
13pub mod conv;
14pub mod embedding;
15pub mod norm;
16pub mod pool;
17pub mod view;
18
19pub use activations::{sigmoid, softmax, tanh};
20pub use conv::conv2d;
21pub use embedding::embedding;
22pub use norm::{batch_norm2d, layer_norm};
23pub use pool::max_pool2d;
24pub use view::ReshapeBackward;
25
26#[derive(Debug)]
27pub struct MulBackward {
28 pub lhs: Tensor,
29 pub rhs: Tensor,
30}
31
32impl BackwardOp for MulBackward {
33 fn backward(&self, grad: &Tensor) {
34 let same_input = Arc::ptr_eq(&self.lhs.inner, &self.rhs.inner);
35 if same_input && self.lhs.requires_grad() {
36 let mut grad_lhs = crate::ops::mul(grad, &self.rhs);
37 let mut grad_rhs = crate::ops::mul(grad, &self.lhs);
38 if grad_lhs.shape() != self.lhs.shape() {
39 grad_lhs = sum_to(&grad_lhs, self.lhs.shape());
40 }
41 if grad_rhs.shape() != self.lhs.shape() {
42 grad_rhs = sum_to(&grad_rhs, self.lhs.shape());
43 }
44 let grad_total = add(&grad_lhs, &grad_rhs);
45 self.lhs.accumulate_grad(&grad_total);
46 self.lhs.backward_step();
47 return;
48 }
49
50 if self.lhs.requires_grad() {
51 let mut grad_lhs = crate::ops::mul(grad, &self.rhs);
52 if grad_lhs.shape() != self.lhs.shape() {
53 grad_lhs = sum_to(&grad_lhs, self.lhs.shape());
54 }
55 self.lhs.accumulate_grad(&grad_lhs);
56 self.lhs.backward_step();
57 }
58 if self.rhs.requires_grad() {
59 let mut grad_rhs = crate::ops::mul(grad, &self.lhs);
60 if grad_rhs.shape() != self.rhs.shape() {
61 grad_rhs = sum_to(&grad_rhs, self.rhs.shape());
62 }
63 self.rhs.accumulate_grad(&grad_rhs);
64 self.rhs.backward_step();
65 }
66 }
67}
68
69pub fn mul(lhs: &Tensor, rhs: &Tensor) -> Tensor {
70 #[cfg(feature = "wgpu_backend")]
71 {
72 if let (Some(lhs_buf), Some(rhs_buf)) =
73 (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer())
74 {
75 let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
76 .expect("Shapes not broadcastable");
77
78 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
79 let output_buf = elementwise_wgpu_buffer(
80 lhs_buf,
81 lhs.shape(),
82 lhs.strides(),
83 Some((rhs_buf, rhs.shape(), rhs.strides())),
84 &target_shape,
85 ElementwiseOp::Mul,
86 None,
87 );
88
89 let size: usize = target_shape.iter().product();
90 let storage = Storage::new_wgpu(output_buf, size, 0);
91 let mut tensor = Tensor::new_with_storage(storage, &target_shape);
92
93 if lhs.requires_grad() || rhs.requires_grad() {
94 tensor.set_requires_grad_mut(true);
95 tensor.set_op(Arc::new(MulBackward {
96 lhs: lhs.clone(),
97 rhs: rhs.clone(),
98 }));
99 }
100 return tensor;
101 }
102 }
103
104 if lhs.shape() != rhs.shape() {
105 let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
107 .expect("Shapes not broadcastable");
108 let lhs_expanded = lhs.expand(&target_shape);
109 let rhs_expanded = rhs.expand(&target_shape);
110 return mul(&lhs_expanded, &rhs_expanded);
111 }
112
113 let lhs_contig = if lhs.is_contiguous() {
114 lhs.clone()
115 } else {
116 lhs.contiguous()
117 };
118 let rhs_contig = if rhs.is_contiguous() {
119 rhs.clone()
120 } else {
121 rhs.contiguous()
122 };
123
124 let lhs_guard = lhs_contig.data();
125 let rhs_guard = rhs_contig.data();
126 let lhs_data = &*lhs_guard;
127 let rhs_data = &*rhs_guard;
128
129 let result_data = elemwise_auto(lhs_data, rhs_data, ElemwiseKind::Mul);
130
131 let storage = Storage::new(result_data);
132 let mut tensor = Tensor::new_with_storage(storage, lhs.shape());
133
134 if lhs.requires_grad() || rhs.requires_grad() {
135 tensor.set_requires_grad_mut(true);
136 tensor.set_op(Arc::new(MulBackward {
137 lhs: lhs.clone(),
138 rhs: rhs.clone(),
139 }));
140 }
141
142 tensor
143}
144
145pub fn div(lhs: &Tensor, rhs: &Tensor) -> Tensor {
146 if lhs.shape() != rhs.shape() {
148 let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
149 .expect("Shapes not broadcastable");
150 let lhs_expanded = lhs.expand(&target_shape);
151 let rhs_expanded = rhs.expand(&target_shape);
152 return div(&lhs_expanded, &rhs_expanded);
153 }
154
155 let lhs_contig = if lhs.is_contiguous() {
156 lhs.clone()
157 } else {
158 lhs.contiguous()
159 };
160 let rhs_contig = if rhs.is_contiguous() {
161 rhs.clone()
162 } else {
163 rhs.contiguous()
164 };
165
166 let lhs_guard = lhs_contig.data();
167 let rhs_guard = rhs_contig.data();
168 let result_data: Vec<f32> = lhs_guard
169 .par_iter()
170 .zip(rhs_guard.par_iter())
171 .map(|(a, b)| a / b)
172 .collect();
173 let storage = Storage::new(result_data);
174 let tensor = Tensor::new_with_storage(storage, lhs.shape());
175
176 tensor
178}
179
180#[derive(Debug)]
182pub struct AddBackward {
183 pub lhs: Tensor,
184 pub rhs: Tensor,
185}
186
187impl BackwardOp for AddBackward {
188 fn backward(&self, grad: &Tensor) {
189 let same_input = Arc::ptr_eq(&self.lhs.inner, &self.rhs.inner);
190 if same_input && self.lhs.requires_grad() {
191 let grad_lhs = if grad.shape() != self.lhs.shape() {
192 sum_to(grad, self.lhs.shape())
193 } else {
194 grad.clone()
195 };
196 let grad_total = add(&grad_lhs, &grad_lhs);
197 self.lhs.accumulate_grad(&grad_total);
198 self.lhs.backward_step();
199 return;
200 }
201
202 if self.lhs.requires_grad() {
203 let grad_lhs = if grad.shape() != self.lhs.shape() {
204 sum_to(grad, self.lhs.shape())
205 } else {
206 grad.clone()
207 };
208 self.lhs.accumulate_grad(&grad_lhs);
209 self.lhs.backward_step();
210 }
211 if self.rhs.requires_grad() {
212 let grad_rhs = if grad.shape() != self.rhs.shape() {
213 sum_to(grad, self.rhs.shape())
214 } else {
215 grad.clone()
216 };
217 self.rhs.accumulate_grad(&grad_rhs);
218 self.rhs.backward_step();
219 }
220 }
221}
222
223pub fn add(lhs: &Tensor, rhs: &Tensor) -> Tensor {
224 #[cfg(feature = "wgpu_backend")]
225 {
226 if let (Some(lhs_buf), Some(rhs_buf)) =
227 (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer())
228 {
229 let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
230 .expect("Shapes not broadcastable");
231
232 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
233 let output_buf = elementwise_wgpu_buffer(
234 lhs_buf,
235 lhs.shape(),
236 lhs.strides(),
237 Some((rhs_buf, rhs.shape(), rhs.strides())),
238 &target_shape,
239 ElementwiseOp::Add,
240 None,
241 );
242
243 let size: usize = target_shape.iter().product();
244 let storage = Storage::new_wgpu(output_buf, size, 0);
245 let mut tensor = Tensor::new_with_storage(storage, &target_shape);
246
247 if lhs.requires_grad() || rhs.requires_grad() {
248 tensor.set_requires_grad_mut(true);
249 tensor.set_op(Arc::new(AddBackward {
250 lhs: lhs.clone(),
251 rhs: rhs.clone(),
252 }));
253 }
254 return tensor;
255 }
256 }
257
258 if lhs.shape() != rhs.shape() {
259 let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
260 .expect("Shapes not broadcastable");
261 let lhs_expanded = lhs.expand(&target_shape);
262 let rhs_expanded = rhs.expand(&target_shape);
263 return add(&lhs_expanded, &rhs_expanded);
264 }
265
266 let lhs_contig = if lhs.is_contiguous() {
267 lhs.clone()
268 } else {
269 lhs.contiguous()
270 };
271 let rhs_contig = if rhs.is_contiguous() {
272 rhs.clone()
273 } else {
274 rhs.contiguous()
275 };
276
277 let lhs_guard = lhs_contig.data();
278 let rhs_guard = rhs_contig.data();
279 let result_data = elemwise_auto(&lhs_guard, &rhs_guard, ElemwiseKind::Add);
280 let storage = Storage::new(result_data);
281 let mut tensor = Tensor::new_with_storage(storage, lhs.shape());
282 if lhs.requires_grad() || rhs.requires_grad() {
283 tensor.set_requires_grad_mut(true);
284 tensor.set_op(Arc::new(AddBackward {
285 lhs: lhs.clone(),
286 rhs: rhs.clone(),
287 }));
288 }
289 tensor
290}
291
292#[derive(Debug)]
293pub struct SubBackward {
294 pub lhs: Tensor,
295 pub rhs: Tensor,
296}
297
298impl BackwardOp for SubBackward {
299 fn backward(&self, grad: &Tensor) {
300 let same_input = Arc::ptr_eq(&self.lhs.inner, &self.rhs.inner);
301 if same_input && self.lhs.requires_grad() {
302 return;
303 }
304
305 if self.lhs.requires_grad() {
306 let mut grad_lhs = grad.clone();
307 if grad_lhs.shape() != self.lhs.shape() {
308 grad_lhs = sum_to(&grad_lhs, self.lhs.shape());
309 }
310 self.lhs.accumulate_grad(&grad_lhs);
311 self.lhs.backward_step();
312 }
313 if self.rhs.requires_grad() {
314 let mut grad_rhs = neg(grad);
315 if grad_rhs.shape() != self.rhs.shape() {
316 grad_rhs = sum_to(&grad_rhs, self.rhs.shape());
317 }
318 self.rhs.accumulate_grad(&grad_rhs);
319 self.rhs.backward_step();
320 }
321 }
322}
323
324pub fn sub(lhs: &Tensor, rhs: &Tensor) -> Tensor {
325 #[cfg(feature = "wgpu_backend")]
326 {
327 if let (Some(lhs_buf), Some(rhs_buf)) =
328 (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer())
329 {
330 let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
331 .expect("Shapes not broadcastable");
332
333 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
334 let output_buf = elementwise_wgpu_buffer(
335 lhs_buf,
336 lhs.shape(),
337 lhs.strides(),
338 Some((rhs_buf, rhs.shape(), rhs.strides())),
339 &target_shape,
340 ElementwiseOp::Sub,
341 None,
342 );
343
344 let size: usize = target_shape.iter().product();
345 let storage = Storage::new_wgpu(output_buf, size, 0);
346 let mut tensor = Tensor::new_with_storage(storage, &target_shape);
347
348 if lhs.requires_grad() || rhs.requires_grad() {
349 tensor.set_requires_grad_mut(true);
350 tensor.set_op(Arc::new(SubBackward {
351 lhs: lhs.clone(),
352 rhs: rhs.clone(),
353 }));
354 }
355 return tensor;
356 }
357 }
358
359 if lhs.shape() != rhs.shape() {
360 let target_shape = crate::broadcast::broadcast_shapes(lhs.shape(), rhs.shape())
361 .expect("Shapes not broadcastable");
362 let lhs_expanded = lhs.expand(&target_shape);
363 let rhs_expanded = rhs.expand(&target_shape);
364 return sub(&lhs_expanded, &rhs_expanded);
365 }
366
367 let lhs_contig = if lhs.is_contiguous() {
368 lhs.clone()
369 } else {
370 lhs.contiguous()
371 };
372 let rhs_contig = if rhs.is_contiguous() {
373 rhs.clone()
374 } else {
375 rhs.contiguous()
376 };
377 let lhs_guard = lhs_contig.data();
378 let rhs_guard = rhs_contig.data();
379 let result_data = elemwise_auto(&lhs_guard, &rhs_guard, ElemwiseKind::Sub);
380 let storage = Storage::new(result_data);
381 let mut tensor = Tensor::new_with_storage(storage, lhs.shape());
382
383 if lhs.requires_grad() || rhs.requires_grad() {
384 tensor.set_requires_grad_mut(true);
385 tensor.set_op(Arc::new(SubBackward {
386 lhs: lhs.clone(),
387 rhs: rhs.clone(),
388 }));
389 }
390 tensor
391}
392
393pub fn neg(input: &Tensor) -> Tensor {
394 let input_guard = input.data();
395 let result_data: Vec<f32> = input_guard.par_iter().map(|x| -x).collect();
396 let storage = Storage::new(result_data);
397 Tensor::new_with_storage(storage, input.shape())
398}
399
400#[derive(Debug)]
401pub struct ReluBackward {
402 pub input: Tensor,
403 pub output: Tensor,
404}
405
406impl BackwardOp for ReluBackward {
407 fn backward(&self, grad: &Tensor) {
408 if self.input.requires_grad() {
409 #[cfg(feature = "wgpu_backend")]
410 {
411 if let (Some(out_buf), Some(grad_buf)) = (
412 self.output.storage().wgpu_buffer(),
413 grad.storage().wgpu_buffer(),
414 ) {
415 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
416 let grad_input_buf = elementwise_wgpu_buffer(
417 out_buf,
418 self.output.shape(),
419 self.output.strides(),
420 Some((grad_buf, grad.shape(), grad.strides())),
421 grad.shape(),
422 ElementwiseOp::ReLUBackward,
423 None,
424 );
425 let size: usize = grad.shape().iter().product();
426 let storage = Storage::new_wgpu(grad_input_buf, size, 0);
427 let grad_input = Tensor::new_with_storage(storage, grad.shape());
428 self.input.accumulate_grad(&grad_input);
429 self.input.backward_step();
430 return;
431 }
432 }
433
434 let input_guard = self.input.data();
435 let grad_guard = grad.data();
436 let grad_input: Vec<f32> = input_guard
437 .par_iter()
438 .zip(grad_guard.par_iter())
439 .map(|(x, g)| if *x > 0.0 { *g } else { 0.0 })
440 .collect();
441 let storage = Storage::new(grad_input);
442 let grad_input_tensor = Tensor::new_with_storage(storage, grad.shape());
443 self.input.accumulate_grad(&grad_input_tensor);
444 self.input.backward_step();
445 }
446 }
447}
448
449pub fn relu(input: &Tensor) -> Tensor {
450 #[cfg(feature = "wgpu_backend")]
451 {
452 if let Some(buf) = input.storage().wgpu_buffer() {
453 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
454 let output_buf = elementwise_wgpu_buffer(
455 buf,
456 input.shape(),
457 input.strides(),
458 None,
459 input.shape(),
460 ElementwiseOp::ReLU,
461 None,
462 );
463 let size: usize = input.shape().iter().product();
464 let storage = Storage::new_wgpu(output_buf, size, 0);
465 let mut tensor = Tensor::new_with_storage(storage, input.shape());
466
467 if input.requires_grad() {
468 tensor.set_requires_grad_mut(true);
469 tensor.set_op(Arc::new(ReluBackward {
470 input: input.clone(),
471 output: tensor.detach(),
472 }));
473 }
474 return tensor;
475 }
476 }
477
478 let input_guard = input.data();
479 let result_data: Vec<f32> = input_guard.par_iter().map(|x| x.max(0.0)).collect();
480 let storage = Storage::new(result_data);
481 let mut tensor = Tensor::new_with_storage(storage, input.shape());
482
483 if input.requires_grad() {
484 tensor.set_requires_grad_mut(true);
485 tensor.set_op(Arc::new(ReluBackward {
486 input: input.clone(),
487 output: tensor.detach(),
488 }));
489 }
490 tensor
491}
492
493pub fn sgd_step(param: &Tensor, grad: &Tensor, lr: f32) -> Tensor {
494 #[cfg(feature = "wgpu_backend")]
496 {
497 if let (Some(p_buf), Some(g_buf)) =
498 (param.storage().wgpu_buffer(), grad.storage().wgpu_buffer())
499 {
500 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
501 let output_buf = elementwise_wgpu_buffer(
502 p_buf,
503 param.shape(),
504 param.strides(),
505 Some((g_buf, grad.shape(), grad.strides())),
506 param.shape(),
507 ElementwiseOp::SGDStep,
508 Some(lr),
509 );
510 let size: usize = param.shape().iter().product();
511 let storage = Storage::new_wgpu(output_buf, size, 0);
512 return Tensor::new_with_storage(storage, param.shape());
513 }
514 }
515
516 let p_data = param.data();
518 let g_data = grad.data();
519 let res_data: Vec<f32> = p_data
520 .par_iter()
521 .zip(g_data.par_iter())
522 .map(|(p, g)| p - lr * g)
523 .collect();
524 let storage = Storage::new(res_data);
525 Tensor::new_with_storage(storage, param.shape())
526}
527
528#[derive(Debug)]
529pub struct MatmulBackward {
530 pub lhs: Tensor,
531 pub rhs: Tensor,
532}
533
534impl BackwardOp for MatmulBackward {
535 fn backward(&self, grad: &Tensor) {
536 #[cfg(feature = "wgpu_backend")]
537 {
538 let grad_is_wgpu = grad.storage().wgpu_buffer().is_some();
539 let lhs_is_wgpu = self.lhs.storage().wgpu_buffer().is_some();
540 let rhs_is_wgpu = self.rhs.storage().wgpu_buffer().is_some();
541
542 if grad_is_wgpu && lhs_is_wgpu && rhs_is_wgpu {
543 if self.lhs.requires_grad() {
544 let rhs_t = self.rhs.t();
545 let grad_lhs = matmul(grad, &rhs_t).detach();
546 self.lhs.accumulate_grad(&grad_lhs);
547 self.lhs.backward_step();
548 }
549 if self.rhs.requires_grad() {
550 let grad_rhs = matmul(&self.lhs.t(), grad).detach();
551 self.rhs.accumulate_grad(&grad_rhs);
552 self.rhs.backward_step();
553 }
554 return;
555 }
556 }
557
558 if self.lhs.requires_grad() {
559 let rhs_t = self.rhs.t();
560 let grad_lhs = matmul(grad, &rhs_t);
561 self.lhs.accumulate_grad(&grad_lhs);
562 self.lhs.backward_step();
563 }
564 if self.rhs.requires_grad() {
565 let grad_rhs = matmul(&self.lhs.t(), grad);
566 self.rhs.accumulate_grad(&grad_rhs);
567 self.rhs.backward_step();
568 }
569 }
570}
571
572#[cfg(feature = "wgpu_backend")]
573#[allow(dead_code)]
574fn matmul_gpu_aware_no_grad(lhs: &Tensor, rhs: &Tensor) -> Tensor {
575 let lhs_shape = lhs.shape();
576 let rhs_shape = rhs.shape();
577
578 if lhs_shape.len() != 2 || rhs_shape.len() != 2 {
579 panic!("Matmul only supports 2D");
580 }
581
582 let m = lhs_shape[0];
583 let k = lhs_shape[1];
584 let k2 = rhs_shape[0];
585 let n = rhs_shape[1];
586
587 if k != k2 {
588 panic!("Matmul dimension mismatch");
589 }
590
591 if let (Some(_), Some(_)) = (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer()) {
592 let lhs_contig = if lhs.is_contiguous() {
593 lhs.clone()
594 } else {
595 lhs.contiguous()
596 };
597 let rhs_contig = if rhs.is_contiguous() {
598 rhs.clone()
599 } else {
600 rhs.contiguous()
601 };
602
603 if lhs_contig.storage().wgpu_buffer().is_none()
604 || rhs_contig.storage().wgpu_buffer().is_none()
605 {
606 return matmul(&lhs_contig, &rhs_contig);
607 }
608
609 use crate::backend::wgpu::{matmul_wgpu_buffer, Activation};
610 let output_buf = matmul_wgpu_buffer(
611 lhs_contig.storage().wgpu_buffer().unwrap(),
612 lhs_contig.shape(),
613 rhs_contig.storage().wgpu_buffer().unwrap(),
614 rhs_contig.shape(),
615 Activation::None,
616 );
617
618 let storage = Storage::new_wgpu(output_buf, m * n, 0);
619 return Tensor::new_with_storage(storage, &[m, n]);
620 }
621
622 matmul(lhs, rhs)
623}
624
625pub fn sum_to(tensor: &Tensor, shape: &[usize]) -> Tensor {
626 if tensor.shape() == shape {
627 return tensor.clone();
628 }
629 view::sum_to(tensor, shape)
630}
631
632#[derive(Debug)]
633pub struct SumBackward {
634 pub input: Tensor,
635}
636
637impl BackwardOp for SumBackward {
638 fn backward(&self, grad: &Tensor) {
639 if self.input.requires_grad() {
640 #[cfg(feature = "wgpu_backend")]
641 let grad_cpu = if grad.storage().device().is_wgpu() {
642 grad.to_cpu()
643 } else {
644 grad.clone()
645 };
646 #[cfg(not(feature = "wgpu_backend"))]
647 let grad_cpu = grad.clone();
648
649 let grad_val = grad_cpu.data()[0];
650 let mut grad_input = Tensor::full(self.input.shape(), grad_val);
651 #[cfg(feature = "wgpu_backend")]
652 if self.input.storage().device().is_wgpu() {
653 grad_input = grad_input.to_wgpu();
654 }
655 self.input.accumulate_grad(&grad_input);
656 self.input.backward_step();
657 }
658 }
659}
660
661pub fn sum(tensor: &Tensor) -> Tensor {
662 let total_size: usize = tensor.shape().iter().product();
663
664 let mut output = {
665 #[cfg(feature = "wgpu_backend")]
666 {
667 if let Some(input_buf) = tensor.storage().wgpu_buffer() {
668 let output_buf = crate::backend::wgpu::reduce_sum_all_wgpu(input_buf, total_size);
669 let storage = Storage::new_wgpu(output_buf, 1, 0);
670 Tensor::new_with_storage(storage, &[])
671 } else {
672 let data = tensor.data();
673 let sum_val = sum_auto(&data);
674 Tensor::new_with_storage(Storage::new(vec![sum_val]), &[])
675 }
676 }
677 #[cfg(not(feature = "wgpu_backend"))]
678 {
679 let data = tensor.data();
680 let sum_val = sum_auto(&data);
681 Tensor::new_with_storage(Storage::new(vec![sum_val]), &[])
682 }
683 };
684
685 if tensor.requires_grad() {
686 output.set_requires_grad_mut(true);
687 output.set_op(Arc::new(SumBackward {
688 input: tensor.clone(),
689 }));
690 }
691
692 output
693}
694
695#[derive(Debug)]
696pub struct MeanBackward {
697 pub input: Tensor,
698}
699
700impl BackwardOp for MeanBackward {
701 fn backward(&self, grad: &Tensor) {
702 if self.input.requires_grad() {
703 #[cfg(feature = "wgpu_backend")]
704 let grad_cpu = if grad.storage().device().is_wgpu() {
705 grad.to_cpu()
706 } else {
707 grad.clone()
708 };
709 #[cfg(not(feature = "wgpu_backend"))]
710 let grad_cpu = grad.clone();
711
712 let grad_val = grad_cpu.data()[0];
713 let numel = self.input.shape().iter().product::<usize>() as f32;
714 let mut grad_input = Tensor::full(self.input.shape(), grad_val / numel);
715 #[cfg(feature = "wgpu_backend")]
716 if self.input.storage().device().is_wgpu() {
717 grad_input = grad_input.to_wgpu();
718 }
719 self.input.accumulate_grad(&grad_input);
720 self.input.backward_step();
721 }
722 }
723}
724
725pub fn mean(tensor: &Tensor) -> Tensor {
726 let t_cpu = {
727 #[cfg(feature = "wgpu_backend")]
728 {
729 if tensor.storage().device().is_wgpu() {
730 tensor.to_cpu()
731 } else {
732 tensor.clone()
733 }
734 }
735 #[cfg(not(feature = "wgpu_backend"))]
736 {
737 tensor.clone()
738 }
739 };
740 let data = t_cpu.data();
741 let numel = data.len() as f32;
742 let mean_val = sum_auto(&data) / numel;
743 let mut out = Tensor::new_with_storage(Storage::new(vec![mean_val]), &[]);
744 if tensor.requires_grad() {
745 out.set_requires_grad_mut(true);
746 out.set_op(Arc::new(MeanBackward {
747 input: tensor.clone(),
748 }));
749 }
750 out
751}
752
753#[derive(Debug)]
754pub struct VarBackward {
755 pub input: Tensor,
756 pub mean: f32,
757}
758
759impl BackwardOp for VarBackward {
760 fn backward(&self, grad: &Tensor) {
761 if self.input.requires_grad() {
762 #[cfg(feature = "wgpu_backend")]
763 let grad_cpu = if grad.storage().device().is_wgpu() {
764 grad.to_cpu()
765 } else {
766 grad.clone()
767 };
768 #[cfg(not(feature = "wgpu_backend"))]
769 let grad_cpu = grad.clone();
770 let grad_val = grad_cpu.data()[0];
771 let input_cpu = {
772 #[cfg(feature = "wgpu_backend")]
773 {
774 if self.input.storage().device().is_wgpu() {
775 self.input.to_cpu()
776 } else {
777 self.input.clone()
778 }
779 }
780 #[cfg(not(feature = "wgpu_backend"))]
781 {
782 self.input.clone()
783 }
784 };
785 let input_data = input_cpu.data();
786 let numel = input_data.len() as f32;
787 let scale = grad_val * 2.0 / numel;
788 let grad_data: Vec<f32> = input_data.iter().map(|x| (x - self.mean) * scale).collect();
789 let mut grad_input = Tensor::new(&grad_data, self.input.shape());
790 #[cfg(feature = "wgpu_backend")]
791 if self.input.storage().device().is_wgpu() {
792 grad_input = grad_input.to_wgpu();
793 }
794 self.input.accumulate_grad(&grad_input);
795 self.input.backward_step();
796 }
797 }
798}
799
800pub fn var(tensor: &Tensor) -> Tensor {
801 let t_cpu = {
802 #[cfg(feature = "wgpu_backend")]
803 {
804 if tensor.storage().device().is_wgpu() {
805 tensor.to_cpu()
806 } else {
807 tensor.clone()
808 }
809 }
810 #[cfg(not(feature = "wgpu_backend"))]
811 {
812 tensor.clone()
813 }
814 };
815 let data = t_cpu.data();
816 let numel = data.len() as f32;
817 let m = sum_auto(&data) / numel;
818 let sq: Vec<f32> = data.iter().map(|x| (x - m) * (x - m)).collect();
819 let v = sum_auto(&sq) / numel;
820 let mut out = Tensor::new_with_storage(Storage::new(vec![v]), &[]);
821 if tensor.requires_grad() {
822 out.set_requires_grad_mut(true);
823 out.set_op(Arc::new(VarBackward {
824 input: tensor.clone(),
825 mean: m,
826 }));
827 }
828 out
829}
830
831pub fn linear_mse_grads(input: &Tensor, output: &Tensor, target: &Tensor) -> (f32, Tensor, Tensor) {
832 let x = if input.is_contiguous() {
833 input.clone()
834 } else {
835 input.contiguous()
836 };
837 let y = if output.is_contiguous() {
838 output.clone()
839 } else {
840 output.contiguous()
841 };
842 let t = if target.is_contiguous() {
843 target.clone()
844 } else {
845 target.contiguous()
846 };
847
848 let x_shape = x.shape();
849 let y_shape = y.shape();
850 let t_shape = t.shape();
851 if x_shape.len() != 2 || y_shape.len() != 2 || t_shape.len() != 2 {
852 panic!("linear_mse_grads expects 2D tensors");
853 }
854 if y_shape != t_shape {
855 panic!("linear_mse_grads output and target shape mismatch");
856 }
857 if x_shape[0] != y_shape[0] {
858 panic!("linear_mse_grads batch mismatch");
859 }
860
861 let batch = x_shape[0];
862 let in_dim = x_shape[1];
863 let out_dim = y_shape[1];
864 let numel = (batch * out_dim) as f32;
865 let grad_scale = 2.0 / numel;
866
867 let x_data = x.data();
868 let y_data = y.data();
869 let t_data = t.data();
870
871 let mut grad_w = vec![0.0f32; out_dim * in_dim];
872 let mut grad_b = vec![0.0f32; out_dim];
873 let mut loss = 0.0f32;
874
875 for b in 0..batch {
876 for (o, gb) in grad_b.iter_mut().enumerate().take(out_dim) {
877 let idx = b * out_dim + o;
878 let d = y_data[idx] - t_data[idx];
879 loss += d * d;
880 let go = d * grad_scale;
881 *gb += go;
882 let w_row_offset = o * in_dim;
883 let x_row_offset = b * in_dim;
884 for i in 0..in_dim {
885 grad_w[w_row_offset + i] += go * x_data[x_row_offset + i];
886 }
887 }
888 }
889
890 let loss = loss / numel;
891 let grad_w_t = Tensor::new(&grad_w, &[out_dim, in_dim]);
892 let grad_b_t = Tensor::new(&grad_b, &[out_dim]);
893 (loss, grad_w_t, grad_b_t)
894}
895
896#[derive(Debug)]
897pub struct FusedMatmulBackward {
898 pub lhs: Tensor,
899 pub rhs: Tensor,
900 pub bias: Option<Tensor>,
901 pub output: Tensor,
902 pub activation: crate::backend::Activation,
903}
904
905impl BackwardOp for FusedMatmulBackward {
906 fn backward(&self, grad_output: &Tensor) {
907 let grad_pre_act = match self.activation {
908 crate::backend::Activation::ReLU => {
909 #[cfg(feature = "wgpu_backend")]
910 {
911 if let (Some(out_buf), Some(grad_buf)) = (
912 self.output.storage().wgpu_buffer(),
913 grad_output.storage().wgpu_buffer(),
914 ) {
915 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
917 let target_shape = grad_output.shape();
918 let out_buf = elementwise_wgpu_buffer(
919 out_buf,
920 self.output.shape(),
921 self.output.strides(),
922 Some((grad_buf, grad_output.shape(), grad_output.strides())),
923 target_shape,
924 ElementwiseOp::ReLUBackward,
925 None,
926 );
927 let storage = Storage::new_wgpu(out_buf, target_shape.iter().product(), 0);
928 Tensor::new_with_storage(storage, target_shape)
929 } else {
930 grad_output.clone()
931 }
932 }
933 #[cfg(not(feature = "wgpu_backend"))]
934 grad_output.clone()
935 }
936 crate::backend::Activation::Sigmoid => {
937 #[cfg(feature = "wgpu_backend")]
938 {
939 if let (Some(out_buf), Some(grad_buf)) = (
940 self.output.storage().wgpu_buffer(),
941 grad_output.storage().wgpu_buffer(),
942 ) {
943 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
944 let target_shape = grad_output.shape();
945 let out_buf = elementwise_wgpu_buffer(
946 out_buf,
947 self.output.shape(),
948 self.output.strides(),
949 Some((grad_buf, grad_output.shape(), grad_output.strides())),
950 target_shape,
951 ElementwiseOp::SigmoidBackward,
952 None,
953 );
954 let storage = Storage::new_wgpu(out_buf, target_shape.iter().product(), 0);
955 Tensor::new_with_storage(storage, target_shape)
956 } else {
957 grad_output.clone()
958 }
959 }
960 #[cfg(not(feature = "wgpu_backend"))]
961 grad_output.clone()
962 }
963 crate::backend::Activation::Tanh => {
964 #[cfg(feature = "wgpu_backend")]
965 {
966 if let (Some(out_buf), Some(grad_buf)) = (
967 self.output.storage().wgpu_buffer(),
968 grad_output.storage().wgpu_buffer(),
969 ) {
970 use crate::backend::wgpu::{elementwise_wgpu_buffer, ElementwiseOp};
971 let target_shape = grad_output.shape();
972 let out_buf = elementwise_wgpu_buffer(
973 out_buf,
974 self.output.shape(),
975 self.output.strides(),
976 Some((grad_buf, grad_output.shape(), grad_output.strides())),
977 target_shape,
978 ElementwiseOp::TanhBackward,
979 None,
980 );
981 let storage = Storage::new_wgpu(out_buf, target_shape.iter().product(), 0);
982 Tensor::new_with_storage(storage, target_shape)
983 } else {
984 grad_output.clone()
985 }
986 }
987 #[cfg(not(feature = "wgpu_backend"))]
988 grad_output.clone()
989 }
990 crate::backend::Activation::None => grad_output.clone(),
991 };
992
993 if let Some(bias) = &self.bias {
994 if bias.requires_grad() {
995 let grad_bias = sum_to(&grad_pre_act, bias.shape());
996 bias.accumulate_grad(&grad_bias);
997 bias.backward_step();
998 }
999 }
1000
1001 if self.lhs.requires_grad() {
1002 let rhs_t = self.rhs.t();
1003 let grad_lhs = matmul(&grad_pre_act, &rhs_t);
1004 self.lhs.accumulate_grad(&grad_lhs);
1005 self.lhs.backward_step();
1006 }
1007
1008 if self.rhs.requires_grad() {
1009 let grad_rhs = matmul(&self.lhs.t(), &grad_pre_act);
1010 self.rhs.accumulate_grad(&grad_rhs);
1011 self.rhs.backward_step();
1012 }
1013 }
1014}
1015
1016#[inline]
1017fn parse_usize_env(key: &str, default: usize) -> usize {
1018 std::env::var(key)
1019 .ok()
1020 .and_then(|s| s.parse::<usize>().ok())
1021 .unwrap_or(default)
1022}
1023
1024#[derive(Clone, Copy, PartialEq, Eq)]
1025enum CpuMatmulStrategy {
1026 Auto,
1027 Profile,
1028 Sgemm,
1029 Parallel,
1030}
1031
1032#[derive(Clone, Copy)]
1033struct CpuMatmulConfig {
1034 strategy: CpuMatmulStrategy,
1035 min_m: usize,
1036 min_k: usize,
1037 max_n: usize,
1038 profile_iters: usize,
1039}
1040
1041#[derive(Clone, Copy, PartialEq, Eq)]
1042enum CpuKernelChoice {
1043 Sgemm,
1044 Parallel,
1045}
1046
1047type MatmulPerfKey = (usize, usize, usize, bool);
1048type MatmulPerfCache = HashMap<MatmulPerfKey, CpuKernelChoice>;
1049
1050fn cpu_matmul_config() -> CpuMatmulConfig {
1051 static CFG: OnceLock<CpuMatmulConfig> = OnceLock::new();
1052 *CFG.get_or_init(|| {
1053 let strategy = match std::env::var("RUSTORCH_CPU_MATMUL_STRATEGY")
1054 .unwrap_or_else(|_| "auto".to_string())
1055 .to_ascii_lowercase()
1056 .as_str()
1057 {
1058 "parallel" => CpuMatmulStrategy::Parallel,
1059 "sgemm" => CpuMatmulStrategy::Sgemm,
1060 "profile" => CpuMatmulStrategy::Profile,
1061 _ => CpuMatmulStrategy::Auto,
1062 };
1063
1064 CpuMatmulConfig {
1065 strategy,
1066 min_m: parse_usize_env("RUSTORCH_CPU_MATMUL_MIN_M", 128),
1067 min_k: parse_usize_env("RUSTORCH_CPU_MATMUL_MIN_K", 256),
1068 max_n: parse_usize_env("RUSTORCH_CPU_MATMUL_MAX_N", 128),
1069 profile_iters: parse_usize_env("RUSTORCH_CPU_MATMUL_PROFILE_ITERS", 2),
1070 }
1071 })
1072}
1073
1074#[inline]
1075fn should_use_parallel_auto(m: usize, k: usize, n: usize) -> bool {
1076 let cfg = cpu_matmul_config();
1077 m >= cfg.min_m && k >= cfg.min_k && n <= cfg.max_n
1078}
1079
1080fn matmul_profile_cache() -> &'static Mutex<MatmulPerfCache> {
1081 static CACHE: OnceLock<Mutex<MatmulPerfCache>> = OnceLock::new();
1082 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
1083}
1084
1085fn matmul_cpu_sgemm_core(
1086 lhs_data: &[f32],
1087 rhs_data: &[f32],
1088 m: usize,
1089 k: usize,
1090 n: usize,
1091 lhs_stride0: isize,
1092 lhs_stride1: isize,
1093 rhs_stride0: isize,
1094 rhs_stride1: isize,
1095 bias: Option<&[f32]>,
1096) -> Vec<f32> {
1097 let mut out = vec![0.0f32; m * n];
1098 unsafe {
1099 matrixmultiply::sgemm(
1100 m,
1101 k,
1102 n,
1103 1.0,
1104 lhs_data.as_ptr(),
1105 lhs_stride0,
1106 lhs_stride1,
1107 rhs_data.as_ptr(),
1108 rhs_stride0,
1109 rhs_stride1,
1110 0.0,
1111 out.as_mut_ptr(),
1112 n as isize,
1113 1,
1114 );
1115 }
1116 if let Some(bias_data) = bias {
1117 out.par_chunks_mut(n).for_each(|row| {
1118 row.iter_mut()
1119 .zip(bias_data.iter())
1120 .for_each(|(v, b)| *v += *b);
1121 });
1122 }
1123 out
1124}
1125
1126fn bench_kernel<F: Fn() -> Vec<f32>>(f: F, iters: usize) -> u128 {
1127 let mut total_ns = 0u128;
1128 let mut acc = 0.0f32;
1129 for _ in 0..iters {
1130 let t0 = Instant::now();
1131 let out = f();
1132 total_ns += t0.elapsed().as_nanos();
1133 if let Some(v) = out.first() {
1134 acc += *v;
1135 }
1136 black_box(acc);
1137 }
1138 total_ns
1139}
1140
1141fn choose_cpu_kernel(
1142 m: usize,
1143 k: usize,
1144 n: usize,
1145 has_bias: bool,
1146 lhs_data: &[f32],
1147 rhs_data: &[f32],
1148 lhs_stride0: isize,
1149 lhs_stride1: isize,
1150 rhs_stride0: isize,
1151 rhs_stride1: isize,
1152 bias: Option<&[f32]>,
1153) -> CpuKernelChoice {
1154 let cfg = cpu_matmul_config();
1155 match cfg.strategy {
1156 CpuMatmulStrategy::Parallel => CpuKernelChoice::Parallel,
1157 CpuMatmulStrategy::Sgemm => CpuKernelChoice::Sgemm,
1158 CpuMatmulStrategy::Auto => {
1159 if should_use_parallel_auto(m, k, n) {
1160 CpuKernelChoice::Parallel
1161 } else {
1162 CpuKernelChoice::Sgemm
1163 }
1164 }
1165 CpuMatmulStrategy::Profile => {
1166 let key = (m, k, n, has_bias);
1167 if let Some(cached) = matmul_profile_cache().lock().get(&key).copied() {
1168 return cached;
1169 }
1170
1171 let iters = cfg.profile_iters.max(1);
1172 let sgemm_ns = bench_kernel(
1173 || {
1174 matmul_cpu_sgemm_core(
1175 lhs_data,
1176 rhs_data,
1177 m,
1178 k,
1179 n,
1180 lhs_stride0,
1181 lhs_stride1,
1182 rhs_stride0,
1183 rhs_stride1,
1184 bias,
1185 )
1186 },
1187 iters,
1188 );
1189 let parallel_ns = bench_kernel(
1190 || matmul_cpu_parallel_core(lhs_data, rhs_data, m, k, n, bias),
1191 iters,
1192 );
1193 let choice = if parallel_ns < sgemm_ns {
1194 CpuKernelChoice::Parallel
1195 } else {
1196 CpuKernelChoice::Sgemm
1197 };
1198 matmul_profile_cache().lock().insert(key, choice);
1199 choice
1200 }
1201 }
1202}
1203
1204fn matmul_cpu_parallel_core(
1205 lhs_data: &[f32],
1206 rhs_data: &[f32],
1207 m: usize,
1208 k: usize,
1209 n: usize,
1210 bias: Option<&[f32]>,
1211) -> Vec<f32> {
1212 let mut result = vec![0.0f32; m * n];
1213 result.par_chunks_mut(n).enumerate().for_each(|(i, row)| {
1214 let lhs_row = &lhs_data[i * k..(i + 1) * k];
1215 for j in 0..n {
1216 let mut sum = bias.map_or(0.0, |b| b[j]);
1217 let mut p = 0usize;
1218 while p + 8 <= k {
1219 sum += lhs_row[p] * rhs_data[p * n + j];
1220 sum += lhs_row[p + 1] * rhs_data[(p + 1) * n + j];
1221 sum += lhs_row[p + 2] * rhs_data[(p + 2) * n + j];
1222 sum += lhs_row[p + 3] * rhs_data[(p + 3) * n + j];
1223 sum += lhs_row[p + 4] * rhs_data[(p + 4) * n + j];
1224 sum += lhs_row[p + 5] * rhs_data[(p + 5) * n + j];
1225 sum += lhs_row[p + 6] * rhs_data[(p + 6) * n + j];
1226 sum += lhs_row[p + 7] * rhs_data[(p + 7) * n + j];
1227 p += 8;
1228 }
1229 while p < k {
1230 sum += lhs_row[p] * rhs_data[p * n + j];
1231 p += 1;
1232 }
1233 row[j] = sum;
1234 }
1235 });
1236 result
1237}
1238
1239#[derive(Clone, Copy, PartialEq, Eq, Hash)]
1240enum ElemwiseKind {
1241 Add,
1242 Sub,
1243 Mul,
1244}
1245
1246#[derive(Clone, Copy, PartialEq, Eq)]
1247enum CpuElemwiseStrategy {
1248 Auto,
1249 Profile,
1250 Scalar,
1251 Simd,
1252}
1253
1254#[derive(Clone, Copy, PartialEq, Eq)]
1255enum ElemwiseKernelChoice {
1256 Scalar,
1257 Simd,
1258}
1259
1260#[derive(Clone, Copy)]
1261struct CpuElemwiseConfig {
1262 strategy: CpuElemwiseStrategy,
1263 min_len: usize,
1264 profile_iters: usize,
1265}
1266
1267type ElemwisePerfKey = (usize, ElemwiseKind);
1268type ElemwisePerfCache = HashMap<ElemwisePerfKey, ElemwiseKernelChoice>;
1269
1270fn cpu_elemwise_config() -> CpuElemwiseConfig {
1271 static CFG: OnceLock<CpuElemwiseConfig> = OnceLock::new();
1272 *CFG.get_or_init(|| {
1273 let strategy = match std::env::var("RUSTORCH_CPU_ELEMWISE_STRATEGY")
1274 .unwrap_or_else(|_| "auto".to_string())
1275 .to_ascii_lowercase()
1276 .as_str()
1277 {
1278 "simd" => CpuElemwiseStrategy::Simd,
1279 "scalar" => CpuElemwiseStrategy::Scalar,
1280 "profile" => CpuElemwiseStrategy::Profile,
1281 _ => CpuElemwiseStrategy::Auto,
1282 };
1283 CpuElemwiseConfig {
1284 strategy,
1285 min_len: parse_usize_env("RUSTORCH_CPU_ELEMWISE_MIN_LEN", 2048),
1286 profile_iters: parse_usize_env("RUSTORCH_CPU_ELEMWISE_PROFILE_ITERS", 2),
1287 }
1288 })
1289}
1290
1291fn elemwise_profile_cache() -> &'static Mutex<ElemwisePerfCache> {
1292 static CACHE: OnceLock<Mutex<ElemwisePerfCache>> = OnceLock::new();
1293 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
1294}
1295
1296#[inline]
1297fn apply_elemwise_scalar(a: f32, b: f32, kind: ElemwiseKind) -> f32 {
1298 match kind {
1299 ElemwiseKind::Add => a + b,
1300 ElemwiseKind::Sub => a - b,
1301 ElemwiseKind::Mul => a * b,
1302 }
1303}
1304
1305#[inline]
1306fn apply_elemwise_simd(a: f32x8, b: f32x8, kind: ElemwiseKind) -> f32x8 {
1307 match kind {
1308 ElemwiseKind::Add => a + b,
1309 ElemwiseKind::Sub => a - b,
1310 ElemwiseKind::Mul => a * b,
1311 }
1312}
1313
1314fn elemwise_scalar(lhs: &[f32], rhs: &[f32], kind: ElemwiseKind) -> Vec<f32> {
1315 lhs.par_iter()
1316 .zip(rhs.par_iter())
1317 .map(|(a, b)| apply_elemwise_scalar(*a, *b, kind))
1318 .collect()
1319}
1320
1321fn elemwise_simd(lhs: &[f32], rhs: &[f32], kind: ElemwiseKind) -> Vec<f32> {
1322 let len = lhs.len();
1323 let mut out = vec![0.0f32; len];
1324 let vec_len = len / 8 * 8;
1325 let lanes = 8usize;
1326
1327 out[..vec_len]
1328 .par_chunks_mut(1024)
1329 .enumerate()
1330 .for_each(|(chunk_idx, out_chunk)| {
1331 let chunk_start = chunk_idx * 1024;
1332 let lhs_chunk = &lhs[chunk_start..chunk_start + out_chunk.len()];
1333 let rhs_chunk = &rhs[chunk_start..chunk_start + out_chunk.len()];
1334 let mut i = 0usize;
1335 while i + lanes <= out_chunk.len() {
1336 let mut la = [0.0f32; 8];
1337 let mut lb = [0.0f32; 8];
1338 la.copy_from_slice(&lhs_chunk[i..i + lanes]);
1339 lb.copy_from_slice(&rhs_chunk[i..i + lanes]);
1340 let va = f32x8::from(la);
1341 let vb = f32x8::from(lb);
1342 let vc = apply_elemwise_simd(va, vb, kind);
1343 let oc: [f32; 8] = vc.into();
1344 out_chunk[i..i + lanes].copy_from_slice(&oc);
1345 i += lanes;
1346 }
1347 while i < out_chunk.len() {
1348 out_chunk[i] = apply_elemwise_scalar(lhs_chunk[i], rhs_chunk[i], kind);
1349 i += 1;
1350 }
1351 });
1352
1353 for i in vec_len..len {
1354 out[i] = apply_elemwise_scalar(lhs[i], rhs[i], kind);
1355 }
1356 out
1357}
1358
1359fn choose_elemwise_kernel(
1360 len: usize,
1361 kind: ElemwiseKind,
1362 lhs: &[f32],
1363 rhs: &[f32],
1364) -> ElemwiseKernelChoice {
1365 let cfg = cpu_elemwise_config();
1366 match cfg.strategy {
1367 CpuElemwiseStrategy::Simd => ElemwiseKernelChoice::Simd,
1368 CpuElemwiseStrategy::Scalar => ElemwiseKernelChoice::Scalar,
1369 CpuElemwiseStrategy::Auto => {
1370 if len >= cfg.min_len {
1371 ElemwiseKernelChoice::Simd
1372 } else {
1373 ElemwiseKernelChoice::Scalar
1374 }
1375 }
1376 CpuElemwiseStrategy::Profile => {
1377 let key = (len, kind);
1378 if let Some(cached) = elemwise_profile_cache().lock().get(&key).copied() {
1379 return cached;
1380 }
1381 let iters = cfg.profile_iters.max(1);
1382 let scalar_ns = {
1383 let mut total = 0u128;
1384 for _ in 0..iters {
1385 let t0 = Instant::now();
1386 let out = elemwise_scalar(lhs, rhs, kind);
1387 black_box(out.len());
1388 total += t0.elapsed().as_nanos();
1389 }
1390 total
1391 };
1392 let simd_ns = {
1393 let mut total = 0u128;
1394 for _ in 0..iters {
1395 let t0 = Instant::now();
1396 let out = elemwise_simd(lhs, rhs, kind);
1397 black_box(out.len());
1398 total += t0.elapsed().as_nanos();
1399 }
1400 total
1401 };
1402 let choice = if simd_ns < scalar_ns {
1403 ElemwiseKernelChoice::Simd
1404 } else {
1405 ElemwiseKernelChoice::Scalar
1406 };
1407 elemwise_profile_cache().lock().insert(key, choice);
1408 choice
1409 }
1410 }
1411}
1412
1413fn elemwise_auto(lhs: &[f32], rhs: &[f32], kind: ElemwiseKind) -> Vec<f32> {
1414 match choose_elemwise_kernel(lhs.len(), kind, lhs, rhs) {
1415 ElemwiseKernelChoice::Simd => elemwise_simd(lhs, rhs, kind),
1416 ElemwiseKernelChoice::Scalar => elemwise_scalar(lhs, rhs, kind),
1417 }
1418}
1419
1420#[derive(Clone, Copy, PartialEq, Eq)]
1421enum CpuReductionStrategy {
1422 Auto,
1423 Profile,
1424 Scalar,
1425 Simd,
1426}
1427
1428#[derive(Clone, Copy, PartialEq, Eq)]
1429enum ReductionKernelChoice {
1430 Scalar,
1431 Simd,
1432}
1433
1434#[derive(Clone, Copy)]
1435struct CpuReductionConfig {
1436 strategy: CpuReductionStrategy,
1437 min_len: usize,
1438 profile_iters: usize,
1439}
1440
1441fn cpu_reduction_config() -> CpuReductionConfig {
1442 static CFG: OnceLock<CpuReductionConfig> = OnceLock::new();
1443 *CFG.get_or_init(|| {
1444 let strategy = match std::env::var("RUSTORCH_CPU_REDUCTION_STRATEGY")
1445 .unwrap_or_else(|_| "auto".to_string())
1446 .to_ascii_lowercase()
1447 .as_str()
1448 {
1449 "simd" => CpuReductionStrategy::Simd,
1450 "scalar" => CpuReductionStrategy::Scalar,
1451 "profile" => CpuReductionStrategy::Profile,
1452 _ => CpuReductionStrategy::Auto,
1453 };
1454 CpuReductionConfig {
1455 strategy,
1456 min_len: parse_usize_env("RUSTORCH_CPU_REDUCTION_MIN_LEN", 4096),
1457 profile_iters: parse_usize_env("RUSTORCH_CPU_REDUCTION_PROFILE_ITERS", 2),
1458 }
1459 })
1460}
1461
1462fn reduction_profile_cache() -> &'static Mutex<HashMap<usize, ReductionKernelChoice>> {
1463 static CACHE: OnceLock<Mutex<HashMap<usize, ReductionKernelChoice>>> = OnceLock::new();
1464 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
1465}
1466
1467fn sum_scalar(data: &[f32]) -> f32 {
1468 data.par_iter().copied().sum()
1469}
1470
1471fn sum_simd_chunk(chunk: &[f32]) -> f32 {
1472 let lanes = 8usize;
1473 let vec_len = chunk.len() / lanes * lanes;
1474 let mut acc = f32x8::from([0.0; 8]);
1475 let mut i = 0usize;
1476 while i < vec_len {
1477 let mut v = [0.0f32; 8];
1478 v.copy_from_slice(&chunk[i..i + lanes]);
1479 acc += f32x8::from(v);
1480 i += lanes;
1481 }
1482 let a: [f32; 8] = acc.into();
1483 let mut s = a.iter().sum::<f32>();
1484 while i < chunk.len() {
1485 s += chunk[i];
1486 i += 1;
1487 }
1488 s
1489}
1490
1491fn sum_simd(data: &[f32]) -> f32 {
1492 data.par_chunks(4096).map(sum_simd_chunk).sum()
1493}
1494
1495fn choose_reduction_kernel(len: usize, data: &[f32]) -> ReductionKernelChoice {
1496 let cfg = cpu_reduction_config();
1497 match cfg.strategy {
1498 CpuReductionStrategy::Simd => ReductionKernelChoice::Simd,
1499 CpuReductionStrategy::Scalar => ReductionKernelChoice::Scalar,
1500 CpuReductionStrategy::Auto => {
1501 if len >= cfg.min_len {
1502 ReductionKernelChoice::Simd
1503 } else {
1504 ReductionKernelChoice::Scalar
1505 }
1506 }
1507 CpuReductionStrategy::Profile => {
1508 if let Some(cached) = reduction_profile_cache().lock().get(&len).copied() {
1509 return cached;
1510 }
1511 let iters = cfg.profile_iters.max(1);
1512 let mut scalar_ns = 0u128;
1513 let mut simd_ns = 0u128;
1514 for _ in 0..iters {
1515 let t0 = Instant::now();
1516 let s = sum_scalar(data);
1517 scalar_ns += t0.elapsed().as_nanos();
1518 black_box(s);
1519
1520 let t1 = Instant::now();
1521 let v = sum_simd(data);
1522 simd_ns += t1.elapsed().as_nanos();
1523 black_box(v);
1524 }
1525 let choice = if simd_ns < scalar_ns {
1526 ReductionKernelChoice::Simd
1527 } else {
1528 ReductionKernelChoice::Scalar
1529 };
1530 reduction_profile_cache().lock().insert(len, choice);
1531 choice
1532 }
1533 }
1534}
1535
1536pub(crate) fn sum_auto(data: &[f32]) -> f32 {
1537 match choose_reduction_kernel(data.len(), data) {
1538 ReductionKernelChoice::Simd => sum_simd(data),
1539 ReductionKernelChoice::Scalar => sum_scalar(data),
1540 }
1541}
1542
1543pub fn matmul_fused(
1544 lhs: &Tensor,
1545 rhs: &Tensor,
1546 bias: Option<&Tensor>,
1547 activation: crate::backend::Activation,
1548) -> Tensor {
1549 #[cfg(feature = "wgpu_backend")]
1550 {
1551 if let (Some(lhs_buf), Some(rhs_buf)) =
1552 (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer())
1553 {
1554 let m = lhs.shape()[0];
1555 let _k = lhs.shape()[1];
1556 let n = rhs.shape()[1];
1557
1558 let bias_data =
1559 bias.and_then(|b| b.storage().wgpu_buffer().map(|buf| (buf, b.shape())));
1560
1561 use crate::backend::wgpu::matmul_fused_wgpu_buffer;
1562 let output_buf = matmul_fused_wgpu_buffer(
1563 lhs_buf,
1564 lhs.shape(),
1565 rhs_buf,
1566 rhs.shape(),
1567 bias_data,
1568 activation,
1569 );
1570
1571 let storage = Storage::new_wgpu(output_buf, m * n, 0);
1572 let mut tensor = Tensor::new_with_storage(storage, &[m, n]);
1573
1574 if lhs.requires_grad()
1575 || rhs.requires_grad()
1576 || bias.map_or(false, |b| b.requires_grad())
1577 {
1578 tensor.set_requires_grad_mut(true);
1579 tensor.set_op(Arc::new(FusedMatmulBackward {
1580 lhs: lhs.clone(),
1581 rhs: rhs.clone(),
1582 bias: bias.cloned(),
1583 output: tensor.detach(),
1584 activation,
1585 }));
1586 }
1587 return tensor;
1588 }
1589 }
1590
1591 if matches!(activation, crate::backend::Activation::None) {
1592 let lhs_shape = lhs.shape();
1593 let rhs_shape = rhs.shape();
1594 if lhs_shape.len() == 2 && rhs_shape.len() == 2 && lhs_shape[1] == rhs_shape[0] {
1595 let m = lhs_shape[0];
1596 let k = lhs_shape[1];
1597 let n = rhs_shape[1];
1598
1599 let lhs_contig = if lhs.is_contiguous() {
1600 lhs.clone()
1601 } else {
1602 lhs.contiguous()
1603 };
1604 let rhs_contig = if rhs.is_contiguous() {
1605 rhs.clone()
1606 } else {
1607 rhs.contiguous()
1608 };
1609
1610 #[cfg(feature = "wgpu_backend")]
1611 let (lhs_contig, rhs_contig) = {
1612 let l = if lhs_contig.storage().device().is_wgpu() {
1613 lhs_contig.to_cpu()
1614 } else {
1615 lhs_contig
1616 };
1617 let r = if rhs_contig.storage().device().is_wgpu() {
1618 rhs_contig.to_cpu()
1619 } else {
1620 rhs_contig
1621 };
1622 (l, r)
1623 };
1624
1625 let lhs_guard = lhs_contig.data();
1626 let rhs_guard = rhs_contig.data();
1627 let lhs_data = &*lhs_guard;
1628 let rhs_data = &*rhs_guard;
1629
1630 let bias_vec = bias.and_then(|b| {
1631 if b.shape().len() == 1 && b.shape()[0] == n {
1632 let b_cpu = {
1633 #[cfg(feature = "wgpu_backend")]
1634 {
1635 if b.storage().device().is_wgpu() {
1636 b.to_cpu()
1637 } else {
1638 b.clone()
1639 }
1640 }
1641 #[cfg(not(feature = "wgpu_backend"))]
1642 {
1643 b.clone()
1644 }
1645 };
1646 let bg = b_cpu.data();
1647 Some(bg.to_vec())
1648 } else {
1649 None
1650 }
1651 });
1652 let bias_slice = bias_vec.as_deref();
1653
1654 let lhs_s0 = lhs_contig.strides()[0] as isize;
1655 let lhs_s1 = lhs_contig.strides()[1] as isize;
1656 let rhs_s0 = rhs_contig.strides()[0] as isize;
1657 let rhs_s1 = rhs_contig.strides()[1] as isize;
1658 let kernel = choose_cpu_kernel(
1659 m,
1660 k,
1661 n,
1662 bias_slice.is_some(),
1663 lhs_data,
1664 rhs_data,
1665 lhs_s0,
1666 lhs_s1,
1667 rhs_s0,
1668 rhs_s1,
1669 bias_slice,
1670 );
1671 let result_data = match kernel {
1672 CpuKernelChoice::Parallel => {
1673 matmul_cpu_parallel_core(lhs_data, rhs_data, m, k, n, bias_slice)
1674 }
1675 CpuKernelChoice::Sgemm => matmul_cpu_sgemm_core(
1676 lhs_data, rhs_data, m, k, n, lhs_s0, lhs_s1, rhs_s0, rhs_s1, bias_slice,
1677 ),
1678 };
1679
1680 let storage = Storage::new(result_data);
1681 let mut tensor = Tensor::new_with_storage(storage, &[m, n]);
1682 if lhs.requires_grad()
1683 || rhs.requires_grad()
1684 || bias.map_or(false, |b| b.requires_grad())
1685 {
1686 tensor.set_requires_grad_mut(true);
1687 tensor.set_op(Arc::new(FusedMatmulBackward {
1688 lhs: lhs.clone(),
1689 rhs: rhs.clone(),
1690 bias: bias.cloned(),
1691 output: tensor.detach(),
1692 activation,
1693 }));
1694 }
1695 return tensor;
1696 }
1697 }
1698
1699 bump_pipeline_stat("staged");
1700 let mut out = lhs.matmul(rhs);
1701 if let Some(b) = bias {
1702 out = out.add(b);
1703 }
1704 match activation {
1705 crate::backend::Activation::ReLU => out.relu(),
1706 crate::backend::Activation::Sigmoid => crate::ops::activations::sigmoid(&out),
1707 crate::backend::Activation::Tanh => crate::ops::activations::tanh(&out),
1708 crate::backend::Activation::None => out,
1709 }
1710}
1711
1712#[derive(Clone, Copy, PartialEq, Eq)]
1713enum FusedPipelineStrategy {
1714 Auto,
1715 Profile,
1716 Staged,
1717 Fused,
1718}
1719
1720#[derive(Clone, Copy, PartialEq, Eq)]
1721enum FusedPipelineChoice {
1722 Staged,
1723 Fused,
1724}
1725
1726#[derive(Clone, Copy)]
1727struct FusedPipelineConfig {
1728 strategy: FusedPipelineStrategy,
1729 profile_iters: usize,
1730}
1731
1732type FusedPipelineKey = (usize, usize, usize, bool, bool, i32);
1733
1734fn fused_pipeline_config() -> FusedPipelineConfig {
1735 static CFG: OnceLock<FusedPipelineConfig> = OnceLock::new();
1736 *CFG.get_or_init(|| {
1737 let strategy = match std::env::var("RUSTORCH_FUSED_PIPELINE_STRATEGY")
1738 .unwrap_or_else(|_| "auto".to_string())
1739 .to_ascii_lowercase()
1740 .as_str()
1741 {
1742 "fused" => FusedPipelineStrategy::Fused,
1743 "staged" => FusedPipelineStrategy::Staged,
1744 "profile" => FusedPipelineStrategy::Profile,
1745 _ => FusedPipelineStrategy::Auto,
1746 };
1747 FusedPipelineConfig {
1748 strategy,
1749 profile_iters: parse_usize_env("RUSTORCH_FUSED_PIPELINE_PROFILE_ITERS", 1),
1750 }
1751 })
1752}
1753
1754fn fused_pipeline_cache() -> &'static Mutex<HashMap<FusedPipelineKey, FusedPipelineChoice>> {
1755 static CACHE: OnceLock<Mutex<HashMap<FusedPipelineKey, FusedPipelineChoice>>> = OnceLock::new();
1756 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
1757}
1758
1759fn fused_pipeline_stats() -> &'static Mutex<HashMap<String, u64>> {
1760 static STATS: OnceLock<Mutex<HashMap<String, u64>>> = OnceLock::new();
1761 STATS.get_or_init(|| Mutex::new(HashMap::new()))
1762}
1763
1764fn bump_pipeline_stat(key: &str) {
1765 let mut s = fused_pipeline_stats().lock();
1766 *s.entry(key.to_string()).or_insert(0) += 1;
1767}
1768
1769pub fn get_fused_pipeline_stats() -> HashMap<String, u64> {
1770 fused_pipeline_stats().lock().clone()
1771}
1772
1773fn apply_activation(t: Tensor, activation: crate::backend::Activation) -> Tensor {
1774 match activation {
1775 crate::backend::Activation::ReLU => t.relu(),
1776 crate::backend::Activation::Sigmoid => crate::ops::activations::sigmoid(&t),
1777 crate::backend::Activation::Tanh => crate::ops::activations::tanh(&t),
1778 crate::backend::Activation::None => t,
1779 }
1780}
1781
1782fn pipeline_staged(
1783 lhs: &Tensor,
1784 rhs: &Tensor,
1785 bias: Option<&Tensor>,
1786 norm_weight: Option<&Tensor>,
1787 norm_bias: Option<&Tensor>,
1788 eps: f32,
1789 activation: crate::backend::Activation,
1790) -> Tensor {
1791 let mut out = matmul(lhs, rhs);
1792 if let Some(b) = bias {
1793 out = add(&out, b);
1794 }
1795 let norm_shape = [rhs.shape()[1]];
1796 out = layer_norm(&out, &norm_shape, norm_weight, norm_bias, eps);
1797 apply_activation(out, activation)
1798}
1799
1800fn pipeline_fused(
1801 lhs: &Tensor,
1802 rhs: &Tensor,
1803 bias: Option<&Tensor>,
1804 norm_weight: Option<&Tensor>,
1805 norm_bias: Option<&Tensor>,
1806 eps: f32,
1807 activation: crate::backend::Activation,
1808) -> Tensor {
1809 let out = matmul_fused(lhs, rhs, bias, crate::backend::Activation::None);
1810 let norm_shape = [rhs.shape()[1]];
1811 let out = layer_norm(&out, &norm_shape, norm_weight, norm_bias, eps);
1812 apply_activation(out, activation)
1813}
1814
1815pub fn matmul_bias_norm_activation(
1816 lhs: &Tensor,
1817 rhs: &Tensor,
1818 bias: Option<&Tensor>,
1819 norm_weight: Option<&Tensor>,
1820 norm_bias: Option<&Tensor>,
1821 eps: f32,
1822 activation: crate::backend::Activation,
1823) -> Tensor {
1824 let m = lhs.shape()[0];
1825 let k = lhs.shape()[1];
1826 let n = rhs.shape()[1];
1827 let key: FusedPipelineKey = (
1828 m,
1829 k,
1830 n,
1831 norm_weight.is_some(),
1832 norm_bias.is_some(),
1833 activation as i32,
1834 );
1835 let cfg = fused_pipeline_config();
1836 let choice = match cfg.strategy {
1837 FusedPipelineStrategy::Fused => FusedPipelineChoice::Fused,
1838 FusedPipelineStrategy::Staged => FusedPipelineChoice::Staged,
1839 FusedPipelineStrategy::Auto => {
1840 if m >= 128 && k >= 128 && n >= 32 {
1841 FusedPipelineChoice::Fused
1842 } else {
1843 FusedPipelineChoice::Staged
1844 }
1845 }
1846 FusedPipelineStrategy::Profile => {
1847 if let Some(cached) = fused_pipeline_cache().lock().get(&key).copied() {
1848 cached
1849 } else {
1850 let iters = cfg.profile_iters.max(1);
1851 let mut staged_ns = 0u128;
1852 let mut fused_ns = 0u128;
1853 for _ in 0..iters {
1854 let t0 = Instant::now();
1855 let s =
1856 pipeline_staged(lhs, rhs, bias, norm_weight, norm_bias, eps, activation);
1857 staged_ns += t0.elapsed().as_nanos();
1858 black_box(s.shape()[0]);
1859
1860 let t1 = Instant::now();
1861 let f = pipeline_fused(lhs, rhs, bias, norm_weight, norm_bias, eps, activation);
1862 fused_ns += t1.elapsed().as_nanos();
1863 black_box(f.shape()[0]);
1864 }
1865 let c = if fused_ns < staged_ns {
1866 FusedPipelineChoice::Fused
1867 } else {
1868 FusedPipelineChoice::Staged
1869 };
1870 fused_pipeline_cache().lock().insert(key, c);
1871 c
1872 }
1873 }
1874 };
1875
1876 match choice {
1877 FusedPipelineChoice::Fused => {
1878 bump_pipeline_stat("fused");
1879 pipeline_fused(lhs, rhs, bias, norm_weight, norm_bias, eps, activation)
1880 }
1881 FusedPipelineChoice::Staged => {
1882 bump_pipeline_stat("staged");
1883 pipeline_staged(lhs, rhs, bias, norm_weight, norm_bias, eps, activation)
1884 }
1885 }
1886}
1887
1888pub fn matmul(lhs: &Tensor, rhs: &Tensor) -> Tensor {
1889 let lhs_shape = lhs.shape();
1890 let rhs_shape = rhs.shape();
1891
1892 if lhs_shape.len() != 2 || rhs_shape.len() != 2 {
1893 panic!("Matmul only supports 2D");
1894 }
1895
1896 let m = lhs_shape[0];
1897 let k = lhs_shape[1];
1898 let k2 = rhs_shape[0];
1899 let n = rhs_shape[1];
1900
1901 if k != k2 {
1902 panic!("Matmul dimension mismatch");
1903 }
1904
1905 #[cfg(feature = "wgpu_backend")]
1906 {
1907 let lhs_is_wgpu = lhs.storage().device().is_wgpu();
1909 let rhs_is_wgpu = rhs.storage().device().is_wgpu();
1910
1911 let (lhs, rhs) = if lhs_is_wgpu && !rhs_is_wgpu {
1912 (lhs.clone(), rhs.to_wgpu())
1914 } else if !lhs_is_wgpu && rhs_is_wgpu {
1915 (lhs.to_wgpu(), rhs.clone())
1917 } else {
1918 (lhs.clone(), rhs.clone())
1919 };
1920
1921 if let (Some(lhs_buf), Some(rhs_buf)) =
1922 (lhs.storage().wgpu_buffer(), rhs.storage().wgpu_buffer())
1923 {
1924 let lhs_strides = lhs.strides();
1925 let rhs_strides = rhs.strides();
1926
1927 let lhs_is_contig = lhs.is_contiguous();
1928 let rhs_is_contig = rhs.is_contiguous();
1929
1930 let lhs_is_transposed = !lhs_is_contig && lhs_strides[0] == 1;
1934 let rhs_is_transposed = !rhs_is_contig && rhs_strides[0] == 1;
1935
1936 if lhs_is_contig && rhs_is_contig {
1937 use crate::backend::wgpu::{matmul_wgpu_buffer, Activation};
1938 let output_buf =
1939 matmul_wgpu_buffer(lhs_buf, lhs_shape, rhs_buf, rhs_shape, Activation::None);
1940
1941 let storage = Storage::new_wgpu(output_buf, m * n, 0);
1942 let mut tensor = Tensor::new_with_storage(storage, &[m, n]);
1943
1944 if lhs.requires_grad() || rhs.requires_grad() {
1945 tensor.set_requires_grad_mut(true);
1946 tensor.set_op(Arc::new(MatmulBackward {
1947 lhs: lhs.clone(),
1948 rhs: rhs.clone(),
1949 }));
1950 }
1951 return tensor;
1952 }
1953
1954 if lhs_is_transposed || rhs_is_transposed {
1955 crate::backend::wgpu::flush_queue();
1957
1958 let lhs_contig = if lhs_is_transposed {
1962 lhs.contiguous()
1963 } else {
1964 lhs.clone()
1965 };
1966 let rhs_contig = if rhs_is_transposed {
1967 rhs.contiguous()
1968 } else {
1969 rhs.clone()
1970 };
1971
1972 crate::backend::wgpu::flush_queue();
1974
1975 if lhs_contig.storage().wgpu_buffer().is_some()
1976 && rhs_contig.storage().wgpu_buffer().is_some()
1977 {
1978 let lhs_buf = lhs_contig.storage().wgpu_buffer().unwrap();
1979 let rhs_buf = rhs_contig.storage().wgpu_buffer().unwrap();
1980
1981 use crate::backend::wgpu::{matmul_wgpu_buffer, Activation};
1982 let output_buf = matmul_wgpu_buffer(
1983 lhs_buf,
1984 lhs_contig.shape(),
1985 rhs_buf,
1986 rhs_contig.shape(),
1987 Activation::None,
1988 );
1989
1990 let storage = Storage::new_wgpu(output_buf, m * n, 0);
1991 let mut tensor = Tensor::new_with_storage(storage, &[m, n]);
1992
1993 if lhs.requires_grad() || rhs.requires_grad() {
1994 tensor.set_requires_grad_mut(true);
1995 tensor.set_op(Arc::new(MatmulBackward {
1996 lhs: lhs.clone(),
1997 rhs: rhs.clone(),
1998 }));
1999 }
2000 return tensor;
2001 }
2002 }
2003
2004 return matmul(&lhs.contiguous(), &rhs.contiguous());
2005 }
2006 }
2007
2008 let lhs_contig = if lhs.is_contiguous() {
2010 lhs.clone()
2011 } else {
2012 lhs.contiguous()
2013 };
2014 let rhs_contig = if rhs.is_contiguous() {
2015 rhs.clone()
2016 } else {
2017 rhs.contiguous()
2018 };
2019
2020 #[cfg(feature = "wgpu_backend")]
2021 let (lhs_contig, rhs_contig) = {
2022 let l = if lhs_contig.storage().device().is_wgpu() {
2023 lhs_contig.to_cpu()
2024 } else {
2025 lhs_contig
2026 };
2027 let r = if rhs_contig.storage().device().is_wgpu() {
2028 rhs_contig.to_cpu()
2029 } else {
2030 rhs_contig
2031 };
2032 (l, r)
2033 };
2034
2035 let lhs_guard = lhs_contig.data();
2036 let rhs_guard = rhs_contig.data();
2037 let lhs_data = &*lhs_guard;
2038 let rhs_data = &*rhs_guard;
2039
2040 let lhs_s0 = lhs_contig.strides()[0] as isize;
2041 let lhs_s1 = lhs_contig.strides()[1] as isize;
2042 let rhs_s0 = rhs_contig.strides()[0] as isize;
2043 let rhs_s1 = rhs_contig.strides()[1] as isize;
2044 let kernel = choose_cpu_kernel(
2045 m, k, n, false, lhs_data, rhs_data, lhs_s0, lhs_s1, rhs_s0, rhs_s1, None,
2046 );
2047 let result_data = match kernel {
2048 CpuKernelChoice::Parallel => matmul_cpu_parallel_core(lhs_data, rhs_data, m, k, n, None),
2049 CpuKernelChoice::Sgemm => matmul_cpu_sgemm_core(
2050 lhs_data, rhs_data, m, k, n, lhs_s0, lhs_s1, rhs_s0, rhs_s1, None,
2051 ),
2052 };
2053
2054 let storage = Storage::new(result_data);
2055 let mut tensor = Tensor::new_with_storage(storage, &[m, n]);
2056
2057 if lhs.requires_grad() || rhs.requires_grad() {
2058 tensor.set_requires_grad_mut(true);
2059 tensor.set_op(Arc::new(MatmulBackward {
2060 lhs: lhs.clone(),
2061 rhs: rhs.clone(),
2062 }));
2063 }
2064
2065 tensor
2066}