element_iteration/
element_iteration.rs1use train_station::Tensor;
56
57fn main() -> Result<(), Box<dyn std::error::Error>> {
62 println!("Starting Element Iteration Example");
63
64 demonstrate_basic_iteration()?;
65 demonstrate_standard_methods()?;
66 demonstrate_gradient_tracking()?;
67 demonstrate_advanced_patterns()?;
68
69 println!("Element Iteration Example completed successfully!");
70 Ok(())
71}
72
73fn demonstrate_basic_iteration() -> Result<(), Box<dyn std::error::Error>> {
78 println!("\n--- Basic Element Iteration ---");
79
80 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
82 println!("Original tensor: {:?}", tensor.data());
83
84 println!("\nBasic iteration with for loop:");
86 for (i, element) in tensor.iter().enumerate() {
87 println!(
88 " Element {}: value = {:.1}, shape = {:?}",
89 i,
90 element.value(),
91 element.shape().dims
92 );
93 }
94
95 println!("\nElement-wise transformation (2x + 1):");
97 let transformed: Tensor = tensor
98 .iter()
99 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
100 .collect();
101 println!(" Result: {:?}", transformed.data());
102
103 println!("\nFiltering elements (values > 3.0):");
105 let filtered: Tensor = tensor.iter().filter(|elem| elem.value() > 3.0).collect();
106 println!(" Filtered: {:?}", filtered.data());
107
108 Ok(())
109}
110
111fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
116 println!("\n--- Standard Iterator Methods ---");
117
118 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
119
120 println!("\nMap transformation (square each element):");
122 let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
123 println!(" Squared: {:?}", squared.data());
124
125 println!("\nEnumerate with indexed operations:");
127 let indexed: Tensor = tensor
128 .iter()
129 .enumerate()
130 .map(|(i, elem)| elem.add_scalar(i as f32))
131 .collect();
132 println!(" Indexed: {:?}", indexed.data());
133
134 println!("\nFold for sum calculation:");
136 let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
137 println!(" Sum: {:.1}", sum);
138
139 println!("\nFind specific element:");
141 if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
142 println!(" Found element with value 3.0: {:.1}", found.value());
143 }
144
145 println!("\nCondition checking:");
147 let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
148 let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
149 println!(" All positive: {}", all_positive);
150 println!(" Any > 4.0: {}", any_large);
151
152 Ok(())
153}
154
155fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
160 println!("\n--- Gradient Tracking ---");
161
162 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
164 println!("Input tensor (requires_grad): {:?}", tensor.data());
165
166 let result: Tensor = tensor
168 .iter()
169 .map(|elem| {
170 elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
172 })
173 .collect();
174
175 println!("Result tensor: {:?}", result.data());
176 println!("Result requires_grad: {}", result.requires_grad());
177
178 let mut loss = result.sum();
180 loss.backward(None);
181
182 println!("Loss: {:.6}", loss.value());
183 println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
184
185 Ok(())
186}
187
188fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
193 println!("\n--- Advanced Iterator Patterns ---");
194
195 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
196 println!("Input tensor: {:?}", tensor.data());
197
198 println!("\nComplex chain (even indices only, add index to value):");
200 let result: Tensor = tensor
201 .iter()
202 .enumerate()
203 .filter(|(i, _)| i % 2 == 0) .map(|(i, elem)| elem.add_scalar(i as f32)) .collect();
206 println!(" Result: {:?}", result.data());
207
208 println!("\nWindowing with take and skip:");
210 let window1: Tensor = tensor.iter().take(3).collect();
211 let window2: Tensor = tensor.iter().skip(2).take(3).collect();
212 println!(" Window 1 (first 3): {:?}", window1.data());
213 println!(" Window 2 (middle 3): {:?}", window2.data());
214
215 println!("\nReverse iteration:");
217 let reversed: Tensor = tensor.iter().rev().collect();
218 println!(" Reversed: {:?}", reversed.data());
219
220 println!("\nMathematical operation chain:");
222 let math_result: Tensor = tensor
223 .iter()
224 .map(|elem| elem.exp()) .filter(|elem| elem.value() < 50.0) .map(|elem| elem.log()) .collect();
228 println!(" Math chain result: {:?}", math_result.data());
229
230 println!("\nElement-wise combination with zip:");
232 let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
233 let combined: Tensor = tensor
234 .iter()
235 .zip(tensor2.iter())
236 .map(|(a, b)| a.mul_tensor(&b)) .collect();
238 println!(" Combined: {:?}", combined.data());
239
240 Ok(())
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
249 fn test_basic_iteration() {
250 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
251 let elements: Vec<Tensor> = tensor.iter().collect();
252
253 assert_eq!(elements.len(), 3);
254 assert_eq!(elements[0].value(), 1.0);
255 assert_eq!(elements[1].value(), 2.0);
256 assert_eq!(elements[2].value(), 3.0);
257 }
258
259 #[test]
261 fn test_element_transformation() {
262 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
263 let doubled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
264
265 assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
266 }
267
268 #[test]
270 fn test_gradient_tracking() {
271 let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
272 .unwrap()
273 .with_requires_grad();
274
275 let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
276
277 assert!(result.requires_grad());
278 assert_eq!(result.data(), &[2.0, 4.0]);
279 }
280}