train_station/tensor/core/shape.rs
1//! Tensor shape and memory layout management
2//!
3//! This module provides the `Shape` struct and related components for managing
4//! tensor dimensions, memory strides, and layout information. The shape system
5//! enables efficient view operations, broadcasting, and memory access optimization.
6//!
7//! # Architecture
8//!
9//! The shape system consists of:
10//! - **Shape**: Main struct containing dimensions, strides, and layout information
11//! - **MemoryLayout**: Enum describing memory layout types (Contiguous, Strided, View)
12//! - **Stride Calculation**: Efficient computation of memory access patterns
13//! - **Broadcasting**: NumPy-compatible broadcasting rules implementation
14//!
15//! # Key Features
16//!
17//! - **Memory Layout Tracking**: Contiguous, strided, and view layout types
18//! - **Stride Optimization**: Efficient memory access pattern calculation
19//! - **Broadcasting Support**: NumPy-compatible broadcasting rules
20//! - **View Operations**: Zero-copy tensor transformations
21//! - **Performance Hints**: Layout information for operation optimization
22//! - **Memory Safety**: Bounds checking and validation
23//!
24//! # Performance Characteristics
25//!
26//! - **Zero-Cost Layout**: Layout information computed once and cached
27//! - **Efficient Strides**: Row-major stride calculation for optimal memory access
28//! - **Broadcasting**: O(rank) complexity for broadcasting compatibility checks
29//! - **Memory Access**: O(1) offset calculation for multi-dimensional indices
30//! - **View Efficiency**: Zero-copy view creation with minimal overhead
31//!
32//! # Memory Layout Types
33//!
34//! - **Contiguous**: Standard row-major layout with sequential memory access
35//! - **Strided**: Custom stride layout for non-contiguous memory access
36//! - **View**: Non-contiguous reference to existing tensor data
37//!
38//! # Examples
39//!
40//! ## Basic Shape Operations
41//!
42//! ```
43//! use train_station::tensor::Shape;
44//!
45//! // Create contiguous shape
46//! let shape = Shape::new(vec![2, 3, 4]);
47//! assert_eq!(shape.size, 24);
48//! assert!(shape.is_contiguous());
49//!
50//! // Create view shape
51//! let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
52//! assert!(view_shape.is_view());
53//!
54//! // Check broadcasting compatibility
55//! let shape1 = Shape::new(vec![2, 3, 4]);
56//! let shape2 = Shape::new(vec![1, 3, 4]);
57//! assert!(shape1.is_broadcastable_with(&shape2));
58//!
59//! // Calculate memory offset
60//! let offset = shape.offset(&[1, 2, 3]);
61//! assert_eq!(offset, 12 + 8 + 3);
62//! ```
63//!
64//! # Design Principles
65//!
66//! - **Memory Efficiency**: Optimized for cache-friendly access patterns
67//! - **Zero-Cost Abstractions**: Minimal overhead for shape operations
68//! - **NumPy Compatibility**: Broadcasting rules match NumPy behavior
69//! - **Type Safety**: Strong typing for memory layout and dimensions
70//! - **Performance First**: All operations optimized for speed
71
72/// Memory layout information for tensors
73///
74/// Describes how tensor data is arranged in memory for optimized access patterns
75/// and view operations. This enum provides performance hints for operation
76/// selection and memory access optimization.
77///
78/// # Variants
79///
80/// * `Contiguous` - Standard row-major layout with sequential memory access
81/// * `Strided` - Custom stride layout for non-contiguous memory access
82/// * `View` - Non-contiguous reference to existing tensor data
83///
84/// # Performance Characteristics
85///
86/// - **Contiguous**: Optimal for SIMD operations and cache efficiency
87/// - **Strided**: Requires custom memory access patterns
88/// - **View**: Zero-copy operations with shared memory management
89///
90/// # Implementation Details
91///
92/// This enum is used internally by the shape system to track memory layout
93/// information for optimization decisions. The layout type determines which
94/// operations can be used efficiently on the tensor data.
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum MemoryLayout {
97 /// Contiguous memory layout (standard row-major)
98 Contiguous,
99 /// Strided memory layout with custom stride information
100 Strided,
101 /// Non-contiguous view of another tensor
102 View,
103}
104
105/// Represents the shape/dimensions of a tensor with stride tracking
106///
107/// This struct holds the dimensional information for a tensor, including the size
108/// of each dimension, memory strides, and layout information for efficient view
109/// operations and transformations. The shape system enables zero-copy tensor
110/// views and optimized memory access patterns.
111///
112/// # Key Features
113///
114/// - **Dimension Management**: Multi-dimensional shape representation
115/// - **Stride Calculation**: Efficient memory access pattern computation
116/// - **Layout Tracking**: Contiguous, strided, and view layout types
117/// - **Broadcasting**: NumPy-compatible broadcasting rules
118/// - **Memory Safety**: Bounds checking and validation
119///
120/// # Performance Characteristics
121///
122/// - **Zero-Cost Layout**: Layout information computed once and cached
123/// - **Efficient Strides**: Row-major stride calculation for optimal access
124/// - **Memory Access**: O(1) offset calculation for multi-dimensional indices
125/// - **View Efficiency**: Zero-copy view creation with minimal overhead
126///
127/// # Examples
128///
129/// ```
130/// use train_station::tensor::Shape;
131///
132/// let shape = Shape::new(vec![2, 3, 4]);
133/// assert_eq!(shape.size, 24);
134/// assert!(shape.is_contiguous());
135/// assert_eq!(shape.strides(), &[12, 4, 1]);
136/// ```
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct Shape {
139 /// The dimensions of the tensor (e.g., [2, 3, 4] for a 2x3x4 tensor)
140 pub dims: Vec<usize>,
141 /// Total number of elements in the tensor
142 pub size: usize,
143 /// Memory strides for each dimension (elements between consecutive indices)
144 /// For a contiguous tensor with shape [2, 3, 4], strides would be [12, 4, 1]
145 pub strides: Vec<usize>,
146 /// Memory layout type for optimization decisions
147 pub layout: MemoryLayout,
148}
149
150impl Shape {
151 /// Creates a new contiguous shape from a vector of dimensions
152 ///
153 /// Computes the total size and contiguous strides for the given dimensions.
154 /// The resulting shape uses row-major memory layout optimized for cache
155 /// efficiency and SIMD operations.
156 ///
157 /// # Arguments
158 ///
159 /// * `dims` - Vector of dimension sizes defining the tensor shape
160 ///
161 /// # Returns
162 ///
163 /// A new Shape with calculated size, contiguous strides, and contiguous layout
164 ///
165 /// # Performance
166 ///
167 /// - **Time Complexity**: O(rank) for stride calculation
168 /// - **Memory**: Single allocation for dimensions and strides
169 /// - **Optimization**: Row-major layout for cache efficiency
170 ///
171 /// # Examples
172 ///
173 /// ```
174 /// use train_station::tensor::Shape;
175 ///
176 /// let shape = Shape::new(vec![2, 3, 4]);
177 /// assert_eq!(shape.size, 24);
178 /// assert!(shape.is_contiguous());
179 /// assert_eq!(shape.strides(), &[12, 4, 1]);
180 /// ```
181 #[inline]
182 pub fn new(dims: Vec<usize>) -> Self {
183 let size = dims.iter().product();
184 let strides = Self::compute_contiguous_strides(&dims);
185 Self {
186 dims,
187 size,
188 strides,
189 layout: MemoryLayout::Contiguous,
190 }
191 }
192
193 /// Creates a new shape with custom strides
194 ///
195 /// Creates a shape with user-defined strides for non-contiguous memory layouts.
196 /// Automatically detects if the strides represent a contiguous layout and sets
197 /// the appropriate layout type.
198 ///
199 /// # Arguments
200 ///
201 /// * `dims` - Vector of dimension sizes defining the tensor shape
202 /// * `strides` - Vector of memory strides for each dimension
203 ///
204 /// # Returns
205 ///
206 /// A new Shape with the given dimensions and strides, with layout type
207 /// automatically determined
208 ///
209 /// # Panics
210 ///
211 /// Panics if dimensions and strides have different lengths
212 ///
213 /// # Performance
214 ///
215 /// - **Layout Detection**: O(rank) comparison with contiguous strides
216 /// - **Memory**: Single allocation for shape data
217 /// - **Optimization**: Automatic layout type detection
218 ///
219 /// # Examples
220 ///
221 /// ```
222 /// use train_station::tensor::Shape;
223 ///
224 /// let shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
225 /// assert_eq!(shape.size, 6);
226 /// assert!(!shape.is_contiguous());
227 /// assert_eq!(shape.strides(), &[6, 2]);
228 /// ```
229 #[inline]
230 pub fn with_strides(dims: Vec<usize>, strides: Vec<usize>) -> Self {
231 assert_eq!(
232 dims.len(),
233 strides.len(),
234 "Dimensions and strides must have same length"
235 );
236 let size = dims.iter().product();
237 let layout = if strides == Self::compute_contiguous_strides(&dims) {
238 MemoryLayout::Contiguous
239 } else {
240 MemoryLayout::Strided
241 };
242 Self {
243 dims,
244 size,
245 strides,
246 layout,
247 }
248 }
249
250 /// Creates a view shape (non-contiguous reference to existing tensor)
251 ///
252 /// Creates a shape representing a view of existing tensor data with custom
253 /// dimensions and strides. View shapes enable zero-copy tensor transformations
254 /// by sharing memory with the original tensor.
255 ///
256 /// # Arguments
257 ///
258 /// * `dims` - Vector of dimension sizes for the view
259 /// * `strides` - Vector of memory strides for the view
260 ///
261 /// # Returns
262 ///
263 /// A new Shape marked as a view with the given dimensions and strides
264 ///
265 /// # Panics
266 ///
267 /// Panics if dimensions and strides have different lengths
268 ///
269 /// # Performance
270 ///
271 /// - **Zero-Copy**: No data copying, only metadata creation
272 /// - **Memory Efficient**: Shares memory with original tensor
273 /// - **View Optimization**: Enables view-specific operation optimizations
274 ///
275 /// # Examples
276 ///
277 /// ```
278 /// use train_station::tensor::Shape;
279 ///
280 /// let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
281 /// assert!(view_shape.is_view());
282 /// assert!(!view_shape.is_contiguous());
283 /// ```
284 #[inline]
285 pub fn as_view(dims: Vec<usize>, strides: Vec<usize>) -> Self {
286 assert_eq!(
287 dims.len(),
288 strides.len(),
289 "Dimensions and strides must have same length"
290 );
291 let size = dims.iter().product();
292 Self {
293 dims,
294 size,
295 strides,
296 layout: MemoryLayout::View,
297 }
298 }
299
300 /// Computes contiguous strides for given dimensions (row-major order)
301 ///
302 /// Calculates the memory strides for a contiguous row-major layout.
303 /// This is used internally for shape creation and layout detection.
304 ///
305 /// # Arguments
306 ///
307 /// * `dims` - Vector of dimension sizes
308 ///
309 /// # Returns
310 ///
311 /// Vector of strides for contiguous row-major layout
312 ///
313 /// # Implementation Details
314 ///
315 /// This method is used internally by the shape system to compute
316 /// contiguous strides for new shapes and to detect if custom strides
317 /// represent a contiguous layout.
318 fn compute_contiguous_strides(dims: &[usize]) -> Vec<usize> {
319 let mut strides = Vec::with_capacity(dims.len());
320 if dims.is_empty() {
321 return strides;
322 }
323
324 let mut stride = 1;
325 for &dim in dims.iter().rev() {
326 strides.push(stride);
327 stride *= dim;
328 }
329 strides.reverse();
330 strides
331 }
332
333 /// Returns the number of dimensions (rank) of the tensor
334 ///
335 /// # Returns
336 ///
337 /// The number of dimensions in the tensor shape
338 ///
339 /// # Examples
340 ///
341 /// ```
342 /// use train_station::tensor::Shape;
343 ///
344 /// let shape = Shape::new(vec![2, 3, 4]);
345 /// assert_eq!(shape.rank(), 3);
346 /// ```
347 #[inline]
348 pub fn rank(&self) -> usize {
349 self.dims.len()
350 }
351
352 /// Checks if the tensor has contiguous memory layout
353 ///
354 /// # Returns
355 ///
356 /// `true` if the tensor data is stored contiguously in memory
357 ///
358 /// # Examples
359 ///
360 /// ```
361 /// use train_station::tensor::Shape;
362 ///
363 /// let shape = Shape::new(vec![2, 3, 4]);
364 /// assert!(shape.is_contiguous());
365 /// ```
366 #[inline]
367 pub fn is_contiguous(&self) -> bool {
368 matches!(self.layout, MemoryLayout::Contiguous)
369 }
370
371 /// Checks if the tensor is a view of another tensor
372 ///
373 /// # Returns
374 ///
375 /// `true` if this tensor is a view (non-contiguous reference)
376 ///
377 /// # Examples
378 ///
379 /// ```
380 /// use train_station::tensor::Shape;
381 ///
382 /// let view_shape = Shape::as_view(vec![2, 2], vec![4, 1]);
383 /// assert!(view_shape.is_view());
384 /// ```
385 #[inline]
386 pub fn is_view(&self) -> bool {
387 matches!(self.layout, MemoryLayout::View)
388 }
389
390 /// Gets the memory stride for a specific dimension
391 ///
392 /// # Arguments
393 ///
394 /// * `dim` - The dimension index
395 ///
396 /// # Returns
397 ///
398 /// The memory stride for the given dimension
399 ///
400 /// # Panics
401 ///
402 /// Panics if `dim` is out of bounds
403 ///
404 /// # Examples
405 ///
406 /// ```
407 /// use train_station::tensor::Shape;
408 ///
409 /// let shape = Shape::new(vec![2, 3, 4]);
410 /// assert_eq!(shape.stride(0), 12);
411 /// assert_eq!(shape.stride(1), 4);
412 /// assert_eq!(shape.stride(2), 1);
413 /// ```
414 #[inline]
415 pub fn stride(&self, dim: usize) -> usize {
416 self.strides[dim]
417 }
418
419 /// Gets all memory strides
420 ///
421 /// # Returns
422 ///
423 /// Reference to the stride vector
424 ///
425 /// # Examples
426 ///
427 /// ```
428 /// use train_station::tensor::Shape;
429 ///
430 /// let shape = Shape::new(vec![2, 3, 4]);
431 /// assert_eq!(shape.strides(), &[12, 4, 1]);
432 /// ```
433 #[inline]
434 pub fn strides(&self) -> &[usize] {
435 &self.strides
436 }
437
438 /// Gets the memory layout type
439 ///
440 /// # Returns
441 ///
442 /// Reference to the memory layout
443 ///
444 /// # Implementation Details
445 ///
446 /// This method returns the memory layout type which can be used for
447 /// optimization decisions in tensor operations.
448 #[inline]
449 pub fn layout(&self) -> &MemoryLayout {
450 &self.layout
451 }
452
453 /// Calculates the linear memory offset for given indices
454 ///
455 /// Computes the linear memory offset for multi-dimensional tensor indices
456 /// using the stored stride information. This enables efficient direct memory
457 /// access for tensor operations.
458 ///
459 /// # Arguments
460 ///
461 /// * `indices` - Vector of indices for each dimension
462 ///
463 /// # Returns
464 ///
465 /// Linear memory offset for the given multi-dimensional indices
466 ///
467 /// # Panics
468 ///
469 /// Panics if indices length doesn't match tensor rank
470 ///
471 /// # Performance
472 ///
473 /// - **Time Complexity**: O(rank) for offset calculation
474 /// - **Memory**: No allocation, uses existing stride data
475 /// - **Optimization**: Efficient dot product of indices and strides
476 ///
477 /// # Examples
478 ///
479 /// ```
480 /// use train_station::tensor::Shape;
481 ///
482 /// let shape = Shape::new(vec![2, 3, 4]);
483 /// let offset = shape.offset(&[1, 2, 3]);
484 /// assert_eq!(offset, 12 + 8 + 3); // 1*12 + 2*4 + 3*1
485 /// ```
486 #[inline]
487 pub fn offset(&self, indices: &[usize]) -> usize {
488 assert_eq!(indices.len(), self.rank(), "Indices must match tensor rank");
489 indices
490 .iter()
491 .zip(self.strides.iter())
492 .map(|(&idx, &stride)| idx * stride)
493 .sum()
494 }
495
496 /// Checks if this shape is broadcastable with another shape
497 ///
498 /// Determines if two shapes can be broadcast together according to NumPy
499 /// broadcasting rules. Broadcasting enables element-wise operations between
500 /// tensors with different shapes by expanding singleton dimensions.
501 ///
502 /// # Arguments
503 ///
504 /// * `other` - The other shape to check broadcasting compatibility
505 ///
506 /// # Returns
507 ///
508 /// `true` if the shapes are broadcastable according to NumPy broadcasting rules
509 ///
510 /// # Performance
511 ///
512 /// - **Time Complexity**: O(max(rank1, rank2)) for broadcasting check
513 /// - **Memory**: No allocation, uses existing dimension data
514 /// - **Optimization**: Right-aligned dimension comparison
515 ///
516 /// # Broadcasting Rules
517 ///
518 /// - Dimensions are compared from right to left
519 /// - Dimensions are compatible if they are equal, or one is 1
520 /// - Missing dimensions are treated as 1
521 ///
522 /// # Examples
523 ///
524 /// ```
525 /// use train_station::tensor::Shape;
526 ///
527 /// let shape1 = Shape::new(vec![2, 3, 4]);
528 /// let shape2 = Shape::new(vec![1, 3, 4]);
529 /// assert!(shape1.is_broadcastable_with(&shape2));
530 ///
531 /// let shape3 = Shape::new(vec![4]);
532 /// assert!(shape1.is_broadcastable_with(&shape3));
533 /// ```
534 pub fn is_broadcastable_with(&self, other: &Shape) -> bool {
535 let max_rank = self.rank().max(other.rank());
536
537 for i in 0..max_rank {
538 let self_dim = if i < self.rank() {
539 self.dims[self.rank() - 1 - i]
540 } else {
541 1
542 };
543
544 let other_dim = if i < other.rank() {
545 other.dims[other.rank() - 1 - i]
546 } else {
547 1
548 };
549
550 if self_dim != other_dim && self_dim != 1 && other_dim != 1 {
551 return false;
552 }
553 }
554
555 true
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 //! Shape and memory layout tests
562 //!
563 //! Comprehensive tests for shape creation, memory layout detection, stride
564 //! calculation, broadcasting compatibility, and offset computation. Tests
565 //! cover all major functionality including edge cases and performance characteristics.
566
567 use super::*;
568
569 /// Test basic shape creation and properties
570 ///
571 /// Verifies that shapes are created with correct dimensions, size, rank,
572 /// and layout information. Tests the fundamental shape creation functionality.
573 #[test]
574 fn test_shape_creation() {
575 let shape = Shape::new(vec![2, 3, 4]);
576 assert_eq!(shape.size, 24);
577 assert_eq!(shape.rank(), 3);
578 assert!(shape.is_contiguous());
579 assert!(!shape.is_view());
580 assert_eq!(shape.strides(), &[12, 4, 1]);
581 }
582
583 #[test]
584 fn test_shape_1d() {
585 let shape = Shape::new(vec![10]);
586 assert_eq!(shape.size, 10);
587 assert_eq!(shape.rank(), 1);
588 assert!(shape.is_contiguous());
589 assert_eq!(shape.strides(), &[1]);
590 }
591
592 #[test]
593 fn test_shape_2d() {
594 let shape = Shape::new(vec![5, 6]);
595 assert_eq!(shape.size, 30);
596 assert_eq!(shape.rank(), 2);
597 assert!(shape.is_contiguous());
598 assert_eq!(shape.strides(), &[6, 1]);
599 }
600
601 #[test]
602 fn test_zero_sized_shape() {
603 let shape = Shape::new(vec![0]);
604 assert_eq!(shape.size, 0);
605 assert_eq!(shape.rank(), 1);
606 assert_eq!(shape.strides(), &[1]);
607 }
608
609 #[test]
610 fn test_shape_with_strides() {
611 let shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
612 assert_eq!(shape.size, 6);
613 assert_eq!(shape.rank(), 2);
614 assert!(!shape.is_contiguous());
615 assert_eq!(shape.strides(), &[6, 2]);
616 assert_eq!(shape.stride(0), 6);
617 assert_eq!(shape.stride(1), 2);
618 }
619
620 #[test]
621 fn test_shape_as_view() {
622 let shape = Shape::as_view(vec![2, 2], vec![4, 1]);
623 assert_eq!(shape.size, 4);
624 assert_eq!(shape.rank(), 2);
625 assert!(shape.is_view());
626 assert!(!shape.is_contiguous());
627 assert_eq!(shape.strides(), &[4, 1]);
628 }
629
630 #[test]
631 fn test_stride_calculation() {
632 let dims = vec![2, 3, 4];
633 let strides = Shape::compute_contiguous_strides(&dims);
634 assert_eq!(strides, vec![12, 4, 1]);
635
636 let dims = vec![5];
637 let strides = Shape::compute_contiguous_strides(&dims);
638 assert_eq!(strides, vec![1]);
639 }
640
641 /// Test memory offset calculation for multi-dimensional indices
642 ///
643 /// Verifies that linear memory offsets are correctly calculated for various
644 /// multi-dimensional index combinations using stride information.
645 #[test]
646 fn test_offset_calculation() {
647 let shape = Shape::new(vec![2, 3, 4]);
648
649 // Test corner cases
650 assert_eq!(shape.offset(&[0, 0, 0]), 0);
651 assert_eq!(shape.offset(&[1, 2, 3]), 12 + 8 + 3);
652 assert_eq!(shape.offset(&[1, 0, 0]), 12);
653 assert_eq!(shape.offset(&[0, 1, 0]), 4);
654 assert_eq!(shape.offset(&[0, 0, 1]), 1);
655 }
656
657 /// Test broadcasting compatibility rules
658 ///
659 /// Verifies that NumPy-compatible broadcasting rules are correctly implemented
660 /// for various shape combinations including singleton dimensions and different ranks.
661 #[test]
662 fn test_broadcasting_compatibility() {
663 let shape1 = Shape::new(vec![2, 3, 4]);
664 let shape2 = Shape::new(vec![1, 3, 4]);
665 let shape3 = Shape::new(vec![4]);
666 let shape4 = Shape::new(vec![2, 1, 4]);
667 let shape5 = Shape::new(vec![2, 2, 4]);
668
669 assert!(shape1.is_broadcastable_with(&shape2));
670 assert!(shape1.is_broadcastable_with(&shape3));
671 assert!(shape1.is_broadcastable_with(&shape4));
672 assert!(!shape1.is_broadcastable_with(&shape5)); // 3 != 2 and neither is 1
673
674 assert!(shape2.is_broadcastable_with(&shape1));
675 assert!(shape3.is_broadcastable_with(&shape1));
676 assert!(shape4.is_broadcastable_with(&shape1));
677 }
678
679 #[test]
680 fn test_contiguous_strides_detection() {
681 let shape1 = Shape::with_strides(vec![2, 3, 4], vec![12, 4, 1]);
682 assert!(shape1.is_contiguous());
683
684 let shape2 = Shape::with_strides(vec![2, 3, 4], vec![12, 4, 2]);
685 assert!(!shape2.is_contiguous());
686 }
687}