tensor_basics/
tensor_basics.rs

1//! Tensor Basics Example
2//!
3//! This example demonstrates fundamental tensor concepts in Train Station:
4//! - Creating tensors with different initializations
5//! - Basic arithmetic operations
6//! - Shape manipulation and data access
7//! - Utility functions and properties
8//!
9//! # Learning Objectives
10//!
11//! - Understand tensor creation and initialization
12//! - Learn basic tensor operations and arithmetic
13//! - Explore shape manipulation and data access patterns
14//! - Discover utility functions for tensor analysis
15//!
16//! # Prerequisites
17//!
18//! - Basic Rust knowledge
19//! - Understanding of multi-dimensional arrays
20//!
21//! # Usage
22//!
23//! ```bash
24//! cargo run --example tensor_basics
25//! ```
26
27use train_station::Tensor;
28
29fn main() {
30    println!("=== Tensor Basics Example ===\n");
31
32    demonstrate_tensor_creation();
33    demonstrate_basic_operations();
34    demonstrate_shape_operations();
35    demonstrate_data_access();
36    demonstrate_utility_functions();
37
38    println!("\n=== Example completed successfully! ===");
39}
40
41/// Demonstrate different ways to create tensors
42fn demonstrate_tensor_creation() {
43    println!("--- Tensor Creation ---");
44
45    // Create tensors with different initializations
46    let zeros = Tensor::zeros(vec![2, 3]);
47    println!(
48        "Zeros tensor: shape {:?}, data: {:?}",
49        zeros.shape().dims(),
50        zeros.data()
51    );
52
53    let ones = Tensor::ones(vec![3, 2]);
54    println!(
55        "Ones tensor: shape {:?}, data: {:?}",
56        ones.shape().dims(),
57        ones.data()
58    );
59
60    // Create tensor from slice
61    let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62    let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63    println!(
64        "From slice: shape {:?}, data: {:?}",
65        from_slice.shape().dims(),
66        from_slice.data()
67    );
68
69    // Create tensor with specific value
70    let mut filled = Tensor::new(vec![2, 2]);
71    {
72        let data = filled.data_mut();
73        for value in data.iter_mut() {
74            *value = 42.0;
75        }
76    }
77    println!("Filled with 42: {:?}", filled.data());
78
79    // Create tensor with random data
80    let random = Tensor::randn(vec![2, 2], Some(42));
81    println!(
82        "Random tensor: shape {:?}, data: {:?}",
83        random.shape().dims(),
84        random.data()
85    );
86}
87
88/// Demonstrate basic arithmetic operations
89fn demonstrate_basic_operations() {
90    println!("\n--- Basic Operations ---");
91
92    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
93    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
94
95    // Addition
96    let sum = a.add_tensor(&b);
97    println!("A + B: {:?}", sum.data());
98
99    // Subtraction
100    let diff = a.sub_tensor(&b);
101    println!("A - B: {:?}", diff.data());
102
103    // Multiplication
104    let product = a.mul_tensor(&b);
105    println!("A * B: {:?}", product.data());
106
107    // Division
108    let quotient = a.div_tensor(&b);
109    println!("A / B: {:?}", quotient.data());
110
111    // Scalar operations
112    let scalar_add = a.add_scalar(5.0);
113    println!("A + 5.0: {:?}", scalar_add.data());
114
115    let scalar_mul = a.mul_scalar(2.0);
116    println!("A * 2.0: {:?}", scalar_mul.data());
117}
118
119/// Demonstrate shape manipulation operations
120fn demonstrate_shape_operations() {
121    println!("\n--- Shape Operations ---");
122
123    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
124    println!(
125        "Original: shape {:?}, data: {:?}",
126        tensor.shape().dims(),
127        tensor.data()
128    );
129
130    // Reshape (view)
131    let reshaped = tensor.view(vec![3, 2]);
132    println!(
133        "Reshaped to [3, 2]: shape {:?}, data: {:?}",
134        reshaped.shape().dims(),
135        reshaped.data()
136    );
137
138    // Create a different shaped tensor for demonstration
139    let tensor_2d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
140    println!(
141        "2D tensor: shape {:?}, data: {:?}",
142        tensor_2d.shape().dims(),
143        tensor_2d.data()
144    );
145
146    // Create a 1D tensor
147    let tensor_1d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
148    println!(
149        "1D tensor: shape {:?}, data: {:?}",
150        tensor_1d.shape().dims(),
151        tensor_1d.data()
152    );
153}
154
155/// Demonstrate data access patterns
156fn demonstrate_data_access() {
157    println!("\n--- Data Access ---");
158
159    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
160
161    // Access individual elements
162    println!("Element [0, 0]: {}", tensor.get(&[0, 0]));
163    println!("Element [0, 1]: {}", tensor.get(&[0, 1]));
164    println!("Element [1, 0]: {}", tensor.get(&[1, 0]));
165    println!("Element [1, 1]: {}", tensor.get(&[1, 1]));
166
167    // Access data as slice
168    let data = tensor.data();
169    println!("Data as slice: {:?}", data);
170
171    // Iterate over elements
172    println!("Elements:");
173    for (i, &value) in data.iter().enumerate() {
174        println!("  [{}]: {}", i, value);
175    }
176}
177
178/// Demonstrate utility functions
179fn demonstrate_utility_functions() {
180    println!("\n--- Utility Functions ---");
181
182    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184    // Basic properties
185    println!("Shape: {:?}", tensor.shape().dims());
186    println!("Size: {}", tensor.size());
187    println!("Is contiguous: {}", tensor.is_contiguous());
188    println!("Device: {:?}", tensor.device());
189
190    // Mathematical operations
191    let sum = tensor.sum();
192    println!("Sum: {}", sum.value());
193
194    let mean = tensor.mean();
195    println!("Mean: {}", mean.value());
196
197    let norm = tensor.norm();
198    println!("Norm: {}", norm.value());
199
200    // Device placement
201    let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202    println!(
203        "CPU tensor: shape {:?}, device: {:?}",
204        cpu_tensor.shape().dims(),
205        cpu_tensor.device()
206    );
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_tensor_creation() {
215        let tensor = Tensor::zeros(vec![2, 3]);
216        assert_eq!(tensor.shape().dims(), vec![2, 3]);
217        assert_eq!(tensor.size(), 6);
218        assert_eq!(tensor.data(), &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
219    }
220
221    #[test]
222    fn test_basic_operations() {
223        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
224        let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
225
226        let sum = a.add_tensor(&b);
227        assert_eq!(sum.data(), &[4.0, 6.0]);
228
229        let scalar_add = a.add_scalar(5.0);
230        assert_eq!(scalar_add.data(), &[6.0, 7.0]);
231    }
232
233    #[test]
234    fn test_shape_operations() {
235        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
236
237        let reshaped = tensor.view(vec![4]);
238        assert_eq!(reshaped.shape().dims(), vec![4]);
239        assert_eq!(reshaped.data(), tensor.data());
240    }
241
242    #[test]
243    fn test_data_access() {
244        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
245
246        assert_eq!(tensor.get(&[0, 0]), 1.0);
247        assert_eq!(tensor.get(&[0, 1]), 2.0);
248        assert_eq!(tensor.get(&[1, 0]), 3.0);
249        assert_eq!(tensor.get(&[1, 1]), 4.0);
250    }
251}