1use train_station::Tensor;
28
29fn main() {
30 println!("=== Tensor Basics Example ===\n");
31
32 demonstrate_tensor_creation();
33 demonstrate_basic_operations();
34 demonstrate_shape_operations();
35 demonstrate_data_access();
36 demonstrate_utility_functions();
37
38 println!("\n=== Example completed successfully! ===");
39}
40
41fn demonstrate_tensor_creation() {
43 println!("--- Tensor Creation ---");
44
45 let zeros = Tensor::zeros(vec![2, 3]);
47 println!(
48 "Zeros tensor: shape {:?}, data: {:?}",
49 zeros.shape().dims(),
50 zeros.data()
51 );
52
53 let ones = Tensor::ones(vec![3, 2]);
54 println!(
55 "Ones tensor: shape {:?}, data: {:?}",
56 ones.shape().dims(),
57 ones.data()
58 );
59
60 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62 let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63 println!(
64 "From slice: shape {:?}, data: {:?}",
65 from_slice.shape().dims(),
66 from_slice.data()
67 );
68
69 let mut filled = Tensor::new(vec![2, 2]);
71 {
72 let data = filled.data_mut();
73 for value in data.iter_mut() {
74 *value = 42.0;
75 }
76 }
77 println!("Filled with 42: {:?}", filled.data());
78
79 let random = Tensor::randn(vec![2, 2], Some(42));
81 println!(
82 "Random tensor: shape {:?}, data: {:?}",
83 random.shape().dims(),
84 random.data()
85 );
86}
87
88fn demonstrate_basic_operations() {
90 println!("\n--- Basic Operations ---");
91
92 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
93 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
94
95 let sum = a.add_tensor(&b);
97 println!("A + B: {:?}", sum.data());
98
99 let diff = a.sub_tensor(&b);
101 println!("A - B: {:?}", diff.data());
102
103 let product = a.mul_tensor(&b);
105 println!("A * B: {:?}", product.data());
106
107 let quotient = a.div_tensor(&b);
109 println!("A / B: {:?}", quotient.data());
110
111 let scalar_add = a.add_scalar(5.0);
113 println!("A + 5.0: {:?}", scalar_add.data());
114
115 let scalar_mul = a.mul_scalar(2.0);
116 println!("A * 2.0: {:?}", scalar_mul.data());
117}
118
119fn demonstrate_shape_operations() {
121 println!("\n--- Shape Operations ---");
122
123 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
124 println!(
125 "Original: shape {:?}, data: {:?}",
126 tensor.shape().dims(),
127 tensor.data()
128 );
129
130 let reshaped = tensor.view(vec![3, 2]);
132 println!(
133 "Reshaped to [3, 2]: shape {:?}, data: {:?}",
134 reshaped.shape().dims(),
135 reshaped.data()
136 );
137
138 let tensor_2d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
140 println!(
141 "2D tensor: shape {:?}, data: {:?}",
142 tensor_2d.shape().dims(),
143 tensor_2d.data()
144 );
145
146 let tensor_1d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
148 println!(
149 "1D tensor: shape {:?}, data: {:?}",
150 tensor_1d.shape().dims(),
151 tensor_1d.data()
152 );
153}
154
155fn demonstrate_data_access() {
157 println!("\n--- Data Access ---");
158
159 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
160
161 println!("Element [0, 0]: {}", tensor.get(&[0, 0]));
163 println!("Element [0, 1]: {}", tensor.get(&[0, 1]));
164 println!("Element [1, 0]: {}", tensor.get(&[1, 0]));
165 println!("Element [1, 1]: {}", tensor.get(&[1, 1]));
166
167 let data = tensor.data();
169 println!("Data as slice: {:?}", data);
170
171 println!("Elements:");
173 for (i, &value) in data.iter().enumerate() {
174 println!(" [{}]: {}", i, value);
175 }
176}
177
178fn demonstrate_utility_functions() {
180 println!("\n--- Utility Functions ---");
181
182 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184 println!("Shape: {:?}", tensor.shape().dims());
186 println!("Size: {}", tensor.size());
187 println!("Is contiguous: {}", tensor.is_contiguous());
188 println!("Device: {:?}", tensor.device());
189
190 let sum = tensor.sum();
192 println!("Sum: {}", sum.value());
193
194 let mean = tensor.mean();
195 println!("Mean: {}", mean.value());
196
197 let norm = tensor.norm();
198 println!("Norm: {}", norm.value());
199
200 let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202 println!(
203 "CPU tensor: shape {:?}, device: {:?}",
204 cpu_tensor.shape().dims(),
205 cpu_tensor.device()
206 );
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_tensor_creation() {
215 let tensor = Tensor::zeros(vec![2, 3]);
216 assert_eq!(tensor.shape().dims(), vec![2, 3]);
217 assert_eq!(tensor.size(), 6);
218 assert_eq!(tensor.data(), &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
219 }
220
221 #[test]
222 fn test_basic_operations() {
223 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
224 let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
225
226 let sum = a.add_tensor(&b);
227 assert_eq!(sum.data(), &[4.0, 6.0]);
228
229 let scalar_add = a.add_scalar(5.0);
230 assert_eq!(scalar_add.data(), &[6.0, 7.0]);
231 }
232
233 #[test]
234 fn test_shape_operations() {
235 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
236
237 let reshaped = tensor.view(vec![4]);
238 assert_eq!(reshaped.shape().dims(), vec![4]);
239 assert_eq!(reshaped.data(), tensor.data());
240 }
241
242 #[test]
243 fn test_data_access() {
244 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
245
246 assert_eq!(tensor.get(&[0, 0]), 1.0);
247 assert_eq!(tensor.get(&[0, 1]), 2.0);
248 assert_eq!(tensor.get(&[1, 0]), 3.0);
249 assert_eq!(tensor.get(&[1, 1]), 4.0);
250 }
251}