train_station/tensor/core/shape.rs
1//! Tensor shape and memory layout management
2//!
3//! This module provides the `Shape` struct and related components for managing
4//! tensor dimensions, memory strides, and layout information. The shape system
5//! enables efficient view operations, broadcasting, and memory access optimization.
6//!
7//! # Architecture
8//!
9//! The shape system consists of:
10//! - **Shape**: Main struct containing dimensions, strides, and layout information
11//! - **MemoryLayout**: Enum describing memory layout types (Contiguous, Strided, View)
12//! - **Stride Calculation**: Efficient computation of memory access patterns
13//! - **Broadcasting**: NumPy-compatible broadcasting rules implementation
14//!
15//! # Key Features
16//!
17//! - **Memory Layout Tracking**: Contiguous, strided, and view layout types
18//! - **Stride Optimization**: Efficient memory access pattern calculation
19//! - **Broadcasting Support**: NumPy-compatible broadcasting rules
20//! - **View Operations**: Zero-copy tensor transformations
21//! - **Performance Hints**: Layout information for operation optimization
22//! - **Memory Safety**: Bounds checking and validation
23//!
24//! # Performance Characteristics
25//!
26//! - **Zero-Cost Layout**: Layout information computed once and cached
27//! - **Efficient Strides**: Row-major stride calculation for optimal memory access
28//! - **Broadcasting**: O(rank) complexity for broadcasting compatibility checks
29//! - **Memory Access**: O(1) offset calculation for multi-dimensional indices
30//! - **View Efficiency**: Zero-copy view creation with minimal overhead
31//!
32//! # Memory Layout Types
33//!
34//! - **Contiguous**: Standard row-major layout with sequential memory access
35//! - **Strided**: Custom stride layout for non-contiguous memory access
36//! - **View**: Non-contiguous reference to existing tensor data
37//!
38//! # Examples
39//!
40//! ## Basic Shape Operations
41//!
42//! ```
43//! use train_station::tensor::Shape;
44//!
45//! // Create contiguous shape
46//! let shape = Shape::new(vec![2, 3, 4]);
47//! assert_eq!(shape.size, 24);
48//! assert!(shape.is_contiguous());
49//!
50//! // Create view shape
51//! let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
52//! assert!(view_shape.is_view());
53//!
54//! // Check broadcasting compatibility
55//! let shape1 = Shape::new(vec![2, 3, 4]);
56//! let shape2 = Shape::new(vec![1, 3, 4]);
57//! assert!(shape1.is_broadcastable_with(&shape2));
58//!
59//! // Calculate memory offset
60//! let offset = shape.offset(&[1, 2, 3]);
61//! assert_eq!(offset, 12 + 8 + 3);
62//! ```
63//!
64//! # Design Principles
65//!
66//! - **Memory Efficiency**: Optimized for cache-friendly access patterns
67//! - **Zero-Cost Abstractions**: Minimal overhead for shape operations
68//! - **NumPy Compatibility**: Broadcasting rules match NumPy behavior
69//! - **Type Safety**: Strong typing for memory layout and dimensions
70//! - **Performance First**: All operations optimized for speed
71
72/// Memory layout information for tensors
73///
74/// Describes how tensor data is arranged in memory for optimized access patterns
75/// and view operations. This enum provides performance hints for operation
76/// selection and memory access optimization.
77///
78/// # Variants
79///
80/// * `Contiguous` - Standard row-major layout with sequential memory access
81/// * `Strided` - Custom stride layout for non-contiguous memory access
82/// * `View` - Non-contiguous reference to existing tensor data
83///
84/// # Performance Characteristics
85///
86/// - **Contiguous**: Optimal for SIMD operations and cache efficiency
87/// - **Strided**: Requires custom memory access patterns
88/// - **View**: Zero-copy operations with shared memory management
89///
90/// # Implementation Details
91///
92/// This enum is used internally by the shape system to track memory layout
93/// information for optimization decisions. The layout type determines which
94/// operations can be used efficiently on the tensor data.
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum MemoryLayout {
97 /// Contiguous memory layout (standard row-major)
98 Contiguous,
99 /// Strided memory layout with custom stride information
100 Strided,
101 /// Non-contiguous view of another tensor
102 View,
103}
104
105/// Represents the shape/dimensions of a tensor with stride tracking
106///
107/// This struct holds the dimensional information for a tensor, including the size
108/// of each dimension, memory strides, and layout information for efficient view
109/// operations and transformations. The shape system enables zero-copy tensor
110/// views and optimized memory access patterns.
111///
112/// # Key Features
113///
114/// - **Dimension Management**: Multi-dimensional shape representation
115/// - **Stride Calculation**: Efficient memory access pattern computation
116/// - **Layout Tracking**: Contiguous, strided, and view layout types
117/// - **Broadcasting**: NumPy-compatible broadcasting rules
118/// - **Memory Safety**: Bounds checking and validation
119///
120/// # Performance Characteristics
121///
122/// - **Zero-Cost Layout**: Layout information computed once and cached
123/// - **Efficient Strides**: Row-major stride calculation for optimal access
124/// - **Memory Access**: O(1) offset calculation for multi-dimensional indices
125/// - **View Efficiency**: Zero-copy view creation with minimal overhead
126///
127/// # Examples
128///
129/// ```
130/// use train_station::tensor::Shape;
131///
132/// let shape = Shape::new(vec![2, 3, 4]);
133/// assert_eq!(shape.size, 24);
134/// assert!(shape.is_contiguous());
135/// assert_eq!(shape.strides(), &[12, 4, 1]);
136/// ```
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct Shape {
139 /// The dimensions of the tensor (e.g., [2, 3, 4] for a 2x3x4 tensor)
140 pub dims: Vec<usize>,
141 /// Total number of elements in the tensor
142 pub size: usize,
143 /// Memory strides for each dimension (elements between consecutive indices)
144 /// For a contiguous tensor with shape [2, 3, 4], strides would be [12, 4, 1]
145 pub strides: Vec<usize>,
146 /// Memory layout type for optimization decisions
147 pub layout: MemoryLayout,
148}
149
150impl Shape {
151 /// Creates a new contiguous shape from a vector of dimensions
152 ///
153 /// Computes the total size and contiguous strides for the given dimensions.
154 /// The resulting shape uses row-major memory layout optimized for cache
155 /// efficiency and SIMD operations.
156 ///
157 /// # Arguments
158 ///
159 /// * `dims` - Vector of dimension sizes defining the tensor shape
160 ///
161 /// # Returns
162 ///
163 /// A new Shape with calculated size, contiguous strides, and contiguous layout
164 ///
165 /// # Performance
166 ///
167 /// - **Time Complexity**: O(rank) for stride calculation
168 /// - **Memory**: Single allocation for dimensions and strides
169 /// - **Optimization**: Row-major layout for cache efficiency
170 ///
171 /// # Examples
172 ///
173 /// ```
174 /// use train_station::tensor::Shape;
175 ///
176 /// let shape = Shape::new(vec![2, 3, 4]);
177 /// assert_eq!(shape.size, 24);
178 /// assert!(shape.is_contiguous());
179 /// assert_eq!(shape.strides(), &[12, 4, 1]);
180 /// ```
181 #[inline]
182 #[track_caller]
183 pub fn new(dims: Vec<usize>) -> Self {
184 let size = dims.iter().product();
185 let strides = Self::compute_contiguous_strides(&dims);
186 Self {
187 dims,
188 size,
189 strides,
190 layout: MemoryLayout::Contiguous,
191 }
192 }
193
194 /// Creates a new shape with custom strides
195 ///
196 /// Creates a shape with user-defined strides for non-contiguous memory layouts.
197 /// Automatically detects if the strides represent a contiguous layout and sets
198 /// the appropriate layout type.
199 ///
200 /// # Arguments
201 ///
202 /// * `dims` - Vector of dimension sizes defining the tensor shape
203 /// * `strides` - Vector of memory strides for each dimension
204 ///
205 /// # Returns
206 ///
207 /// A new Shape with the given dimensions and strides, with layout type
208 /// automatically determined
209 ///
210 /// # Panics
211 ///
212 /// Panics if dimensions and strides have different lengths
213 ///
214 /// # Performance
215 ///
216 /// - **Layout Detection**: O(rank) comparison with contiguous strides
217 /// - **Memory**: Single allocation for shape data
218 /// - **Optimization**: Automatic layout type detection
219 ///
220 /// # Examples
221 ///
222 /// ```
223 /// use train_station::tensor::Shape;
224 ///
225 /// let shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
226 /// assert_eq!(shape.size, 6);
227 /// assert!(!shape.is_contiguous());
228 /// assert_eq!(shape.strides(), &[6, 2]);
229 /// ```
230 #[inline]
231 #[track_caller]
232 pub fn with_strides(dims: Vec<usize>, strides: Vec<usize>) -> Self {
233 assert_eq!(
234 dims.len(),
235 strides.len(),
236 "Dimensions and strides must have same length"
237 );
238 let size = dims.iter().product();
239 let layout = if strides == Self::compute_contiguous_strides(&dims) {
240 MemoryLayout::Contiguous
241 } else {
242 MemoryLayout::Strided
243 };
244 Self {
245 dims,
246 size,
247 strides,
248 layout,
249 }
250 }
251
252 /// Creates a view shape (non-contiguous reference to existing tensor)
253 ///
254 /// Creates a shape representing a view of existing tensor data with custom
255 /// dimensions and strides. View shapes enable zero-copy tensor transformations
256 /// by sharing memory with the original tensor.
257 ///
258 /// # Arguments
259 ///
260 /// * `dims` - Vector of dimension sizes for the view
261 /// * `strides` - Vector of memory strides for the view
262 ///
263 /// # Returns
264 ///
265 /// A new Shape marked as a view with the given dimensions and strides
266 ///
267 /// # Panics
268 ///
269 /// Panics if dimensions and strides have different lengths
270 ///
271 /// # Performance
272 ///
273 /// - **Zero-Copy**: No data copying, only metadata creation
274 /// - **Memory Efficient**: Shares memory with original tensor
275 /// - **View Optimization**: Enables view-specific operation optimizations
276 ///
277 /// # Examples
278 ///
279 /// ```
280 /// use train_station::tensor::Shape;
281 ///
282 /// let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
283 /// assert!(view_shape.is_view());
284 /// assert!(!view_shape.is_contiguous());
285 /// ```
286 #[inline]
287 #[track_caller]
288 pub fn as_view(dims: Vec<usize>, strides: Vec<usize>) -> Self {
289 assert_eq!(
290 dims.len(),
291 strides.len(),
292 "Dimensions and strides must have same length"
293 );
294 let size = dims.iter().product();
295 Self {
296 dims,
297 size,
298 strides,
299 layout: MemoryLayout::View,
300 }
301 }
302
303 /// Computes contiguous strides for given dimensions (row-major order)
304 ///
305 /// Calculates the memory strides for a contiguous row-major layout.
306 /// This is used internally for shape creation and layout detection.
307 ///
308 /// # Arguments
309 ///
310 /// * `dims` - Vector of dimension sizes
311 ///
312 /// # Returns
313 ///
314 /// Vector of strides for contiguous row-major layout
315 ///
316 /// # Implementation Details
317 ///
318 /// This method is used internally by the shape system to compute
319 /// contiguous strides for new shapes and to detect if custom strides
320 /// represent a contiguous layout.
321 fn compute_contiguous_strides(dims: &[usize]) -> Vec<usize> {
322 let mut strides = Vec::with_capacity(dims.len());
323 if dims.is_empty() {
324 return strides;
325 }
326
327 let mut stride = 1;
328 for &dim in dims.iter().rev() {
329 strides.push(stride);
330 stride *= dim;
331 }
332 strides.reverse();
333 strides
334 }
335
336 /// Returns the number of dimensions (rank) of the tensor
337 ///
338 /// # Returns
339 ///
340 /// The number of dimensions in the tensor shape
341 ///
342 /// # Examples
343 ///
344 /// ```
345 /// use train_station::tensor::Shape;
346 ///
347 /// let shape = Shape::new(vec![2, 3, 4]);
348 /// assert_eq!(shape.rank(), 3);
349 /// ```
350 #[inline]
351 #[track_caller]
352 pub fn rank(&self) -> usize {
353 self.dims.len()
354 }
355
356 /// Checks if the tensor has contiguous memory layout
357 ///
358 /// # Returns
359 ///
360 /// `true` if the tensor data is stored contiguously in memory
361 ///
362 /// # Examples
363 ///
364 /// ```
365 /// use train_station::tensor::Shape;
366 ///
367 /// let shape = Shape::new(vec![2, 3, 4]);
368 /// assert!(shape.is_contiguous());
369 /// ```
370 #[inline]
371 #[track_caller]
372 pub fn is_contiguous(&self) -> bool {
373 matches!(self.layout, MemoryLayout::Contiguous)
374 }
375
376 /// Checks if the tensor is a view of another tensor
377 ///
378 /// # Returns
379 ///
380 /// `true` if this tensor is a view (non-contiguous reference)
381 ///
382 /// # Examples
383 ///
384 /// ```
385 /// use train_station::tensor::Shape;
386 ///
387 /// let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
388 /// assert!(view_shape.is_view());
389 /// ```
390 #[inline]
391 #[track_caller]
392 pub fn is_view(&self) -> bool {
393 matches!(self.layout, MemoryLayout::View)
394 }
395
396 /// Gets the memory stride for a specific dimension
397 ///
398 /// # Arguments
399 ///
400 /// * `dim` - The dimension index
401 ///
402 /// # Returns
403 ///
404 /// The memory stride for the given dimension
405 ///
406 /// # Panics
407 ///
408 /// Panics if `dim` is out of bounds
409 ///
410 /// # Examples
411 ///
412 /// ```
413 /// use train_station::tensor::Shape;
414 ///
415 /// let shape = Shape::new(vec![2, 3, 4]);
416 /// assert_eq!(shape.stride(0), 12);
417 /// assert_eq!(shape.stride(1), 4);
418 /// assert_eq!(shape.stride(2), 1);
419 /// ```
420 #[inline]
421 #[track_caller]
422 pub fn stride(&self, dim: usize) -> usize {
423 self.strides[dim]
424 }
425
426 /// Gets all memory strides
427 ///
428 /// # Returns
429 ///
430 /// Reference to the stride vector
431 ///
432 /// # Examples
433 ///
434 /// ```
435 /// use train_station::tensor::Shape;
436 ///
437 /// let shape = Shape::new(vec![2, 3, 4]);
438 /// assert_eq!(shape.strides(), &[12, 4, 1]);
439 /// ```
440 #[inline]
441 #[track_caller]
442 pub fn strides(&self) -> &[usize] {
443 &self.strides
444 }
445
446 /// Gets the memory layout type
447 ///
448 /// # Returns
449 ///
450 /// Reference to the memory layout
451 ///
452 /// # Implementation Details
453 ///
454 /// This method returns the memory layout type which can be used for
455 /// optimization decisions in tensor operations.
456 #[inline]
457 #[track_caller]
458 pub fn layout(&self) -> &MemoryLayout {
459 &self.layout
460 }
461
462 /// Calculates the linear memory offset for given indices
463 ///
464 /// Computes the linear memory offset for multi-dimensional tensor indices
465 /// using the stored stride information. This enables efficient direct memory
466 /// access for tensor operations.
467 ///
468 /// # Arguments
469 ///
470 /// * `indices` - Vector of indices for each dimension
471 ///
472 /// # Returns
473 ///
474 /// Linear memory offset for the given multi-dimensional indices
475 ///
476 /// # Panics
477 ///
478 /// Panics if indices length doesn't match tensor rank
479 ///
480 /// # Performance
481 ///
482 /// - **Time Complexity**: O(rank) for offset calculation
483 /// - **Memory**: No allocation, uses existing stride data
484 /// - **Optimization**: Efficient dot product of indices and strides
485 ///
486 /// # Examples
487 ///
488 /// ```
489 /// use train_station::tensor::Shape;
490 ///
491 /// let shape = Shape::new(vec![2, 3, 4]);
492 /// let offset = shape.offset(&[1, 2, 3]);
493 /// assert_eq!(offset, 12 + 8 + 3); // 1*12 + 2*4 + 3*1
494 /// ```
495 #[inline]
496 #[track_caller]
497 pub fn offset(&self, indices: &[usize]) -> usize {
498 assert_eq!(indices.len(), self.rank(), "Indices must match tensor rank");
499 indices
500 .iter()
501 .zip(self.strides.iter())
502 .map(|(&idx, &stride)| idx * stride)
503 .sum()
504 }
505
506 /// Checks if this shape is broadcastable with another shape
507 ///
508 /// Determines if two shapes can be broadcast together according to NumPy
509 /// broadcasting rules. Broadcasting enables element-wise operations between
510 /// tensors with different shapes by expanding singleton dimensions.
511 ///
512 /// # Arguments
513 ///
514 /// * `other` - The other shape to check broadcasting compatibility
515 ///
516 /// # Returns
517 ///
518 /// `true` if the shapes are broadcastable according to NumPy broadcasting rules
519 ///
520 /// # Performance
521 ///
522 /// - **Time Complexity**: O(max(rank1, rank2)) for broadcasting check
523 /// - **Memory**: No allocation, uses existing dimension data
524 /// - **Optimization**: Right-aligned dimension comparison
525 ///
526 /// # Broadcasting Rules
527 ///
528 /// - Dimensions are compared from right to left
529 /// - Dimensions are compatible if they are equal, or one is 1
530 /// - Missing dimensions are treated as 1
531 ///
532 /// # Examples
533 ///
534 /// ```
535 /// use train_station::tensor::Shape;
536 ///
537 /// let shape1 = Shape::new(vec![2, 3, 4]);
538 /// let shape2 = Shape::new(vec![1, 3, 4]);
539 /// assert!(shape1.is_broadcastable_with(&shape2));
540 ///
541 /// let shape3 = Shape::new(vec![4]);
542 /// assert!(shape1.is_broadcastable_with(&shape3));
543 /// ```
544 #[track_caller]
545 pub fn is_broadcastable_with(&self, other: &Shape) -> bool {
546 let max_rank = self.rank().max(other.rank());
547
548 for i in 0..max_rank {
549 let self_dim = if i < self.rank() {
550 self.dims[self.rank() - 1 - i]
551 } else {
552 1
553 };
554
555 let other_dim = if i < other.rank() {
556 other.dims[other.rank() - 1 - i]
557 } else {
558 1
559 };
560
561 if self_dim != other_dim && self_dim != 1 && other_dim != 1 {
562 return false;
563 }
564 }
565
566 true
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 //! Shape and memory layout tests
573 //!
574 //! Comprehensive tests for shape creation, memory layout detection, stride
575 //! calculation, broadcasting compatibility, and offset computation. Tests
576 //! cover all major functionality including edge cases and performance characteristics.
577
578 use super::*;
579
580 /// Test basic shape creation and properties
581 ///
582 /// Verifies that shapes are created with correct dimensions, size, rank,
583 /// and layout information. Tests the fundamental shape creation functionality.
584 #[test]
585 fn test_shape_creation() {
586 let shape = Shape::new(vec![2, 3, 4]);
587 assert_eq!(shape.size, 24);
588 assert_eq!(shape.rank(), 3);
589 assert!(shape.is_contiguous());
590 assert!(!shape.is_view());
591 assert_eq!(shape.strides(), &[12, 4, 1]);
592 }
593
594 #[test]
595 fn test_shape_1d() {
596 let shape = Shape::new(vec![10]);
597 assert_eq!(shape.size, 10);
598 assert_eq!(shape.rank(), 1);
599 assert!(shape.is_contiguous());
600 assert_eq!(shape.strides(), &[1]);
601 }
602
603 #[test]
604 fn test_shape_2d() {
605 let shape = Shape::new(vec![5, 6]);
606 assert_eq!(shape.size, 30);
607 assert_eq!(shape.rank(), 2);
608 assert!(shape.is_contiguous());
609 assert_eq!(shape.strides(), &[6, 1]);
610 }
611
612 #[test]
613 fn test_zero_sized_shape() {
614 let shape = Shape::new(vec![0]);
615 assert_eq!(shape.size, 0);
616 assert_eq!(shape.rank(), 1);
617 assert_eq!(shape.strides(), &[1]);
618 }
619
620 #[test]
621 fn test_shape_with_strides() {
622 let shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
623 assert_eq!(shape.size, 6);
624 assert_eq!(shape.rank(), 2);
625 assert!(!shape.is_contiguous());
626 assert_eq!(shape.strides(), &[6, 2]);
627 assert_eq!(shape.stride(0), 6);
628 assert_eq!(shape.stride(1), 2);
629 }
630
631 #[test]
632 fn test_shape_as_view() {
633 let shape = Shape::as_view(vec![2, 2], vec![4, 1]);
634 assert_eq!(shape.size, 4);
635 assert_eq!(shape.rank(), 2);
636 assert!(shape.is_view());
637 assert!(!shape.is_contiguous());
638 assert_eq!(shape.strides(), &[4, 1]);
639 }
640
641 #[test]
642 fn test_stride_calculation() {
643 let dims = vec![2, 3, 4];
644 let strides = Shape::compute_contiguous_strides(&dims);
645 assert_eq!(strides, vec![12, 4, 1]);
646
647 let dims = vec![5];
648 let strides = Shape::compute_contiguous_strides(&dims);
649 assert_eq!(strides, vec![1]);
650 }
651
652 /// Test memory offset calculation for multi-dimensional indices
653 ///
654 /// Verifies that linear memory offsets are correctly calculated for various
655 /// multi-dimensional index combinations using stride information.
656 #[test]
657 fn test_offset_calculation() {
658 let shape = Shape::new(vec![2, 3, 4]);
659
660 // Test corner cases
661 assert_eq!(shape.offset(&[0, 0, 0]), 0);
662 assert_eq!(shape.offset(&[1, 2, 3]), 12 + 8 + 3);
663 assert_eq!(shape.offset(&[1, 0, 0]), 12);
664 assert_eq!(shape.offset(&[0, 1, 0]), 4);
665 assert_eq!(shape.offset(&[0, 0, 1]), 1);
666 }
667
668 /// Test broadcasting compatibility rules
669 ///
670 /// Verifies that NumPy-compatible broadcasting rules are correctly implemented
671 /// for various shape combinations including singleton dimensions and different ranks.
672 #[test]
673 fn test_broadcasting_compatibility() {
674 let shape1 = Shape::new(vec![2, 3, 4]);
675 let shape2 = Shape::new(vec![1, 3, 4]);
676 let shape3 = Shape::new(vec![4]);
677 let shape4 = Shape::new(vec![2, 1, 4]);
678 let shape5 = Shape::new(vec![2, 2, 4]);
679
680 assert!(shape1.is_broadcastable_with(&shape2));
681 assert!(shape1.is_broadcastable_with(&shape3));
682 assert!(shape1.is_broadcastable_with(&shape4));
683 assert!(!shape1.is_broadcastable_with(&shape5)); // 3 != 2 and neither is 1
684
685 assert!(shape2.is_broadcastable_with(&shape1));
686 assert!(shape3.is_broadcastable_with(&shape1));
687 assert!(shape4.is_broadcastable_with(&shape1));
688 }
689
690 #[test]
691 fn test_contiguous_strides_detection() {
692 let shape1 = Shape::with_strides(vec![2, 3, 4], vec![12, 4, 1]);
693 assert!(shape1.is_contiguous());
694
695 let shape2 = Shape::with_strides(vec![2, 3, 4], vec![12, 4, 2]);
696 assert!(!shape2.is_contiguous());
697 }
698}