train_station/tensor/core/operators.rs
1//! Tensor operator overloading for natural mathematical expressions
2//!
3//! This module provides comprehensive operator overloading for tensor operations,
4//! enabling natural mathematical expressions with tensors and scalars. All operators
5//! are zero-cost abstractions that delegate to the underlying tensor operations.
6//!
7//! # Supported Operations
8//!
9//! ## Tensor-Tensor Operations
10//! - **Addition**: `Tensor + Tensor`, `&Tensor + &Tensor`, `Tensor + &Tensor`, `&Tensor + Tensor`
11//! - **Subtraction**: `Tensor - Tensor`, `&Tensor - &Tensor`, `Tensor - &Tensor`, `&Tensor - Tensor`
12//! - **Multiplication**: `Tensor * Tensor`, `&Tensor * &Tensor`, `Tensor * &Tensor`, `&Tensor * Tensor`
13//! - **Division**: `Tensor / Tensor`, `&Tensor / &Tensor`, `Tensor / &Tensor`, `&Tensor / Tensor`
14//!
15//! ## Tensor-Scalar Operations
16//! - **Addition**: `Tensor + f32`, `&Tensor + f32`, `f32 + Tensor`, `f32 + &Tensor`
17//! - **Subtraction**: `Tensor - f32`, `&Tensor - f32`, `f32 - Tensor`, `f32 - &Tensor`
18//! - **Multiplication**: `Tensor * f32`, `&Tensor * f32`, `f32 * Tensor`, `f32 * &Tensor`
19//! - **Division**: `Tensor / f32`, `&Tensor / f32`, `f32 / Tensor`, `f32 / &Tensor`
20//!
21//! ## Assignment Operations
22//! - **In-place addition**: `Tensor += Tensor`, `Tensor += &Tensor`, `Tensor += f32`
23//! - **In-place subtraction**: `Tensor -= Tensor`, `Tensor -= &Tensor`, `Tensor -= f32`
24//! - **In-place multiplication**: `Tensor *= Tensor`, `Tensor *= &Tensor`, `Tensor *= f32`
25//! - **In-place division**: `Tensor /= Tensor`, `Tensor /= &Tensor`, `Tensor /= f32`
26//!
27//! ## Unary Operations
28//! - **Negation**: `-Tensor`, `-&Tensor`
29//!
30//! # Performance Characteristics
31//!
32//! - **Zero-Cost Abstractions**: All operators have no runtime overhead
33//! - **SIMD Optimization**: Underlying operations use SIMD acceleration
34//! - **Memory Efficiency**: Operations are optimized for cache performance
35//! - **Thread Safety**: All operations are thread-safe
36//!
37//! # Examples
38//!
39//! ## Basic Tensor Operations
40//!
41//! ```
42//! use train_station::Tensor;
43//!
44//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
45//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
46//!
47//! // Tensor addition
48//! let result = a.clone() + b.clone();
49//! assert_eq!(result.get(&[0, 0]), 6.0);
50//!
51//! // Element-wise multiplication
52//! let result = a.clone() * b.clone();
53//! assert_eq!(result.get(&[0, 0]), 5.0);
54//! ```
55//!
56//! ## Scalar Operations
57//!
58//! ```
59//! use train_station::Tensor;
60//!
61//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
62//!
63//! // Tensor + scalar
64//! let result = tensor.clone() + 5.0;
65//! assert_eq!(result.get(&[0, 0]), 6.0);
66//!
67//! // Scalar + tensor
68//! let result = 5.0 + tensor.clone();
69//! assert_eq!(result.get(&[0, 0]), 6.0);
70//!
71//! // Tensor * scalar
72//! let result = tensor.clone() * 3.0;
73//! assert_eq!(result.get(&[0, 0]), 3.0);
74//! ```
75//!
76//! ## Compound Expressions
77//!
78//! ```
79//! use train_station::Tensor;
80//!
81//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
82//! let b = Tensor::from_slice(&[2.0, 3.0, 4.0, 5.0], vec![2, 2]).unwrap();
83//!
84//! // Complex mathematical expression
85//! let result = (a.clone() + b.clone()) * 2.0 - 1.0;
86//! assert_eq!(result.get(&[0, 0]), 5.0); // (1+2)*2-1 = 5
87//! ```
88//!
89//! ## Assignment Operators
90//!
91//! ```
92//! use train_station::Tensor;
93//!
94//! let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
95//!
96//! // In-place operations
97//! tensor += 5.0;
98//! assert_eq!(tensor.get(&[0, 0]), 6.0);
99//!
100//! tensor *= 2.0;
101//! assert_eq!(tensor.get(&[0, 0]), 12.0);
102//! ```
103//!
104//! # Thread Safety
105//!
106//! All operator implementations are thread-safe and can be used concurrently
107//! across multiple threads. Operations on different tensors can be performed
108//! simultaneously without synchronization.
109
110use super::Tensor;
111
112use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
113
114// ===== Tensor-Tensor Operations =====
115
116/// Tensor addition operator implementations
117///
118/// Provides addition operations between tensors with various reference combinations.
119/// All implementations delegate to the underlying `add_tensor` method for optimal performance.
120impl Add for Tensor {
121 type Output = Tensor;
122
123 /// Adds two tensors element-wise
124 ///
125 /// # Returns
126 ///
127 /// A new tensor containing the element-wise sum
128 fn add(self, other: Tensor) -> Tensor {
129 self.add_tensor(&other)
130 }
131}
132
133impl Add for &Tensor {
134 type Output = Tensor;
135
136 /// Adds two tensors element-wise (reference version)
137 ///
138 /// # Returns
139 ///
140 /// A new tensor containing the element-wise sum
141 fn add(self, other: &Tensor) -> Tensor {
142 self.add_tensor(other)
143 }
144}
145
146impl Add<&Tensor> for Tensor {
147 type Output = Tensor;
148
149 /// Adds a tensor and a tensor reference element-wise
150 ///
151 /// # Returns
152 ///
153 /// A new tensor containing the element-wise sum
154 fn add(self, other: &Tensor) -> Tensor {
155 self.add_tensor(other)
156 }
157}
158
159impl Add<Tensor> for &Tensor {
160 type Output = Tensor;
161
162 /// Adds a tensor reference and a tensor element-wise
163 ///
164 /// # Returns
165 ///
166 /// A new tensor containing the element-wise sum
167 fn add(self, other: Tensor) -> Tensor {
168 self.add_tensor(&other)
169 }
170}
171
172/// Tensor addition assignment operator implementations
173///
174/// Provides in-place addition operations between tensors.
175/// All implementations delegate to the underlying `add_tensor` method.
176impl AddAssign for Tensor {
177 /// Adds another tensor to this tensor in-place
178 fn add_assign(&mut self, other: Tensor) {
179 *self = self.add_tensor(&other);
180 }
181}
182
183impl AddAssign<&Tensor> for Tensor {
184 /// Adds another tensor reference to this tensor in-place
185 fn add_assign(&mut self, other: &Tensor) {
186 *self = self.add_tensor(other);
187 }
188}
189
190/// Tensor subtraction operator implementations
191///
192/// Provides subtraction operations between tensors with various reference combinations.
193/// All implementations delegate to the underlying `sub_tensor` method for optimal performance.
194impl Sub for Tensor {
195 type Output = Tensor;
196
197 /// Subtracts two tensors element-wise
198 ///
199 /// # Returns
200 ///
201 /// A new tensor containing the element-wise difference
202 fn sub(self, other: Tensor) -> Tensor {
203 self.sub_tensor(&other)
204 }
205}
206
207impl Sub for &Tensor {
208 type Output = Tensor;
209
210 /// Subtracts two tensors element-wise (reference version)
211 ///
212 /// # Returns
213 ///
214 /// A new tensor containing the element-wise difference
215 fn sub(self, other: &Tensor) -> Tensor {
216 self.sub_tensor(other)
217 }
218}
219
220impl Sub<&Tensor> for Tensor {
221 type Output = Tensor;
222
223 /// Subtracts a tensor reference from a tensor element-wise
224 ///
225 /// # Returns
226 ///
227 /// A new tensor containing the element-wise difference
228 fn sub(self, other: &Tensor) -> Tensor {
229 self.sub_tensor(other)
230 }
231}
232
233impl Sub<Tensor> for &Tensor {
234 type Output = Tensor;
235
236 /// Subtracts a tensor from a tensor reference element-wise
237 ///
238 /// # Returns
239 ///
240 /// A new tensor containing the element-wise difference
241 fn sub(self, other: Tensor) -> Tensor {
242 self.sub_tensor(&other)
243 }
244}
245
246/// Tensor subtraction assignment operator implementations
247///
248/// Provides in-place subtraction operations between tensors.
249/// All implementations delegate to the underlying `sub_tensor` method.
250impl SubAssign for Tensor {
251 /// Subtracts another tensor from this tensor in-place
252 fn sub_assign(&mut self, other: Tensor) {
253 *self = self.sub_tensor(&other);
254 }
255}
256
257impl SubAssign<&Tensor> for Tensor {
258 /// Subtracts another tensor reference from this tensor in-place
259 fn sub_assign(&mut self, other: &Tensor) {
260 *self = self.sub_tensor(other);
261 }
262}
263
264/// Tensor multiplication operator implementations
265///
266/// Provides element-wise multiplication operations between tensors with various reference combinations.
267/// All implementations delegate to the underlying `mul_tensor` method for optimal performance.
268impl Mul for Tensor {
269 type Output = Tensor;
270
271 /// Multiplies two tensors element-wise
272 ///
273 /// # Returns
274 ///
275 /// A new tensor containing the element-wise product
276 fn mul(self, other: Tensor) -> Tensor {
277 self.mul_tensor(&other)
278 }
279}
280
281impl Mul for &Tensor {
282 type Output = Tensor;
283
284 /// Multiplies two tensors element-wise (reference version)
285 ///
286 /// # Returns
287 ///
288 /// A new tensor containing the element-wise product
289 fn mul(self, other: &Tensor) -> Tensor {
290 self.mul_tensor(other)
291 }
292}
293
294impl Mul<&Tensor> for Tensor {
295 type Output = Tensor;
296
297 /// Multiplies a tensor and a tensor reference element-wise
298 ///
299 /// # Returns
300 ///
301 /// A new tensor containing the element-wise product
302 fn mul(self, other: &Tensor) -> Tensor {
303 self.mul_tensor(other)
304 }
305}
306
307impl Mul<Tensor> for &Tensor {
308 type Output = Tensor;
309
310 /// Multiplies a tensor reference and a tensor element-wise
311 ///
312 /// # Returns
313 ///
314 /// A new tensor containing the element-wise product
315 fn mul(self, other: Tensor) -> Tensor {
316 self.mul_tensor(&other)
317 }
318}
319
320/// Tensor multiplication assignment operator implementations
321///
322/// Provides in-place multiplication operations between tensors.
323/// All implementations delegate to the underlying `mul_tensor` method.
324impl MulAssign for Tensor {
325 /// Multiplies this tensor by another tensor in-place
326 fn mul_assign(&mut self, other: Tensor) {
327 *self = self.mul_tensor(&other);
328 }
329}
330
331impl MulAssign<&Tensor> for Tensor {
332 /// Multiplies this tensor by another tensor reference in-place
333 fn mul_assign(&mut self, other: &Tensor) {
334 *self = self.mul_tensor(other);
335 }
336}
337
338/// Tensor division operator implementations
339///
340/// Provides element-wise division operations between tensors with various reference combinations.
341/// All implementations delegate to the underlying `div_tensor` method for optimal performance.
342impl Div for Tensor {
343 type Output = Tensor;
344
345 /// Divides two tensors element-wise
346 ///
347 /// # Returns
348 ///
349 /// A new tensor containing the element-wise quotient
350 fn div(self, other: Tensor) -> Tensor {
351 self.div_tensor(&other)
352 }
353}
354
355impl Div for &Tensor {
356 type Output = Tensor;
357
358 /// Divides two tensors element-wise (reference version)
359 ///
360 /// # Returns
361 ///
362 /// A new tensor containing the element-wise quotient
363 fn div(self, other: &Tensor) -> Tensor {
364 self.div_tensor(other)
365 }
366}
367
368impl Div<&Tensor> for Tensor {
369 type Output = Tensor;
370
371 /// Divides a tensor by a tensor reference element-wise
372 ///
373 /// # Returns
374 ///
375 /// A new tensor containing the element-wise quotient
376 fn div(self, other: &Tensor) -> Tensor {
377 self.div_tensor(other)
378 }
379}
380
381impl Div<Tensor> for &Tensor {
382 type Output = Tensor;
383
384 /// Divides a tensor reference by a tensor element-wise
385 ///
386 /// # Returns
387 ///
388 /// A new tensor containing the element-wise quotient
389 fn div(self, other: Tensor) -> Tensor {
390 self.div_tensor(&other)
391 }
392}
393
394/// Tensor division assignment operator implementations
395///
396/// Provides in-place division operations between tensors.
397/// All implementations delegate to the underlying `div_tensor` method.
398impl DivAssign for Tensor {
399 /// Divides this tensor by another tensor in-place
400 fn div_assign(&mut self, other: Tensor) {
401 *self = self.div_tensor(&other);
402 }
403}
404
405impl DivAssign<&Tensor> for Tensor {
406 /// Divides this tensor by another tensor reference in-place
407 fn div_assign(&mut self, other: &Tensor) {
408 *self = self.div_tensor(other);
409 }
410}
411
412// ===== Scalar Operations =====
413
414/// Tensor-scalar addition operator implementations
415///
416/// Provides addition operations between tensors and scalars.
417/// All implementations delegate to the underlying `add_scalar` method.
418impl Add<f32> for Tensor {
419 type Output = Tensor;
420
421 /// Adds a scalar to each element of the tensor
422 ///
423 /// # Returns
424 ///
425 /// A new tensor with the scalar added to each element
426 fn add(self, scalar: f32) -> Tensor {
427 self.add_scalar(scalar)
428 }
429}
430
431impl Add<f32> for &Tensor {
432 type Output = Tensor;
433
434 /// Adds a scalar to each element of the tensor (reference version)
435 ///
436 /// # Returns
437 ///
438 /// A new tensor with the scalar added to each element
439 fn add(self, scalar: f32) -> Tensor {
440 self.add_scalar(scalar)
441 }
442}
443
444/// Scalar-tensor addition operator implementations
445///
446/// Provides addition operations between scalars and tensors.
447/// All implementations delegate to the underlying `add_scalar` method.
448impl Add<Tensor> for f32 {
449 type Output = Tensor;
450
451 /// Adds a scalar to each element of the tensor
452 ///
453 /// # Returns
454 ///
455 /// A new tensor with the scalar added to each element
456 fn add(self, tensor: Tensor) -> Tensor {
457 tensor.add_scalar(self)
458 }
459}
460
461impl Add<&Tensor> for f32 {
462 type Output = Tensor;
463
464 /// Adds a scalar to each element of the tensor (reference version)
465 ///
466 /// # Returns
467 ///
468 /// A new tensor with the scalar added to each element
469 fn add(self, tensor: &Tensor) -> Tensor {
470 tensor.add_scalar(self)
471 }
472}
473
474/// Tensor-scalar addition assignment operator implementations
475///
476/// Provides in-place addition operations between tensors and scalars.
477impl AddAssign<f32> for Tensor {
478 /// Adds a scalar to each element of this tensor in-place
479 fn add_assign(&mut self, scalar: f32) {
480 *self = self.add_scalar(scalar);
481 }
482}
483
484/// Tensor-scalar subtraction operator implementations
485///
486/// Provides subtraction operations between tensors and scalars.
487/// All implementations delegate to the underlying `sub_scalar` method.
488impl Sub<f32> for Tensor {
489 type Output = Tensor;
490
491 /// Subtracts a scalar from each element of the tensor
492 ///
493 /// # Returns
494 ///
495 /// A new tensor with the scalar subtracted from each element
496 fn sub(self, scalar: f32) -> Tensor {
497 self.sub_scalar(scalar)
498 }
499}
500
501impl Sub<f32> for &Tensor {
502 type Output = Tensor;
503
504 /// Subtracts a scalar from each element of the tensor (reference version)
505 ///
506 /// # Returns
507 ///
508 /// A new tensor with the scalar subtracted from each element
509 fn sub(self, scalar: f32) -> Tensor {
510 self.sub_scalar(scalar)
511 }
512}
513
514/// Scalar-tensor subtraction operator implementations
515///
516/// Provides subtraction operations between scalars and tensors.
517/// Computes `scalar - tensor` by negating the tensor and adding the scalar.
518impl Sub<Tensor> for f32 {
519 type Output = Tensor;
520
521 /// Subtracts each element of the tensor from the scalar
522 ///
523 /// # Returns
524 ///
525 /// A new tensor with each element subtracted from the scalar
526 fn sub(self, tensor: Tensor) -> Tensor {
527 // For scalar - tensor, we need to negate the tensor and add the scalar
528 // This is equivalent to: scalar + (-tensor)
529 let mut result = tensor;
530 result.negate_inplace();
531 result.add_scalar(self)
532 }
533}
534
535impl Sub<&Tensor> for f32 {
536 type Output = Tensor;
537
538 /// Subtracts each element of the tensor from the scalar (reference version)
539 ///
540 /// # Returns
541 ///
542 /// A new tensor with each element subtracted from the scalar
543 fn sub(self, tensor: &Tensor) -> Tensor {
544 // For scalar - tensor, we need to negate the tensor and add the scalar
545 let mut result = tensor.clone();
546 result.negate_inplace();
547 result.add_scalar(self)
548 }
549}
550
551/// Tensor-scalar subtraction assignment operator implementations
552///
553/// Provides in-place subtraction operations between tensors and scalars.
554impl SubAssign<f32> for Tensor {
555 /// Subtracts a scalar from each element of this tensor in-place
556 fn sub_assign(&mut self, scalar: f32) {
557 *self = self.sub_scalar(scalar);
558 }
559}
560
561/// Tensor-scalar multiplication operator implementations
562///
563/// Provides multiplication operations between tensors and scalars.
564/// All implementations delegate to the underlying `mul_scalar` method.
565impl Mul<f32> for Tensor {
566 type Output = Tensor;
567
568 /// Multiplies each element of the tensor by a scalar
569 ///
570 /// # Returns
571 ///
572 /// A new tensor with each element multiplied by the scalar
573 fn mul(self, scalar: f32) -> Tensor {
574 self.mul_scalar(scalar)
575 }
576}
577
578impl Mul<f32> for &Tensor {
579 type Output = Tensor;
580
581 /// Multiplies each element of the tensor by a scalar (reference version)
582 ///
583 /// # Returns
584 ///
585 /// A new tensor with each element multiplied by the scalar
586 fn mul(self, scalar: f32) -> Tensor {
587 self.mul_scalar(scalar)
588 }
589}
590
591/// Scalar-tensor multiplication operator implementations
592///
593/// Provides multiplication operations between scalars and tensors.
594/// All implementations delegate to the underlying `mul_scalar` method.
595impl Mul<Tensor> for f32 {
596 type Output = Tensor;
597
598 /// Multiplies each element of the tensor by a scalar
599 ///
600 /// # Returns
601 ///
602 /// A new tensor with each element multiplied by the scalar
603 fn mul(self, tensor: Tensor) -> Tensor {
604 tensor.mul_scalar(self)
605 }
606}
607
608impl Mul<&Tensor> for f32 {
609 type Output = Tensor;
610
611 /// Multiplies each element of the tensor by a scalar (reference version)
612 ///
613 /// # Returns
614 ///
615 /// A new tensor with each element multiplied by the scalar
616 fn mul(self, tensor: &Tensor) -> Tensor {
617 tensor.mul_scalar(self)
618 }
619}
620
621/// Tensor-scalar multiplication assignment operator implementations
622///
623/// Provides in-place multiplication operations between tensors and scalars.
624impl MulAssign<f32> for Tensor {
625 /// Multiplies each element of this tensor by a scalar in-place
626 fn mul_assign(&mut self, scalar: f32) {
627 *self = self.mul_scalar(scalar);
628 }
629}
630
631/// Tensor-scalar division operator implementations
632///
633/// Provides division operations between tensors and scalars.
634/// All implementations delegate to the underlying `div_scalar` method.
635impl Div<f32> for Tensor {
636 type Output = Tensor;
637
638 /// Divides each element of the tensor by a scalar
639 ///
640 /// # Returns
641 ///
642 /// A new tensor with each element divided by the scalar
643 fn div(self, scalar: f32) -> Tensor {
644 self.div_scalar(scalar)
645 }
646}
647
648impl Div<f32> for &Tensor {
649 type Output = Tensor;
650
651 /// Divides each element of the tensor by a scalar (reference version)
652 ///
653 /// # Returns
654 ///
655 /// A new tensor with each element divided by the scalar
656 fn div(self, scalar: f32) -> Tensor {
657 self.div_scalar(scalar)
658 }
659}
660
661/// Scalar-tensor division operator implementations
662///
663/// Provides division operations between scalars and tensors.
664/// Computes `scalar / tensor` by computing the reciprocal of the tensor and multiplying by the scalar.
665impl Div<Tensor> for f32 {
666 type Output = Tensor;
667
668 /// Divides a scalar by each element of the tensor
669 ///
670 /// # Returns
671 ///
672 /// A new tensor with the scalar divided by each element
673 fn div(self, tensor: Tensor) -> Tensor {
674 // For scalar / tensor, we need to compute scalar / each element
675 // This is equivalent to: scalar * (1 / tensor)
676 tensor.pow_scalar(-1.0).mul_scalar(self)
677 }
678}
679
680impl Div<&Tensor> for f32 {
681 type Output = Tensor;
682
683 /// Divides a scalar by each element of the tensor (reference version)
684 ///
685 /// # Returns
686 ///
687 /// A new tensor with the scalar divided by each element
688 fn div(self, tensor: &Tensor) -> Tensor {
689 // For scalar / tensor, we need to compute scalar / each element
690 tensor.pow_scalar(-1.0).mul_scalar(self)
691 }
692}
693
694/// Tensor-scalar division assignment operator implementations
695///
696/// Provides in-place division operations between tensors and scalars.
697impl DivAssign<f32> for Tensor {
698 /// Divides each element of this tensor by a scalar in-place
699 fn div_assign(&mut self, scalar: f32) {
700 *self = self.div_scalar(scalar);
701 }
702}
703
704// ===== Negation =====
705
706use std::ops::Neg;
707
708/// Tensor negation operator implementations
709///
710/// Provides unary negation operations for tensors.
711/// All implementations delegate to the underlying `mul_scalar` method with -1.0.
712impl Neg for Tensor {
713 type Output = Tensor;
714
715 /// Negates each element of the tensor
716 ///
717 /// # Returns
718 ///
719 /// A new tensor with each element negated
720 fn neg(self) -> Tensor {
721 self.mul_scalar(-1.0)
722 }
723}
724
725impl Neg for &Tensor {
726 type Output = Tensor;
727
728 /// Negates each element of the tensor (reference version)
729 ///
730 /// # Returns
731 ///
732 /// A new tensor with each element negated
733 fn neg(self) -> Tensor {
734 self.mul_scalar(-1.0)
735 }
736}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741
742 // ===== Operator Overloading Tests =====
743
744 /// Test tensor addition operator overloading
745 ///
746 /// Verifies that all tensor addition operator combinations work correctly:
747 /// Tensor + Tensor, &Tensor + &Tensor, Tensor + &Tensor, &Tensor + Tensor,
748 /// and assignment operators (+=).
749 #[test]
750 fn test_tensor_addition_operators() {
751 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
752 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
753
754 // Tensor + Tensor
755 let result = a.clone() + b.clone();
756 assert_eq!(result.get(&[0, 0]), 6.0);
757 assert_eq!(result.get(&[0, 1]), 8.0);
758 assert_eq!(result.get(&[1, 0]), 10.0);
759 assert_eq!(result.get(&[1, 1]), 12.0);
760
761 // &Tensor + &Tensor
762 let result = &a + &b;
763 assert_eq!(result.get(&[0, 0]), 6.0);
764
765 // Tensor + &Tensor
766 let result = a.clone() + &b;
767 assert_eq!(result.get(&[0, 0]), 6.0);
768
769 // &Tensor + Tensor
770 let result = &a + b.clone();
771 assert_eq!(result.get(&[0, 0]), 6.0);
772
773 // Tensor += Tensor
774 let mut c = a.clone();
775 c += b.clone();
776 assert_eq!(c.get(&[0, 0]), 6.0);
777
778 // Tensor += &Tensor
779 let mut c = a.clone();
780 c += &b;
781 assert_eq!(c.get(&[0, 0]), 6.0);
782 }
783
784 #[test]
785 fn test_tensor_subtraction_operators() {
786 let a = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
787 let b = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
788
789 // Tensor - Tensor
790 let result = a.clone() - b.clone();
791 assert_eq!(result.get(&[0, 0]), 4.0);
792 assert_eq!(result.get(&[0, 1]), 4.0);
793 assert_eq!(result.get(&[1, 0]), 4.0);
794 assert_eq!(result.get(&[1, 1]), 4.0);
795
796 // &Tensor - &Tensor
797 let result = &a - &b;
798 assert_eq!(result.get(&[0, 0]), 4.0);
799
800 // Tensor - &Tensor
801 let result = a.clone() - &b;
802 assert_eq!(result.get(&[0, 0]), 4.0);
803
804 // &Tensor - Tensor
805 let result = &a - b.clone();
806 assert_eq!(result.get(&[0, 0]), 4.0);
807
808 // Tensor -= Tensor
809 let mut c = a.clone();
810 c -= b.clone();
811 assert_eq!(c.get(&[0, 0]), 4.0);
812
813 // Tensor -= &Tensor
814 let mut c = a.clone();
815 c -= &b;
816 assert_eq!(c.get(&[0, 0]), 4.0);
817 }
818
819 #[test]
820 fn test_tensor_multiplication_operators() {
821 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
822 let b = Tensor::from_slice(&[2.0, 3.0, 4.0, 5.0], vec![2, 2]).unwrap();
823
824 // Tensor * Tensor
825 let result = a.clone() * b.clone();
826 assert_eq!(result.get(&[0, 0]), 2.0);
827 assert_eq!(result.get(&[0, 1]), 6.0);
828 assert_eq!(result.get(&[1, 0]), 12.0);
829 assert_eq!(result.get(&[1, 1]), 20.0);
830
831 // &Tensor * &Tensor
832 let result = &a * &b;
833 assert_eq!(result.get(&[0, 0]), 2.0);
834
835 // Tensor * &Tensor
836 let result = a.clone() * &b;
837 assert_eq!(result.get(&[0, 0]), 2.0);
838
839 // &Tensor * Tensor
840 let result = &a * b.clone();
841 assert_eq!(result.get(&[0, 0]), 2.0);
842
843 // Tensor *= Tensor
844 let mut c = a.clone();
845 c *= b.clone();
846 assert_eq!(c.get(&[0, 0]), 2.0);
847
848 // Tensor *= &Tensor
849 let mut c = a.clone();
850 c *= &b;
851 assert_eq!(c.get(&[0, 0]), 2.0);
852 }
853
854 #[test]
855 fn test_tensor_division_operators() {
856 let a = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0], vec![2, 2]).unwrap();
857 let b = Tensor::from_slice(&[2.0, 4.0, 5.0, 8.0], vec![2, 2]).unwrap();
858
859 // Tensor / Tensor
860 let result = a.clone() / b.clone();
861 assert_eq!(result.get(&[0, 0]), 5.0);
862 assert_eq!(result.get(&[0, 1]), 5.0);
863 assert_eq!(result.get(&[1, 0]), 6.0);
864 assert_eq!(result.get(&[1, 1]), 5.0);
865
866 // &Tensor / &Tensor
867 let result = &a / &b;
868 assert_eq!(result.get(&[0, 0]), 5.0);
869
870 // Tensor / &Tensor
871 let result = a.clone() / &b;
872 assert_eq!(result.get(&[0, 0]), 5.0);
873
874 // &Tensor / Tensor
875 let result = &a / b.clone();
876 assert_eq!(result.get(&[0, 0]), 5.0);
877
878 // Tensor /= Tensor
879 let mut c = a.clone();
880 c /= b.clone();
881 assert_eq!(c.get(&[0, 0]), 5.0);
882
883 // Tensor /= &Tensor
884 let mut c = a.clone();
885 c /= &b;
886 assert_eq!(c.get(&[0, 0]), 5.0);
887 }
888
889 #[test]
890 fn test_scalar_operations() {
891 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
892
893 // Tensor + f32
894 let result = a.clone() + 5.0;
895 assert_eq!(result.get(&[0, 0]), 6.0);
896 assert_eq!(result.get(&[0, 1]), 7.0);
897 assert_eq!(result.get(&[1, 0]), 8.0);
898 assert_eq!(result.get(&[1, 1]), 9.0);
899
900 // &Tensor + f32
901 let result = &a + 5.0;
902 assert_eq!(result.get(&[0, 0]), 6.0);
903
904 // f32 + Tensor
905 let result = 5.0 + a.clone();
906 assert_eq!(result.get(&[0, 0]), 6.0);
907
908 // f32 + &Tensor
909 let result = 5.0 + &a;
910 assert_eq!(result.get(&[0, 0]), 6.0);
911
912 // Tensor += f32
913 let mut b = a.clone();
914 b += 5.0;
915 assert_eq!(b.get(&[0, 0]), 6.0);
916
917 // Tensor - f32
918 let result = a.clone() - 2.0;
919 assert_eq!(result.get(&[0, 0]), -1.0);
920 assert_eq!(result.get(&[0, 1]), 0.0);
921 assert_eq!(result.get(&[1, 0]), 1.0);
922 assert_eq!(result.get(&[1, 1]), 2.0);
923
924 // &Tensor - f32
925 let result = &a - 2.0;
926 assert_eq!(result.get(&[0, 0]), -1.0);
927
928 // f32 - Tensor
929 let result = 10.0 - a.clone();
930 assert_eq!(result.get(&[0, 0]), 9.0);
931 assert_eq!(result.get(&[0, 1]), 8.0);
932 assert_eq!(result.get(&[1, 0]), 7.0);
933 assert_eq!(result.get(&[1, 1]), 6.0);
934
935 // f32 - &Tensor
936 let result = 10.0 - &a;
937 assert_eq!(result.get(&[0, 0]), 9.0);
938
939 // Tensor -= f32
940 let mut b = a.clone();
941 b -= 2.0;
942 assert_eq!(b.get(&[0, 0]), -1.0);
943
944 // Tensor * f32
945 let result = a.clone() * 3.0;
946 assert_eq!(result.get(&[0, 0]), 3.0);
947 assert_eq!(result.get(&[0, 1]), 6.0);
948 assert_eq!(result.get(&[1, 0]), 9.0);
949 assert_eq!(result.get(&[1, 1]), 12.0);
950
951 // &Tensor * f32
952 let result = &a * 3.0;
953 assert_eq!(result.get(&[0, 0]), 3.0);
954
955 // f32 * Tensor
956 let result = 3.0 * a.clone();
957 assert_eq!(result.get(&[0, 0]), 3.0);
958
959 // f32 * &Tensor
960 let result = 3.0 * &a;
961 assert_eq!(result.get(&[0, 0]), 3.0);
962
963 // Tensor *= f32
964 let mut b = a.clone();
965 b *= 3.0;
966 assert_eq!(b.get(&[0, 0]), 3.0);
967
968 // Tensor / f32
969 let result = a.clone() / 2.0;
970 assert_eq!(result.get(&[0, 0]), 0.5);
971 assert_eq!(result.get(&[0, 1]), 1.0);
972 assert_eq!(result.get(&[1, 0]), 1.5);
973 assert_eq!(result.get(&[1, 1]), 2.0);
974
975 // &Tensor / f32
976 let result = &a / 2.0;
977 assert_eq!(result.get(&[0, 0]), 0.5);
978
979 // f32 / Tensor
980 let result = 10.0 / a.clone();
981 assert_eq!(result.get(&[0, 0]), 10.0);
982 assert_eq!(result.get(&[0, 1]), 5.0);
983 assert!((result.get(&[1, 0]) - (10.0 / 3.0)).abs() < 1e-6);
984 assert_eq!(result.get(&[1, 1]), 2.5);
985
986 // f32 / &Tensor
987 let result = 10.0 / &a;
988 assert_eq!(result.get(&[0, 0]), 10.0);
989
990 // Tensor /= f32
991 let mut b = a.clone();
992 b /= 2.0;
993 assert_eq!(b.get(&[0, 0]), 0.5);
994 }
995
996 #[test]
997 fn test_negation_operator() {
998 let a = Tensor::from_slice(&[1.0, -2.0, 3.0, -4.0], vec![2, 2]).unwrap();
999
1000 // -Tensor
1001 let result = -a.clone();
1002 assert_eq!(result.get(&[0, 0]), -1.0);
1003 assert_eq!(result.get(&[0, 1]), 2.0);
1004 assert_eq!(result.get(&[1, 0]), -3.0);
1005 assert_eq!(result.get(&[1, 1]), 4.0);
1006
1007 // -&Tensor
1008 let result = -&a;
1009 assert_eq!(result.get(&[0, 0]), -1.0);
1010 }
1011
1012 /// Test complex operator chaining and compound expressions
1013 ///
1014 /// Verifies that complex mathematical expressions with multiple operators
1015 /// work correctly, including parentheses and operator precedence.
1016 #[test]
1017 fn test_operator_chaining() {
1018 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1019 let b = Tensor::from_slice(&[2.0, 3.0, 4.0, 5.0], vec![2, 2]).unwrap();
1020
1021 // Complex expression: (a + b) * 2 - 1
1022 let result = (a.clone() + b.clone()) * 2.0 - 1.0;
1023 assert_eq!(result.get(&[0, 0]), 5.0); // (1+2)*2-1 = 5
1024 assert_eq!(result.get(&[0, 1]), 9.0); // (2+3)*2-1 = 9
1025 assert_eq!(result.get(&[1, 0]), 13.0); // (3+4)*2-1 = 13
1026 assert_eq!(result.get(&[1, 1]), 17.0); // (4+5)*2-1 = 17
1027
1028 // With references: (&a + &b) * 2 - 1
1029 let result = (&a + &b) * 2.0 - 1.0;
1030 assert_eq!(result.get(&[0, 0]), 5.0);
1031 }
1032}