1use train_station::Tensor;
31
32fn main() {
33 println!("=== Tensor Operators Example ===\n");
34
35 demonstrate_basic_operators();
36 demonstrate_scalar_operators();
37 demonstrate_operator_assignment();
38 demonstrate_operator_chaining();
39 demonstrate_broadcasting();
40 demonstrate_method_equivalence();
41
42 println!("\n=== Example completed successfully! ===");
43}
44
45fn demonstrate_basic_operators() {
47 println!("--- Basic Tensor-Tensor Operators ---");
48
49 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
50 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
51
52 println!("Tensor A: {:?}", a.data());
53 println!("Tensor B: {:?}", b.data());
54
55 let c = &a + &b;
57 println!("A + B: {:?}", c.data());
58
59 let d = &a - &b;
61 println!("A - B: {:?}", d.data());
62
63 let e = &a * &b;
65 println!("A * B: {:?}", e.data());
66
67 let f = &a / &b;
69 println!("A / B: {:?}", f.data());
70}
71
72fn demonstrate_scalar_operators() {
74 println!("\n--- Tensor-Scalar Operators ---");
75
76 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
77 println!("Original tensor: {:?}", tensor.data());
78
79 let result1 = &tensor + 5.0;
81 println!("Tensor + 5.0: {:?}", result1.data());
82
83 let result2 = 5.0 + &tensor;
85 println!("5.0 + Tensor: {:?}", result2.data());
86
87 let result3 = &tensor - 2.0;
89 println!("Tensor - 2.0: {:?}", result3.data());
90
91 let result4 = &tensor * 3.0;
93 println!("Tensor * 3.0: {:?}", result4.data());
94
95 let result5 = 3.0 * &tensor;
97 println!("3.0 * Tensor: {:?}", result5.data());
98
99 let result6 = &tensor / 2.0;
101 println!("Tensor / 2.0: {:?}", result6.data());
102}
103
104fn demonstrate_operator_assignment() {
106 println!("\n--- Assignment Operators ---");
107
108 let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
109 println!("Original tensor: {:?}", tensor.data());
110
111 tensor += 5.0;
113 println!("After += 5.0: {:?}", tensor.data());
114
115 tensor -= 2.0;
117 println!("After -= 2.0: {:?}", tensor.data());
118
119 tensor *= 3.0;
121 println!("After *= 3.0: {:?}", tensor.data());
122
123 tensor /= 2.0;
125 println!("After /= 2.0: {:?}", tensor.data());
126}
127
128fn demonstrate_operator_chaining() {
130 println!("\n--- Operator Chaining ---");
131
132 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
133 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
134 let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
135
136 println!("Tensor A: {:?}", a.data());
137 println!("Tensor B: {:?}", b.data());
138 println!("Tensor C: {:?}", c.data());
139
140 let result = (&a + &b) * &c - 5.0;
142 println!("(A + B) * C - 5: {:?}", result.data());
143
144 let result2 = &a * 2.0 + &b / 2.0;
146 println!("A * 2 + B / 2: {:?}", result2.data());
147
148 let result3 = -&a + &b * &c;
150 println!("-A + B * C: {:?}", result3.data());
151
152 let result4 = (&a + &b) / (&c - 1.0);
154 println!("(A + B) / (C - 1): {:?}", result4.data());
155}
156
157fn demonstrate_broadcasting() {
159 println!("\n--- Broadcasting ---");
160
161 let tensor_2d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
163 println!(
164 "2D tensor: shape {:?}, data: {:?}",
165 tensor_2d.shape().dims,
166 tensor_2d.data()
167 );
168
169 let tensor_1d = Tensor::from_slice(&[10.0, 20.0], vec![2]).unwrap();
171 println!(
172 "1D tensor: shape {:?}, data: {:?}",
173 tensor_1d.shape().dims,
174 tensor_1d.data()
175 );
176
177 let broadcast_sum = &tensor_2d + &tensor_1d;
179 println!(
180 "Broadcast sum: shape {:?}, data: {:?}",
181 broadcast_sum.shape().dims,
182 broadcast_sum.data()
183 );
184
185 let broadcast_mul = &tensor_2d * &tensor_1d;
187 println!(
188 "Broadcast multiplication: shape {:?}, data: {:?}",
189 broadcast_mul.shape().dims,
190 broadcast_mul.data()
191 );
192
193 let broadcast_scalar = &tensor_2d + 100.0;
195 println!(
196 "Broadcast scalar: shape {:?}, data: {:?}",
197 broadcast_scalar.shape().dims,
198 broadcast_scalar.data()
199 );
200}
201
202fn demonstrate_method_equivalence() {
204 println!("\n--- Operator vs Method Call Equivalence ---");
205
206 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
207 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
208
209 let operator_result = &a + &b;
211 let method_result = a.add_tensor(&b);
212
213 println!("A + B (operator): {:?}", operator_result.data());
214 println!("A.add_tensor(B): {:?}", method_result.data());
215 println!(
216 "Results are equal: {}",
217 operator_result.data() == method_result.data()
218 );
219
220 let operator_result = &a * &b;
222 let method_result = a.mul_tensor(&b);
223
224 println!("A * B (operator): {:?}", operator_result.data());
225 println!("A.mul_tensor(B): {:?}", method_result.data());
226 println!(
227 "Results are equal: {}",
228 operator_result.data() == method_result.data()
229 );
230
231 let operator_result = &a + 5.0;
233 let method_result = a.add_scalar(5.0);
234
235 println!("A + 5.0 (operator): {:?}", operator_result.data());
236 println!("A.add_scalar(5.0): {:?}", method_result.data());
237 println!(
238 "Results are equal: {}",
239 operator_result.data() == method_result.data()
240 );
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_basic_operators() {
249 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
250 let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
251
252 let sum = &a + &b;
253 assert_eq!(sum.data(), &[4.0, 6.0]);
254
255 let product = &a * &b;
256 assert_eq!(product.data(), &[3.0, 8.0]);
257 }
258
259 #[test]
260 fn test_scalar_operators() {
261 let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
262
263 let result = &tensor + 5.0;
264 assert_eq!(result.data(), &[6.0, 7.0]);
265
266 let result = 5.0 + &tensor;
267 assert_eq!(result.data(), &[6.0, 7.0]);
268 }
269
270 #[test]
271 fn test_assignment_operators() {
272 let mut tensor = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
273
274 tensor += 3.0;
275 assert_eq!(tensor.data(), &[4.0, 5.0]);
276
277 tensor *= 2.0;
278 assert_eq!(tensor.data(), &[8.0, 10.0]);
279 }
280
281 #[test]
282 fn test_operator_chaining() {
283 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
284 let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
285
286 let result = (&a + &b) * 2.0;
287 assert_eq!(result.data(), &[8.0, 12.0]);
288 }
289}