1use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
33#[cfg(target_arch = "x86_64")]
34use crate::tensor::core::memory::simd_alignment_bytes;
35use crate::tensor::core::memory::{detect_runtime_simd, SimdLevel};
36use crate::tensor::core::Tensor;
39struct CachedKernels {
45 simd_level: SimdLevel,
46 alignment: usize,
47
48 tensor_aligned: unsafe fn(*const f32, *const f32, *mut f32, usize),
50 tensor_unaligned: unsafe fn(*const f32, *const f32, *mut f32, usize),
51 tensor_stream: unsafe fn(*const f32, *const f32, *mut f32, usize),
52
53 scalar_aligned: unsafe fn(*const f32, *mut f32, usize, f32),
55 scalar_unaligned: unsafe fn(*const f32, *mut f32, usize, f32),
56 scalar_stream: unsafe fn(*const f32, *mut f32, usize, f32),
57
58 min_aligned_size: usize,
60 min_stream_size: usize,
61}
62
63impl Tensor {
64 #[inline]
66 pub(crate) fn stream_min_elems() -> usize {
67 1 << 22 }
69
70 #[inline]
135 #[track_caller]
136 pub fn add_tensor(&self, other: &Tensor) -> Tensor {
137 if self.shape().dims() == other.shape().dims() {
139 return self.add_tensor_same_shape(other);
140 }
141
142 let mut result = self.add_tensor_optimized(other);
144
145 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
146 result.set_requires_grad_internal(true);
147 let grad_fn = GradFn::Add {
148 is_tensor_add: true,
149 original_shapes: Some((
150 self.shape().dims().to_vec(),
151 other.shape().dims().to_vec(),
152 )),
153 };
154 result.set_grad_fn(grad_fn.clone());
155
156 let input_ids = vec![self.id(), other.id()];
158 GradEngine::register_operation(result.id(), input_ids, grad_fn);
159 }
160
161 result
162 }
163
164 #[inline]
166 fn add_tensor_same_shape(&self, other: &Tensor) -> Tensor {
167 assert_eq!(
168 self.shape(),
169 other.shape(),
170 "Tensor shapes must match for same-shape addition"
171 );
172 let mut result = self.add_tensor_same_shape_optimized(other);
173
174 if (self.requires_grad() || other.requires_grad()) && is_grad_enabled() {
175 result.set_requires_grad_internal(true);
176 let grad_fn = GradFn::Add {
177 is_tensor_add: true,
178 original_shapes: None, };
180 result.set_grad_fn(grad_fn.clone());
181
182 let input_ids = vec![self.id(), other.id()];
184 GradEngine::register_operation(result.id(), input_ids, grad_fn);
185 }
186
187 result
188 }
189
190 #[inline]
192 #[track_caller]
193 pub fn add_scalar(&self, scalar: f32) -> Tensor {
194 let mut result = self.add_scalar_optimized(scalar);
195
196 if self.requires_grad() && is_grad_enabled() {
197 result.set_requires_grad_internal(true);
198 let grad_fn = GradFn::Add {
199 is_tensor_add: false,
200 original_shapes: None, };
202 result.set_grad_fn(grad_fn.clone());
203 let input_ids = vec![self.id()];
205 GradEngine::register_operation(result.id(), input_ids, grad_fn);
206 }
207
208 result
209 }
210
211 #[inline]
213 pub(crate) fn add_tensor_optimized(&self, other: &Tensor) -> Tensor {
214 if self.shape() == other.shape() {
216 return self.add_tensor_same_shape_optimized(other);
217 }
218
219 use crate::tensor::ops::broadcasting::{broadcast_shapes_cow, BroadcastError};
221
222 match broadcast_shapes_cow(self, other) {
223 Ok((broadcasted_self, broadcasted_other, _result_shape)) => {
224 debug_assert_eq!(
225 broadcasted_self.shape().dims(),
226 broadcasted_other.shape().dims()
227 );
228 broadcasted_self
229 .as_ref()
230 .add_tensor_same_shape_optimized(broadcasted_other.as_ref())
231 }
232 Err(BroadcastError::IncompatibleShapes { shape1, shape2, .. }) => {
233 panic!(
234 "Cannot broadcast tensor shapes {:?} and {:?}: shapes are incompatible",
235 shape1, shape2
236 );
237 }
238 Err(BroadcastError::AllocationFailed) => {
239 panic!("Memory allocation failed during broadcasting");
240 }
241 }
242 }
243
244 #[inline]
246 fn add_tensor_same_shape_optimized(&self, other: &Tensor) -> Tensor {
247 debug_assert_eq!(
248 self.shape().dims(),
249 other.shape().dims(),
250 "Tensor dims must match"
251 );
252
253 let (a_ptr, _a_keep): (*const f32, Option<Tensor>) = Self::get_optimized_tensor_ptr(self);
255 let (b_ptr, _b_keep): (*const f32, Option<Tensor>) = Self::get_optimized_tensor_ptr(other);
256
257 let mut output = Tensor::new(self.shape().dims().to_vec());
258
259 unsafe {
260 let a = a_ptr;
261 let b = b_ptr;
262 let dst = output.as_mut_ptr();
263 let n = self.size();
264
265 let stream_min = Self::stream_min_elems();
267 if n >= stream_min && Self::try_add_stream_best(a, b, dst, n) {
268 } else if !Self::try_add_simd_best(a, b, dst, n) {
270 Self::add_tensors_scalar_chunk(a, b, dst, n);
271 }
272 }
273
274 output
275 }
276
277 #[inline]
281 unsafe fn try_add_simd_best(a: *const f32, b: *const f32, dst: *mut f32, size: usize) -> bool {
282 if size == 0 {
283 return true;
284 }
285
286 let kernels = Self::get_cached_kernels();
287
288 if matches!(kernels.simd_level, SimdLevel::Scalar) {
290 return false;
291 }
292
293 let a_mod = (a as usize) % kernels.alignment;
295 let b_mod = (b as usize) % kernels.alignment;
296 let d_mod = (dst as usize) % kernels.alignment;
297
298 if a_mod == 0 && b_mod == 0 && d_mod == 0 && size >= kernels.min_aligned_size {
300 (kernels.tensor_aligned)(a, b, dst, size);
301 return true;
302 }
303
304 if a_mod == b_mod && b_mod == d_mod && size >= kernels.min_aligned_size {
306 let bytes_to_align = if a_mod == 0 {
307 0
308 } else {
309 kernels.alignment - a_mod
310 };
311 let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(size);
312
313 for i in 0..elems_to_align {
315 *dst.add(i) = *a.add(i) + *b.add(i);
316 }
317
318 let rem = size - elems_to_align;
319 if rem >= kernels.min_aligned_size {
320 (kernels.tensor_aligned)(
321 a.add(elems_to_align),
322 b.add(elems_to_align),
323 dst.add(elems_to_align),
324 rem,
325 );
326 }
327 return true;
328 }
329
330 (kernels.tensor_unaligned)(a, b, dst, size);
332 true
333 }
334
335 #[inline]
337 unsafe fn try_add_stream_best(
338 a: *const f32,
339 b: *const f32,
340 dst: *mut f32,
341 size: usize,
342 ) -> bool {
343 let kernels = Self::get_cached_kernels();
344
345 if size < kernels.min_stream_size || size == 0 {
346 return false;
347 }
348
349 if (dst as usize).is_multiple_of(kernels.alignment) {
351 (kernels.tensor_stream)(a, b, dst, size);
352 return true;
353 }
354
355 false
356 }
357
358 #[cfg(target_arch = "x86_64")]
359 #[inline]
360 #[target_feature(enable = "avx512f")]
361 unsafe fn add_simd_avx512_aligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
362 use std::arch::x86_64::*;
363 let mut offset = 0usize;
364 let block = 64usize; while offset + block <= size {
366 let a1 = _mm512_load_ps(a.add(offset));
367 let b1 = _mm512_load_ps(b.add(offset));
368 _mm512_store_ps(dst.add(offset), _mm512_add_ps(a1, b1));
369
370 let a2 = _mm512_load_ps(a.add(offset + 16));
371 let b2 = _mm512_load_ps(b.add(offset + 16));
372 _mm512_store_ps(dst.add(offset + 16), _mm512_add_ps(a2, b2));
373
374 let a3 = _mm512_load_ps(a.add(offset + 32));
375 let b3 = _mm512_load_ps(b.add(offset + 32));
376 _mm512_store_ps(dst.add(offset + 32), _mm512_add_ps(a3, b3));
377
378 let a4 = _mm512_load_ps(a.add(offset + 48));
379 let b4 = _mm512_load_ps(b.add(offset + 48));
380 _mm512_store_ps(dst.add(offset + 48), _mm512_add_ps(a4, b4));
381 offset += block;
382 }
383 let mut rem = size - offset;
384 while rem >= 16 {
385 let av = _mm512_load_ps(a.add(offset));
386 let bv = _mm512_load_ps(b.add(offset));
387 _mm512_store_ps(dst.add(offset), _mm512_add_ps(av, bv));
388 offset += 16;
389 rem -= 16;
390 }
391 for i in offset..size {
392 *dst.add(i) = *a.add(i) + *b.add(i);
393 }
394 }
395
396 #[cfg(target_arch = "x86_64")]
397 #[inline]
398 #[target_feature(enable = "avx512f")]
399 unsafe fn add_simd_avx512_unaligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
400 use std::arch::x86_64::*;
401 let mut offset = 0usize;
402 let block = 64usize;
403 while offset + block <= size {
404 let a1 = _mm512_loadu_ps(a.add(offset));
405 let b1 = _mm512_loadu_ps(b.add(offset));
406 _mm512_storeu_ps(dst.add(offset), _mm512_add_ps(a1, b1));
407
408 let a2 = _mm512_loadu_ps(a.add(offset + 16));
409 let b2 = _mm512_loadu_ps(b.add(offset + 16));
410 _mm512_storeu_ps(dst.add(offset + 16), _mm512_add_ps(a2, b2));
411
412 let a3 = _mm512_loadu_ps(a.add(offset + 32));
413 let b3 = _mm512_loadu_ps(b.add(offset + 32));
414 _mm512_storeu_ps(dst.add(offset + 32), _mm512_add_ps(a3, b3));
415
416 let a4 = _mm512_loadu_ps(a.add(offset + 48));
417 let b4 = _mm512_loadu_ps(b.add(offset + 48));
418 _mm512_storeu_ps(dst.add(offset + 48), _mm512_add_ps(a4, b4));
419 offset += block;
420 }
421 let mut rem = size - offset;
422 while rem >= 16 {
423 let av = _mm512_loadu_ps(a.add(offset));
424 let bv = _mm512_loadu_ps(b.add(offset));
425 _mm512_storeu_ps(dst.add(offset), _mm512_add_ps(av, bv));
426 offset += 16;
427 rem -= 16;
428 }
429 for i in offset..size {
430 *dst.add(i) = *a.add(i) + *b.add(i);
431 }
432 }
433
434 #[cfg(target_arch = "x86_64")]
435 #[inline]
436 #[target_feature(enable = "avx512f")]
437 unsafe fn add_simd_avx512_stream(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
438 use std::arch::x86_64::*;
439 let mut offset = 0usize;
440 let block = 64usize;
441 while offset + block <= size {
442 let a1 = _mm512_loadu_ps(a.add(offset));
443 let b1 = _mm512_loadu_ps(b.add(offset));
444 _mm512_stream_ps(dst.add(offset), _mm512_add_ps(a1, b1));
445
446 let a2 = _mm512_loadu_ps(a.add(offset + 16));
447 let b2 = _mm512_loadu_ps(b.add(offset + 16));
448 _mm512_stream_ps(dst.add(offset + 16), _mm512_add_ps(a2, b2));
449
450 let a3 = _mm512_loadu_ps(a.add(offset + 32));
451 let b3 = _mm512_loadu_ps(b.add(offset + 32));
452 _mm512_stream_ps(dst.add(offset + 32), _mm512_add_ps(a3, b3));
453
454 let a4 = _mm512_loadu_ps(a.add(offset + 48));
455 let b4 = _mm512_loadu_ps(b.add(offset + 48));
456 _mm512_stream_ps(dst.add(offset + 48), _mm512_add_ps(a4, b4));
457 offset += block;
458 }
459 let mut rem = size - offset;
460 while rem >= 16 {
461 let av = _mm512_loadu_ps(a.add(offset));
462 let bv = _mm512_loadu_ps(b.add(offset));
463 _mm512_stream_ps(dst.add(offset), _mm512_add_ps(av, bv));
464 offset += 16;
465 rem -= 16;
466 }
467 for i in offset..size {
468 *dst.add(i) = *a.add(i) + *b.add(i);
469 }
470 }
471
472 #[cfg(target_arch = "x86_64")]
473 #[inline]
474 #[target_feature(enable = "avx2")]
475 unsafe fn add_simd_avx2_aligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
476 use std::arch::x86_64::*;
477 let mut offset = 0usize;
478 let block = 32usize; while offset + block <= size {
480 let a1 = _mm256_load_ps(a.add(offset));
481 let b1 = _mm256_load_ps(b.add(offset));
482 _mm256_store_ps(dst.add(offset), _mm256_add_ps(a1, b1));
483
484 let a2 = _mm256_load_ps(a.add(offset + 8));
485 let b2 = _mm256_load_ps(b.add(offset + 8));
486 _mm256_store_ps(dst.add(offset + 8), _mm256_add_ps(a2, b2));
487
488 let a3 = _mm256_load_ps(a.add(offset + 16));
489 let b3 = _mm256_load_ps(b.add(offset + 16));
490 _mm256_store_ps(dst.add(offset + 16), _mm256_add_ps(a3, b3));
491
492 let a4 = _mm256_load_ps(a.add(offset + 24));
493 let b4 = _mm256_load_ps(b.add(offset + 24));
494 _mm256_store_ps(dst.add(offset + 24), _mm256_add_ps(a4, b4));
495 offset += block;
496 }
497 let mut rem = size - offset;
498 while rem >= 8 {
499 let av = _mm256_load_ps(a.add(offset));
500 let bv = _mm256_load_ps(b.add(offset));
501 _mm256_store_ps(dst.add(offset), _mm256_add_ps(av, bv));
502 offset += 8;
503 rem -= 8;
504 }
505 for i in offset..size {
506 *dst.add(i) = *a.add(i) + *b.add(i);
507 }
508 }
509
510 #[cfg(target_arch = "x86_64")]
511 #[inline]
512 #[target_feature(enable = "avx2")]
513 unsafe fn add_simd_avx2_unaligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
514 use std::arch::x86_64::*;
515 let mut offset = 0usize;
516 let block = 32usize;
517 while offset + block <= size {
518 let a1 = _mm256_loadu_ps(a.add(offset));
519 let b1 = _mm256_loadu_ps(b.add(offset));
520 _mm256_storeu_ps(dst.add(offset), _mm256_add_ps(a1, b1));
521
522 let a2 = _mm256_loadu_ps(a.add(offset + 8));
523 let b2 = _mm256_loadu_ps(b.add(offset + 8));
524 _mm256_storeu_ps(dst.add(offset + 8), _mm256_add_ps(a2, b2));
525
526 let a3 = _mm256_loadu_ps(a.add(offset + 16));
527 let b3 = _mm256_loadu_ps(b.add(offset + 16));
528 _mm256_storeu_ps(dst.add(offset + 16), _mm256_add_ps(a3, b3));
529
530 let a4 = _mm256_loadu_ps(a.add(offset + 24));
531 let b4 = _mm256_loadu_ps(b.add(offset + 24));
532 _mm256_storeu_ps(dst.add(offset + 24), _mm256_add_ps(a4, b4));
533 offset += block;
534 }
535 let mut rem = size - offset;
536 while rem >= 8 {
537 let av = _mm256_loadu_ps(a.add(offset));
538 let bv = _mm256_loadu_ps(b.add(offset));
539 _mm256_storeu_ps(dst.add(offset), _mm256_add_ps(av, bv));
540 offset += 8;
541 rem -= 8;
542 }
543 for i in offset..size {
544 *dst.add(i) = *a.add(i) + *b.add(i);
545 }
546 }
547
548 #[cfg(target_arch = "x86_64")]
549 #[inline]
550 #[target_feature(enable = "avx2")]
551 unsafe fn add_simd_avx2_stream(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
552 use std::arch::x86_64::*;
553 let mut offset = 0usize;
554 let block = 32usize;
555 while offset + block <= size {
556 let a1 = _mm256_loadu_ps(a.add(offset));
557 let b1 = _mm256_loadu_ps(b.add(offset));
558 _mm256_stream_ps(dst.add(offset), _mm256_add_ps(a1, b1));
559
560 let a2 = _mm256_loadu_ps(a.add(offset + 8));
561 let b2 = _mm256_loadu_ps(b.add(offset + 8));
562 _mm256_stream_ps(dst.add(offset + 8), _mm256_add_ps(a2, b2));
563
564 let a3 = _mm256_loadu_ps(a.add(offset + 16));
565 let b3 = _mm256_loadu_ps(b.add(offset + 16));
566 _mm256_stream_ps(dst.add(offset + 16), _mm256_add_ps(a3, b3));
567
568 let a4 = _mm256_loadu_ps(a.add(offset + 24));
569 let b4 = _mm256_loadu_ps(b.add(offset + 24));
570 _mm256_stream_ps(dst.add(offset + 24), _mm256_add_ps(a4, b4));
571 offset += block;
572 }
573 let mut rem = size - offset;
574 while rem >= 8 {
575 let av = _mm256_loadu_ps(a.add(offset));
576 let bv = _mm256_loadu_ps(b.add(offset));
577 _mm256_stream_ps(dst.add(offset), _mm256_add_ps(av, bv));
578 offset += 8;
579 rem -= 8;
580 }
581 for i in offset..size {
582 *dst.add(i) = *a.add(i) + *b.add(i);
583 }
584 }
585
586 #[cfg(target_arch = "x86_64")]
587 #[inline]
588 #[target_feature(enable = "sse2")]
589 unsafe fn add_simd_sse_aligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
590 use std::arch::x86_64::*;
591 let mut offset = 0usize;
592 let block = 16usize; while offset + block <= size {
594 let a1 = _mm_load_ps(a.add(offset));
595 let b1 = _mm_load_ps(b.add(offset));
596 _mm_store_ps(dst.add(offset), _mm_add_ps(a1, b1));
597
598 let a2 = _mm_load_ps(a.add(offset + 4));
599 let b2 = _mm_load_ps(b.add(offset + 4));
600 _mm_store_ps(dst.add(offset + 4), _mm_add_ps(a2, b2));
601
602 let a3 = _mm_load_ps(a.add(offset + 8));
603 let b3 = _mm_load_ps(b.add(offset + 8));
604 _mm_store_ps(dst.add(offset + 8), _mm_add_ps(a3, b3));
605
606 let a4 = _mm_load_ps(a.add(offset + 12));
607 let b4 = _mm_load_ps(b.add(offset + 12));
608 _mm_store_ps(dst.add(offset + 12), _mm_add_ps(a4, b4));
609 offset += block;
610 }
611 let mut rem = size - offset;
612 while rem >= 4 {
613 let av = _mm_load_ps(a.add(offset));
614 let bv = _mm_load_ps(b.add(offset));
615 _mm_store_ps(dst.add(offset), _mm_add_ps(av, bv));
616 offset += 4;
617 rem -= 4;
618 }
619 for i in offset..size {
620 *dst.add(i) = *a.add(i) + *b.add(i);
621 }
622 }
623
624 #[cfg(target_arch = "x86_64")]
625 #[inline]
626 #[target_feature(enable = "sse2")]
627 unsafe fn add_simd_sse_unaligned(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
628 use std::arch::x86_64::*;
629 let mut offset = 0usize;
630 let block = 16usize;
631 while offset + block <= size {
632 let a1 = _mm_loadu_ps(a.add(offset));
633 let b1 = _mm_loadu_ps(b.add(offset));
634 _mm_storeu_ps(dst.add(offset), _mm_add_ps(a1, b1));
635
636 let a2 = _mm_loadu_ps(a.add(offset + 4));
637 let b2 = _mm_loadu_ps(b.add(offset + 4));
638 _mm_storeu_ps(dst.add(offset + 4), _mm_add_ps(a2, b2));
639
640 let a3 = _mm_loadu_ps(a.add(offset + 8));
641 let b3 = _mm_loadu_ps(b.add(offset + 8));
642 _mm_storeu_ps(dst.add(offset + 8), _mm_add_ps(a3, b3));
643
644 let a4 = _mm_loadu_ps(a.add(offset + 12));
645 let b4 = _mm_loadu_ps(b.add(offset + 12));
646 _mm_storeu_ps(dst.add(offset + 12), _mm_add_ps(a4, b4));
647 offset += block;
648 }
649 let mut rem = size - offset;
650 while rem >= 4 {
651 let s = _mm_loadu_ps(a.add(offset));
652 let t = _mm_loadu_ps(b.add(offset));
653 _mm_storeu_ps(dst.add(offset), _mm_add_ps(s, t));
654 offset += 4;
655 rem -= 4;
656 }
657 for i in offset..size {
658 *dst.add(i) = *a.add(i) + *b.add(i);
659 }
660 }
661
662 #[cfg(target_arch = "x86_64")]
663 #[inline]
664 #[target_feature(enable = "sse2")]
665 unsafe fn add_simd_sse_stream(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
666 use std::arch::x86_64::*;
667 let mut offset = 0usize;
668 let block = 16usize;
669 while offset + block <= size {
670 let a1 = _mm_loadu_ps(a.add(offset));
671 let b1 = _mm_loadu_ps(b.add(offset));
672 _mm_stream_ps(dst.add(offset), _mm_add_ps(a1, b1));
673
674 let a2 = _mm_loadu_ps(a.add(offset + 4));
675 let b2 = _mm_loadu_ps(b.add(offset + 4));
676 _mm_stream_ps(dst.add(offset + 4), _mm_add_ps(a2, b2));
677
678 let a3 = _mm_loadu_ps(a.add(offset + 8));
679 let b3 = _mm_loadu_ps(b.add(offset + 8));
680 _mm_stream_ps(dst.add(offset + 8), _mm_add_ps(a3, b3));
681
682 let a4 = _mm_loadu_ps(a.add(offset + 12));
683 let b4 = _mm_loadu_ps(b.add(offset + 12));
684 _mm_stream_ps(dst.add(offset + 12), _mm_add_ps(a4, b4));
685 offset += block;
686 }
687 let mut rem = size - offset;
688 while rem >= 4 {
689 let s = _mm_loadu_ps(a.add(offset));
690 let t = _mm_loadu_ps(b.add(offset));
691 _mm_stream_ps(dst.add(offset), _mm_add_ps(s, t));
692 offset += 4;
693 rem -= 4;
694 }
695 for i in offset..size {
696 *dst.add(i) = *a.add(i) + *b.add(i);
697 }
698 }
699
700 #[inline]
702 unsafe fn try_add_scalar_simd_best(
704 src: *const f32,
705 dst: *mut f32,
706 size: usize,
707 scalar: f32,
708 ) -> bool {
709 if size == 0 {
710 return true;
711 }
712
713 let kernels = Self::get_cached_kernels();
714
715 if matches!(kernels.simd_level, SimdLevel::Scalar) {
717 return false;
718 }
719
720 let s_mod = (src as usize) % kernels.alignment;
722 let d_mod = (dst as usize) % kernels.alignment;
723
724 if s_mod == 0 && d_mod == 0 && size >= kernels.min_aligned_size {
726 (kernels.scalar_aligned)(src, dst, size, scalar);
727 return true;
728 }
729
730 if s_mod == d_mod && size >= kernels.min_aligned_size {
732 let bytes_to_align = if s_mod == 0 {
733 0
734 } else {
735 kernels.alignment - s_mod
736 };
737 let elems_to_align = (bytes_to_align / std::mem::size_of::<f32>()).min(size);
738
739 for i in 0..elems_to_align {
741 *dst.add(i) = *src.add(i) + scalar;
742 }
743
744 let rem = size - elems_to_align;
745 if rem >= kernels.min_aligned_size {
746 (kernels.scalar_aligned)(
747 src.add(elems_to_align),
748 dst.add(elems_to_align),
749 rem,
750 scalar,
751 );
752 }
753 return true;
754 }
755
756 (kernels.scalar_unaligned)(src, dst, size, scalar);
758 true
759 }
760
761 #[inline]
763 unsafe fn try_add_scalar_stream_best(
764 src: *const f32,
765 dst: *mut f32,
766 size: usize,
767 scalar: f32,
768 ) -> bool {
769 let kernels = Self::get_cached_kernels();
770
771 if size < kernels.min_stream_size || size == 0 {
772 return false;
773 }
774
775 if (dst as usize).is_multiple_of(kernels.alignment) {
777 (kernels.scalar_stream)(src, dst, size, scalar);
778 return true;
779 }
780
781 false
782 }
783
784 #[cfg(target_arch = "x86_64")]
785 #[inline]
786 #[target_feature(enable = "avx512f")]
787 unsafe fn add_scalar_avx512_aligned(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
788 use std::arch::x86_64::*;
789 let sv = _mm512_set1_ps(scalar);
790 let mut offset = 0usize;
791 let block = 64usize; while offset + block <= size {
793 let s1 = _mm512_load_ps(src.add(offset));
794 _mm512_store_ps(dst.add(offset), _mm512_add_ps(s1, sv));
795 let s2 = _mm512_load_ps(src.add(offset + 16));
796 _mm512_store_ps(dst.add(offset + 16), _mm512_add_ps(s2, sv));
797 let s3 = _mm512_load_ps(src.add(offset + 32));
798 _mm512_store_ps(dst.add(offset + 32), _mm512_add_ps(s3, sv));
799 let s4 = _mm512_load_ps(src.add(offset + 48));
800 _mm512_store_ps(dst.add(offset + 48), _mm512_add_ps(s4, sv));
801 offset += block;
802 }
803 let mut rem = size - offset;
804 while rem >= 16 {
805 let s = _mm512_load_ps(src.add(offset));
806 _mm512_store_ps(dst.add(offset), _mm512_add_ps(s, sv));
807 offset += 16;
808 rem -= 16;
809 }
810 for i in offset..size {
811 *dst.add(i) = *src.add(i) + scalar;
812 }
813 }
814
815 #[cfg(target_arch = "x86_64")]
816 #[inline]
817 #[target_feature(enable = "avx512f")]
818 unsafe fn add_scalar_avx512_unaligned(
819 src: *const f32,
820 dst: *mut f32,
821 size: usize,
822 scalar: f32,
823 ) {
824 use std::arch::x86_64::*;
825 let sv = _mm512_set1_ps(scalar);
826 let mut offset = 0usize;
827 let block = 64usize;
828 while offset + block <= size {
829 let s1 = _mm512_loadu_ps(src.add(offset));
830 _mm512_storeu_ps(dst.add(offset), _mm512_add_ps(s1, sv));
831 let s2 = _mm512_loadu_ps(src.add(offset + 16));
832 _mm512_storeu_ps(dst.add(offset + 16), _mm512_add_ps(s2, sv));
833 let s3 = _mm512_loadu_ps(src.add(offset + 32));
834 _mm512_storeu_ps(dst.add(offset + 32), _mm512_add_ps(s3, sv));
835 let s4 = _mm512_loadu_ps(src.add(offset + 48));
836 _mm512_storeu_ps(dst.add(offset + 48), _mm512_add_ps(s4, sv));
837 offset += block;
838 }
839 let mut rem = size - offset;
840 while rem >= 16 {
841 let s = _mm512_loadu_ps(src.add(offset));
842 _mm512_storeu_ps(dst.add(offset), _mm512_add_ps(s, sv));
843 offset += 16;
844 rem -= 16;
845 }
846 for i in offset..size {
847 *dst.add(i) = *src.add(i) + scalar;
848 }
849 }
850
851 #[cfg(target_arch = "x86_64")]
852 #[inline]
853 #[target_feature(enable = "avx512f")]
854 unsafe fn add_scalar_avx512_stream(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
855 use std::arch::x86_64::*;
856 let sv = _mm512_set1_ps(scalar);
857 let mut offset = 0usize;
858 let block = 64usize;
859 while offset + block <= size {
860 let s1 = _mm512_loadu_ps(src.add(offset));
861 _mm512_stream_ps(dst.add(offset), _mm512_add_ps(s1, sv));
862 let s2 = _mm512_loadu_ps(src.add(offset + 16));
863 _mm512_stream_ps(dst.add(offset + 16), _mm512_add_ps(s2, sv));
864 let s3 = _mm512_loadu_ps(src.add(offset + 32));
865 _mm512_stream_ps(dst.add(offset + 32), _mm512_add_ps(s3, sv));
866 let s4 = _mm512_loadu_ps(src.add(offset + 48));
867 _mm512_stream_ps(dst.add(offset + 48), _mm512_add_ps(s4, sv));
868 offset += block;
869 }
870 let mut rem = size - offset;
871 while rem >= 16 {
872 let s = _mm512_loadu_ps(src.add(offset));
873 _mm512_stream_ps(dst.add(offset), _mm512_add_ps(s, sv));
874 offset += 16;
875 rem -= 16;
876 }
877 for i in offset..size {
878 *dst.add(i) = *src.add(i) + scalar;
879 }
880 }
881
882 #[cfg(target_arch = "x86_64")]
883 #[inline]
884 #[target_feature(enable = "avx2")]
885 unsafe fn add_scalar_avx2_aligned(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
886 use std::arch::x86_64::*;
887 let sv = _mm256_set1_ps(scalar);
888 let mut offset = 0usize;
889 let block = 32usize;
890 while offset + block <= size {
891 let s1 = _mm256_load_ps(src.add(offset));
892 _mm256_store_ps(dst.add(offset), _mm256_add_ps(s1, sv));
893 let s2 = _mm256_load_ps(src.add(offset + 8));
894 _mm256_store_ps(dst.add(offset + 8), _mm256_add_ps(s2, sv));
895 let s3 = _mm256_load_ps(src.add(offset + 16));
896 _mm256_store_ps(dst.add(offset + 16), _mm256_add_ps(s3, sv));
897 let s4 = _mm256_load_ps(src.add(offset + 24));
898 _mm256_store_ps(dst.add(offset + 24), _mm256_add_ps(s4, sv));
899 offset += block;
900 }
901 let mut rem = size - offset;
902 while rem >= 8 {
903 let s = _mm256_load_ps(src.add(offset));
904 _mm256_store_ps(dst.add(offset), _mm256_add_ps(s, sv));
905 offset += 8;
906 rem -= 8;
907 }
908 for i in offset..size {
909 *dst.add(i) = *src.add(i) + scalar;
910 }
911 }
912
913 #[cfg(target_arch = "x86_64")]
914 #[inline]
915 #[target_feature(enable = "avx2")]
916 unsafe fn add_scalar_avx2_unaligned(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
917 use std::arch::x86_64::*;
918 let sv = _mm256_set1_ps(scalar);
919 let mut offset = 0usize;
920 let block = 32usize;
921 while offset + block <= size {
922 let s1 = _mm256_loadu_ps(src.add(offset));
923 _mm256_storeu_ps(dst.add(offset), _mm256_add_ps(s1, sv));
924 let s2 = _mm256_loadu_ps(src.add(offset + 8));
925 _mm256_storeu_ps(dst.add(offset + 8), _mm256_add_ps(s2, sv));
926 let s3 = _mm256_loadu_ps(src.add(offset + 16));
927 _mm256_storeu_ps(dst.add(offset + 16), _mm256_add_ps(s3, sv));
928 let s4 = _mm256_loadu_ps(src.add(offset + 24));
929 _mm256_storeu_ps(dst.add(offset + 24), _mm256_add_ps(s4, sv));
930 offset += block;
931 }
932 let mut rem = size - offset;
933 while rem >= 8 {
934 let s = _mm256_loadu_ps(src.add(offset));
935 _mm256_storeu_ps(dst.add(offset), _mm256_add_ps(s, sv));
936 offset += 8;
937 rem -= 8;
938 }
939 for i in offset..size {
940 *dst.add(i) = *src.add(i) + scalar;
941 }
942 }
943
944 #[cfg(target_arch = "x86_64")]
945 #[inline]
946 #[target_feature(enable = "avx2")]
947 unsafe fn add_scalar_avx2_stream(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
948 use std::arch::x86_64::*;
949 let sv = _mm256_set1_ps(scalar);
950 let mut offset = 0usize;
951 let block = 32usize;
952 while offset + block <= size {
953 let s1 = _mm256_loadu_ps(src.add(offset));
954 _mm256_stream_ps(dst.add(offset), _mm256_add_ps(s1, sv));
955 let s2 = _mm256_loadu_ps(src.add(offset + 8));
956 _mm256_stream_ps(dst.add(offset + 8), _mm256_add_ps(s2, sv));
957 let s3 = _mm256_loadu_ps(src.add(offset + 16));
958 _mm256_stream_ps(dst.add(offset + 16), _mm256_add_ps(s3, sv));
959 let s4 = _mm256_loadu_ps(src.add(offset + 24));
960 _mm256_stream_ps(dst.add(offset + 24), _mm256_add_ps(s4, sv));
961 offset += block;
962 }
963 let mut rem = size - offset;
964 while rem >= 8 {
965 let s = _mm256_loadu_ps(src.add(offset));
966 _mm256_stream_ps(dst.add(offset), _mm256_add_ps(s, sv));
967 offset += 8;
968 rem -= 8;
969 }
970 for i in offset..size {
971 *dst.add(i) = *src.add(i) + scalar;
972 }
973 }
974
975 #[cfg(target_arch = "x86_64")]
976 #[inline]
977 #[target_feature(enable = "sse2")]
978 unsafe fn add_scalar_sse_aligned(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
979 use std::arch::x86_64::*;
980 let sv = _mm_set1_ps(scalar);
981 let mut offset = 0usize;
982 let block = 16usize; while offset + block <= size {
984 let s1 = _mm_load_ps(src.add(offset));
985 _mm_store_ps(dst.add(offset), _mm_add_ps(s1, sv));
986 let s2 = _mm_load_ps(src.add(offset + 4));
987 _mm_store_ps(dst.add(offset + 4), _mm_add_ps(s2, sv));
988 let s3 = _mm_load_ps(src.add(offset + 8));
989 _mm_store_ps(dst.add(offset + 8), _mm_add_ps(s3, sv));
990 let s4 = _mm_load_ps(src.add(offset + 12));
991 _mm_store_ps(dst.add(offset + 12), _mm_add_ps(s4, sv));
992 offset += block;
993 }
994 let mut rem = size - offset;
995 while rem >= 4 {
996 let s = _mm_load_ps(src.add(offset));
997 _mm_store_ps(dst.add(offset), _mm_add_ps(s, sv));
998 offset += 4;
999 rem -= 4;
1000 }
1001 for i in offset..size {
1002 *dst.add(i) = *src.add(i) + scalar;
1003 }
1004 }
1005
1006 #[cfg(target_arch = "x86_64")]
1007 #[inline]
1008 #[target_feature(enable = "sse2")]
1009 unsafe fn add_scalar_sse_unaligned(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
1010 use std::arch::x86_64::*;
1011 let sv = _mm_set1_ps(scalar);
1012 let mut offset = 0usize;
1013 let block = 16usize;
1014 while offset + block <= size {
1015 let s1 = _mm_loadu_ps(src.add(offset));
1016 _mm_storeu_ps(dst.add(offset), _mm_add_ps(s1, sv));
1017 let s2 = _mm_loadu_ps(src.add(offset + 4));
1018 _mm_storeu_ps(dst.add(offset + 4), _mm_add_ps(s2, sv));
1019 let s3 = _mm_loadu_ps(src.add(offset + 8));
1020 _mm_storeu_ps(dst.add(offset + 8), _mm_add_ps(s3, sv));
1021 let s4 = _mm_loadu_ps(src.add(offset + 12));
1022 _mm_storeu_ps(dst.add(offset + 12), _mm_add_ps(s4, sv));
1023 offset += block;
1024 }
1025 let mut rem = size - offset;
1026 while rem >= 4 {
1027 let s = _mm_loadu_ps(src.add(offset));
1028 _mm_storeu_ps(dst.add(offset), _mm_add_ps(s, sv));
1029 offset += 4;
1030 rem -= 4;
1031 }
1032 for i in offset..size {
1033 *dst.add(i) = *src.add(i) + scalar;
1034 }
1035 }
1036
1037 #[cfg(target_arch = "x86_64")]
1038 #[inline]
1039 #[target_feature(enable = "sse2")]
1040 unsafe fn add_scalar_sse_stream(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
1041 use std::arch::x86_64::*;
1042 let sv = _mm_set1_ps(scalar);
1043 let mut offset = 0usize;
1044 let block = 16usize;
1045 while offset + block <= size {
1046 let s1 = _mm_loadu_ps(src.add(offset));
1047 _mm_stream_ps(dst.add(offset), _mm_add_ps(s1, sv));
1048 let s2 = _mm_loadu_ps(src.add(offset + 4));
1049 _mm_stream_ps(dst.add(offset + 4), _mm_add_ps(s2, sv));
1050 let s3 = _mm_loadu_ps(src.add(offset + 8));
1051 _mm_stream_ps(dst.add(offset + 8), _mm_add_ps(s3, sv));
1052 let s4 = _mm_loadu_ps(src.add(offset + 12));
1053 _mm_stream_ps(dst.add(offset + 12), _mm_add_ps(s4, sv));
1054 offset += block;
1055 }
1056 let mut rem = size - offset;
1057 while rem >= 4 {
1058 let s = _mm_loadu_ps(src.add(offset));
1059 _mm_stream_ps(dst.add(offset), _mm_add_ps(s, sv));
1060 offset += 4;
1061 rem -= 4;
1062 }
1063 for i in offset..size {
1064 *dst.add(i) = *src.add(i) + scalar;
1065 }
1066 }
1067
1068 #[inline]
1070 unsafe fn add_tensors_scalar_chunk(a: *const f32, b: *const f32, dst: *mut f32, size: usize) {
1071 let unroll_count = size / 8;
1072 let mut offset = 0;
1073
1074 for _ in 0..unroll_count {
1075 *dst.add(offset) = *a.add(offset) + *b.add(offset);
1076 *dst.add(offset + 1) = *a.add(offset + 1) + *b.add(offset + 1);
1077 *dst.add(offset + 2) = *a.add(offset + 2) + *b.add(offset + 2);
1078 *dst.add(offset + 3) = *a.add(offset + 3) + *b.add(offset + 3);
1079 *dst.add(offset + 4) = *a.add(offset + 4) + *b.add(offset + 4);
1080 *dst.add(offset + 5) = *a.add(offset + 5) + *b.add(offset + 5);
1081 *dst.add(offset + 6) = *a.add(offset + 6) + *b.add(offset + 6);
1082 *dst.add(offset + 7) = *a.add(offset + 7) + *b.add(offset + 7);
1083 offset += 8;
1084 }
1085 for i in offset..size {
1086 *dst.add(i) = *a.add(i) + *b.add(i);
1087 }
1088 }
1089
1090 #[inline]
1092 pub(crate) fn add_scalar_optimized(&self, scalar: f32) -> Tensor {
1093 let (src_ptr, _src_keep): (*const f32, Option<Tensor>) =
1095 Self::get_optimized_tensor_ptr(self);
1096 let mut output = Tensor::new(self.shape().dims().to_vec());
1097
1098 unsafe {
1099 let src = src_ptr;
1100 let dst = output.as_mut_ptr();
1101 let n = self.size();
1102
1103 let stream_min = Self::stream_min_elems();
1105 if n >= stream_min && Self::try_add_scalar_stream_best(src, dst, n, scalar) {
1106 } else if !Self::try_add_scalar_simd_best(src, dst, n, scalar) {
1108 Self::add_scalar_fallback_chunk(src, dst, n, scalar);
1109 }
1110 }
1111
1112 output
1113 }
1114
1115 #[inline]
1119 unsafe fn add_scalar_fallback_chunk(src: *const f32, dst: *mut f32, size: usize, scalar: f32) {
1120 let unroll_count = size / 8;
1121 let mut offset = 0;
1122
1123 for _ in 0..unroll_count {
1124 *dst.add(offset) = *src.add(offset) + scalar;
1125 *dst.add(offset + 1) = *src.add(offset + 1) + scalar;
1126 *dst.add(offset + 2) = *src.add(offset + 2) + scalar;
1127 *dst.add(offset + 3) = *src.add(offset + 3) + scalar;
1128 *dst.add(offset + 4) = *src.add(offset + 4) + scalar;
1129 *dst.add(offset + 5) = *src.add(offset + 5) + scalar;
1130 *dst.add(offset + 6) = *src.add(offset + 6) + scalar;
1131 *dst.add(offset + 7) = *src.add(offset + 7) + scalar;
1132 offset += 8;
1133 }
1134 for i in offset..size {
1135 *dst.add(i) = *src.add(i) + scalar;
1136 }
1137 }
1138
1139 #[inline]
1148 fn get_optimized_tensor_ptr(tensor: &Tensor) -> (*const f32, Option<Tensor>) {
1149 unsafe {
1150 if tensor.is_contiguous() {
1151 (tensor.as_ptr(), None)
1152 } else {
1153 let tmp = tensor.contiguous();
1155 (tmp.as_ptr(), Some(tmp))
1156 }
1157 }
1158 }
1159
1160 #[inline]
1165 #[allow(dead_code)]
1166 fn can_use_stride_based_access(_tensor: &Tensor) -> bool {
1167 false
1168 }
1169
1170 #[inline]
1174 fn get_cached_kernels() -> &'static CachedKernels {
1175 use std::sync::OnceLock;
1176
1177 static CACHED_KERNELS: OnceLock<CachedKernels> = OnceLock::new();
1178
1179 CACHED_KERNELS.get_or_init(|| {
1180 let simd_level = detect_runtime_simd();
1181 #[cfg(target_arch = "x86_64")]
1182 let alignment = simd_alignment_bytes(simd_level);
1183
1184 #[cfg(target_arch = "x86_64")]
1185 {
1186 match simd_level {
1187 SimdLevel::Avx512 => CachedKernels {
1188 simd_level,
1189 alignment,
1190 tensor_aligned: Self::add_simd_avx512_aligned,
1191 tensor_unaligned: Self::add_simd_avx512_unaligned,
1192 tensor_stream: Self::add_simd_avx512_stream,
1193 scalar_aligned: Self::add_scalar_avx512_aligned,
1194 scalar_unaligned: Self::add_scalar_avx512_unaligned,
1195 scalar_stream: Self::add_scalar_avx512_stream,
1196 min_aligned_size: 16,
1197 min_stream_size: Self::stream_min_elems(),
1198 },
1199 SimdLevel::Avx2 => CachedKernels {
1200 simd_level,
1201 alignment,
1202 tensor_aligned: Self::add_simd_avx2_aligned,
1203 tensor_unaligned: Self::add_simd_avx2_unaligned,
1204 tensor_stream: Self::add_simd_avx2_stream,
1205 scalar_aligned: Self::add_scalar_avx2_aligned,
1206 scalar_unaligned: Self::add_scalar_avx2_unaligned,
1207 scalar_stream: Self::add_scalar_avx2_stream,
1208 min_aligned_size: 8,
1209 min_stream_size: Self::stream_min_elems(),
1210 },
1211 SimdLevel::Sse2 => CachedKernels {
1212 simd_level,
1213 alignment,
1214 tensor_aligned: Self::add_simd_sse_aligned,
1215 tensor_unaligned: Self::add_simd_sse_unaligned,
1216 tensor_stream: Self::add_simd_sse_stream,
1217 scalar_aligned: Self::add_scalar_sse_aligned,
1218 scalar_unaligned: Self::add_scalar_sse_unaligned,
1219 scalar_stream: Self::add_scalar_sse_stream,
1220 min_aligned_size: 4,
1221 min_stream_size: Self::stream_min_elems(),
1222 },
1223 SimdLevel::Scalar => CachedKernels {
1224 simd_level,
1225 alignment: 4, tensor_aligned: Self::add_tensors_scalar_chunk,
1227 tensor_unaligned: Self::add_tensors_scalar_chunk,
1228 tensor_stream: Self::add_tensors_scalar_chunk,
1229 scalar_aligned: Self::add_scalar_fallback_chunk,
1230 scalar_unaligned: Self::add_scalar_fallback_chunk,
1231 scalar_stream: Self::add_scalar_fallback_chunk,
1232 min_aligned_size: 1,
1233 min_stream_size: usize::MAX, },
1235 }
1236 }
1237
1238 #[cfg(not(target_arch = "x86_64"))]
1239 {
1240 CachedKernels {
1241 simd_level,
1242 alignment: 4, tensor_aligned: Self::add_tensors_scalar_chunk,
1244 tensor_unaligned: Self::add_tensors_scalar_chunk,
1245 tensor_stream: Self::add_tensors_scalar_chunk,
1246 scalar_aligned: Self::add_scalar_fallback_chunk,
1247 scalar_unaligned: Self::add_scalar_fallback_chunk,
1248 scalar_stream: Self::add_scalar_fallback_chunk,
1249 min_aligned_size: 1,
1250 min_stream_size: usize::MAX,
1251 }
1252 }
1253 })
1254 }
1255}
1256
1257#[cfg(test)]
1258mod tests {
1259 use super::*;
1260 use std::sync::Arc;
1261 use std::thread;
1262
1263 #[test]
1264 fn test_tensor_addition() {
1265 let a = Tensor::ones(vec![2, 3]);
1266 let b = Tensor::ones(vec![2, 3]);
1267 let result = a.add_tensor_optimized(&b);
1268
1269 assert_eq!(result.shape().dims(), vec![2, 3]);
1270 assert_eq!(result.size(), 6);
1271
1272 unsafe {
1274 for i in 0..result.size() {
1275 assert!((result.as_ptr().add(i).read() - 2.0).abs() < 1e-6);
1276 }
1277 }
1278 }
1279
1280 #[test]
1281 fn test_thread_safety_cross_thread_ops() {
1282 use crate::gradtrack::clear_gradients;
1283 clear_gradients();
1284
1285 let a = Arc::new(Tensor::ones(vec![2, 3]).with_requires_grad());
1287 let b = Arc::new(Tensor::ones(vec![2, 3]).with_requires_grad());
1288
1289 let a1 = a.clone();
1291 let b1 = b.clone();
1292 let handle = thread::spawn(move || {
1293 let t_local1 = Tensor::from_slice(&[2.0; 6], vec![2, 3]).unwrap();
1294 let r1 = (*a1).add_tensor(&t_local1); let t_local2 = Tensor::ones(vec![2, 3]);
1297 let r2 = t_local2.add_tensor(&b1); let combined = r1.add_tensor(&r2); let mut loss = combined.sum();
1301 loss.backward(None);
1302
1303 let ga = (*a1).grad_owned().expect("grad for a (thread)");
1304 let gb = (*b1).grad_owned().expect("grad for b (thread)");
1305 let ga_sum = unsafe { (0..ga.size()).map(|i| *ga.as_ptr().add(i)).sum::<f32>() };
1306 let gb_sum = unsafe { (0..gb.size()).map(|i| *gb.as_ptr().add(i)).sum::<f32>() };
1307 (
1308 ga.shape().dims().to_vec(),
1309 gb.shape().dims().to_vec(),
1310 ga_sum,
1311 gb_sum,
1312 )
1313 });
1314
1315 let (ga_dims, gb_dims, ga_sum, gb_sum) = handle.join().expect("worker panicked");
1316 assert_eq!(ga_dims, vec![2, 3]);
1317 assert_eq!(gb_dims, vec![2, 3]);
1318 assert!((ga_sum - 6.0).abs() < 1e-6);
1320 assert!((gb_sum - 6.0).abs() < 1e-6);
1321 }
1322
1323 #[test]
1324 fn test_thread_safety_parallel_large_add_backward() {
1325 use crate::gradtrack::clear_gradients;
1326 clear_gradients();
1327
1328 let n = 8_388_608; let a = Arc::new(Tensor::ones(vec![n]).with_requires_grad());
1331 let b = Arc::new(Tensor::ones(vec![n]).with_requires_grad());
1332
1333 let at = a.clone();
1335 let bt = b.clone();
1336 let handle = thread::spawn(move || {
1337 let result = (*at).add_tensor(&bt);
1338 let mut loss = result.sum();
1339 loss.backward(None);
1340 let ga = (*at).grad_owned().expect("grad for a (thread)");
1341 let gb = (*bt).grad_owned().expect("grad for b (thread)");
1342 let ga_sum = unsafe { (0..ga.size()).map(|i| *ga.as_ptr().add(i)).sum::<f32>() };
1344 let gb_sum = unsafe { (0..gb.size()).map(|i| *gb.as_ptr().add(i)).sum::<f32>() };
1345 (
1346 ga.shape().dims().to_vec(),
1347 gb.shape().dims().to_vec(),
1348 ga_sum,
1349 gb_sum,
1350 )
1351 });
1352
1353 let (ga_dims, gb_dims, ga_sum, gb_sum) = handle.join().expect("worker thread panicked");
1354 assert_eq!(ga_dims, vec![n]);
1355 assert_eq!(gb_dims, vec![n]);
1356 assert!((ga_sum - n as f32).abs() < 1e-3);
1358 assert!((gb_sum - n as f32).abs() < 1e-3);
1359 }
1360
1361 #[test]
1362 fn test_scalar_addition() {
1363 let tensor = Tensor::ones(vec![2, 2]);
1364 let result = tensor.add_scalar_optimized(5.0);
1365
1366 assert_eq!(result.shape().dims(), vec![2, 2]);
1367 assert_eq!(result.size(), 4);
1368
1369 unsafe {
1371 for i in 0..result.size() {
1372 assert!((result.as_ptr().add(i).read() - 6.0).abs() < 1e-6);
1373 }
1374 }
1375 }
1376
1377 #[test]
1378 #[should_panic(expected = "Cannot broadcast tensor shapes")]
1379 fn test_mismatched_shapes() {
1380 let a = Tensor::ones(vec![2, 3]);
1381 let b = Tensor::ones(vec![3, 2]);
1382 a.add_tensor_optimized(&b);
1383 }
1384
1385 #[test]
1386 fn test_add_with_no_grad_guard() {
1387 use crate::gradtrack::{is_grad_enabled, NoGradTrack};
1388
1389 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
1391 let b = Tensor::ones(vec![2, 2]).with_requires_grad();
1392
1393 assert!(is_grad_enabled());
1395
1396 let c1 = a.add_tensor(&b);
1398 assert!(
1399 c1.requires_grad(),
1400 "Result should require gradients normally"
1401 );
1402
1403 {
1405 let _guard = NoGradTrack::new();
1406 assert!(
1407 !is_grad_enabled(),
1408 "Gradients should be disabled within guard"
1409 );
1410
1411 let c2 = a.add_tensor(&b);
1412 assert!(
1413 !c2.requires_grad(),
1414 "Result should not require gradients within NoGradTrack"
1415 );
1416
1417 let c3 = a.add_scalar(5.0);
1419 assert!(
1420 !c3.requires_grad(),
1421 "Scalar addition result should not require gradients within NoGradTrack"
1422 );
1423 }
1424
1425 assert!(
1427 is_grad_enabled(),
1428 "Gradients should be restored after guard"
1429 );
1430
1431 let c4 = a.add_tensor(&b);
1432 assert!(
1433 c4.requires_grad(),
1434 "Result should require gradients after guard is dropped"
1435 );
1436 }
1437
1438 #[test]
1439 fn test_add_nested_no_grad_guards() {
1440 use crate::gradtrack::{is_grad_enabled, NoGradTrack};
1441
1442 let a = Tensor::ones(vec![2, 2]).with_requires_grad();
1443 let b = Tensor::ones(vec![2, 2]).with_requires_grad();
1444
1445 assert!(is_grad_enabled());
1446
1447 {
1448 let _guard1 = NoGradTrack::new();
1449 assert!(!is_grad_enabled());
1450
1451 let c1 = a.add_tensor(&b);
1452 assert!(!c1.requires_grad());
1453
1454 {
1455 let _guard2 = NoGradTrack::new();
1456 assert!(!is_grad_enabled());
1457
1458 let c2 = a.add_tensor(&b);
1459 assert!(!c2.requires_grad());
1460 }
1461
1462 assert!(!is_grad_enabled());
1464 let c3 = a.add_tensor(&b);
1465 assert!(!c3.requires_grad());
1466 }
1467
1468 assert!(is_grad_enabled());
1470 let c4 = a.add_tensor(&b);
1471 assert!(c4.requires_grad());
1472 }
1473
1474 #[test]
1475 fn test_add_with_mixed_requires_grad() {
1476 use crate::gradtrack::NoGradTrack;
1477
1478 let a = Tensor::ones(vec![2, 2]).with_requires_grad(); let b = Tensor::ones(vec![2, 2]); let c1 = a.add_tensor(&b);
1483 assert!(c1.requires_grad());
1484
1485 let c2 = b.add_tensor(&a);
1486 assert!(c2.requires_grad());
1487
1488 {
1490 let _guard = NoGradTrack::new();
1491
1492 let c3 = a.add_tensor(&b);
1493 assert!(!c3.requires_grad());
1494
1495 let c4 = b.add_tensor(&a);
1496 assert!(!c4.requires_grad());
1497 }
1498 }
1499
1500 #[test]
1501 fn test_broadcasting_gradients_basic() {
1502 use crate::gradtrack::clear_gradients;
1503 clear_gradients();
1504
1505 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1509 .unwrap()
1510 .with_requires_grad();
1511 let b = Tensor::from_slice(&[0.1, 0.2, 0.3], vec![1, 3])
1512 .unwrap()
1513 .with_requires_grad();
1514
1515 let mut result = a.add_tensor(&b);
1516 assert_eq!(result.shape().dims(), vec![2, 3]);
1517
1518 result.backward(None);
1520
1521 let grad_a = a.grad_owned().expect("grad_a should exist");
1523 let grad_b = b.grad_owned().expect("grad_b should exist");
1524
1525 println!(
1526 "Original shapes: a={:?}, b={:?}",
1527 a.shape().dims(),
1528 b.shape().dims()
1529 );
1530 println!(
1531 "Gradient shapes: grad_a={:?}, grad_b={:?}",
1532 grad_a.shape().dims(),
1533 grad_b.shape().dims()
1534 );
1535
1536 assert_eq!(
1538 grad_a.shape().dims(),
1539 vec![2, 3],
1540 "grad_a should match original shape of a"
1541 );
1542
1543 assert_eq!(
1546 grad_b.shape().dims(),
1547 vec![1, 3],
1548 "grad_b should match original shape of b"
1549 );
1550
1551 for i in 0..grad_a.size() {
1553 let val = unsafe { *grad_a.as_ptr().add(i) };
1554 assert!(
1555 (val - 1.0).abs() < 1e-6,
1556 "grad_a[{}] = {} should be 1.0",
1557 i,
1558 val
1559 );
1560 }
1561
1562 let expected_grad_b = [2.0, 2.0, 2.0];
1564 for (i, val) in expected_grad_b.iter().enumerate().take(grad_b.size()) {
1565 let actual = unsafe { *grad_b.as_ptr().add(i) };
1566 assert!(
1567 (actual - val).abs() < 1e-6,
1568 "grad_b[{}] = {} should be {}",
1569 i,
1570 actual,
1571 val
1572 );
1573 }
1574 }
1575
1576 #[test]
1577 fn test_scalar_broadcasting_gradients() {
1578 use crate::gradtrack::clear_gradients;
1579 clear_gradients();
1580
1581 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1585 .unwrap()
1586 .with_requires_grad();
1587 let b = Tensor::from_slice(&[0.5], vec![1])
1588 .unwrap()
1589 .with_requires_grad();
1590
1591 let mut result = a.add_tensor(&b);
1592 result.backward(None);
1593
1594 let grad_a = a.grad_owned().expect("grad_a should exist");
1595 let grad_b = b.grad_owned().expect("grad_b should exist");
1596
1597 assert_eq!(grad_a.shape().dims(), vec![2, 3]);
1599
1600 println!("grad_b shape: {:?}, expected: [1]", grad_b.shape().dims());
1602 assert_eq!(grad_b.shape().dims(), vec![1]);
1603
1604 let val = unsafe { *grad_b.as_ptr() };
1606 assert!((val - 6.0).abs() < 1e-6, "grad_b = {} should be 6.0", val);
1607 }
1608
1609 #[test]
1610 fn test_linear_layer_bias_broadcasting() {
1611 use crate::gradtrack::clear_gradients;
1612 clear_gradients();
1613
1614 let input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
1619 .unwrap()
1620 .with_requires_grad();
1621 let weight = Tensor::from_slice(
1622 &(1..=12).map(|i| i as f32 * 0.1).collect::<Vec<_>>(),
1623 vec![3, 4],
1624 )
1625 .unwrap()
1626 .with_requires_grad();
1627 let bias = Tensor::from_slice(&[0.1, 0.2, 0.3, 0.4], vec![4])
1628 .unwrap()
1629 .with_requires_grad();
1630
1631 let matmul_result = input.matmul(&weight);
1633 println!("Matmul result shape: {:?}", matmul_result.shape().dims());
1634 println!("Bias shape: {:?}", bias.shape().dims());
1635
1636 let linear_output = matmul_result.add_tensor(&bias);
1637 println!("Linear output shape: {:?}", linear_output.shape().dims());
1638
1639 let mut loss = linear_output.sum();
1641 loss.backward(None);
1642
1643 let bias_grad = bias.grad_owned().expect("bias gradient should exist");
1645 println!("Bias gradient shape: {:?}", bias_grad.shape().dims());
1646 assert_eq!(
1647 bias_grad.shape().dims(),
1648 vec![4],
1649 "bias gradient should match bias shape"
1650 );
1651
1652 for i in 0..4 {
1654 let val = unsafe { *bias_grad.as_ptr().add(i) };
1655 assert!(
1656 (val - 2.0).abs() < 1e-6,
1657 "bias_grad[{}] = {} should be 2.0",
1658 i,
1659 val
1660 );
1661 }
1662
1663 println!("Linear layer bias broadcasting test passed!");
1664 }
1665}