train_station/tensor/transform/flatten.rs
1//! Tensor flattening operations
2//!
3//! This module provides tensor flattening functionality that transforms
4//! multi-dimensional tensors into 1D representations. Flattening is a
5//! fundamental tensor transformation operation used in machine learning
6//! for preparing data for linear layers, feature extraction, and
7//! dimensionality reduction.
8//!
9//! # Operations
10//!
11//! * `flatten()` - Flatten a tensor into a 1D representation
12//!
13//! # Performance Characteristics
14//!
15//! * **Zero-Copy Operation**: Returns a view when possible, avoiding data copying
16//! * **Memory Efficient**: Reuses existing tensor data through reshape operations
17//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
18//! * **Shape Preservation**: Maintains the total number of elements while changing dimensions
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Flatten a 2D tensor
26//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
27//! let flattened = tensor.flatten();
28//! assert_eq!(flattened.shape().dims, vec![4]);
29//!
30//! // Flatten a 3D tensor
31//! let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
32//! let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
33//! let flattened = tensor.flatten();
34//! assert_eq!(flattened.shape().dims, vec![12]);
35//! ```
36
37use crate::tensor::Tensor;
38
39impl Tensor {
40 /// Flatten the tensor into a 1D representation
41 ///
42 /// Transforms a multi-dimensional tensor into a 1D tensor by reshaping
43 /// all dimensions into a single dimension. This is equivalent to
44 /// `reshape(vec![-1])` where `-1` automatically calculates the size
45 /// based on the total number of elements.
46 ///
47 /// The flatten operation preserves the total number of elements while
48 /// changing the tensor's shape to have a single dimension. This is
49 /// commonly used in neural networks to prepare tensor data for linear
50 /// layers or feature extraction.
51 ///
52 /// # Returns
53 ///
54 /// A 1D tensor containing the same data as the original tensor, with
55 /// shape `[total_elements]` where `total_elements` is the product of
56 /// all original dimensions.
57 ///
58 /// # Examples
59 ///
60 /// ```
61 /// use train_station::Tensor;
62 ///
63 /// // Flatten a 2D tensor
64 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
65 /// let flattened = tensor.flatten();
66 /// assert_eq!(flattened.shape().dims, vec![4]);
67 /// assert_eq!(flattened.get(&[0]), 1.0);
68 /// assert_eq!(flattened.get(&[1]), 2.0);
69 /// assert_eq!(flattened.get(&[2]), 3.0);
70 /// assert_eq!(flattened.get(&[3]), 4.0);
71 /// ```
72 ///
73 /// ```
74 /// use train_station::Tensor;
75 ///
76 /// // Flatten a 3D tensor
77 /// let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
78 /// let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
79 /// let flattened = tensor.flatten();
80 /// assert_eq!(flattened.shape().dims, vec![12]);
81 /// assert_eq!(flattened.size(), 12);
82 /// ```
83 ///
84 /// ```
85 /// use train_station::Tensor;
86 ///
87 /// // Flatten with gradient tracking
88 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
89 /// tensor.set_requires_grad(true);
90 ///
91 /// let flattened = tensor.flatten();
92 /// assert!(flattened.requires_grad());
93 /// assert_eq!(flattened.shape().dims, vec![4]);
94 /// ```
95 ///
96 /// ```
97 /// use train_station::Tensor;
98 ///
99 /// // Flatten an already 1D tensor (no change)
100 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
101 /// let flattened = tensor.flatten();
102 /// assert_eq!(flattened.shape().dims, vec![3]);
103 /// assert_eq!(flattened.size(), 3);
104 /// ```
105 ///
106 /// # Performance
107 ///
108 /// - **Time Complexity**: O(1) - Returns a view when possible
109 /// - **Memory Usage**: No additional memory allocation for view operations
110 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
111 ///
112 /// # Relationship to Other Operations
113 ///
114 /// This operation is equivalent to:
115 /// ```rust
116 /// use train_station::Tensor;
117 ///
118 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
119 /// let flattened = tensor.reshape(vec![-1]);
120 /// ```
121 ///
122 /// Where `-1` is a special value that automatically calculates the
123 /// dimension size based on the total number of elements in the tensor.
124 pub fn flatten(&self) -> Tensor {
125 self.reshape(vec![-1])
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_flatten() {
135 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
136 let flattened = tensor.flatten();
137
138 assert_eq!(flattened.shape().dims, vec![4]);
139 assert_eq!(flattened.size(), 4);
140 }
141
142 #[test]
143 fn test_flatten_3d() {
144 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
145 let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
146 let flattened = tensor.flatten();
147
148 assert_eq!(flattened.shape().dims, vec![24]);
149 assert_eq!(flattened.size(), 24);
150 }
151
152 #[test]
153 fn test_flatten_already_1d() {
154 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
155 let flattened = tensor.flatten();
156
157 assert_eq!(flattened.shape().dims, vec![3]);
158 assert_eq!(flattened.size(), 3);
159 }
160
161 #[test]
162 fn test_flatten_preserves_data() {
163 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
164 let tensor = Tensor::from_slice(&data, vec![2, 3]).unwrap();
165 let flattened = tensor.flatten();
166
167 // Verify data is preserved
168 for (i, &d) in data.iter().enumerate().take(data.len()) {
169 assert_eq!(flattened.get(&[i]), d);
170 }
171 }
172}