train_station/tensor/transform/
contiguous.rs

1//! Contiguous tensor transformation operation
2//!
3//! This module provides functionality to create contiguous copies of tensors,
4//! ensuring that tensor data is stored in a linear, cache-friendly memory layout.
5//! Contiguous tensors are essential for optimal performance in many operations,
6//! particularly SIMD-optimized computations and operations that require
7//! sequential memory access patterns.
8//!
9//! # Memory Layout
10//!
11//! A tensor is considered contiguous when its elements are stored in memory
12//! in row-major order without gaps. Non-contiguous tensors can arise from
13//! operations like transpose, permute, or slice views that change the
14//! memory layout without copying data.
15//!
16//! # Performance Characteristics
17//!
18//! - **Already Contiguous**: O(1) time, returns a clone
19//! - **Small Tensors (≤64 elements)**: Simple copy with coordinate conversion
20//! - **Medium Tensors (65-1023 elements)**: Unrolled copy for better performance
21//! - **Large Tensors (≥1024 elements)**: Blocked copy with cache optimization
22//!
23//! # Examples
24//!
25//! ```
26//! use train_station::Tensor;
27//!
28//! // Create a contiguous tensor
29//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
30//! assert!(tensor.is_contiguous());
31//!
32//! // Create a non-contiguous tensor through transpose
33//! let transposed = tensor.transpose(0, 1);
34//! assert!(!transposed.is_contiguous());
35//!
36//! // Make it contiguous again
37//! let contiguous = transposed.contiguous();
38//! assert!(contiguous.is_contiguous());
39//! assert_eq!(contiguous.shape().dims, vec![2, 2]);
40//! ```
41//!
42//! ```
43//! use train_station::Tensor;
44//!
45//! // Contiguous preserves gradient tracking
46//! let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
47//! tensor.set_requires_grad(true);
48//!
49//! let transposed = tensor.transpose(0, 1);
50//! let contiguous = transposed.contiguous();
51//! assert!(contiguous.requires_grad());
52//! ```
53//!
54//! # Gradient Tracking
55//!
56//! The contiguous operation supports automatic gradient tracking through
57//! the GradTrack system. When `requires_grad` is enabled, the operation
58//! registers a gradient function that ensures proper gradient flow during
59//! backward passes.
60
61use crate::gradtrack::{GradEngine, GradFn};
62use crate::tensor::Tensor;
63
64impl Tensor {
65    /// Creates a contiguous copy of the tensor
66    ///
67    /// This operation ensures that the tensor data is stored in a linear,
68    /// cache-friendly memory layout. If the tensor is already contiguous,
69    /// this operation returns a clone. For non-contiguous tensors, it
70    /// creates a new tensor with the same data but in contiguous memory layout.
71    ///
72    /// The operation uses different optimization strategies based on tensor size:
73    /// - Small tensors (≤64 elements): Simple coordinate-based copy
74    /// - Medium tensors (65-1023 elements): Unrolled copy for better performance
75    /// - Large tensors (≥1024 elements): Blocked copy with cache optimization
76    ///
77    /// # Returns
78    ///
79    /// A new tensor with contiguous memory layout containing the same data
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// use train_station::Tensor;
85    ///
86    /// // Already contiguous tensor
87    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
88    /// let contiguous = tensor.contiguous();
89    /// assert!(contiguous.is_contiguous());
90    /// assert_eq!(contiguous.shape().dims, vec![2, 2]);
91    /// ```
92    ///
93    /// ```
94    /// use train_station::Tensor;
95    ///
96    /// // Non-contiguous tensor from transpose
97    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
98    /// let transposed = tensor.transpose(0, 1);
99    /// assert!(!transposed.is_contiguous());
100    ///
101    /// let contiguous = transposed.contiguous();
102    /// assert!(contiguous.is_contiguous());
103    /// assert_eq!(contiguous.get(&[0, 0]), 1.0);
104    /// assert_eq!(contiguous.get(&[0, 1]), 3.0);
105    /// ```
106    ///
107    /// ```
108    /// use train_station::Tensor;
109    ///
110    /// // Preserves gradient tracking
111    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
112    /// tensor.set_requires_grad(true);
113    ///
114    /// let contiguous = tensor.contiguous();
115    /// assert!(contiguous.requires_grad());
116    /// ```
117    ///
118    /// # Performance
119    ///
120    /// - **Already contiguous**: O(1) time complexity, returns a clone
121    /// - **Non-contiguous**: O(n) time complexity with size-dependent optimizations
122    /// - **Memory usage**: Creates a new tensor with the same size as the original
123    pub fn contiguous(&self) -> Tensor {
124        if self.is_contiguous() {
125            let mut cloned = self.clone();
126            // Ensure gradient requirements are preserved
127            if self.requires_grad() {
128                cloned.set_requires_grad(true);
129                // Register gradient function even for already-contiguous tensors
130                let grad_fn = GradFn::Contiguous {
131                    input_shape: self.shape().dims.clone(),
132                };
133                cloned.set_grad_fn(grad_fn.clone());
134                GradEngine::register_operation(cloned.id(), vec![self.id()], grad_fn);
135            }
136            return cloned;
137        }
138
139        // Create new contiguous tensor and copy via optimized methods
140        let mut result = Tensor::new(self.shape().dims.clone());
141
142        unsafe {
143            self.copy_to_contiguous_optimized(&mut result);
144        }
145
146        // Preserve gradient requirements and register gradient function
147        if self.requires_grad() {
148            result.set_requires_grad(true);
149            let grad_fn = GradFn::Contiguous {
150                input_shape: self.shape().dims.clone(),
151            };
152            result.set_grad_fn(grad_fn.clone());
153            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
154        }
155
156        result
157    }
158
159    /// Internal optimized contiguous copy operation
160    ///
161    /// This function dispatches to the appropriate copy strategy based on
162    /// tensor size and rank for optimal performance.
163    ///
164    /// # Arguments
165    ///
166    /// * `result` - The destination tensor to copy data into
167    ///
168    /// # Safety
169    ///
170    /// The caller must ensure:
171    /// * `result` has the same shape as `self`
172    /// * `result` is properly allocated and initialized
173    /// * Both tensors are valid and not moved during the operation
174    #[inline]
175    unsafe fn copy_to_contiguous_optimized(&self, result: &mut Tensor) {
176        let size = self.size();
177        let rank = self.shape().rank();
178        let _src_ptr = self.as_ptr();
179        let _dst_ptr = result.as_mut_ptr();
180
181        // For simple 1D tensors or very small tensors, use simple copy
182        if rank <= 1 || size <= 64 {
183            self.copy_to_contiguous_simple(result, rank);
184            return;
185        }
186
187        // For larger multi-dimensional tensors, use optimized stride-aware copy
188        if size >= 1024 {
189            self.copy_to_contiguous_large(result, rank);
190        } else {
191            self.copy_to_contiguous_medium(result, rank);
192        }
193    }
194
195    /// Simple copy for small tensors or 1D tensors
196    ///
197    /// This function performs a straightforward coordinate-based copy
198    /// suitable for small tensors where the overhead of more complex
199    /// optimizations would not be beneficial.
200    ///
201    /// # Arguments
202    ///
203    /// * `result` - The destination tensor
204    /// * `rank` - The rank of the tensor
205    ///
206    /// # Safety
207    ///
208    /// The caller must ensure both tensors are valid and properly allocated.
209    #[inline]
210    unsafe fn copy_to_contiguous_simple(&self, result: &mut Tensor, rank: usize) {
211        let size = self.size();
212        let src_ptr = self.as_ptr();
213        let dst_ptr = result.as_mut_ptr();
214
215        for dst_idx in 0..size {
216            // Compute destination coordinates under contiguous strides
217            let mut coords = vec![0usize; rank];
218            let mut tmp = dst_idx;
219            for i in (0..rank).rev() {
220                let dim_size = self.shape().dims[i];
221                coords[i] = tmp % dim_size;
222                tmp /= dim_size;
223            }
224            let src_off = self.shape().offset(&coords);
225            *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
226        }
227    }
228
229    /// Optimized copy for medium-sized tensors with unrolling
230    ///
231    /// This function uses loop unrolling to improve performance for
232    /// medium-sized tensors by reducing loop overhead and improving
233    /// instruction-level parallelism.
234    ///
235    /// # Arguments
236    ///
237    /// * `result` - The destination tensor
238    /// * `rank` - The rank of the tensor
239    ///
240    /// # Safety
241    ///
242    /// The caller must ensure both tensors are valid and properly allocated.
243    #[inline]
244    unsafe fn copy_to_contiguous_medium(&self, result: &mut Tensor, rank: usize) {
245        let size = self.size();
246        let src_ptr = self.as_ptr();
247        let dst_ptr = result.as_mut_ptr();
248        let unroll_factor = 4;
249        let unroll_count = size / unroll_factor;
250        let mut dst_idx = 0;
251
252        // Unrolled loop for better performance
253        for _ in 0..unroll_count {
254            for unroll_i in 0..unroll_factor {
255                let coords = self.linear_to_coords(dst_idx + unroll_i, rank);
256                let src_off = self.shape().offset(&coords);
257                *dst_ptr.add(dst_idx + unroll_i) = *src_ptr.add(src_off);
258            }
259            dst_idx += unroll_factor;
260        }
261
262        // Handle remaining elements
263        for i in dst_idx..size {
264            let coords = self.linear_to_coords(i, rank);
265            let src_off = self.shape().offset(&coords);
266            *dst_ptr.add(i) = *src_ptr.add(src_off);
267        }
268    }
269
270    /// Cache-optimized copy for large tensors with blocking
271    ///
272    /// This function uses blocking to improve cache locality for large tensors.
273    /// It processes the tensor in blocks to maximize cache hit rates and
274    /// combines blocking with loop unrolling for optimal performance.
275    ///
276    /// # Arguments
277    ///
278    /// * `result` - The destination tensor
279    /// * `rank` - The rank of the tensor
280    ///
281    /// # Safety
282    ///
283    /// The caller must ensure both tensors are valid and properly allocated.
284    #[inline]
285    unsafe fn copy_to_contiguous_large(&self, result: &mut Tensor, rank: usize) {
286        let size = self.size();
287        let src_ptr = self.as_ptr();
288        let dst_ptr = result.as_mut_ptr();
289
290        // Use blocking to improve cache locality
291        let block_size = 1024; // Process 1024 elements per block
292        let num_blocks = (size + block_size - 1) / block_size;
293
294        for block in 0..num_blocks {
295            let start_idx = block * block_size;
296            let end_idx = (start_idx + block_size).min(size);
297            let block_len = end_idx - start_idx;
298
299            // Process block with unrolling
300            let unroll_factor = 4;
301            let unroll_count = block_len / unroll_factor;
302            let mut local_idx = 0;
303
304            for _ in 0..unroll_count {
305                for unroll_i in 0..unroll_factor {
306                    let dst_idx = start_idx + local_idx + unroll_i;
307                    let coords = self.linear_to_coords(dst_idx, rank);
308                    let src_off = self.shape().offset(&coords);
309                    *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
310                }
311                local_idx += unroll_factor;
312            }
313
314            // Handle remaining elements in this block
315            for i in local_idx..block_len {
316                let dst_idx = start_idx + i;
317                let coords = self.linear_to_coords(dst_idx, rank);
318                let src_off = self.shape().offset(&coords);
319                *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
320            }
321        }
322    }
323
324    /// Helper function to convert linear index to coordinates
325    ///
326    /// Converts a linear (flat) index into multi-dimensional coordinates
327    /// based on the tensor's shape. This is used for coordinate-based
328    /// memory access in non-contiguous tensors.
329    ///
330    /// # Arguments
331    ///
332    /// * `idx` - The linear index to convert
333    /// * `rank` - The rank of the tensor
334    ///
335    /// # Returns
336    ///
337    /// A vector of coordinates representing the multi-dimensional position
338    #[inline]
339    fn linear_to_coords(&self, mut idx: usize, rank: usize) -> Vec<usize> {
340        let mut coords = vec![0usize; rank];
341        for i in (0..rank).rev() {
342            let dim_size = self.shape().dims[i];
343            coords[i] = idx % dim_size;
344            idx /= dim_size;
345        }
346        coords
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_contiguous_copy() {
356        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
357
358        // Test that contiguous() returns a proper copy
359        let contiguous = tensor.contiguous();
360        assert!(contiguous.is_contiguous());
361        assert_eq!(contiguous.shape().dims, tensor.shape().dims);
362
363        // Verify data is preserved
364        assert_eq!(contiguous.get(&[0, 0]), 1.0);
365        assert_eq!(contiguous.get(&[1, 2]), 6.0);
366    }
367
368    #[test]
369    fn test_contiguous_already_contiguous() {
370        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
371
372        // For already contiguous tensors, should return a clone
373        let contiguous = tensor.contiguous();
374        assert!(contiguous.is_contiguous());
375        assert_eq!(contiguous.shape().dims, tensor.shape().dims);
376        assert_eq!(contiguous.size(), tensor.size());
377    }
378
379    #[test]
380    fn test_contiguous_preserves_gradient_tracking() {
381        let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
382        tensor.set_requires_grad(true);
383
384        let contiguous = tensor.contiguous();
385        assert!(contiguous.requires_grad());
386    }
387
388    #[test]
389    fn test_contiguous_gradient_flow() {
390        // Test that gradients flow correctly through contiguous operation
391        let mut x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
392        x.set_requires_grad(true);
393
394        // Create a non-contiguous tensor through transpose
395        let x_transposed = x.transpose(0, 1);
396        assert!(!x_transposed.is_contiguous());
397
398        // Make it contiguous
399        let x_contiguous = x_transposed.contiguous();
400        assert!(x_contiguous.is_contiguous());
401        assert!(x_contiguous.requires_grad());
402
403        // Do a simple operation and backward
404        let mut result = x_contiguous.sum();
405        result.backward(None);
406
407        // Check that the original tensor received gradients
408        let grad = x.grad_by_value().expect("Gradient should exist");
409        assert_eq!(grad.shape().dims, vec![2, 2]);
410
411        // All gradients should be 1.0 since sum operation
412        for i in 0..2 {
413            for j in 0..2 {
414                assert_eq!(grad.get(&[i, j]), 1.0);
415            }
416        }
417    }
418
419    #[test]
420    fn test_contiguous_1d() {
421        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
422        let contiguous = tensor.contiguous();
423
424        assert!(contiguous.is_contiguous());
425        assert_eq!(contiguous.shape().dims, vec![3]);
426
427        // Verify data preservation
428        for i in 0..3 {
429            assert_eq!(contiguous.get(&[i]), (i + 1) as f32);
430        }
431    }
432
433    #[test]
434    fn test_contiguous_3d() {
435        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
436        let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
437        let contiguous = tensor.contiguous();
438
439        assert!(contiguous.is_contiguous());
440        assert_eq!(contiguous.shape().dims, vec![2, 3, 4]);
441        assert_eq!(contiguous.size(), 24);
442    }
443}