train_station/tensor/core/
mod.rs

1//! Tensor (core): PyTorch‑inspired, zero‑dependency, maximum‑performance array with autograd
2//!
3//! This module contains the central `Tensor` type and its core building blocks
4//! (allocation, memory, shape, operators, views). The public API is Tensor‑centric:
5//! you construct and operate via `Tensor` methods; submodules exist for internal
6//! organization and are referenced from the `Tensor` API.
7//!
8//! # Highlights
9//!
10//! - **Initialization**: `new`, `zeros`, `ones`, `randn`, `from_slice`, `new_on_device`,
11//!   `new_uninitialized`, `new_uninitialized_aligned`
12//! - **Ops**: element‑wise (+, -, *, /), scalar ops, reductions (`sum`), `matmul`
13//! - **Broadcasting**: NumPy‑compatible rules for element‑wise and batched ops
14//! - **Views**: zero‑copy `view` (reshape), `slice_view`, `element_view`, transpose/strides
15//! - **Iterator API**: idiomatic iteration over elements/chunks/dimensions/windows that yields
16//!   view tensors and preserves autograd; collect back with `collect_shape`
17//! - **Autograd (GradTrack)**: thread‑safe, fast backward with `retain_grad`, `grad_owned`
18//! - **Performance**: SIMD‑aligned memory, cache‑aware kernels, thread‑local memory pool
19//! - **Controls**: `with_no_mem_pool` for cross‑thread ownership, `with_no_mem_padding` for exact sizes
20//!
21//! # Quick start
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Init (many options)
27//! let a = Tensor::zeros(vec![2, 3]);
28//! let b = Tensor::ones(vec![3]).with_requires_grad();
29//! let x = Tensor::randn(vec![2, 3], None);
30//!
31//! // Element‑wise ops and reductions
32//! let y = a.add_scalar(1.0).mul_scalar(2.0);
33//! let s = y.sum();
34//! assert_eq!(s.size(), 1);
35//! ```
36//!
37//! ## Broadcasting
38//!
39//! ```
40//! use train_station::Tensor;
41//!
42//! let a = Tensor::ones(vec![2, 1, 4]);
43//! let b = Tensor::ones(vec![3, 1]);
44//! let c = a.add_tensor(&b); // [2,1,4] + [3,1] -> [2,3,4]
45//! assert_eq!(c.shape().dims(), &[2, 3, 4]);
46//! ```
47//!
48//! ## Views (zero‑copy)
49//!
50//! ```
51//! use train_station::Tensor;
52//!
53//! let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
54//! let v = x.view(vec![2, 2]);
55//! assert_eq!(v.shape().dims(), &[2, 2]);
56//! let e = x.element_view(2);
57//! assert_eq!(e.value(), 3.0);
58//! ```
59//!
60//! ## Iterator‑first API
61//!
62//! ```
63//! use train_station::{Tensor, tensor::TensorCollectExt};
64//!
65//! let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
66//! let mat = t.iter_chunks(2)
67//!     .map(|chunk| chunk.mul_scalar(2.0))
68//!     .collect_shape(vec![3, 2]);
69//! assert_eq!(mat.shape().dims(), &[3, 2]);
70//! assert_eq!(mat.data(), &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
71//! ```
72//!
73//! ## Autograd (GradTrack)
74//!
75//! ```
76//! use train_station::Tensor;
77//!
78//! let x = Tensor::ones(vec![2, 3]).with_requires_grad();
79//! let mut loss = x.add_scalar(5.0).sum();
80//! loss.backward(None);
81//! let gx = x.grad_owned().unwrap();
82//! assert_eq!(gx.shape().dims(), &[2, 3]);
83//! ```
84//!
85//! ## Cross‑thread memory pool control
86//!
87//! ```
88//! use train_station::{Tensor, tensor::with_no_mem_pool};
89//! use std::thread;
90//!
91//! // Create in worker and return to main: prefer system allocator
92//! let handle = thread::spawn(|| {
93//!     with_no_mem_pool(|| Tensor::ones(vec![10]))
94//! });
95//! let _t = handle.join().unwrap();
96//! ```
97//!
98//! # Memory layout & performance
99//!
100//! Row‑major layout with runtime‑detected SIMD alignment: typically 16/32/64‑byte alignment and
101//! lane‑multiple capacity for vectorized kernels. Zero‑copy views preserve allocation ownership.
102
103pub mod allocation;
104pub mod memory;
105pub mod operators;
106pub mod serialization;
107pub mod shape;
108// Deprecated: legacy thread_pool kept only if other crates depend on it
109// pub mod thread_pool;
110pub mod utils;
111pub mod view;
112
113use std::marker::PhantomData;
114use std::ptr::NonNull;
115use std::sync::atomic::AtomicUsize;
116use std::sync::Arc;
117
118use crate::device::Device;
119use crate::gradtrack::engine::GraphGroupRef;
120use crate::gradtrack::GradFn;
121
122pub use allocation::Allocation;
123pub use memory::{with_no_mem_pool, NoMemPoolGuard};
124pub use shape::{MemoryLayout, Shape};
125
126// Note: Prefetching functions are now in ops/add.rs where they're used
127
128/// Global counter for unique tensor IDs
129///
130/// Provides thread-safe, unique identifiers for tensor gradtrack tracking.
131/// Uses atomic operations to ensure uniqueness across concurrent tensor creation.
132static TENSOR_ID_COUNTER: AtomicUsize = AtomicUsize::new(1);
133
134/// High-performance multi-dimensional tensor with automatic differentiation support
135///
136/// The core data structure for machine learning operations, designed for maximum
137/// performance with zero-cost abstractions. Supports arbitrary dimensionality,
138/// SIMD optimization, gradient tracking, device placement, and natural mathematical
139/// expressions through operator overloading.
140///
141/// # Key Features
142///
143/// - **Raw Pointer Storage**: Zero-overhead memory access for maximum performance
144/// - **SIMD Optimization**: AVX2 alignment and vectorized operations
145/// - **Memory Efficiency**: Optimized alignment strategies for different tensor sizes
146/// - **gradtrack Integration**: Built-in gradient tracking and computation
147/// - **Device Support**: CPU and future CUDA device placement
148/// - **View Tensors**: Zero-copy tensor views with shared memory management
149/// - **Thread Safety**: Send + Sync implementation for concurrent usage
150/// - **Operator Overloading**: Natural mathematical expressions (+, -, *, /, +=, -=, *=, /=)
151///
152/// # Memory Layout
153///
154/// Tensors use row-major memory layout with size-dependent alignment:
155/// - **Small tensors** (≤8 elements): 16-byte SSE alignment
156/// - **Medium tensors** (8-1024 elements): 32-byte AVX2 alignment  
157/// - **Large tensors** (>1024 elements): 64-byte cache-line alignment
158///
159/// # Performance Characteristics
160///
161/// - **Memory Overhead**: ~64 bytes per tensor (excluding data)
162/// - **SIMD Ready**: Properly aligned for vectorized operations
163/// - **Cache Friendly**: Optimized memory layout for CPU cache hierarchies
164/// - **Zero-Cost Views**: View tensors share memory without copying
165/// - **Thread Safe**: Atomic ID generation and lock-free operations
166/// - **Operator Performance**: Zero-cost operator overloading for mathematical expressions
167///
168/// # Safety
169///
170/// This struct uses unsafe code for performance. The following invariants must be maintained:
171/// - `data` must be valid for `shape.size` elements
172/// - `data` must be properly aligned for `f32`
173/// - `data` must not be aliased while the tensor exists
174/// - `shape.size` must match the actual allocated memory
175/// - `allocation_owner` must be valid if present
176///
177/// # Examples
178///
179/// ## Basic Tensor Operations
180///
181/// ```
182/// use train_station::Tensor;
183///
184/// // Create tensors with different configurations
185/// let tensor = Tensor::new(vec![2, 3]);
186/// let tensor_with_grad = Tensor::ones(vec![10, 10]).with_requires_grad();
187///
188/// // Access tensor properties
189/// assert_eq!(tensor.size(), 6);
190/// assert_eq!(tensor.shape().dims(), vec![2, 3]);
191/// assert!(tensor.is_contiguous());
192/// ```
193///
194/// ## Operator Overloading
195///
196/// ```
197/// use train_station::Tensor;
198///
199/// // Create tensors for operations
200/// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
201/// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
202///
203/// // Tensor operations with operators
204/// let result = a.clone() + b.clone();                    // Tensor addition
205/// let result = a.clone() * b.clone();                    // Element-wise multiplication
206/// let result = a.clone() - b.clone();                    // Tensor subtraction
207/// let result = a.clone() / b.clone();                    // Element-wise division
208///
209/// // Scalar operations
210/// let result = a.clone() + 5.0;                          // Tensor + scalar
211/// let result = 5.0 + a.clone();                          // Scalar + tensor
212/// let result = a.clone() * 3.0;                          // Tensor * scalar
213/// let result = 3.0 * a.clone();                          // Scalar * tensor
214///
215/// // Compound expressions
216/// let result = (a.clone() + b.clone()) * 2.0 - 1.0;      // Complex mathematical expressions
217///
218/// // Assignment operators
219/// let mut c = a.clone();
220/// c += b.clone();                                        // In-place addition
221/// c *= 2.0;                                              // In-place scalar multiplication
222///
223/// // Negation
224/// let result = -a;                                       // Negate all elements
225/// ```
226///
227/// # Thread Safety
228///
229/// This type is `Send + Sync` and can be safely shared between threads.
230/// All operations are thread-safe through atomic ID generation and
231/// thread-local gradtrack storage.
232pub struct Tensor {
233    /// Raw pointer to the tensor data in memory
234    ///
235    /// Provides zero-overhead access to tensor elements for maximum performance.
236    /// The pointer is guaranteed to be valid for `shape.size` elements and properly
237    /// aligned for SIMD operations. This field enables direct memory access without
238    /// bounds checking overhead.
239    ///
240    /// # Safety
241    ///
242    /// - Must be valid for `shape.size` elements
243    /// - Must be properly aligned for `f32` operations
244    /// - Must not be aliased while tensor exists
245    data: NonNull<f32>,
246
247    /// The shape and dimensional information of the tensor
248    ///
249    /// Contains the dimensions, size, strides, and memory layout information.
250    /// This field determines how the raw data is interpreted as a multi-dimensional
251    /// tensor and enables efficient memory access patterns.
252    shape: Shape,
253
254    /// Device where this tensor is located (CPU/GPU)
255    ///
256    /// Determines the physical location of the tensor data and which operations
257    /// can be performed on it. Currently supports CPU with future CUDA support.
258    device: Device,
259
260    /// Unique identifier for gradtrack tracking
261    ///
262    /// Thread-safe, globally unique ID used by the gradtrack system to track
263    /// tensor operations and gradient computation. Generated atomically to
264    /// ensure uniqueness across concurrent tensor creation.
265    id: usize,
266
267    /// Whether this tensor requires gradient computation
268    ///
269    /// Controls whether the gradtrack system tracks operations on this tensor
270    /// and computes gradients during backward pass. When `true`, operations
271    /// are recorded in the computation graph for gradient propagation.
272    requires_grad: bool,
273
274    /// Whether this tensor should retain its gradient after backward even if non-leaf
275    ///
276    /// When set via `retain_grad()`/`retain_grad_()`, users can materialize the
277    /// gradient into `self.grad` after backward using `grad_or_fetch()` so that
278    /// `grad()` returns `Some(&Tensor)` for non-leaf tensors too.
279    retain_grad: bool,
280
281    /// Accumulated gradients from backward pass
282    ///
283    /// Stores the computed gradients for this tensor after calling `backward()`.
284    /// `None` if `requires_grad=false` or no gradients have been computed yet.
285    /// Uses `Arc` for efficient sharing between view tensors.
286    grad: Option<Arc<Tensor>>,
287
288    /// Gradient function for gradtrack computation
289    ///
290    /// Records the operation that created this tensor for gradient computation
291    /// during backward pass. Contains the necessary information to compute
292    /// gradients with respect to input tensors.
293    grad_fn: GradFn,
294
295    /// Shared allocation owner for view tensors
296    ///
297    /// Enables zero-copy tensor views by sharing memory allocation between
298    /// multiple tensors. `None` for tensors that own their memory directly.
299    /// Uses `Arc` for thread-safe reference counting and automatic cleanup.
300    allocation_owner: Option<std::sync::Arc<Allocation>>,
301
302    /// Optional graph group reference for implicit cross-thread autograd context.
303    /// None when gradients are disabled. When present, this tensor participates in
304    /// the associated computation graph (local or shared).
305    graph_group: Option<std::sync::Arc<GraphGroupRef>>,
306
307    /// Phantom data to ensure proper lifetime management
308    ///
309    /// Ensures the tensor has the correct lifetime parameters for the `f32`
310    /// data type. This prevents lifetime issues when working with raw pointers.
311    _phantom: PhantomData<f32>,
312}
313
314// Make Tensor Send + Sync for thread-safe usage
315//
316// Safety: The raw pointer is properly managed through RAII patterns and
317// the data is not shared between threads without proper synchronization.
318// All tensor operations are thread-safe through atomic ID generation and
319// thread-local gradtrack storage.
320unsafe impl Send for Tensor {}
321unsafe impl Sync for Tensor {}
322
323// No custom Drop: memory is managed by the shared `Allocation` owner when present.
324
325impl std::fmt::Debug for Tensor {
326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        f.debug_struct("Tensor")
328            .field("shape", &self.shape)
329            .field("size", &self.size())
330            .field("id", &self.id)
331            .field("requires_grad", &self.requires_grad)
332            .field("has_grad", &self.grad.is_some())
333            .field("has_grad_fn", &!matches!(self.grad_fn, GradFn::None))
334            .finish()
335    }
336}
337
338/// Clone implementation for Tensor
339///
340/// Creates a deep copy of the tensor data but resets gradtrack state
341/// (new tensor won't track gradients unless explicitly set)
342impl Clone for Tensor {
343    fn clone(&self) -> Self {
344        // Fast path for contiguous tensors: direct linear copy
345        if self.is_contiguous() || self.size() == 0 {
346            let mut cloned = Self::new(self.shape().dims().to_vec());
347            unsafe {
348                let src = self.as_ptr();
349                let dst = cloned.as_mut_ptr();
350                std::ptr::copy_nonoverlapping(src, dst, self.size());
351            }
352            return cloned;
353        }
354
355        // Non-contiguous view: materialize into a contiguous copy respecting strides
356        let mut result = Tensor::new(self.shape().dims().to_vec());
357        let rank = self.shape().rank();
358        unsafe {
359            let dst_ptr = result.as_mut_ptr();
360            for dst_idx in 0..result.size() {
361                // Compute destination coordinates under contiguous strides
362                let mut coords = vec![0usize; rank];
363                let mut tmp = dst_idx;
364                for i in (0..rank).rev() {
365                    let dim_size = self.shape().dims()[i];
366                    coords[i] = tmp % dim_size;
367                    tmp /= dim_size;
368                }
369                let src_off = self.shape().offset(&coords);
370                *dst_ptr.add(dst_idx) = *self.as_ptr().add(src_off);
371            }
372        }
373
374        result
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    //! Core tensor functionality tests
381    //!
382    //! Comprehensive tests for tensor creation, memory layout, operator overloading,
383    //! device management, and optimization information. Tests cover all major
384    //! functionality including edge cases and performance characteristics.
385
386    use super::*;
387
388    /// Test basic tensor creation and properties
389    ///
390    /// Verifies that tensors are created with correct dimensions, size, and rank.
391    /// Tests the fundamental tensor creation functionality.
392    #[test]
393    fn test_tensor_creation() {
394        let tensor = Tensor::new(vec![2, 3, 4]);
395        assert_eq!(tensor.size(), 24);
396        assert_eq!(tensor.shape().rank(), 3);
397    }
398
399    #[test]
400    fn test_tensor_1d() {
401        let tensor = Tensor::new(vec![10]);
402        assert_eq!(tensor.size(), 10);
403        assert_eq!(tensor.shape().rank(), 1);
404    }
405
406    #[test]
407    fn test_tensor_2d() {
408        let tensor = Tensor::new(vec![3, 4]);
409        assert_eq!(tensor.size(), 12);
410        assert_eq!(tensor.shape().rank(), 2);
411    }
412
413    #[test]
414    fn test_zero_sized_tensor() {
415        let tensor = Tensor::new(vec![0]);
416        assert_eq!(tensor.size(), 0);
417    }
418
419    #[test]
420    fn test_broadcasting_compatibility() {
421        let a = Tensor::new(vec![2, 3, 4]);
422        let b = Tensor::new(vec![1, 3, 4]);
423        let c = Tensor::new(vec![4]);
424        let d = Tensor::new(vec![2, 1, 4]);
425        let e = Tensor::new(vec![2, 2, 4]);
426
427        assert!(a.is_broadcastable_with(&b));
428        assert!(a.is_broadcastable_with(&c));
429        assert!(a.is_broadcastable_with(&d));
430        assert!(!a.is_broadcastable_with(&e)); // 3 != 2 and neither is 1
431    }
432
433    #[test]
434    fn test_tensor_device_cpu() {
435        use crate::device::Device;
436
437        let tensor = Tensor::new(vec![2, 3]);
438        assert_eq!(tensor.device(), Device::cpu());
439        assert!(tensor.device().is_cpu());
440        assert!(!tensor.device().is_cuda());
441    }
442
443    #[test]
444    fn test_tensor_new_on_device_cpu() {
445        use crate::device::Device;
446
447        let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
448        assert_eq!(tensor.device(), Device::cpu());
449        assert_eq!(tensor.size(), 6);
450    }
451
452    #[test]
453    #[should_panic(expected = "CUDA support not enabled. Enable with --features cuda")]
454    fn test_tensor_new_on_cuda_panics() {
455        use crate::device::Device;
456
457        // This should panic since CUDA feature is not enabled
458        // The panic occurs when trying to create the CUDA device
459        Device::cuda(0);
460    }
461
462    #[test]
463    fn test_device_context_integration() {
464        use crate::device::{with_device, Device};
465
466        // Test that tensors created in different device contexts get the right device
467        let tensor1 = Tensor::new(vec![2]);
468        assert_eq!(tensor1.device(), Device::cpu());
469
470        with_device(Device::cpu(), || {
471            let tensor2 = Tensor::new(vec![3]);
472            assert_eq!(tensor2.device(), Device::cpu());
473        });
474    }
475
476    #[test]
477    fn test_device_zero_sized_tensor() {
478        use crate::device::Device;
479
480        let tensor = Tensor::new_on_device(vec![0], Device::cpu());
481        assert_eq!(tensor.device(), Device::cpu());
482        assert_eq!(tensor.size(), 0);
483    }
484
485    /// Test data() and data_mut() methods for safe tensor data access
486    #[test]
487    fn test_data_access_methods() {
488        // Test data() method
489        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
490        let data = tensor.data();
491
492        assert_eq!(data.len(), 4);
493        assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
494
495        // Test data_mut() method
496        let mut tensor = Tensor::new(vec![2, 2]);
497        let data_mut = tensor.data_mut();
498        data_mut[0] = 10.0;
499        data_mut[1] = 20.0;
500        data_mut[2] = 30.0;
501        data_mut[3] = 40.0;
502
503        // Verify changes
504        assert_eq!(tensor.get(&[0, 0]), 10.0);
505        assert_eq!(tensor.get(&[0, 1]), 20.0);
506        assert_eq!(tensor.get(&[1, 0]), 30.0);
507        assert_eq!(tensor.get(&[1, 1]), 40.0);
508
509        // Test with zero-sized tensor
510        let empty = Tensor::new(vec![0]);
511        assert_eq!(empty.data().len(), 0);
512
513        let mut empty_mut = Tensor::new(vec![0]);
514        assert_eq!(empty_mut.data_mut().len(), 0);
515    }
516
517    /// Test data() method with standard library operations
518    #[test]
519    fn test_data_with_std_operations() {
520        let tensor = Tensor::from_slice(&[1.0, -2.0, 3.0, -4.0, 5.0], vec![5]).unwrap();
521        let data = tensor.data();
522
523        // Test iterator methods
524        let sum: f32 = data.iter().sum();
525        assert_eq!(sum, 3.0);
526
527        let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
528        assert_eq!(max, 5.0);
529
530        let positive_count = data.iter().filter(|&&x| x > 0.0).count();
531        assert_eq!(positive_count, 3);
532
533        // Test indexing
534        assert_eq!(data[0], 1.0);
535        assert_eq!(data[4], 5.0);
536    }
537
538    /// Test value() method for scalar tensor access
539    #[test]
540    fn test_value_method() {
541        // Test single-element tensor
542        let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
543        assert_eq!(scalar.value(), 42.0);
544
545        // Test with different shapes that have size 1
546        let scalar_2d = Tensor::from_slice(&[std::f32::consts::PI], vec![1, 1]).unwrap();
547        assert_eq!(scalar_2d.value(), std::f32::consts::PI);
548
549        let scalar_3d = Tensor::from_slice(&[-1.5], vec![1, 1, 1]).unwrap();
550        assert_eq!(scalar_3d.value(), -1.5);
551
552        // Test with result from iterator
553        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
554        let first_elem = tensor.iter_elements().next().unwrap();
555        assert_eq!(first_elem.value(), 1.0);
556        assert_eq!(first_elem.shape().dims(), vec![1]);
557        assert_eq!(first_elem.size(), 1);
558    }
559
560    /// Test value() method error handling
561    #[test]
562    #[should_panic(expected = "value() can only be called on tensors with exactly one element")]
563    fn test_value_method_panics_on_multi_element() {
564        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
565        let _ = tensor.value(); // Should panic
566    }
567
568    /// Test value() method with empty tensor
569    #[test]
570    #[should_panic(expected = "value() can only be called on tensors with exactly one element")]
571    fn test_value_method_panics_on_empty() {
572        let empty = Tensor::new(vec![0]);
573        let _ = empty.value(); // Should panic
574    }
575}