train_station/tensor/mod.rs
1//! Tensor module for high-performance multi-dimensional data structures
2//!
3//! This module provides the foundational building blocks for tensor operations.
4//! Public API is intentionally Tensor-centric: developers and agents should
5//! interact primarily through methods on `Tensor` (and re-exported helpers).
6//!
7//! Internally, the implementation is organized into specialized submodules for
8//! maximum performance and maintainability. These submodules are not part of the
9//! public API surface.
10//!
11//! The tensor system is designed for zero-cost abstractions with SIMD optimization
12//! and comprehensive automatic differentiation support.
13//!
14//! # Organization
15//!
16//! Internal organization (for context; not public API):
17//! - core: `Tensor` memory, views, and operators
18//! - shape: dimension/stride management and broadcasting logic
19//! - ops: SIMD-optimized math (add/sub/mul/div/matmul, activations, etc.)
20//! - transform: reshape/transpose/permute; concat/stack utilities
21//! - indexing: select/gather/masked operations
22//! - reductions: sum/mean/min/max/std/var
23//! - init: constructors and initialization helpers
24//!
25//! # Key Features
26//!
27//! - **Zero-Cost Abstractions**: Minimal overhead for tensor operations
28//! - **SIMD Optimization**: AVX2 optimizations for x86_64 architectures
29//! - **Memory Efficiency**: Optimized alignment and layout strategies
30//! - **GradTrack Integration**: Built-in gradient tracking and computation
31//! - **Operator Overloading**: Natural mathematical expressions (+, -, *, /, +=, -=, *=, /=)
32//! - **Thread Safety**: Send + Sync implementation for concurrent usage
33//! - **Device Support**: CPU and future CUDA device placement
34//! - **View Tensors**: Zero-copy tensor views with shared memory
35//! - **Broadcasting**: NumPy-style for element-wise ops; batched ND matmul
36//! - **Iterator-first API**: chunks, windows, dims, values with collect helpers
37//! - **PyTorch-inspired API**: familiar ergonomics for easy adoption
38//!
39//! ## Initialization capabilities (Tensor-centric)
40//!
41//! - `Tensor::new(dims)` for uninitialized memory (initialize before reading)
42//! - `Tensor::zeros(dims)`, `Tensor::ones(dims)`, `Tensor::randn(dims, seed)`
43//! - `Tensor::from_slice(values, dims)` for zero-copy ingest-then-own
44//! - `Tensor::new_uninitialized(dims)` and `Tensor::new_uninitialized_aligned(dims, align)` for perf paths
45//! - And moreāsee `Tensor` methods in the docs for the full set of constructors
46//!
47//! # Performance Characteristics
48//!
49//! - **Memory Overhead**: ~64 bytes per tensor (excluding data)
50//! - **SIMD Alignment**: 32-byte alignment for AVX2 operations
51//! - **Cache Optimization**: Cache-line alignment for large tensors
52//! - **View Efficiency**: Zero-copy views with shared memory management
53//! - **Operator Performance**: Zero-cost operator overloading for mathematical expressions
54//! - **Thread Safety**: Lock-free operations with atomic ID generation
55//!
56//! # Examples
57//!
58//! ## Basic Tensor Operations
59//!
60//! ```
61//! use train_station::Tensor;
62//!
63//! // Create tensors with different configurations
64//! let tensor = Tensor::new(vec![2, 3, 4]);
65//! let tensor_with_grad = Tensor::ones(vec![10, 10]).with_requires_grad();
66//! let z = Tensor::zeros(vec![2, 3]);
67//! let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
68//!
69//! // Access tensor properties
70//! assert_eq!(tensor.size(), 24);
71//! assert_eq!(tensor.shape().dims(), vec![2, 3, 4]);
72//! assert!(tensor.is_contiguous());
73//! ```
74//!
75//! ## Operator Overloading
76//!
77//! ```
78//! use train_station::Tensor;
79//!
80//! // Create tensors for operations
81//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
82//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
83//!
84//! // Tensor operations with operators
85//! let result = a.clone() + b.clone(); // Tensor addition
86//! let result = a.clone() * b.clone(); // Element-wise multiplication
87//! let result = a.clone() - b.clone(); // Tensor subtraction
88//! let result = a.clone() / b.clone(); // Element-wise division
89//!
90//! // Scalar operations
91//! let result = a.clone() + 5.0; // Tensor + scalar
92//! let result = 5.0 + a.clone(); // Scalar + tensor
93//! let result = a.clone() * 3.0; // Tensor * scalar
94//! let result = 3.0 * a.clone(); // Scalar * tensor
95//!
96//! // Compound expressions
97//! let result = (a.clone() + b.clone()) * 2.0 - 1.0; // Complex mathematical expressions
98//!
99//! // Assignment operators
100//! let mut c = a.clone();
101//! c += b.clone(); // In-place addition
102//! c *= 2.0; // In-place scalar multiplication
103//!
104//! // Negation
105//! let result = -a; // Negate all elements
106//! ```
107//!
108//! ## Automatic Differentiation
109//!
110//! ```
111//! use train_station::Tensor;
112//!
113//! // Create tensors with gradient tracking
114//! let x = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap().with_requires_grad();
115//! let y = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap().with_requires_grad();
116//!
117//! // Perform operations (gradients are automatically tracked)
118//! let z = x.clone() * y.clone() + 2.0;
119//! let mut loss = z.sum();
120//!
121//! // Compute gradients
122//! loss.backward(None);
123//!
124//! // Access gradients (gradients are computed and stored)
125//! // Note: Gradient availability depends on the computation graph
126//! let x_grad = x.grad();
127//! let y_grad = y.grad();
128//! ```
129//!
130//! ## Broadcasting
131//!
132//! ```
133//! use train_station::Tensor;
134//!
135//! let a = Tensor::ones(vec![2, 3]);
136//! let b = Tensor::ones(vec![1, 3]);
137//! let c = a.add_tensor(&b); // [2,3] + [1,3] -> [2,3]
138//! assert_eq!(c.shape().dims(), vec![2, 3]);
139//! ```
140//!
141//! ## Iterators and collect helpers
142//!
143//! ```
144//! use train_station::Tensor;
145//!
146//! let t = Tensor::from_slice(&(0..6).map(|x| x as f32).collect::<Vec<_>>(), vec![6]).unwrap();
147//! let mat = t.chunks(2).collect_shape(vec![3, 2]);
148//! assert_eq!(mat.shape().dims(), &[3, 2]);
149//!
150//! // Conversions
151//! let from_vec: Tensor = vec![1.0, 2.0, 3.0].into();
152//! let back: Vec<f32> = from_vec.into();
153//! ```
154//!
155//! # Thread Safety
156//!
157//! All tensor operations are thread-safe and implement `Send + Sync`. Tensors can be
158//! safely shared between threads for concurrent read access. Write operations should
159//! be synchronized externally if multiple threads need to modify the same tensor.
160//!
161//! Note: submodules listed above are internal; users should access functionality via
162//! methods on `Tensor` (and a few re-exported helpers) for a clean, PyTorch-inspired API.
163//!
164//! Memory pool note: allocations are served by a thread-local pool by default. If you
165//! create tensors in a worker and return them to another thread, consider wrapping
166//! creation in `train_station::tensor::core::with_no_mem_pool(|| ...)` so those
167//! allocations use the system allocator instead of a thread-local pool.
168//!
169//! # Design Principles
170//!
171//! - **Performance First**: Every design decision optimized for speed
172//! - **Memory Safety**: RAII patterns with justified unsafe usage
173//! - **Zero Dependencies**: Only standard library dependencies
174//! - **SIMD Ready**: Optimized for vectorized operations
175//! - **Future Proof**: Foundation for advanced ML operations
176//! - **Natural API**: Operator overloading for intuitive mathematical expressions
177//! - **Modular Organization**: Specialized submodules for maintainability
178//! - **Comprehensive Testing**: 100% coverage with FFI mathematical validation
179
180pub(crate) mod core;
181pub(crate) mod indexing;
182pub(crate) mod init;
183pub(crate) mod iterator;
184pub(crate) mod ops;
185pub(crate) mod reductions;
186pub(crate) mod transform;
187
188pub(crate) use core::MemoryLayout;
189pub use core::{with_no_mem_pool, NoMemPoolGuard, Shape, Tensor};
190
191// Re-export iterator helpers/traits so users can access collect_shape without deep paths
192pub use iterator::{TensorCollectExt, ValuesCollectExt};