train_station/tensor/transform/
flatten.rs

1//! Tensor flattening operations
2//!
3//! This module provides tensor flattening functionality that transforms
4//! multi-dimensional tensors into 1D representations. Flattening is a
5//! fundamental tensor transformation operation used in machine learning
6//! for preparing data for linear layers, feature extraction, and
7//! dimensionality reduction.
8//!
9//! # Operations
10//!
11//! * `flatten()` - Flatten a tensor into a 1D representation
12//!
13//! # Performance Characteristics
14//!
15//! * **Zero-Copy Operation**: Returns a view when possible, avoiding data copying
16//! * **Memory Efficient**: Reuses existing tensor data through reshape operations
17//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
18//! * **Shape Preservation**: Maintains the total number of elements while changing dimensions
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Flatten a 2D tensor
26//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
27//! let flattened = tensor.flatten();
28//! assert_eq!(flattened.shape().dims, vec![4]);
29//!
30//! // Flatten a 3D tensor
31//! let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
32//! let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
33//! let flattened = tensor.flatten();
34//! assert_eq!(flattened.shape().dims, vec![12]);
35//! ```
36
37use crate::tensor::Tensor;
38
39impl Tensor {
40    /// Flatten the tensor into a 1D representation
41    ///
42    /// Transforms a multi-dimensional tensor into a 1D tensor by reshaping
43    /// all dimensions into a single dimension. This is equivalent to
44    /// `reshape(vec![-1])` where `-1` automatically calculates the size
45    /// based on the total number of elements.
46    ///
47    /// The flatten operation preserves the total number of elements while
48    /// changing the tensor's shape to have a single dimension. This is
49    /// commonly used in neural networks to prepare tensor data for linear
50    /// layers or feature extraction.
51    ///
52    /// # Returns
53    ///
54    /// A 1D tensor containing the same data as the original tensor, with
55    /// shape `[total_elements]` where `total_elements` is the product of
56    /// all original dimensions.
57    ///
58    /// # Examples
59    ///
60    /// ```
61    /// use train_station::Tensor;
62    ///
63    /// // Flatten a 2D tensor
64    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
65    /// let flattened = tensor.flatten();
66    /// assert_eq!(flattened.shape().dims, vec![4]);
67    /// assert_eq!(flattened.get(&[0]), 1.0);
68    /// assert_eq!(flattened.get(&[1]), 2.0);
69    /// assert_eq!(flattened.get(&[2]), 3.0);
70    /// assert_eq!(flattened.get(&[3]), 4.0);
71    /// ```
72    ///
73    /// ```
74    /// use train_station::Tensor;
75    ///
76    /// // Flatten a 3D tensor
77    /// let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
78    /// let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
79    /// let flattened = tensor.flatten();
80    /// assert_eq!(flattened.shape().dims, vec![12]);
81    /// assert_eq!(flattened.size(), 12);
82    /// ```
83    ///
84    /// ```
85    /// use train_station::Tensor;
86    ///
87    /// // Flatten with gradient tracking
88    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
89    /// tensor.set_requires_grad(true);
90    ///
91    /// let flattened = tensor.flatten();
92    /// assert!(flattened.requires_grad());
93    /// assert_eq!(flattened.shape().dims, vec![4]);
94    /// ```
95    ///
96    /// ```
97    /// use train_station::Tensor;
98    ///
99    /// // Flatten an already 1D tensor (no change)
100    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
101    /// let flattened = tensor.flatten();
102    /// assert_eq!(flattened.shape().dims, vec![3]);
103    /// assert_eq!(flattened.size(), 3);
104    /// ```
105    ///
106    /// # Performance
107    ///
108    /// - **Time Complexity**: O(1) - Returns a view when possible
109    /// - **Memory Usage**: No additional memory allocation for view operations
110    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
111    ///
112    /// # Relationship to Other Operations
113    ///
114    /// This operation is equivalent to:
115    /// ```rust
116    /// use train_station::Tensor;
117    ///
118    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
119    /// let flattened = tensor.reshape(vec![-1]);
120    /// ```
121    ///
122    /// Where `-1` is a special value that automatically calculates the
123    /// dimension size based on the total number of elements in the tensor.
124    pub fn flatten(&self) -> Tensor {
125        self.reshape(vec![-1])
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_flatten() {
135        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
136        let flattened = tensor.flatten();
137
138        assert_eq!(flattened.shape().dims, vec![4]);
139        assert_eq!(flattened.size(), 4);
140    }
141
142    #[test]
143    fn test_flatten_3d() {
144        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
145        let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
146        let flattened = tensor.flatten();
147
148        assert_eq!(flattened.shape().dims, vec![24]);
149        assert_eq!(flattened.size(), 24);
150    }
151
152    #[test]
153    fn test_flatten_already_1d() {
154        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
155        let flattened = tensor.flatten();
156
157        assert_eq!(flattened.shape().dims, vec![3]);
158        assert_eq!(flattened.size(), 3);
159    }
160
161    #[test]
162    fn test_flatten_preserves_data() {
163        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
164        let tensor = Tensor::from_slice(&data, vec![2, 3]).unwrap();
165        let flattened = tensor.flatten();
166
167        // Verify data is preserved
168        for (i, &d) in data.iter().enumerate().take(data.len()) {
169            assert_eq!(flattened.get(&[i]), d);
170        }
171    }
172}