train_station/tensor/transform/contiguous.rs
1//! Contiguous tensor transformation operation
2//!
3//! This module provides functionality to create contiguous copies of tensors,
4//! ensuring that tensor data is stored in a linear, cache-friendly memory layout.
5//! Contiguous tensors are essential for optimal performance in many operations,
6//! particularly SIMD-optimized computations and operations that require
7//! sequential memory access patterns.
8//!
9//! # Memory Layout
10//!
11//! A tensor is considered contiguous when its elements are stored in memory
12//! in row-major order without gaps. Non-contiguous tensors can arise from
13//! operations like transpose, permute, or slice views that change the
14//! memory layout without copying data.
15//!
16//! # Performance Characteristics
17//!
18//! - **Already Contiguous**: O(1) time, returns a clone
19//! - **Small Tensors (≤64 elements)**: Simple copy with coordinate conversion
20//! - **Medium Tensors (65-1023 elements)**: Unrolled copy for better performance
21//! - **Large Tensors (≥1024 elements)**: Blocked copy with cache optimization
22//!
23//! # Examples
24//!
25//! ```
26//! use train_station::Tensor;
27//!
28//! // Create a contiguous tensor
29//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
30//! assert!(tensor.is_contiguous());
31//!
32//! // Create a non-contiguous tensor through transpose
33//! let transposed = tensor.transpose(0, 1);
34//! assert!(!transposed.is_contiguous());
35//!
36//! // Make it contiguous again
37//! let contiguous = transposed.contiguous();
38//! assert!(contiguous.is_contiguous());
39//! assert_eq!(contiguous.shape().dims(), vec![2, 2]);
40//! ```
41//!
42//! ```
43//! use train_station::Tensor;
44//!
45//! // Contiguous preserves gradient tracking
46//! let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
47//! tensor.set_requires_grad(true);
48//!
49//! let transposed = tensor.transpose(0, 1);
50//! let contiguous = transposed.contiguous();
51//! assert!(contiguous.requires_grad());
52//! ```
53//!
54//! # Gradient Tracking
55//!
56//! The contiguous operation supports automatic gradient tracking through
57//! the GradTrack system. When `requires_grad` is enabled, the operation
58//! registers a gradient function that ensures proper gradient flow during
59//! backward passes.
60
61use crate::gradtrack::{GradEngine, GradFn};
62use crate::tensor::iterator::collect::optimized_copy;
63use crate::tensor::Tensor;
64
65impl Tensor {
66 /// Creates a contiguous copy of the tensor
67 ///
68 /// This operation ensures that the tensor data is stored in a linear,
69 /// cache-friendly memory layout. If the tensor is already contiguous,
70 /// this operation returns a clone. For non-contiguous tensors, it
71 /// creates a new tensor with the same data but in contiguous memory layout.
72 ///
73 /// The operation uses different optimization strategies based on tensor size:
74 /// - Small tensors (≤64 elements): Simple coordinate-based copy
75 /// - Medium tensors (65-1023 elements): Unrolled copy for better performance
76 /// - Large tensors (≥1024 elements): Blocked copy with cache optimization
77 ///
78 /// # Returns
79 ///
80 /// A new tensor with contiguous memory layout containing the same data
81 ///
82 /// # Examples
83 ///
84 /// ```
85 /// use train_station::Tensor;
86 ///
87 /// // Already contiguous tensor
88 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
89 /// let contiguous = tensor.contiguous();
90 /// assert!(contiguous.is_contiguous());
91 /// assert_eq!(contiguous.shape().dims(), vec![2, 2]);
92 /// ```
93 ///
94 /// ```
95 /// use train_station::Tensor;
96 ///
97 /// // Non-contiguous tensor from transpose
98 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
99 /// let transposed = tensor.transpose(0, 1);
100 /// assert!(!transposed.is_contiguous());
101 ///
102 /// let contiguous = transposed.contiguous();
103 /// assert!(contiguous.is_contiguous());
104 /// assert_eq!(contiguous.get(&[0, 0]), 1.0);
105 /// assert_eq!(contiguous.get(&[0, 1]), 3.0);
106 /// ```
107 ///
108 /// ```
109 /// use train_station::Tensor;
110 ///
111 /// // Preserves gradient tracking
112 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
113 /// tensor.set_requires_grad(true);
114 ///
115 /// let contiguous = tensor.contiguous();
116 /// assert!(contiguous.requires_grad());
117 /// ```
118 ///
119 /// # Performance
120 ///
121 /// - **Already contiguous**: O(1) time complexity, returns a clone
122 /// - **Non-contiguous**: O(n) time complexity with size-dependent optimizations
123 /// - **Memory usage**: Creates a new tensor with the same size as the original
124 #[track_caller]
125 pub fn contiguous(&self) -> Tensor {
126 if self.is_contiguous() {
127 let mut cloned = self.clone();
128 // Ensure gradient requirements are preserved
129 if self.requires_grad() {
130 cloned.set_requires_grad(true);
131 // Register gradient function even for already-contiguous tensors
132 let grad_fn = GradFn::Contiguous {
133 input_shape: self.shape().dims().to_vec(),
134 };
135 cloned.set_grad_fn(grad_fn.clone());
136 GradEngine::register_operation(cloned.id(), vec![self.id()], grad_fn);
137 }
138 return cloned;
139 }
140
141 // Create new contiguous tensor and copy via optimized methods
142 let mut result = Tensor::new(self.shape().dims().to_vec());
143
144 unsafe {
145 self.copy_to_contiguous_optimized(&mut result);
146 }
147
148 // Preserve gradient requirements and register gradient function
149 if self.requires_grad() {
150 result.set_requires_grad(true);
151 let grad_fn = GradFn::Contiguous {
152 input_shape: self.shape().dims().to_vec(),
153 };
154 result.set_grad_fn(grad_fn.clone());
155 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
156 }
157
158 result
159 }
160
161 /// Internal optimized contiguous copy operation
162 ///
163 /// This function dispatches to the appropriate copy strategy based on
164 /// tensor size and rank for optimal performance.
165 ///
166 /// # Arguments
167 ///
168 /// * `result` - The destination tensor to copy data into
169 ///
170 /// # Safety
171 ///
172 /// The caller must ensure:
173 /// * `result` has the same shape as `self`
174 /// * `result` is properly allocated and initialized
175 /// * Both tensors are valid and not moved during the operation
176 #[inline]
177 unsafe fn copy_to_contiguous_optimized(&self, result: &mut Tensor) {
178 let size = self.size();
179 let rank = self.shape().rank();
180
181 if size == 0 {
182 return;
183 }
184
185 // Fast path: if the last dimension is contiguous in the source view,
186 // copy row-by-row using SIMD-optimized contiguous copies.
187 if rank >= 1 && self.stride(rank - 1) == 1 {
188 let dims = self.shape().dims();
189 let row_len = dims[rank - 1];
190
191 // Number of outer rows to copy (all dims except the last)
192 let outer: usize = if rank == 1 {
193 1
194 } else {
195 dims[..rank - 1].iter().product()
196 };
197
198 let src_base = self.as_ptr();
199 let dst_base = result.as_mut_ptr();
200 let strides = self.strides();
201
202 // Coordinate vector for outer dimensions (exclude last)
203 let mut coords = vec![0usize; rank];
204
205 for outer_idx in 0..outer {
206 // Compute multi-index over dims[0..rank-1) in row-major order
207 if rank > 1 {
208 let mut tmp = outer_idx;
209 for i in (0..rank - 1).rev() {
210 let d = dims[i];
211 coords[i] = if d == 0 { 0 } else { tmp % d };
212 if d != 0 {
213 tmp /= d;
214 }
215 }
216 }
217 coords[rank - 1] = 0; // start of the contiguous row
218
219 // Compute source offset via strides
220 let mut src_off = 0usize;
221 for i in 0..rank {
222 src_off += coords[i] * strides[i];
223 }
224
225 // Destination offset: linear index over outer dims times row_len
226 let mut dst_row_index = 0usize;
227 if rank > 1 {
228 for i in 0..rank - 1 {
229 dst_row_index = dst_row_index * dims[i] + coords[i];
230 }
231 }
232 let dst_off = dst_row_index * row_len;
233
234 // Copy the entire row as a contiguous block
235 optimized_copy(src_base.add(src_off), dst_base.add(dst_off), row_len);
236 }
237 return;
238 }
239
240 // Fallback: general coordinate-based copy (works for any strided view)
241 self.copy_to_contiguous_simple(result, rank);
242 }
243
244 /// Simple copy for small tensors or 1D tensors
245 ///
246 /// This function performs a straightforward coordinate-based copy
247 /// suitable for small tensors where the overhead of more complex
248 /// optimizations would not be beneficial.
249 ///
250 /// # Arguments
251 ///
252 /// * `result` - The destination tensor
253 /// * `rank` - The rank of the tensor
254 ///
255 /// # Safety
256 ///
257 /// The caller must ensure both tensors are valid and properly allocated.
258 #[inline]
259 unsafe fn copy_to_contiguous_simple(&self, result: &mut Tensor, rank: usize) {
260 let size = self.size();
261 let src_ptr = self.as_ptr();
262 let dst_ptr = result.as_mut_ptr();
263
264 for dst_idx in 0..size {
265 // Compute destination coordinates under contiguous strides
266 let mut coords = vec![0usize; rank];
267 let mut tmp = dst_idx;
268 for i in (0..rank).rev() {
269 let dim_size = self.shape().dims()[i];
270 coords[i] = tmp % dim_size;
271 tmp /= dim_size;
272 }
273 let src_off = self.shape().offset(&coords);
274 *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
275 }
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_contiguous_copy() {
285 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
286
287 // Test that contiguous() returns a proper copy
288 let contiguous = tensor.contiguous();
289 assert!(contiguous.is_contiguous());
290 assert_eq!(contiguous.shape().dims(), tensor.shape().dims());
291
292 // Verify data is preserved
293 assert_eq!(contiguous.get(&[0, 0]), 1.0);
294 assert_eq!(contiguous.get(&[1, 2]), 6.0);
295 }
296
297 #[test]
298 fn test_contiguous_already_contiguous() {
299 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
300
301 // For already contiguous tensors, should return a clone
302 let contiguous = tensor.contiguous();
303 assert!(contiguous.is_contiguous());
304 assert_eq!(contiguous.shape().dims(), tensor.shape().dims());
305 assert_eq!(contiguous.size(), tensor.size());
306 }
307
308 #[test]
309 fn test_contiguous_preserves_gradient_tracking() {
310 let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
311 tensor.set_requires_grad(true);
312
313 let contiguous = tensor.contiguous();
314 assert!(contiguous.requires_grad());
315 }
316
317 #[test]
318 fn test_contiguous_gradient_flow() {
319 // Test that gradients flow correctly through contiguous operation
320 let mut x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
321 x.set_requires_grad(true);
322
323 // Create a non-contiguous tensor through transpose
324 let x_transposed = x.transpose(0, 1);
325 assert!(!x_transposed.is_contiguous());
326
327 // Make it contiguous
328 let x_contiguous = x_transposed.contiguous();
329 assert!(x_contiguous.is_contiguous());
330 assert!(x_contiguous.requires_grad());
331
332 // Do a simple operation and backward
333 let mut result = x_contiguous.sum();
334 result.backward(None);
335
336 // Check that the original tensor received gradients
337 let grad = x.grad_owned().expect("Gradient should exist");
338 assert_eq!(grad.shape().dims(), vec![2, 2]);
339
340 // All gradients should be 1.0 since sum operation
341 for i in 0..2 {
342 for j in 0..2 {
343 assert_eq!(grad.get(&[i, j]), 1.0);
344 }
345 }
346 }
347
348 #[test]
349 fn test_contiguous_1d() {
350 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
351 let contiguous = tensor.contiguous();
352
353 assert!(contiguous.is_contiguous());
354 assert_eq!(contiguous.shape().dims(), vec![3]);
355
356 // Verify data preservation
357 for i in 0..3 {
358 assert_eq!(contiguous.get(&[i]), (i + 1) as f32);
359 }
360 }
361
362 #[test]
363 fn test_contiguous_3d() {
364 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
365 let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
366 let contiguous = tensor.contiguous();
367
368 assert!(contiguous.is_contiguous());
369 assert_eq!(contiguous.shape().dims(), vec![2, 3, 4]);
370 assert_eq!(contiguous.size(), 24);
371 }
372}