train_station/tensor/core/
operators.rs

1//! Tensor operator overloading for natural mathematical expressions
2//!
3//! This module provides comprehensive operator overloading for tensor operations,
4//! enabling natural mathematical expressions with tensors and scalars. All operators
5//! are zero-cost abstractions that delegate to the underlying tensor operations.
6//!
7//! # Supported Operations
8//!
9//! ## Tensor-Tensor Operations
10//! - **Addition**: `Tensor + Tensor`, `&Tensor + &Tensor`, `Tensor + &Tensor`, `&Tensor + Tensor`
11//! - **Subtraction**: `Tensor - Tensor`, `&Tensor - &Tensor`, `Tensor - &Tensor`, `&Tensor - Tensor`
12//! - **Multiplication**: `Tensor * Tensor`, `&Tensor * &Tensor`, `Tensor * &Tensor`, `&Tensor * Tensor`
13//! - **Division**: `Tensor / Tensor`, `&Tensor / &Tensor`, `Tensor / &Tensor`, `&Tensor / Tensor`
14//!
15//! ## Tensor-Scalar Operations
16//! - **Addition**: `Tensor + f32`, `&Tensor + f32`, `f32 + Tensor`, `f32 + &Tensor`
17//! - **Subtraction**: `Tensor - f32`, `&Tensor - f32`, `f32 - Tensor`, `f32 - &Tensor`
18//! - **Multiplication**: `Tensor * f32`, `&Tensor * f32`, `f32 * Tensor`, `f32 * &Tensor`
19//! - **Division**: `Tensor / f32`, `&Tensor / f32`, `f32 / Tensor`, `f32 / &Tensor`
20//!
21//! ## Assignment Operations
22//! - **In-place addition**: `Tensor += Tensor`, `Tensor += &Tensor`, `Tensor += f32`
23//! - **In-place subtraction**: `Tensor -= Tensor`, `Tensor -= &Tensor`, `Tensor -= f32`
24//! - **In-place multiplication**: `Tensor *= Tensor`, `Tensor *= &Tensor`, `Tensor *= f32`
25//! - **In-place division**: `Tensor /= Tensor`, `Tensor /= &Tensor`, `Tensor /= f32`
26//!
27//! ## Unary Operations
28//! - **Negation**: `-Tensor`, `-&Tensor`
29//!
30//! # Performance Characteristics
31//!
32//! - **Zero-Cost Abstractions**: All operators have no runtime overhead
33//! - **SIMD Optimization**: Underlying operations use SIMD acceleration
34//! - **Memory Efficiency**: Operations are optimized for cache performance
35//! - **Thread Safety**: All operations are thread-safe
36//!
37//! # Examples
38//!
39//! ## Basic Tensor Operations
40//!
41//! ```
42//! use train_station::Tensor;
43//!
44//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
45//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
46//!
47//! // Tensor addition
48//! let result = a.clone() + b.clone();
49//! assert_eq!(result.get(&[0, 0]), 6.0);
50//!
51//! // Element-wise multiplication
52//! let result = a.clone() * b.clone();
53//! assert_eq!(result.get(&[0, 0]), 5.0);
54//! ```
55//!
56//! ## Scalar Operations
57//!
58//! ```
59//! use train_station::Tensor;
60//!
61//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
62//!
63//! // Tensor + scalar
64//! let result = tensor.clone() + 5.0;
65//! assert_eq!(result.get(&[0, 0]), 6.0);
66//!
67//! // Scalar + tensor
68//! let result = 5.0 + tensor.clone();
69//! assert_eq!(result.get(&[0, 0]), 6.0);
70//!
71//! // Tensor * scalar
72//! let result = tensor.clone() * 3.0;
73//! assert_eq!(result.get(&[0, 0]), 3.0);
74//! ```
75//!
76//! ## Compound Expressions
77//!
78//! ```
79//! use train_station::Tensor;
80//!
81//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
82//! let b = Tensor::from_slice(&[2.0, 3.0, 4.0, 5.0], vec![2, 2]).unwrap();
83//!
84//! // Complex mathematical expression
85//! let result = (a.clone() + b.clone()) * 2.0 - 1.0;
86//! assert_eq!(result.get(&[0, 0]), 5.0); // (1+2)*2-1 = 5
87//! ```
88//!
89//! ## Assignment Operators
90//!
91//! ```
92//! use train_station::Tensor;
93//!
94//! let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
95//!
96//! // In-place operations
97//! tensor += 5.0;
98//! assert_eq!(tensor.get(&[0, 0]), 6.0);
99//!
100//! tensor *= 2.0;
101//! assert_eq!(tensor.get(&[0, 0]), 12.0);
102//! ```
103//!
104//! # Thread Safety
105//!
106//! All operator implementations are thread-safe and can be used concurrently
107//! across multiple threads. Operations on different tensors can be performed
108//! simultaneously without synchronization.
109
110use super::Tensor;
111
112use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
113
114// ===== Tensor-Tensor Operations =====
115
116/// Tensor addition operator implementations
117///
118/// Provides addition operations between tensors with various reference combinations.
119/// All implementations delegate to the underlying `add_tensor` method for optimal performance.
120impl Add for Tensor {
121    type Output = Tensor;
122
123    /// Adds two tensors element-wise
124    ///
125    /// # Returns
126    ///
127    /// A new tensor containing the element-wise sum
128    fn add(self, other: Tensor) -> Tensor {
129        self.add_tensor(&other)
130    }
131}
132
133impl Add for &Tensor {
134    type Output = Tensor;
135
136    /// Adds two tensors element-wise (reference version)
137    ///
138    /// # Returns
139    ///
140    /// A new tensor containing the element-wise sum
141    fn add(self, other: &Tensor) -> Tensor {
142        self.add_tensor(other)
143    }
144}
145
146impl Add<&Tensor> for Tensor {
147    type Output = Tensor;
148
149    /// Adds a tensor and a tensor reference element-wise
150    ///
151    /// # Returns
152    ///
153    /// A new tensor containing the element-wise sum
154    fn add(self, other: &Tensor) -> Tensor {
155        self.add_tensor(other)
156    }
157}
158
159impl Add<Tensor> for &Tensor {
160    type Output = Tensor;
161
162    /// Adds a tensor reference and a tensor element-wise
163    ///
164    /// # Returns
165    ///
166    /// A new tensor containing the element-wise sum
167    fn add(self, other: Tensor) -> Tensor {
168        self.add_tensor(&other)
169    }
170}
171
172/// Tensor addition assignment operator implementations
173///
174/// Provides in-place addition operations between tensors.
175/// All implementations delegate to the underlying `add_tensor` method.
176impl AddAssign for Tensor {
177    /// Adds another tensor to this tensor in-place
178    fn add_assign(&mut self, other: Tensor) {
179        *self = self.add_tensor(&other);
180    }
181}
182
183impl AddAssign<&Tensor> for Tensor {
184    /// Adds another tensor reference to this tensor in-place
185    fn add_assign(&mut self, other: &Tensor) {
186        *self = self.add_tensor(other);
187    }
188}
189
190/// Tensor subtraction operator implementations
191///
192/// Provides subtraction operations between tensors with various reference combinations.
193/// All implementations delegate to the underlying `sub_tensor` method for optimal performance.
194impl Sub for Tensor {
195    type Output = Tensor;
196
197    /// Subtracts two tensors element-wise
198    ///
199    /// # Returns
200    ///
201    /// A new tensor containing the element-wise difference
202    fn sub(self, other: Tensor) -> Tensor {
203        self.sub_tensor(&other)
204    }
205}
206
207impl Sub for &Tensor {
208    type Output = Tensor;
209
210    /// Subtracts two tensors element-wise (reference version)
211    ///
212    /// # Returns
213    ///
214    /// A new tensor containing the element-wise difference
215    fn sub(self, other: &Tensor) -> Tensor {
216        self.sub_tensor(other)
217    }
218}
219
220impl Sub<&Tensor> for Tensor {
221    type Output = Tensor;
222
223    /// Subtracts a tensor reference from a tensor element-wise
224    ///
225    /// # Returns
226    ///
227    /// A new tensor containing the element-wise difference
228    fn sub(self, other: &Tensor) -> Tensor {
229        self.sub_tensor(other)
230    }
231}
232
233impl Sub<Tensor> for &Tensor {
234    type Output = Tensor;
235
236    /// Subtracts a tensor from a tensor reference element-wise
237    ///
238    /// # Returns
239    ///
240    /// A new tensor containing the element-wise difference
241    fn sub(self, other: Tensor) -> Tensor {
242        self.sub_tensor(&other)
243    }
244}
245
246/// Tensor subtraction assignment operator implementations
247///
248/// Provides in-place subtraction operations between tensors.
249/// All implementations delegate to the underlying `sub_tensor` method.
250impl SubAssign for Tensor {
251    /// Subtracts another tensor from this tensor in-place
252    fn sub_assign(&mut self, other: Tensor) {
253        *self = self.sub_tensor(&other);
254    }
255}
256
257impl SubAssign<&Tensor> for Tensor {
258    /// Subtracts another tensor reference from this tensor in-place
259    fn sub_assign(&mut self, other: &Tensor) {
260        *self = self.sub_tensor(other);
261    }
262}
263
264/// Tensor multiplication operator implementations
265///
266/// Provides element-wise multiplication operations between tensors with various reference combinations.
267/// All implementations delegate to the underlying `mul_tensor` method for optimal performance.
268impl Mul for Tensor {
269    type Output = Tensor;
270
271    /// Multiplies two tensors element-wise
272    ///
273    /// # Returns
274    ///
275    /// A new tensor containing the element-wise product
276    fn mul(self, other: Tensor) -> Tensor {
277        self.mul_tensor(&other)
278    }
279}
280
281impl Mul for &Tensor {
282    type Output = Tensor;
283
284    /// Multiplies two tensors element-wise (reference version)
285    ///
286    /// # Returns
287    ///
288    /// A new tensor containing the element-wise product
289    fn mul(self, other: &Tensor) -> Tensor {
290        self.mul_tensor(other)
291    }
292}
293
294impl Mul<&Tensor> for Tensor {
295    type Output = Tensor;
296
297    /// Multiplies a tensor and a tensor reference element-wise
298    ///
299    /// # Returns
300    ///
301    /// A new tensor containing the element-wise product
302    fn mul(self, other: &Tensor) -> Tensor {
303        self.mul_tensor(other)
304    }
305}
306
307impl Mul<Tensor> for &Tensor {
308    type Output = Tensor;
309
310    /// Multiplies a tensor reference and a tensor element-wise
311    ///
312    /// # Returns
313    ///
314    /// A new tensor containing the element-wise product
315    fn mul(self, other: Tensor) -> Tensor {
316        self.mul_tensor(&other)
317    }
318}
319
320/// Tensor multiplication assignment operator implementations
321///
322/// Provides in-place multiplication operations between tensors.
323/// All implementations delegate to the underlying `mul_tensor` method.
324impl MulAssign for Tensor {
325    /// Multiplies this tensor by another tensor in-place
326    fn mul_assign(&mut self, other: Tensor) {
327        *self = self.mul_tensor(&other);
328    }
329}
330
331impl MulAssign<&Tensor> for Tensor {
332    /// Multiplies this tensor by another tensor reference in-place
333    fn mul_assign(&mut self, other: &Tensor) {
334        *self = self.mul_tensor(other);
335    }
336}
337
338/// Tensor division operator implementations
339///
340/// Provides element-wise division operations between tensors with various reference combinations.
341/// All implementations delegate to the underlying `div_tensor` method for optimal performance.
342impl Div for Tensor {
343    type Output = Tensor;
344
345    /// Divides two tensors element-wise
346    ///
347    /// # Returns
348    ///
349    /// A new tensor containing the element-wise quotient
350    fn div(self, other: Tensor) -> Tensor {
351        self.div_tensor(&other)
352    }
353}
354
355impl Div for &Tensor {
356    type Output = Tensor;
357
358    /// Divides two tensors element-wise (reference version)
359    ///
360    /// # Returns
361    ///
362    /// A new tensor containing the element-wise quotient
363    fn div(self, other: &Tensor) -> Tensor {
364        self.div_tensor(other)
365    }
366}
367
368impl Div<&Tensor> for Tensor {
369    type Output = Tensor;
370
371    /// Divides a tensor by a tensor reference element-wise
372    ///
373    /// # Returns
374    ///
375    /// A new tensor containing the element-wise quotient
376    fn div(self, other: &Tensor) -> Tensor {
377        self.div_tensor(other)
378    }
379}
380
381impl Div<Tensor> for &Tensor {
382    type Output = Tensor;
383
384    /// Divides a tensor reference by a tensor element-wise
385    ///
386    /// # Returns
387    ///
388    /// A new tensor containing the element-wise quotient
389    fn div(self, other: Tensor) -> Tensor {
390        self.div_tensor(&other)
391    }
392}
393
394/// Tensor division assignment operator implementations
395///
396/// Provides in-place division operations between tensors.
397/// All implementations delegate to the underlying `div_tensor` method.
398impl DivAssign for Tensor {
399    /// Divides this tensor by another tensor in-place
400    fn div_assign(&mut self, other: Tensor) {
401        *self = self.div_tensor(&other);
402    }
403}
404
405impl DivAssign<&Tensor> for Tensor {
406    /// Divides this tensor by another tensor reference in-place
407    fn div_assign(&mut self, other: &Tensor) {
408        *self = self.div_tensor(other);
409    }
410}
411
412// ===== Scalar Operations =====
413
414/// Tensor-scalar addition operator implementations
415///
416/// Provides addition operations between tensors and scalars.
417/// All implementations delegate to the underlying `add_scalar` method.
418impl Add<f32> for Tensor {
419    type Output = Tensor;
420
421    /// Adds a scalar to each element of the tensor
422    ///
423    /// # Returns
424    ///
425    /// A new tensor with the scalar added to each element
426    fn add(self, scalar: f32) -> Tensor {
427        self.add_scalar(scalar)
428    }
429}
430
431impl Add<f32> for &Tensor {
432    type Output = Tensor;
433
434    /// Adds a scalar to each element of the tensor (reference version)
435    ///
436    /// # Returns
437    ///
438    /// A new tensor with the scalar added to each element
439    fn add(self, scalar: f32) -> Tensor {
440        self.add_scalar(scalar)
441    }
442}
443
444/// Scalar-tensor addition operator implementations
445///
446/// Provides addition operations between scalars and tensors.
447/// All implementations delegate to the underlying `add_scalar` method.
448impl Add<Tensor> for f32 {
449    type Output = Tensor;
450
451    /// Adds a scalar to each element of the tensor
452    ///
453    /// # Returns
454    ///
455    /// A new tensor with the scalar added to each element
456    fn add(self, tensor: Tensor) -> Tensor {
457        tensor.add_scalar(self)
458    }
459}
460
461impl Add<&Tensor> for f32 {
462    type Output = Tensor;
463
464    /// Adds a scalar to each element of the tensor (reference version)
465    ///
466    /// # Returns
467    ///
468    /// A new tensor with the scalar added to each element
469    fn add(self, tensor: &Tensor) -> Tensor {
470        tensor.add_scalar(self)
471    }
472}
473
474/// Tensor-scalar addition assignment operator implementations
475///
476/// Provides in-place addition operations between tensors and scalars.
477impl AddAssign<f32> for Tensor {
478    /// Adds a scalar to each element of this tensor in-place
479    fn add_assign(&mut self, scalar: f32) {
480        *self = self.add_scalar(scalar);
481    }
482}
483
484/// Tensor-scalar subtraction operator implementations
485///
486/// Provides subtraction operations between tensors and scalars.
487/// All implementations delegate to the underlying `sub_scalar` method.
488impl Sub<f32> for Tensor {
489    type Output = Tensor;
490
491    /// Subtracts a scalar from each element of the tensor
492    ///
493    /// # Returns
494    ///
495    /// A new tensor with the scalar subtracted from each element
496    fn sub(self, scalar: f32) -> Tensor {
497        self.sub_scalar(scalar)
498    }
499}
500
501impl Sub<f32> for &Tensor {
502    type Output = Tensor;
503
504    /// Subtracts a scalar from each element of the tensor (reference version)
505    ///
506    /// # Returns
507    ///
508    /// A new tensor with the scalar subtracted from each element
509    fn sub(self, scalar: f32) -> Tensor {
510        self.sub_scalar(scalar)
511    }
512}
513
514/// Scalar-tensor subtraction operator implementations
515///
516/// Provides subtraction operations between scalars and tensors.
517/// Computes `scalar - tensor` by negating the tensor and adding the scalar.
518impl Sub<Tensor> for f32 {
519    type Output = Tensor;
520
521    /// Subtracts each element of the tensor from the scalar
522    ///
523    /// # Returns
524    ///
525    /// A new tensor with each element subtracted from the scalar
526    fn sub(self, tensor: Tensor) -> Tensor {
527        // For scalar - tensor, we need to negate the tensor and add the scalar
528        // This is equivalent to: scalar + (-tensor)
529        let mut result = tensor;
530        result.negate_inplace();
531        result.add_scalar(self)
532    }
533}
534
535impl Sub<&Tensor> for f32 {
536    type Output = Tensor;
537
538    /// Subtracts each element of the tensor from the scalar (reference version)
539    ///
540    /// # Returns
541    ///
542    /// A new tensor with each element subtracted from the scalar
543    fn sub(self, tensor: &Tensor) -> Tensor {
544        // For scalar - tensor, we need to negate the tensor and add the scalar
545        let mut result = tensor.clone();
546        result.negate_inplace();
547        result.add_scalar(self)
548    }
549}
550
551/// Tensor-scalar subtraction assignment operator implementations
552///
553/// Provides in-place subtraction operations between tensors and scalars.
554impl SubAssign<f32> for Tensor {
555    /// Subtracts a scalar from each element of this tensor in-place
556    fn sub_assign(&mut self, scalar: f32) {
557        *self = self.sub_scalar(scalar);
558    }
559}
560
561/// Tensor-scalar multiplication operator implementations
562///
563/// Provides multiplication operations between tensors and scalars.
564/// All implementations delegate to the underlying `mul_scalar` method.
565impl Mul<f32> for Tensor {
566    type Output = Tensor;
567
568    /// Multiplies each element of the tensor by a scalar
569    ///
570    /// # Returns
571    ///
572    /// A new tensor with each element multiplied by the scalar
573    fn mul(self, scalar: f32) -> Tensor {
574        self.mul_scalar(scalar)
575    }
576}
577
578impl Mul<f32> for &Tensor {
579    type Output = Tensor;
580
581    /// Multiplies each element of the tensor by a scalar (reference version)
582    ///
583    /// # Returns
584    ///
585    /// A new tensor with each element multiplied by the scalar
586    fn mul(self, scalar: f32) -> Tensor {
587        self.mul_scalar(scalar)
588    }
589}
590
591/// Scalar-tensor multiplication operator implementations
592///
593/// Provides multiplication operations between scalars and tensors.
594/// All implementations delegate to the underlying `mul_scalar` method.
595impl Mul<Tensor> for f32 {
596    type Output = Tensor;
597
598    /// Multiplies each element of the tensor by a scalar
599    ///
600    /// # Returns
601    ///
602    /// A new tensor with each element multiplied by the scalar
603    fn mul(self, tensor: Tensor) -> Tensor {
604        tensor.mul_scalar(self)
605    }
606}
607
608impl Mul<&Tensor> for f32 {
609    type Output = Tensor;
610
611    /// Multiplies each element of the tensor by a scalar (reference version)
612    ///
613    /// # Returns
614    ///
615    /// A new tensor with each element multiplied by the scalar
616    fn mul(self, tensor: &Tensor) -> Tensor {
617        tensor.mul_scalar(self)
618    }
619}
620
621/// Tensor-scalar multiplication assignment operator implementations
622///
623/// Provides in-place multiplication operations between tensors and scalars.
624impl MulAssign<f32> for Tensor {
625    /// Multiplies each element of this tensor by a scalar in-place
626    fn mul_assign(&mut self, scalar: f32) {
627        *self = self.mul_scalar(scalar);
628    }
629}
630
631/// Tensor-scalar division operator implementations
632///
633/// Provides division operations between tensors and scalars.
634/// All implementations delegate to the underlying `div_scalar` method.
635impl Div<f32> for Tensor {
636    type Output = Tensor;
637
638    /// Divides each element of the tensor by a scalar
639    ///
640    /// # Returns
641    ///
642    /// A new tensor with each element divided by the scalar
643    fn div(self, scalar: f32) -> Tensor {
644        self.div_scalar(scalar)
645    }
646}
647
648impl Div<f32> for &Tensor {
649    type Output = Tensor;
650
651    /// Divides each element of the tensor by a scalar (reference version)
652    ///
653    /// # Returns
654    ///
655    /// A new tensor with each element divided by the scalar
656    fn div(self, scalar: f32) -> Tensor {
657        self.div_scalar(scalar)
658    }
659}
660
661/// Scalar-tensor division operator implementations
662///
663/// Provides division operations between scalars and tensors.
664/// Computes `scalar / tensor` by computing the reciprocal of the tensor and multiplying by the scalar.
665impl Div<Tensor> for f32 {
666    type Output = Tensor;
667
668    /// Divides a scalar by each element of the tensor
669    ///
670    /// # Returns
671    ///
672    /// A new tensor with the scalar divided by each element
673    fn div(self, tensor: Tensor) -> Tensor {
674        // For scalar / tensor, we need to compute scalar / each element
675        // This is equivalent to: scalar * (1 / tensor)
676        tensor.pow_scalar(-1.0).mul_scalar(self)
677    }
678}
679
680impl Div<&Tensor> for f32 {
681    type Output = Tensor;
682
683    /// Divides a scalar by each element of the tensor (reference version)
684    ///
685    /// # Returns
686    ///
687    /// A new tensor with the scalar divided by each element
688    fn div(self, tensor: &Tensor) -> Tensor {
689        // For scalar / tensor, we need to compute scalar / each element
690        tensor.pow_scalar(-1.0).mul_scalar(self)
691    }
692}
693
694/// Tensor-scalar division assignment operator implementations
695///
696/// Provides in-place division operations between tensors and scalars.
697impl DivAssign<f32> for Tensor {
698    /// Divides each element of this tensor by a scalar in-place
699    fn div_assign(&mut self, scalar: f32) {
700        *self = self.div_scalar(scalar);
701    }
702}
703
704// ===== Negation =====
705
706use std::ops::Neg;
707
708/// Tensor negation operator implementations
709///
710/// Provides unary negation operations for tensors.
711/// All implementations delegate to the underlying `mul_scalar` method with -1.0.
712impl Neg for Tensor {
713    type Output = Tensor;
714
715    /// Negates each element of the tensor
716    ///
717    /// # Returns
718    ///
719    /// A new tensor with each element negated
720    fn neg(self) -> Tensor {
721        self.mul_scalar(-1.0)
722    }
723}
724
725impl Neg for &Tensor {
726    type Output = Tensor;
727
728    /// Negates each element of the tensor (reference version)
729    ///
730    /// # Returns
731    ///
732    /// A new tensor with each element negated
733    fn neg(self) -> Tensor {
734        self.mul_scalar(-1.0)
735    }
736}
737
738#[cfg(test)]
739mod tests {
740    use super::*;
741
742    // ===== Operator Overloading Tests =====
743
744    /// Test tensor addition operator overloading
745    ///
746    /// Verifies that all tensor addition operator combinations work correctly:
747    /// Tensor + Tensor, &Tensor + &Tensor, Tensor + &Tensor, &Tensor + Tensor,
748    /// and assignment operators (+=).
749    #[test]
750    fn test_tensor_addition_operators() {
751        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
752        let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
753
754        // Tensor + Tensor
755        let result = a.clone() + b.clone();
756        assert_eq!(result.get(&[0, 0]), 6.0);
757        assert_eq!(result.get(&[0, 1]), 8.0);
758        assert_eq!(result.get(&[1, 0]), 10.0);
759        assert_eq!(result.get(&[1, 1]), 12.0);
760
761        // &Tensor + &Tensor
762        let result = &a + &b;
763        assert_eq!(result.get(&[0, 0]), 6.0);
764
765        // Tensor + &Tensor
766        let result = a.clone() + &b;
767        assert_eq!(result.get(&[0, 0]), 6.0);
768
769        // &Tensor + Tensor
770        let result = &a + b.clone();
771        assert_eq!(result.get(&[0, 0]), 6.0);
772
773        // Tensor += Tensor
774        let mut c = a.clone();
775        c += b.clone();
776        assert_eq!(c.get(&[0, 0]), 6.0);
777
778        // Tensor += &Tensor
779        let mut c = a.clone();
780        c += &b;
781        assert_eq!(c.get(&[0, 0]), 6.0);
782    }
783
784    #[test]
785    fn test_tensor_subtraction_operators() {
786        let a = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
787        let b = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
788
789        // Tensor - Tensor
790        let result = a.clone() - b.clone();
791        assert_eq!(result.get(&[0, 0]), 4.0);
792        assert_eq!(result.get(&[0, 1]), 4.0);
793        assert_eq!(result.get(&[1, 0]), 4.0);
794        assert_eq!(result.get(&[1, 1]), 4.0);
795
796        // &Tensor - &Tensor
797        let result = &a - &b;
798        assert_eq!(result.get(&[0, 0]), 4.0);
799
800        // Tensor - &Tensor
801        let result = a.clone() - &b;
802        assert_eq!(result.get(&[0, 0]), 4.0);
803
804        // &Tensor - Tensor
805        let result = &a - b.clone();
806        assert_eq!(result.get(&[0, 0]), 4.0);
807
808        // Tensor -= Tensor
809        let mut c = a.clone();
810        c -= b.clone();
811        assert_eq!(c.get(&[0, 0]), 4.0);
812
813        // Tensor -= &Tensor
814        let mut c = a.clone();
815        c -= &b;
816        assert_eq!(c.get(&[0, 0]), 4.0);
817    }
818
819    #[test]
820    fn test_tensor_multiplication_operators() {
821        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
822        let b = Tensor::from_slice(&[2.0, 3.0, 4.0, 5.0], vec![2, 2]).unwrap();
823
824        // Tensor * Tensor
825        let result = a.clone() * b.clone();
826        assert_eq!(result.get(&[0, 0]), 2.0);
827        assert_eq!(result.get(&[0, 1]), 6.0);
828        assert_eq!(result.get(&[1, 0]), 12.0);
829        assert_eq!(result.get(&[1, 1]), 20.0);
830
831        // &Tensor * &Tensor
832        let result = &a * &b;
833        assert_eq!(result.get(&[0, 0]), 2.0);
834
835        // Tensor * &Tensor
836        let result = a.clone() * &b;
837        assert_eq!(result.get(&[0, 0]), 2.0);
838
839        // &Tensor * Tensor
840        let result = &a * b.clone();
841        assert_eq!(result.get(&[0, 0]), 2.0);
842
843        // Tensor *= Tensor
844        let mut c = a.clone();
845        c *= b.clone();
846        assert_eq!(c.get(&[0, 0]), 2.0);
847
848        // Tensor *= &Tensor
849        let mut c = a.clone();
850        c *= &b;
851        assert_eq!(c.get(&[0, 0]), 2.0);
852    }
853
854    #[test]
855    fn test_tensor_division_operators() {
856        let a = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0], vec![2, 2]).unwrap();
857        let b = Tensor::from_slice(&[2.0, 4.0, 5.0, 8.0], vec![2, 2]).unwrap();
858
859        // Tensor / Tensor
860        let result = a.clone() / b.clone();
861        assert_eq!(result.get(&[0, 0]), 5.0);
862        assert_eq!(result.get(&[0, 1]), 5.0);
863        assert_eq!(result.get(&[1, 0]), 6.0);
864        assert_eq!(result.get(&[1, 1]), 5.0);
865
866        // &Tensor / &Tensor
867        let result = &a / &b;
868        assert_eq!(result.get(&[0, 0]), 5.0);
869
870        // Tensor / &Tensor
871        let result = a.clone() / &b;
872        assert_eq!(result.get(&[0, 0]), 5.0);
873
874        // &Tensor / Tensor
875        let result = &a / b.clone();
876        assert_eq!(result.get(&[0, 0]), 5.0);
877
878        // Tensor /= Tensor
879        let mut c = a.clone();
880        c /= b.clone();
881        assert_eq!(c.get(&[0, 0]), 5.0);
882
883        // Tensor /= &Tensor
884        let mut c = a.clone();
885        c /= &b;
886        assert_eq!(c.get(&[0, 0]), 5.0);
887    }
888
889    #[test]
890    fn test_scalar_operations() {
891        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
892
893        // Tensor + f32
894        let result = a.clone() + 5.0;
895        assert_eq!(result.get(&[0, 0]), 6.0);
896        assert_eq!(result.get(&[0, 1]), 7.0);
897        assert_eq!(result.get(&[1, 0]), 8.0);
898        assert_eq!(result.get(&[1, 1]), 9.0);
899
900        // &Tensor + f32
901        let result = &a + 5.0;
902        assert_eq!(result.get(&[0, 0]), 6.0);
903
904        // f32 + Tensor
905        let result = 5.0 + a.clone();
906        assert_eq!(result.get(&[0, 0]), 6.0);
907
908        // f32 + &Tensor
909        let result = 5.0 + &a;
910        assert_eq!(result.get(&[0, 0]), 6.0);
911
912        // Tensor += f32
913        let mut b = a.clone();
914        b += 5.0;
915        assert_eq!(b.get(&[0, 0]), 6.0);
916
917        // Tensor - f32
918        let result = a.clone() - 2.0;
919        assert_eq!(result.get(&[0, 0]), -1.0);
920        assert_eq!(result.get(&[0, 1]), 0.0);
921        assert_eq!(result.get(&[1, 0]), 1.0);
922        assert_eq!(result.get(&[1, 1]), 2.0);
923
924        // &Tensor - f32
925        let result = &a - 2.0;
926        assert_eq!(result.get(&[0, 0]), -1.0);
927
928        // f32 - Tensor
929        let result = 10.0 - a.clone();
930        assert_eq!(result.get(&[0, 0]), 9.0);
931        assert_eq!(result.get(&[0, 1]), 8.0);
932        assert_eq!(result.get(&[1, 0]), 7.0);
933        assert_eq!(result.get(&[1, 1]), 6.0);
934
935        // f32 - &Tensor
936        let result = 10.0 - &a;
937        assert_eq!(result.get(&[0, 0]), 9.0);
938
939        // Tensor -= f32
940        let mut b = a.clone();
941        b -= 2.0;
942        assert_eq!(b.get(&[0, 0]), -1.0);
943
944        // Tensor * f32
945        let result = a.clone() * 3.0;
946        assert_eq!(result.get(&[0, 0]), 3.0);
947        assert_eq!(result.get(&[0, 1]), 6.0);
948        assert_eq!(result.get(&[1, 0]), 9.0);
949        assert_eq!(result.get(&[1, 1]), 12.0);
950
951        // &Tensor * f32
952        let result = &a * 3.0;
953        assert_eq!(result.get(&[0, 0]), 3.0);
954
955        // f32 * Tensor
956        let result = 3.0 * a.clone();
957        assert_eq!(result.get(&[0, 0]), 3.0);
958
959        // f32 * &Tensor
960        let result = 3.0 * &a;
961        assert_eq!(result.get(&[0, 0]), 3.0);
962
963        // Tensor *= f32
964        let mut b = a.clone();
965        b *= 3.0;
966        assert_eq!(b.get(&[0, 0]), 3.0);
967
968        // Tensor / f32
969        let result = a.clone() / 2.0;
970        assert_eq!(result.get(&[0, 0]), 0.5);
971        assert_eq!(result.get(&[0, 1]), 1.0);
972        assert_eq!(result.get(&[1, 0]), 1.5);
973        assert_eq!(result.get(&[1, 1]), 2.0);
974
975        // &Tensor / f32
976        let result = &a / 2.0;
977        assert_eq!(result.get(&[0, 0]), 0.5);
978
979        // f32 / Tensor
980        let result = 10.0 / a.clone();
981        assert_eq!(result.get(&[0, 0]), 10.0);
982        assert_eq!(result.get(&[0, 1]), 5.0);
983        assert!((result.get(&[1, 0]) - (10.0 / 3.0)).abs() < 1e-6);
984        assert_eq!(result.get(&[1, 1]), 2.5);
985
986        // f32 / &Tensor
987        let result = 10.0 / &a;
988        assert_eq!(result.get(&[0, 0]), 10.0);
989
990        // Tensor /= f32
991        let mut b = a.clone();
992        b /= 2.0;
993        assert_eq!(b.get(&[0, 0]), 0.5);
994    }
995
996    #[test]
997    fn test_negation_operator() {
998        let a = Tensor::from_slice(&[1.0, -2.0, 3.0, -4.0], vec![2, 2]).unwrap();
999
1000        // -Tensor
1001        let result = -a.clone();
1002        assert_eq!(result.get(&[0, 0]), -1.0);
1003        assert_eq!(result.get(&[0, 1]), 2.0);
1004        assert_eq!(result.get(&[1, 0]), -3.0);
1005        assert_eq!(result.get(&[1, 1]), 4.0);
1006
1007        // -&Tensor
1008        let result = -&a;
1009        assert_eq!(result.get(&[0, 0]), -1.0);
1010    }
1011
1012    /// Test complex operator chaining and compound expressions
1013    ///
1014    /// Verifies that complex mathematical expressions with multiple operators
1015    /// work correctly, including parentheses and operator precedence.
1016    #[test]
1017    fn test_operator_chaining() {
1018        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1019        let b = Tensor::from_slice(&[2.0, 3.0, 4.0, 5.0], vec![2, 2]).unwrap();
1020
1021        // Complex expression: (a + b) * 2 - 1
1022        let result = (a.clone() + b.clone()) * 2.0 - 1.0;
1023        assert_eq!(result.get(&[0, 0]), 5.0); // (1+2)*2-1 = 5
1024        assert_eq!(result.get(&[0, 1]), 9.0); // (2+3)*2-1 = 9
1025        assert_eq!(result.get(&[1, 0]), 13.0); // (3+4)*2-1 = 13
1026        assert_eq!(result.get(&[1, 1]), 17.0); // (4+5)*2-1 = 17
1027
1028        // With references: (&a + &b) * 2 - 1
1029        let result = (&a + &b) * 2.0 - 1.0;
1030        assert_eq!(result.get(&[0, 0]), 5.0);
1031    }
1032}