train_station/tensor/transform/reshape.rs
1//! Tensor reshape operations
2//!
3//! This module provides tensor reshape functionality that changes the
4//! dimensions of a tensor while preserving the total number of elements.
5//! Reshaping is a fundamental tensor transformation operation used in
6//! machine learning for preparing data for different layer types,
7//! implementing complex tensor manipulations, and adapting tensor shapes
8//! for specific operations.
9//!
10//! # Operations
11//!
12//! * `reshape()` - Reshape tensor to specified dimensions with automatic inference
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operation**: Returns a view when tensor is contiguous
17//! * **Memory Efficient**: Reuses existing tensor data through view operations
18//! * **Automatic Inference**: Supports -1 dimension for automatic size calculation
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Validation**: Comprehensive error checking for invalid reshape operations
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Basic reshape
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let reshaped = tensor.reshape(vec![3, 2]);
30//! assert_eq!(reshaped.shape().dims(), vec![3, 2]);
31//!
32//! // Automatic dimension inference with -1
33//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
34//! let reshaped = tensor.reshape(vec![2, -1]);
35//! assert_eq!(reshaped.shape().dims(), vec![2, 2]);
36//! ```
37
38use crate::gradtrack::{GradEngine, GradFn};
39use crate::tensor::core::Tensor;
40// Shape is referenced via core view helpers; direct import not needed here
41
42impl Tensor {
43 /// Reshape the tensor to the specified dimensions
44 ///
45 /// Changes the shape of the tensor while preserving the total number of elements.
46 /// This operation returns a view when the tensor is contiguous, avoiding data
47 /// copying. For non-contiguous tensors, data is copied to ensure the reshape
48 /// is valid.
49 ///
50 /// The reshape operation supports automatic dimension inference using -1,
51 /// which allows one dimension to be automatically calculated based on the
52 /// total number of elements and the other specified dimensions.
53 ///
54 /// # Arguments
55 ///
56 /// * `new_shape` - Target shape for the tensor. Use -1 for one dimension
57 /// to have it automatically inferred from the total size.
58 ///
59 /// # Returns
60 ///
61 /// A new tensor with the specified shape containing the same data as the
62 /// original tensor.
63 ///
64 /// # Panics
65 ///
66 /// * If more than one dimension is -1
67 /// * If the total number of elements doesn't match the original tensor
68 /// * If any dimension size is 0 or less than -1
69 /// * If the inferred dimension size is not a whole number
70 ///
71 /// # Examples
72 ///
73 /// ```
74 /// use train_station::Tensor;
75 ///
76 /// // Basic reshape
77 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
78 /// let reshaped = tensor.reshape(vec![3, 2]);
79 /// assert_eq!(reshaped.shape().dims(), vec![3, 2]);
80 /// assert_eq!(reshaped.get(&[0, 0]), 1.0);
81 /// assert_eq!(reshaped.get(&[2, 1]), 6.0);
82 /// ```
83 ///
84 /// ```
85 /// use train_station::Tensor;
86 ///
87 /// // Using -1 for automatic dimension inference
88 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
89 /// let reshaped = tensor.reshape(vec![2, -1]);
90 /// assert_eq!(reshaped.shape().dims(), vec![2, 2]);
91 /// assert_eq!(reshaped.get(&[0, 0]), 1.0);
92 /// assert_eq!(reshaped.get(&[1, 1]), 4.0);
93 /// ```
94 ///
95 /// ```
96 /// use train_station::Tensor;
97 ///
98 /// // Reshape with gradient tracking
99 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
100 /// tensor.set_requires_grad(true);
101 ///
102 /// let reshaped = tensor.reshape(vec![4]);
103 /// assert!(reshaped.requires_grad());
104 /// assert_eq!(reshaped.shape().dims(), vec![4]);
105 /// ```
106 ///
107 /// ```
108 /// use train_station::Tensor;
109 ///
110 /// // Reshape 3D tensor
111 /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
112 /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
113 /// let reshaped = tensor.reshape(vec![6, 4]);
114 /// assert_eq!(reshaped.shape().dims(), vec![6, 4]);
115 /// assert_eq!(reshaped.size(), 24);
116 /// ```
117 ///
118 /// # Performance
119 ///
120 /// - **Contiguous tensors**: O(1) time complexity, returns a view
121 /// - **Non-contiguous tensors**: O(n) time complexity with data copying
122 /// - **Memory usage**: No additional allocation for view operations
123 /// - **Gradient tracking**: Preserves gradient requirements and tracking
124 ///
125 /// # Automatic Dimension Inference
126 ///
127 /// When using -1 for a dimension, the size is automatically calculated:
128 /// ```rust
129 /// use train_station::Tensor;
130 ///
131 /// // For a tensor with 12 elements
132 /// let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
133 /// let tensor = Tensor::from_slice(&data, vec![3, 4]).unwrap();
134 ///
135 /// let reshaped1 = tensor.reshape(vec![3, -1]); // Results in shape [3, 4]
136 /// let reshaped2 = tensor.reshape(vec![-1, 6]); // Results in shape [2, 6]
137 /// let reshaped3 = tensor.reshape(vec![-1]); // Results in shape [12]
138 /// ```
139 #[track_caller]
140 pub fn reshape(&self, new_shape: Vec<i32>) -> Tensor {
141 // Validate and process the new shape
142 let processed_shape = self.process_reshape_dimensions(new_shape);
143
144 // Validate that total size matches
145 let new_size: usize = processed_shape.iter().product();
146 assert_eq!(
147 new_size,
148 self.size(),
149 "Cannot reshape tensor of size {} to shape {:?} (size {})",
150 self.size(),
151 processed_shape,
152 new_size
153 );
154 // Zero-copy reshape using core reshape_view validation
155 match crate::tensor::core::view::reshape_view(self, &processed_shape) {
156 Ok(mut v) => {
157 if self.requires_grad() {
158 v.set_requires_grad(true);
159 let grad_fn = GradFn::Reshape {
160 original_shape: self.shape().dims().to_vec(),
161 };
162 v.set_grad_fn(grad_fn.clone());
163 GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
164 }
165 v
166 }
167 Err(_) => {
168 // If core reshape fails (non-contiguous), materialize contiguous then view
169 let contiguous = self.contiguous();
170 let mut v =
171 match crate::tensor::core::view::reshape_view(&contiguous, &processed_shape) {
172 Ok(v) => v,
173 Err(e) => panic!("reshape error: {:?}", e),
174 };
175 if self.requires_grad() {
176 v.set_requires_grad(true);
177 let grad_fn = GradFn::Reshape {
178 original_shape: self.shape().dims().to_vec(),
179 };
180 v.set_grad_fn(grad_fn.clone());
181 GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
182 }
183 v
184 }
185 }
186 }
187
188 /// Process reshape dimensions and handle -1 inference
189 ///
190 /// Validates reshape dimensions and automatically infers the size of any
191 /// dimension marked as -1. This method ensures that the reshape operation
192 /// is valid and calculates the appropriate dimension sizes.
193 ///
194 /// # Arguments
195 ///
196 /// * `new_shape` - Target shape with possible -1 for inference
197 ///
198 /// # Returns
199 ///
200 /// Processed shape with all dimensions as positive usize values
201 ///
202 /// # Panics
203 ///
204 /// * If more than one dimension is -1
205 /// * If any dimension size is 0 or less than -1
206 /// * If the total size is not divisible by the known dimensions
207 ///
208 /// # Examples
209 ///
210 /// ```
211 /// use train_station::Tensor;
212 ///
213 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
214 /// // This internally calls process_reshape_dimensions
215 /// let reshaped = tensor.reshape(vec![2, -1]);
216 /// assert_eq!(reshaped.shape().dims(), vec![2, 2]);
217 /// ```
218 pub(crate) fn process_reshape_dimensions(&self, new_shape: Vec<i32>) -> Vec<usize> {
219 // Validate input dimensions
220 let mut infer_dim = None;
221 let mut known_size = 1usize;
222
223 for (i, &dim) in new_shape.iter().enumerate() {
224 if dim == -1 {
225 if infer_dim.is_some() {
226 panic!("Only one dimension can be -1 for automatic inference");
227 }
228 infer_dim = Some(i);
229 } else if dim <= 0 {
230 panic!("Dimension sizes must be positive, got {}", dim);
231 } else {
232 known_size *= dim as usize;
233 }
234 }
235
236 // Convert to usize and infer -1 dimension
237 let mut processed: Vec<usize> = new_shape
238 .iter()
239 .map(|&d| if d == -1 { 0 } else { d as usize })
240 .collect();
241
242 if let Some(infer_idx) = infer_dim {
243 let total_size = self.size();
244 if known_size == 0 || !total_size.is_multiple_of(known_size) {
245 panic!(
246 "Cannot infer dimension size: total size {} not divisible by known size {}",
247 total_size, known_size
248 );
249 }
250 processed[infer_idx] = total_size / known_size;
251 }
252
253 processed
254 }
255
256 // removed: private reshape_view in favor of core view::reshape_view
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_basic_reshape() {
265 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
266 let reshaped = tensor.reshape(vec![3, 2]);
267
268 assert_eq!(reshaped.shape().dims(), vec![3, 2]);
269 assert_eq!(reshaped.size(), 6);
270
271 // Verify data integrity
272 assert_eq!(reshaped.get(&[0, 0]), 1.0);
273 assert_eq!(reshaped.get(&[2, 1]), 6.0);
274 }
275
276 #[test]
277 fn test_reshape_with_inference() {
278 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
279 let reshaped = tensor.reshape(vec![2, -1]);
280
281 assert_eq!(reshaped.shape().dims(), vec![2, 2]);
282 assert_eq!(reshaped.size(), 4);
283 }
284
285 #[test]
286 fn test_reshape_autograd() {
287 let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
288 tensor.set_requires_grad(true);
289
290 let reshaped = tensor.reshape(vec![4]);
291 assert!(reshaped.requires_grad());
292 assert!(!matches!(reshaped.grad_fn(), GradFn::None));
293 }
294
295 #[test]
296 #[should_panic(expected = "Only one dimension can be -1")]
297 fn test_multiple_infer_dimensions() {
298 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
299 tensor.reshape(vec![-1, -1]);
300 }
301
302 #[test]
303 #[should_panic(expected = "Cannot reshape tensor of size 4")]
304 fn test_invalid_size() {
305 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
306 tensor.reshape(vec![2, 3]); // 2*3 = 6 != 4
307 }
308
309 #[test]
310 #[should_panic(expected = "Dimension sizes must be positive")]
311 fn test_negative_dimension() {
312 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
313 tensor.reshape(vec![2, -2]);
314 }
315
316 #[test]
317 fn test_large_tensor_reshape() {
318 // Test with larger tensor to verify performance
319 let data: Vec<f32> = (0..1000).map(|i| i as f32).collect();
320 let tensor = Tensor::from_slice(&data, vec![10, 100]).unwrap();
321
322 let reshaped = tensor.reshape(vec![25, 40]);
323 assert_eq!(reshaped.shape().dims(), vec![25, 40]);
324 assert_eq!(reshaped.size(), 1000);
325
326 // Verify first and last elements preserved
327 assert_eq!(reshaped.get(&[0, 0]), 0.0);
328 assert_eq!(reshaped.get(&[24, 39]), 999.0);
329 }
330
331 #[test]
332 fn test_reshape_edge_cases() {
333 // Scalar to 1D
334 let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
335 let reshaped = scalar.reshape(vec![-1]);
336 assert_eq!(reshaped.shape().dims(), vec![1]);
337
338 // 1D to scalar (well, size-1 tensor)
339 let tensor = Tensor::from_slice(&[42.0], vec![1]).unwrap();
340 let reshaped = tensor.reshape(vec![1]);
341 assert_eq!(reshaped.shape().dims(), vec![1]);
342 }
343
344 #[test]
345 fn test_multi_operation_with_reshape() {
346 // Test that reshape works well with other operations
347 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
348 let reshaped = tensor.reshape(vec![4]);
349 let transposed = reshaped.reshape(vec![1, 4]);
350
351 assert_eq!(transposed.shape().dims(), vec![1, 4]);
352 assert_eq!(transposed.get(&[0, 3]), 4.0);
353 }
354
355 #[test]
356 fn test_reshape_with_autograd_chain() {
357 // Test autograd with reshape in a computation chain
358 let mut a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
359 let b = Tensor::from_slice(&[0.5, 0.5, 0.5, 0.5], vec![4]).unwrap();
360
361 a.set_requires_grad(true);
362
363 // Reshape a to match b's shape, then add
364 let reshaped_a = a.reshape(vec![4]);
365 assert!(reshaped_a.requires_grad());
366
367 let result = reshaped_a.add_tensor_optimized(&b);
368
369 assert_eq!(result.shape().dims(), vec![4]);
370 // Note: add_tensor_optimized may not preserve gradients for mixed operations
371 // In a full implementation, we'd use the AutogradTensor trait methods
372
373 // Verify values
374 assert_eq!(result.get(&[0]), 1.5); // 1.0 + 0.5
375 assert_eq!(result.get(&[3]), 4.5); // 4.0 + 0.5
376 }
377}