train_station/tensor/transform/flatten.rs
1//! Tensor flattening operations
2//!
3//! This module provides tensor flattening functionality that transforms
4//! multi-dimensional tensors into 1D representations. Flattening is a
5//! fundamental tensor transformation operation used in machine learning
6//! for preparing data for linear layers, feature extraction, and
7//! dimensionality reduction.
8//!
9//! # Operations
10//!
11//! * `flatten()` - Flatten a tensor into a 1D representation
12//!
13//! # Performance Characteristics
14//!
15//! * **Zero-Copy Operation**: Returns a view when possible, avoiding data copying
16//! * **Memory Efficient**: Reuses existing tensor data through reshape operations
17//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
18//! * **Shape Preservation**: Maintains the total number of elements while changing dimensions
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Flatten a 2D tensor
26//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
27//! let flattened = tensor.flatten();
28//! assert_eq!(flattened.shape().dims, vec![4]);
29//!
30//! // Flatten a 3D tensor
31//! let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
32//! let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
33//! let flattened = tensor.flatten();
34//! assert_eq!(flattened.shape().dims, vec![12]);
35//! ```
36
37use crate::tensor::Tensor;
38
39impl Tensor {
40 /// Flatten the tensor into a 1D representation
41 ///
42 /// Transforms a multi-dimensional tensor into a 1D tensor by reshaping
43 /// all dimensions into a single dimension. This is equivalent to
44 /// `reshape(vec![-1])` where `-1` automatically calculates the size
45 /// based on the total number of elements.
46 ///
47 /// The flatten operation preserves the total number of elements while
48 /// changing the tensor's shape to have a single dimension. This is
49 /// commonly used in neural networks to prepare tensor data for linear
50 /// layers or feature extraction.
51 ///
52 /// # Returns
53 ///
54 /// A 1D tensor containing the same data as the original tensor, with
55 /// shape `[total_elements]` where `total_elements` is the product of
56 /// all original dimensions.
57 ///
58 /// # Examples
59 ///
60 /// ```
61 /// use train_station::Tensor;
62 ///
63 /// // Flatten a 2D tensor
64 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
65 /// let flattened = tensor.flatten();
66 /// assert_eq!(flattened.shape().dims, vec![4]);
67 /// assert_eq!(flattened.get(&[0]), 1.0);
68 /// assert_eq!(flattened.get(&[1]), 2.0);
69 /// assert_eq!(flattened.get(&[2]), 3.0);
70 /// assert_eq!(flattened.get(&[3]), 4.0);
71 /// ```
72 ///
73 /// ```
74 /// use train_station::Tensor;
75 ///
76 /// // Flatten a 3D tensor
77 /// let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
78 /// let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
79 /// let flattened = tensor.flatten();
80 /// assert_eq!(flattened.shape().dims, vec![12]);
81 /// assert_eq!(flattened.size(), 12);
82 /// ```
83 ///
84 /// ```
85 /// use train_station::Tensor;
86 ///
87 /// // Flatten with gradient tracking
88 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
89 /// tensor.set_requires_grad(true);
90 ///
91 /// let flattened = tensor.flatten();
92 /// assert!(flattened.requires_grad());
93 /// assert_eq!(flattened.shape().dims, vec![4]);
94 /// ```
95 ///
96 /// ```
97 /// use train_station::Tensor;
98 ///
99 /// // Flatten an already 1D tensor (no change)
100 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
101 /// let flattened = tensor.flatten();
102 /// assert_eq!(flattened.shape().dims, vec![3]);
103 /// assert_eq!(flattened.size(), 3);
104 /// ```
105 ///
106 /// # Performance
107 ///
108 /// - **Time Complexity**: O(1) - Returns a view when possible
109 /// - **Memory Usage**: No additional memory allocation for view operations
110 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
111 ///
112 /// # Relationship to Other Operations
113 ///
114 /// This operation is equivalent to:
115 /// ```rust
116 /// use train_station::Tensor;
117 ///
118 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
119 /// let flattened = tensor.reshape(vec![-1]);
120 /// ```
121 ///
122 /// Where `-1` is a special value that automatically calculates the
123 /// dimension size based on the total number of elements in the tensor.
124 #[track_caller]
125 pub fn flatten(&self) -> Tensor {
126 self.reshape(vec![-1])
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_flatten() {
136 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
137 let flattened = tensor.flatten();
138
139 assert_eq!(flattened.shape().dims, vec![4]);
140 assert_eq!(flattened.size(), 4);
141 }
142
143 #[test]
144 fn test_flatten_3d() {
145 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
146 let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
147 let flattened = tensor.flatten();
148
149 assert_eq!(flattened.shape().dims, vec![24]);
150 assert_eq!(flattened.size(), 24);
151 }
152
153 #[test]
154 fn test_flatten_already_1d() {
155 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
156 let flattened = tensor.flatten();
157
158 assert_eq!(flattened.shape().dims, vec![3]);
159 assert_eq!(flattened.size(), 3);
160 }
161
162 #[test]
163 fn test_flatten_preserves_data() {
164 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
165 let tensor = Tensor::from_slice(&data, vec![2, 3]).unwrap();
166 let flattened = tensor.flatten();
167
168 // Verify data is preserved
169 for (i, &d) in data.iter().enumerate().take(data.len()) {
170 assert_eq!(flattened.get(&[i]), d);
171 }
172 }
173}