train_station/tensor/transform/contiguous.rs
1//! Contiguous tensor transformation operation
2//!
3//! This module provides functionality to create contiguous copies of tensors,
4//! ensuring that tensor data is stored in a linear, cache-friendly memory layout.
5//! Contiguous tensors are essential for optimal performance in many operations,
6//! particularly SIMD-optimized computations and operations that require
7//! sequential memory access patterns.
8//!
9//! # Memory Layout
10//!
11//! A tensor is considered contiguous when its elements are stored in memory
12//! in row-major order without gaps. Non-contiguous tensors can arise from
13//! operations like transpose, permute, or slice views that change the
14//! memory layout without copying data.
15//!
16//! # Performance Characteristics
17//!
18//! - **Already Contiguous**: O(1) time, returns a clone
19//! - **Small Tensors (≤64 elements)**: Simple copy with coordinate conversion
20//! - **Medium Tensors (65-1023 elements)**: Unrolled copy for better performance
21//! - **Large Tensors (≥1024 elements)**: Blocked copy with cache optimization
22//!
23//! # Examples
24//!
25//! ```
26//! use train_station::Tensor;
27//!
28//! // Create a contiguous tensor
29//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
30//! assert!(tensor.is_contiguous());
31//!
32//! // Create a non-contiguous tensor through transpose
33//! let transposed = tensor.transpose(0, 1);
34//! assert!(!transposed.is_contiguous());
35//!
36//! // Make it contiguous again
37//! let contiguous = transposed.contiguous();
38//! assert!(contiguous.is_contiguous());
39//! assert_eq!(contiguous.shape().dims, vec![2, 2]);
40//! ```
41//!
42//! ```
43//! use train_station::Tensor;
44//!
45//! // Contiguous preserves gradient tracking
46//! let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
47//! tensor.set_requires_grad(true);
48//!
49//! let transposed = tensor.transpose(0, 1);
50//! let contiguous = transposed.contiguous();
51//! assert!(contiguous.requires_grad());
52//! ```
53//!
54//! # Gradient Tracking
55//!
56//! The contiguous operation supports automatic gradient tracking through
57//! the GradTrack system. When `requires_grad` is enabled, the operation
58//! registers a gradient function that ensures proper gradient flow during
59//! backward passes.
60
61use crate::gradtrack::{GradEngine, GradFn};
62use crate::tensor::Tensor;
63
64impl Tensor {
65 /// Creates a contiguous copy of the tensor
66 ///
67 /// This operation ensures that the tensor data is stored in a linear,
68 /// cache-friendly memory layout. If the tensor is already contiguous,
69 /// this operation returns a clone. For non-contiguous tensors, it
70 /// creates a new tensor with the same data but in contiguous memory layout.
71 ///
72 /// The operation uses different optimization strategies based on tensor size:
73 /// - Small tensors (≤64 elements): Simple coordinate-based copy
74 /// - Medium tensors (65-1023 elements): Unrolled copy for better performance
75 /// - Large tensors (≥1024 elements): Blocked copy with cache optimization
76 ///
77 /// # Returns
78 ///
79 /// A new tensor with contiguous memory layout containing the same data
80 ///
81 /// # Examples
82 ///
83 /// ```
84 /// use train_station::Tensor;
85 ///
86 /// // Already contiguous tensor
87 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
88 /// let contiguous = tensor.contiguous();
89 /// assert!(contiguous.is_contiguous());
90 /// assert_eq!(contiguous.shape().dims, vec![2, 2]);
91 /// ```
92 ///
93 /// ```
94 /// use train_station::Tensor;
95 ///
96 /// // Non-contiguous tensor from transpose
97 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
98 /// let transposed = tensor.transpose(0, 1);
99 /// assert!(!transposed.is_contiguous());
100 ///
101 /// let contiguous = transposed.contiguous();
102 /// assert!(contiguous.is_contiguous());
103 /// assert_eq!(contiguous.get(&[0, 0]), 1.0);
104 /// assert_eq!(contiguous.get(&[0, 1]), 3.0);
105 /// ```
106 ///
107 /// ```
108 /// use train_station::Tensor;
109 ///
110 /// // Preserves gradient tracking
111 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
112 /// tensor.set_requires_grad(true);
113 ///
114 /// let contiguous = tensor.contiguous();
115 /// assert!(contiguous.requires_grad());
116 /// ```
117 ///
118 /// # Performance
119 ///
120 /// - **Already contiguous**: O(1) time complexity, returns a clone
121 /// - **Non-contiguous**: O(n) time complexity with size-dependent optimizations
122 /// - **Memory usage**: Creates a new tensor with the same size as the original
123 #[track_caller]
124 pub fn contiguous(&self) -> Tensor {
125 if self.is_contiguous() {
126 let mut cloned = self.clone();
127 // Ensure gradient requirements are preserved
128 if self.requires_grad() {
129 cloned.set_requires_grad(true);
130 // Register gradient function even for already-contiguous tensors
131 let grad_fn = GradFn::Contiguous {
132 input_shape: self.shape().dims.clone(),
133 };
134 cloned.set_grad_fn(grad_fn.clone());
135 GradEngine::register_operation(cloned.id(), vec![self.id()], grad_fn);
136 }
137 return cloned;
138 }
139
140 // Create new contiguous tensor and copy via optimized methods
141 let mut result = Tensor::new(self.shape().dims.clone());
142
143 unsafe {
144 self.copy_to_contiguous_optimized(&mut result);
145 }
146
147 // Preserve gradient requirements and register gradient function
148 if self.requires_grad() {
149 result.set_requires_grad(true);
150 let grad_fn = GradFn::Contiguous {
151 input_shape: self.shape().dims.clone(),
152 };
153 result.set_grad_fn(grad_fn.clone());
154 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
155 }
156
157 result
158 }
159
160 /// Internal optimized contiguous copy operation
161 ///
162 /// This function dispatches to the appropriate copy strategy based on
163 /// tensor size and rank for optimal performance.
164 ///
165 /// # Arguments
166 ///
167 /// * `result` - The destination tensor to copy data into
168 ///
169 /// # Safety
170 ///
171 /// The caller must ensure:
172 /// * `result` has the same shape as `self`
173 /// * `result` is properly allocated and initialized
174 /// * Both tensors are valid and not moved during the operation
175 #[inline]
176 unsafe fn copy_to_contiguous_optimized(&self, result: &mut Tensor) {
177 let size = self.size();
178 let rank = self.shape().rank();
179 let _src_ptr = self.as_ptr();
180 let _dst_ptr = result.as_mut_ptr();
181
182 // For simple 1D tensors or very small tensors, use simple copy
183 if rank <= 1 || size <= 64 {
184 self.copy_to_contiguous_simple(result, rank);
185 return;
186 }
187
188 // For larger multi-dimensional tensors, use optimized stride-aware copy
189 if size >= 1024 {
190 self.copy_to_contiguous_large(result, rank);
191 } else {
192 self.copy_to_contiguous_medium(result, rank);
193 }
194 }
195
196 /// Simple copy for small tensors or 1D tensors
197 ///
198 /// This function performs a straightforward coordinate-based copy
199 /// suitable for small tensors where the overhead of more complex
200 /// optimizations would not be beneficial.
201 ///
202 /// # Arguments
203 ///
204 /// * `result` - The destination tensor
205 /// * `rank` - The rank of the tensor
206 ///
207 /// # Safety
208 ///
209 /// The caller must ensure both tensors are valid and properly allocated.
210 #[inline]
211 unsafe fn copy_to_contiguous_simple(&self, result: &mut Tensor, rank: usize) {
212 let size = self.size();
213 let src_ptr = self.as_ptr();
214 let dst_ptr = result.as_mut_ptr();
215
216 for dst_idx in 0..size {
217 // Compute destination coordinates under contiguous strides
218 let mut coords = vec![0usize; rank];
219 let mut tmp = dst_idx;
220 for i in (0..rank).rev() {
221 let dim_size = self.shape().dims[i];
222 coords[i] = tmp % dim_size;
223 tmp /= dim_size;
224 }
225 let src_off = self.shape().offset(&coords);
226 *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
227 }
228 }
229
230 /// Optimized copy for medium-sized tensors with unrolling
231 ///
232 /// This function uses loop unrolling to improve performance for
233 /// medium-sized tensors by reducing loop overhead and improving
234 /// instruction-level parallelism.
235 ///
236 /// # Arguments
237 ///
238 /// * `result` - The destination tensor
239 /// * `rank` - The rank of the tensor
240 ///
241 /// # Safety
242 ///
243 /// The caller must ensure both tensors are valid and properly allocated.
244 #[inline]
245 unsafe fn copy_to_contiguous_medium(&self, result: &mut Tensor, rank: usize) {
246 let size = self.size();
247 let src_ptr = self.as_ptr();
248 let dst_ptr = result.as_mut_ptr();
249 let unroll_factor = 4;
250 let unroll_count = size / unroll_factor;
251 let mut dst_idx = 0;
252
253 // Unrolled loop for better performance
254 for _ in 0..unroll_count {
255 for unroll_i in 0..unroll_factor {
256 let coords = self.linear_to_coords(dst_idx + unroll_i, rank);
257 let src_off = self.shape().offset(&coords);
258 *dst_ptr.add(dst_idx + unroll_i) = *src_ptr.add(src_off);
259 }
260 dst_idx += unroll_factor;
261 }
262
263 // Handle remaining elements
264 for i in dst_idx..size {
265 let coords = self.linear_to_coords(i, rank);
266 let src_off = self.shape().offset(&coords);
267 *dst_ptr.add(i) = *src_ptr.add(src_off);
268 }
269 }
270
271 /// Cache-optimized copy for large tensors with blocking
272 ///
273 /// This function uses blocking to improve cache locality for large tensors.
274 /// It processes the tensor in blocks to maximize cache hit rates and
275 /// combines blocking with loop unrolling for optimal performance.
276 ///
277 /// # Arguments
278 ///
279 /// * `result` - The destination tensor
280 /// * `rank` - The rank of the tensor
281 ///
282 /// # Safety
283 ///
284 /// The caller must ensure both tensors are valid and properly allocated.
285 #[inline]
286 unsafe fn copy_to_contiguous_large(&self, result: &mut Tensor, rank: usize) {
287 let size = self.size();
288 let src_ptr = self.as_ptr();
289 let dst_ptr = result.as_mut_ptr();
290
291 // Use blocking to improve cache locality
292 let block_size = 1024; // Process 1024 elements per block
293 let num_blocks = (size + block_size - 1) / block_size;
294
295 for block in 0..num_blocks {
296 let start_idx = block * block_size;
297 let end_idx = (start_idx + block_size).min(size);
298 let block_len = end_idx - start_idx;
299
300 // Process block with unrolling
301 let unroll_factor = 4;
302 let unroll_count = block_len / unroll_factor;
303 let mut local_idx = 0;
304
305 for _ in 0..unroll_count {
306 for unroll_i in 0..unroll_factor {
307 let dst_idx = start_idx + local_idx + unroll_i;
308 let coords = self.linear_to_coords(dst_idx, rank);
309 let src_off = self.shape().offset(&coords);
310 *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
311 }
312 local_idx += unroll_factor;
313 }
314
315 // Handle remaining elements in this block
316 for i in local_idx..block_len {
317 let dst_idx = start_idx + i;
318 let coords = self.linear_to_coords(dst_idx, rank);
319 let src_off = self.shape().offset(&coords);
320 *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
321 }
322 }
323 }
324
325 /// Helper function to convert linear index to coordinates
326 ///
327 /// Converts a linear (flat) index into multi-dimensional coordinates
328 /// based on the tensor's shape. This is used for coordinate-based
329 /// memory access in non-contiguous tensors.
330 ///
331 /// # Arguments
332 ///
333 /// * `idx` - The linear index to convert
334 /// * `rank` - The rank of the tensor
335 ///
336 /// # Returns
337 ///
338 /// A vector of coordinates representing the multi-dimensional position
339 #[inline]
340 fn linear_to_coords(&self, mut idx: usize, rank: usize) -> Vec<usize> {
341 let mut coords = vec![0usize; rank];
342 for i in (0..rank).rev() {
343 let dim_size = self.shape().dims[i];
344 coords[i] = idx % dim_size;
345 idx /= dim_size;
346 }
347 coords
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_contiguous_copy() {
357 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
358
359 // Test that contiguous() returns a proper copy
360 let contiguous = tensor.contiguous();
361 assert!(contiguous.is_contiguous());
362 assert_eq!(contiguous.shape().dims, tensor.shape().dims);
363
364 // Verify data is preserved
365 assert_eq!(contiguous.get(&[0, 0]), 1.0);
366 assert_eq!(contiguous.get(&[1, 2]), 6.0);
367 }
368
369 #[test]
370 fn test_contiguous_already_contiguous() {
371 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
372
373 // For already contiguous tensors, should return a clone
374 let contiguous = tensor.contiguous();
375 assert!(contiguous.is_contiguous());
376 assert_eq!(contiguous.shape().dims, tensor.shape().dims);
377 assert_eq!(contiguous.size(), tensor.size());
378 }
379
380 #[test]
381 fn test_contiguous_preserves_gradient_tracking() {
382 let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
383 tensor.set_requires_grad(true);
384
385 let contiguous = tensor.contiguous();
386 assert!(contiguous.requires_grad());
387 }
388
389 #[test]
390 fn test_contiguous_gradient_flow() {
391 // Test that gradients flow correctly through contiguous operation
392 let mut x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
393 x.set_requires_grad(true);
394
395 // Create a non-contiguous tensor through transpose
396 let x_transposed = x.transpose(0, 1);
397 assert!(!x_transposed.is_contiguous());
398
399 // Make it contiguous
400 let x_contiguous = x_transposed.contiguous();
401 assert!(x_contiguous.is_contiguous());
402 assert!(x_contiguous.requires_grad());
403
404 // Do a simple operation and backward
405 let mut result = x_contiguous.sum();
406 result.backward(None);
407
408 // Check that the original tensor received gradients
409 let grad = x.grad_by_value().expect("Gradient should exist");
410 assert_eq!(grad.shape().dims, vec![2, 2]);
411
412 // All gradients should be 1.0 since sum operation
413 for i in 0..2 {
414 for j in 0..2 {
415 assert_eq!(grad.get(&[i, j]), 1.0);
416 }
417 }
418 }
419
420 #[test]
421 fn test_contiguous_1d() {
422 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
423 let contiguous = tensor.contiguous();
424
425 assert!(contiguous.is_contiguous());
426 assert_eq!(contiguous.shape().dims, vec![3]);
427
428 // Verify data preservation
429 for i in 0..3 {
430 assert_eq!(contiguous.get(&[i]), (i + 1) as f32);
431 }
432 }
433
434 #[test]
435 fn test_contiguous_3d() {
436 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
437 let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
438 let contiguous = tensor.contiguous();
439
440 assert!(contiguous.is_contiguous());
441 assert_eq!(contiguous.shape().dims, vec![2, 3, 4]);
442 assert_eq!(contiguous.size(), 24);
443 }
444}