train_station/tensor/transform/
reshape.rs

1//! Tensor reshape operations
2//!
3//! This module provides tensor reshape functionality that changes the
4//! dimensions of a tensor while preserving the total number of elements.
5//! Reshaping is a fundamental tensor transformation operation used in
6//! machine learning for preparing data for different layer types,
7//! implementing complex tensor manipulations, and adapting tensor shapes
8//! for specific operations.
9//!
10//! # Operations
11//!
12//! * `reshape()` - Reshape tensor to specified dimensions with automatic inference
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operation**: Returns a view when tensor is contiguous
17//! * **Memory Efficient**: Reuses existing tensor data through view operations
18//! * **Automatic Inference**: Supports -1 dimension for automatic size calculation
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Validation**: Comprehensive error checking for invalid reshape operations
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Basic reshape
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let reshaped = tensor.reshape(vec![3, 2]);
30//! assert_eq!(reshaped.shape().dims(), vec![3, 2]);
31//!
32//! // Automatic dimension inference with -1
33//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
34//! let reshaped = tensor.reshape(vec![2, -1]);
35//! assert_eq!(reshaped.shape().dims(), vec![2, 2]);
36//! ```
37
38use crate::gradtrack::{GradEngine, GradFn};
39use crate::tensor::core::Tensor;
40// Shape is referenced via core view helpers; direct import not needed here
41
42impl Tensor {
43    /// Reshape the tensor to the specified dimensions
44    ///
45    /// Changes the shape of the tensor while preserving the total number of elements.
46    /// This operation returns a view when the tensor is contiguous, avoiding data
47    /// copying. For non-contiguous tensors, data is copied to ensure the reshape
48    /// is valid.
49    ///
50    /// The reshape operation supports automatic dimension inference using -1,
51    /// which allows one dimension to be automatically calculated based on the
52    /// total number of elements and the other specified dimensions.
53    ///
54    /// # Arguments
55    ///
56    /// * `new_shape` - Target shape for the tensor. Use -1 for one dimension
57    ///   to have it automatically inferred from the total size.
58    ///
59    /// # Returns
60    ///
61    /// A new tensor with the specified shape containing the same data as the
62    /// original tensor.
63    ///
64    /// # Panics
65    ///
66    /// * If more than one dimension is -1
67    /// * If the total number of elements doesn't match the original tensor
68    /// * If any dimension size is 0 or less than -1
69    /// * If the inferred dimension size is not a whole number
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// use train_station::Tensor;
75    ///
76    /// // Basic reshape
77    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
78    /// let reshaped = tensor.reshape(vec![3, 2]);
79    /// assert_eq!(reshaped.shape().dims(), vec![3, 2]);
80    /// assert_eq!(reshaped.get(&[0, 0]), 1.0);
81    /// assert_eq!(reshaped.get(&[2, 1]), 6.0);
82    /// ```
83    ///
84    /// ```
85    /// use train_station::Tensor;
86    ///
87    /// // Using -1 for automatic dimension inference
88    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
89    /// let reshaped = tensor.reshape(vec![2, -1]);
90    /// assert_eq!(reshaped.shape().dims(), vec![2, 2]);
91    /// assert_eq!(reshaped.get(&[0, 0]), 1.0);
92    /// assert_eq!(reshaped.get(&[1, 1]), 4.0);
93    /// ```
94    ///
95    /// ```
96    /// use train_station::Tensor;
97    ///
98    /// // Reshape with gradient tracking
99    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
100    /// tensor.set_requires_grad(true);
101    ///
102    /// let reshaped = tensor.reshape(vec![4]);
103    /// assert!(reshaped.requires_grad());
104    /// assert_eq!(reshaped.shape().dims(), vec![4]);
105    /// ```
106    ///
107    /// ```
108    /// use train_station::Tensor;
109    ///
110    /// // Reshape 3D tensor
111    /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
112    /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
113    /// let reshaped = tensor.reshape(vec![6, 4]);
114    /// assert_eq!(reshaped.shape().dims(), vec![6, 4]);
115    /// assert_eq!(reshaped.size(), 24);
116    /// ```
117    ///
118    /// # Performance
119    ///
120    /// - **Contiguous tensors**: O(1) time complexity, returns a view
121    /// - **Non-contiguous tensors**: O(n) time complexity with data copying
122    /// - **Memory usage**: No additional allocation for view operations
123    /// - **Gradient tracking**: Preserves gradient requirements and tracking
124    ///
125    /// # Automatic Dimension Inference
126    ///
127    /// When using -1 for a dimension, the size is automatically calculated:
128    /// ```rust
129    /// use train_station::Tensor;
130    ///
131    /// // For a tensor with 12 elements
132    /// let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
133    /// let tensor = Tensor::from_slice(&data, vec![3, 4]).unwrap();
134    ///
135    /// let reshaped1 = tensor.reshape(vec![3, -1]);  // Results in shape [3, 4]
136    /// let reshaped2 = tensor.reshape(vec![-1, 6]);  // Results in shape [2, 6]
137    /// let reshaped3 = tensor.reshape(vec![-1]);     // Results in shape [12]
138    /// ```
139    #[track_caller]
140    pub fn reshape(&self, new_shape: Vec<i32>) -> Tensor {
141        // Validate and process the new shape
142        let processed_shape = self.process_reshape_dimensions(new_shape);
143
144        // Validate that total size matches
145        let new_size: usize = processed_shape.iter().product();
146        assert_eq!(
147            new_size,
148            self.size(),
149            "Cannot reshape tensor of size {} to shape {:?} (size {})",
150            self.size(),
151            processed_shape,
152            new_size
153        );
154        // Zero-copy reshape using core reshape_view validation
155        match crate::tensor::core::view::reshape_view(self, &processed_shape) {
156            Ok(mut v) => {
157                if self.requires_grad() {
158                    v.set_requires_grad(true);
159                    let grad_fn = GradFn::Reshape {
160                        original_shape: self.shape().dims().to_vec(),
161                    };
162                    v.set_grad_fn(grad_fn.clone());
163                    GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
164                }
165                v
166            }
167            Err(_) => {
168                // If core reshape fails (non-contiguous), materialize contiguous then view
169                let contiguous = self.contiguous();
170                let mut v =
171                    match crate::tensor::core::view::reshape_view(&contiguous, &processed_shape) {
172                        Ok(v) => v,
173                        Err(e) => panic!("reshape error: {:?}", e),
174                    };
175                if self.requires_grad() {
176                    v.set_requires_grad(true);
177                    let grad_fn = GradFn::Reshape {
178                        original_shape: self.shape().dims().to_vec(),
179                    };
180                    v.set_grad_fn(grad_fn.clone());
181                    GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
182                }
183                v
184            }
185        }
186    }
187
188    /// Process reshape dimensions and handle -1 inference
189    ///
190    /// Validates reshape dimensions and automatically infers the size of any
191    /// dimension marked as -1. This method ensures that the reshape operation
192    /// is valid and calculates the appropriate dimension sizes.
193    ///
194    /// # Arguments
195    ///
196    /// * `new_shape` - Target shape with possible -1 for inference
197    ///
198    /// # Returns
199    ///
200    /// Processed shape with all dimensions as positive usize values
201    ///
202    /// # Panics
203    ///
204    /// * If more than one dimension is -1
205    /// * If any dimension size is 0 or less than -1
206    /// * If the total size is not divisible by the known dimensions
207    ///
208    /// # Examples
209    ///
210    /// ```
211    /// use train_station::Tensor;
212    ///
213    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
214    /// // This internally calls process_reshape_dimensions
215    /// let reshaped = tensor.reshape(vec![2, -1]);
216    /// assert_eq!(reshaped.shape().dims(), vec![2, 2]);
217    /// ```
218    pub(crate) fn process_reshape_dimensions(&self, new_shape: Vec<i32>) -> Vec<usize> {
219        // Validate input dimensions
220        let mut infer_dim = None;
221        let mut known_size = 1usize;
222
223        for (i, &dim) in new_shape.iter().enumerate() {
224            if dim == -1 {
225                if infer_dim.is_some() {
226                    panic!("Only one dimension can be -1 for automatic inference");
227                }
228                infer_dim = Some(i);
229            } else if dim <= 0 {
230                panic!("Dimension sizes must be positive, got {}", dim);
231            } else {
232                known_size *= dim as usize;
233            }
234        }
235
236        // Convert to usize and infer -1 dimension
237        let mut processed: Vec<usize> = new_shape
238            .iter()
239            .map(|&d| if d == -1 { 0 } else { d as usize })
240            .collect();
241
242        if let Some(infer_idx) = infer_dim {
243            let total_size = self.size();
244            if known_size == 0 || !total_size.is_multiple_of(known_size) {
245                panic!(
246                    "Cannot infer dimension size: total size {} not divisible by known size {}",
247                    total_size, known_size
248                );
249            }
250            processed[infer_idx] = total_size / known_size;
251        }
252
253        processed
254    }
255
256    // removed: private reshape_view in favor of core view::reshape_view
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_basic_reshape() {
265        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
266        let reshaped = tensor.reshape(vec![3, 2]);
267
268        assert_eq!(reshaped.shape().dims(), vec![3, 2]);
269        assert_eq!(reshaped.size(), 6);
270
271        // Verify data integrity
272        assert_eq!(reshaped.get(&[0, 0]), 1.0);
273        assert_eq!(reshaped.get(&[2, 1]), 6.0);
274    }
275
276    #[test]
277    fn test_reshape_with_inference() {
278        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
279        let reshaped = tensor.reshape(vec![2, -1]);
280
281        assert_eq!(reshaped.shape().dims(), vec![2, 2]);
282        assert_eq!(reshaped.size(), 4);
283    }
284
285    #[test]
286    fn test_reshape_autograd() {
287        let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
288        tensor.set_requires_grad(true);
289
290        let reshaped = tensor.reshape(vec![4]);
291        assert!(reshaped.requires_grad());
292        assert!(!matches!(reshaped.grad_fn(), GradFn::None));
293    }
294
295    #[test]
296    #[should_panic(expected = "Only one dimension can be -1")]
297    fn test_multiple_infer_dimensions() {
298        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
299        tensor.reshape(vec![-1, -1]);
300    }
301
302    #[test]
303    #[should_panic(expected = "Cannot reshape tensor of size 4")]
304    fn test_invalid_size() {
305        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
306        tensor.reshape(vec![2, 3]); // 2*3 = 6 != 4
307    }
308
309    #[test]
310    #[should_panic(expected = "Dimension sizes must be positive")]
311    fn test_negative_dimension() {
312        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
313        tensor.reshape(vec![2, -2]);
314    }
315
316    #[test]
317    fn test_large_tensor_reshape() {
318        // Test with larger tensor to verify performance
319        let data: Vec<f32> = (0..1000).map(|i| i as f32).collect();
320        let tensor = Tensor::from_slice(&data, vec![10, 100]).unwrap();
321
322        let reshaped = tensor.reshape(vec![25, 40]);
323        assert_eq!(reshaped.shape().dims(), vec![25, 40]);
324        assert_eq!(reshaped.size(), 1000);
325
326        // Verify first and last elements preserved
327        assert_eq!(reshaped.get(&[0, 0]), 0.0);
328        assert_eq!(reshaped.get(&[24, 39]), 999.0);
329    }
330
331    #[test]
332    fn test_reshape_edge_cases() {
333        // Scalar to 1D
334        let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
335        let reshaped = scalar.reshape(vec![-1]);
336        assert_eq!(reshaped.shape().dims(), vec![1]);
337
338        // 1D to scalar (well, size-1 tensor)
339        let tensor = Tensor::from_slice(&[42.0], vec![1]).unwrap();
340        let reshaped = tensor.reshape(vec![1]);
341        assert_eq!(reshaped.shape().dims(), vec![1]);
342    }
343
344    #[test]
345    fn test_multi_operation_with_reshape() {
346        // Test that reshape works well with other operations
347        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
348        let reshaped = tensor.reshape(vec![4]);
349        let transposed = reshaped.reshape(vec![1, 4]);
350
351        assert_eq!(transposed.shape().dims(), vec![1, 4]);
352        assert_eq!(transposed.get(&[0, 3]), 4.0);
353    }
354
355    #[test]
356    fn test_reshape_with_autograd_chain() {
357        // Test autograd with reshape in a computation chain
358        let mut a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
359        let b = Tensor::from_slice(&[0.5, 0.5, 0.5, 0.5], vec![4]).unwrap();
360
361        a.set_requires_grad(true);
362
363        // Reshape a to match b's shape, then add
364        let reshaped_a = a.reshape(vec![4]);
365        assert!(reshaped_a.requires_grad());
366
367        let result = reshaped_a.add_tensor_optimized(&b);
368
369        assert_eq!(result.shape().dims(), vec![4]);
370        // Note: add_tensor_optimized may not preserve gradients for mixed operations
371        // In a full implementation, we'd use the AutogradTensor trait methods
372
373        // Verify values
374        assert_eq!(result.get(&[0]), 1.5); // 1.0 + 0.5
375        assert_eq!(result.get(&[3]), 4.5); // 4.0 + 0.5
376    }
377}