train_station/tensor/core/mod.rs
1//! Core tensor implementation for high-performance machine learning
2//!
3//! This module provides the foundational `Tensor` struct and related components
4//! that form the backbone of the Train Station library. The tensor system is
5//! designed for maximum performance with zero-cost abstractions and SIMD optimization.
6//!
7//! # Organization
8//!
9//! The core tensor system consists of:
10//! - **Tensor**: Main multi-dimensional tensor with gradient tracking
11//! - **Allocation**: Shared memory management for view tensors
12//! - **TensorOptimizationInfo**: Performance hints for operation selection
13//! - **Shape**: Dimension and stride management (from shape module)
14//!
15//! # Key Features
16//!
17//! - **Zero-Cost Abstractions**: Minimal overhead for tensor operations
18//! - **SIMD Optimization**: AVX2 optimizations for x86_64 architectures
19//! - **Memory Efficiency**: Optimized alignment and layout strategies
20//! - **Thread Safety**: Send + Sync implementation for concurrent usage
21//! - **GradTrack Integration**: Built-in gradient tracking and computation
22//! - **Device Support**: CPU and future CUDA device placement
23//! - **View Tensors**: Zero-copy tensor views with shared memory
24//! - **Operator Overloading**: Natural mathematical expressions with operators
25//!
26//! # Performance Characteristics
27//!
28//! - **Memory Overhead**: ~64 bytes per tensor (excluding data)
29//! - **SIMD Alignment**: 32-byte alignment for AVX2 operations
30//! - **Cache Optimization**: Cache-line alignment for large tensors
31//! - **Thread Safety**: Lock-free operations with atomic ID generation
32//! - **View Efficiency**: Zero-copy views with shared memory management
33//! - **Operator Performance**: Zero-cost operator overloading for mathematical expressions
34//!
35//! # Memory Layout
36//!
37//! Tensors use row-major memory layout with optimized alignment:
38//! - **Small tensors** (≤8 elements): 16-byte SSE alignment
39//! - **Medium tensors** (8-1024 elements): 32-byte AVX2 alignment
40//! - **Large tensors** (>1024 elements): 64-byte cache-line alignment
41//!
42//! # Examples
43//!
44//! ## Basic Tensor Operations
45//!
46//! ```
47//! use train_station::Tensor;
48//!
49//! // Create tensors with different configurations
50//! let tensor = Tensor::new(vec![2, 3, 4]);
51//! let tensor_with_grad = Tensor::ones(vec![10, 10]).with_requires_grad();
52//!
53//! // Access tensor properties
54//! assert_eq!(tensor.size(), 24);
55//! assert_eq!(tensor.shape().dims, vec![2, 3, 4]);
56//! assert!(tensor.is_contiguous());
57//! ```
58//!
59//! ## Operator Overloading
60//!
61//! ```
62//! use train_station::Tensor;
63//!
64//! // Create tensors for operations
65//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
66//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
67//!
68//! // Tensor operations with operators
69//! let result = a.clone() + b.clone(); // Tensor addition
70//! let result = a.clone() * b.clone(); // Element-wise multiplication
71//! let result = a.clone() - b.clone(); // Tensor subtraction
72//! let result = a.clone() / b.clone(); // Element-wise division
73//!
74//! // Scalar operations
75//! let result = a.clone() + 5.0; // Tensor + scalar
76//! let result = 5.0 + a.clone(); // Scalar + tensor
77//! let result = a.clone() * 3.0; // Tensor * scalar
78//! let result = 3.0 * a.clone(); // Scalar * tensor
79//!
80//! // Compound expressions
81//! let result = (a.clone() + b.clone()) * 2.0 - 1.0; // Complex mathematical expressions
82//!
83//! // Assignment operators
84//! let mut c = a.clone();
85//! c += b.clone(); // In-place addition
86//! c *= 2.0; // In-place scalar multiplication
87//!
88//! // Negation
89//! let result = -a; // Negate all elements
90//! ```
91//!
92//! # Thread Safety
93//!
94//! All tensor operations are thread-safe and implement `Send + Sync`. Tensors can be
95//! safely shared between threads for concurrent read access. Write operations should
96//! be synchronized externally if multiple threads need to modify the same tensor.
97//!
98//! # Design Principles
99//!
100//! - **Performance First**: Every design decision optimized for speed
101//! - **Memory Safety**: RAII patterns with justified unsafe usage
102//! - **Zero Dependencies**: Only standard library dependencies
103//! - **SIMD Ready**: Optimized for vectorized operations
104//! - **Future Proof**: Foundation for advanced ML operations
105//! - **Natural API**: Operator overloading for intuitive mathematical expressions
106
107pub mod allocation;
108pub mod operators;
109pub mod serialization;
110pub mod shape;
111pub mod utils;
112
113use std::alloc::{dealloc, Layout};
114use std::marker::PhantomData;
115use std::ptr::NonNull;
116use std::sync::atomic::AtomicUsize;
117use std::sync::Arc;
118
119use crate::device::Device;
120use crate::gradtrack::GradFn;
121
122pub use allocation::Allocation;
123pub use shape::{MemoryLayout, Shape};
124
125// Note: Prefetching functions are now in ops/add.rs where they're used
126
127/// Global counter for unique tensor IDs
128///
129/// Provides thread-safe, unique identifiers for tensor gradtrack tracking.
130/// Uses atomic operations to ensure uniqueness across concurrent tensor creation.
131static TENSOR_ID_COUNTER: AtomicUsize = AtomicUsize::new(1);
132
133/// High-performance multi-dimensional tensor with automatic differentiation support
134///
135/// The core data structure for machine learning operations, designed for maximum
136/// performance with zero-cost abstractions. Supports arbitrary dimensionality,
137/// SIMD optimization, gradient tracking, device placement, and natural mathematical
138/// expressions through operator overloading.
139///
140/// # Key Features
141///
142/// - **Raw Pointer Storage**: Zero-overhead memory access for maximum performance
143/// - **SIMD Optimization**: AVX2 alignment and vectorized operations
144/// - **Memory Efficiency**: Optimized alignment strategies for different tensor sizes
145/// - **gradtrack Integration**: Built-in gradient tracking and computation
146/// - **Device Support**: CPU and future CUDA device placement
147/// - **View Tensors**: Zero-copy tensor views with shared memory management
148/// - **Thread Safety**: Send + Sync implementation for concurrent usage
149/// - **Operator Overloading**: Natural mathematical expressions (+, -, *, /, +=, -=, *=, /=)
150///
151/// # Memory Layout
152///
153/// Tensors use row-major memory layout with size-dependent alignment:
154/// - **Small tensors** (≤8 elements): 16-byte SSE alignment
155/// - **Medium tensors** (8-1024 elements): 32-byte AVX2 alignment
156/// - **Large tensors** (>1024 elements): 64-byte cache-line alignment
157///
158/// # Performance Characteristics
159///
160/// - **Memory Overhead**: ~64 bytes per tensor (excluding data)
161/// - **SIMD Ready**: Properly aligned for vectorized operations
162/// - **Cache Friendly**: Optimized memory layout for CPU cache hierarchies
163/// - **Zero-Cost Views**: View tensors share memory without copying
164/// - **Thread Safe**: Atomic ID generation and lock-free operations
165/// - **Operator Performance**: Zero-cost operator overloading for mathematical expressions
166///
167/// # Safety
168///
169/// This struct uses unsafe code for performance. The following invariants must be maintained:
170/// - `data` must be valid for `shape.size` elements
171/// - `data` must be properly aligned for `f32`
172/// - `data` must not be aliased while the tensor exists
173/// - `shape.size` must match the actual allocated memory
174/// - `allocation_owner` must be valid if present
175///
176/// # Examples
177///
178/// ## Basic Tensor Operations
179///
180/// ```
181/// use train_station::Tensor;
182///
183/// // Create tensors with different configurations
184/// let tensor = Tensor::new(vec![2, 3]);
185/// let tensor_with_grad = Tensor::ones(vec![10, 10]).with_requires_grad();
186///
187/// // Access tensor properties
188/// assert_eq!(tensor.size(), 6);
189/// assert_eq!(tensor.shape().dims, vec![2, 3]);
190/// assert!(tensor.is_contiguous());
191/// ```
192///
193/// ## Operator Overloading
194///
195/// ```
196/// use train_station::Tensor;
197///
198/// // Create tensors for operations
199/// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
200/// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
201///
202/// // Tensor operations with operators
203/// let result = a.clone() + b.clone(); // Tensor addition
204/// let result = a.clone() * b.clone(); // Element-wise multiplication
205/// let result = a.clone() - b.clone(); // Tensor subtraction
206/// let result = a.clone() / b.clone(); // Element-wise division
207///
208/// // Scalar operations
209/// let result = a.clone() + 5.0; // Tensor + scalar
210/// let result = 5.0 + a.clone(); // Scalar + tensor
211/// let result = a.clone() * 3.0; // Tensor * scalar
212/// let result = 3.0 * a.clone(); // Scalar * tensor
213///
214/// // Compound expressions
215/// let result = (a.clone() + b.clone()) * 2.0 - 1.0; // Complex mathematical expressions
216///
217/// // Assignment operators
218/// let mut c = a.clone();
219/// c += b.clone(); // In-place addition
220/// c *= 2.0; // In-place scalar multiplication
221///
222/// // Negation
223/// let result = -a; // Negate all elements
224/// ```
225///
226/// # Thread Safety
227///
228/// This type is `Send + Sync` and can be safely shared between threads.
229/// All operations are thread-safe through atomic ID generation and
230/// thread-local gradtrack storage.
231pub struct Tensor {
232 /// Raw pointer to the tensor data in memory
233 ///
234 /// Provides zero-overhead access to tensor elements for maximum performance.
235 /// The pointer is guaranteed to be valid for `shape.size` elements and properly
236 /// aligned for SIMD operations. This field enables direct memory access without
237 /// bounds checking overhead.
238 ///
239 /// # Safety
240 ///
241 /// - Must be valid for `shape.size` elements
242 /// - Must be properly aligned for `f32` operations
243 /// - Must not be aliased while tensor exists
244 data: NonNull<f32>,
245
246 /// The shape and dimensional information of the tensor
247 ///
248 /// Contains the dimensions, size, strides, and memory layout information.
249 /// This field determines how the raw data is interpreted as a multi-dimensional
250 /// tensor and enables efficient memory access patterns.
251 shape: Shape,
252
253 /// Device where this tensor is located (CPU/GPU)
254 ///
255 /// Determines the physical location of the tensor data and which operations
256 /// can be performed on it. Currently supports CPU with future CUDA support.
257 device: Device,
258
259 /// Unique identifier for gradtrack tracking
260 ///
261 /// Thread-safe, globally unique ID used by the gradtrack system to track
262 /// tensor operations and gradient computation. Generated atomically to
263 /// ensure uniqueness across concurrent tensor creation.
264 id: usize,
265
266 /// Whether this tensor requires gradient computation
267 ///
268 /// Controls whether the gradtrack system tracks operations on this tensor
269 /// and computes gradients during backward pass. When `true`, operations
270 /// are recorded in the computation graph for gradient propagation.
271 requires_grad: bool,
272
273 /// Accumulated gradients from backward pass
274 ///
275 /// Stores the computed gradients for this tensor after calling `backward()`.
276 /// `None` if `requires_grad=false` or no gradients have been computed yet.
277 /// Uses `Arc` for efficient sharing between view tensors.
278 grad: Option<Arc<Tensor>>,
279
280 /// Gradient function for gradtrack computation
281 ///
282 /// Records the operation that created this tensor for gradient computation
283 /// during backward pass. Contains the necessary information to compute
284 /// gradients with respect to input tensors.
285 grad_fn: GradFn,
286
287 /// Shared allocation owner for view tensors
288 ///
289 /// Enables zero-copy tensor views by sharing memory allocation between
290 /// multiple tensors. `None` for tensors that own their memory directly.
291 /// Uses `Arc` for thread-safe reference counting and automatic cleanup.
292 allocation_owner: Option<std::sync::Arc<Allocation>>,
293
294 /// Phantom data to ensure proper lifetime management
295 ///
296 /// Ensures the tensor has the correct lifetime parameters for the `f32`
297 /// data type. This prevents lifetime issues when working with raw pointers.
298 _phantom: PhantomData<f32>,
299}
300
301// Make Tensor Send + Sync for thread-safe usage
302//
303// Safety: The raw pointer is properly managed through RAII patterns and
304// the data is not shared between threads without proper synchronization.
305// All tensor operations are thread-safe through atomic ID generation and
306// thread-local gradtrack storage.
307unsafe impl Send for Tensor {}
308unsafe impl Sync for Tensor {}
309
310impl Drop for Tensor {
311 /// Frees the tensor's memory when it goes out of scope
312 ///
313 /// This ensures proper cleanup and prevents memory leaks.
314 fn drop(&mut self) {
315 // If we have a shared allocation owner, memory will be freed when last owner drops.
316 if self.allocation_owner.is_none() && self.shape.size > 0 {
317 unsafe {
318 let layout =
319 Layout::from_size_align(self.shape.size * std::mem::size_of::<f32>(), 32)
320 .expect("Failed to create layout for deallocation");
321 dealloc(self.data.as_ptr() as *mut u8, layout);
322 }
323 }
324 }
325}
326
327impl std::fmt::Debug for Tensor {
328 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329 f.debug_struct("Tensor")
330 .field("shape", &self.shape)
331 .field("size", &self.size())
332 .field("id", &self.id)
333 .field("requires_grad", &self.requires_grad)
334 .field("has_grad", &self.grad.is_some())
335 .field("has_grad_fn", &!matches!(self.grad_fn, GradFn::None))
336 .finish()
337 }
338}
339
340/// Clone implementation for Tensor
341///
342/// Creates a deep copy of the tensor data but resets gradtrack state
343/// (new tensor won't track gradients unless explicitly set)
344impl Clone for Tensor {
345 fn clone(&self) -> Self {
346 // Fast path for contiguous tensors: direct linear copy
347 if self.is_contiguous() || self.size() == 0 {
348 let mut cloned = Self::new(self.shape.dims.clone());
349 unsafe {
350 let src = self.as_ptr();
351 let dst = cloned.as_mut_ptr();
352 std::ptr::copy_nonoverlapping(src, dst, self.size());
353 }
354 return cloned;
355 }
356
357 // Non-contiguous view: materialize into a contiguous copy respecting strides
358 let mut result = Tensor::new(self.shape().dims.clone());
359 let rank = self.shape().rank();
360 unsafe {
361 let dst_ptr = result.as_mut_ptr();
362 for dst_idx in 0..result.size() {
363 // Compute destination coordinates under contiguous strides
364 let mut coords = vec![0usize; rank];
365 let mut tmp = dst_idx;
366 for i in (0..rank).rev() {
367 let dim_size = self.shape().dims[i];
368 coords[i] = tmp % dim_size;
369 tmp /= dim_size;
370 }
371 let src_off = self.shape().offset(&coords);
372 *dst_ptr.add(dst_idx) = *self.as_ptr().add(src_off);
373 }
374 }
375
376 result
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 //! Core tensor functionality tests
383 //!
384 //! Comprehensive tests for tensor creation, memory layout, operator overloading,
385 //! device management, and optimization information. Tests cover all major
386 //! functionality including edge cases and performance characteristics.
387
388 use super::*;
389
390 /// Test basic tensor creation and properties
391 ///
392 /// Verifies that tensors are created with correct dimensions, size, and rank.
393 /// Tests the fundamental tensor creation functionality.
394 #[test]
395 fn test_tensor_creation() {
396 let tensor = Tensor::new(vec![2, 3, 4]);
397 assert_eq!(tensor.size(), 24);
398 assert_eq!(tensor.shape().rank(), 3);
399 }
400
401 #[test]
402 fn test_tensor_1d() {
403 let tensor = Tensor::new(vec![10]);
404 assert_eq!(tensor.size(), 10);
405 assert_eq!(tensor.shape().rank(), 1);
406 }
407
408 #[test]
409 fn test_tensor_2d() {
410 let tensor = Tensor::new(vec![3, 4]);
411 assert_eq!(tensor.size(), 12);
412 assert_eq!(tensor.shape().rank(), 2);
413 }
414
415 #[test]
416 fn test_zero_sized_tensor() {
417 let tensor = Tensor::new(vec![0]);
418 assert_eq!(tensor.size(), 0);
419 }
420
421 /// Test memory layout API and optimization information
422 ///
423 /// Verifies that memory layout information (contiguity, strides, alignment)
424 /// and optimization hints are correctly computed and accessible.
425 #[test]
426 fn test_memory_layout_api() {
427 let tensor = Tensor::new(vec![2, 3, 4]);
428
429 // Test contiguity
430 assert!(tensor.is_contiguous());
431 assert!(!tensor.is_view());
432
433 // Test strides
434 assert_eq!(tensor.strides(), &[12, 4, 1]);
435 assert_eq!(tensor.stride(0), 12);
436 assert_eq!(tensor.stride(1), 4);
437 assert_eq!(tensor.stride(2), 1);
438
439 // Test memory offset calculation
440 assert_eq!(tensor.memory_offset(&[0, 0, 0]), 0);
441 assert_eq!(tensor.memory_offset(&[1, 2, 3]), 12 + 8 + 3);
442
443 // Test SIMD alignment
444 assert!(tensor.is_simd_aligned());
445 assert_eq!(tensor.memory_alignment(), 32);
446
447 // Test memory footprint
448 assert_eq!(tensor.memory_footprint(), 24 * 4); // 24 elements * 4 bytes
449 }
450
451 #[test]
452 fn test_broadcasting_compatibility() {
453 let a = Tensor::new(vec![2, 3, 4]);
454 let b = Tensor::new(vec![1, 3, 4]);
455 let c = Tensor::new(vec![4]);
456 let d = Tensor::new(vec![2, 1, 4]);
457 let e = Tensor::new(vec![2, 2, 4]);
458
459 assert!(a.is_broadcastable_with(&b));
460 assert!(a.is_broadcastable_with(&c));
461 assert!(a.is_broadcastable_with(&d));
462 assert!(!a.is_broadcastable_with(&e)); // 3 != 2 and neither is 1
463 }
464
465 #[test]
466 fn test_tensor_device_cpu() {
467 use crate::device::Device;
468
469 let tensor = Tensor::new(vec![2, 3]);
470 assert_eq!(tensor.device(), Device::cpu());
471 assert!(tensor.device().is_cpu());
472 assert!(!tensor.device().is_cuda());
473 }
474
475 #[test]
476 fn test_tensor_new_on_device_cpu() {
477 use crate::device::Device;
478
479 let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
480 assert_eq!(tensor.device(), Device::cpu());
481 assert_eq!(tensor.size(), 6);
482 }
483
484 #[test]
485 #[should_panic(expected = "CUDA support not enabled. Enable with --features cuda")]
486 fn test_tensor_new_on_cuda_panics() {
487 use crate::device::Device;
488
489 // This should panic since CUDA feature is not enabled
490 // The panic occurs when trying to create the CUDA device
491 Device::cuda(0);
492 }
493
494 #[test]
495 fn test_device_context_integration() {
496 use crate::device::{with_device, Device};
497
498 // Test that tensors created in different device contexts get the right device
499 let tensor1 = Tensor::new(vec![2]);
500 assert_eq!(tensor1.device(), Device::cpu());
501
502 with_device(Device::cpu(), || {
503 let tensor2 = Tensor::new(vec![3]);
504 assert_eq!(tensor2.device(), Device::cpu());
505 });
506 }
507
508 #[test]
509 fn test_device_zero_sized_tensor() {
510 use crate::device::Device;
511
512 let tensor = Tensor::new_on_device(vec![0], Device::cpu());
513 assert_eq!(tensor.device(), Device::cpu());
514 assert_eq!(tensor.size(), 0);
515 }
516
517 /// Test data() and data_mut() methods for safe tensor data access
518 #[test]
519 fn test_data_access_methods() {
520 // Test data() method
521 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
522 let data = tensor.data();
523
524 assert_eq!(data.len(), 4);
525 assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
526
527 // Test data_mut() method
528 let mut tensor = Tensor::new(vec![2, 2]);
529 let data_mut = tensor.data_mut();
530 data_mut[0] = 10.0;
531 data_mut[1] = 20.0;
532 data_mut[2] = 30.0;
533 data_mut[3] = 40.0;
534
535 // Verify changes
536 assert_eq!(tensor.get(&[0, 0]), 10.0);
537 assert_eq!(tensor.get(&[0, 1]), 20.0);
538 assert_eq!(tensor.get(&[1, 0]), 30.0);
539 assert_eq!(tensor.get(&[1, 1]), 40.0);
540
541 // Test with zero-sized tensor
542 let empty = Tensor::new(vec![0]);
543 assert_eq!(empty.data().len(), 0);
544
545 let mut empty_mut = Tensor::new(vec![0]);
546 assert_eq!(empty_mut.data_mut().len(), 0);
547 }
548
549 /// Test data() method with standard library operations
550 #[test]
551 fn test_data_with_std_operations() {
552 let tensor = Tensor::from_slice(&[1.0, -2.0, 3.0, -4.0, 5.0], vec![5]).unwrap();
553 let data = tensor.data();
554
555 // Test iterator methods
556 let sum: f32 = data.iter().sum();
557 assert_eq!(sum, 3.0);
558
559 let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
560 assert_eq!(max, 5.0);
561
562 let positive_count = data.iter().filter(|&&x| x > 0.0).count();
563 assert_eq!(positive_count, 3);
564
565 // Test indexing
566 assert_eq!(data[0], 1.0);
567 assert_eq!(data[4], 5.0);
568 }
569
570 /// Test value() method for scalar tensor access
571 #[test]
572 fn test_value_method() {
573 // Test single-element tensor
574 let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
575 assert_eq!(scalar.value(), 42.0);
576
577 // Test with different shapes that have size 1
578 let scalar_2d = Tensor::from_slice(&[std::f32::consts::PI], vec![1, 1]).unwrap();
579 assert_eq!(scalar_2d.value(), std::f32::consts::PI);
580
581 let scalar_3d = Tensor::from_slice(&[-1.5], vec![1, 1, 1]).unwrap();
582 assert_eq!(scalar_3d.value(), -1.5);
583
584 // Test with result from iterator
585 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
586 let first_elem = tensor.iter().next().unwrap();
587 assert_eq!(first_elem.value(), 1.0);
588 assert_eq!(first_elem.shape().dims, vec![1]);
589 assert_eq!(first_elem.size(), 1);
590 }
591
592 /// Test value() method error handling
593 #[test]
594 #[should_panic(expected = "value() can only be called on tensors with exactly one element")]
595 fn test_value_method_panics_on_multi_element() {
596 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
597 let _ = tensor.value(); // Should panic
598 }
599
600 /// Test value() method with empty tensor
601 #[test]
602 #[should_panic(expected = "value() can only be called on tensors with exactly one element")]
603 fn test_value_method_panics_on_empty() {
604 let empty = Tensor::new(vec![0]);
605 let _ = empty.value(); // Should panic
606 }
607}