train_station/tensor/transform/
contiguous.rs

1//! Contiguous tensor transformation operation
2//!
3//! This module provides functionality to create contiguous copies of tensors,
4//! ensuring that tensor data is stored in a linear, cache-friendly memory layout.
5//! Contiguous tensors are essential for optimal performance in many operations,
6//! particularly SIMD-optimized computations and operations that require
7//! sequential memory access patterns.
8//!
9//! # Memory Layout
10//!
11//! A tensor is considered contiguous when its elements are stored in memory
12//! in row-major order without gaps. Non-contiguous tensors can arise from
13//! operations like transpose, permute, or slice views that change the
14//! memory layout without copying data.
15//!
16//! # Performance Characteristics
17//!
18//! - **Already Contiguous**: O(1) time, returns a clone
19//! - **Small Tensors (≤64 elements)**: Simple copy with coordinate conversion
20//! - **Medium Tensors (65-1023 elements)**: Unrolled copy for better performance
21//! - **Large Tensors (≥1024 elements)**: Blocked copy with cache optimization
22//!
23//! # Examples
24//!
25//! ```
26//! use train_station::Tensor;
27//!
28//! // Create a contiguous tensor
29//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
30//! assert!(tensor.is_contiguous());
31//!
32//! // Create a non-contiguous tensor through transpose
33//! let transposed = tensor.transpose(0, 1);
34//! assert!(!transposed.is_contiguous());
35//!
36//! // Make it contiguous again
37//! let contiguous = transposed.contiguous();
38//! assert!(contiguous.is_contiguous());
39//! assert_eq!(contiguous.shape().dims(), vec![2, 2]);
40//! ```
41//!
42//! ```
43//! use train_station::Tensor;
44//!
45//! // Contiguous preserves gradient tracking
46//! let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
47//! tensor.set_requires_grad(true);
48//!
49//! let transposed = tensor.transpose(0, 1);
50//! let contiguous = transposed.contiguous();
51//! assert!(contiguous.requires_grad());
52//! ```
53//!
54//! # Gradient Tracking
55//!
56//! The contiguous operation supports automatic gradient tracking through
57//! the GradTrack system. When `requires_grad` is enabled, the operation
58//! registers a gradient function that ensures proper gradient flow during
59//! backward passes.
60
61use crate::gradtrack::{GradEngine, GradFn};
62use crate::tensor::iterator::collect::optimized_copy;
63use crate::tensor::Tensor;
64
65impl Tensor {
66    /// Creates a contiguous copy of the tensor
67    ///
68    /// This operation ensures that the tensor data is stored in a linear,
69    /// cache-friendly memory layout. If the tensor is already contiguous,
70    /// this operation returns a clone. For non-contiguous tensors, it
71    /// creates a new tensor with the same data but in contiguous memory layout.
72    ///
73    /// The operation uses different optimization strategies based on tensor size:
74    /// - Small tensors (≤64 elements): Simple coordinate-based copy
75    /// - Medium tensors (65-1023 elements): Unrolled copy for better performance
76    /// - Large tensors (≥1024 elements): Blocked copy with cache optimization
77    ///
78    /// # Returns
79    ///
80    /// A new tensor with contiguous memory layout containing the same data
81    ///
82    /// # Examples
83    ///
84    /// ```
85    /// use train_station::Tensor;
86    ///
87    /// // Already contiguous tensor
88    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
89    /// let contiguous = tensor.contiguous();
90    /// assert!(contiguous.is_contiguous());
91    /// assert_eq!(contiguous.shape().dims(), vec![2, 2]);
92    /// ```
93    ///
94    /// ```
95    /// use train_station::Tensor;
96    ///
97    /// // Non-contiguous tensor from transpose
98    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
99    /// let transposed = tensor.transpose(0, 1);
100    /// assert!(!transposed.is_contiguous());
101    ///
102    /// let contiguous = transposed.contiguous();
103    /// assert!(contiguous.is_contiguous());
104    /// assert_eq!(contiguous.get(&[0, 0]), 1.0);
105    /// assert_eq!(contiguous.get(&[0, 1]), 3.0);
106    /// ```
107    ///
108    /// ```
109    /// use train_station::Tensor;
110    ///
111    /// // Preserves gradient tracking
112    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
113    /// tensor.set_requires_grad(true);
114    ///
115    /// let contiguous = tensor.contiguous();
116    /// assert!(contiguous.requires_grad());
117    /// ```
118    ///
119    /// # Performance
120    ///
121    /// - **Already contiguous**: O(1) time complexity, returns a clone
122    /// - **Non-contiguous**: O(n) time complexity with size-dependent optimizations
123    /// - **Memory usage**: Creates a new tensor with the same size as the original
124    #[track_caller]
125    pub fn contiguous(&self) -> Tensor {
126        if self.is_contiguous() {
127            let mut cloned = self.clone();
128            // Ensure gradient requirements are preserved
129            if self.requires_grad() {
130                cloned.set_requires_grad(true);
131                // Register gradient function even for already-contiguous tensors
132                let grad_fn = GradFn::Contiguous {
133                    input_shape: self.shape().dims().to_vec(),
134                };
135                cloned.set_grad_fn(grad_fn.clone());
136                GradEngine::register_operation(cloned.id(), vec![self.id()], grad_fn);
137            }
138            return cloned;
139        }
140
141        // Create new contiguous tensor and copy via optimized methods
142        let mut result = Tensor::new(self.shape().dims().to_vec());
143
144        unsafe {
145            self.copy_to_contiguous_optimized(&mut result);
146        }
147
148        // Preserve gradient requirements and register gradient function
149        if self.requires_grad() {
150            result.set_requires_grad(true);
151            let grad_fn = GradFn::Contiguous {
152                input_shape: self.shape().dims().to_vec(),
153            };
154            result.set_grad_fn(grad_fn.clone());
155            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
156        }
157
158        result
159    }
160
161    /// Internal optimized contiguous copy operation
162    ///
163    /// This function dispatches to the appropriate copy strategy based on
164    /// tensor size and rank for optimal performance.
165    ///
166    /// # Arguments
167    ///
168    /// * `result` - The destination tensor to copy data into
169    ///
170    /// # Safety
171    ///
172    /// The caller must ensure:
173    /// * `result` has the same shape as `self`
174    /// * `result` is properly allocated and initialized
175    /// * Both tensors are valid and not moved during the operation
176    #[inline]
177    unsafe fn copy_to_contiguous_optimized(&self, result: &mut Tensor) {
178        let size = self.size();
179        let rank = self.shape().rank();
180
181        if size == 0 {
182            return;
183        }
184
185        // Fast path: if the last dimension is contiguous in the source view,
186        // copy row-by-row using SIMD-optimized contiguous copies.
187        if rank >= 1 && self.stride(rank - 1) == 1 {
188            let dims = self.shape().dims();
189            let row_len = dims[rank - 1];
190
191            // Number of outer rows to copy (all dims except the last)
192            let outer: usize = if rank == 1 {
193                1
194            } else {
195                dims[..rank - 1].iter().product()
196            };
197
198            let src_base = self.as_ptr();
199            let dst_base = result.as_mut_ptr();
200            let strides = self.strides();
201
202            // Coordinate vector for outer dimensions (exclude last)
203            let mut coords = vec![0usize; rank];
204
205            for outer_idx in 0..outer {
206                // Compute multi-index over dims[0..rank-1) in row-major order
207                if rank > 1 {
208                    let mut tmp = outer_idx;
209                    for i in (0..rank - 1).rev() {
210                        let d = dims[i];
211                        coords[i] = if d == 0 { 0 } else { tmp % d };
212                        if d != 0 {
213                            tmp /= d;
214                        }
215                    }
216                }
217                coords[rank - 1] = 0; // start of the contiguous row
218
219                // Compute source offset via strides
220                let mut src_off = 0usize;
221                for i in 0..rank {
222                    src_off += coords[i] * strides[i];
223                }
224
225                // Destination offset: linear index over outer dims times row_len
226                let mut dst_row_index = 0usize;
227                if rank > 1 {
228                    for i in 0..rank - 1 {
229                        dst_row_index = dst_row_index * dims[i] + coords[i];
230                    }
231                }
232                let dst_off = dst_row_index * row_len;
233
234                // Copy the entire row as a contiguous block
235                optimized_copy(src_base.add(src_off), dst_base.add(dst_off), row_len);
236            }
237            return;
238        }
239
240        // Fallback: general coordinate-based copy (works for any strided view)
241        self.copy_to_contiguous_simple(result, rank);
242    }
243
244    /// Simple copy for small tensors or 1D tensors
245    ///
246    /// This function performs a straightforward coordinate-based copy
247    /// suitable for small tensors where the overhead of more complex
248    /// optimizations would not be beneficial.
249    ///
250    /// # Arguments
251    ///
252    /// * `result` - The destination tensor
253    /// * `rank` - The rank of the tensor
254    ///
255    /// # Safety
256    ///
257    /// The caller must ensure both tensors are valid and properly allocated.
258    #[inline]
259    unsafe fn copy_to_contiguous_simple(&self, result: &mut Tensor, rank: usize) {
260        let size = self.size();
261        let src_ptr = self.as_ptr();
262        let dst_ptr = result.as_mut_ptr();
263
264        for dst_idx in 0..size {
265            // Compute destination coordinates under contiguous strides
266            let mut coords = vec![0usize; rank];
267            let mut tmp = dst_idx;
268            for i in (0..rank).rev() {
269                let dim_size = self.shape().dims()[i];
270                coords[i] = tmp % dim_size;
271                tmp /= dim_size;
272            }
273            let src_off = self.shape().offset(&coords);
274            *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
275        }
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_contiguous_copy() {
285        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
286
287        // Test that contiguous() returns a proper copy
288        let contiguous = tensor.contiguous();
289        assert!(contiguous.is_contiguous());
290        assert_eq!(contiguous.shape().dims(), tensor.shape().dims());
291
292        // Verify data is preserved
293        assert_eq!(contiguous.get(&[0, 0]), 1.0);
294        assert_eq!(contiguous.get(&[1, 2]), 6.0);
295    }
296
297    #[test]
298    fn test_contiguous_already_contiguous() {
299        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
300
301        // For already contiguous tensors, should return a clone
302        let contiguous = tensor.contiguous();
303        assert!(contiguous.is_contiguous());
304        assert_eq!(contiguous.shape().dims(), tensor.shape().dims());
305        assert_eq!(contiguous.size(), tensor.size());
306    }
307
308    #[test]
309    fn test_contiguous_preserves_gradient_tracking() {
310        let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
311        tensor.set_requires_grad(true);
312
313        let contiguous = tensor.contiguous();
314        assert!(contiguous.requires_grad());
315    }
316
317    #[test]
318    fn test_contiguous_gradient_flow() {
319        // Test that gradients flow correctly through contiguous operation
320        let mut x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
321        x.set_requires_grad(true);
322
323        // Create a non-contiguous tensor through transpose
324        let x_transposed = x.transpose(0, 1);
325        assert!(!x_transposed.is_contiguous());
326
327        // Make it contiguous
328        let x_contiguous = x_transposed.contiguous();
329        assert!(x_contiguous.is_contiguous());
330        assert!(x_contiguous.requires_grad());
331
332        // Do a simple operation and backward
333        let mut result = x_contiguous.sum();
334        result.backward(None);
335
336        // Check that the original tensor received gradients
337        let grad = x.grad_owned().expect("Gradient should exist");
338        assert_eq!(grad.shape().dims(), vec![2, 2]);
339
340        // All gradients should be 1.0 since sum operation
341        for i in 0..2 {
342            for j in 0..2 {
343                assert_eq!(grad.get(&[i, j]), 1.0);
344            }
345        }
346    }
347
348    #[test]
349    fn test_contiguous_1d() {
350        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
351        let contiguous = tensor.contiguous();
352
353        assert!(contiguous.is_contiguous());
354        assert_eq!(contiguous.shape().dims(), vec![3]);
355
356        // Verify data preservation
357        for i in 0..3 {
358            assert_eq!(contiguous.get(&[i]), (i + 1) as f32);
359        }
360    }
361
362    #[test]
363    fn test_contiguous_3d() {
364        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
365        let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
366        let contiguous = tensor.contiguous();
367
368        assert!(contiguous.is_contiguous());
369        assert_eq!(contiguous.shape().dims(), vec![2, 3, 4]);
370        assert_eq!(contiguous.size(), 24);
371    }
372}