train_station/tensor/core/utils.rs
1//! Tensor utility functions and core implementation methods
2//!
3//! This module provides essential utility functions for tensor creation, memory management,
4//! gradient tracking, and optimization. It contains the core implementation methods
5//! that enable efficient tensor operations and gradtrack functionality.
6//!
7//! # Key Features
8//!
9//! - **Tensor Creation**: Optimized constructors with memory alignment
10//! - **Memory Management**: Safe memory access and allocation utilities
11//! - **Gradient Tracking**: GradTrack system integration and gradient management
12//! - **Performance Optimization**: SIMD-ready memory layout and alignment
13//! - **Device Management**: CPU and future CUDA device support
14//! - **Memory Layout**: Contiguous, strided, and view memory access patterns
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Alignment**: 16-byte SSE, 32-byte AVX2, 64-byte cache-line alignment
19//! - **SIMD Optimization**: Properly aligned memory for vectorized operations
20//! - **Zero-Cost Abstractions**: Minimal overhead for utility operations
21//! - **Thread Safety**: Atomic operations for gradient tracking and ID generation
22//! - **Memory Efficiency**: Optimized allocation strategies for different tensor sizes
23//!
24//! # Examples
25//!
26//! ## Basic Tensor Creation
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensors of different sizes
32//! let small_tensor = Tensor::new(vec![2, 3]); // 16-byte alignment
33//! let medium_tensor = Tensor::new(vec![32, 32]); // 32-byte alignment
34//! let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
35//!
36//! // Initialize data before use
37//! let mut tensor = Tensor::new(vec![2, 3]);
38//! tensor.fill(0.0); // Initialize with zeros
39//! ```
40//!
41//! ## Gradient Tracking
42//!
43//! ```
44//! use train_station::Tensor;
45//!
46//! let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
47//! assert!(tensor.requires_grad());
48//! ```
49//!
50//! ## Memory Access
51//!
52//! ```
53//! use train_station::Tensor;
54//!
55//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
56//! let value = tensor.get(&[0, 1]);
57//! assert_eq!(value, 2.0);
58//!
59//! let mut tensor = Tensor::new(vec![2, 2]);
60//! tensor.set(&[0, 1], 42.0);
61//! assert_eq!(tensor.get(&[0, 1]), 42.0);
62//! ```
63
64use std::alloc::Layout;
65use std::marker::PhantomData;
66use std::ptr::NonNull;
67use std::sync::atomic::Ordering;
68use std::sync::Arc;
69
70use crate::device::current_device;
71use crate::gradtrack::engine::ensure_local_group_for_tensor;
72use crate::gradtrack::{self, GradEngine, GradFn};
73use crate::tensor::core::memory::{
74 compute_allocation_params, detect_runtime_simd, simd_lane_width_elems, use_pool_alloc_enabled,
75 SimdLevel,
76};
77use crate::tensor::core::{Allocation, Device, TENSOR_ID_COUNTER};
78use crate::tensor::Shape;
79
80use super::Tensor;
81
82impl Tensor {
83 /// Ensures that this tensor has unique ownership of its allocation before mutation.
84 ///
85 /// If the underlying allocation is shared (i.e., multiple views), this performs
86 /// a copy-on-write: allocates a new buffer and copies data so that mutating this tensor
87 /// does not affect existing views.
88 fn ensure_unique_allocation(&mut self) {
89 if self.size() == 0 {
90 return;
91 }
92 if let Some(owner) = self.allocation_owner.as_ref() {
93 if std::sync::Arc::strong_count(owner) > 1 {
94 // Allocate a new buffer with the same size/alignment policy as constructors
95 let (alignment, padded_elems) = compute_allocation_params(self.shape.size());
96 let total_size = padded_elems * std::mem::size_of::<f32>();
97 let layout = Layout::from_size_align(total_size, alignment)
98 .expect("Failed to create layout for tensor data");
99
100 let alloc_obj = if use_pool_alloc_enabled() {
101 Allocation::new_pooled(padded_elems, alignment, layout)
102 } else {
103 Allocation::new_uninitialized(padded_elems, alignment, layout)
104 };
105
106 let new_ptr = alloc_obj.ptr;
107 unsafe {
108 std::ptr::copy_nonoverlapping(self.as_ptr(), new_ptr.as_ptr(), self.size());
109 }
110
111 // Replace pointer and owner with the new unique allocation
112 self.data = new_ptr;
113 self.allocation_owner = Some(std::sync::Arc::new(alloc_obj));
114 }
115 }
116 }
117 // ===== Runtime SIMD capability helpers for ops selection =====
118
119 /// Returns the highest runtime-detected SIMD level on this CPU
120 #[inline]
121 pub(crate) fn simd_runtime_level() -> SimdLevel {
122 detect_runtime_simd()
123 }
124
125 /// Returns the SIMD lane width (elements per vector) for f32 at runtime
126 #[inline]
127 pub(crate) fn simd_lane_width_elems_runtime() -> usize {
128 simd_lane_width_elems(Self::simd_runtime_level())
129 }
130
131 /// Checks whether this tensor's data pointer is aligned for the specified SIMD level
132 #[inline]
133 #[cfg(test)]
134 #[cfg(target_arch = "x86_64")]
135 pub(crate) fn is_aligned_for_level(&self, level: SimdLevel) -> bool {
136 use crate::tensor::core::memory::simd_alignment_bytes;
137
138 let required = simd_alignment_bytes(level);
139 (self.data.as_ptr() as usize).is_multiple_of(required)
140 }
141
142 /// Returns true if this tensor is aligned for the current runtime SIMD level
143 #[inline]
144 #[cfg(test)]
145 #[cfg(target_arch = "x86_64")]
146 pub(crate) fn is_aligned_for_runtime_level(&self) -> bool {
147 self.is_aligned_for_level(Self::simd_runtime_level())
148 }
149
150 /// Returns true if ops should use aligned SIMD loads/stores without a tail branch
151 /// (i.e., pointer is aligned for the runtime SIMD level and length is a multiple of lane)
152 #[inline]
153 #[cfg(test)]
154 #[cfg(target_arch = "x86_64")]
155 pub(crate) fn prefer_aligned_simd_ops(&self) -> bool {
156 let lane = Self::simd_lane_width_elems_runtime();
157 self.is_aligned_for_runtime_level() && self.size().is_multiple_of(lane)
158 }
159
160 /// Returns true if ops should use unaligned SIMD loads/stores (mm_loadu)
161 /// because pointer alignment or length does not meet aligned requirements.
162 /// Returns false when no SIMD is available.
163 #[inline]
164 #[cfg(target_arch = "x86_64")]
165 #[cfg(test)]
166 pub(crate) fn should_use_unaligned_simd_ops(&self) -> bool {
167 match Self::simd_runtime_level() {
168 SimdLevel::Scalar => false,
169 _ => !self.prefer_aligned_simd_ops(),
170 }
171 }
172
173 /// Returns the allocated capacity in elements, which may be padded beyond logical size
174 #[inline]
175 pub fn capacity_elems(&self) -> usize {
176 if let Some(owner) = self.allocation_owner() {
177 owner.capacity_elems()
178 } else {
179 self.size()
180 }
181 }
182 /// Creates a new tensor with the specified shape and optimized memory layout
183 ///
184 /// Allocates memory with size-dependent alignment for optimal performance:
185 /// - Small tensors (≤8 elements): 16-byte SSE alignment
186 /// - Medium tensors (8-1024 elements): 32-byte AVX2 alignment
187 /// - Large tensors (>1024 elements): 64-byte cache-line alignment
188 ///
189 /// # Arguments
190 ///
191 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
192 ///
193 /// # Returns
194 ///
195 /// A new tensor with uninitialized data. The data must be initialized
196 /// before use to avoid undefined behavior.
197 ///
198 /// # Performance
199 ///
200 /// - **Memory Allocation**: Single allocation with optimized alignment
201 /// - **SIMD Ready**: Properly aligned for vectorized operations
202 /// - **Cache Friendly**: Optimized for CPU cache hierarchies
203 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
204 ///
205 /// # Safety
206 ///
207 /// The returned tensor contains uninitialized memory. You must initialize
208 /// the data before performing any operations that read from it.
209 ///
210 /// # Examples
211 ///
212 /// ```
213 /// use train_station::Tensor;
214 ///
215 /// // Create tensors of different sizes
216 /// let small_tensor = Tensor::new(vec![2, 3]); // 16-byte alignment
217 /// let medium_tensor = Tensor::new(vec![32, 32]); // 32-byte alignment
218 /// let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
219 ///
220 /// // Initialize data before use
221 /// let mut tensor = Tensor::new(vec![2, 3]);
222 /// tensor.fill(0.0); // Initialize with zeros
223 /// ```
224 #[inline]
225 #[track_caller]
226 pub fn new(shape_dims: Vec<usize>) -> Self {
227 let shape = Shape::new(shape_dims);
228 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
229
230 if shape.size() == 0 {
231 // Handle zero-sized tensors
232 return Self {
233 data: NonNull::dangling(),
234 shape,
235 device: current_device(),
236 id,
237 requires_grad: false,
238 retain_grad: false,
239 grad: None,
240 grad_fn: GradFn::None,
241 allocation_owner: None,
242 graph_group: None,
243 _phantom: PhantomData,
244 };
245 }
246
247 // Compute alignment and padded element count based on runtime SIMD
248 let (alignment, padded_elems) = compute_allocation_params(shape.size());
249 let total_size = padded_elems * std::mem::size_of::<f32>();
250
251 let layout = Layout::from_size_align(total_size, alignment)
252 .expect("Failed to create layout for tensor data");
253
254 // Allocation policy: prefer pool for tiny tensors and medium/large classes.
255 // Route certain small sizes to system allocator to keep pool stats semantics used by tests.
256 let alloc_obj = if use_pool_alloc_enabled() {
257 Allocation::new_pooled(padded_elems, alignment, layout)
258 } else {
259 Allocation::new(padded_elems, alignment, layout)
260 };
261 let ptr = alloc_obj.ptr;
262
263 debug_assert!(
264 alloc_obj.capacity_elems() >= padded_elems,
265 "Allocation capacity ({}) smaller than padded elements ({})",
266 alloc_obj.capacity_elems(),
267 padded_elems
268 );
269
270 Self {
271 data: ptr,
272 shape,
273 device: current_device(),
274 id,
275 requires_grad: false,
276 retain_grad: false,
277 grad: None,
278 grad_fn: GradFn::None,
279 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
280 graph_group: None,
281 _phantom: PhantomData,
282 }
283 }
284
285 /// Returns the shape and dimensional information of the tensor
286 ///
287 /// Provides access to the tensor's dimensions, size, strides, and memory
288 /// layout information. This is used for shape validation, memory access
289 /// calculations, and optimization decisions.
290 ///
291 /// # Returns
292 ///
293 /// Reference to the tensor's shape information containing dimensions,
294 /// size, strides, and memory layout type.
295 ///
296 /// # Performance
297 ///
298 /// - **Time Complexity**: O(1) - direct field access
299 /// - **Memory**: No allocation - returns reference to existing data
300 ///
301 /// # Examples
302 ///
303 /// ```
304 /// use train_station::Tensor;
305 ///
306 /// let tensor = Tensor::new(vec![2, 3, 4]);
307 /// let shape = tensor.shape();
308 /// assert_eq!(shape.dims(), vec![2, 3, 4]);
309 /// assert_eq!(shape.size(), 24);
310 /// assert_eq!(shape.rank(), 3);
311 /// ```
312 #[inline]
313 #[track_caller]
314 pub fn shape(&self) -> &Shape {
315 &self.shape
316 }
317
318 /// Returns the total number of elements in the tensor
319 ///
320 /// Provides the total count of elements across all dimensions. This is
321 /// used for memory allocation, iteration bounds, and performance optimization.
322 ///
323 /// # Returns
324 ///
325 /// Total number of elements as `usize`
326 ///
327 /// # Performance
328 ///
329 /// - **Time Complexity**: O(1) - direct field access
330 /// - **Memory**: No allocation - returns stored value
331 ///
332 /// # Examples
333 ///
334 /// ```
335 /// use train_station::Tensor;
336 ///
337 /// let tensor = Tensor::new(vec![2, 3, 4]);
338 /// assert_eq!(tensor.size(), 24); // 2 * 3 * 4
339 ///
340 /// let scalar = Tensor::new(vec![1]);
341 /// assert_eq!(scalar.size(), 1);
342 ///
343 /// let empty = Tensor::new(vec![0]);
344 /// assert_eq!(empty.size(), 0);
345 /// ```
346 #[inline]
347 #[track_caller]
348 pub fn size(&self) -> usize {
349 self.shape().size()
350 }
351
352 /// Returns the device where this tensor is located
353 ///
354 /// Provides the physical location of the tensor data (CPU/GPU). This
355 /// determines which operations can be performed on the tensor and where
356 /// computations will be executed.
357 ///
358 /// # Returns
359 ///
360 /// Device enum indicating the tensor's physical location
361 ///
362 /// # Performance
363 ///
364 /// - **Time Complexity**: O(1) - direct field access
365 /// - **Memory**: No allocation - returns stored value
366 ///
367 /// # Examples
368 ///
369 /// ```
370 /// use train_station::Tensor;
371 ///
372 /// let tensor = Tensor::new(vec![2, 3]);
373 /// assert!(tensor.device().is_cpu());
374 /// assert!(!tensor.device().is_cuda());
375 /// ```
376 #[inline]
377 #[track_caller]
378 pub fn device(&self) -> Device {
379 self.device
380 }
381
382 /// Creates a new tensor with the specified shape on a specific device
383 ///
384 /// Allocates memory on the specified device with the same optimized alignment
385 /// strategy as `new()`. Currently supports CPU device with future CUDA support.
386 ///
387 /// # Arguments
388 ///
389 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
390 /// * `device` - The device where the tensor should be allocated
391 ///
392 /// # Returns
393 ///
394 /// A new tensor with uninitialized data on the specified device
395 ///
396 /// # Performance
397 ///
398 /// - **Memory Allocation**: Device-specific allocation with optimized alignment
399 /// - **SIMD Ready**: Properly aligned for vectorized operations on target device
400 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
401 ///
402 /// # Panics
403 ///
404 /// Panics if the specified device is not supported (e.g., CUDA without feature flag)
405 ///
406 /// # Examples
407 ///
408 /// ```
409 /// use train_station::Tensor;
410 ///
411 /// let tensor = Tensor::new_on_device(vec![2, 3], train_station::Device::cpu());
412 /// assert!(tensor.device().is_cpu());
413 /// assert_eq!(tensor.size(), 6);
414 /// ```
415 ///
416 /// # Arguments
417 ///
418 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
419 /// * `device` - Device where the tensor should be allocated
420 ///
421 /// # Returns
422 ///
423 /// A new tensor with uninitialized data on the specified device
424 ///
425 /// # Panics
426 ///
427 /// Panics if the device is not supported (currently only CPU is supported)
428 ///
429 /// # Performance
430 ///
431 /// - **Memory Allocation**: Single allocation with optimized alignment
432 /// - **SIMD Ready**: Properly aligned for vectorized operations
433 /// - **Cache Friendly**: Optimized for CPU cache hierarchies
434 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
435 ///
436 /// # Examples
437 ///
438 /// ```
439 /// use train_station::{Tensor, Device};
440 ///
441 /// // Create tensor on CPU device
442 /// let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
443 /// assert_eq!(tensor.device(), Device::cpu());
444 /// assert_eq!(tensor.size(), 6);
445 /// ```
446 #[track_caller]
447 pub fn new_on_device(shape_dims: Vec<usize>, device: Device) -> Self {
448 // For now, only CPU is supported
449 if !device.is_cpu() {
450 panic!("Only CPU device is currently supported. CUDA support is planned for future releases.");
451 }
452
453 let shape = Shape::new(shape_dims);
454 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
455
456 if shape.size() == 0 {
457 // Handle zero-sized tensors
458 return Self {
459 data: NonNull::dangling(),
460 shape,
461 device,
462 id,
463 requires_grad: false,
464 retain_grad: false,
465 grad: None,
466 grad_fn: GradFn::None,
467 allocation_owner: None,
468 graph_group: None,
469 _phantom: PhantomData,
470 };
471 }
472
473 let (alignment, padded_elems) = compute_allocation_params(shape.size());
474 let total_size = padded_elems * std::mem::size_of::<f32>();
475 let layout = Layout::from_size_align(total_size, alignment)
476 .expect("Failed to create layout for tensor data");
477
478 // Same allocation policy as new(): prefer pool for tiny and medium/large classes
479 let alloc_obj = if use_pool_alloc_enabled() {
480 Allocation::new_pooled(padded_elems, alignment, layout)
481 } else {
482 Allocation::new(padded_elems, alignment, layout)
483 };
484 let ptr = alloc_obj.ptr;
485
486 debug_assert!(
487 alloc_obj.capacity_elems() >= padded_elems,
488 "Allocation capacity ({}) smaller than padded elements ({})",
489 alloc_obj.capacity_elems(),
490 padded_elems
491 );
492
493 Self {
494 data: ptr,
495 shape,
496 device,
497 id,
498 requires_grad: false,
499 retain_grad: false,
500 grad: None,
501 grad_fn: GradFn::None,
502 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
503 graph_group: None,
504 _phantom: PhantomData,
505 }
506 }
507
508 /// Enable gradient computation for this tensor
509 ///
510 /// Builder method that enables automatic gradient tracking for this tensor.
511 /// When enabled, all operations involving this tensor will be recorded in
512 /// the computation graph for gradient computation during backward pass.
513 ///
514 /// # Returns
515 ///
516 /// `self` with gradient tracking enabled
517 ///
518 /// # Performance
519 ///
520 /// - **Time Complexity**: O(1) - simple field assignment
521 /// - **Memory**: No additional allocation
522 /// - **Overhead**: Minimal gradtrack tracking overhead when gradients computed
523 ///
524 /// # Examples
525 ///
526 /// ```
527 /// use train_station::Tensor;
528 ///
529 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
530 /// assert!(tensor.requires_grad());
531 /// ```
532 #[track_caller]
533 pub fn with_requires_grad(mut self) -> Self {
534 self.requires_grad = true;
535 if self.graph_group.is_none() {
536 self.graph_group = Some(ensure_local_group_for_tensor(self.id()));
537 }
538 self
539 }
540
541 /// Set gradient tracking for this tensor
542 ///
543 /// Controls whether the gradtrack system tracks operations on this tensor
544 /// and computes gradients during backward pass. When disabled, clears
545 /// any existing gradients and gradient functions.
546 ///
547 /// # Arguments
548 ///
549 /// * `requires_grad` - Whether to track gradients for this tensor
550 ///
551 /// # Performance
552 ///
553 /// - **Time Complexity**: O(1) - simple field assignment
554 /// - **Memory**: May free gradient storage when disabled
555 /// - **Overhead**: Zero overhead when gradients disabled
556 ///
557 /// # Examples
558 ///
559 /// ```
560 /// use train_station::Tensor;
561 ///
562 /// let mut tensor = Tensor::ones(vec![2, 3]);
563 /// tensor.set_requires_grad(true);
564 /// assert!(tensor.requires_grad());
565 ///
566 /// // Disable gradient tracking
567 /// tensor.set_requires_grad(false);
568 /// assert!(!tensor.requires_grad());
569 /// ```
570 #[track_caller]
571 pub fn set_requires_grad(&mut self, requires_grad: bool) {
572 self.requires_grad = requires_grad;
573 if !requires_grad {
574 self.grad = None;
575 self.grad_fn = GradFn::None;
576 self.graph_group = None;
577 } else {
578 // Lazily bind to a local graph group for this tensor id
579 if self.graph_group.is_none() {
580 self.graph_group = Some(ensure_local_group_for_tensor(self.id()));
581 }
582 }
583 }
584
585 /// Mark this tensor to retain gradients after backward, even if it is non-leaf.
586 ///
587 /// Builder-style API: returns self with `retain_grad=true`.
588 /// Call `materialize_grad()` or `grad_or_fetch()` after backward to copy the
589 /// accumulated gradient from the GradGraph into `self.grad` so `grad()` works.
590 #[track_caller]
591 pub fn retain_grad(mut self) -> Self {
592 self.retain_grad = true;
593 // Register intent with grad engine to keep gradient available after backward
594 crate::gradtrack::engine::mark_retain_grad(self.id);
595 self
596 }
597
598 /// In-place variant to enable or disable gradient retention for non-leaf tensors
599 #[track_caller]
600 pub fn retain_grad_(&mut self, enable: bool) {
601 self.retain_grad = enable;
602 if enable {
603 crate::gradtrack::engine::mark_retain_grad(self.id);
604 }
605 }
606
607 /// Check if this tensor requires gradients
608 ///
609 /// # Returns
610 ///
611 /// `true` if gradient tracking is enabled for this tensor
612 ///
613 /// # Examples
614 ///
615 /// ```
616 /// use train_station::Tensor;
617 ///
618 /// let tensor = Tensor::new(vec![2, 3]);
619 /// assert!(!tensor.requires_grad());
620 ///
621 /// let grad_tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
622 /// assert!(grad_tensor.requires_grad());
623 /// ```
624 #[track_caller]
625 pub fn requires_grad(&self) -> bool {
626 self.requires_grad
627 }
628
629 /// Get the accumulated gradients (if any)
630 ///
631 /// Returns a reference to the gradient tensor if gradients have been computed
632 /// and this tensor has gradient tracking enabled.
633 ///
634 /// # Returns
635 ///
636 /// Optional reference to the gradient tensor, or `None` if no gradients exist
637 ///
638 /// # Examples
639 ///
640 /// ```
641 /// use train_station::Tensor;
642 ///
643 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
644 /// assert!(tensor.grad().is_none()); // No gradients computed yet
645 /// ```
646 #[track_caller]
647 pub fn grad(&self) -> Option<&Tensor> {
648 // First check if we have a gradient stored directly
649 if let Some(grad) = self.grad.as_ref() {
650 return Some(grad.as_ref());
651 }
652
653 None
654 }
655
656 /// Fetch the accumulated gradient after backward and cache it on this tensor if `retain_grad` is enabled.
657 ///
658 /// Returns true if a gradient was found and cached. After a successful call,
659 /// `grad()` will return `Some(&Tensor)` even for non-leaf tensors.
660 #[track_caller]
661 pub fn materialize_grad(&mut self) -> bool {
662 if self.grad.is_some() {
663 return true;
664 }
665 if !self.retain_grad {
666 return false;
667 }
668 if let Some(g) = self.grad_owned() {
669 self.set_grad(g);
670 return true;
671 }
672 false
673 }
674
675 /// Convenience accessor: if `retain_grad` is enabled, fetch and cache the gradient
676 /// on first access so callers can immediately get a reference.
677 #[track_caller]
678 pub fn grad_or_fetch(&mut self) -> Option<&Tensor> {
679 if self.grad.is_none() && self.retain_grad {
680 if let Some(g) = self.grad_owned() {
681 self.set_grad(g);
682 }
683 }
684 self.grad.as_deref()
685 }
686
687 /// Get the accumulated gradient as an owned tensor
688 ///
689 /// Returns the gradient by value (owned). This complements `grad()` which returns
690 /// a reference. Useful when you need to take ownership of the gradient data
691 /// (e.g., move into another structure or thread).
692 ///
693 /// This function does not clear internal gradient state. If a locally cached
694 /// gradient is not present, it consults shared autograd storage to fetch an
695 /// up-to-date gradient for this tensor ID.
696 ///
697 /// # Returns
698 ///
699 /// `Some(Tensor)` containing the gradient when available, otherwise `None`.
700 ///
701 /// # Examples
702 ///
703 /// ```
704 /// use train_station::Tensor;
705 ///
706 /// // Before backward: no gradient yet
707 /// let mut x = Tensor::ones(vec![2, 3]).with_requires_grad();
708 /// assert!(x.grad_owned().is_none());
709 ///
710 /// // Compute a simple loss and backprop
711 /// let mut loss = x.sum();
712 /// loss.backward(None);
713 ///
714 /// // Fetch the gradient by value (owned)
715 /// let g = x.grad_owned().unwrap();
716 /// assert_eq!(g.shape().dims(), vec![2, 3]);
717 /// ```
718 #[track_caller]
719 pub fn grad_owned(&self) -> Option<Tensor> {
720 // First check if we have a gradient stored directly
721 if let Some(grad) = self.grad.as_ref() {
722 return Some((**grad).clone());
723 }
724
725 // Always consult the global/shared gradient storage so gradients accumulated
726 // in a promoted/merged shared group are visible regardless of the calling thread.
727 // This is critical for cross-thread forward/backward validation and parity with LibTorch.
728 use crate::gradtrack;
729 gradtrack::get_accumulated_gradient(self.id)
730 }
731
732 /// Get the unique ID of this tensor
733 ///
734 /// Returns the unique identifier assigned to this tensor during creation.
735 /// This ID is used for gradtrack tracking and tensor identification.
736 ///
737 /// # Returns
738 ///
739 /// Unique tensor ID as `usize`
740 ///
741 /// # Examples
742 ///
743 /// ```
744 /// use train_station::Tensor;
745 ///
746 /// let tensor1 = Tensor::new(vec![2, 3]);
747 /// let tensor2 = Tensor::new(vec![2, 3]);
748 /// assert_ne!(tensor1.id(), tensor2.id()); // Each tensor has unique ID
749 /// ```
750 #[track_caller]
751 pub fn id(&self) -> usize {
752 self.id
753 }
754
755 /// Detach this tensor from the computation graph
756 ///
757 /// Returns a new tensor with the same data but no gradient tracking.
758 /// This is useful when you want to use a tensor in inference without
759 /// affecting the computation graph.
760 ///
761 /// # Returns
762 ///
763 /// A new tensor with the same data but gradient tracking disabled
764 ///
765 /// # Examples
766 ///
767 /// ```
768 /// use train_station::Tensor;
769 ///
770 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
771 /// let detached = tensor.detach();
772 /// assert!(!detached.requires_grad());
773 /// assert_eq!(tensor.size(), detached.size());
774 /// ```
775 #[track_caller]
776 pub fn detach(&self) -> Self {
777 let mut detached = Self::new(self.shape().dims().to_vec());
778
779 // Copy data
780 unsafe {
781 let src = self.as_ptr();
782 let dst = detached.as_mut_ptr();
783 std::ptr::copy_nonoverlapping(src, dst, self.size());
784 }
785
786 detached
787 }
788
789 /// Create a new tensor that doesn't track gradients from this one
790 ///
791 /// Similar to detach() but modifies this tensor in place. This is useful
792 /// when you want to disable gradient tracking for the current tensor
793 /// without creating a copy.
794 ///
795 /// # Examples
796 ///
797 /// ```
798 /// use train_station::Tensor;
799 ///
800 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
801 /// assert!(tensor.requires_grad());
802 /// tensor.detach_();
803 /// assert!(!tensor.requires_grad());
804 /// ```
805 #[track_caller]
806 pub fn detach_(&mut self) {
807 self.requires_grad = false;
808 self.grad = None;
809 self.grad_fn = GradFn::None;
810 }
811
812 /// Entry point for backward pass on this tensor
813 ///
814 /// Computes gradients for all tensors in the computation graph that have
815 /// `requires_grad` set to true. This is the main entry point for automatic
816 /// differentiation.
817 ///
818 /// # Arguments
819 ///
820 /// * `grad_output` - Optional gradient tensor for the output. If None, assumes
821 /// the tensor is a scalar (e.g., loss value) and uses a tensor of ones.
822 ///
823 /// # Examples
824 ///
825 /// ```
826 /// use train_station::Tensor;
827 ///
828 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
829 /// let mut result = tensor.add_scalar(5.0);
830 /// result.backward(None);
831 /// // Note: Gradient computation depends on the gradtrack system implementation
832 /// ```
833 #[track_caller]
834 pub fn backward(&mut self, grad_output: Option<Tensor>) {
835 GradEngine::backward(self, grad_output);
836 }
837
838 /// Returns a raw pointer to the tensor data for unsafe operations
839 ///
840 /// # Safety
841 ///
842 /// This is unsafe because it provides direct access to the underlying memory.
843 /// The caller must ensure:
844 /// - The tensor is not dropped while the pointer is used
845 /// - No concurrent mutable access occurs
846 /// - Bounds are respected
847 #[inline]
848 pub unsafe fn as_ptr(&self) -> *const f32 {
849 self.data.as_ptr()
850 }
851
852 /// Returns a mutable raw pointer to the tensor data for unsafe operations
853 ///
854 /// # Safety
855 ///
856 /// This is unsafe because it provides direct mutable access to the underlying memory.
857 /// The caller must ensure:
858 /// - The tensor is not dropped while the pointer is used
859 /// - No concurrent access occurs
860 /// - Bounds are respected
861 #[inline]
862 pub unsafe fn as_mut_ptr(&mut self) -> *mut f32 {
863 self.data.as_ptr()
864 }
865
866 /// Internal method to set gradient function (used by operations)
867 ///
868 /// Sets the gradient function for this tensor. This is used internally
869 /// by tensor operations to record the computation graph for gradtrack.
870 ///
871 /// # Arguments
872 ///
873 /// * `grad_fn` - The gradient function to set
874 ///
875 /// # Implementation Details
876 ///
877 /// This method is called by tensor operations to register the gradient
878 /// computation function. It only sets the gradient function if gradient
879 /// tracking is enabled for this tensor.
880 pub(crate) fn set_grad_fn(&mut self, grad_fn: GradFn) {
881 if self.requires_grad {
882 self.grad_fn = grad_fn;
883 }
884 }
885
886 /// Get a reference to the gradient function (for gradtrack)
887 ///
888 /// Returns a reference to the gradient function associated with this tensor.
889 /// This is used internally by the gradtrack system to compute gradients.
890 ///
891 /// # Returns
892 ///
893 /// Reference to the gradient function
894 ///
895 /// # Implementation Details
896 ///
897 /// This method is used by the gradtrack engine to access the gradient
898 /// computation function during backward pass.
899 #[track_caller]
900 pub fn grad_fn(&self) -> &GradFn {
901 &self.grad_fn
902 }
903
904 /// Internal method to set requires_grad (used by gradtrack operations)
905 ///
906 /// Sets the gradient tracking flag for this tensor. This is used internally
907 /// by gradtrack operations to control gradient computation.
908 ///
909 /// # Arguments
910 ///
911 /// * `requires_grad` - Whether to enable gradient tracking
912 ///
913 /// # Implementation Details
914 ///
915 /// This method is used internally by the gradtrack system to control
916 /// gradient tracking without triggering additional side effects.
917 pub(crate) fn set_requires_grad_internal(&mut self, requires_grad: bool) {
918 self.requires_grad = requires_grad;
919 }
920
921 /// Internal method to accumulate gradients with optimized operations
922 ///
923 /// Accumulates a gradient tensor with any existing gradients for this tensor.
924 /// This is used internally by the gradtrack system to handle gradient accumulation
925 /// during backward pass.
926 ///
927 /// # Arguments
928 ///
929 /// * `grad` - The gradient tensor to accumulate
930 ///
931 /// # Implementation Details
932 ///
933 /// This method is used internally by the gradtrack engine to accumulate
934 /// gradients from multiple backward passes or operations. It only accumulates
935 /// gradients if gradient tracking is enabled for this tensor.
936 pub(crate) fn accumulate_grad(&mut self, grad: Tensor) {
937 if !self.requires_grad {
938 return;
939 }
940
941 match &self.grad {
942 Some(existing_grad) => {
943 // Use optimized tensor addition but create new tensor for safety
944 let accumulated = existing_grad.add_tensor_optimized(&grad);
945 self.grad = Some(Arc::new(accumulated));
946 }
947 None => {
948 self.grad = Some(Arc::new(grad));
949 }
950 }
951 }
952
953 /// Set gradient from external source
954 ///
955 /// Sets the gradient tensor for this tensor. This is used internally by the
956 /// gradtrack system to set gradients during backward pass.
957 ///
958 /// # Arguments
959 ///
960 /// * `grad` - The gradient tensor to set
961 ///
962 /// # Implementation Details
963 ///
964 /// This method is used internally by the gradtrack engine to set gradients
965 /// during backward pass. It only sets the gradient if gradient tracking is
966 /// enabled for this tensor.
967 #[track_caller]
968 pub fn set_grad(&mut self, grad: Tensor) {
969 if self.requires_grad {
970 self.grad = Some(Arc::new(grad));
971 }
972 }
973
974 /// Clear accumulated gradients for this tensor
975 ///
976 /// This method is used by optimizers to zero gradients before each backward pass.
977 /// It clears any accumulated gradients, allowing for fresh gradient computation.
978 ///
979 /// # Examples
980 ///
981 /// ```
982 /// use train_station::Tensor;
983 ///
984 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
985 /// tensor.set_grad(Tensor::ones(vec![2, 3]));
986 /// assert!(tensor.grad().is_some());
987 /// tensor.zero_grad();
988 /// assert!(tensor.grad().is_none());
989 /// ```
990 #[track_caller]
991 pub fn zero_grad(&mut self) {
992 // Clear locally cached gradient
993 self.grad = None;
994 // Ensure gradient is also cleared from autograd storage so future backward runs
995 // don't observe stale gradients when mixing TLS and shared groups.
996 crate::gradtrack::engine::clear_gradient_for_tensor(self.id);
997 }
998
999 /// Negate all elements in the tensor in-place
1000 ///
1001 /// This is used internally for gradient computation in subtraction operations.
1002 /// For tensor - tensor operations, the second operand gets a negated gradient.
1003 ///
1004 /// # Implementation Details
1005 ///
1006 /// This method is used internally by the gradtrack system to compute gradients
1007 /// for subtraction operations. It uses SIMD optimization when available for
1008 /// better performance.
1009 #[inline]
1010 pub(crate) fn negate_inplace(&mut self) {
1011 if self.shape.size() == 0 {
1012 return;
1013 }
1014
1015 unsafe {
1016 let ptr = self.data.as_ptr();
1017
1018 #[cfg(target_arch = "x86_64")]
1019 {
1020 // Use SIMD for better performance when available
1021 if is_x86_feature_detected!("avx2") {
1022 self.negate_simd_avx2(ptr);
1023 return;
1024 }
1025 }
1026
1027 // Fallback to scalar operations
1028 for i in 0..self.shape.size() {
1029 *ptr.add(i) = -*ptr.add(i);
1030 }
1031 }
1032 }
1033
1034 /// SIMD-optimized negation using AVX2 instructions
1035 ///
1036 /// Performs in-place negation of tensor elements using AVX2 SIMD instructions
1037 /// for improved performance on x86_64 architectures.
1038 ///
1039 /// # Arguments
1040 ///
1041 /// * `ptr` - Raw pointer to the tensor data
1042 ///
1043 /// # Safety
1044 ///
1045 /// The caller must ensure:
1046 /// - `ptr` is a valid pointer to tensor data
1047 /// - The tensor size matches the actual data size
1048 /// - The tensor is not moved or dropped during this operation
1049 ///
1050 /// # Implementation Details
1051 ///
1052 /// This method is used internally by `negate_inplace` when AVX2 is available.
1053 /// It processes 8 elements per iteration using AVX2 instructions.
1054 #[cfg(target_arch = "x86_64")]
1055 #[inline]
1056 unsafe fn negate_simd_avx2(&self, ptr: *mut f32) {
1057 use std::arch::x86_64::_mm256_setzero_ps;
1058
1059 let size = self.shape().size();
1060 let zero_vec = _mm256_setzero_ps();
1061 let simd_count = size / 8; // Process 8 elements per iteration
1062 let mut offset = 0;
1063
1064 // SIMD loop for negation
1065 for _ in 0..simd_count {
1066 use std::arch::x86_64::{_mm256_load_ps, _mm256_store_ps, _mm256_sub_ps};
1067
1068 let vec = _mm256_load_ps(ptr.add(offset));
1069 let neg_vec = _mm256_sub_ps(zero_vec, vec);
1070 _mm256_store_ps(ptr.add(offset), neg_vec);
1071 offset += 8;
1072 }
1073
1074 // Handle remaining elements
1075 for i in offset..size {
1076 *ptr.add(i) = -*ptr.add(i);
1077 }
1078 }
1079
1080 // ===== Memory Layout and Optimization API =====
1081
1082 /// Checks if the tensor data is stored contiguously in memory
1083 ///
1084 /// # Returns
1085 ///
1086 /// `true` if the tensor data is contiguous, enabling optimized SIMD operations
1087 ///
1088 /// # Examples
1089 ///
1090 /// ```
1091 /// use train_station::Tensor;
1092 ///
1093 /// let tensor = Tensor::new(vec![2, 3, 4]);
1094 /// assert!(tensor.is_contiguous());
1095 /// ```
1096 #[inline]
1097 #[track_caller]
1098 pub fn is_contiguous(&self) -> bool {
1099 self.shape().is_contiguous()
1100 }
1101
1102 /// Gets the memory strides for all dimensions
1103 ///
1104 /// # Returns
1105 ///
1106 /// Reference to the stride vector for efficient memory access calculations
1107 ///
1108 /// # Examples
1109 ///
1110 /// ```
1111 /// use train_station::Tensor;
1112 ///
1113 /// let tensor = Tensor::new(vec![2, 3, 4]);
1114 /// assert_eq!(tensor.strides(), &[12, 4, 1]);
1115 /// ```
1116 #[inline]
1117 #[track_caller]
1118 pub fn strides(&self) -> &[usize] {
1119 self.shape.strides()
1120 }
1121
1122 /// Gets the memory stride for a specific dimension
1123 ///
1124 /// # Arguments
1125 ///
1126 /// * `dim` - The dimension index
1127 ///
1128 /// # Returns
1129 ///
1130 /// The memory stride for the given dimension
1131 ///
1132 /// # Panics
1133 ///
1134 /// Panics if `dim` is out of bounds
1135 #[inline]
1136 #[track_caller]
1137 pub fn stride(&self, dim: usize) -> usize {
1138 self.shape.stride(dim)
1139 }
1140
1141 /// Calculates the linear memory offset for given multi-dimensional indices
1142 ///
1143 /// # Arguments
1144 ///
1145 /// * `indices` - Vector of indices for each dimension
1146 ///
1147 /// # Returns
1148 ///
1149 /// Linear memory offset for direct memory access
1150 ///
1151 /// # Examples
1152 ///
1153 /// ```
1154 /// use train_station::Tensor;
1155 ///
1156 /// let tensor = Tensor::new(vec![2, 3, 4]);
1157 /// let offset = tensor.memory_offset(&[1, 2, 3]);
1158 /// // offset = 1*12 + 2*4 + 3*1 = 23
1159 /// ```
1160 #[inline]
1161 #[track_caller]
1162 pub fn memory_offset(&self, indices: &[usize]) -> usize {
1163 self.shape.offset(indices)
1164 }
1165
1166 /// Broadcast this tensor with another tensor for element-wise operations
1167 ///
1168 /// Returns a tuple containing:
1169 /// - Broadcasted view of self
1170 /// - Broadcasted view of other
1171 /// - Result shape for the operation
1172 ///
1173 /// # Arguments
1174 ///
1175 /// * `other` - The tensor to broadcast with
1176 ///
1177 /// # Returns
1178 ///
1179 /// A tuple `(broadcasted_self, broadcasted_other, result_shape)`
1180 #[track_caller]
1181 pub(crate) fn broadcast_with(
1182 &self,
1183 other: &Tensor,
1184 ) -> Result<
1185 (Tensor, Tensor, crate::tensor::Shape),
1186 crate::tensor::ops::broadcasting::BroadcastError,
1187 > {
1188 crate::tensor::ops::broadcasting::broadcast_shapes(self, other)
1189 }
1190
1191 /// Checks if the tensor data is properly aligned for SIMD operations
1192 ///
1193 /// # Returns
1194 ///
1195 /// `true` if the tensor data is aligned to 32-byte boundaries for AVX2
1196 #[inline]
1197 #[track_caller]
1198 #[cfg(target_arch = "x86_64")]
1199 pub(crate) fn is_simd_aligned(&self) -> bool {
1200 // Maintain AVX2 compatibility for existing SIMD paths
1201 (self.data.as_ptr() as usize).is_multiple_of(32)
1202 }
1203
1204 /// Gets the memory alignment of the tensor data
1205 ///
1206 /// # Returns
1207 ///
1208 /// The memory alignment in bytes (typically 32 for SIMD optimization)
1209 #[inline]
1210 #[track_caller]
1211 pub fn memory_alignment(&self) -> usize {
1212 if let Some(owner) = self.allocation_owner() {
1213 owner.alignment()
1214 } else {
1215 // Zero-sized or non-owned views default to conservative 32
1216 32
1217 }
1218 }
1219
1220 /// Checks if this tensor is broadcastable with another tensor
1221 ///
1222 /// # Arguments
1223 ///
1224 /// * `other` - The other tensor to check broadcasting compatibility
1225 ///
1226 /// # Returns
1227 ///
1228 /// `true` if the tensors are broadcastable according to NumPy broadcasting rules
1229 ///
1230 /// # Examples
1231 ///
1232 /// ```
1233 /// use train_station::Tensor;
1234 ///
1235 /// let a = Tensor::new(vec![2, 3, 4]);
1236 /// let b = Tensor::new(vec![1, 3, 4]);
1237 /// assert!(a.is_broadcastable_with(&b));
1238 /// ```
1239 #[inline]
1240 #[track_caller]
1241 pub fn is_broadcastable_with(&self, other: &Tensor) -> bool {
1242 self.shape.is_broadcastable_with(&other.shape)
1243 }
1244
1245 /// Gets the total number of bytes allocated for this tensor
1246 ///
1247 /// # Returns
1248 ///
1249 /// Total memory footprint in bytes
1250 #[inline]
1251 #[track_caller]
1252 pub fn memory_footprint(&self) -> usize {
1253 self.shape.size() * std::mem::size_of::<f32>()
1254 }
1255
1256 /// Get a single element from the tensor at the specified indices
1257 ///
1258 /// # Arguments
1259 ///
1260 /// * `indices` - Multi-dimensional indices to access the element
1261 ///
1262 /// # Returns
1263 ///
1264 /// The value at the specified position
1265 ///
1266 /// # Panics
1267 ///
1268 /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1269 ///
1270 /// # Examples
1271 ///
1272 /// ```
1273 /// use train_station::Tensor;
1274 ///
1275 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1276 /// let value = tensor.get(&[0, 1]);
1277 /// assert_eq!(value, 2.0);
1278 /// ```
1279 #[track_caller]
1280 pub fn get(&self, indices: &[usize]) -> f32 {
1281 assert_eq!(
1282 indices.len(),
1283 self.shape().rank(),
1284 "Indices length must match tensor rank"
1285 );
1286
1287 // Check bounds
1288 for (i, &idx) in indices.iter().enumerate() {
1289 assert!(
1290 idx < self.shape().dims()[i],
1291 "Index {} out of bounds for dimension {}",
1292 idx,
1293 i
1294 );
1295 }
1296
1297 let offset = self.memory_offset(indices);
1298 unsafe { *self.as_ptr().add(offset) }
1299 }
1300
1301 /// Set a single element in the tensor at the specified indices
1302 ///
1303 /// # Arguments
1304 ///
1305 /// * `indices` - Multi-dimensional indices to set the element
1306 /// * `value` - The value to set
1307 ///
1308 /// # Panics
1309 ///
1310 /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1311 ///
1312 /// # Examples
1313 ///
1314 /// ```
1315 /// use train_station::Tensor;
1316 ///
1317 /// let mut tensor = Tensor::new(vec![2, 2]);
1318 /// tensor.set(&[0, 1], 42.0);
1319 /// assert_eq!(tensor.get(&[0, 1]), 42.0);
1320 /// ```
1321 #[track_caller]
1322 pub fn set(&mut self, indices: &[usize], value: f32) {
1323 assert_eq!(
1324 indices.len(),
1325 self.shape().rank(),
1326 "Indices length must match tensor rank"
1327 );
1328
1329 // Check bounds
1330 for (i, &idx) in indices.iter().enumerate() {
1331 assert!(
1332 idx < self.shape().dims()[i],
1333 "Index {} out of bounds for dimension {}",
1334 idx,
1335 i
1336 );
1337 }
1338
1339 let offset = self.memory_offset(indices);
1340 unsafe {
1341 *self.as_mut_ptr().add(offset) = value;
1342 }
1343 }
1344
1345 /// Returns a safe slice of the tensor's underlying data
1346 ///
1347 /// Provides safe access to the tensor's data without requiring unsafe pointer operations.
1348 /// This is the preferred way to access tensor data for reading values, comparisons,
1349 /// and other operations that don't require direct pointer manipulation.
1350 ///
1351 /// # Returns
1352 ///
1353 /// A slice containing all tensor elements in row-major order
1354 ///
1355 /// # Performance
1356 ///
1357 /// - **Zero-Cost**: Direct slice creation with no copying
1358 /// - **Cache-Friendly**: Sequential memory access pattern
1359 /// - **Safe**: No unsafe code required for basic data access
1360 ///
1361 /// # Examples
1362 ///
1363 /// ```
1364 /// use train_station::Tensor;
1365 ///
1366 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1367 /// let data = tensor.data();
1368 ///
1369 /// // Safe indexing and comparisons
1370 /// assert_eq!(data[0], 1.0);
1371 /// assert_eq!(data.len(), tensor.size());
1372 /// ```
1373 #[inline]
1374 #[track_caller]
1375 pub fn data(&self) -> &[f32] {
1376 if self.size() == 0 {
1377 return &[];
1378 }
1379 unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()) }
1380 }
1381
1382 /// Returns a mutable slice of the tensor's underlying data
1383 ///
1384 /// Provides safe mutable access to the tensor's data without requiring unsafe
1385 /// pointer operations. Use this for in-place modifications of tensor values.
1386 ///
1387 /// # Returns
1388 ///
1389 /// A mutable slice containing all tensor elements in row-major order
1390 ///
1391 /// # Performance
1392 ///
1393 /// - **Zero-Cost**: Direct slice creation with no copying
1394 /// - **Cache-Friendly**: Sequential memory access pattern
1395 /// - **Safe**: No unsafe code required for basic data modification
1396 ///
1397 /// # Examples
1398 ///
1399 /// ```
1400 /// use train_station::Tensor;
1401 ///
1402 /// let mut tensor = Tensor::new(vec![2, 2]);
1403 /// let data = tensor.data_mut();
1404 ///
1405 /// // Safe indexing for modification
1406 /// data[0] = 1.0;
1407 /// data[1] = 2.0;
1408 ///
1409 /// assert_eq!(tensor.get(&[0, 0]), 1.0);
1410 /// assert_eq!(tensor.get(&[0, 1]), 2.0);
1411 /// ```
1412 #[inline]
1413 #[track_caller]
1414 pub fn data_mut(&mut self) -> &mut [f32] {
1415 if self.size() == 0 {
1416 return &mut [];
1417 }
1418 // Copy-on-write to protect existing views
1419 self.ensure_unique_allocation();
1420 unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) }
1421 }
1422
1423 /// Extract scalar value from single-element tensor
1424 ///
1425 /// This method provides a convenient way to extract the scalar value from
1426 /// tensors that contain exactly one element. This is commonly used with
1427 /// element iterator results and scalar tensor operations.
1428 ///
1429 /// # Returns
1430 ///
1431 /// The scalar value contained in this tensor
1432 ///
1433 /// # Panics
1434 ///
1435 /// Panics if the tensor does not contain exactly one element
1436 ///
1437 /// # Examples
1438 ///
1439 /// ```
1440 /// use train_station::Tensor;
1441 ///
1442 /// // Single-element tensor
1443 /// let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
1444 /// assert_eq!(scalar.value(), 42.0);
1445 /// ```
1446 #[inline]
1447 #[track_caller]
1448 pub fn value(&self) -> f32 {
1449 assert_eq!(
1450 self.size(),
1451 1,
1452 "value() can only be called on tensors with exactly one element. \
1453 This tensor has {} elements with shape {:?}",
1454 self.size(),
1455 self.shape().dims()
1456 );
1457 self.data()[0]
1458 }
1459
1460 /// Create a view with a new shape (requires contiguous memory)
1461 ///
1462 /// Behaves like PyTorch `view`: tensor must be contiguous and the total
1463 /// number of elements must remain the same. Supports -1 inference for one dimension.
1464 ///
1465 /// # Arguments
1466 ///
1467 /// * `new_shape` - New shape for the tensor (can contain -1 for inference)
1468 ///
1469 /// # Returns
1470 ///
1471 /// A tensor viewing the same data with a new shape
1472 ///
1473 /// # Examples
1474 ///
1475 /// ```
1476 /// use train_station::Tensor;
1477 ///
1478 /// let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
1479 /// let y = x.view(vec![2, 2]);
1480 /// assert_eq!(y.shape().dims(), vec![2, 2]);
1481 /// ```
1482 #[track_caller]
1483 pub fn view(&self, new_shape: Vec<i32>) -> Tensor {
1484 // PyTorch-like view with single -1 inference; requires contiguity
1485 let size = self.size();
1486 let mut infer_idx: Option<usize> = None;
1487 let mut product: usize = 1;
1488 let mut dims: Vec<usize> = Vec::with_capacity(new_shape.len());
1489 for (i, d) in new_shape.iter().enumerate() {
1490 if *d == -1 {
1491 assert!(infer_idx.is_none(), "Only one -1 is allowed in view shape");
1492 infer_idx = Some(i);
1493 dims.push(1);
1494 } else {
1495 assert!(*d > 0, "Negative dims not supported in view shape");
1496 let du = *d as usize;
1497 product = product.saturating_mul(du);
1498 dims.push(du);
1499 }
1500 }
1501 if let Some(pos) = infer_idx {
1502 assert!(
1503 product > 0 && size.is_multiple_of(product),
1504 "View shape incompatible with tensor size"
1505 );
1506 dims[pos] = size / product;
1507 } else {
1508 assert!(
1509 product == size,
1510 "View shape has different number of elements"
1511 );
1512 }
1513
1514 let mut v = match crate::tensor::core::view::reshape_view(self, &dims) {
1515 Ok(v) => v,
1516 Err(e) => panic!("view reshape error: {:?}", e),
1517 };
1518 if self.requires_grad() && gradtrack::is_grad_enabled() {
1519 v.set_requires_grad(true);
1520 let grad_fn = GradFn::Reshape {
1521 original_shape: self.shape().dims().to_vec(),
1522 };
1523 v.set_grad_fn(grad_fn.clone());
1524 GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
1525 }
1526 v
1527 }
1528
1529 /// Create an element view for the specified index
1530 ///
1531 /// Returns a scalar tensor (shape \[1\]) that views a single element
1532 /// of the source tensor. Maintains gradient tracking.
1533 ///
1534 /// # Arguments
1535 ///
1536 /// * `index` - Linear index of the element to view
1537 ///
1538 /// # Returns
1539 ///
1540 /// A scalar tensor viewing the specified element
1541 ///
1542 /// # Examples
1543 ///
1544 /// ```
1545 /// use train_station::Tensor;
1546 ///
1547 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
1548 /// let element = tensor.element_view(1);
1549 /// assert_eq!(element.value(), 2.0);
1550 /// ```
1551 #[track_caller]
1552 pub fn element_view(&self, index: usize) -> Tensor {
1553 let mut v = match crate::tensor::core::view::element_view_linear(self, index) {
1554 Ok(v) => v,
1555 Err(e) => panic!("element_view error: {:?}", e),
1556 };
1557 if self.requires_grad() && gradtrack::is_grad_enabled() {
1558 v.set_requires_grad(true);
1559 // Reuse SliceView grad for single-element selection to propagate back to source
1560 let grad_fn = GradFn::View {
1561 mapping: crate::gradtrack::grad_fn::ViewMapping::LinearRange {
1562 start: index,
1563 step: 1,
1564 length: 1,
1565 },
1566 input_shape: self.shape().dims().to_vec(),
1567 };
1568 v.set_grad_fn(grad_fn.clone());
1569 GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
1570 }
1571 v
1572 }
1573
1574 /// Create a slice view of the tensor
1575 ///
1576 /// Returns a view of a contiguous or strided slice of the source tensor.
1577 ///
1578 /// # Arguments
1579 ///
1580 /// * `start` - Starting index
1581 /// * `step` - Step size (1 for contiguous)
1582 /// * `length` - Number of elements
1583 ///
1584 /// # Returns
1585 ///
1586 /// A tensor viewing the specified slice
1587 ///
1588 /// # Examples
1589 ///
1590 /// ```
1591 /// use train_station::Tensor;
1592 ///
1593 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1594 /// let slice = tensor.slice_view(1, 2, 2); // [2.0, 4.0]
1595 /// assert_eq!(slice.get(&[0]), 2.0);
1596 /// assert_eq!(slice.get(&[1]), 4.0);
1597 /// ```
1598 #[track_caller]
1599 pub fn slice_view(&self, start: usize, step: usize, length: usize) -> Tensor {
1600 let mut v = match crate::tensor::core::view::slice_view_linear(self, start, step, length) {
1601 Ok(v) => v,
1602 Err(e) => panic!("slice_view error: {:?}", e),
1603 };
1604 // Ensure correct data() representation for stepped views
1605 // Offset-only (step==1, start>0) is already contiguous via base pointer
1606 if step != 1 {
1607 v = v.contiguous();
1608 }
1609 if self.requires_grad() && gradtrack::is_grad_enabled() {
1610 v.set_requires_grad(true);
1611 let grad_fn = GradFn::View {
1612 mapping: crate::gradtrack::grad_fn::ViewMapping::LinearRange {
1613 start,
1614 step,
1615 length,
1616 },
1617 input_shape: self.shape().dims().to_vec(),
1618 };
1619 v.set_grad_fn(grad_fn.clone());
1620 GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
1621 }
1622 v
1623 }
1624
1625 /// Create a tensor view from raw components
1626 ///
1627 /// Creates a tensor that views existing memory with the specified shape and device.
1628 /// The tensor shares memory with the original allocation through the allocation_owner.
1629 ///
1630 /// # Safety
1631 ///
1632 /// The caller must ensure:
1633 /// - `data` is valid for the number of elements specified by `shape`
1634 /// - `data` remains valid for the lifetime of the returned tensor
1635 /// - `allocation_owner` properly manages the memory lifecycle
1636 ///
1637 /// # Arguments
1638 ///
1639 /// * `data` - Raw pointer to the tensor data
1640 /// * `shape` - Shape of the tensor view
1641 /// * `device` - Device where the tensor is located
1642 /// * `allocation_owner` - Optional shared allocation owner
1643 ///
1644 /// # Returns
1645 ///
1646 /// A new tensor that views the specified memory
1647 ///
1648 /// # Implementation Details
1649 ///
1650 /// This method is used internally to create tensor views that share memory
1651 /// with other tensors. It's primarily used for view operations and memory
1652 /// management.
1653 pub(crate) fn from_raw_view(
1654 data: *const f32,
1655 shape: Shape,
1656 device: Device,
1657 allocation_owner: Option<Arc<Allocation>>,
1658 ) -> Self {
1659 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1660
1661 Self {
1662 data: NonNull::new(data as *mut f32).expect("Data pointer cannot be null"),
1663 shape,
1664 device,
1665 id,
1666 requires_grad: false,
1667 retain_grad: false,
1668 grad: None,
1669 grad_fn: GradFn::None,
1670 allocation_owner,
1671 graph_group: None,
1672 _phantom: PhantomData,
1673 }
1674 }
1675
1676 /// Get the allocation owner for this tensor
1677 ///
1678 /// Returns the shared allocation owner if this tensor is a view,
1679 /// or None if this tensor owns its memory directly.
1680 ///
1681 /// # Returns
1682 ///
1683 /// Optional reference to the allocation owner
1684 ///
1685 /// # Implementation Details
1686 ///
1687 /// This method is used internally to manage memory lifecycle for tensor views.
1688 /// It helps determine whether a tensor shares memory with another tensor.
1689 #[track_caller]
1690 pub fn allocation_owner(&self) -> Option<&Arc<Allocation>> {
1691 self.allocation_owner.as_ref()
1692 }
1693
1694 /// Create a new tensor with uninitialized memory
1695 ///
1696 /// This method allocates memory for a tensor without initializing it to any value.
1697 /// This is useful for performance-critical operations where the memory will be
1698 /// immediately overwritten, such as matrix multiplication results.
1699 ///
1700 /// # Safety
1701 ///
1702 /// The caller must ensure that all memory is written before reading from the tensor.
1703 /// Reading from uninitialized memory is undefined behavior.
1704 ///
1705 /// # Arguments
1706 ///
1707 /// * `shape_dims` - The dimensions of the tensor
1708 ///
1709 /// # Returns
1710 ///
1711 /// A tensor with uninitialized memory
1712 ///
1713 /// # Performance
1714 ///
1715 /// - **Zero Initialization**: Skips memory initialization for maximum performance
1716 /// - **SIMD Ready**: Properly aligned for vectorized operations
1717 /// - **Memory Efficient**: Uses optimized alignment strategies
1718 ///
1719 /// # Example
1720 ///
1721 /// ```
1722 /// use train_station::Tensor;
1723 ///
1724 /// // Create uninitialized tensor for matmul result
1725 /// let mut result = Tensor::new_uninitialized(vec![100, 100]);
1726 /// // Initialize the memory before use
1727 /// for value in result.data_mut() {
1728 /// *value = 0.0;
1729 /// }
1730 /// ```
1731 #[inline]
1732 #[track_caller]
1733 pub fn new_uninitialized(shape_dims: Vec<usize>) -> Self {
1734 let shape = Shape::new(shape_dims);
1735 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1736
1737 if shape.size() == 0 {
1738 // Handle zero-sized tensors
1739 return Self {
1740 data: NonNull::dangling(),
1741 shape,
1742 device: current_device(),
1743 id,
1744 requires_grad: false,
1745 retain_grad: false,
1746 grad: None,
1747 grad_fn: GradFn::None,
1748 allocation_owner: None,
1749 graph_group: None,
1750 _phantom: PhantomData,
1751 };
1752 }
1753
1754 let (alignment, padded_elems) = compute_allocation_params(shape.size());
1755 let total_size = padded_elems * std::mem::size_of::<f32>();
1756 let layout = Layout::from_size_align(total_size, alignment)
1757 .expect("Failed to create layout for tensor data");
1758
1759 // Allocate memory via shared Allocation (uninitialized)
1760 let alloc_obj = if use_pool_alloc_enabled() {
1761 Allocation::new_pooled(padded_elems, alignment, layout)
1762 } else {
1763 Allocation::new_uninitialized(padded_elems, alignment, layout)
1764 };
1765 let ptr = alloc_obj.ptr;
1766
1767 debug_assert!(
1768 alloc_obj.capacity_elems() >= padded_elems,
1769 "Allocation capacity ({}) smaller than padded elements ({})",
1770 alloc_obj.capacity_elems(),
1771 padded_elems
1772 );
1773
1774 Self {
1775 data: ptr,
1776 shape,
1777 device: current_device(),
1778 id,
1779 requires_grad: false,
1780 retain_grad: false,
1781 grad: None,
1782 grad_fn: GradFn::None,
1783 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
1784 graph_group: None,
1785 _phantom: PhantomData,
1786 }
1787 }
1788
1789 /// Create a new uninitialized tensor with an explicit alignment request (in bytes)
1790 ///
1791 /// This is intended for internal high-performance paths (e.g., packed GEMM panels)
1792 /// where stronger alignment such as 64 bytes is desired even on AVX2 systems.
1793 #[inline]
1794 #[track_caller]
1795 pub fn new_uninitialized_aligned(shape_dims: Vec<usize>, alignment_bytes: usize) -> Self {
1796 let shape = Shape::new(shape_dims);
1797 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1798
1799 if shape.size() == 0 {
1800 return Self {
1801 data: NonNull::dangling(),
1802 shape,
1803 device: current_device(),
1804 id,
1805 requires_grad: false,
1806 retain_grad: false,
1807 grad: None,
1808 grad_fn: GradFn::None,
1809 allocation_owner: None,
1810 graph_group: None,
1811 _phantom: PhantomData,
1812 };
1813 }
1814
1815 // Preserve the usual element padding policy
1816 let (_default_align, padded_elems) = compute_allocation_params(shape.size());
1817 // Honor explicit alignment request (at least 16)
1818 let alignment = alignment_bytes.max(16);
1819 let total_size = padded_elems * std::mem::size_of::<f32>();
1820 let layout = Layout::from_size_align(total_size, alignment)
1821 .expect("Failed to create layout for tensor data (aligned)");
1822
1823 let alloc_obj = if use_pool_alloc_enabled() {
1824 Allocation::new_pooled(padded_elems, alignment, layout)
1825 } else {
1826 Allocation::new_uninitialized(padded_elems, alignment, layout)
1827 };
1828 let ptr = alloc_obj.ptr;
1829
1830 debug_assert!(alloc_obj.capacity_elems() >= padded_elems);
1831
1832 Self {
1833 data: ptr,
1834 shape,
1835 device: current_device(),
1836 id,
1837 requires_grad: false,
1838 retain_grad: false,
1839 grad: None,
1840 grad_fn: GradFn::None,
1841 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
1842 graph_group: None,
1843 _phantom: PhantomData,
1844 }
1845 }
1846}
1847
1848#[cfg(test)]
1849mod memory_alloc_tests {
1850 use super::*;
1851 use crate::tensor::core::memory::{
1852 detect_runtime_simd, simd_alignment_bytes, simd_lane_width_elems, with_no_mem_padding,
1853 TensorMemoryPool,
1854 };
1855
1856 #[test]
1857 fn test_padding_and_alignment_pool_enabled() {
1858 let lane = simd_lane_width_elems(detect_runtime_simd());
1859 let align = simd_alignment_bytes(detect_runtime_simd());
1860
1861 let req = lane * 3 + 1; // force padding
1862 let t = Tensor::new(vec![req]);
1863
1864 assert_eq!(t.capacity_elems() % lane, 0);
1865 unsafe {
1866 assert_eq!((t.as_ptr() as usize) % align, 0);
1867 }
1868 // Logical size is not a multiple of lane, so aligned SIMD ops are not preferred
1869 #[cfg(target_arch = "x86_64")]
1870 assert!(!t.prefer_aligned_simd_ops());
1871
1872 drop(t);
1873 }
1874
1875 #[test]
1876 fn test_no_padding_with_guard() {
1877 with_no_mem_padding(|| {
1878 let lane = simd_lane_width_elems(detect_runtime_simd());
1879 let req = lane * 3 + 1; // non-multiple
1880 let t = Tensor::new(vec![req]);
1881
1882 assert_eq!(t.size(), req);
1883 #[cfg(target_arch = "x86_64")]
1884 assert!(!t.prefer_aligned_simd_ops());
1885 });
1886 }
1887
1888 #[test]
1889 fn test_system_alloc_no_pool_counters_unchanged() {
1890 let before = TensorMemoryPool::thread_stats();
1891 let _t1 = Tensor::new(vec![128]);
1892 let _t2 = Tensor::new(vec![257]);
1893 let after = TensorMemoryPool::thread_stats();
1894 assert_eq!(before.allocations, 0);
1895 assert_eq!(after.allocations, 2);
1896 assert_eq!(before.deallocations, 0);
1897 }
1898
1899 #[test]
1900 fn test_pool_alloc_and_dealloc_counters_match() {
1901 let before = TensorMemoryPool::thread_stats();
1902 {
1903 let _t1 = Tensor::new(vec![64]);
1904 let _t2 = Tensor::new(vec![2048]);
1905 let _t3 = Tensor::new(vec![131072]);
1906 }
1907 let after = TensorMemoryPool::thread_stats();
1908 assert!(after.allocations >= before.allocations + 3);
1909 assert!(after.deallocations >= before.deallocations + 3);
1910 }
1911
1912 #[test]
1913 fn test_mixed_modes_no_leak_in_pool_stats_scope() {
1914 let before = TensorMemoryPool::thread_stats();
1915 {
1916 let _a = Tensor::new(vec![1000]);
1917 let _b = Tensor::new(vec![1000]);
1918 let _c = Tensor::new(vec![2000]);
1919 let _c = Tensor::new(vec![2000]);
1920 }
1921
1922 let after = TensorMemoryPool::thread_stats();
1923 assert_eq!(
1924 after.allocations - before.allocations,
1925 after.deallocations - before.deallocations
1926 );
1927 }
1928
1929 #[cfg(target_arch = "x86_64")]
1930 #[test]
1931 fn test_alignment_helpers_and_hints() {
1932 let lane = simd_lane_width_elems(detect_runtime_simd());
1933 let req = lane * 4; // exact multiple
1934 let t = Tensor::new(vec![req]);
1935 assert!(t.is_aligned_for_runtime_level());
1936 assert!(t.prefer_aligned_simd_ops());
1937 assert!(!t.should_use_unaligned_simd_ops());
1938
1939 with_no_mem_padding(|| {
1940 let lane = simd_lane_width_elems(detect_runtime_simd());
1941 let req = lane * 4 + 1; // not multiple
1942 let t = Tensor::new(vec![req]);
1943 assert!(!t.prefer_aligned_simd_ops());
1944 assert!(t.should_use_unaligned_simd_ops());
1945 });
1946 }
1947}
1948
1949#[cfg(test)]
1950mod memory_alloc_additional_tests {
1951 use super::*;
1952 use crate::tensor::core::memory::{
1953 detect_runtime_simd, simd_lane_width_elems, with_no_mem_padding, with_no_mem_pool,
1954 TensorMemoryPool,
1955 };
1956 use std::time::Instant;
1957
1958 #[test]
1959 fn test_zero_size_tensors_pool_and_system() {
1960 let t = Tensor::new(vec![0]);
1961 assert_eq!(t.size(), 0);
1962 assert_eq!(t.capacity_elems(), 0);
1963 assert_eq!(t.data().len(), 0);
1964 let t = Tensor::new(vec![0]);
1965 assert_eq!(t.size(), 0);
1966 assert_eq!(t.capacity_elems(), 0);
1967 assert_eq!(t.data().len(), 0);
1968 }
1969
1970 #[test]
1971 fn test_capacity_across_various_sizes_padding_modes() {
1972 let lane = simd_lane_width_elems(detect_runtime_simd());
1973 let sizes = [
1974 0,
1975 1,
1976 lane - 1,
1977 lane,
1978 lane + 1,
1979 2 * lane - 1,
1980 2 * lane + 1,
1981 1000,
1982 1001,
1983 ];
1984
1985 for &n in &sizes {
1986 let t = Tensor::new(vec![n]);
1987 let expected_padded = if n == 0 { 0 } else { n.div_ceil(lane) * lane };
1988 // Pool may round up further than lane padding; ensure at least padded and lane-aligned
1989 assert!(
1990 t.capacity_elems() >= expected_padded,
1991 "n={} lane={} cap={}",
1992 n,
1993 lane,
1994 t.capacity_elems()
1995 );
1996 if t.capacity_elems() > 0 {
1997 assert_eq!(
1998 t.capacity_elems() % lane,
1999 0,
2000 "capacity not lane-multiple: {}",
2001 t.capacity_elems()
2002 );
2003 }
2004 }
2005 with_no_mem_padding(|| {
2006 for &n in &sizes {
2007 let t = Tensor::new(vec![n]);
2008 assert_eq!(t.size(), n);
2009 }
2010 });
2011 }
2012
2013 #[test]
2014 fn test_class_boundary_planned_capacity_pool() {
2015 let boundaries = [
2016 crate::tensor::core::memory::SMALL_BUFFER_SIZE,
2017 crate::tensor::core::memory::SMALL_BUFFER_SIZE + 1,
2018 crate::tensor::core::memory::MEDIUM_BUFFER_SIZE,
2019 crate::tensor::core::memory::MEDIUM_BUFFER_SIZE + 1,
2020 crate::tensor::core::memory::LARGE_BUFFER_SIZE,
2021 crate::tensor::core::memory::LARGE_BUFFER_SIZE + 1,
2022 ];
2023 let lane = simd_lane_width_elems(detect_runtime_simd());
2024 for &n in &boundaries {
2025 let padded = if n == 0 { 0 } else { n.div_ceil(lane) * lane };
2026 let planned = TensorMemoryPool::planned_capacity_elems(padded);
2027 let t = Tensor::new(vec![n]);
2028 assert_eq!(t.capacity_elems(), planned, "boundary n={}", n);
2029 }
2030 }
2031
2032 #[test]
2033 fn test_view_does_not_allocate_additional_memory() {
2034 let before = TensorMemoryPool::thread_stats();
2035 let t = Tensor::new(vec![256]);
2036 let _view = t.view(vec![16, 16]);
2037 let after = TensorMemoryPool::thread_stats();
2038 // Exactly one allocation happened (for t), view didn't allocate
2039 assert_eq!(after.allocations, before.allocations + 1);
2040 }
2041
2042 #[test]
2043 fn test_performance_pool_vs_system_small_allocs() {
2044 let iters = 1000;
2045 let _ = Tensor::new(vec![64]);
2046 let _ = Tensor::new(vec![64]);
2047 let _ = Tensor::new(vec![64]);
2048
2049 let pool_time = {
2050 let start = Instant::now();
2051 for _ in 0..iters {
2052 let _t = Tensor::new(vec![64]);
2053 }
2054 start.elapsed()
2055 };
2056 let sys_time = with_no_mem_pool(|| {
2057 let start = Instant::now();
2058 for _ in 0..iters {
2059 let _t = Tensor::new(vec![64]);
2060 }
2061 start.elapsed()
2062 });
2063 assert!(
2064 pool_time <= sys_time * 10,
2065 "pool {:?} vs system {:?}",
2066 pool_time,
2067 sys_time
2068 );
2069 }
2070
2071 #[test]
2072 fn test_performance_padded_vs_unpadded_fill() {
2073 let lane = simd_lane_width_elems(detect_runtime_simd());
2074 let n = lane * 2048;
2075
2076 let padded_time = {
2077 let t = Tensor::new(vec![n]);
2078 let mut x = t.clone();
2079 let start = Instant::now();
2080 x.fill(1.2345);
2081 start.elapsed()
2082 };
2083 let unpadded_time = with_no_mem_padding(|| {
2084 let t = Tensor::new(vec![n + 1]);
2085 let mut x = t.clone();
2086 let start = Instant::now();
2087 x.fill(1.2345);
2088 start.elapsed()
2089 });
2090 assert!(
2091 padded_time <= unpadded_time * 100,
2092 "padded {:?} vs unpadded {:?}",
2093 padded_time,
2094 unpadded_time
2095 );
2096 }
2097}