train_station/tensor/core/
shape.rs

1//! Tensor shape and memory layout management
2//!
3//! This module provides the `Shape` struct and related components for managing
4//! tensor dimensions, memory strides, and layout information. The shape system
5//! enables efficient view operations, broadcasting, and memory access optimization.
6//!
7//! # Architecture
8//!
9//! The shape system consists of:
10//! - **Shape**: Main struct containing dimensions, strides, and layout information
11//! - **MemoryLayout**: Enum describing memory layout types (Contiguous, Strided, View)
12//! - **Stride Calculation**: Efficient computation of memory access patterns
13//! - **Broadcasting**: NumPy-compatible broadcasting rules implementation
14//!
15//! # Key Features
16//!
17//! - **Memory Layout Tracking**: Contiguous, strided, and view layout types
18//! - **Stride Optimization**: Efficient memory access pattern calculation
19//! - **Broadcasting Support**: NumPy-compatible broadcasting rules
20//! - **View Operations**: Zero-copy tensor transformations
21//! - **Performance Hints**: Layout information for operation optimization
22//! - **Memory Safety**: Bounds checking and validation
23//!
24//! # Performance Characteristics
25//!
26//! - **Zero-Cost Layout**: Layout information computed once and cached
27//! - **Efficient Strides**: Row-major stride calculation for optimal memory access
28//! - **Broadcasting**: O(rank) complexity for broadcasting compatibility checks
29//! - **Memory Access**: O(1) offset calculation for multi-dimensional indices
30//! - **View Efficiency**: Zero-copy view creation with minimal overhead
31//!
32//! # Memory Layout Types
33//!
34//! - **Contiguous**: Standard row-major layout with sequential memory access
35//! - **Strided**: Custom stride layout for non-contiguous memory access
36//! - **View**: Non-contiguous reference to existing tensor data
37//!
38//! # Examples
39//!
40//! ## Basic Shape Operations
41//!
42//! ```
43//! use train_station::tensor::Shape;
44//!
45//! // Create contiguous shape
46//! let shape = Shape::new(vec![2, 3, 4]);
47//! assert_eq!(shape.size, 24);
48//! assert!(shape.is_contiguous());
49//!
50//! // Create view shape
51//! let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
52//! assert!(view_shape.is_view());
53//!
54//! // Check broadcasting compatibility
55//! let shape1 = Shape::new(vec![2, 3, 4]);
56//! let shape2 = Shape::new(vec![1, 3, 4]);
57//! assert!(shape1.is_broadcastable_with(&shape2));
58//!
59//! // Calculate memory offset
60//! let offset = shape.offset(&[1, 2, 3]);
61//! assert_eq!(offset, 12 + 8 + 3);
62//! ```
63//!
64//! # Design Principles
65//!
66//! - **Memory Efficiency**: Optimized for cache-friendly access patterns
67//! - **Zero-Cost Abstractions**: Minimal overhead for shape operations
68//! - **NumPy Compatibility**: Broadcasting rules match NumPy behavior
69//! - **Type Safety**: Strong typing for memory layout and dimensions
70//! - **Performance First**: All operations optimized for speed
71
72/// Memory layout information for tensors
73///
74/// Describes how tensor data is arranged in memory for optimized access patterns
75/// and view operations. This enum provides performance hints for operation
76/// selection and memory access optimization.
77///
78/// # Variants
79///
80/// * `Contiguous` - Standard row-major layout with sequential memory access
81/// * `Strided` - Custom stride layout for non-contiguous memory access
82/// * `View` - Non-contiguous reference to existing tensor data
83///
84/// # Performance Characteristics
85///
86/// - **Contiguous**: Optimal for SIMD operations and cache efficiency
87/// - **Strided**: Requires custom memory access patterns
88/// - **View**: Zero-copy operations with shared memory management
89///
90/// # Implementation Details
91///
92/// This enum is used internally by the shape system to track memory layout
93/// information for optimization decisions. The layout type determines which
94/// operations can be used efficiently on the tensor data.
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum MemoryLayout {
97    /// Contiguous memory layout (standard row-major)
98    Contiguous,
99    /// Strided memory layout with custom stride information
100    Strided,
101    /// Non-contiguous view of another tensor
102    View,
103}
104
105/// Represents the shape/dimensions of a tensor with stride tracking
106///
107/// This struct holds the dimensional information for a tensor, including the size
108/// of each dimension, memory strides, and layout information for efficient view
109/// operations and transformations. The shape system enables zero-copy tensor
110/// views and optimized memory access patterns.
111///
112/// # Key Features
113///
114/// - **Dimension Management**: Multi-dimensional shape representation
115/// - **Stride Calculation**: Efficient memory access pattern computation
116/// - **Layout Tracking**: Contiguous, strided, and view layout types
117/// - **Broadcasting**: NumPy-compatible broadcasting rules
118/// - **Memory Safety**: Bounds checking and validation
119///
120/// # Performance Characteristics
121///
122/// - **Zero-Cost Layout**: Layout information computed once and cached
123/// - **Efficient Strides**: Row-major stride calculation for optimal access
124/// - **Memory Access**: O(1) offset calculation for multi-dimensional indices
125/// - **View Efficiency**: Zero-copy view creation with minimal overhead
126///
127/// # Examples
128///
129/// ```
130/// use train_station::tensor::Shape;
131///
132/// let shape = Shape::new(vec![2, 3, 4]);
133/// assert_eq!(shape.size, 24);
134/// assert!(shape.is_contiguous());
135/// assert_eq!(shape.strides(), &[12, 4, 1]);
136/// ```
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct Shape {
139    /// The dimensions of the tensor (e.g., [2, 3, 4] for a 2x3x4 tensor)
140    pub dims: Vec<usize>,
141    /// Total number of elements in the tensor
142    pub size: usize,
143    /// Memory strides for each dimension (elements between consecutive indices)
144    /// For a contiguous tensor with shape [2, 3, 4], strides would be [12, 4, 1]
145    pub strides: Vec<usize>,
146    /// Memory layout type for optimization decisions
147    pub layout: MemoryLayout,
148}
149
150impl Shape {
151    /// Creates a new contiguous shape from a vector of dimensions
152    ///
153    /// Computes the total size and contiguous strides for the given dimensions.
154    /// The resulting shape uses row-major memory layout optimized for cache
155    /// efficiency and SIMD operations.
156    ///
157    /// # Arguments
158    ///
159    /// * `dims` - Vector of dimension sizes defining the tensor shape
160    ///
161    /// # Returns
162    ///
163    /// A new Shape with calculated size, contiguous strides, and contiguous layout
164    ///
165    /// # Performance
166    ///
167    /// - **Time Complexity**: O(rank) for stride calculation
168    /// - **Memory**: Single allocation for dimensions and strides
169    /// - **Optimization**: Row-major layout for cache efficiency
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use train_station::tensor::Shape;
175    ///
176    /// let shape = Shape::new(vec![2, 3, 4]);
177    /// assert_eq!(shape.size, 24);
178    /// assert!(shape.is_contiguous());
179    /// assert_eq!(shape.strides(), &[12, 4, 1]);
180    /// ```
181    #[inline]
182    #[track_caller]
183    pub fn new(dims: Vec<usize>) -> Self {
184        let size = dims.iter().product();
185        let strides = Self::compute_contiguous_strides(&dims);
186        Self {
187            dims,
188            size,
189            strides,
190            layout: MemoryLayout::Contiguous,
191        }
192    }
193
194    /// Creates a new shape with custom strides
195    ///
196    /// Creates a shape with user-defined strides for non-contiguous memory layouts.
197    /// Automatically detects if the strides represent a contiguous layout and sets
198    /// the appropriate layout type.
199    ///
200    /// # Arguments
201    ///
202    /// * `dims` - Vector of dimension sizes defining the tensor shape
203    /// * `strides` - Vector of memory strides for each dimension
204    ///
205    /// # Returns
206    ///
207    /// A new Shape with the given dimensions and strides, with layout type
208    /// automatically determined
209    ///
210    /// # Panics
211    ///
212    /// Panics if dimensions and strides have different lengths
213    ///
214    /// # Performance
215    ///
216    /// - **Layout Detection**: O(rank) comparison with contiguous strides
217    /// - **Memory**: Single allocation for shape data
218    /// - **Optimization**: Automatic layout type detection
219    ///
220    /// # Examples
221    ///
222    /// ```
223    /// use train_station::tensor::Shape;
224    ///
225    /// let shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
226    /// assert_eq!(shape.size, 6);
227    /// assert!(!shape.is_contiguous());
228    /// assert_eq!(shape.strides(), &[6, 2]);
229    /// ```
230    #[inline]
231    #[track_caller]
232    pub fn with_strides(dims: Vec<usize>, strides: Vec<usize>) -> Self {
233        assert_eq!(
234            dims.len(),
235            strides.len(),
236            "Dimensions and strides must have same length"
237        );
238        let size = dims.iter().product();
239        let layout = if strides == Self::compute_contiguous_strides(&dims) {
240            MemoryLayout::Contiguous
241        } else {
242            MemoryLayout::Strided
243        };
244        Self {
245            dims,
246            size,
247            strides,
248            layout,
249        }
250    }
251
252    /// Creates a view shape (non-contiguous reference to existing tensor)
253    ///
254    /// Creates a shape representing a view of existing tensor data with custom
255    /// dimensions and strides. View shapes enable zero-copy tensor transformations
256    /// by sharing memory with the original tensor.
257    ///
258    /// # Arguments
259    ///
260    /// * `dims` - Vector of dimension sizes for the view
261    /// * `strides` - Vector of memory strides for the view
262    ///
263    /// # Returns
264    ///
265    /// A new Shape marked as a view with the given dimensions and strides
266    ///
267    /// # Panics
268    ///
269    /// Panics if dimensions and strides have different lengths
270    ///
271    /// # Performance
272    ///
273    /// - **Zero-Copy**: No data copying, only metadata creation
274    /// - **Memory Efficient**: Shares memory with original tensor
275    /// - **View Optimization**: Enables view-specific operation optimizations
276    ///
277    /// # Examples
278    ///
279    /// ```
280    /// use train_station::tensor::Shape;
281    ///
282    /// let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
283    /// assert!(view_shape.is_view());
284    /// assert!(!view_shape.is_contiguous());
285    /// ```
286    #[inline]
287    #[track_caller]
288    pub fn as_view(dims: Vec<usize>, strides: Vec<usize>) -> Self {
289        assert_eq!(
290            dims.len(),
291            strides.len(),
292            "Dimensions and strides must have same length"
293        );
294        let size = dims.iter().product();
295        Self {
296            dims,
297            size,
298            strides,
299            layout: MemoryLayout::View,
300        }
301    }
302
303    /// Computes contiguous strides for given dimensions (row-major order)
304    ///
305    /// Calculates the memory strides for a contiguous row-major layout.
306    /// This is used internally for shape creation and layout detection.
307    ///
308    /// # Arguments
309    ///
310    /// * `dims` - Vector of dimension sizes
311    ///
312    /// # Returns
313    ///
314    /// Vector of strides for contiguous row-major layout
315    ///
316    /// # Implementation Details
317    ///
318    /// This method is used internally by the shape system to compute
319    /// contiguous strides for new shapes and to detect if custom strides
320    /// represent a contiguous layout.
321    fn compute_contiguous_strides(dims: &[usize]) -> Vec<usize> {
322        let mut strides = Vec::with_capacity(dims.len());
323        if dims.is_empty() {
324            return strides;
325        }
326
327        let mut stride = 1;
328        for &dim in dims.iter().rev() {
329            strides.push(stride);
330            stride *= dim;
331        }
332        strides.reverse();
333        strides
334    }
335
336    /// Returns the number of dimensions (rank) of the tensor
337    ///
338    /// # Returns
339    ///
340    /// The number of dimensions in the tensor shape
341    ///
342    /// # Examples
343    ///
344    /// ```
345    /// use train_station::tensor::Shape;
346    ///
347    /// let shape = Shape::new(vec![2, 3, 4]);
348    /// assert_eq!(shape.rank(), 3);
349    /// ```
350    #[inline]
351    #[track_caller]
352    pub fn rank(&self) -> usize {
353        self.dims.len()
354    }
355
356    /// Checks if the tensor has contiguous memory layout
357    ///
358    /// # Returns
359    ///
360    /// `true` if the tensor data is stored contiguously in memory
361    ///
362    /// # Examples
363    ///
364    /// ```
365    /// use train_station::tensor::Shape;
366    ///
367    /// let shape = Shape::new(vec![2, 3, 4]);
368    /// assert!(shape.is_contiguous());
369    /// ```
370    #[inline]
371    #[track_caller]
372    pub fn is_contiguous(&self) -> bool {
373        matches!(self.layout, MemoryLayout::Contiguous)
374    }
375
376    /// Checks if the tensor is a view of another tensor
377    ///
378    /// # Returns
379    ///
380    /// `true` if this tensor is a view (non-contiguous reference)
381    ///
382    /// # Examples
383    ///
384    /// ```
385    /// use train_station::tensor::Shape;
386    ///
387    /// let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
388    /// assert!(view_shape.is_view());
389    /// ```
390    #[inline]
391    #[track_caller]
392    pub fn is_view(&self) -> bool {
393        matches!(self.layout, MemoryLayout::View)
394    }
395
396    /// Gets the memory stride for a specific dimension
397    ///
398    /// # Arguments
399    ///
400    /// * `dim` - The dimension index
401    ///
402    /// # Returns
403    ///
404    /// The memory stride for the given dimension
405    ///
406    /// # Panics
407    ///
408    /// Panics if `dim` is out of bounds
409    ///
410    /// # Examples
411    ///
412    /// ```
413    /// use train_station::tensor::Shape;
414    ///
415    /// let shape = Shape::new(vec![2, 3, 4]);
416    /// assert_eq!(shape.stride(0), 12);
417    /// assert_eq!(shape.stride(1), 4);
418    /// assert_eq!(shape.stride(2), 1);
419    /// ```
420    #[inline]
421    #[track_caller]
422    pub fn stride(&self, dim: usize) -> usize {
423        self.strides[dim]
424    }
425
426    /// Gets all memory strides
427    ///
428    /// # Returns
429    ///
430    /// Reference to the stride vector
431    ///
432    /// # Examples
433    ///
434    /// ```
435    /// use train_station::tensor::Shape;
436    ///
437    /// let shape = Shape::new(vec![2, 3, 4]);
438    /// assert_eq!(shape.strides(), &[12, 4, 1]);
439    /// ```
440    #[inline]
441    #[track_caller]
442    pub fn strides(&self) -> &[usize] {
443        &self.strides
444    }
445
446    /// Gets the memory layout type
447    ///
448    /// # Returns
449    ///
450    /// Reference to the memory layout
451    ///
452    /// # Implementation Details
453    ///
454    /// This method returns the memory layout type which can be used for
455    /// optimization decisions in tensor operations.
456    #[inline]
457    #[track_caller]
458    pub fn layout(&self) -> &MemoryLayout {
459        &self.layout
460    }
461
462    /// Calculates the linear memory offset for given indices
463    ///
464    /// Computes the linear memory offset for multi-dimensional tensor indices
465    /// using the stored stride information. This enables efficient direct memory
466    /// access for tensor operations.
467    ///
468    /// # Arguments
469    ///
470    /// * `indices` - Vector of indices for each dimension
471    ///
472    /// # Returns
473    ///
474    /// Linear memory offset for the given multi-dimensional indices
475    ///
476    /// # Panics
477    ///
478    /// Panics if indices length doesn't match tensor rank
479    ///
480    /// # Performance
481    ///
482    /// - **Time Complexity**: O(rank) for offset calculation
483    /// - **Memory**: No allocation, uses existing stride data
484    /// - **Optimization**: Efficient dot product of indices and strides
485    ///
486    /// # Examples
487    ///
488    /// ```
489    /// use train_station::tensor::Shape;
490    ///
491    /// let shape = Shape::new(vec![2, 3, 4]);
492    /// let offset = shape.offset(&[1, 2, 3]);
493    /// assert_eq!(offset, 12 + 8 + 3); // 1*12 + 2*4 + 3*1
494    /// ```
495    #[inline]
496    #[track_caller]
497    pub fn offset(&self, indices: &[usize]) -> usize {
498        assert_eq!(indices.len(), self.rank(), "Indices must match tensor rank");
499        indices
500            .iter()
501            .zip(self.strides.iter())
502            .map(|(&idx, &stride)| idx * stride)
503            .sum()
504    }
505
506    /// Checks if this shape is broadcastable with another shape
507    ///
508    /// Determines if two shapes can be broadcast together according to NumPy
509    /// broadcasting rules. Broadcasting enables element-wise operations between
510    /// tensors with different shapes by expanding singleton dimensions.
511    ///
512    /// # Arguments
513    ///
514    /// * `other` - The other shape to check broadcasting compatibility
515    ///
516    /// # Returns
517    ///
518    /// `true` if the shapes are broadcastable according to NumPy broadcasting rules
519    ///
520    /// # Performance
521    ///
522    /// - **Time Complexity**: O(max(rank1, rank2)) for broadcasting check
523    /// - **Memory**: No allocation, uses existing dimension data
524    /// - **Optimization**: Right-aligned dimension comparison
525    ///
526    /// # Broadcasting Rules
527    ///
528    /// - Dimensions are compared from right to left
529    /// - Dimensions are compatible if they are equal, or one is 1
530    /// - Missing dimensions are treated as 1
531    ///
532    /// # Examples
533    ///
534    /// ```
535    /// use train_station::tensor::Shape;
536    ///
537    /// let shape1 = Shape::new(vec![2, 3, 4]);
538    /// let shape2 = Shape::new(vec![1, 3, 4]);
539    /// assert!(shape1.is_broadcastable_with(&shape2));
540    ///
541    /// let shape3 = Shape::new(vec![4]);
542    /// assert!(shape1.is_broadcastable_with(&shape3));
543    /// ```
544    #[track_caller]
545    pub fn is_broadcastable_with(&self, other: &Shape) -> bool {
546        let max_rank = self.rank().max(other.rank());
547
548        for i in 0..max_rank {
549            let self_dim = if i < self.rank() {
550                self.dims[self.rank() - 1 - i]
551            } else {
552                1
553            };
554
555            let other_dim = if i < other.rank() {
556                other.dims[other.rank() - 1 - i]
557            } else {
558                1
559            };
560
561            if self_dim != other_dim && self_dim != 1 && other_dim != 1 {
562                return false;
563            }
564        }
565
566        true
567    }
568}
569
570#[cfg(test)]
571mod tests {
572    //! Shape and memory layout tests
573    //!
574    //! Comprehensive tests for shape creation, memory layout detection, stride
575    //! calculation, broadcasting compatibility, and offset computation. Tests
576    //! cover all major functionality including edge cases and performance characteristics.
577
578    use super::*;
579
580    /// Test basic shape creation and properties
581    ///
582    /// Verifies that shapes are created with correct dimensions, size, rank,
583    /// and layout information. Tests the fundamental shape creation functionality.
584    #[test]
585    fn test_shape_creation() {
586        let shape = Shape::new(vec![2, 3, 4]);
587        assert_eq!(shape.size, 24);
588        assert_eq!(shape.rank(), 3);
589        assert!(shape.is_contiguous());
590        assert!(!shape.is_view());
591        assert_eq!(shape.strides(), &[12, 4, 1]);
592    }
593
594    #[test]
595    fn test_shape_1d() {
596        let shape = Shape::new(vec![10]);
597        assert_eq!(shape.size, 10);
598        assert_eq!(shape.rank(), 1);
599        assert!(shape.is_contiguous());
600        assert_eq!(shape.strides(), &[1]);
601    }
602
603    #[test]
604    fn test_shape_2d() {
605        let shape = Shape::new(vec![5, 6]);
606        assert_eq!(shape.size, 30);
607        assert_eq!(shape.rank(), 2);
608        assert!(shape.is_contiguous());
609        assert_eq!(shape.strides(), &[6, 1]);
610    }
611
612    #[test]
613    fn test_zero_sized_shape() {
614        let shape = Shape::new(vec![0]);
615        assert_eq!(shape.size, 0);
616        assert_eq!(shape.rank(), 1);
617        assert_eq!(shape.strides(), &[1]);
618    }
619
620    #[test]
621    fn test_shape_with_strides() {
622        let shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
623        assert_eq!(shape.size, 6);
624        assert_eq!(shape.rank(), 2);
625        assert!(!shape.is_contiguous());
626        assert_eq!(shape.strides(), &[6, 2]);
627        assert_eq!(shape.stride(0), 6);
628        assert_eq!(shape.stride(1), 2);
629    }
630
631    #[test]
632    fn test_shape_as_view() {
633        let shape = Shape::as_view(vec![2, 2], vec![4, 1]);
634        assert_eq!(shape.size, 4);
635        assert_eq!(shape.rank(), 2);
636        assert!(shape.is_view());
637        assert!(!shape.is_contiguous());
638        assert_eq!(shape.strides(), &[4, 1]);
639    }
640
641    #[test]
642    fn test_stride_calculation() {
643        let dims = vec![2, 3, 4];
644        let strides = Shape::compute_contiguous_strides(&dims);
645        assert_eq!(strides, vec![12, 4, 1]);
646
647        let dims = vec![5];
648        let strides = Shape::compute_contiguous_strides(&dims);
649        assert_eq!(strides, vec![1]);
650    }
651
652    /// Test memory offset calculation for multi-dimensional indices
653    ///
654    /// Verifies that linear memory offsets are correctly calculated for various
655    /// multi-dimensional index combinations using stride information.
656    #[test]
657    fn test_offset_calculation() {
658        let shape = Shape::new(vec![2, 3, 4]);
659
660        // Test corner cases
661        assert_eq!(shape.offset(&[0, 0, 0]), 0);
662        assert_eq!(shape.offset(&[1, 2, 3]), 12 + 8 + 3);
663        assert_eq!(shape.offset(&[1, 0, 0]), 12);
664        assert_eq!(shape.offset(&[0, 1, 0]), 4);
665        assert_eq!(shape.offset(&[0, 0, 1]), 1);
666    }
667
668    /// Test broadcasting compatibility rules
669    ///
670    /// Verifies that NumPy-compatible broadcasting rules are correctly implemented
671    /// for various shape combinations including singleton dimensions and different ranks.
672    #[test]
673    fn test_broadcasting_compatibility() {
674        let shape1 = Shape::new(vec![2, 3, 4]);
675        let shape2 = Shape::new(vec![1, 3, 4]);
676        let shape3 = Shape::new(vec![4]);
677        let shape4 = Shape::new(vec![2, 1, 4]);
678        let shape5 = Shape::new(vec![2, 2, 4]);
679
680        assert!(shape1.is_broadcastable_with(&shape2));
681        assert!(shape1.is_broadcastable_with(&shape3));
682        assert!(shape1.is_broadcastable_with(&shape4));
683        assert!(!shape1.is_broadcastable_with(&shape5)); // 3 != 2 and neither is 1
684
685        assert!(shape2.is_broadcastable_with(&shape1));
686        assert!(shape3.is_broadcastable_with(&shape1));
687        assert!(shape4.is_broadcastable_with(&shape1));
688    }
689
690    #[test]
691    fn test_contiguous_strides_detection() {
692        let shape1 = Shape::with_strides(vec![2, 3, 4], vec![12, 4, 1]);
693        assert!(shape1.is_contiguous());
694
695        let shape2 = Shape::with_strides(vec![2, 3, 4], vec![12, 4, 2]);
696        assert!(!shape2.is_contiguous());
697    }
698}