train_station/tensor/transform/contiguous.rs
1//! Contiguous tensor transformation operation
2//!
3//! This module provides functionality to create contiguous copies of tensors,
4//! ensuring that tensor data is stored in a linear, cache-friendly memory layout.
5//! Contiguous tensors are essential for optimal performance in many operations,
6//! particularly SIMD-optimized computations and operations that require
7//! sequential memory access patterns.
8//!
9//! # Memory Layout
10//!
11//! A tensor is considered contiguous when its elements are stored in memory
12//! in row-major order without gaps. Non-contiguous tensors can arise from
13//! operations like transpose, permute, or slice views that change the
14//! memory layout without copying data.
15//!
16//! # Performance Characteristics
17//!
18//! - **Already Contiguous**: O(1) time, returns a clone
19//! - **Small Tensors (≤64 elements)**: Simple copy with coordinate conversion
20//! - **Medium Tensors (65-1023 elements)**: Unrolled copy for better performance
21//! - **Large Tensors (≥1024 elements)**: Blocked copy with cache optimization
22//!
23//! # Examples
24//!
25//! ```
26//! use train_station::Tensor;
27//!
28//! // Create a contiguous tensor
29//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
30//! assert!(tensor.is_contiguous());
31//!
32//! // Create a non-contiguous tensor through transpose
33//! let transposed = tensor.transpose(0, 1);
34//! assert!(!transposed.is_contiguous());
35//!
36//! // Make it contiguous again
37//! let contiguous = transposed.contiguous();
38//! assert!(contiguous.is_contiguous());
39//! assert_eq!(contiguous.shape().dims, vec![2, 2]);
40//! ```
41//!
42//! ```
43//! use train_station::Tensor;
44//!
45//! // Contiguous preserves gradient tracking
46//! let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
47//! tensor.set_requires_grad(true);
48//!
49//! let transposed = tensor.transpose(0, 1);
50//! let contiguous = transposed.contiguous();
51//! assert!(contiguous.requires_grad());
52//! ```
53//!
54//! # Gradient Tracking
55//!
56//! The contiguous operation supports automatic gradient tracking through
57//! the GradTrack system. When `requires_grad` is enabled, the operation
58//! registers a gradient function that ensures proper gradient flow during
59//! backward passes.
60
61use crate::gradtrack::{GradEngine, GradFn};
62use crate::tensor::Tensor;
63
64impl Tensor {
65 /// Creates a contiguous copy of the tensor
66 ///
67 /// This operation ensures that the tensor data is stored in a linear,
68 /// cache-friendly memory layout. If the tensor is already contiguous,
69 /// this operation returns a clone. For non-contiguous tensors, it
70 /// creates a new tensor with the same data but in contiguous memory layout.
71 ///
72 /// The operation uses different optimization strategies based on tensor size:
73 /// - Small tensors (≤64 elements): Simple coordinate-based copy
74 /// - Medium tensors (65-1023 elements): Unrolled copy for better performance
75 /// - Large tensors (≥1024 elements): Blocked copy with cache optimization
76 ///
77 /// # Returns
78 ///
79 /// A new tensor with contiguous memory layout containing the same data
80 ///
81 /// # Examples
82 ///
83 /// ```
84 /// use train_station::Tensor;
85 ///
86 /// // Already contiguous tensor
87 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
88 /// let contiguous = tensor.contiguous();
89 /// assert!(contiguous.is_contiguous());
90 /// assert_eq!(contiguous.shape().dims, vec![2, 2]);
91 /// ```
92 ///
93 /// ```
94 /// use train_station::Tensor;
95 ///
96 /// // Non-contiguous tensor from transpose
97 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
98 /// let transposed = tensor.transpose(0, 1);
99 /// assert!(!transposed.is_contiguous());
100 ///
101 /// let contiguous = transposed.contiguous();
102 /// assert!(contiguous.is_contiguous());
103 /// assert_eq!(contiguous.get(&[0, 0]), 1.0);
104 /// assert_eq!(contiguous.get(&[0, 1]), 3.0);
105 /// ```
106 ///
107 /// ```
108 /// use train_station::Tensor;
109 ///
110 /// // Preserves gradient tracking
111 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
112 /// tensor.set_requires_grad(true);
113 ///
114 /// let contiguous = tensor.contiguous();
115 /// assert!(contiguous.requires_grad());
116 /// ```
117 ///
118 /// # Performance
119 ///
120 /// - **Already contiguous**: O(1) time complexity, returns a clone
121 /// - **Non-contiguous**: O(n) time complexity with size-dependent optimizations
122 /// - **Memory usage**: Creates a new tensor with the same size as the original
123 pub fn contiguous(&self) -> Tensor {
124 if self.is_contiguous() {
125 let mut cloned = self.clone();
126 // Ensure gradient requirements are preserved
127 if self.requires_grad() {
128 cloned.set_requires_grad(true);
129 // Register gradient function even for already-contiguous tensors
130 let grad_fn = GradFn::Contiguous {
131 input_shape: self.shape().dims.clone(),
132 };
133 cloned.set_grad_fn(grad_fn.clone());
134 GradEngine::register_operation(cloned.id(), vec![self.id()], grad_fn);
135 }
136 return cloned;
137 }
138
139 // Create new contiguous tensor and copy via optimized methods
140 let mut result = Tensor::new(self.shape().dims.clone());
141
142 unsafe {
143 self.copy_to_contiguous_optimized(&mut result);
144 }
145
146 // Preserve gradient requirements and register gradient function
147 if self.requires_grad() {
148 result.set_requires_grad(true);
149 let grad_fn = GradFn::Contiguous {
150 input_shape: self.shape().dims.clone(),
151 };
152 result.set_grad_fn(grad_fn.clone());
153 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
154 }
155
156 result
157 }
158
159 /// Internal optimized contiguous copy operation
160 ///
161 /// This function dispatches to the appropriate copy strategy based on
162 /// tensor size and rank for optimal performance.
163 ///
164 /// # Arguments
165 ///
166 /// * `result` - The destination tensor to copy data into
167 ///
168 /// # Safety
169 ///
170 /// The caller must ensure:
171 /// * `result` has the same shape as `self`
172 /// * `result` is properly allocated and initialized
173 /// * Both tensors are valid and not moved during the operation
174 #[inline]
175 unsafe fn copy_to_contiguous_optimized(&self, result: &mut Tensor) {
176 let size = self.size();
177 let rank = self.shape().rank();
178 let _src_ptr = self.as_ptr();
179 let _dst_ptr = result.as_mut_ptr();
180
181 // For simple 1D tensors or very small tensors, use simple copy
182 if rank <= 1 || size <= 64 {
183 self.copy_to_contiguous_simple(result, rank);
184 return;
185 }
186
187 // For larger multi-dimensional tensors, use optimized stride-aware copy
188 if size >= 1024 {
189 self.copy_to_contiguous_large(result, rank);
190 } else {
191 self.copy_to_contiguous_medium(result, rank);
192 }
193 }
194
195 /// Simple copy for small tensors or 1D tensors
196 ///
197 /// This function performs a straightforward coordinate-based copy
198 /// suitable for small tensors where the overhead of more complex
199 /// optimizations would not be beneficial.
200 ///
201 /// # Arguments
202 ///
203 /// * `result` - The destination tensor
204 /// * `rank` - The rank of the tensor
205 ///
206 /// # Safety
207 ///
208 /// The caller must ensure both tensors are valid and properly allocated.
209 #[inline]
210 unsafe fn copy_to_contiguous_simple(&self, result: &mut Tensor, rank: usize) {
211 let size = self.size();
212 let src_ptr = self.as_ptr();
213 let dst_ptr = result.as_mut_ptr();
214
215 for dst_idx in 0..size {
216 // Compute destination coordinates under contiguous strides
217 let mut coords = vec![0usize; rank];
218 let mut tmp = dst_idx;
219 for i in (0..rank).rev() {
220 let dim_size = self.shape().dims[i];
221 coords[i] = tmp % dim_size;
222 tmp /= dim_size;
223 }
224 let src_off = self.shape().offset(&coords);
225 *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
226 }
227 }
228
229 /// Optimized copy for medium-sized tensors with unrolling
230 ///
231 /// This function uses loop unrolling to improve performance for
232 /// medium-sized tensors by reducing loop overhead and improving
233 /// instruction-level parallelism.
234 ///
235 /// # Arguments
236 ///
237 /// * `result` - The destination tensor
238 /// * `rank` - The rank of the tensor
239 ///
240 /// # Safety
241 ///
242 /// The caller must ensure both tensors are valid and properly allocated.
243 #[inline]
244 unsafe fn copy_to_contiguous_medium(&self, result: &mut Tensor, rank: usize) {
245 let size = self.size();
246 let src_ptr = self.as_ptr();
247 let dst_ptr = result.as_mut_ptr();
248 let unroll_factor = 4;
249 let unroll_count = size / unroll_factor;
250 let mut dst_idx = 0;
251
252 // Unrolled loop for better performance
253 for _ in 0..unroll_count {
254 for unroll_i in 0..unroll_factor {
255 let coords = self.linear_to_coords(dst_idx + unroll_i, rank);
256 let src_off = self.shape().offset(&coords);
257 *dst_ptr.add(dst_idx + unroll_i) = *src_ptr.add(src_off);
258 }
259 dst_idx += unroll_factor;
260 }
261
262 // Handle remaining elements
263 for i in dst_idx..size {
264 let coords = self.linear_to_coords(i, rank);
265 let src_off = self.shape().offset(&coords);
266 *dst_ptr.add(i) = *src_ptr.add(src_off);
267 }
268 }
269
270 /// Cache-optimized copy for large tensors with blocking
271 ///
272 /// This function uses blocking to improve cache locality for large tensors.
273 /// It processes the tensor in blocks to maximize cache hit rates and
274 /// combines blocking with loop unrolling for optimal performance.
275 ///
276 /// # Arguments
277 ///
278 /// * `result` - The destination tensor
279 /// * `rank` - The rank of the tensor
280 ///
281 /// # Safety
282 ///
283 /// The caller must ensure both tensors are valid and properly allocated.
284 #[inline]
285 unsafe fn copy_to_contiguous_large(&self, result: &mut Tensor, rank: usize) {
286 let size = self.size();
287 let src_ptr = self.as_ptr();
288 let dst_ptr = result.as_mut_ptr();
289
290 // Use blocking to improve cache locality
291 let block_size = 1024; // Process 1024 elements per block
292 let num_blocks = (size + block_size - 1) / block_size;
293
294 for block in 0..num_blocks {
295 let start_idx = block * block_size;
296 let end_idx = (start_idx + block_size).min(size);
297 let block_len = end_idx - start_idx;
298
299 // Process block with unrolling
300 let unroll_factor = 4;
301 let unroll_count = block_len / unroll_factor;
302 let mut local_idx = 0;
303
304 for _ in 0..unroll_count {
305 for unroll_i in 0..unroll_factor {
306 let dst_idx = start_idx + local_idx + unroll_i;
307 let coords = self.linear_to_coords(dst_idx, rank);
308 let src_off = self.shape().offset(&coords);
309 *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
310 }
311 local_idx += unroll_factor;
312 }
313
314 // Handle remaining elements in this block
315 for i in local_idx..block_len {
316 let dst_idx = start_idx + i;
317 let coords = self.linear_to_coords(dst_idx, rank);
318 let src_off = self.shape().offset(&coords);
319 *dst_ptr.add(dst_idx) = *src_ptr.add(src_off);
320 }
321 }
322 }
323
324 /// Helper function to convert linear index to coordinates
325 ///
326 /// Converts a linear (flat) index into multi-dimensional coordinates
327 /// based on the tensor's shape. This is used for coordinate-based
328 /// memory access in non-contiguous tensors.
329 ///
330 /// # Arguments
331 ///
332 /// * `idx` - The linear index to convert
333 /// * `rank` - The rank of the tensor
334 ///
335 /// # Returns
336 ///
337 /// A vector of coordinates representing the multi-dimensional position
338 #[inline]
339 fn linear_to_coords(&self, mut idx: usize, rank: usize) -> Vec<usize> {
340 let mut coords = vec![0usize; rank];
341 for i in (0..rank).rev() {
342 let dim_size = self.shape().dims[i];
343 coords[i] = idx % dim_size;
344 idx /= dim_size;
345 }
346 coords
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_contiguous_copy() {
356 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
357
358 // Test that contiguous() returns a proper copy
359 let contiguous = tensor.contiguous();
360 assert!(contiguous.is_contiguous());
361 assert_eq!(contiguous.shape().dims, tensor.shape().dims);
362
363 // Verify data is preserved
364 assert_eq!(contiguous.get(&[0, 0]), 1.0);
365 assert_eq!(contiguous.get(&[1, 2]), 6.0);
366 }
367
368 #[test]
369 fn test_contiguous_already_contiguous() {
370 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
371
372 // For already contiguous tensors, should return a clone
373 let contiguous = tensor.contiguous();
374 assert!(contiguous.is_contiguous());
375 assert_eq!(contiguous.shape().dims, tensor.shape().dims);
376 assert_eq!(contiguous.size(), tensor.size());
377 }
378
379 #[test]
380 fn test_contiguous_preserves_gradient_tracking() {
381 let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
382 tensor.set_requires_grad(true);
383
384 let contiguous = tensor.contiguous();
385 assert!(contiguous.requires_grad());
386 }
387
388 #[test]
389 fn test_contiguous_gradient_flow() {
390 // Test that gradients flow correctly through contiguous operation
391 let mut x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
392 x.set_requires_grad(true);
393
394 // Create a non-contiguous tensor through transpose
395 let x_transposed = x.transpose(0, 1);
396 assert!(!x_transposed.is_contiguous());
397
398 // Make it contiguous
399 let x_contiguous = x_transposed.contiguous();
400 assert!(x_contiguous.is_contiguous());
401 assert!(x_contiguous.requires_grad());
402
403 // Do a simple operation and backward
404 let mut result = x_contiguous.sum();
405 result.backward(None);
406
407 // Check that the original tensor received gradients
408 let grad = x.grad_by_value().expect("Gradient should exist");
409 assert_eq!(grad.shape().dims, vec![2, 2]);
410
411 // All gradients should be 1.0 since sum operation
412 for i in 0..2 {
413 for j in 0..2 {
414 assert_eq!(grad.get(&[i, j]), 1.0);
415 }
416 }
417 }
418
419 #[test]
420 fn test_contiguous_1d() {
421 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
422 let contiguous = tensor.contiguous();
423
424 assert!(contiguous.is_contiguous());
425 assert_eq!(contiguous.shape().dims, vec![3]);
426
427 // Verify data preservation
428 for i in 0..3 {
429 assert_eq!(contiguous.get(&[i]), (i + 1) as f32);
430 }
431 }
432
433 #[test]
434 fn test_contiguous_3d() {
435 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
436 let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
437 let contiguous = tensor.contiguous();
438
439 assert!(contiguous.is_contiguous());
440 assert_eq!(contiguous.shape().dims, vec![2, 3, 4]);
441 assert_eq!(contiguous.size(), 24);
442 }
443}