tensor_operators/
tensor_operators.rs

1//! Tensor Operators Example
2//!
3//! This example demonstrates Rust operator overloading for tensors in Train Station:
4//! - Tensor-tensor operations (+, -, *, /)
5//! - Tensor-scalar operations (+, -, *, /)
6//! - Assignment operators (+=, -=, *=, /=)
7//! - Operator chaining and complex expressions
8//! - Broadcasting behavior
9//! - Equivalence between operators and method calls
10//!
11//! # Learning Objectives
12//!
13//! - Understand how Train Station implements Rust operator overloading
14//! - Learn to use natural mathematical expressions with tensors
15//! - Explore tensor broadcasting and shape compatibility
16//! - Compare operator syntax with explicit method calls
17//!
18//! # Prerequisites
19//!
20//! - Basic Rust knowledge
21//! - Understanding of tensor basics (see tensor_basics.rs)
22//! - Familiarity with operator overloading concepts
23//!
24//! # Usage
25//!
26//! ```bash
27//! cargo run --example tensor_operators
28//! ```
29
30use train_station::Tensor;
31
32fn main() {
33    println!("=== Tensor Operators Example ===\n");
34
35    demonstrate_basic_operators();
36    demonstrate_scalar_operators();
37    demonstrate_operator_assignment();
38    demonstrate_operator_chaining();
39    demonstrate_broadcasting();
40    demonstrate_method_equivalence();
41
42    println!("\n=== Example completed successfully! ===");
43}
44
45/// Demonstrate basic tensor-tensor operators
46fn demonstrate_basic_operators() {
47    println!("--- Basic Tensor-Tensor Operators ---");
48
49    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
50    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
51
52    println!("Tensor A: {:?}", a.data());
53    println!("Tensor B: {:?}", b.data());
54
55    // Addition
56    let c = &a + &b;
57    println!("A + B: {:?}", c.data());
58
59    // Subtraction
60    let d = &a - &b;
61    println!("A - B: {:?}", d.data());
62
63    // Multiplication
64    let e = &a * &b;
65    println!("A * B: {:?}", e.data());
66
67    // Division
68    let f = &a / &b;
69    println!("A / B: {:?}", f.data());
70}
71
72/// Demonstrate tensor-scalar operators
73fn demonstrate_scalar_operators() {
74    println!("\n--- Tensor-Scalar Operators ---");
75
76    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
77    println!("Original tensor: {:?}", tensor.data());
78
79    // Tensor + scalar
80    let result1 = &tensor + 5.0;
81    println!("Tensor + 5.0: {:?}", result1.data());
82
83    // Scalar + tensor
84    let result2 = 5.0 + &tensor;
85    println!("5.0 + Tensor: {:?}", result2.data());
86
87    // Tensor - scalar
88    let result3 = &tensor - 2.0;
89    println!("Tensor - 2.0: {:?}", result3.data());
90
91    // Tensor * scalar
92    let result4 = &tensor * 3.0;
93    println!("Tensor * 3.0: {:?}", result4.data());
94
95    // Scalar * tensor
96    let result5 = 3.0 * &tensor;
97    println!("3.0 * Tensor: {:?}", result5.data());
98
99    // Tensor / scalar
100    let result6 = &tensor / 2.0;
101    println!("Tensor / 2.0: {:?}", result6.data());
102}
103
104/// Demonstrate assignment operators
105fn demonstrate_operator_assignment() {
106    println!("\n--- Assignment Operators ---");
107
108    let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
109    println!("Original tensor: {:?}", tensor.data());
110
111    // In-place addition
112    tensor += 5.0;
113    println!("After += 5.0: {:?}", tensor.data());
114
115    // In-place subtraction
116    tensor -= 2.0;
117    println!("After -= 2.0: {:?}", tensor.data());
118
119    // In-place multiplication
120    tensor *= 3.0;
121    println!("After *= 3.0: {:?}", tensor.data());
122
123    // In-place division
124    tensor /= 2.0;
125    println!("After /= 2.0: {:?}", tensor.data());
126}
127
128/// Demonstrate operator chaining and complex expressions
129fn demonstrate_operator_chaining() {
130    println!("\n--- Operator Chaining ---");
131
132    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
133    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
134    let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
135
136    println!("Tensor A: {:?}", a.data());
137    println!("Tensor B: {:?}", b.data());
138    println!("Tensor C: {:?}", c.data());
139
140    // Complex expression: (A + B) * C - 5
141    let result = (&a + &b) * &c - 5.0;
142    println!("(A + B) * C - 5: {:?}", result.data());
143
144    // Another complex expression: A * 2 + B / 2
145    let result2 = &a * 2.0 + &b / 2.0;
146    println!("A * 2 + B / 2: {:?}", result2.data());
147
148    // Negation and addition: -A + B * C
149    let result3 = -&a + &b * &c;
150    println!("-A + B * C: {:?}", result3.data());
151
152    // Division with parentheses: (A + B) / (C - 1)
153    let result4 = (&a + &b) / (&c - 1.0);
154    println!("(A + B) / (C - 1): {:?}", result4.data());
155}
156
157/// Demonstrate broadcasting behavior
158fn demonstrate_broadcasting() {
159    println!("\n--- Broadcasting ---");
160
161    // 2D tensor
162    let tensor_2d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
163    println!(
164        "2D tensor: shape {:?}, data: {:?}",
165        tensor_2d.shape().dims,
166        tensor_2d.data()
167    );
168
169    // 1D tensor (will be broadcasted)
170    let tensor_1d = Tensor::from_slice(&[10.0, 20.0], vec![2]).unwrap();
171    println!(
172        "1D tensor: shape {:?}, data: {:?}",
173        tensor_1d.shape().dims,
174        tensor_1d.data()
175    );
176
177    // Broadcasting addition
178    let broadcast_sum = &tensor_2d + &tensor_1d;
179    println!(
180        "Broadcast sum: shape {:?}, data: {:?}",
181        broadcast_sum.shape().dims,
182        broadcast_sum.data()
183    );
184
185    // Broadcasting multiplication
186    let broadcast_mul = &tensor_2d * &tensor_1d;
187    println!(
188        "Broadcast multiplication: shape {:?}, data: {:?}",
189        broadcast_mul.shape().dims,
190        broadcast_mul.data()
191    );
192
193    // Broadcasting with scalar
194    let broadcast_scalar = &tensor_2d + 100.0;
195    println!(
196        "Broadcast scalar: shape {:?}, data: {:?}",
197        broadcast_scalar.shape().dims,
198        broadcast_scalar.data()
199    );
200}
201
202/// Demonstrate equivalence between operators and method calls
203fn demonstrate_method_equivalence() {
204    println!("\n--- Operator vs Method Call Equivalence ---");
205
206    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
207    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
208
209    // Addition: operator vs method
210    let operator_result = &a + &b;
211    let method_result = a.add_tensor(&b);
212
213    println!("A + B (operator): {:?}", operator_result.data());
214    println!("A.add_tensor(B): {:?}", method_result.data());
215    println!(
216        "Results are equal: {}",
217        operator_result.data() == method_result.data()
218    );
219
220    // Multiplication: operator vs method
221    let operator_result = &a * &b;
222    let method_result = a.mul_tensor(&b);
223
224    println!("A * B (operator): {:?}", operator_result.data());
225    println!("A.mul_tensor(B): {:?}", method_result.data());
226    println!(
227        "Results are equal: {}",
228        operator_result.data() == method_result.data()
229    );
230
231    // Scalar addition: operator vs method
232    let operator_result = &a + 5.0;
233    let method_result = a.add_scalar(5.0);
234
235    println!("A + 5.0 (operator): {:?}", operator_result.data());
236    println!("A.add_scalar(5.0): {:?}", method_result.data());
237    println!(
238        "Results are equal: {}",
239        operator_result.data() == method_result.data()
240    );
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_basic_operators() {
249        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
250        let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
251
252        let sum = &a + &b;
253        assert_eq!(sum.data(), &[4.0, 6.0]);
254
255        let product = &a * &b;
256        assert_eq!(product.data(), &[3.0, 8.0]);
257    }
258
259    #[test]
260    fn test_scalar_operators() {
261        let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
262
263        let result = &tensor + 5.0;
264        assert_eq!(result.data(), &[6.0, 7.0]);
265
266        let result = 5.0 + &tensor;
267        assert_eq!(result.data(), &[6.0, 7.0]);
268    }
269
270    #[test]
271    fn test_assignment_operators() {
272        let mut tensor = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
273
274        tensor += 3.0;
275        assert_eq!(tensor.data(), &[4.0, 5.0]);
276
277        tensor *= 2.0;
278        assert_eq!(tensor.data(), &[8.0, 10.0]);
279    }
280
281    #[test]
282    fn test_operator_chaining() {
283        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
284        let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
285
286        let result = (&a + &b) * 2.0;
287        assert_eq!(result.data(), &[8.0, 12.0]);
288    }
289}