train_station/tensor/core/
utils.rs

1//! Tensor utility functions and core implementation methods
2//!
3//! This module provides essential utility functions for tensor creation, memory management,
4//! gradient tracking, and optimization. It contains the core implementation methods
5//! that enable efficient tensor operations and gradtrack functionality.
6//!
7//! # Key Features
8//!
9//! - **Tensor Creation**: Optimized constructors with memory alignment
10//! - **Memory Management**: Safe memory access and allocation utilities
11//! - **Gradient Tracking**: GradTrack system integration and gradient management
12//! - **Performance Optimization**: SIMD-ready memory layout and alignment
13//! - **Device Management**: CPU and future CUDA device support
14//! - **Memory Layout**: Contiguous, strided, and view memory access patterns
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Alignment**: 16-byte SSE, 32-byte AVX2, 64-byte cache-line alignment
19//! - **SIMD Optimization**: Properly aligned memory for vectorized operations
20//! - **Zero-Cost Abstractions**: Minimal overhead for utility operations
21//! - **Thread Safety**: Atomic operations for gradient tracking and ID generation
22//! - **Memory Efficiency**: Optimized allocation strategies for different tensor sizes
23//!
24//! # Examples
25//!
26//! ## Basic Tensor Creation
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensors of different sizes
32//! let small_tensor = Tensor::new(vec![2, 3]);      // 16-byte alignment
33//! let medium_tensor = Tensor::new(vec![32, 32]);   // 32-byte alignment
34//! let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
35//!
36//! // Initialize data before use
37//! let mut tensor = Tensor::new(vec![2, 3]);
38//! tensor.fill(0.0); // Initialize with zeros
39//! ```
40//!
41//! ## Gradient Tracking
42//!
43//! ```
44//! use train_station::Tensor;
45//!
46//! let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
47//! assert!(tensor.requires_grad());
48//! ```
49//!
50//! ## Memory Access
51//!
52//! ```
53//! use train_station::Tensor;
54//!
55//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
56//! let value = tensor.get(&[0, 1]);
57//! assert_eq!(value, 2.0);
58//!
59//! let mut tensor = Tensor::new(vec![2, 2]);
60//! tensor.set(&[0, 1], 42.0);
61//! assert_eq!(tensor.get(&[0, 1]), 42.0);
62//! ```
63
64use std::alloc::Layout;
65use std::marker::PhantomData;
66use std::ptr::NonNull;
67use std::sync::atomic::Ordering;
68use std::sync::Arc;
69
70use crate::device::current_device;
71use crate::gradtrack::engine::ensure_local_group_for_tensor;
72use crate::gradtrack::{self, GradEngine, GradFn};
73use crate::tensor::core::memory::{
74    compute_allocation_params, detect_runtime_simd, simd_lane_width_elems, use_pool_alloc_enabled,
75    SimdLevel,
76};
77use crate::tensor::core::{Allocation, Device, TENSOR_ID_COUNTER};
78use crate::tensor::Shape;
79
80use super::Tensor;
81
82impl Tensor {
83    /// Ensures that this tensor has unique ownership of its allocation before mutation.
84    ///
85    /// If the underlying allocation is shared (i.e., multiple views), this performs
86    /// a copy-on-write: allocates a new buffer and copies data so that mutating this tensor
87    /// does not affect existing views.
88    fn ensure_unique_allocation(&mut self) {
89        if self.size() == 0 {
90            return;
91        }
92        if let Some(owner) = self.allocation_owner.as_ref() {
93            if std::sync::Arc::strong_count(owner) > 1 {
94                // Allocate a new buffer with the same size/alignment policy as constructors
95                let (alignment, padded_elems) = compute_allocation_params(self.shape.size());
96                let total_size = padded_elems * std::mem::size_of::<f32>();
97                let layout = Layout::from_size_align(total_size, alignment)
98                    .expect("Failed to create layout for tensor data");
99
100                let alloc_obj = if use_pool_alloc_enabled() {
101                    Allocation::new_pooled(padded_elems, alignment, layout)
102                } else {
103                    Allocation::new_uninitialized(padded_elems, alignment, layout)
104                };
105
106                let new_ptr = alloc_obj.ptr;
107                unsafe {
108                    std::ptr::copy_nonoverlapping(self.as_ptr(), new_ptr.as_ptr(), self.size());
109                }
110
111                // Replace pointer and owner with the new unique allocation
112                self.data = new_ptr;
113                self.allocation_owner = Some(std::sync::Arc::new(alloc_obj));
114            }
115        }
116    }
117    // ===== Runtime SIMD capability helpers for ops selection =====
118
119    /// Returns the highest runtime-detected SIMD level on this CPU
120    #[inline]
121    pub(crate) fn simd_runtime_level() -> SimdLevel {
122        detect_runtime_simd()
123    }
124
125    /// Returns the SIMD lane width (elements per vector) for f32 at runtime
126    #[inline]
127    pub(crate) fn simd_lane_width_elems_runtime() -> usize {
128        simd_lane_width_elems(Self::simd_runtime_level())
129    }
130
131    /// Checks whether this tensor's data pointer is aligned for the specified SIMD level
132    #[inline]
133    #[cfg(test)]
134    #[cfg(target_arch = "x86_64")]
135    pub(crate) fn is_aligned_for_level(&self, level: SimdLevel) -> bool {
136        use crate::tensor::core::memory::simd_alignment_bytes;
137
138        let required = simd_alignment_bytes(level);
139        (self.data.as_ptr() as usize).is_multiple_of(required)
140    }
141
142    /// Returns true if this tensor is aligned for the current runtime SIMD level
143    #[inline]
144    #[cfg(test)]
145    #[cfg(target_arch = "x86_64")]
146    pub(crate) fn is_aligned_for_runtime_level(&self) -> bool {
147        self.is_aligned_for_level(Self::simd_runtime_level())
148    }
149
150    /// Returns true if ops should use aligned SIMD loads/stores without a tail branch
151    /// (i.e., pointer is aligned for the runtime SIMD level and length is a multiple of lane)
152    #[inline]
153    #[cfg(test)]
154    #[cfg(target_arch = "x86_64")]
155    pub(crate) fn prefer_aligned_simd_ops(&self) -> bool {
156        let lane = Self::simd_lane_width_elems_runtime();
157        self.is_aligned_for_runtime_level() && self.size().is_multiple_of(lane)
158    }
159
160    /// Returns true if ops should use unaligned SIMD loads/stores (mm_loadu)
161    /// because pointer alignment or length does not meet aligned requirements.
162    /// Returns false when no SIMD is available.
163    #[inline]
164    #[cfg(target_arch = "x86_64")]
165    #[cfg(test)]
166    pub(crate) fn should_use_unaligned_simd_ops(&self) -> bool {
167        match Self::simd_runtime_level() {
168            SimdLevel::Scalar => false,
169            _ => !self.prefer_aligned_simd_ops(),
170        }
171    }
172
173    /// Returns the allocated capacity in elements, which may be padded beyond logical size
174    #[inline]
175    pub fn capacity_elems(&self) -> usize {
176        if let Some(owner) = self.allocation_owner() {
177            owner.capacity_elems()
178        } else {
179            self.size()
180        }
181    }
182    /// Creates a new tensor with the specified shape and optimized memory layout
183    ///
184    /// Allocates memory with size-dependent alignment for optimal performance:
185    /// - Small tensors (≤8 elements): 16-byte SSE alignment
186    /// - Medium tensors (8-1024 elements): 32-byte AVX2 alignment
187    /// - Large tensors (>1024 elements): 64-byte cache-line alignment
188    ///
189    /// # Arguments
190    ///
191    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
192    ///
193    /// # Returns
194    ///
195    /// A new tensor with uninitialized data. The data must be initialized
196    /// before use to avoid undefined behavior.
197    ///
198    /// # Performance
199    ///
200    /// - **Memory Allocation**: Single allocation with optimized alignment
201    /// - **SIMD Ready**: Properly aligned for vectorized operations
202    /// - **Cache Friendly**: Optimized for CPU cache hierarchies
203    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
204    ///
205    /// # Safety
206    ///
207    /// The returned tensor contains uninitialized memory. You must initialize
208    /// the data before performing any operations that read from it.
209    ///
210    /// # Examples
211    ///
212    /// ```
213    /// use train_station::Tensor;
214    ///
215    /// // Create tensors of different sizes
216    /// let small_tensor = Tensor::new(vec![2, 3]);      // 16-byte alignment
217    /// let medium_tensor = Tensor::new(vec![32, 32]);   // 32-byte alignment
218    /// let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
219    ///
220    /// // Initialize data before use
221    /// let mut tensor = Tensor::new(vec![2, 3]);
222    /// tensor.fill(0.0); // Initialize with zeros
223    /// ```
224    #[inline]
225    #[track_caller]
226    pub fn new(shape_dims: Vec<usize>) -> Self {
227        let shape = Shape::new(shape_dims);
228        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
229
230        if shape.size() == 0 {
231            // Handle zero-sized tensors
232            return Self {
233                data: NonNull::dangling(),
234                shape,
235                device: current_device(),
236                id,
237                requires_grad: false,
238                retain_grad: false,
239                grad: None,
240                grad_fn: GradFn::None,
241                allocation_owner: None,
242                graph_group: None,
243                _phantom: PhantomData,
244            };
245        }
246
247        // Compute alignment and padded element count based on runtime SIMD
248        let (alignment, padded_elems) = compute_allocation_params(shape.size());
249        let total_size = padded_elems * std::mem::size_of::<f32>();
250
251        let layout = Layout::from_size_align(total_size, alignment)
252            .expect("Failed to create layout for tensor data");
253
254        // Allocation policy: prefer pool for tiny tensors and medium/large classes.
255        // Route certain small sizes to system allocator to keep pool stats semantics used by tests.
256        let alloc_obj = if use_pool_alloc_enabled() {
257            Allocation::new_pooled(padded_elems, alignment, layout)
258        } else {
259            Allocation::new(padded_elems, alignment, layout)
260        };
261        let ptr = alloc_obj.ptr;
262
263        debug_assert!(
264            alloc_obj.capacity_elems() >= padded_elems,
265            "Allocation capacity ({}) smaller than padded elements ({})",
266            alloc_obj.capacity_elems(),
267            padded_elems
268        );
269
270        Self {
271            data: ptr,
272            shape,
273            device: current_device(),
274            id,
275            requires_grad: false,
276            retain_grad: false,
277            grad: None,
278            grad_fn: GradFn::None,
279            allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
280            graph_group: None,
281            _phantom: PhantomData,
282        }
283    }
284
285    /// Returns the shape and dimensional information of the tensor
286    ///
287    /// Provides access to the tensor's dimensions, size, strides, and memory
288    /// layout information. This is used for shape validation, memory access
289    /// calculations, and optimization decisions.
290    ///
291    /// # Returns
292    ///
293    /// Reference to the tensor's shape information containing dimensions,
294    /// size, strides, and memory layout type.
295    ///
296    /// # Performance
297    ///
298    /// - **Time Complexity**: O(1) - direct field access
299    /// - **Memory**: No allocation - returns reference to existing data
300    ///
301    /// # Examples
302    ///
303    /// ```
304    /// use train_station::Tensor;
305    ///
306    /// let tensor = Tensor::new(vec![2, 3, 4]);
307    /// let shape = tensor.shape();
308    /// assert_eq!(shape.dims(), vec![2, 3, 4]);
309    /// assert_eq!(shape.size(), 24);
310    /// assert_eq!(shape.rank(), 3);
311    /// ```
312    #[inline]
313    #[track_caller]
314    pub fn shape(&self) -> &Shape {
315        &self.shape
316    }
317
318    /// Returns the total number of elements in the tensor
319    ///
320    /// Provides the total count of elements across all dimensions. This is
321    /// used for memory allocation, iteration bounds, and performance optimization.
322    ///
323    /// # Returns
324    ///
325    /// Total number of elements as `usize`
326    ///
327    /// # Performance
328    ///
329    /// - **Time Complexity**: O(1) - direct field access
330    /// - **Memory**: No allocation - returns stored value
331    ///
332    /// # Examples
333    ///
334    /// ```
335    /// use train_station::Tensor;
336    ///
337    /// let tensor = Tensor::new(vec![2, 3, 4]);
338    /// assert_eq!(tensor.size(), 24); // 2 * 3 * 4
339    ///
340    /// let scalar = Tensor::new(vec![1]);
341    /// assert_eq!(scalar.size(), 1);
342    ///
343    /// let empty = Tensor::new(vec![0]);
344    /// assert_eq!(empty.size(), 0);
345    /// ```
346    #[inline]
347    #[track_caller]
348    pub fn size(&self) -> usize {
349        self.shape().size()
350    }
351
352    /// Returns the device where this tensor is located
353    ///
354    /// Provides the physical location of the tensor data (CPU/GPU). This
355    /// determines which operations can be performed on the tensor and where
356    /// computations will be executed.
357    ///
358    /// # Returns
359    ///
360    /// Device enum indicating the tensor's physical location
361    ///
362    /// # Performance
363    ///
364    /// - **Time Complexity**: O(1) - direct field access
365    /// - **Memory**: No allocation - returns stored value
366    ///
367    /// # Examples
368    ///
369    /// ```
370    /// use train_station::Tensor;
371    ///
372    /// let tensor = Tensor::new(vec![2, 3]);
373    /// assert!(tensor.device().is_cpu());
374    /// assert!(!tensor.device().is_cuda());
375    /// ```
376    #[inline]
377    #[track_caller]
378    pub fn device(&self) -> Device {
379        self.device
380    }
381
382    /// Creates a new tensor with the specified shape on a specific device
383    ///
384    /// Allocates memory on the specified device with the same optimized alignment
385    /// strategy as `new()`. Currently supports CPU device with future CUDA support.
386    ///
387    /// # Arguments
388    ///
389    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
390    /// * `device` - The device where the tensor should be allocated
391    ///
392    /// # Returns
393    ///
394    /// A new tensor with uninitialized data on the specified device
395    ///
396    /// # Performance
397    ///
398    /// - **Memory Allocation**: Device-specific allocation with optimized alignment
399    /// - **SIMD Ready**: Properly aligned for vectorized operations on target device
400    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
401    ///
402    /// # Panics
403    ///
404    /// Panics if the specified device is not supported (e.g., CUDA without feature flag)
405    ///
406    /// # Examples
407    ///
408    /// ```
409    /// use train_station::Tensor;
410    ///
411    /// let tensor = Tensor::new_on_device(vec![2, 3], train_station::Device::cpu());
412    /// assert!(tensor.device().is_cpu());
413    /// assert_eq!(tensor.size(), 6);
414    /// ```
415    ///
416    /// # Arguments
417    ///
418    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
419    /// * `device` - Device where the tensor should be allocated
420    ///
421    /// # Returns
422    ///
423    /// A new tensor with uninitialized data on the specified device
424    ///
425    /// # Panics
426    ///
427    /// Panics if the device is not supported (currently only CPU is supported)
428    ///
429    /// # Performance
430    ///
431    /// - **Memory Allocation**: Single allocation with optimized alignment
432    /// - **SIMD Ready**: Properly aligned for vectorized operations
433    /// - **Cache Friendly**: Optimized for CPU cache hierarchies
434    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
435    ///
436    /// # Examples
437    ///
438    /// ```
439    /// use train_station::{Tensor, Device};
440    ///
441    /// // Create tensor on CPU device
442    /// let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
443    /// assert_eq!(tensor.device(), Device::cpu());
444    /// assert_eq!(tensor.size(), 6);
445    /// ```
446    #[track_caller]
447    pub fn new_on_device(shape_dims: Vec<usize>, device: Device) -> Self {
448        // For now, only CPU is supported
449        if !device.is_cpu() {
450            panic!("Only CPU device is currently supported. CUDA support is planned for future releases.");
451        }
452
453        let shape = Shape::new(shape_dims);
454        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
455
456        if shape.size() == 0 {
457            // Handle zero-sized tensors
458            return Self {
459                data: NonNull::dangling(),
460                shape,
461                device,
462                id,
463                requires_grad: false,
464                retain_grad: false,
465                grad: None,
466                grad_fn: GradFn::None,
467                allocation_owner: None,
468                graph_group: None,
469                _phantom: PhantomData,
470            };
471        }
472
473        let (alignment, padded_elems) = compute_allocation_params(shape.size());
474        let total_size = padded_elems * std::mem::size_of::<f32>();
475        let layout = Layout::from_size_align(total_size, alignment)
476            .expect("Failed to create layout for tensor data");
477
478        // Same allocation policy as new(): prefer pool for tiny and medium/large classes
479        let alloc_obj = if use_pool_alloc_enabled() {
480            Allocation::new_pooled(padded_elems, alignment, layout)
481        } else {
482            Allocation::new(padded_elems, alignment, layout)
483        };
484        let ptr = alloc_obj.ptr;
485
486        debug_assert!(
487            alloc_obj.capacity_elems() >= padded_elems,
488            "Allocation capacity ({}) smaller than padded elements ({})",
489            alloc_obj.capacity_elems(),
490            padded_elems
491        );
492
493        Self {
494            data: ptr,
495            shape,
496            device,
497            id,
498            requires_grad: false,
499            retain_grad: false,
500            grad: None,
501            grad_fn: GradFn::None,
502            allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
503            graph_group: None,
504            _phantom: PhantomData,
505        }
506    }
507
508    /// Enable gradient computation for this tensor
509    ///
510    /// Builder method that enables automatic gradient tracking for this tensor.
511    /// When enabled, all operations involving this tensor will be recorded in
512    /// the computation graph for gradient computation during backward pass.
513    ///
514    /// # Returns
515    ///
516    /// `self` with gradient tracking enabled
517    ///
518    /// # Performance
519    ///
520    /// - **Time Complexity**: O(1) - simple field assignment
521    /// - **Memory**: No additional allocation
522    /// - **Overhead**: Minimal gradtrack tracking overhead when gradients computed
523    ///
524    /// # Examples
525    ///
526    /// ```
527    /// use train_station::Tensor;
528    ///
529    /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
530    /// assert!(tensor.requires_grad());
531    /// ```
532    #[track_caller]
533    pub fn with_requires_grad(mut self) -> Self {
534        self.requires_grad = true;
535        if self.graph_group.is_none() {
536            self.graph_group = Some(ensure_local_group_for_tensor(self.id()));
537        }
538        self
539    }
540
541    /// Set gradient tracking for this tensor
542    ///
543    /// Controls whether the gradtrack system tracks operations on this tensor
544    /// and computes gradients during backward pass. When disabled, clears
545    /// any existing gradients and gradient functions.
546    ///
547    /// # Arguments
548    ///
549    /// * `requires_grad` - Whether to track gradients for this tensor
550    ///
551    /// # Performance
552    ///
553    /// - **Time Complexity**: O(1) - simple field assignment
554    /// - **Memory**: May free gradient storage when disabled
555    /// - **Overhead**: Zero overhead when gradients disabled
556    ///
557    /// # Examples
558    ///
559    /// ```
560    /// use train_station::Tensor;
561    ///
562    /// let mut tensor = Tensor::ones(vec![2, 3]);
563    /// tensor.set_requires_grad(true);
564    /// assert!(tensor.requires_grad());
565    ///
566    /// // Disable gradient tracking
567    /// tensor.set_requires_grad(false);
568    /// assert!(!tensor.requires_grad());
569    /// ```
570    #[track_caller]
571    pub fn set_requires_grad(&mut self, requires_grad: bool) {
572        self.requires_grad = requires_grad;
573        if !requires_grad {
574            self.grad = None;
575            self.grad_fn = GradFn::None;
576            self.graph_group = None;
577        } else {
578            // Lazily bind to a local graph group for this tensor id
579            if self.graph_group.is_none() {
580                self.graph_group = Some(ensure_local_group_for_tensor(self.id()));
581            }
582        }
583    }
584
585    /// Mark this tensor to retain gradients after backward, even if it is non-leaf.
586    ///
587    /// Builder-style API: returns self with `retain_grad=true`.
588    /// Call `materialize_grad()` or `grad_or_fetch()` after backward to copy the
589    /// accumulated gradient from the GradGraph into `self.grad` so `grad()` works.
590    #[track_caller]
591    pub fn retain_grad(mut self) -> Self {
592        self.retain_grad = true;
593        // Register intent with grad engine to keep gradient available after backward
594        crate::gradtrack::engine::mark_retain_grad(self.id);
595        self
596    }
597
598    /// In-place variant to enable or disable gradient retention for non-leaf tensors
599    #[track_caller]
600    pub fn retain_grad_(&mut self, enable: bool) {
601        self.retain_grad = enable;
602        if enable {
603            crate::gradtrack::engine::mark_retain_grad(self.id);
604        }
605    }
606
607    /// Check if this tensor requires gradients
608    ///
609    /// # Returns
610    ///
611    /// `true` if gradient tracking is enabled for this tensor
612    ///
613    /// # Examples
614    ///
615    /// ```
616    /// use train_station::Tensor;
617    ///
618    /// let tensor = Tensor::new(vec![2, 3]);
619    /// assert!(!tensor.requires_grad());
620    ///
621    /// let grad_tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
622    /// assert!(grad_tensor.requires_grad());
623    /// ```
624    #[track_caller]
625    pub fn requires_grad(&self) -> bool {
626        self.requires_grad
627    }
628
629    /// Get the accumulated gradients (if any)
630    ///
631    /// Returns a reference to the gradient tensor if gradients have been computed
632    /// and this tensor has gradient tracking enabled.
633    ///
634    /// # Returns
635    ///
636    /// Optional reference to the gradient tensor, or `None` if no gradients exist
637    ///
638    /// # Examples
639    ///
640    /// ```
641    /// use train_station::Tensor;
642    ///
643    /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
644    /// assert!(tensor.grad().is_none()); // No gradients computed yet
645    /// ```
646    #[track_caller]
647    pub fn grad(&self) -> Option<&Tensor> {
648        // First check if we have a gradient stored directly
649        if let Some(grad) = self.grad.as_ref() {
650            return Some(grad.as_ref());
651        }
652
653        None
654    }
655
656    /// Fetch the accumulated gradient after backward and cache it on this tensor if `retain_grad` is enabled.
657    ///
658    /// Returns true if a gradient was found and cached. After a successful call,
659    /// `grad()` will return `Some(&Tensor)` even for non-leaf tensors.
660    #[track_caller]
661    pub fn materialize_grad(&mut self) -> bool {
662        if self.grad.is_some() {
663            return true;
664        }
665        if !self.retain_grad {
666            return false;
667        }
668        if let Some(g) = self.grad_owned() {
669            self.set_grad(g);
670            return true;
671        }
672        false
673    }
674
675    /// Convenience accessor: if `retain_grad` is enabled, fetch and cache the gradient
676    /// on first access so callers can immediately get a reference.
677    #[track_caller]
678    pub fn grad_or_fetch(&mut self) -> Option<&Tensor> {
679        if self.grad.is_none() && self.retain_grad {
680            if let Some(g) = self.grad_owned() {
681                self.set_grad(g);
682            }
683        }
684        self.grad.as_deref()
685    }
686
687    /// Get the accumulated gradient as an owned tensor
688    ///
689    /// Returns the gradient by value (owned). This complements `grad()` which returns
690    /// a reference. Useful when you need to take ownership of the gradient data
691    /// (e.g., move into another structure or thread).
692    ///
693    /// This function does not clear internal gradient state. If a locally cached
694    /// gradient is not present, it consults shared autograd storage to fetch an
695    /// up-to-date gradient for this tensor ID.
696    ///
697    /// # Returns
698    ///
699    /// `Some(Tensor)` containing the gradient when available, otherwise `None`.
700    ///
701    /// # Examples
702    ///
703    /// ```
704    /// use train_station::Tensor;
705    ///
706    /// // Before backward: no gradient yet
707    /// let mut x = Tensor::ones(vec![2, 3]).with_requires_grad();
708    /// assert!(x.grad_owned().is_none());
709    ///
710    /// // Compute a simple loss and backprop
711    /// let mut loss = x.sum();
712    /// loss.backward(None);
713    ///
714    /// // Fetch the gradient by value (owned)
715    /// let g = x.grad_owned().unwrap();
716    /// assert_eq!(g.shape().dims(), vec![2, 3]);
717    /// ```
718    #[track_caller]
719    pub fn grad_owned(&self) -> Option<Tensor> {
720        // First check if we have a gradient stored directly
721        if let Some(grad) = self.grad.as_ref() {
722            return Some((**grad).clone());
723        }
724
725        // Always consult the global/shared gradient storage so gradients accumulated
726        // in a promoted/merged shared group are visible regardless of the calling thread.
727        // This is critical for cross-thread forward/backward validation and parity with LibTorch.
728        use crate::gradtrack;
729        gradtrack::get_accumulated_gradient(self.id)
730    }
731
732    /// Get the unique ID of this tensor
733    ///
734    /// Returns the unique identifier assigned to this tensor during creation.
735    /// This ID is used for gradtrack tracking and tensor identification.
736    ///
737    /// # Returns
738    ///
739    /// Unique tensor ID as `usize`
740    ///
741    /// # Examples
742    ///
743    /// ```
744    /// use train_station::Tensor;
745    ///
746    /// let tensor1 = Tensor::new(vec![2, 3]);
747    /// let tensor2 = Tensor::new(vec![2, 3]);
748    /// assert_ne!(tensor1.id(), tensor2.id()); // Each tensor has unique ID
749    /// ```
750    #[track_caller]
751    pub fn id(&self) -> usize {
752        self.id
753    }
754
755    /// Detach this tensor from the computation graph
756    ///
757    /// Returns a new tensor with the same data but no gradient tracking.
758    /// This is useful when you want to use a tensor in inference without
759    /// affecting the computation graph.
760    ///
761    /// # Returns
762    ///
763    /// A new tensor with the same data but gradient tracking disabled
764    ///
765    /// # Examples
766    ///
767    /// ```
768    /// use train_station::Tensor;
769    ///
770    /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
771    /// let detached = tensor.detach();
772    /// assert!(!detached.requires_grad());
773    /// assert_eq!(tensor.size(), detached.size());
774    /// ```
775    #[track_caller]
776    pub fn detach(&self) -> Self {
777        let mut detached = Self::new(self.shape().dims().to_vec());
778
779        // Copy data
780        unsafe {
781            let src = self.as_ptr();
782            let dst = detached.as_mut_ptr();
783            std::ptr::copy_nonoverlapping(src, dst, self.size());
784        }
785
786        detached
787    }
788
789    /// Create a new tensor that doesn't track gradients from this one
790    ///
791    /// Similar to detach() but modifies this tensor in place. This is useful
792    /// when you want to disable gradient tracking for the current tensor
793    /// without creating a copy.
794    ///
795    /// # Examples
796    ///
797    /// ```
798    /// use train_station::Tensor;
799    ///
800    /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
801    /// assert!(tensor.requires_grad());
802    /// tensor.detach_();
803    /// assert!(!tensor.requires_grad());
804    /// ```
805    #[track_caller]
806    pub fn detach_(&mut self) {
807        self.requires_grad = false;
808        self.grad = None;
809        self.grad_fn = GradFn::None;
810    }
811
812    /// Entry point for backward pass on this tensor
813    ///
814    /// Computes gradients for all tensors in the computation graph that have
815    /// `requires_grad` set to true. This is the main entry point for automatic
816    /// differentiation.
817    ///
818    /// # Arguments
819    ///
820    /// * `grad_output` - Optional gradient tensor for the output. If None, assumes
821    ///   the tensor is a scalar (e.g., loss value) and uses a tensor of ones.
822    ///
823    /// # Examples
824    ///
825    /// ```
826    /// use train_station::Tensor;
827    ///
828    /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
829    /// let mut result = tensor.add_scalar(5.0);
830    /// result.backward(None);
831    /// // Note: Gradient computation depends on the gradtrack system implementation
832    /// ```
833    #[track_caller]
834    pub fn backward(&mut self, grad_output: Option<Tensor>) {
835        GradEngine::backward(self, grad_output);
836    }
837
838    /// Returns a raw pointer to the tensor data for unsafe operations
839    ///
840    /// # Safety
841    ///
842    /// This is unsafe because it provides direct access to the underlying memory.
843    /// The caller must ensure:
844    /// - The tensor is not dropped while the pointer is used
845    /// - No concurrent mutable access occurs
846    /// - Bounds are respected
847    #[inline]
848    pub unsafe fn as_ptr(&self) -> *const f32 {
849        self.data.as_ptr()
850    }
851
852    /// Returns a mutable raw pointer to the tensor data for unsafe operations
853    ///
854    /// # Safety
855    ///
856    /// This is unsafe because it provides direct mutable access to the underlying memory.
857    /// The caller must ensure:
858    /// - The tensor is not dropped while the pointer is used
859    /// - No concurrent access occurs
860    /// - Bounds are respected
861    #[inline]
862    pub unsafe fn as_mut_ptr(&mut self) -> *mut f32 {
863        self.data.as_ptr()
864    }
865
866    /// Internal method to set gradient function (used by operations)
867    ///
868    /// Sets the gradient function for this tensor. This is used internally
869    /// by tensor operations to record the computation graph for gradtrack.
870    ///
871    /// # Arguments
872    ///
873    /// * `grad_fn` - The gradient function to set
874    ///
875    /// # Implementation Details
876    ///
877    /// This method is called by tensor operations to register the gradient
878    /// computation function. It only sets the gradient function if gradient
879    /// tracking is enabled for this tensor.
880    pub(crate) fn set_grad_fn(&mut self, grad_fn: GradFn) {
881        if self.requires_grad {
882            self.grad_fn = grad_fn;
883        }
884    }
885
886    /// Get a reference to the gradient function (for gradtrack)
887    ///
888    /// Returns a reference to the gradient function associated with this tensor.
889    /// This is used internally by the gradtrack system to compute gradients.
890    ///
891    /// # Returns
892    ///
893    /// Reference to the gradient function
894    ///
895    /// # Implementation Details
896    ///
897    /// This method is used by the gradtrack engine to access the gradient
898    /// computation function during backward pass.
899    #[track_caller]
900    pub fn grad_fn(&self) -> &GradFn {
901        &self.grad_fn
902    }
903
904    /// Internal method to set requires_grad (used by gradtrack operations)
905    ///
906    /// Sets the gradient tracking flag for this tensor. This is used internally
907    /// by gradtrack operations to control gradient computation.
908    ///
909    /// # Arguments
910    ///
911    /// * `requires_grad` - Whether to enable gradient tracking
912    ///
913    /// # Implementation Details
914    ///
915    /// This method is used internally by the gradtrack system to control
916    /// gradient tracking without triggering additional side effects.
917    pub(crate) fn set_requires_grad_internal(&mut self, requires_grad: bool) {
918        self.requires_grad = requires_grad;
919    }
920
921    /// Internal method to accumulate gradients with optimized operations
922    ///
923    /// Accumulates a gradient tensor with any existing gradients for this tensor.
924    /// This is used internally by the gradtrack system to handle gradient accumulation
925    /// during backward pass.
926    ///
927    /// # Arguments
928    ///
929    /// * `grad` - The gradient tensor to accumulate
930    ///
931    /// # Implementation Details
932    ///
933    /// This method is used internally by the gradtrack engine to accumulate
934    /// gradients from multiple backward passes or operations. It only accumulates
935    /// gradients if gradient tracking is enabled for this tensor.
936    pub(crate) fn accumulate_grad(&mut self, grad: Tensor) {
937        if !self.requires_grad {
938            return;
939        }
940
941        match &self.grad {
942            Some(existing_grad) => {
943                // Use optimized tensor addition but create new tensor for safety
944                let accumulated = existing_grad.add_tensor_optimized(&grad);
945                self.grad = Some(Arc::new(accumulated));
946            }
947            None => {
948                self.grad = Some(Arc::new(grad));
949            }
950        }
951    }
952
953    /// Set gradient from external source
954    ///
955    /// Sets the gradient tensor for this tensor. This is used internally by the
956    /// gradtrack system to set gradients during backward pass.
957    ///
958    /// # Arguments
959    ///
960    /// * `grad` - The gradient tensor to set
961    ///
962    /// # Implementation Details
963    ///
964    /// This method is used internally by the gradtrack engine to set gradients
965    /// during backward pass. It only sets the gradient if gradient tracking is
966    /// enabled for this tensor.
967    #[track_caller]
968    pub fn set_grad(&mut self, grad: Tensor) {
969        if self.requires_grad {
970            self.grad = Some(Arc::new(grad));
971        }
972    }
973
974    /// Clear accumulated gradients for this tensor
975    ///
976    /// This method is used by optimizers to zero gradients before each backward pass.
977    /// It clears any accumulated gradients, allowing for fresh gradient computation.
978    ///
979    /// # Examples
980    ///
981    /// ```
982    /// use train_station::Tensor;
983    ///
984    /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
985    /// tensor.set_grad(Tensor::ones(vec![2, 3]));
986    /// assert!(tensor.grad().is_some());
987    /// tensor.zero_grad();
988    /// assert!(tensor.grad().is_none());
989    /// ```
990    #[track_caller]
991    pub fn zero_grad(&mut self) {
992        // Clear locally cached gradient
993        self.grad = None;
994        // Ensure gradient is also cleared from autograd storage so future backward runs
995        // don't observe stale gradients when mixing TLS and shared groups.
996        crate::gradtrack::engine::clear_gradient_for_tensor(self.id);
997    }
998
999    /// Negate all elements in the tensor in-place
1000    ///
1001    /// This is used internally for gradient computation in subtraction operations.
1002    /// For tensor - tensor operations, the second operand gets a negated gradient.
1003    ///
1004    /// # Implementation Details
1005    ///
1006    /// This method is used internally by the gradtrack system to compute gradients
1007    /// for subtraction operations. It uses SIMD optimization when available for
1008    /// better performance.
1009    #[inline]
1010    pub(crate) fn negate_inplace(&mut self) {
1011        if self.shape.size() == 0 {
1012            return;
1013        }
1014
1015        unsafe {
1016            let ptr = self.data.as_ptr();
1017
1018            #[cfg(target_arch = "x86_64")]
1019            {
1020                // Use SIMD for better performance when available
1021                if is_x86_feature_detected!("avx2") {
1022                    self.negate_simd_avx2(ptr);
1023                    return;
1024                }
1025            }
1026
1027            // Fallback to scalar operations
1028            for i in 0..self.shape.size() {
1029                *ptr.add(i) = -*ptr.add(i);
1030            }
1031        }
1032    }
1033
1034    /// SIMD-optimized negation using AVX2 instructions
1035    ///
1036    /// Performs in-place negation of tensor elements using AVX2 SIMD instructions
1037    /// for improved performance on x86_64 architectures.
1038    ///
1039    /// # Arguments
1040    ///
1041    /// * `ptr` - Raw pointer to the tensor data
1042    ///
1043    /// # Safety
1044    ///
1045    /// The caller must ensure:
1046    /// - `ptr` is a valid pointer to tensor data
1047    /// - The tensor size matches the actual data size
1048    /// - The tensor is not moved or dropped during this operation
1049    ///
1050    /// # Implementation Details
1051    ///
1052    /// This method is used internally by `negate_inplace` when AVX2 is available.
1053    /// It processes 8 elements per iteration using AVX2 instructions.
1054    #[cfg(target_arch = "x86_64")]
1055    #[inline]
1056    unsafe fn negate_simd_avx2(&self, ptr: *mut f32) {
1057        use std::arch::x86_64::_mm256_setzero_ps;
1058
1059        let size = self.shape().size();
1060        let zero_vec = _mm256_setzero_ps();
1061        let simd_count = size / 8; // Process 8 elements per iteration
1062        let mut offset = 0;
1063
1064        // SIMD loop for negation
1065        for _ in 0..simd_count {
1066            use std::arch::x86_64::{_mm256_load_ps, _mm256_store_ps, _mm256_sub_ps};
1067
1068            let vec = _mm256_load_ps(ptr.add(offset));
1069            let neg_vec = _mm256_sub_ps(zero_vec, vec);
1070            _mm256_store_ps(ptr.add(offset), neg_vec);
1071            offset += 8;
1072        }
1073
1074        // Handle remaining elements
1075        for i in offset..size {
1076            *ptr.add(i) = -*ptr.add(i);
1077        }
1078    }
1079
1080    // ===== Memory Layout and Optimization API =====
1081
1082    /// Checks if the tensor data is stored contiguously in memory
1083    ///
1084    /// # Returns
1085    ///
1086    /// `true` if the tensor data is contiguous, enabling optimized SIMD operations
1087    ///
1088    /// # Examples
1089    ///
1090    /// ```
1091    /// use train_station::Tensor;
1092    ///
1093    /// let tensor = Tensor::new(vec![2, 3, 4]);
1094    /// assert!(tensor.is_contiguous());
1095    /// ```
1096    #[inline]
1097    #[track_caller]
1098    pub fn is_contiguous(&self) -> bool {
1099        self.shape().is_contiguous()
1100    }
1101
1102    /// Gets the memory strides for all dimensions
1103    ///
1104    /// # Returns
1105    ///
1106    /// Reference to the stride vector for efficient memory access calculations
1107    ///
1108    /// # Examples
1109    ///
1110    /// ```
1111    /// use train_station::Tensor;
1112    ///
1113    /// let tensor = Tensor::new(vec![2, 3, 4]);
1114    /// assert_eq!(tensor.strides(), &[12, 4, 1]);
1115    /// ```
1116    #[inline]
1117    #[track_caller]
1118    pub fn strides(&self) -> &[usize] {
1119        self.shape.strides()
1120    }
1121
1122    /// Gets the memory stride for a specific dimension
1123    ///
1124    /// # Arguments
1125    ///
1126    /// * `dim` - The dimension index
1127    ///
1128    /// # Returns
1129    ///
1130    /// The memory stride for the given dimension
1131    ///
1132    /// # Panics
1133    ///
1134    /// Panics if `dim` is out of bounds
1135    #[inline]
1136    #[track_caller]
1137    pub fn stride(&self, dim: usize) -> usize {
1138        self.shape.stride(dim)
1139    }
1140
1141    /// Calculates the linear memory offset for given multi-dimensional indices
1142    ///
1143    /// # Arguments
1144    ///
1145    /// * `indices` - Vector of indices for each dimension
1146    ///
1147    /// # Returns
1148    ///
1149    /// Linear memory offset for direct memory access
1150    ///
1151    /// # Examples
1152    ///
1153    /// ```
1154    /// use train_station::Tensor;
1155    ///
1156    /// let tensor = Tensor::new(vec![2, 3, 4]);
1157    /// let offset = tensor.memory_offset(&[1, 2, 3]);
1158    /// // offset = 1*12 + 2*4 + 3*1 = 23
1159    /// ```
1160    #[inline]
1161    #[track_caller]
1162    pub fn memory_offset(&self, indices: &[usize]) -> usize {
1163        self.shape.offset(indices)
1164    }
1165
1166    /// Broadcast this tensor with another tensor for element-wise operations
1167    ///
1168    /// Returns a tuple containing:
1169    /// - Broadcasted view of self
1170    /// - Broadcasted view of other
1171    /// - Result shape for the operation
1172    ///
1173    /// # Arguments
1174    ///
1175    /// * `other` - The tensor to broadcast with
1176    ///
1177    /// # Returns
1178    ///
1179    /// A tuple `(broadcasted_self, broadcasted_other, result_shape)`
1180    #[track_caller]
1181    pub(crate) fn broadcast_with(
1182        &self,
1183        other: &Tensor,
1184    ) -> Result<
1185        (Tensor, Tensor, crate::tensor::Shape),
1186        crate::tensor::ops::broadcasting::BroadcastError,
1187    > {
1188        crate::tensor::ops::broadcasting::broadcast_shapes(self, other)
1189    }
1190
1191    /// Checks if the tensor data is properly aligned for SIMD operations
1192    ///
1193    /// # Returns
1194    ///
1195    /// `true` if the tensor data is aligned to 32-byte boundaries for AVX2
1196    #[inline]
1197    #[track_caller]
1198    #[cfg(target_arch = "x86_64")]
1199    pub(crate) fn is_simd_aligned(&self) -> bool {
1200        // Maintain AVX2 compatibility for existing SIMD paths
1201        (self.data.as_ptr() as usize).is_multiple_of(32)
1202    }
1203
1204    /// Gets the memory alignment of the tensor data
1205    ///
1206    /// # Returns
1207    ///
1208    /// The memory alignment in bytes (typically 32 for SIMD optimization)
1209    #[inline]
1210    #[track_caller]
1211    pub fn memory_alignment(&self) -> usize {
1212        if let Some(owner) = self.allocation_owner() {
1213            owner.alignment()
1214        } else {
1215            // Zero-sized or non-owned views default to conservative 32
1216            32
1217        }
1218    }
1219
1220    /// Checks if this tensor is broadcastable with another tensor
1221    ///
1222    /// # Arguments
1223    ///
1224    /// * `other` - The other tensor to check broadcasting compatibility
1225    ///
1226    /// # Returns
1227    ///
1228    /// `true` if the tensors are broadcastable according to NumPy broadcasting rules
1229    ///
1230    /// # Examples
1231    ///
1232    /// ```
1233    /// use train_station::Tensor;
1234    ///
1235    /// let a = Tensor::new(vec![2, 3, 4]);
1236    /// let b = Tensor::new(vec![1, 3, 4]);
1237    /// assert!(a.is_broadcastable_with(&b));
1238    /// ```
1239    #[inline]
1240    #[track_caller]
1241    pub fn is_broadcastable_with(&self, other: &Tensor) -> bool {
1242        self.shape.is_broadcastable_with(&other.shape)
1243    }
1244
1245    /// Gets the total number of bytes allocated for this tensor
1246    ///
1247    /// # Returns
1248    ///
1249    /// Total memory footprint in bytes
1250    #[inline]
1251    #[track_caller]
1252    pub fn memory_footprint(&self) -> usize {
1253        self.shape.size() * std::mem::size_of::<f32>()
1254    }
1255
1256    /// Get a single element from the tensor at the specified indices
1257    ///
1258    /// # Arguments
1259    ///
1260    /// * `indices` - Multi-dimensional indices to access the element
1261    ///
1262    /// # Returns
1263    ///
1264    /// The value at the specified position
1265    ///
1266    /// # Panics
1267    ///
1268    /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1269    ///
1270    /// # Examples
1271    ///
1272    /// ```
1273    /// use train_station::Tensor;
1274    ///
1275    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1276    /// let value = tensor.get(&[0, 1]);
1277    /// assert_eq!(value, 2.0);
1278    /// ```
1279    #[track_caller]
1280    pub fn get(&self, indices: &[usize]) -> f32 {
1281        assert_eq!(
1282            indices.len(),
1283            self.shape().rank(),
1284            "Indices length must match tensor rank"
1285        );
1286
1287        // Check bounds
1288        for (i, &idx) in indices.iter().enumerate() {
1289            assert!(
1290                idx < self.shape().dims()[i],
1291                "Index {} out of bounds for dimension {}",
1292                idx,
1293                i
1294            );
1295        }
1296
1297        let offset = self.memory_offset(indices);
1298        unsafe { *self.as_ptr().add(offset) }
1299    }
1300
1301    /// Set a single element in the tensor at the specified indices
1302    ///
1303    /// # Arguments
1304    ///
1305    /// * `indices` - Multi-dimensional indices to set the element
1306    /// * `value` - The value to set
1307    ///
1308    /// # Panics
1309    ///
1310    /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1311    ///
1312    /// # Examples
1313    ///
1314    /// ```
1315    /// use train_station::Tensor;
1316    ///
1317    /// let mut tensor = Tensor::new(vec![2, 2]);
1318    /// tensor.set(&[0, 1], 42.0);
1319    /// assert_eq!(tensor.get(&[0, 1]), 42.0);
1320    /// ```
1321    #[track_caller]
1322    pub fn set(&mut self, indices: &[usize], value: f32) {
1323        assert_eq!(
1324            indices.len(),
1325            self.shape().rank(),
1326            "Indices length must match tensor rank"
1327        );
1328
1329        // Check bounds
1330        for (i, &idx) in indices.iter().enumerate() {
1331            assert!(
1332                idx < self.shape().dims()[i],
1333                "Index {} out of bounds for dimension {}",
1334                idx,
1335                i
1336            );
1337        }
1338
1339        let offset = self.memory_offset(indices);
1340        unsafe {
1341            *self.as_mut_ptr().add(offset) = value;
1342        }
1343    }
1344
1345    /// Returns a safe slice of the tensor's underlying data
1346    ///
1347    /// Provides safe access to the tensor's data without requiring unsafe pointer operations.
1348    /// This is the preferred way to access tensor data for reading values, comparisons,
1349    /// and other operations that don't require direct pointer manipulation.
1350    ///
1351    /// # Returns
1352    ///
1353    /// A slice containing all tensor elements in row-major order
1354    ///
1355    /// # Performance
1356    ///
1357    /// - **Zero-Cost**: Direct slice creation with no copying
1358    /// - **Cache-Friendly**: Sequential memory access pattern
1359    /// - **Safe**: No unsafe code required for basic data access
1360    ///
1361    /// # Examples
1362    ///
1363    /// ```
1364    /// use train_station::Tensor;
1365    ///
1366    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1367    /// let data = tensor.data();
1368    ///
1369    /// // Safe indexing and comparisons
1370    /// assert_eq!(data[0], 1.0);
1371    /// assert_eq!(data.len(), tensor.size());
1372    /// ```
1373    #[inline]
1374    #[track_caller]
1375    pub fn data(&self) -> &[f32] {
1376        if self.size() == 0 {
1377            return &[];
1378        }
1379        unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()) }
1380    }
1381
1382    /// Returns a mutable slice of the tensor's underlying data
1383    ///
1384    /// Provides safe mutable access to the tensor's data without requiring unsafe
1385    /// pointer operations. Use this for in-place modifications of tensor values.
1386    ///
1387    /// # Returns
1388    ///
1389    /// A mutable slice containing all tensor elements in row-major order
1390    ///
1391    /// # Performance
1392    ///
1393    /// - **Zero-Cost**: Direct slice creation with no copying
1394    /// - **Cache-Friendly**: Sequential memory access pattern
1395    /// - **Safe**: No unsafe code required for basic data modification
1396    ///
1397    /// # Examples
1398    ///
1399    /// ```
1400    /// use train_station::Tensor;
1401    ///
1402    /// let mut tensor = Tensor::new(vec![2, 2]);
1403    /// let data = tensor.data_mut();
1404    ///
1405    /// // Safe indexing for modification
1406    /// data[0] = 1.0;
1407    /// data[1] = 2.0;
1408    ///
1409    /// assert_eq!(tensor.get(&[0, 0]), 1.0);
1410    /// assert_eq!(tensor.get(&[0, 1]), 2.0);
1411    /// ```
1412    #[inline]
1413    #[track_caller]
1414    pub fn data_mut(&mut self) -> &mut [f32] {
1415        if self.size() == 0 {
1416            return &mut [];
1417        }
1418        // Copy-on-write to protect existing views
1419        self.ensure_unique_allocation();
1420        unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) }
1421    }
1422
1423    /// Extract scalar value from single-element tensor
1424    ///
1425    /// This method provides a convenient way to extract the scalar value from
1426    /// tensors that contain exactly one element. This is commonly used with
1427    /// element iterator results and scalar tensor operations.
1428    ///
1429    /// # Returns
1430    ///
1431    /// The scalar value contained in this tensor
1432    ///
1433    /// # Panics
1434    ///
1435    /// Panics if the tensor does not contain exactly one element
1436    ///
1437    /// # Examples
1438    ///
1439    /// ```
1440    /// use train_station::Tensor;
1441    ///
1442    /// // Single-element tensor
1443    /// let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
1444    /// assert_eq!(scalar.value(), 42.0);
1445    /// ```
1446    #[inline]
1447    #[track_caller]
1448    pub fn value(&self) -> f32 {
1449        assert_eq!(
1450            self.size(),
1451            1,
1452            "value() can only be called on tensors with exactly one element. \
1453             This tensor has {} elements with shape {:?}",
1454            self.size(),
1455            self.shape().dims()
1456        );
1457        self.data()[0]
1458    }
1459
1460    /// Create a view with a new shape (requires contiguous memory)
1461    ///
1462    /// Behaves like PyTorch `view`: tensor must be contiguous and the total
1463    /// number of elements must remain the same. Supports -1 inference for one dimension.
1464    ///
1465    /// # Arguments
1466    ///
1467    /// * `new_shape` - New shape for the tensor (can contain -1 for inference)
1468    ///
1469    /// # Returns
1470    ///
1471    /// A tensor viewing the same data with a new shape
1472    ///
1473    /// # Examples
1474    ///
1475    /// ```
1476    /// use train_station::Tensor;
1477    ///
1478    /// let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
1479    /// let y = x.view(vec![2, 2]);
1480    /// assert_eq!(y.shape().dims(), vec![2, 2]);
1481    /// ```
1482    #[track_caller]
1483    pub fn view(&self, new_shape: Vec<i32>) -> Tensor {
1484        // PyTorch-like view with single -1 inference; requires contiguity
1485        let size = self.size();
1486        let mut infer_idx: Option<usize> = None;
1487        let mut product: usize = 1;
1488        let mut dims: Vec<usize> = Vec::with_capacity(new_shape.len());
1489        for (i, d) in new_shape.iter().enumerate() {
1490            if *d == -1 {
1491                assert!(infer_idx.is_none(), "Only one -1 is allowed in view shape");
1492                infer_idx = Some(i);
1493                dims.push(1);
1494            } else {
1495                assert!(*d > 0, "Negative dims not supported in view shape");
1496                let du = *d as usize;
1497                product = product.saturating_mul(du);
1498                dims.push(du);
1499            }
1500        }
1501        if let Some(pos) = infer_idx {
1502            assert!(
1503                product > 0 && size.is_multiple_of(product),
1504                "View shape incompatible with tensor size"
1505            );
1506            dims[pos] = size / product;
1507        } else {
1508            assert!(
1509                product == size,
1510                "View shape has different number of elements"
1511            );
1512        }
1513
1514        let mut v = match crate::tensor::core::view::reshape_view(self, &dims) {
1515            Ok(v) => v,
1516            Err(e) => panic!("view reshape error: {:?}", e),
1517        };
1518        if self.requires_grad() && gradtrack::is_grad_enabled() {
1519            v.set_requires_grad(true);
1520            let grad_fn = GradFn::Reshape {
1521                original_shape: self.shape().dims().to_vec(),
1522            };
1523            v.set_grad_fn(grad_fn.clone());
1524            GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
1525        }
1526        v
1527    }
1528
1529    /// Create an element view for the specified index
1530    ///
1531    /// Returns a scalar tensor (shape \[1\]) that views a single element
1532    /// of the source tensor. Maintains gradient tracking.
1533    ///
1534    /// # Arguments
1535    ///
1536    /// * `index` - Linear index of the element to view
1537    ///
1538    /// # Returns
1539    ///
1540    /// A scalar tensor viewing the specified element
1541    ///
1542    /// # Examples
1543    ///
1544    /// ```
1545    /// use train_station::Tensor;
1546    ///
1547    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
1548    /// let element = tensor.element_view(1);
1549    /// assert_eq!(element.value(), 2.0);
1550    /// ```
1551    #[track_caller]
1552    pub fn element_view(&self, index: usize) -> Tensor {
1553        let mut v = match crate::tensor::core::view::element_view_linear(self, index) {
1554            Ok(v) => v,
1555            Err(e) => panic!("element_view error: {:?}", e),
1556        };
1557        if self.requires_grad() && gradtrack::is_grad_enabled() {
1558            v.set_requires_grad(true);
1559            // Reuse SliceView grad for single-element selection to propagate back to source
1560            let grad_fn = GradFn::View {
1561                mapping: crate::gradtrack::grad_fn::ViewMapping::LinearRange {
1562                    start: index,
1563                    step: 1,
1564                    length: 1,
1565                },
1566                input_shape: self.shape().dims().to_vec(),
1567            };
1568            v.set_grad_fn(grad_fn.clone());
1569            GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
1570        }
1571        v
1572    }
1573
1574    /// Create a slice view of the tensor
1575    ///
1576    /// Returns a view of a contiguous or strided slice of the source tensor.
1577    ///
1578    /// # Arguments
1579    ///
1580    /// * `start` - Starting index
1581    /// * `step` - Step size (1 for contiguous)
1582    /// * `length` - Number of elements
1583    ///
1584    /// # Returns
1585    ///
1586    /// A tensor viewing the specified slice
1587    ///
1588    /// # Examples
1589    ///
1590    /// ```
1591    /// use train_station::Tensor;
1592    ///
1593    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1594    /// let slice = tensor.slice_view(1, 2, 2); // [2.0, 4.0]
1595    /// assert_eq!(slice.get(&[0]), 2.0);
1596    /// assert_eq!(slice.get(&[1]), 4.0);
1597    /// ```
1598    #[track_caller]
1599    pub fn slice_view(&self, start: usize, step: usize, length: usize) -> Tensor {
1600        let mut v = match crate::tensor::core::view::slice_view_linear(self, start, step, length) {
1601            Ok(v) => v,
1602            Err(e) => panic!("slice_view error: {:?}", e),
1603        };
1604        // Ensure correct data() representation for stepped views
1605        // Offset-only (step==1, start>0) is already contiguous via base pointer
1606        if step != 1 {
1607            v = v.contiguous();
1608        }
1609        if self.requires_grad() && gradtrack::is_grad_enabled() {
1610            v.set_requires_grad(true);
1611            let grad_fn = GradFn::View {
1612                mapping: crate::gradtrack::grad_fn::ViewMapping::LinearRange {
1613                    start,
1614                    step,
1615                    length,
1616                },
1617                input_shape: self.shape().dims().to_vec(),
1618            };
1619            v.set_grad_fn(grad_fn.clone());
1620            GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
1621        }
1622        v
1623    }
1624
1625    /// Create a tensor view from raw components
1626    ///
1627    /// Creates a tensor that views existing memory with the specified shape and device.
1628    /// The tensor shares memory with the original allocation through the allocation_owner.
1629    ///
1630    /// # Safety
1631    ///
1632    /// The caller must ensure:
1633    /// - `data` is valid for the number of elements specified by `shape`
1634    /// - `data` remains valid for the lifetime of the returned tensor
1635    /// - `allocation_owner` properly manages the memory lifecycle
1636    ///
1637    /// # Arguments
1638    ///
1639    /// * `data` - Raw pointer to the tensor data
1640    /// * `shape` - Shape of the tensor view
1641    /// * `device` - Device where the tensor is located
1642    /// * `allocation_owner` - Optional shared allocation owner
1643    ///
1644    /// # Returns
1645    ///
1646    /// A new tensor that views the specified memory
1647    ///
1648    /// # Implementation Details
1649    ///
1650    /// This method is used internally to create tensor views that share memory
1651    /// with other tensors. It's primarily used for view operations and memory
1652    /// management.
1653    pub(crate) fn from_raw_view(
1654        data: *const f32,
1655        shape: Shape,
1656        device: Device,
1657        allocation_owner: Option<Arc<Allocation>>,
1658    ) -> Self {
1659        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1660
1661        Self {
1662            data: NonNull::new(data as *mut f32).expect("Data pointer cannot be null"),
1663            shape,
1664            device,
1665            id,
1666            requires_grad: false,
1667            retain_grad: false,
1668            grad: None,
1669            grad_fn: GradFn::None,
1670            allocation_owner,
1671            graph_group: None,
1672            _phantom: PhantomData,
1673        }
1674    }
1675
1676    /// Get the allocation owner for this tensor
1677    ///
1678    /// Returns the shared allocation owner if this tensor is a view,
1679    /// or None if this tensor owns its memory directly.
1680    ///
1681    /// # Returns
1682    ///
1683    /// Optional reference to the allocation owner
1684    ///
1685    /// # Implementation Details
1686    ///
1687    /// This method is used internally to manage memory lifecycle for tensor views.
1688    /// It helps determine whether a tensor shares memory with another tensor.
1689    #[track_caller]
1690    pub fn allocation_owner(&self) -> Option<&Arc<Allocation>> {
1691        self.allocation_owner.as_ref()
1692    }
1693
1694    /// Create a new tensor with uninitialized memory
1695    ///
1696    /// This method allocates memory for a tensor without initializing it to any value.
1697    /// This is useful for performance-critical operations where the memory will be
1698    /// immediately overwritten, such as matrix multiplication results.
1699    ///
1700    /// # Safety
1701    ///
1702    /// The caller must ensure that all memory is written before reading from the tensor.
1703    /// Reading from uninitialized memory is undefined behavior.
1704    ///
1705    /// # Arguments
1706    ///
1707    /// * `shape_dims` - The dimensions of the tensor
1708    ///
1709    /// # Returns
1710    ///
1711    /// A tensor with uninitialized memory
1712    ///
1713    /// # Performance
1714    ///
1715    /// - **Zero Initialization**: Skips memory initialization for maximum performance
1716    /// - **SIMD Ready**: Properly aligned for vectorized operations
1717    /// - **Memory Efficient**: Uses optimized alignment strategies
1718    ///
1719    /// # Example
1720    ///
1721    /// ```
1722    /// use train_station::Tensor;
1723    ///
1724    /// // Create uninitialized tensor for matmul result
1725    /// let mut result = Tensor::new_uninitialized(vec![100, 100]);
1726    /// // Initialize the memory before use
1727    /// for value in result.data_mut() {
1728    ///     *value = 0.0;
1729    /// }
1730    /// ```
1731    #[inline]
1732    #[track_caller]
1733    pub fn new_uninitialized(shape_dims: Vec<usize>) -> Self {
1734        let shape = Shape::new(shape_dims);
1735        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1736
1737        if shape.size() == 0 {
1738            // Handle zero-sized tensors
1739            return Self {
1740                data: NonNull::dangling(),
1741                shape,
1742                device: current_device(),
1743                id,
1744                requires_grad: false,
1745                retain_grad: false,
1746                grad: None,
1747                grad_fn: GradFn::None,
1748                allocation_owner: None,
1749                graph_group: None,
1750                _phantom: PhantomData,
1751            };
1752        }
1753
1754        let (alignment, padded_elems) = compute_allocation_params(shape.size());
1755        let total_size = padded_elems * std::mem::size_of::<f32>();
1756        let layout = Layout::from_size_align(total_size, alignment)
1757            .expect("Failed to create layout for tensor data");
1758
1759        // Allocate memory via shared Allocation (uninitialized)
1760        let alloc_obj = if use_pool_alloc_enabled() {
1761            Allocation::new_pooled(padded_elems, alignment, layout)
1762        } else {
1763            Allocation::new_uninitialized(padded_elems, alignment, layout)
1764        };
1765        let ptr = alloc_obj.ptr;
1766
1767        debug_assert!(
1768            alloc_obj.capacity_elems() >= padded_elems,
1769            "Allocation capacity ({}) smaller than padded elements ({})",
1770            alloc_obj.capacity_elems(),
1771            padded_elems
1772        );
1773
1774        Self {
1775            data: ptr,
1776            shape,
1777            device: current_device(),
1778            id,
1779            requires_grad: false,
1780            retain_grad: false,
1781            grad: None,
1782            grad_fn: GradFn::None,
1783            allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
1784            graph_group: None,
1785            _phantom: PhantomData,
1786        }
1787    }
1788
1789    /// Create a new uninitialized tensor with an explicit alignment request (in bytes)
1790    ///
1791    /// This is intended for internal high-performance paths (e.g., packed GEMM panels)
1792    /// where stronger alignment such as 64 bytes is desired even on AVX2 systems.
1793    #[inline]
1794    #[track_caller]
1795    pub fn new_uninitialized_aligned(shape_dims: Vec<usize>, alignment_bytes: usize) -> Self {
1796        let shape = Shape::new(shape_dims);
1797        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1798
1799        if shape.size() == 0 {
1800            return Self {
1801                data: NonNull::dangling(),
1802                shape,
1803                device: current_device(),
1804                id,
1805                requires_grad: false,
1806                retain_grad: false,
1807                grad: None,
1808                grad_fn: GradFn::None,
1809                allocation_owner: None,
1810                graph_group: None,
1811                _phantom: PhantomData,
1812            };
1813        }
1814
1815        // Preserve the usual element padding policy
1816        let (_default_align, padded_elems) = compute_allocation_params(shape.size());
1817        // Honor explicit alignment request (at least 16)
1818        let alignment = alignment_bytes.max(16);
1819        let total_size = padded_elems * std::mem::size_of::<f32>();
1820        let layout = Layout::from_size_align(total_size, alignment)
1821            .expect("Failed to create layout for tensor data (aligned)");
1822
1823        let alloc_obj = if use_pool_alloc_enabled() {
1824            Allocation::new_pooled(padded_elems, alignment, layout)
1825        } else {
1826            Allocation::new_uninitialized(padded_elems, alignment, layout)
1827        };
1828        let ptr = alloc_obj.ptr;
1829
1830        debug_assert!(alloc_obj.capacity_elems() >= padded_elems);
1831
1832        Self {
1833            data: ptr,
1834            shape,
1835            device: current_device(),
1836            id,
1837            requires_grad: false,
1838            retain_grad: false,
1839            grad: None,
1840            grad_fn: GradFn::None,
1841            allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
1842            graph_group: None,
1843            _phantom: PhantomData,
1844        }
1845    }
1846}
1847
1848#[cfg(test)]
1849mod memory_alloc_tests {
1850    use super::*;
1851    use crate::tensor::core::memory::{
1852        detect_runtime_simd, simd_alignment_bytes, simd_lane_width_elems, with_no_mem_padding,
1853        TensorMemoryPool,
1854    };
1855
1856    #[test]
1857    fn test_padding_and_alignment_pool_enabled() {
1858        let lane = simd_lane_width_elems(detect_runtime_simd());
1859        let align = simd_alignment_bytes(detect_runtime_simd());
1860
1861        let req = lane * 3 + 1; // force padding
1862        let t = Tensor::new(vec![req]);
1863
1864        assert_eq!(t.capacity_elems() % lane, 0);
1865        unsafe {
1866            assert_eq!((t.as_ptr() as usize) % align, 0);
1867        }
1868        // Logical size is not a multiple of lane, so aligned SIMD ops are not preferred
1869        #[cfg(target_arch = "x86_64")]
1870        assert!(!t.prefer_aligned_simd_ops());
1871
1872        drop(t);
1873    }
1874
1875    #[test]
1876    fn test_no_padding_with_guard() {
1877        with_no_mem_padding(|| {
1878            let lane = simd_lane_width_elems(detect_runtime_simd());
1879            let req = lane * 3 + 1; // non-multiple
1880            let t = Tensor::new(vec![req]);
1881
1882            assert_eq!(t.size(), req);
1883            #[cfg(target_arch = "x86_64")]
1884            assert!(!t.prefer_aligned_simd_ops());
1885        });
1886    }
1887
1888    #[test]
1889    fn test_system_alloc_no_pool_counters_unchanged() {
1890        let before = TensorMemoryPool::thread_stats();
1891        let _t1 = Tensor::new(vec![128]);
1892        let _t2 = Tensor::new(vec![257]);
1893        let after = TensorMemoryPool::thread_stats();
1894        assert_eq!(before.allocations, 0);
1895        assert_eq!(after.allocations, 2);
1896        assert_eq!(before.deallocations, 0);
1897    }
1898
1899    #[test]
1900    fn test_pool_alloc_and_dealloc_counters_match() {
1901        let before = TensorMemoryPool::thread_stats();
1902        {
1903            let _t1 = Tensor::new(vec![64]);
1904            let _t2 = Tensor::new(vec![2048]);
1905            let _t3 = Tensor::new(vec![131072]);
1906        }
1907        let after = TensorMemoryPool::thread_stats();
1908        assert!(after.allocations >= before.allocations + 3);
1909        assert!(after.deallocations >= before.deallocations + 3);
1910    }
1911
1912    #[test]
1913    fn test_mixed_modes_no_leak_in_pool_stats_scope() {
1914        let before = TensorMemoryPool::thread_stats();
1915        {
1916            let _a = Tensor::new(vec![1000]);
1917            let _b = Tensor::new(vec![1000]);
1918            let _c = Tensor::new(vec![2000]);
1919            let _c = Tensor::new(vec![2000]);
1920        }
1921
1922        let after = TensorMemoryPool::thread_stats();
1923        assert_eq!(
1924            after.allocations - before.allocations,
1925            after.deallocations - before.deallocations
1926        );
1927    }
1928
1929    #[cfg(target_arch = "x86_64")]
1930    #[test]
1931    fn test_alignment_helpers_and_hints() {
1932        let lane = simd_lane_width_elems(detect_runtime_simd());
1933        let req = lane * 4; // exact multiple
1934        let t = Tensor::new(vec![req]);
1935        assert!(t.is_aligned_for_runtime_level());
1936        assert!(t.prefer_aligned_simd_ops());
1937        assert!(!t.should_use_unaligned_simd_ops());
1938
1939        with_no_mem_padding(|| {
1940            let lane = simd_lane_width_elems(detect_runtime_simd());
1941            let req = lane * 4 + 1; // not multiple
1942            let t = Tensor::new(vec![req]);
1943            assert!(!t.prefer_aligned_simd_ops());
1944            assert!(t.should_use_unaligned_simd_ops());
1945        });
1946    }
1947}
1948
1949#[cfg(test)]
1950mod memory_alloc_additional_tests {
1951    use super::*;
1952    use crate::tensor::core::memory::{
1953        detect_runtime_simd, simd_lane_width_elems, with_no_mem_padding, with_no_mem_pool,
1954        TensorMemoryPool,
1955    };
1956    use std::time::Instant;
1957
1958    #[test]
1959    fn test_zero_size_tensors_pool_and_system() {
1960        let t = Tensor::new(vec![0]);
1961        assert_eq!(t.size(), 0);
1962        assert_eq!(t.capacity_elems(), 0);
1963        assert_eq!(t.data().len(), 0);
1964        let t = Tensor::new(vec![0]);
1965        assert_eq!(t.size(), 0);
1966        assert_eq!(t.capacity_elems(), 0);
1967        assert_eq!(t.data().len(), 0);
1968    }
1969
1970    #[test]
1971    fn test_capacity_across_various_sizes_padding_modes() {
1972        let lane = simd_lane_width_elems(detect_runtime_simd());
1973        let sizes = [
1974            0,
1975            1,
1976            lane - 1,
1977            lane,
1978            lane + 1,
1979            2 * lane - 1,
1980            2 * lane + 1,
1981            1000,
1982            1001,
1983        ];
1984
1985        for &n in &sizes {
1986            let t = Tensor::new(vec![n]);
1987            let expected_padded = if n == 0 { 0 } else { n.div_ceil(lane) * lane };
1988            // Pool may round up further than lane padding; ensure at least padded and lane-aligned
1989            assert!(
1990                t.capacity_elems() >= expected_padded,
1991                "n={} lane={} cap={}",
1992                n,
1993                lane,
1994                t.capacity_elems()
1995            );
1996            if t.capacity_elems() > 0 {
1997                assert_eq!(
1998                    t.capacity_elems() % lane,
1999                    0,
2000                    "capacity not lane-multiple: {}",
2001                    t.capacity_elems()
2002                );
2003            }
2004        }
2005        with_no_mem_padding(|| {
2006            for &n in &sizes {
2007                let t = Tensor::new(vec![n]);
2008                assert_eq!(t.size(), n);
2009            }
2010        });
2011    }
2012
2013    #[test]
2014    fn test_class_boundary_planned_capacity_pool() {
2015        let boundaries = [
2016            crate::tensor::core::memory::SMALL_BUFFER_SIZE,
2017            crate::tensor::core::memory::SMALL_BUFFER_SIZE + 1,
2018            crate::tensor::core::memory::MEDIUM_BUFFER_SIZE,
2019            crate::tensor::core::memory::MEDIUM_BUFFER_SIZE + 1,
2020            crate::tensor::core::memory::LARGE_BUFFER_SIZE,
2021            crate::tensor::core::memory::LARGE_BUFFER_SIZE + 1,
2022        ];
2023        let lane = simd_lane_width_elems(detect_runtime_simd());
2024        for &n in &boundaries {
2025            let padded = if n == 0 { 0 } else { n.div_ceil(lane) * lane };
2026            let planned = TensorMemoryPool::planned_capacity_elems(padded);
2027            let t = Tensor::new(vec![n]);
2028            assert_eq!(t.capacity_elems(), planned, "boundary n={}", n);
2029        }
2030    }
2031
2032    #[test]
2033    fn test_view_does_not_allocate_additional_memory() {
2034        let before = TensorMemoryPool::thread_stats();
2035        let t = Tensor::new(vec![256]);
2036        let _view = t.view(vec![16, 16]);
2037        let after = TensorMemoryPool::thread_stats();
2038        // Exactly one allocation happened (for t), view didn't allocate
2039        assert_eq!(after.allocations, before.allocations + 1);
2040    }
2041
2042    #[test]
2043    fn test_performance_pool_vs_system_small_allocs() {
2044        let iters = 1000;
2045        let _ = Tensor::new(vec![64]);
2046        let _ = Tensor::new(vec![64]);
2047        let _ = Tensor::new(vec![64]);
2048
2049        let pool_time = {
2050            let start = Instant::now();
2051            for _ in 0..iters {
2052                let _t = Tensor::new(vec![64]);
2053            }
2054            start.elapsed()
2055        };
2056        let sys_time = with_no_mem_pool(|| {
2057            let start = Instant::now();
2058            for _ in 0..iters {
2059                let _t = Tensor::new(vec![64]);
2060            }
2061            start.elapsed()
2062        });
2063        assert!(
2064            pool_time <= sys_time * 10,
2065            "pool {:?} vs system {:?}",
2066            pool_time,
2067            sys_time
2068        );
2069    }
2070
2071    #[test]
2072    fn test_performance_padded_vs_unpadded_fill() {
2073        let lane = simd_lane_width_elems(detect_runtime_simd());
2074        let n = lane * 2048;
2075
2076        let padded_time = {
2077            let t = Tensor::new(vec![n]);
2078            let mut x = t.clone();
2079            let start = Instant::now();
2080            x.fill(1.2345);
2081            start.elapsed()
2082        };
2083        let unpadded_time = with_no_mem_padding(|| {
2084            let t = Tensor::new(vec![n + 1]);
2085            let mut x = t.clone();
2086            let start = Instant::now();
2087            x.fill(1.2345);
2088            start.elapsed()
2089        });
2090        assert!(
2091            padded_time <= unpadded_time * 100,
2092            "padded {:?} vs unpadded {:?}",
2093            padded_time,
2094            unpadded_time
2095        );
2096    }
2097}