1use train_station::tensor::{TensorCollectExt, ValuesCollectExt};
69use train_station::{gradtrack::with_no_grad, Tensor};
70
71fn main() -> Result<(), Box<dyn std::error::Error>> {
76 println!("Starting Element Iteration Example");
77
78 demonstrate_basic_iteration()?;
79 demonstrate_standard_methods()?;
80 demonstrate_gradient_tracking()?;
81 demonstrate_advanced_patterns()?;
82 demonstrate_row_wise_collect_shape()?;
83 demonstrate_nograd_and_streaming()?;
84
85 println!("Element Iteration Example completed successfully!");
86 Ok(())
87}
88
89fn demonstrate_basic_iteration() -> Result<(), Box<dyn std::error::Error>> {
94 println!("\n--- Basic Element Iteration ---");
95
96 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
98 println!("Original tensor: {:?}", tensor.data());
99
100 println!("\nBasic iteration with for loop:");
102 for (i, element) in tensor.iter().enumerate() {
103 println!(
104 " Element {}: value = {:.1}, shape = {:?}",
105 i,
106 element.value(),
107 element.shape().dims()
108 );
109 }
110
111 println!("\nElement-wise transformation (2x + 1):");
113 let transformed: Tensor = tensor
114 .iter()
115 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
116 .collect();
117 println!(" Result: {:?}", transformed.data());
118
119 println!("\nFiltering elements (values > 3.0):");
121 let filtered: Tensor = tensor.iter().filter(|elem| elem.value() > 3.0).collect();
122 println!(" Filtered: {:?}", filtered.data());
123
124 Ok(())
125}
126
127fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
132 println!("\n--- Standard Iterator Methods ---");
133
134 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
135
136 println!("\nMap transformation (square each element):");
138 let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
139 println!(" Squared: {:?}", squared.data());
140
141 println!("\nEnumerate with indexed operations:");
143 let indexed: Tensor = tensor
144 .iter()
145 .enumerate()
146 .map(|(i, elem)| elem.add_scalar(i as f32))
147 .collect();
148 println!(" Indexed: {:?}", indexed.data());
149
150 println!("\nFold for sum calculation:");
152 let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
153 println!(" Sum: {:.1}", sum);
154
155 println!("\nFind specific element:");
157 if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
158 println!(" Found element with value 3.0: {:.1}", found.value());
159 }
160
161 println!("\nCondition checking:");
163 let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
164 let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
165 println!(" All positive: {}", all_positive);
166 println!(" Any > 4.0: {}", any_large);
167
168 Ok(())
169}
170
171fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176 println!("\n--- Gradient Tracking ---");
177
178 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180 println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182 let result: Tensor = tensor
184 .iter()
185 .map(|elem| {
186 elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188 })
189 .collect();
190
191 println!("Result tensor: {:?}", result.data());
192 println!("Result requires_grad: {}", result.requires_grad());
193
194 let mut loss = result.sum();
196 loss.backward(None);
197
198 println!("Loss: {:.6}", loss.value());
199 println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201 Ok(())
202}
203
204fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209 println!("\n--- Advanced Iterator Patterns ---");
210
211 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212 println!("Input tensor: {:?}", tensor.data());
213
214 println!("\nComplex chain (even indices only, add index to value):");
216 let result: Tensor = tensor
217 .iter()
218 .enumerate()
219 .filter(|(i, _)| i % 2 == 0) .map(|(i, elem)| elem.add_scalar(i as f32)) .collect();
222 println!(" Result: {:?}", result.data());
223
224 println!("\nWindowing with take and skip:");
226 let window1: Tensor = tensor.iter().take(3).collect();
227 let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228 println!(" Window 1 (first 3): {:?}", window1.data());
229 println!(" Window 2 (middle 3): {:?}", window2.data());
230
231 println!("\nReverse iteration:");
233 let reversed: Tensor = tensor.iter().rev().collect();
234 println!(" Reversed: {:?}", reversed.data());
235
236 println!("\nMathematical operation chain:");
238 let math_result: Tensor = tensor
239 .iter()
240 .map(|elem| elem.exp()) .filter(|elem| elem.value() < 50.0) .map(|elem| elem.log()) .collect();
244 println!(" Math chain result: {:?}", math_result.data());
245
246 println!("\nElement-wise combination with zip:");
248 let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249 let combined: Tensor = tensor
250 .iter()
251 .zip(tensor2.iter())
252 .map(|(a, b)| a.mul_tensor(&b)) .collect();
254 println!(" Combined: {:?}", combined.data());
255
256 Ok(())
257}
258
259fn demonstrate_row_wise_collect_shape() -> Result<(), Box<dyn std::error::Error>> {
264 println!("\n--- Row-wise iteration with collect_shape ---");
265 let mat = Tensor::from_slice(&(1..=12).map(|x| x as f32).collect::<Vec<_>>(), vec![3, 4])?;
266 println!("Input shape: {:?}", mat.shape().dims());
267
268 let out: Tensor = mat
270 .iter()
271 .map(|row| row.mul_scalar(1.1).add_scalar(0.5))
272 .collect_shape(vec![3, 4]);
273 println!(" Output shape: {:?}", out.shape().dims());
274
275 Ok(())
276}
277
278fn demonstrate_nograd_and_streaming() -> Result<(), Box<dyn std::error::Error>> {
285 println!("\n--- NoGrad & Streaming (Inference Fast Paths) ---");
286
287 let input = Tensor::from_slice(
288 &(0..24).map(|i| i as f32 * 0.25).collect::<Vec<_>>(),
289 vec![4, 6],
290 )?;
291 println!("Input shape: {:?}", input.shape().dims());
292
293 let out = with_no_grad(|| {
295 input
296 .data()
297 .iter()
298 .copied()
299 .map(|x| 1.2 * x - 0.3)
300 .collect_shape(vec![4, 6])
301 });
302 println!(
303 " NoGrad streamed map (1.2x-0.3) -> shape {:?}",
304 out.shape().dims()
305 );
306
307 let out_view: Tensor = with_no_grad(|| {
309 input
310 .iter()
311 .map(|e| e.mul_scalar(1.2).add_scalar(-0.3))
312 .collect_shape(vec![4, 6])
313 });
314 println!(
315 " NoGrad view-based map shape {:?}",
316 out_view.shape().dims()
317 );
318
319 assert_eq!(out.data(), out_view.data());
321 println!(" Parity check passed.");
322
323 let reshaped = with_no_grad(|| input.data().iter().copied().collect_shape(vec![6, 4]));
325 println!(
326 " Reshaped via streaming collect_shape: {:?}",
327 reshaped.shape().dims()
328 );
329
330 Ok(())
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
339 fn test_basic_iteration() {
340 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
341 let elements: Vec<Tensor> = tensor.iter().collect();
342
343 assert_eq!(elements.len(), 3);
344 assert_eq!(elements[0].value(), 1.0);
345 assert_eq!(elements[1].value(), 2.0);
346 assert_eq!(elements[2].value(), 3.0);
347 }
348
349 #[test]
351 fn test_element_transformation() {
352 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
353 let doubled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
354
355 assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
356 }
357
358 #[test]
360 fn test_gradient_tracking() {
361 let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
362 .unwrap()
363 .with_requires_grad();
364
365 let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
366
367 assert!(result.requires_grad());
368 assert_eq!(result.data(), &[2.0, 4.0]);
369 }
370}