train_station/tensor/transform/
contiguous.rs

1//! Contiguous tensor transformation operation
2//!
3//! This module provides functionality to create contiguous copies of tensors,
4//! ensuring that tensor data is stored in a linear, cache-friendly memory layout.
5//! Contiguous tensors are essential for optimal performance in many operations,
6//! particularly SIMD-optimized computations and operations that require
7//! sequential memory access patterns.
8//!
9//! # Memory Layout
10//!
11//! A tensor is considered contiguous when its elements are stored in memory
12//! in row-major order without gaps. Non-contiguous tensors can arise from
13//! operations like transpose, permute, or slice views that change the
14//! memory layout without copying data.
15//!
16//! # Performance Characteristics
17//!
18//! - **Already Contiguous**: O(1) time, returns a clone
19//! - **Small Tensors (≤64 elements)**: Simple copy with coordinate conversion
20//! - **Medium Tensors (65-1023 elements)**: Unrolled copy for better performance
21//! - **Large Tensors (≥1024 elements)**: Blocked copy with cache optimization
22//!
23//! # Examples
24//!
25//! ```
26//! use train_station::Tensor;
27//!
28//! // Create a contiguous tensor
29//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
30//! assert!(tensor.is_contiguous());
31//!
32//! // Create a non-contiguous tensor through transpose
33//! let transposed = tensor.transpose(0, 1);
34//! assert!(!transposed.is_contiguous());
35//!
36//! // Make it contiguous again
37//! let contiguous = transposed.contiguous();
38//! assert!(contiguous.is_contiguous());
39//! assert_eq!(contiguous.shape().dims, vec![2, 2]);
40//! ```
41//!
42//! ```
43//! use train_station::Tensor;
44//!
45//! // Contiguous preserves gradient tracking
46//! let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
47//! tensor.set_requires_grad(true);
48//!
49//! let transposed = tensor.transpose(0, 1);
50//! let contiguous = transposed.contiguous();
51//! assert!(contiguous.requires_grad());
52//! ```
53//!
54//! # Gradient Tracking
55//!
56//! The contiguous operation supports automatic gradient tracking through
57//! the GradTrack system. When `requires_grad` is enabled, the operation
58//! registers a gradient function that ensures proper gradient flow during
59//! backward passes.
60
61use crate::gradtrack::{GradEngine, GradFn};
62use crate::tensor::Tensor;
63
64impl Tensor {
65    /// Creates a contiguous copy of the tensor
66    ///
67    /// This operation ensures that the tensor data is stored in a linear,
68    /// cache-friendly memory layout. If the tensor is already contiguous,
69    /// this operation returns a clone. For non-contiguous tensors, it
70    /// creates a new tensor with the same data but in contiguous memory layout.
71    ///
72    /// The operation uses different optimization strategies based on tensor size:
73    /// - Small tensors (≤64 elements): Simple coordinate-based copy
74    /// - Medium tensors (65-1023 elements): Unrolled copy for better performance
75    /// - Large tensors (≥1024 elements): Blocked copy with cache optimization
76    ///
77    /// # Returns
78    ///
79    /// A new tensor with contiguous memory layout containing the same data
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// use train_station::Tensor;
85    ///
86    /// // Already contiguous tensor
87    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
88    /// let contiguous = tensor.contiguous();
89    /// assert!(contiguous.is_contiguous());
90    /// assert_eq!(contiguous.shape().dims, vec![2, 2]);
91    /// ```
92    ///
93    /// ```
94    /// use train_station::Tensor;
95    ///
96    /// // Non-contiguous tensor from transpose
97    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
98    /// let transposed = tensor.transpose(0, 1);
99    /// assert!(!transposed.is_contiguous());
100    ///
101    /// let contiguous = transposed.contiguous();
102    /// assert!(contiguous.is_contiguous());
103    /// assert_eq!(contiguous.get(&[0, 0]), 1.0);
104    /// assert_eq!(contiguous.get(&[0, 1]), 3.0);
105    /// ```
106    ///
107    /// ```
108    /// use train_station::Tensor;
109    ///
110    /// // Preserves gradient tracking
111    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
112    /// tensor.set_requires_grad(true);
113    ///
114    /// let contiguous = tensor.contiguous();
115    /// assert!(contiguous.requires_grad());
116    /// ```
117    ///
118    /// # Performance
119    ///
120    /// - **Already contiguous**: O(1) time complexity, returns a clone
121    /// - **Non-contiguous**: O(n) time complexity with size-dependent optimizations
122    /// - **Memory usage**: Creates a new tensor with the same size as the original
123    #[track_caller]
124    pub fn contiguous(&self) -> Tensor {
125        if self.is_contiguous() {
126            let mut cloned = self.clone();
127            // Ensure gradient requirements are preserved
128            if self.requires_grad() {
129                cloned.set_requires_grad(true);
130                // Register gradient function even for already-contiguous tensors
131                let grad_fn = GradFn::Contiguous {
132                    input_shape: self.shape().dims.clone(),
133                };
134                cloned.set_grad_fn(grad_fn.clone());
135                GradEngine::register_operation(cloned.id(), vec![self.id()], grad_fn);
136            }
137            return cloned;
138        }
139
140        // Create new contiguous tensor and copy via optimized methods
141        let mut result = Tensor::new(self.shape().dims.clone());
142
143        unsafe {
144            self.copy_to_contiguous_optimized(&mut result);
145        }
146
147        // Preserve gradient requirements and register gradient function
148        if self.requires_grad() {
149            result.set_requires_grad(true);
150            let grad_fn = GradFn::Contiguous {
151                input_shape: self.shape().dims.clone(),
152            };
153            result.set_grad_fn(grad_fn.clone());
154            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
155        }
156
157        result
158    }
159
160    /// Internal optimized contiguous copy operation
161    ///
162    /// This function dispatches to the appropriate copy strategy based on
163    /// tensor size and rank for optimal performance.
164    ///
165    /// # Arguments
166    ///
167    /// * `result` - The destination tensor to copy data into
168    ///
169    /// # Safety
170    ///
171    /// The caller must ensure:
172    /// * `result` has the same shape as `self`
173    /// * `result` is properly allocated and initialized
174    /// * Both tensors are valid and not moved during the operation
175    #[inline]
176    unsafe fn copy_to_contiguous_optimized(&self, result: &mut Tensor) {
177        let size = self.size();
178        let rank = self.shape().rank();
179        let _src_ptr = self.as_ptr();
180        let _dst_ptr = result.as_mut_ptr();
181
182        // For simple 1D tensors or very small tensors, use simple copy
183        if rank <= 1 || size <= 64 {
184            self.copy_to_contiguous_simple(result, rank);
185            return;
186        }
187
188        // For larger multi-dimensional tensors, use optimized stride-aware copy
189        if size >= 1024 {
190            self.copy_to_contiguous_large(result, rank);
191        } else {
192            self.copy_to_contiguous_medium(result, rank);
193        }
194    }
195
196    /// Simple copy for small tensors or 1D tensors
197    ///
198    /// This function performs a straightforward coordinate-based copy
199    /// suitable for small tensors where the overhead of more complex
200    /// optimizations would not be beneficial.
201    ///
202    /// # Arguments
203    ///
204    /// * `result` - The destination tensor
205    /// * `rank` - The rank of the tensor
206    ///
207    /// # Safety
208    ///
209    /// The caller must ensure both tensors are valid and properly allocated.
210    #[inline]
211    unsafe fn copy_to_contiguous_simple(&self, result: &mut Tensor, rank: usize) {
212        let size = self.size();
213        let src_ptr = self.as_ptr();
214        let dst_ptr = result.as_mut_ptr();
215
216        for dst_idx in 0..size {
217            // Compute destination coordinates under contiguous strides
218            let mut coords = vec![0usize; rank];
219            let mut tmp = dst_idx;
220            for i in (0..rank).rev() {
221                let dim_size = self.shape().dims[i];
222                coords[i] = tmp % dim_size;
223                tmp /= dim_size;
224            }
225            let src_off = self.shape().offset(&coords);
226            *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
227        }
228    }
229
230    /// Optimized copy for medium-sized tensors with unrolling
231    ///
232    /// This function uses loop unrolling to improve performance for
233    /// medium-sized tensors by reducing loop overhead and improving
234    /// instruction-level parallelism.
235    ///
236    /// # Arguments
237    ///
238    /// * `result` - The destination tensor
239    /// * `rank` - The rank of the tensor
240    ///
241    /// # Safety
242    ///
243    /// The caller must ensure both tensors are valid and properly allocated.
244    #[inline]
245    unsafe fn copy_to_contiguous_medium(&self, result: &mut Tensor, rank: usize) {
246        let size = self.size();
247        let src_ptr = self.as_ptr();
248        let dst_ptr = result.as_mut_ptr();
249        let unroll_factor = 4;
250        let unroll_count = size / unroll_factor;
251        let mut dst_idx = 0;
252
253        // Unrolled loop for better performance
254        for _ in 0..unroll_count {
255            for unroll_i in 0..unroll_factor {
256                let coords = self.linear_to_coords(dst_idx + unroll_i, rank);
257                let src_off = self.shape().offset(&coords);
258                *dst_ptr.add(dst_idx + unroll_i) = *src_ptr.add(src_off);
259            }
260            dst_idx += unroll_factor;
261        }
262
263        // Handle remaining elements
264        for i in dst_idx..size {
265            let coords = self.linear_to_coords(i, rank);
266            let src_off = self.shape().offset(&coords);
267            *dst_ptr.add(i) = *src_ptr.add(src_off);
268        }
269    }
270
271    /// Cache-optimized copy for large tensors with blocking
272    ///
273    /// This function uses blocking to improve cache locality for large tensors.
274    /// It processes the tensor in blocks to maximize cache hit rates and
275    /// combines blocking with loop unrolling for optimal performance.
276    ///
277    /// # Arguments
278    ///
279    /// * `result` - The destination tensor
280    /// * `rank` - The rank of the tensor
281    ///
282    /// # Safety
283    ///
284    /// The caller must ensure both tensors are valid and properly allocated.
285    #[inline]
286    unsafe fn copy_to_contiguous_large(&self, result: &mut Tensor, rank: usize) {
287        let size = self.size();
288        let src_ptr = self.as_ptr();
289        let dst_ptr = result.as_mut_ptr();
290
291        // Use blocking to improve cache locality
292        let block_size = 1024; // Process 1024 elements per block
293        let num_blocks = (size + block_size - 1) / block_size;
294
295        for block in 0..num_blocks {
296            let start_idx = block * block_size;
297            let end_idx = (start_idx + block_size).min(size);
298            let block_len = end_idx - start_idx;
299
300            // Process block with unrolling
301            let unroll_factor = 4;
302            let unroll_count = block_len / unroll_factor;
303            let mut local_idx = 0;
304
305            for _ in 0..unroll_count {
306                for unroll_i in 0..unroll_factor {
307                    let dst_idx = start_idx + local_idx + unroll_i;
308                    let coords = self.linear_to_coords(dst_idx, rank);
309                    let src_off = self.shape().offset(&coords);
310                    *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
311                }
312                local_idx += unroll_factor;
313            }
314
315            // Handle remaining elements in this block
316            for i in local_idx..block_len {
317                let dst_idx = start_idx + i;
318                let coords = self.linear_to_coords(dst_idx, rank);
319                let src_off = self.shape().offset(&coords);
320                *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
321            }
322        }
323    }
324
325    /// Helper function to convert linear index to coordinates
326    ///
327    /// Converts a linear (flat) index into multi-dimensional coordinates
328    /// based on the tensor's shape. This is used for coordinate-based
329    /// memory access in non-contiguous tensors.
330    ///
331    /// # Arguments
332    ///
333    /// * `idx` - The linear index to convert
334    /// * `rank` - The rank of the tensor
335    ///
336    /// # Returns
337    ///
338    /// A vector of coordinates representing the multi-dimensional position
339    #[inline]
340    fn linear_to_coords(&self, mut idx: usize, rank: usize) -> Vec<usize> {
341        let mut coords = vec![0usize; rank];
342        for i in (0..rank).rev() {
343            let dim_size = self.shape().dims[i];
344            coords[i] = idx % dim_size;
345            idx /= dim_size;
346        }
347        coords
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_contiguous_copy() {
357        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
358
359        // Test that contiguous() returns a proper copy
360        let contiguous = tensor.contiguous();
361        assert!(contiguous.is_contiguous());
362        assert_eq!(contiguous.shape().dims, tensor.shape().dims);
363
364        // Verify data is preserved
365        assert_eq!(contiguous.get(&[0, 0]), 1.0);
366        assert_eq!(contiguous.get(&[1, 2]), 6.0);
367    }
368
369    #[test]
370    fn test_contiguous_already_contiguous() {
371        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
372
373        // For already contiguous tensors, should return a clone
374        let contiguous = tensor.contiguous();
375        assert!(contiguous.is_contiguous());
376        assert_eq!(contiguous.shape().dims, tensor.shape().dims);
377        assert_eq!(contiguous.size(), tensor.size());
378    }
379
380    #[test]
381    fn test_contiguous_preserves_gradient_tracking() {
382        let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
383        tensor.set_requires_grad(true);
384
385        let contiguous = tensor.contiguous();
386        assert!(contiguous.requires_grad());
387    }
388
389    #[test]
390    fn test_contiguous_gradient_flow() {
391        // Test that gradients flow correctly through contiguous operation
392        let mut x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
393        x.set_requires_grad(true);
394
395        // Create a non-contiguous tensor through transpose
396        let x_transposed = x.transpose(0, 1);
397        assert!(!x_transposed.is_contiguous());
398
399        // Make it contiguous
400        let x_contiguous = x_transposed.contiguous();
401        assert!(x_contiguous.is_contiguous());
402        assert!(x_contiguous.requires_grad());
403
404        // Do a simple operation and backward
405        let mut result = x_contiguous.sum();
406        result.backward(None);
407
408        // Check that the original tensor received gradients
409        let grad = x.grad_by_value().expect("Gradient should exist");
410        assert_eq!(grad.shape().dims, vec![2, 2]);
411
412        // All gradients should be 1.0 since sum operation
413        for i in 0..2 {
414            for j in 0..2 {
415                assert_eq!(grad.get(&[i, j]), 1.0);
416            }
417        }
418    }
419
420    #[test]
421    fn test_contiguous_1d() {
422        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
423        let contiguous = tensor.contiguous();
424
425        assert!(contiguous.is_contiguous());
426        assert_eq!(contiguous.shape().dims, vec![3]);
427
428        // Verify data preservation
429        for i in 0..3 {
430            assert_eq!(contiguous.get(&[i]), (i + 1) as f32);
431        }
432    }
433
434    #[test]
435    fn test_contiguous_3d() {
436        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
437        let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
438        let contiguous = tensor.contiguous();
439
440        assert!(contiguous.is_contiguous());
441        assert_eq!(contiguous.shape().dims, vec![2, 3, 4]);
442        assert_eq!(contiguous.size(), 24);
443    }
444}