train_station/tensor/core/mod.rs
1//! Tensor (core): PyTorch‑inspired, zero‑dependency, maximum‑performance array with autograd
2//!
3//! This module contains the central `Tensor` type and its core building blocks
4//! (allocation, memory, shape, operators, views). The public API is Tensor‑centric:
5//! you construct and operate via `Tensor` methods; submodules exist for internal
6//! organization and are referenced from the `Tensor` API.
7//!
8//! # Highlights
9//!
10//! - **Initialization**: `new`, `zeros`, `ones`, `randn`, `from_slice`, `new_on_device`,
11//! `new_uninitialized`, `new_uninitialized_aligned`
12//! - **Ops**: element‑wise (+, -, *, /), scalar ops, reductions (`sum`), `matmul`
13//! - **Broadcasting**: NumPy‑compatible rules for element‑wise and batched ops
14//! - **Views**: zero‑copy `view` (reshape), `slice_view`, `element_view`, transpose/strides
15//! - **Iterator API**: idiomatic iteration over elements/chunks/dimensions/windows that yields
16//! view tensors and preserves autograd; collect back with `collect_shape`
17//! - **Autograd (GradTrack)**: thread‑safe, fast backward with `retain_grad`, `grad_owned`
18//! - **Performance**: SIMD‑aligned memory, cache‑aware kernels, thread‑local memory pool
19//! - **Controls**: `with_no_mem_pool` for cross‑thread ownership, `with_no_mem_padding` for exact sizes
20//!
21//! # Quick start
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Init (many options)
27//! let a = Tensor::zeros(vec![2, 3]);
28//! let b = Tensor::ones(vec![3]).with_requires_grad();
29//! let x = Tensor::randn(vec![2, 3], None);
30//!
31//! // Element‑wise ops and reductions
32//! let y = a.add_scalar(1.0).mul_scalar(2.0);
33//! let s = y.sum();
34//! assert_eq!(s.size(), 1);
35//! ```
36//!
37//! ## Broadcasting
38//!
39//! ```
40//! use train_station::Tensor;
41//!
42//! let a = Tensor::ones(vec![2, 1, 4]);
43//! let b = Tensor::ones(vec![3, 1]);
44//! let c = a.add_tensor(&b); // [2,1,4] + [3,1] -> [2,3,4]
45//! assert_eq!(c.shape().dims(), &[2, 3, 4]);
46//! ```
47//!
48//! ## Views (zero‑copy)
49//!
50//! ```
51//! use train_station::Tensor;
52//!
53//! let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
54//! let v = x.view(vec![2, 2]);
55//! assert_eq!(v.shape().dims(), &[2, 2]);
56//! let e = x.element_view(2);
57//! assert_eq!(e.value(), 3.0);
58//! ```
59//!
60//! ## Iterator‑first API
61//!
62//! ```
63//! use train_station::{Tensor, tensor::TensorCollectExt};
64//!
65//! let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
66//! let mat = t.iter_chunks(2)
67//! .map(|chunk| chunk.mul_scalar(2.0))
68//! .collect_shape(vec![3, 2]);
69//! assert_eq!(mat.shape().dims(), &[3, 2]);
70//! assert_eq!(mat.data(), &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
71//! ```
72//!
73//! ## Autograd (GradTrack)
74//!
75//! ```
76//! use train_station::Tensor;
77//!
78//! let x = Tensor::ones(vec![2, 3]).with_requires_grad();
79//! let mut loss = x.add_scalar(5.0).sum();
80//! loss.backward(None);
81//! let gx = x.grad_owned().unwrap();
82//! assert_eq!(gx.shape().dims(), &[2, 3]);
83//! ```
84//!
85//! ## Cross‑thread memory pool control
86//!
87//! ```
88//! use train_station::{Tensor, tensor::with_no_mem_pool};
89//! use std::thread;
90//!
91//! // Create in worker and return to main: prefer system allocator
92//! let handle = thread::spawn(|| {
93//! with_no_mem_pool(|| Tensor::ones(vec![10]))
94//! });
95//! let _t = handle.join().unwrap();
96//! ```
97//!
98//! # Memory layout & performance
99//!
100//! Row‑major layout with runtime‑detected SIMD alignment: typically 16/32/64‑byte alignment and
101//! lane‑multiple capacity for vectorized kernels. Zero‑copy views preserve allocation ownership.
102
103pub mod allocation;
104pub mod memory;
105pub mod operators;
106pub mod serialization;
107pub mod shape;
108// Deprecated: legacy thread_pool kept only if other crates depend on it
109// pub mod thread_pool;
110pub mod utils;
111pub mod view;
112
113use std::marker::PhantomData;
114use std::ptr::NonNull;
115use std::sync::atomic::AtomicUsize;
116use std::sync::Arc;
117
118use crate::device::Device;
119use crate::gradtrack::engine::GraphGroupRef;
120use crate::gradtrack::GradFn;
121
122pub use allocation::Allocation;
123pub use memory::{with_no_mem_pool, NoMemPoolGuard};
124pub use shape::{MemoryLayout, Shape};
125
126// Note: Prefetching functions are now in ops/add.rs where they're used
127
128/// Global counter for unique tensor IDs
129///
130/// Provides thread-safe, unique identifiers for tensor gradtrack tracking.
131/// Uses atomic operations to ensure uniqueness across concurrent tensor creation.
132static TENSOR_ID_COUNTER: AtomicUsize = AtomicUsize::new(1);
133
134/// High-performance multi-dimensional tensor with automatic differentiation support
135///
136/// The core data structure for machine learning operations, designed for maximum
137/// performance with zero-cost abstractions. Supports arbitrary dimensionality,
138/// SIMD optimization, gradient tracking, device placement, and natural mathematical
139/// expressions through operator overloading.
140///
141/// # Key Features
142///
143/// - **Raw Pointer Storage**: Zero-overhead memory access for maximum performance
144/// - **SIMD Optimization**: AVX2 alignment and vectorized operations
145/// - **Memory Efficiency**: Optimized alignment strategies for different tensor sizes
146/// - **gradtrack Integration**: Built-in gradient tracking and computation
147/// - **Device Support**: CPU and future CUDA device placement
148/// - **View Tensors**: Zero-copy tensor views with shared memory management
149/// - **Thread Safety**: Send + Sync implementation for concurrent usage
150/// - **Operator Overloading**: Natural mathematical expressions (+, -, *, /, +=, -=, *=, /=)
151///
152/// # Memory Layout
153///
154/// Tensors use row-major memory layout with size-dependent alignment:
155/// - **Small tensors** (≤8 elements): 16-byte SSE alignment
156/// - **Medium tensors** (8-1024 elements): 32-byte AVX2 alignment
157/// - **Large tensors** (>1024 elements): 64-byte cache-line alignment
158///
159/// # Performance Characteristics
160///
161/// - **Memory Overhead**: ~64 bytes per tensor (excluding data)
162/// - **SIMD Ready**: Properly aligned for vectorized operations
163/// - **Cache Friendly**: Optimized memory layout for CPU cache hierarchies
164/// - **Zero-Cost Views**: View tensors share memory without copying
165/// - **Thread Safe**: Atomic ID generation and lock-free operations
166/// - **Operator Performance**: Zero-cost operator overloading for mathematical expressions
167///
168/// # Safety
169///
170/// This struct uses unsafe code for performance. The following invariants must be maintained:
171/// - `data` must be valid for `shape.size` elements
172/// - `data` must be properly aligned for `f32`
173/// - `data` must not be aliased while the tensor exists
174/// - `shape.size` must match the actual allocated memory
175/// - `allocation_owner` must be valid if present
176///
177/// # Examples
178///
179/// ## Basic Tensor Operations
180///
181/// ```
182/// use train_station::Tensor;
183///
184/// // Create tensors with different configurations
185/// let tensor = Tensor::new(vec![2, 3]);
186/// let tensor_with_grad = Tensor::ones(vec![10, 10]).with_requires_grad();
187///
188/// // Access tensor properties
189/// assert_eq!(tensor.size(), 6);
190/// assert_eq!(tensor.shape().dims(), vec![2, 3]);
191/// assert!(tensor.is_contiguous());
192/// ```
193///
194/// ## Operator Overloading
195///
196/// ```
197/// use train_station::Tensor;
198///
199/// // Create tensors for operations
200/// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
201/// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
202///
203/// // Tensor operations with operators
204/// let result = a.clone() + b.clone(); // Tensor addition
205/// let result = a.clone() * b.clone(); // Element-wise multiplication
206/// let result = a.clone() - b.clone(); // Tensor subtraction
207/// let result = a.clone() / b.clone(); // Element-wise division
208///
209/// // Scalar operations
210/// let result = a.clone() + 5.0; // Tensor + scalar
211/// let result = 5.0 + a.clone(); // Scalar + tensor
212/// let result = a.clone() * 3.0; // Tensor * scalar
213/// let result = 3.0 * a.clone(); // Scalar * tensor
214///
215/// // Compound expressions
216/// let result = (a.clone() + b.clone()) * 2.0 - 1.0; // Complex mathematical expressions
217///
218/// // Assignment operators
219/// let mut c = a.clone();
220/// c += b.clone(); // In-place addition
221/// c *= 2.0; // In-place scalar multiplication
222///
223/// // Negation
224/// let result = -a; // Negate all elements
225/// ```
226///
227/// # Thread Safety
228///
229/// This type is `Send + Sync` and can be safely shared between threads.
230/// All operations are thread-safe through atomic ID generation and
231/// thread-local gradtrack storage.
232pub struct Tensor {
233 /// Raw pointer to the tensor data in memory
234 ///
235 /// Provides zero-overhead access to tensor elements for maximum performance.
236 /// The pointer is guaranteed to be valid for `shape.size` elements and properly
237 /// aligned for SIMD operations. This field enables direct memory access without
238 /// bounds checking overhead.
239 ///
240 /// # Safety
241 ///
242 /// - Must be valid for `shape.size` elements
243 /// - Must be properly aligned for `f32` operations
244 /// - Must not be aliased while tensor exists
245 data: NonNull<f32>,
246
247 /// The shape and dimensional information of the tensor
248 ///
249 /// Contains the dimensions, size, strides, and memory layout information.
250 /// This field determines how the raw data is interpreted as a multi-dimensional
251 /// tensor and enables efficient memory access patterns.
252 shape: Shape,
253
254 /// Device where this tensor is located (CPU/GPU)
255 ///
256 /// Determines the physical location of the tensor data and which operations
257 /// can be performed on it. Currently supports CPU with future CUDA support.
258 device: Device,
259
260 /// Unique identifier for gradtrack tracking
261 ///
262 /// Thread-safe, globally unique ID used by the gradtrack system to track
263 /// tensor operations and gradient computation. Generated atomically to
264 /// ensure uniqueness across concurrent tensor creation.
265 id: usize,
266
267 /// Whether this tensor requires gradient computation
268 ///
269 /// Controls whether the gradtrack system tracks operations on this tensor
270 /// and computes gradients during backward pass. When `true`, operations
271 /// are recorded in the computation graph for gradient propagation.
272 requires_grad: bool,
273
274 /// Whether this tensor should retain its gradient after backward even if non-leaf
275 ///
276 /// When set via `retain_grad()`/`retain_grad_()`, users can materialize the
277 /// gradient into `self.grad` after backward using `grad_or_fetch()` so that
278 /// `grad()` returns `Some(&Tensor)` for non-leaf tensors too.
279 retain_grad: bool,
280
281 /// Accumulated gradients from backward pass
282 ///
283 /// Stores the computed gradients for this tensor after calling `backward()`.
284 /// `None` if `requires_grad=false` or no gradients have been computed yet.
285 /// Uses `Arc` for efficient sharing between view tensors.
286 grad: Option<Arc<Tensor>>,
287
288 /// Gradient function for gradtrack computation
289 ///
290 /// Records the operation that created this tensor for gradient computation
291 /// during backward pass. Contains the necessary information to compute
292 /// gradients with respect to input tensors.
293 grad_fn: GradFn,
294
295 /// Shared allocation owner for view tensors
296 ///
297 /// Enables zero-copy tensor views by sharing memory allocation between
298 /// multiple tensors. `None` for tensors that own their memory directly.
299 /// Uses `Arc` for thread-safe reference counting and automatic cleanup.
300 allocation_owner: Option<std::sync::Arc<Allocation>>,
301
302 /// Optional graph group reference for implicit cross-thread autograd context.
303 /// None when gradients are disabled. When present, this tensor participates in
304 /// the associated computation graph (local or shared).
305 graph_group: Option<std::sync::Arc<GraphGroupRef>>,
306
307 /// Phantom data to ensure proper lifetime management
308 ///
309 /// Ensures the tensor has the correct lifetime parameters for the `f32`
310 /// data type. This prevents lifetime issues when working with raw pointers.
311 _phantom: PhantomData<f32>,
312}
313
314// Make Tensor Send + Sync for thread-safe usage
315//
316// Safety: The raw pointer is properly managed through RAII patterns and
317// the data is not shared between threads without proper synchronization.
318// All tensor operations are thread-safe through atomic ID generation and
319// thread-local gradtrack storage.
320unsafe impl Send for Tensor {}
321unsafe impl Sync for Tensor {}
322
323// No custom Drop: memory is managed by the shared `Allocation` owner when present.
324
325impl std::fmt::Debug for Tensor {
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 f.debug_struct("Tensor")
328 .field("shape", &self.shape)
329 .field("size", &self.size())
330 .field("id", &self.id)
331 .field("requires_grad", &self.requires_grad)
332 .field("has_grad", &self.grad.is_some())
333 .field("has_grad_fn", &!matches!(self.grad_fn, GradFn::None))
334 .finish()
335 }
336}
337
338/// Clone implementation for Tensor
339///
340/// Creates a deep copy of the tensor data but resets gradtrack state
341/// (new tensor won't track gradients unless explicitly set)
342impl Clone for Tensor {
343 fn clone(&self) -> Self {
344 // Fast path for contiguous tensors: direct linear copy
345 if self.is_contiguous() || self.size() == 0 {
346 let mut cloned = Self::new(self.shape().dims().to_vec());
347 unsafe {
348 let src = self.as_ptr();
349 let dst = cloned.as_mut_ptr();
350 std::ptr::copy_nonoverlapping(src, dst, self.size());
351 }
352 return cloned;
353 }
354
355 // Non-contiguous view: materialize into a contiguous copy respecting strides
356 let mut result = Tensor::new(self.shape().dims().to_vec());
357 let rank = self.shape().rank();
358 unsafe {
359 let dst_ptr = result.as_mut_ptr();
360 for dst_idx in 0..result.size() {
361 // Compute destination coordinates under contiguous strides
362 let mut coords = vec![0usize; rank];
363 let mut tmp = dst_idx;
364 for i in (0..rank).rev() {
365 let dim_size = self.shape().dims()[i];
366 coords[i] = tmp % dim_size;
367 tmp /= dim_size;
368 }
369 let src_off = self.shape().offset(&coords);
370 *dst_ptr.add(dst_idx) = *self.as_ptr().add(src_off);
371 }
372 }
373
374 result
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 //! Core tensor functionality tests
381 //!
382 //! Comprehensive tests for tensor creation, memory layout, operator overloading,
383 //! device management, and optimization information. Tests cover all major
384 //! functionality including edge cases and performance characteristics.
385
386 use super::*;
387
388 /// Test basic tensor creation and properties
389 ///
390 /// Verifies that tensors are created with correct dimensions, size, and rank.
391 /// Tests the fundamental tensor creation functionality.
392 #[test]
393 fn test_tensor_creation() {
394 let tensor = Tensor::new(vec![2, 3, 4]);
395 assert_eq!(tensor.size(), 24);
396 assert_eq!(tensor.shape().rank(), 3);
397 }
398
399 #[test]
400 fn test_tensor_1d() {
401 let tensor = Tensor::new(vec![10]);
402 assert_eq!(tensor.size(), 10);
403 assert_eq!(tensor.shape().rank(), 1);
404 }
405
406 #[test]
407 fn test_tensor_2d() {
408 let tensor = Tensor::new(vec![3, 4]);
409 assert_eq!(tensor.size(), 12);
410 assert_eq!(tensor.shape().rank(), 2);
411 }
412
413 #[test]
414 fn test_zero_sized_tensor() {
415 let tensor = Tensor::new(vec![0]);
416 assert_eq!(tensor.size(), 0);
417 }
418
419 #[test]
420 fn test_broadcasting_compatibility() {
421 let a = Tensor::new(vec![2, 3, 4]);
422 let b = Tensor::new(vec![1, 3, 4]);
423 let c = Tensor::new(vec![4]);
424 let d = Tensor::new(vec![2, 1, 4]);
425 let e = Tensor::new(vec![2, 2, 4]);
426
427 assert!(a.is_broadcastable_with(&b));
428 assert!(a.is_broadcastable_with(&c));
429 assert!(a.is_broadcastable_with(&d));
430 assert!(!a.is_broadcastable_with(&e)); // 3 != 2 and neither is 1
431 }
432
433 #[test]
434 fn test_tensor_device_cpu() {
435 use crate::device::Device;
436
437 let tensor = Tensor::new(vec![2, 3]);
438 assert_eq!(tensor.device(), Device::cpu());
439 assert!(tensor.device().is_cpu());
440 assert!(!tensor.device().is_cuda());
441 }
442
443 #[test]
444 fn test_tensor_new_on_device_cpu() {
445 use crate::device::Device;
446
447 let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
448 assert_eq!(tensor.device(), Device::cpu());
449 assert_eq!(tensor.size(), 6);
450 }
451
452 #[test]
453 #[should_panic(expected = "CUDA support not enabled. Enable with --features cuda")]
454 fn test_tensor_new_on_cuda_panics() {
455 use crate::device::Device;
456
457 // This should panic since CUDA feature is not enabled
458 // The panic occurs when trying to create the CUDA device
459 Device::cuda(0);
460 }
461
462 #[test]
463 fn test_device_context_integration() {
464 use crate::device::{with_device, Device};
465
466 // Test that tensors created in different device contexts get the right device
467 let tensor1 = Tensor::new(vec![2]);
468 assert_eq!(tensor1.device(), Device::cpu());
469
470 with_device(Device::cpu(), || {
471 let tensor2 = Tensor::new(vec![3]);
472 assert_eq!(tensor2.device(), Device::cpu());
473 });
474 }
475
476 #[test]
477 fn test_device_zero_sized_tensor() {
478 use crate::device::Device;
479
480 let tensor = Tensor::new_on_device(vec![0], Device::cpu());
481 assert_eq!(tensor.device(), Device::cpu());
482 assert_eq!(tensor.size(), 0);
483 }
484
485 /// Test data() and data_mut() methods for safe tensor data access
486 #[test]
487 fn test_data_access_methods() {
488 // Test data() method
489 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
490 let data = tensor.data();
491
492 assert_eq!(data.len(), 4);
493 assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
494
495 // Test data_mut() method
496 let mut tensor = Tensor::new(vec![2, 2]);
497 let data_mut = tensor.data_mut();
498 data_mut[0] = 10.0;
499 data_mut[1] = 20.0;
500 data_mut[2] = 30.0;
501 data_mut[3] = 40.0;
502
503 // Verify changes
504 assert_eq!(tensor.get(&[0, 0]), 10.0);
505 assert_eq!(tensor.get(&[0, 1]), 20.0);
506 assert_eq!(tensor.get(&[1, 0]), 30.0);
507 assert_eq!(tensor.get(&[1, 1]), 40.0);
508
509 // Test with zero-sized tensor
510 let empty = Tensor::new(vec![0]);
511 assert_eq!(empty.data().len(), 0);
512
513 let mut empty_mut = Tensor::new(vec![0]);
514 assert_eq!(empty_mut.data_mut().len(), 0);
515 }
516
517 /// Test data() method with standard library operations
518 #[test]
519 fn test_data_with_std_operations() {
520 let tensor = Tensor::from_slice(&[1.0, -2.0, 3.0, -4.0, 5.0], vec![5]).unwrap();
521 let data = tensor.data();
522
523 // Test iterator methods
524 let sum: f32 = data.iter().sum();
525 assert_eq!(sum, 3.0);
526
527 let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
528 assert_eq!(max, 5.0);
529
530 let positive_count = data.iter().filter(|&&x| x > 0.0).count();
531 assert_eq!(positive_count, 3);
532
533 // Test indexing
534 assert_eq!(data[0], 1.0);
535 assert_eq!(data[4], 5.0);
536 }
537
538 /// Test value() method for scalar tensor access
539 #[test]
540 fn test_value_method() {
541 // Test single-element tensor
542 let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
543 assert_eq!(scalar.value(), 42.0);
544
545 // Test with different shapes that have size 1
546 let scalar_2d = Tensor::from_slice(&[std::f32::consts::PI], vec![1, 1]).unwrap();
547 assert_eq!(scalar_2d.value(), std::f32::consts::PI);
548
549 let scalar_3d = Tensor::from_slice(&[-1.5], vec![1, 1, 1]).unwrap();
550 assert_eq!(scalar_3d.value(), -1.5);
551
552 // Test with result from iterator
553 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
554 let first_elem = tensor.iter_elements().next().unwrap();
555 assert_eq!(first_elem.value(), 1.0);
556 assert_eq!(first_elem.shape().dims(), vec![1]);
557 assert_eq!(first_elem.size(), 1);
558 }
559
560 /// Test value() method error handling
561 #[test]
562 #[should_panic(expected = "value() can only be called on tensors with exactly one element")]
563 fn test_value_method_panics_on_multi_element() {
564 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
565 let _ = tensor.value(); // Should panic
566 }
567
568 /// Test value() method with empty tensor
569 #[test]
570 #[should_panic(expected = "value() can only be called on tensors with exactly one element")]
571 fn test_value_method_panics_on_empty() {
572 let empty = Tensor::new(vec![0]);
573 let _ = empty.value(); // Should panic
574 }
575}