train_station/tensor/core/
shape.rs

1//! Tensor shape and memory layout management
2//!
3//! This module provides the `Shape` struct and related components for managing
4//! tensor dimensions, memory strides, and layout information. The shape system
5//! enables efficient view operations, broadcasting, and memory access optimization.
6//!
7//! # Architecture
8//!
9//! The shape system consists of:
10//! - **Shape**: Main struct containing dimensions, strides, and layout information
11//! - **MemoryLayout**: Enum describing memory layout types (Contiguous, Strided, View)
12//! - **Stride Calculation**: Efficient computation of memory access patterns
13//! - **Broadcasting**: NumPy-compatible broadcasting rules implementation
14//!
15//! # Key Features
16//!
17//! - **Memory Layout Tracking**: Contiguous, strided, and view layout types
18//! - **Stride Optimization**: Efficient memory access pattern calculation
19//! - **Broadcasting Support**: NumPy-compatible broadcasting rules
20//! - **View Operations**: Zero-copy tensor transformations
21//! - **Performance Hints**: Layout information for operation optimization
22//! - **Memory Safety**: Bounds checking and validation
23//!
24//! # Performance Characteristics
25//!
26//! - **Zero-Cost Layout**: Layout information computed once and cached
27//! - **Efficient Strides**: Row-major stride calculation for optimal memory access
28//! - **Broadcasting**: O(rank) complexity for broadcasting compatibility checks
29//! - **Memory Access**: O(1) offset calculation for multi-dimensional indices
30//! - **View Efficiency**: Zero-copy view creation with minimal overhead
31//!
32//! # Memory Layout Types
33//!
34//! - **Contiguous**: Standard row-major layout with sequential memory access
35//! - **Strided**: Custom stride layout for non-contiguous memory access
36//! - **View**: Non-contiguous reference to existing tensor data
37//!
38//! # Examples
39//!
40//! ## Basic Shape Operations
41//!
42//! ```
43//! use train_station::tensor::Shape;
44//!
45//! // Create contiguous shape
46//! let shape = Shape::new(vec![2, 3, 4]);
47//! assert_eq!(shape.size, 24);
48//! assert!(shape.is_contiguous());
49//!
50//! // Create view shape
51//! let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
52//! assert!(view_shape.is_view());
53//!
54//! // Check broadcasting compatibility
55//! let shape1 = Shape::new(vec![2, 3, 4]);
56//! let shape2 = Shape::new(vec![1, 3, 4]);
57//! assert!(shape1.is_broadcastable_with(&shape2));
58//!
59//! // Calculate memory offset
60//! let offset = shape.offset(&[1, 2, 3]);
61//! assert_eq!(offset, 12 + 8 + 3);
62//! ```
63//!
64//! # Design Principles
65//!
66//! - **Memory Efficiency**: Optimized for cache-friendly access patterns
67//! - **Zero-Cost Abstractions**: Minimal overhead for shape operations
68//! - **NumPy Compatibility**: Broadcasting rules match NumPy behavior
69//! - **Type Safety**: Strong typing for memory layout and dimensions
70//! - **Performance First**: All operations optimized for speed
71
72/// Memory layout information for tensors
73///
74/// Describes how tensor data is arranged in memory for optimized access patterns
75/// and view operations. This enum provides performance hints for operation
76/// selection and memory access optimization.
77///
78/// # Variants
79///
80/// * `Contiguous` - Standard row-major layout with sequential memory access
81/// * `Strided` - Custom stride layout for non-contiguous memory access
82/// * `View` - Non-contiguous reference to existing tensor data
83///
84/// # Performance Characteristics
85///
86/// - **Contiguous**: Optimal for SIMD operations and cache efficiency
87/// - **Strided**: Requires custom memory access patterns
88/// - **View**: Zero-copy operations with shared memory management
89///
90/// # Implementation Details
91///
92/// This enum is used internally by the shape system to track memory layout
93/// information for optimization decisions. The layout type determines which
94/// operations can be used efficiently on the tensor data.
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum MemoryLayout {
97    /// Contiguous memory layout (standard row-major)
98    Contiguous,
99    /// Strided memory layout with custom stride information
100    Strided,
101    /// Non-contiguous view of another tensor
102    View,
103}
104
105/// Represents the shape/dimensions of a tensor with stride tracking
106///
107/// This struct holds the dimensional information for a tensor, including the size
108/// of each dimension, memory strides, and layout information for efficient view
109/// operations and transformations. The shape system enables zero-copy tensor
110/// views and optimized memory access patterns.
111///
112/// # Key Features
113///
114/// - **Dimension Management**: Multi-dimensional shape representation
115/// - **Stride Calculation**: Efficient memory access pattern computation
116/// - **Layout Tracking**: Contiguous, strided, and view layout types
117/// - **Broadcasting**: NumPy-compatible broadcasting rules
118/// - **Memory Safety**: Bounds checking and validation
119///
120/// # Performance Characteristics
121///
122/// - **Zero-Cost Layout**: Layout information computed once and cached
123/// - **Efficient Strides**: Row-major stride calculation for optimal access
124/// - **Memory Access**: O(1) offset calculation for multi-dimensional indices
125/// - **View Efficiency**: Zero-copy view creation with minimal overhead
126///
127/// # Examples
128///
129/// ```
130/// use train_station::tensor::Shape;
131///
132/// let shape = Shape::new(vec![2, 3, 4]);
133/// assert_eq!(shape.size, 24);
134/// assert!(shape.is_contiguous());
135/// assert_eq!(shape.strides(), &[12, 4, 1]);
136/// ```
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct Shape {
139    /// The dimensions of the tensor (e.g., [2, 3, 4] for a 2x3x4 tensor)
140    pub dims: Vec<usize>,
141    /// Total number of elements in the tensor
142    pub size: usize,
143    /// Memory strides for each dimension (elements between consecutive indices)
144    /// For a contiguous tensor with shape [2, 3, 4], strides would be [12, 4, 1]
145    pub strides: Vec<usize>,
146    /// Memory layout type for optimization decisions
147    pub layout: MemoryLayout,
148}
149
150impl Shape {
151    /// Creates a new contiguous shape from a vector of dimensions
152    ///
153    /// Computes the total size and contiguous strides for the given dimensions.
154    /// The resulting shape uses row-major memory layout optimized for cache
155    /// efficiency and SIMD operations.
156    ///
157    /// # Arguments
158    ///
159    /// * `dims` - Vector of dimension sizes defining the tensor shape
160    ///
161    /// # Returns
162    ///
163    /// A new Shape with calculated size, contiguous strides, and contiguous layout
164    ///
165    /// # Performance
166    ///
167    /// - **Time Complexity**: O(rank) for stride calculation
168    /// - **Memory**: Single allocation for dimensions and strides
169    /// - **Optimization**: Row-major layout for cache efficiency
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use train_station::tensor::Shape;
175    ///
176    /// let shape = Shape::new(vec![2, 3, 4]);
177    /// assert_eq!(shape.size, 24);
178    /// assert!(shape.is_contiguous());
179    /// assert_eq!(shape.strides(), &[12, 4, 1]);
180    /// ```
181    #[inline]
182    pub fn new(dims: Vec<usize>) -> Self {
183        let size = dims.iter().product();
184        let strides = Self::compute_contiguous_strides(&dims);
185        Self {
186            dims,
187            size,
188            strides,
189            layout: MemoryLayout::Contiguous,
190        }
191    }
192
193    /// Creates a new shape with custom strides
194    ///
195    /// Creates a shape with user-defined strides for non-contiguous memory layouts.
196    /// Automatically detects if the strides represent a contiguous layout and sets
197    /// the appropriate layout type.
198    ///
199    /// # Arguments
200    ///
201    /// * `dims` - Vector of dimension sizes defining the tensor shape
202    /// * `strides` - Vector of memory strides for each dimension
203    ///
204    /// # Returns
205    ///
206    /// A new Shape with the given dimensions and strides, with layout type
207    /// automatically determined
208    ///
209    /// # Panics
210    ///
211    /// Panics if dimensions and strides have different lengths
212    ///
213    /// # Performance
214    ///
215    /// - **Layout Detection**: O(rank) comparison with contiguous strides
216    /// - **Memory**: Single allocation for shape data
217    /// - **Optimization**: Automatic layout type detection
218    ///
219    /// # Examples
220    ///
221    /// ```
222    /// use train_station::tensor::Shape;
223    ///
224    /// let shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
225    /// assert_eq!(shape.size, 6);
226    /// assert!(!shape.is_contiguous());
227    /// assert_eq!(shape.strides(), &[6, 2]);
228    /// ```
229    #[inline]
230    pub fn with_strides(dims: Vec<usize>, strides: Vec<usize>) -> Self {
231        assert_eq!(
232            dims.len(),
233            strides.len(),
234            "Dimensions and strides must have same length"
235        );
236        let size = dims.iter().product();
237        let layout = if strides == Self::compute_contiguous_strides(&dims) {
238            MemoryLayout::Contiguous
239        } else {
240            MemoryLayout::Strided
241        };
242        Self {
243            dims,
244            size,
245            strides,
246            layout,
247        }
248    }
249
250    /// Creates a view shape (non-contiguous reference to existing tensor)
251    ///
252    /// Creates a shape representing a view of existing tensor data with custom
253    /// dimensions and strides. View shapes enable zero-copy tensor transformations
254    /// by sharing memory with the original tensor.
255    ///
256    /// # Arguments
257    ///
258    /// * `dims` - Vector of dimension sizes for the view
259    /// * `strides` - Vector of memory strides for the view
260    ///
261    /// # Returns
262    ///
263    /// A new Shape marked as a view with the given dimensions and strides
264    ///
265    /// # Panics
266    ///
267    /// Panics if dimensions and strides have different lengths
268    ///
269    /// # Performance
270    ///
271    /// - **Zero-Copy**: No data copying, only metadata creation
272    /// - **Memory Efficient**: Shares memory with original tensor
273    /// - **View Optimization**: Enables view-specific operation optimizations
274    ///
275    /// # Examples
276    ///
277    /// ```
278    /// use train_station::tensor::Shape;
279    ///
280    /// let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
281    /// assert!(view_shape.is_view());
282    /// assert!(!view_shape.is_contiguous());
283    /// ```
284    #[inline]
285    pub fn as_view(dims: Vec<usize>, strides: Vec<usize>) -> Self {
286        assert_eq!(
287            dims.len(),
288            strides.len(),
289            "Dimensions and strides must have same length"
290        );
291        let size = dims.iter().product();
292        Self {
293            dims,
294            size,
295            strides,
296            layout: MemoryLayout::View,
297        }
298    }
299
300    /// Computes contiguous strides for given dimensions (row-major order)
301    ///
302    /// Calculates the memory strides for a contiguous row-major layout.
303    /// This is used internally for shape creation and layout detection.
304    ///
305    /// # Arguments
306    ///
307    /// * `dims` - Vector of dimension sizes
308    ///
309    /// # Returns
310    ///
311    /// Vector of strides for contiguous row-major layout
312    ///
313    /// # Implementation Details
314    ///
315    /// This method is used internally by the shape system to compute
316    /// contiguous strides for new shapes and to detect if custom strides
317    /// represent a contiguous layout.
318    fn compute_contiguous_strides(dims: &[usize]) -> Vec<usize> {
319        let mut strides = Vec::with_capacity(dims.len());
320        if dims.is_empty() {
321            return strides;
322        }
323
324        let mut stride = 1;
325        for &dim in dims.iter().rev() {
326            strides.push(stride);
327            stride *= dim;
328        }
329        strides.reverse();
330        strides
331    }
332
333    /// Returns the number of dimensions (rank) of the tensor
334    ///
335    /// # Returns
336    ///
337    /// The number of dimensions in the tensor shape
338    ///
339    /// # Examples
340    ///
341    /// ```
342    /// use train_station::tensor::Shape;
343    ///
344    /// let shape = Shape::new(vec![2, 3, 4]);
345    /// assert_eq!(shape.rank(), 3);
346    /// ```
347    #[inline]
348    pub fn rank(&self) -> usize {
349        self.dims.len()
350    }
351
352    /// Checks if the tensor has contiguous memory layout
353    ///
354    /// # Returns
355    ///
356    /// `true` if the tensor data is stored contiguously in memory
357    ///
358    /// # Examples
359    ///
360    /// ```
361    /// use train_station::tensor::Shape;
362    ///
363    /// let shape = Shape::new(vec![2, 3, 4]);
364    /// assert!(shape.is_contiguous());
365    /// ```
366    #[inline]
367    pub fn is_contiguous(&self) -> bool {
368        matches!(self.layout, MemoryLayout::Contiguous)
369    }
370
371    /// Checks if the tensor is a view of another tensor
372    ///
373    /// # Returns
374    ///
375    /// `true` if this tensor is a view (non-contiguous reference)
376    ///
377    /// # Examples
378    ///
379    /// ```
380    /// use train_station::tensor::Shape;
381    ///
382    /// let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
383    /// assert!(view_shape.is_view());
384    /// ```
385    #[inline]
386    pub fn is_view(&self) -> bool {
387        matches!(self.layout, MemoryLayout::View)
388    }
389
390    /// Gets the memory stride for a specific dimension
391    ///
392    /// # Arguments
393    ///
394    /// * `dim` - The dimension index
395    ///
396    /// # Returns
397    ///
398    /// The memory stride for the given dimension
399    ///
400    /// # Panics
401    ///
402    /// Panics if `dim` is out of bounds
403    ///
404    /// # Examples
405    ///
406    /// ```
407    /// use train_station::tensor::Shape;
408    ///
409    /// let shape = Shape::new(vec![2, 3, 4]);
410    /// assert_eq!(shape.stride(0), 12);
411    /// assert_eq!(shape.stride(1), 4);
412    /// assert_eq!(shape.stride(2), 1);
413    /// ```
414    #[inline]
415    pub fn stride(&self, dim: usize) -> usize {
416        self.strides[dim]
417    }
418
419    /// Gets all memory strides
420    ///
421    /// # Returns
422    ///
423    /// Reference to the stride vector
424    ///
425    /// # Examples
426    ///
427    /// ```
428    /// use train_station::tensor::Shape;
429    ///
430    /// let shape = Shape::new(vec![2, 3, 4]);
431    /// assert_eq!(shape.strides(), &[12, 4, 1]);
432    /// ```
433    #[inline]
434    pub fn strides(&self) -> &[usize] {
435        &self.strides
436    }
437
438    /// Gets the memory layout type
439    ///
440    /// # Returns
441    ///
442    /// Reference to the memory layout
443    ///
444    /// # Implementation Details
445    ///
446    /// This method returns the memory layout type which can be used for
447    /// optimization decisions in tensor operations.
448    #[inline]
449    pub fn layout(&self) -> &MemoryLayout {
450        &self.layout
451    }
452
453    /// Calculates the linear memory offset for given indices
454    ///
455    /// Computes the linear memory offset for multi-dimensional tensor indices
456    /// using the stored stride information. This enables efficient direct memory
457    /// access for tensor operations.
458    ///
459    /// # Arguments
460    ///
461    /// * `indices` - Vector of indices for each dimension
462    ///
463    /// # Returns
464    ///
465    /// Linear memory offset for the given multi-dimensional indices
466    ///
467    /// # Panics
468    ///
469    /// Panics if indices length doesn't match tensor rank
470    ///
471    /// # Performance
472    ///
473    /// - **Time Complexity**: O(rank) for offset calculation
474    /// - **Memory**: No allocation, uses existing stride data
475    /// - **Optimization**: Efficient dot product of indices and strides
476    ///
477    /// # Examples
478    ///
479    /// ```
480    /// use train_station::tensor::Shape;
481    ///
482    /// let shape = Shape::new(vec![2, 3, 4]);
483    /// let offset = shape.offset(&[1, 2, 3]);
484    /// assert_eq!(offset, 12 + 8 + 3); // 1*12 + 2*4 + 3*1
485    /// ```
486    #[inline]
487    pub fn offset(&self, indices: &[usize]) -> usize {
488        assert_eq!(indices.len(), self.rank(), "Indices must match tensor rank");
489        indices
490            .iter()
491            .zip(self.strides.iter())
492            .map(|(&idx, &stride)| idx * stride)
493            .sum()
494    }
495
496    /// Checks if this shape is broadcastable with another shape
497    ///
498    /// Determines if two shapes can be broadcast together according to NumPy
499    /// broadcasting rules. Broadcasting enables element-wise operations between
500    /// tensors with different shapes by expanding singleton dimensions.
501    ///
502    /// # Arguments
503    ///
504    /// * `other` - The other shape to check broadcasting compatibility
505    ///
506    /// # Returns
507    ///
508    /// `true` if the shapes are broadcastable according to NumPy broadcasting rules
509    ///
510    /// # Performance
511    ///
512    /// - **Time Complexity**: O(max(rank1, rank2)) for broadcasting check
513    /// - **Memory**: No allocation, uses existing dimension data
514    /// - **Optimization**: Right-aligned dimension comparison
515    ///
516    /// # Broadcasting Rules
517    ///
518    /// - Dimensions are compared from right to left
519    /// - Dimensions are compatible if they are equal, or one is 1
520    /// - Missing dimensions are treated as 1
521    ///
522    /// # Examples
523    ///
524    /// ```
525    /// use train_station::tensor::Shape;
526    ///
527    /// let shape1 = Shape::new(vec![2, 3, 4]);
528    /// let shape2 = Shape::new(vec![1, 3, 4]);
529    /// assert!(shape1.is_broadcastable_with(&shape2));
530    ///
531    /// let shape3 = Shape::new(vec![4]);
532    /// assert!(shape1.is_broadcastable_with(&shape3));
533    /// ```
534    pub fn is_broadcastable_with(&self, other: &Shape) -> bool {
535        let max_rank = self.rank().max(other.rank());
536
537        for i in 0..max_rank {
538            let self_dim = if i < self.rank() {
539                self.dims[self.rank() - 1 - i]
540            } else {
541                1
542            };
543
544            let other_dim = if i < other.rank() {
545                other.dims[other.rank() - 1 - i]
546            } else {
547                1
548            };
549
550            if self_dim != other_dim && self_dim != 1 && other_dim != 1 {
551                return false;
552            }
553        }
554
555        true
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    //! Shape and memory layout tests
562    //!
563    //! Comprehensive tests for shape creation, memory layout detection, stride
564    //! calculation, broadcasting compatibility, and offset computation. Tests
565    //! cover all major functionality including edge cases and performance characteristics.
566
567    use super::*;
568
569    /// Test basic shape creation and properties
570    ///
571    /// Verifies that shapes are created with correct dimensions, size, rank,
572    /// and layout information. Tests the fundamental shape creation functionality.
573    #[test]
574    fn test_shape_creation() {
575        let shape = Shape::new(vec![2, 3, 4]);
576        assert_eq!(shape.size, 24);
577        assert_eq!(shape.rank(), 3);
578        assert!(shape.is_contiguous());
579        assert!(!shape.is_view());
580        assert_eq!(shape.strides(), &[12, 4, 1]);
581    }
582
583    #[test]
584    fn test_shape_1d() {
585        let shape = Shape::new(vec![10]);
586        assert_eq!(shape.size, 10);
587        assert_eq!(shape.rank(), 1);
588        assert!(shape.is_contiguous());
589        assert_eq!(shape.strides(), &[1]);
590    }
591
592    #[test]
593    fn test_shape_2d() {
594        let shape = Shape::new(vec![5, 6]);
595        assert_eq!(shape.size, 30);
596        assert_eq!(shape.rank(), 2);
597        assert!(shape.is_contiguous());
598        assert_eq!(shape.strides(), &[6, 1]);
599    }
600
601    #[test]
602    fn test_zero_sized_shape() {
603        let shape = Shape::new(vec![0]);
604        assert_eq!(shape.size, 0);
605        assert_eq!(shape.rank(), 1);
606        assert_eq!(shape.strides(), &[1]);
607    }
608
609    #[test]
610    fn test_shape_with_strides() {
611        let shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
612        assert_eq!(shape.size, 6);
613        assert_eq!(shape.rank(), 2);
614        assert!(!shape.is_contiguous());
615        assert_eq!(shape.strides(), &[6, 2]);
616        assert_eq!(shape.stride(0), 6);
617        assert_eq!(shape.stride(1), 2);
618    }
619
620    #[test]
621    fn test_shape_as_view() {
622        let shape = Shape::as_view(vec![2, 2], vec![4, 1]);
623        assert_eq!(shape.size, 4);
624        assert_eq!(shape.rank(), 2);
625        assert!(shape.is_view());
626        assert!(!shape.is_contiguous());
627        assert_eq!(shape.strides(), &[4, 1]);
628    }
629
630    #[test]
631    fn test_stride_calculation() {
632        let dims = vec![2, 3, 4];
633        let strides = Shape::compute_contiguous_strides(&dims);
634        assert_eq!(strides, vec![12, 4, 1]);
635
636        let dims = vec![5];
637        let strides = Shape::compute_contiguous_strides(&dims);
638        assert_eq!(strides, vec![1]);
639    }
640
641    /// Test memory offset calculation for multi-dimensional indices
642    ///
643    /// Verifies that linear memory offsets are correctly calculated for various
644    /// multi-dimensional index combinations using stride information.
645    #[test]
646    fn test_offset_calculation() {
647        let shape = Shape::new(vec![2, 3, 4]);
648
649        // Test corner cases
650        assert_eq!(shape.offset(&[0, 0, 0]), 0);
651        assert_eq!(shape.offset(&[1, 2, 3]), 12 + 8 + 3);
652        assert_eq!(shape.offset(&[1, 0, 0]), 12);
653        assert_eq!(shape.offset(&[0, 1, 0]), 4);
654        assert_eq!(shape.offset(&[0, 0, 1]), 1);
655    }
656
657    /// Test broadcasting compatibility rules
658    ///
659    /// Verifies that NumPy-compatible broadcasting rules are correctly implemented
660    /// for various shape combinations including singleton dimensions and different ranks.
661    #[test]
662    fn test_broadcasting_compatibility() {
663        let shape1 = Shape::new(vec![2, 3, 4]);
664        let shape2 = Shape::new(vec![1, 3, 4]);
665        let shape3 = Shape::new(vec![4]);
666        let shape4 = Shape::new(vec![2, 1, 4]);
667        let shape5 = Shape::new(vec![2, 2, 4]);
668
669        assert!(shape1.is_broadcastable_with(&shape2));
670        assert!(shape1.is_broadcastable_with(&shape3));
671        assert!(shape1.is_broadcastable_with(&shape4));
672        assert!(!shape1.is_broadcastable_with(&shape5)); // 3 != 2 and neither is 1
673
674        assert!(shape2.is_broadcastable_with(&shape1));
675        assert!(shape3.is_broadcastable_with(&shape1));
676        assert!(shape4.is_broadcastable_with(&shape1));
677    }
678
679    #[test]
680    fn test_contiguous_strides_detection() {
681        let shape1 = Shape::with_strides(vec![2, 3, 4], vec![12, 4, 1]);
682        assert!(shape1.is_contiguous());
683
684        let shape2 = Shape::with_strides(vec![2, 3, 4], vec![12, 4, 2]);
685        assert!(!shape2.is_contiguous());
686    }
687}