train_station/tensor/core/
utils.rs

1//! Tensor utility functions and core implementation methods
2//!
3//! This module provides essential utility functions for tensor creation, memory management,
4//! gradient tracking, and optimization. It contains the core implementation methods
5//! that enable efficient tensor operations and gradtrack functionality.
6//!
7//! # Key Features
8//!
9//! - **Tensor Creation**: Optimized constructors with memory alignment
10//! - **Memory Management**: Safe memory access and allocation utilities
11//! - **Gradient Tracking**: GradTrack system integration and gradient management
12//! - **Performance Optimization**: SIMD-ready memory layout and alignment
13//! - **Device Management**: CPU and future CUDA device support
14//! - **Memory Layout**: Contiguous, strided, and view memory access patterns
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Alignment**: 16-byte SSE, 32-byte AVX2, 64-byte cache-line alignment
19//! - **SIMD Optimization**: Properly aligned memory for vectorized operations
20//! - **Zero-Cost Abstractions**: Minimal overhead for utility operations
21//! - **Thread Safety**: Atomic operations for gradient tracking and ID generation
22//! - **Memory Efficiency**: Optimized allocation strategies for different tensor sizes
23//!
24//! # Examples
25//!
26//! ## Basic Tensor Creation
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensors of different sizes
32//! let small_tensor = Tensor::new(vec![2, 3]);      // 16-byte alignment
33//! let medium_tensor = Tensor::new(vec![32, 32]);   // 32-byte alignment
34//! let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
35//!
36//! // Initialize data before use
37//! let mut tensor = Tensor::new(vec![2, 3]);
38//! tensor.fill(0.0); // Initialize with zeros
39//! ```
40//!
41//! ## Gradient Tracking
42//!
43//! ```
44//! use train_station::Tensor;
45//!
46//! let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
47//! assert!(tensor.requires_grad());
48//! ```
49//!
50//! ## Memory Access
51//!
52//! ```
53//! use train_station::Tensor;
54//!
55//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
56//! let value = tensor.get(&[0, 1]);
57//! assert_eq!(value, 2.0);
58//!
59//! let mut tensor = Tensor::new(vec![2, 2]);
60//! tensor.set(&[0, 1], 42.0);
61//! assert_eq!(tensor.get(&[0, 1]), 42.0);
62//! ```
63
64use std::alloc::Layout;
65use std::marker::PhantomData;
66use std::ptr::NonNull;
67use std::sync::atomic::Ordering;
68use std::sync::Arc;
69
70use crate::device::current_device;
71use crate::gradtrack::{self, GradEngine, GradFn};
72use crate::tensor::core::{Allocation, Device, TENSOR_ID_COUNTER};
73use crate::tensor::Shape;
74
75use super::Tensor;
76
77impl Tensor {
78    /// Creates a new tensor with the specified shape and optimized memory layout
79    ///
80    /// Allocates memory with size-dependent alignment for optimal performance:
81    /// - Small tensors (≤8 elements): 16-byte SSE alignment
82    /// - Medium tensors (8-1024 elements): 32-byte AVX2 alignment
83    /// - Large tensors (>1024 elements): 64-byte cache-line alignment
84    ///
85    /// # Arguments
86    ///
87    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
88    ///
89    /// # Returns
90    ///
91    /// A new tensor with uninitialized data. The data must be initialized
92    /// before use to avoid undefined behavior.
93    ///
94    /// # Performance
95    ///
96    /// - **Memory Allocation**: Single allocation with optimized alignment
97    /// - **SIMD Ready**: Properly aligned for vectorized operations
98    /// - **Cache Friendly**: Optimized for CPU cache hierarchies
99    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
100    ///
101    /// # Safety
102    ///
103    /// The returned tensor contains uninitialized memory. You must initialize
104    /// the data before performing any operations that read from it.
105    ///
106    /// # Examples
107    ///
108    /// ```
109    /// use train_station::Tensor;
110    ///
111    /// // Create tensors of different sizes
112    /// let small_tensor = Tensor::new(vec![2, 3]);      // 16-byte alignment
113    /// let medium_tensor = Tensor::new(vec![32, 32]);   // 32-byte alignment
114    /// let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
115    ///
116    /// // Initialize data before use
117    /// let mut tensor = Tensor::new(vec![2, 3]);
118    /// tensor.fill(0.0); // Initialize with zeros
119    /// ```
120    #[inline]
121    #[track_caller]
122    pub fn new(shape_dims: Vec<usize>) -> Self {
123        let shape = Shape::new(shape_dims);
124        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
125
126        if shape.size == 0 {
127            // Handle zero-sized tensors
128            return Self {
129                data: NonNull::dangling(),
130                shape,
131                device: current_device(),
132                id,
133                requires_grad: false,
134                grad: None,
135                grad_fn: GradFn::None,
136                allocation_owner: None,
137                _phantom: PhantomData,
138            };
139        }
140
141        // Optimized layout calculation for better cache performance
142        let element_size = std::mem::size_of::<f32>();
143        let total_size = shape.size * element_size;
144
145        // Use cache line alignment for large tensors, smaller alignment for small ones
146        let alignment = if total_size > 4096 {
147            64 // Cache line alignment for large tensors
148        } else if shape.size >= 8 {
149            32 // AVX2 alignment for medium tensors
150        } else {
151            16 // SSE alignment for small tensors
152        };
153
154        let layout = Layout::from_size_align(total_size, alignment)
155            .expect("Failed to create layout for tensor data");
156
157        // Allocate memory via shared Allocation
158        let alloc_obj = Allocation::new(shape.size, alignment, layout);
159        let ptr = alloc_obj.ptr;
160
161        Self {
162            data: ptr,
163            shape,
164            device: current_device(),
165            id,
166            requires_grad: false,
167            grad: None,
168            grad_fn: GradFn::None,
169            allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
170            _phantom: PhantomData,
171        }
172    }
173
174    /// Returns the shape and dimensional information of the tensor
175    ///
176    /// Provides access to the tensor's dimensions, size, strides, and memory
177    /// layout information. This is used for shape validation, memory access
178    /// calculations, and optimization decisions.
179    ///
180    /// # Returns
181    ///
182    /// Reference to the tensor's shape information containing dimensions,
183    /// size, strides, and memory layout type.
184    ///
185    /// # Performance
186    ///
187    /// - **Time Complexity**: O(1) - direct field access
188    /// - **Memory**: No allocation - returns reference to existing data
189    ///
190    /// # Examples
191    ///
192    /// ```
193    /// use train_station::Tensor;
194    ///
195    /// let tensor = Tensor::new(vec![2, 3, 4]);
196    /// let shape = tensor.shape();
197    /// assert_eq!(shape.dims, vec![2, 3, 4]);
198    /// assert_eq!(shape.size, 24);
199    /// assert_eq!(shape.rank(), 3);
200    /// ```
201    #[inline]
202    #[track_caller]
203    pub fn shape(&self) -> &Shape {
204        &self.shape
205    }
206
207    /// Returns the total number of elements in the tensor
208    ///
209    /// Provides the total count of elements across all dimensions. This is
210    /// used for memory allocation, iteration bounds, and performance optimization.
211    ///
212    /// # Returns
213    ///
214    /// Total number of elements as `usize`
215    ///
216    /// # Performance
217    ///
218    /// - **Time Complexity**: O(1) - direct field access
219    /// - **Memory**: No allocation - returns stored value
220    ///
221    /// # Examples
222    ///
223    /// ```
224    /// use train_station::Tensor;
225    ///
226    /// let tensor = Tensor::new(vec![2, 3, 4]);
227    /// assert_eq!(tensor.size(), 24); // 2 * 3 * 4
228    ///
229    /// let scalar = Tensor::new(vec![1]);
230    /// assert_eq!(scalar.size(), 1);
231    ///
232    /// let empty = Tensor::new(vec![0]);
233    /// assert_eq!(empty.size(), 0);
234    /// ```
235    #[inline]
236    #[track_caller]
237    pub fn size(&self) -> usize {
238        self.shape.size
239    }
240
241    /// Returns the device where this tensor is located
242    ///
243    /// Provides the physical location of the tensor data (CPU/GPU). This
244    /// determines which operations can be performed on the tensor and where
245    /// computations will be executed.
246    ///
247    /// # Returns
248    ///
249    /// Device enum indicating the tensor's physical location
250    ///
251    /// # Performance
252    ///
253    /// - **Time Complexity**: O(1) - direct field access
254    /// - **Memory**: No allocation - returns stored value
255    ///
256    /// # Examples
257    ///
258    /// ```
259    /// use train_station::Tensor;
260    ///
261    /// let tensor = Tensor::new(vec![2, 3]);
262    /// assert!(tensor.device().is_cpu());
263    /// assert!(!tensor.device().is_cuda());
264    /// ```
265    #[inline]
266    #[track_caller]
267    pub fn device(&self) -> Device {
268        self.device
269    }
270
271    /// Creates a new tensor with the specified shape on a specific device
272    ///
273    /// Allocates memory on the specified device with the same optimized alignment
274    /// strategy as `new()`. Currently supports CPU device with future CUDA support.
275    ///
276    /// # Arguments
277    ///
278    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
279    /// * `device` - The device where the tensor should be allocated
280    ///
281    /// # Returns
282    ///
283    /// A new tensor with uninitialized data on the specified device
284    ///
285    /// # Performance
286    ///
287    /// - **Memory Allocation**: Device-specific allocation with optimized alignment
288    /// - **SIMD Ready**: Properly aligned for vectorized operations on target device
289    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
290    ///
291    /// # Panics
292    ///
293    /// Panics if the specified device is not supported (e.g., CUDA without feature flag)
294    ///
295    /// # Examples
296    ///
297    /// ```
298    /// use train_station::Tensor;
299    ///
300    /// let tensor = Tensor::new_on_device(vec![2, 3], train_station::Device::cpu());
301    /// assert!(tensor.device().is_cpu());
302    /// assert_eq!(tensor.size(), 6);
303    /// ```
304    ///
305    /// # Arguments
306    ///
307    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
308    /// * `device` - Device where the tensor should be allocated
309    ///
310    /// # Returns
311    ///
312    /// A new tensor with uninitialized data on the specified device
313    ///
314    /// # Panics
315    ///
316    /// Panics if the device is not supported (currently only CPU is supported)
317    ///
318    /// # Performance
319    ///
320    /// - **Memory Allocation**: Single allocation with optimized alignment
321    /// - **SIMD Ready**: Properly aligned for vectorized operations
322    /// - **Cache Friendly**: Optimized for CPU cache hierarchies
323    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
324    ///
325    /// # Examples
326    ///
327    /// ```
328    /// use train_station::{Tensor, Device};
329    ///
330    /// // Create tensor on CPU device
331    /// let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
332    /// assert_eq!(tensor.device(), Device::cpu());
333    /// assert_eq!(tensor.size(), 6);
334    /// ```
335    #[track_caller]
336    pub fn new_on_device(shape_dims: Vec<usize>, device: Device) -> Self {
337        // For now, only CPU is supported
338        if !device.is_cpu() {
339            panic!("Only CPU device is currently supported. CUDA support is planned for future releases.");
340        }
341
342        let shape = Shape::new(shape_dims);
343        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
344
345        if shape.size == 0 {
346            // Handle zero-sized tensors
347            return Self {
348                data: NonNull::dangling(),
349                shape,
350                device,
351                id,
352                requires_grad: false,
353                grad: None,
354                grad_fn: GradFn::None,
355                allocation_owner: None,
356                _phantom: PhantomData,
357            };
358        }
359
360        // Optimized layout calculation for better cache performance
361        let element_size = std::mem::size_of::<f32>();
362        let total_size = shape.size * element_size;
363
364        // Use cache line alignment for large tensors, smaller alignment for small ones
365        let alignment = if total_size > 4096 {
366            64 // Cache line alignment for large tensors
367        } else if shape.size >= 8 {
368            32 // AVX2 alignment for medium tensors
369        } else {
370            16 // SSE alignment for small tensors
371        };
372
373        let layout = Layout::from_size_align(total_size, alignment)
374            .expect("Failed to create layout for tensor data");
375
376        // Allocate memory via shared Allocation
377        let alloc_obj = Allocation::new(shape.size, alignment, layout);
378        let ptr = alloc_obj.ptr;
379
380        Self {
381            data: ptr,
382            shape,
383            device,
384            id,
385            requires_grad: false,
386            grad: None,
387            grad_fn: GradFn::None,
388            allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
389            _phantom: PhantomData,
390        }
391    }
392
393    /// Enable gradient computation for this tensor
394    ///
395    /// Builder method that enables automatic gradient tracking for this tensor.
396    /// When enabled, all operations involving this tensor will be recorded in
397    /// the computation graph for gradient computation during backward pass.
398    ///
399    /// # Returns
400    ///
401    /// `self` with gradient tracking enabled
402    ///
403    /// # Performance
404    ///
405    /// - **Time Complexity**: O(1) - simple field assignment
406    /// - **Memory**: No additional allocation
407    /// - **Overhead**: Minimal gradtrack tracking overhead when gradients computed
408    ///
409    /// # Examples
410    ///
411    /// ```
412    /// use train_station::Tensor;
413    ///
414    /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
415    /// assert!(tensor.requires_grad());
416    /// ```
417    #[track_caller]
418    pub fn with_requires_grad(mut self) -> Self {
419        self.requires_grad = true;
420        self
421    }
422
423    /// Set gradient tracking for this tensor
424    ///
425    /// Controls whether the gradtrack system tracks operations on this tensor
426    /// and computes gradients during backward pass. When disabled, clears
427    /// any existing gradients and gradient functions.
428    ///
429    /// # Arguments
430    ///
431    /// * `requires_grad` - Whether to track gradients for this tensor
432    ///
433    /// # Performance
434    ///
435    /// - **Time Complexity**: O(1) - simple field assignment
436    /// - **Memory**: May free gradient storage when disabled
437    /// - **Overhead**: Zero overhead when gradients disabled
438    ///
439    /// # Examples
440    ///
441    /// ```
442    /// use train_station::Tensor;
443    ///
444    /// let mut tensor = Tensor::ones(vec![2, 3]);
445    /// tensor.set_requires_grad(true);
446    /// assert!(tensor.requires_grad());
447    ///
448    /// // Disable gradient tracking
449    /// tensor.set_requires_grad(false);
450    /// assert!(!tensor.requires_grad());
451    /// ```
452    #[track_caller]
453    pub fn set_requires_grad(&mut self, requires_grad: bool) {
454        self.requires_grad = requires_grad;
455        if !requires_grad {
456            self.grad = None;
457            self.grad_fn = GradFn::None;
458        }
459    }
460
461    /// Check if this tensor requires gradients
462    ///
463    /// # Returns
464    ///
465    /// `true` if gradient tracking is enabled for this tensor
466    ///
467    /// # Examples
468    ///
469    /// ```
470    /// use train_station::Tensor;
471    ///
472    /// let tensor = Tensor::new(vec![2, 3]);
473    /// assert!(!tensor.requires_grad());
474    ///
475    /// let grad_tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
476    /// assert!(grad_tensor.requires_grad());
477    /// ```
478    #[track_caller]
479    pub fn requires_grad(&self) -> bool {
480        self.requires_grad
481    }
482
483    /// Get the accumulated gradients (if any)
484    ///
485    /// Returns a reference to the gradient tensor if gradients have been computed
486    /// and this tensor has gradient tracking enabled.
487    ///
488    /// # Returns
489    ///
490    /// Optional reference to the gradient tensor, or `None` if no gradients exist
491    ///
492    /// # Examples
493    ///
494    /// ```
495    /// use train_station::Tensor;
496    ///
497    /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
498    /// assert!(tensor.grad().is_none()); // No gradients computed yet
499    /// ```
500    #[track_caller]
501    pub fn grad(&self) -> Option<&Tensor> {
502        // First check if we have a gradient stored directly
503        if let Some(grad) = self.grad.as_ref() {
504            return Some(grad.as_ref());
505        }
506
507        // If not, check the gradient map for accumulated gradients
508        if let Some(_grad) = gradtrack::get_accumulated_gradient(self.id) {
509            // For simplicity, we'll return None here since we can't return a reference
510            // to a temporary value. In a full implementation, we'd store it in self.grad
511            return None;
512        }
513
514        None
515    }
516
517    /// Get accumulated gradient by value (helper for testing)
518    ///
519    /// Returns the gradient tensor by value, which is useful for testing and
520    /// when you need to own the gradient data.
521    ///
522    /// # Returns
523    ///
524    /// Optional gradient tensor, or `None` if no gradients exist
525    ///
526    /// # Examples
527    ///
528    /// ```
529    /// use train_station::Tensor;
530    ///
531    /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
532    /// assert!(tensor.grad_by_value().is_none()); // No gradients computed yet
533    /// ```
534    #[track_caller]
535    pub fn grad_by_value(&self) -> Option<Tensor> {
536        // First check if we have a gradient stored directly
537        if let Some(grad) = self.grad.as_ref() {
538            return Some((**grad).clone());
539        }
540
541        // If not, check the gradient map for accumulated gradients
542        use crate::gradtrack;
543        gradtrack::get_accumulated_gradient(self.id)
544    }
545
546    /// Get the unique ID of this tensor
547    ///
548    /// Returns the unique identifier assigned to this tensor during creation.
549    /// This ID is used for gradtrack tracking and tensor identification.
550    ///
551    /// # Returns
552    ///
553    /// Unique tensor ID as `usize`
554    ///
555    /// # Examples
556    ///
557    /// ```
558    /// use train_station::Tensor;
559    ///
560    /// let tensor1 = Tensor::new(vec![2, 3]);
561    /// let tensor2 = Tensor::new(vec![2, 3]);
562    /// assert_ne!(tensor1.id(), tensor2.id()); // Each tensor has unique ID
563    /// ```
564    #[track_caller]
565    pub fn id(&self) -> usize {
566        self.id
567    }
568
569    /// Detach this tensor from the computation graph
570    ///
571    /// Returns a new tensor with the same data but no gradient tracking.
572    /// This is useful when you want to use a tensor in inference without
573    /// affecting the computation graph.
574    ///
575    /// # Returns
576    ///
577    /// A new tensor with the same data but gradient tracking disabled
578    ///
579    /// # Examples
580    ///
581    /// ```
582    /// use train_station::Tensor;
583    ///
584    /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
585    /// let detached = tensor.detach();
586    /// assert!(!detached.requires_grad());
587    /// assert_eq!(tensor.size(), detached.size());
588    /// ```
589    #[track_caller]
590    pub fn detach(&self) -> Self {
591        let mut detached = Self::new(self.shape.dims.clone());
592
593        // Copy data
594        unsafe {
595            let src = self.as_ptr();
596            let dst = detached.as_mut_ptr();
597            std::ptr::copy_nonoverlapping(src, dst, self.size());
598        }
599
600        detached
601    }
602
603    /// Create a new tensor that doesn't track gradients from this one
604    ///
605    /// Similar to detach() but modifies this tensor in place. This is useful
606    /// when you want to disable gradient tracking for the current tensor
607    /// without creating a copy.
608    ///
609    /// # Examples
610    ///
611    /// ```
612    /// use train_station::Tensor;
613    ///
614    /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
615    /// assert!(tensor.requires_grad());
616    /// tensor.detach_();
617    /// assert!(!tensor.requires_grad());
618    /// ```
619    #[track_caller]
620    pub fn detach_(&mut self) {
621        self.requires_grad = false;
622        self.grad = None;
623        self.grad_fn = GradFn::None;
624    }
625
626    /// Entry point for backward pass on this tensor
627    ///
628    /// Computes gradients for all tensors in the computation graph that have
629    /// `requires_grad` set to true. This is the main entry point for automatic
630    /// differentiation.
631    ///
632    /// # Arguments
633    ///
634    /// * `grad_output` - Optional gradient tensor for the output. If None, assumes
635    ///   the tensor is a scalar (e.g., loss value) and uses a tensor of ones.
636    ///
637    /// # Examples
638    ///
639    /// ```
640    /// use train_station::Tensor;
641    ///
642    /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
643    /// let mut result = tensor.add_scalar(5.0);
644    /// result.backward(None);
645    /// // Note: Gradient computation depends on the gradtrack system implementation
646    /// ```
647    #[track_caller]
648    pub fn backward(&mut self, grad_output: Option<Tensor>) {
649        GradEngine::backward(self, grad_output);
650    }
651
652    /// Returns a raw pointer to the tensor data for unsafe operations
653    ///
654    /// # Safety
655    ///
656    /// This is unsafe because it provides direct access to the underlying memory.
657    /// The caller must ensure:
658    /// - The tensor is not dropped while the pointer is used
659    /// - No concurrent mutable access occurs
660    /// - Bounds are respected
661    #[inline]
662    pub unsafe fn as_ptr(&self) -> *const f32 {
663        self.data.as_ptr()
664    }
665
666    /// Returns a mutable raw pointer to the tensor data for unsafe operations
667    ///
668    /// # Safety
669    ///
670    /// This is unsafe because it provides direct mutable access to the underlying memory.
671    /// The caller must ensure:
672    /// - The tensor is not dropped while the pointer is used
673    /// - No concurrent access occurs
674    /// - Bounds are respected
675    #[inline]
676    pub unsafe fn as_mut_ptr(&mut self) -> *mut f32 {
677        self.data.as_ptr()
678    }
679
680    /// Internal method to set gradient function (used by operations)
681    ///
682    /// Sets the gradient function for this tensor. This is used internally
683    /// by tensor operations to record the computation graph for gradtrack.
684    ///
685    /// # Arguments
686    ///
687    /// * `grad_fn` - The gradient function to set
688    ///
689    /// # Implementation Details
690    ///
691    /// This method is called by tensor operations to register the gradient
692    /// computation function. It only sets the gradient function if gradient
693    /// tracking is enabled for this tensor.
694    pub(crate) fn set_grad_fn(&mut self, grad_fn: GradFn) {
695        if self.requires_grad {
696            self.grad_fn = grad_fn;
697        }
698    }
699
700    /// Get a reference to the gradient function (for gradtrack)
701    ///
702    /// Returns a reference to the gradient function associated with this tensor.
703    /// This is used internally by the gradtrack system to compute gradients.
704    ///
705    /// # Returns
706    ///
707    /// Reference to the gradient function
708    ///
709    /// # Implementation Details
710    ///
711    /// This method is used by the gradtrack engine to access the gradient
712    /// computation function during backward pass.
713    #[track_caller]
714    pub fn grad_fn(&self) -> &GradFn {
715        &self.grad_fn
716    }
717
718    /// Internal method to set requires_grad (used by gradtrack operations)
719    ///
720    /// Sets the gradient tracking flag for this tensor. This is used internally
721    /// by gradtrack operations to control gradient computation.
722    ///
723    /// # Arguments
724    ///
725    /// * `requires_grad` - Whether to enable gradient tracking
726    ///
727    /// # Implementation Details
728    ///
729    /// This method is used internally by the gradtrack system to control
730    /// gradient tracking without triggering additional side effects.
731    pub(crate) fn set_requires_grad_internal(&mut self, requires_grad: bool) {
732        self.requires_grad = requires_grad;
733    }
734
735    /// Internal method to accumulate gradients with optimized operations
736    ///
737    /// Accumulates a gradient tensor with any existing gradients for this tensor.
738    /// This is used internally by the gradtrack system to handle gradient accumulation
739    /// during backward pass.
740    ///
741    /// # Arguments
742    ///
743    /// * `grad` - The gradient tensor to accumulate
744    ///
745    /// # Implementation Details
746    ///
747    /// This method is used internally by the gradtrack engine to accumulate
748    /// gradients from multiple backward passes or operations. It only accumulates
749    /// gradients if gradient tracking is enabled for this tensor.
750    pub(crate) fn accumulate_grad(&mut self, grad: Tensor) {
751        if !self.requires_grad {
752            return;
753        }
754
755        match &self.grad {
756            Some(existing_grad) => {
757                // Use optimized tensor addition but create new tensor for safety
758                let accumulated = existing_grad.add_tensor_optimized(&grad);
759                self.grad = Some(Arc::new(accumulated));
760            }
761            None => {
762                self.grad = Some(Arc::new(grad));
763            }
764        }
765    }
766
767    /// Set gradient from external source
768    ///
769    /// Sets the gradient tensor for this tensor. This is used internally by the
770    /// gradtrack system to set gradients during backward pass.
771    ///
772    /// # Arguments
773    ///
774    /// * `grad` - The gradient tensor to set
775    ///
776    /// # Implementation Details
777    ///
778    /// This method is used internally by the gradtrack engine to set gradients
779    /// during backward pass. It only sets the gradient if gradient tracking is
780    /// enabled for this tensor.
781    #[track_caller]
782    pub fn set_grad(&mut self, grad: Tensor) {
783        if self.requires_grad {
784            self.grad = Some(Arc::new(grad));
785        }
786    }
787
788    /// Clear accumulated gradients for this tensor
789    ///
790    /// This method is used by optimizers to zero gradients before each backward pass.
791    /// It clears any accumulated gradients, allowing for fresh gradient computation.
792    ///
793    /// # Examples
794    ///
795    /// ```
796    /// use train_station::Tensor;
797    ///
798    /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
799    /// tensor.set_grad(Tensor::ones(vec![2, 3]));
800    /// assert!(tensor.grad().is_some());
801    /// tensor.zero_grad();
802    /// assert!(tensor.grad().is_none());
803    /// ```
804    #[track_caller]
805    pub fn zero_grad(&mut self) {
806        self.grad = None;
807    }
808
809    /// Negate all elements in the tensor in-place
810    ///
811    /// This is used internally for gradient computation in subtraction operations.
812    /// For tensor - tensor operations, the second operand gets a negated gradient.
813    ///
814    /// # Implementation Details
815    ///
816    /// This method is used internally by the gradtrack system to compute gradients
817    /// for subtraction operations. It uses SIMD optimization when available for
818    /// better performance.
819    #[inline]
820    pub(crate) fn negate_inplace(&mut self) {
821        if self.shape.size == 0 {
822            return;
823        }
824
825        unsafe {
826            let ptr = self.data.as_ptr();
827
828            #[cfg(target_arch = "x86_64")]
829            {
830                // Use SIMD for better performance when available
831                if is_x86_feature_detected!("avx2") {
832                    self.negate_simd_avx2(ptr);
833                    return;
834                }
835            }
836
837            // Fallback to scalar operations
838            for i in 0..self.shape.size {
839                *ptr.add(i) = -*ptr.add(i);
840            }
841        }
842    }
843
844    /// SIMD-optimized negation using AVX2 instructions
845    ///
846    /// Performs in-place negation of tensor elements using AVX2 SIMD instructions
847    /// for improved performance on x86_64 architectures.
848    ///
849    /// # Arguments
850    ///
851    /// * `ptr` - Raw pointer to the tensor data
852    ///
853    /// # Safety
854    ///
855    /// The caller must ensure:
856    /// - `ptr` is a valid pointer to tensor data
857    /// - The tensor size matches the actual data size
858    /// - The tensor is not moved or dropped during this operation
859    ///
860    /// # Implementation Details
861    ///
862    /// This method is used internally by `negate_inplace` when AVX2 is available.
863    /// It processes 8 elements per iteration using AVX2 instructions.
864    #[cfg(target_arch = "x86_64")]
865    #[inline]
866    unsafe fn negate_simd_avx2(&self, ptr: *mut f32) {
867        use std::arch::x86_64::_mm256_setzero_ps;
868
869        let size = self.shape.size;
870        let zero_vec = _mm256_setzero_ps();
871        let simd_count = size / 8; // Process 8 elements per iteration
872        let mut offset = 0;
873
874        // SIMD loop for negation
875        for _ in 0..simd_count {
876            use std::arch::x86_64::{_mm256_load_ps, _mm256_store_ps, _mm256_sub_ps};
877
878            let vec = _mm256_load_ps(ptr.add(offset));
879            let neg_vec = _mm256_sub_ps(zero_vec, vec);
880            _mm256_store_ps(ptr.add(offset), neg_vec);
881            offset += 8;
882        }
883
884        // Handle remaining elements
885        for i in offset..size {
886            *ptr.add(i) = -*ptr.add(i);
887        }
888    }
889
890    // ===== Memory Layout and Optimization API =====
891
892    /// Checks if the tensor data is stored contiguously in memory
893    ///
894    /// # Returns
895    ///
896    /// `true` if the tensor data is contiguous, enabling optimized SIMD operations
897    ///
898    /// # Examples
899    ///
900    /// ```
901    /// use train_station::Tensor;
902    ///
903    /// let tensor = Tensor::new(vec![2, 3, 4]);
904    /// assert!(tensor.is_contiguous());
905    /// ```
906    #[inline]
907    #[track_caller]
908    pub fn is_contiguous(&self) -> bool {
909        self.shape.is_contiguous()
910    }
911
912    /// Checks if this tensor is a view of another tensor
913    ///
914    /// # Returns
915    ///
916    /// `true` if this tensor is a view (non-contiguous reference)
917    #[inline]
918    #[track_caller]
919    pub fn is_view(&self) -> bool {
920        self.shape.is_view()
921    }
922
923    /// Gets the memory strides for all dimensions
924    ///
925    /// # Returns
926    ///
927    /// Reference to the stride vector for efficient memory access calculations
928    ///
929    /// # Examples
930    ///
931    /// ```
932    /// use train_station::Tensor;
933    ///
934    /// let tensor = Tensor::new(vec![2, 3, 4]);
935    /// assert_eq!(tensor.strides(), &[12, 4, 1]);
936    /// ```
937    #[inline]
938    #[track_caller]
939    pub fn strides(&self) -> &[usize] {
940        self.shape.strides()
941    }
942
943    /// Gets the memory stride for a specific dimension
944    ///
945    /// # Arguments
946    ///
947    /// * `dim` - The dimension index
948    ///
949    /// # Returns
950    ///
951    /// The memory stride for the given dimension
952    ///
953    /// # Panics
954    ///
955    /// Panics if `dim` is out of bounds
956    #[inline]
957    #[track_caller]
958    pub fn stride(&self, dim: usize) -> usize {
959        self.shape.stride(dim)
960    }
961
962    /// Gets the memory layout type for optimization decisions
963    ///
964    /// # Returns
965    ///
966    /// Reference to the memory layout information
967    #[inline]
968    #[track_caller]
969    pub fn layout(&self) -> &crate::tensor::MemoryLayout {
970        self.shape.layout()
971    }
972
973    /// Calculates the linear memory offset for given multi-dimensional indices
974    ///
975    /// # Arguments
976    ///
977    /// * `indices` - Vector of indices for each dimension
978    ///
979    /// # Returns
980    ///
981    /// Linear memory offset for direct memory access
982    ///
983    /// # Examples
984    ///
985    /// ```
986    /// use train_station::Tensor;
987    ///
988    /// let tensor = Tensor::new(vec![2, 3, 4]);
989    /// let offset = tensor.memory_offset(&[1, 2, 3]);
990    /// // offset = 1*12 + 2*4 + 3*1 = 23
991    /// ```
992    #[inline]
993    #[track_caller]
994    pub fn memory_offset(&self, indices: &[usize]) -> usize {
995        self.shape.offset(indices)
996    }
997
998    /// Broadcast this tensor with another tensor for element-wise operations
999    ///
1000    /// Returns a tuple containing:
1001    /// - Broadcasted view of self
1002    /// - Broadcasted view of other
1003    /// - Result shape for the operation
1004    ///
1005    /// # Arguments
1006    ///
1007    /// * `other` - The tensor to broadcast with
1008    ///
1009    /// # Returns
1010    ///
1011    /// A tuple `(broadcasted_self, broadcasted_other, result_shape)`
1012    ///
1013    /// # Examples
1014    ///
1015    /// ```
1016    /// use train_station::Tensor;
1017    ///
1018    /// let a = Tensor::ones(vec![2, 1, 4]);
1019    /// let b = Tensor::ones(vec![3, 1]);
1020    /// let result = a.broadcast_with(&b);
1021    /// assert!(result.is_ok());
1022    /// ```
1023    #[track_caller]
1024    pub fn broadcast_with(
1025        &self,
1026        other: &Tensor,
1027    ) -> Result<
1028        (Tensor, Tensor, crate::tensor::Shape),
1029        crate::tensor::ops::broadcasting::BroadcastError,
1030    > {
1031        crate::tensor::ops::broadcasting::broadcast_shapes(self, other)
1032    }
1033
1034    /// Checks if the tensor data is properly aligned for SIMD operations
1035    ///
1036    /// # Returns
1037    ///
1038    /// `true` if the tensor data is aligned to 32-byte boundaries for AVX2
1039    #[inline]
1040    #[track_caller]
1041    pub fn is_simd_aligned(&self) -> bool {
1042        (self.data.as_ptr() as usize) % 32 == 0
1043    }
1044
1045    /// Gets the memory alignment of the tensor data
1046    ///
1047    /// # Returns
1048    ///
1049    /// The memory alignment in bytes (typically 32 for SIMD optimization)
1050    #[inline]
1051    #[track_caller]
1052    pub fn memory_alignment(&self) -> usize {
1053        // Our tensors are allocated with 32-byte alignment for AVX2
1054        32
1055    }
1056
1057    /// Checks if this tensor is broadcastable with another tensor
1058    ///
1059    /// # Arguments
1060    ///
1061    /// * `other` - The other tensor to check broadcasting compatibility
1062    ///
1063    /// # Returns
1064    ///
1065    /// `true` if the tensors are broadcastable according to NumPy broadcasting rules
1066    ///
1067    /// # Examples
1068    ///
1069    /// ```
1070    /// use train_station::Tensor;
1071    ///
1072    /// let a = Tensor::new(vec![2, 3, 4]);
1073    /// let b = Tensor::new(vec![1, 3, 4]);
1074    /// assert!(a.is_broadcastable_with(&b));
1075    /// ```
1076    #[inline]
1077    #[track_caller]
1078    pub fn is_broadcastable_with(&self, other: &Tensor) -> bool {
1079        self.shape.is_broadcastable_with(&other.shape)
1080    }
1081
1082    /// Gets the total number of bytes allocated for this tensor
1083    ///
1084    /// # Returns
1085    ///
1086    /// Total memory footprint in bytes
1087    #[inline]
1088    #[track_caller]
1089    pub fn memory_footprint(&self) -> usize {
1090        self.shape.size * std::mem::size_of::<f32>()
1091    }
1092
1093    /// Get a single element from the tensor at the specified indices
1094    ///
1095    /// # Arguments
1096    ///
1097    /// * `indices` - Multi-dimensional indices to access the element
1098    ///
1099    /// # Returns
1100    ///
1101    /// The value at the specified position
1102    ///
1103    /// # Panics
1104    ///
1105    /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1106    ///
1107    /// # Examples
1108    ///
1109    /// ```
1110    /// use train_station::Tensor;
1111    ///
1112    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1113    /// let value = tensor.get(&[0, 1]);
1114    /// assert_eq!(value, 2.0);
1115    /// ```
1116    #[track_caller]
1117    pub fn get(&self, indices: &[usize]) -> f32 {
1118        assert_eq!(
1119            indices.len(),
1120            self.shape().rank(),
1121            "Indices length must match tensor rank"
1122        );
1123
1124        // Check bounds
1125        for (i, &idx) in indices.iter().enumerate() {
1126            assert!(
1127                idx < self.shape().dims[i],
1128                "Index {} out of bounds for dimension {}",
1129                idx,
1130                i
1131            );
1132        }
1133
1134        let offset = self.memory_offset(indices);
1135        unsafe { *self.as_ptr().add(offset) }
1136    }
1137
1138    /// Set a single element in the tensor at the specified indices
1139    ///
1140    /// # Arguments
1141    ///
1142    /// * `indices` - Multi-dimensional indices to set the element
1143    /// * `value` - The value to set
1144    ///
1145    /// # Panics
1146    ///
1147    /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1148    ///
1149    /// # Examples
1150    ///
1151    /// ```
1152    /// use train_station::Tensor;
1153    ///
1154    /// let mut tensor = Tensor::new(vec![2, 2]);
1155    /// tensor.set(&[0, 1], 42.0);
1156    /// assert_eq!(tensor.get(&[0, 1]), 42.0);
1157    /// ```
1158    #[track_caller]
1159    pub fn set(&mut self, indices: &[usize], value: f32) {
1160        assert_eq!(
1161            indices.len(),
1162            self.shape().rank(),
1163            "Indices length must match tensor rank"
1164        );
1165
1166        // Check bounds
1167        for (i, &idx) in indices.iter().enumerate() {
1168            assert!(
1169                idx < self.shape().dims[i],
1170                "Index {} out of bounds for dimension {}",
1171                idx,
1172                i
1173            );
1174        }
1175
1176        let offset = self.memory_offset(indices);
1177        unsafe {
1178            *self.as_mut_ptr().add(offset) = value;
1179        }
1180    }
1181
1182    /// Returns a safe slice of the tensor's underlying data
1183    ///
1184    /// Provides safe access to the tensor's data without requiring unsafe pointer operations.
1185    /// This is the preferred way to access tensor data for reading values, comparisons,
1186    /// and other operations that don't require direct pointer manipulation.
1187    ///
1188    /// # Returns
1189    ///
1190    /// A slice containing all tensor elements in row-major order
1191    ///
1192    /// # Performance
1193    ///
1194    /// - **Zero-Cost**: Direct slice creation with no copying
1195    /// - **Cache-Friendly**: Sequential memory access pattern
1196    /// - **Safe**: No unsafe code required for basic data access
1197    ///
1198    /// # Examples
1199    ///
1200    /// ```
1201    /// use train_station::Tensor;
1202    ///
1203    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1204    /// let data = tensor.data();
1205    ///
1206    /// // Safe indexing and comparisons
1207    /// assert_eq!(data[0], 1.0);
1208    /// assert_eq!(data.len(), tensor.size());
1209    /// ```
1210    #[inline]
1211    #[track_caller]
1212    pub fn data(&self) -> &[f32] {
1213        if self.size() == 0 {
1214            return &[];
1215        }
1216        unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()) }
1217    }
1218
1219    /// Returns a mutable slice of the tensor's underlying data
1220    ///
1221    /// Provides safe mutable access to the tensor's data without requiring unsafe
1222    /// pointer operations. Use this for in-place modifications of tensor values.
1223    ///
1224    /// # Returns
1225    ///
1226    /// A mutable slice containing all tensor elements in row-major order
1227    ///
1228    /// # Performance
1229    ///
1230    /// - **Zero-Cost**: Direct slice creation with no copying
1231    /// - **Cache-Friendly**: Sequential memory access pattern
1232    /// - **Safe**: No unsafe code required for basic data modification
1233    ///
1234    /// # Examples
1235    ///
1236    /// ```
1237    /// use train_station::Tensor;
1238    ///
1239    /// let mut tensor = Tensor::new(vec![2, 2]);
1240    /// let data = tensor.data_mut();
1241    ///
1242    /// // Safe indexing for modification
1243    /// data[0] = 1.0;
1244    /// data[1] = 2.0;
1245    ///
1246    /// assert_eq!(tensor.get(&[0, 0]), 1.0);
1247    /// assert_eq!(tensor.get(&[0, 1]), 2.0);
1248    /// ```
1249    #[inline]
1250    #[track_caller]
1251    pub fn data_mut(&mut self) -> &mut [f32] {
1252        if self.size() == 0 {
1253            return &mut [];
1254        }
1255        unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) }
1256    }
1257
1258    /// Extract scalar value from single-element tensor
1259    ///
1260    /// This method provides a convenient way to extract the scalar value from
1261    /// tensors that contain exactly one element. This is commonly used with
1262    /// element iterator results and scalar tensor operations.
1263    ///
1264    /// # Returns
1265    ///
1266    /// The scalar value contained in this tensor
1267    ///
1268    /// # Panics
1269    ///
1270    /// Panics if the tensor does not contain exactly one element
1271    ///
1272    /// # Examples
1273    ///
1274    /// ```
1275    /// use train_station::Tensor;
1276    ///
1277    /// // Single-element tensor
1278    /// let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
1279    /// assert_eq!(scalar.value(), 42.0);
1280    /// ```
1281    #[inline]
1282    #[track_caller]
1283    pub fn value(&self) -> f32 {
1284        assert_eq!(
1285            self.size(),
1286            1,
1287            "value() can only be called on tensors with exactly one element. \
1288             This tensor has {} elements with shape {:?}",
1289            self.size(),
1290            self.shape().dims
1291        );
1292        self.data()[0]
1293    }
1294
1295    /// Create a view with a new shape (requires contiguous memory)
1296    ///
1297    /// Behaves like PyTorch `view`: tensor must be contiguous and the total
1298    /// number of elements must remain the same. Supports -1 inference for one dimension.
1299    ///
1300    /// # Arguments
1301    ///
1302    /// * `new_shape` - New shape for the tensor (can contain -1 for inference)
1303    ///
1304    /// # Returns
1305    ///
1306    /// A tensor viewing the same data with a new shape
1307    ///
1308    /// # Examples
1309    ///
1310    /// ```
1311    /// use train_station::Tensor;
1312    ///
1313    /// let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
1314    /// let y = x.view(vec![2, 2]);
1315    /// assert_eq!(y.shape().dims, vec![2, 2]);
1316    /// ```
1317    #[track_caller]
1318    pub fn view(&self, new_shape: Vec<i32>) -> Tensor {
1319        // Use the views module implementation
1320        use crate::tensor::transform::view::TensorViewExt;
1321        TensorViewExt::view(self, new_shape)
1322    }
1323
1324    /// Create an element view for the specified index
1325    ///
1326    /// Returns a scalar tensor (shape \[1\]) that views a single element
1327    /// of the source tensor. Maintains gradient tracking.
1328    ///
1329    /// # Arguments
1330    ///
1331    /// * `index` - Linear index of the element to view
1332    ///
1333    /// # Returns
1334    ///
1335    /// A scalar tensor viewing the specified element
1336    ///
1337    /// # Examples
1338    ///
1339    /// ```
1340    /// use train_station::Tensor;
1341    ///
1342    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
1343    /// let element = tensor.element_view(1);
1344    /// assert_eq!(element.value(), 2.0);
1345    /// ```
1346    #[track_caller]
1347    pub fn element_view(&self, index: usize) -> Tensor {
1348        use crate::tensor::transform::view::TensorViewExt;
1349        TensorViewExt::element_view(self, index)
1350    }
1351
1352    /// Create a slice view of the tensor
1353    ///
1354    /// Returns a view of a contiguous or strided slice of the source tensor.
1355    ///
1356    /// # Arguments
1357    ///
1358    /// * `start` - Starting index
1359    /// * `step` - Step size (1 for contiguous)
1360    /// * `length` - Number of elements
1361    ///
1362    /// # Returns
1363    ///
1364    /// A tensor viewing the specified slice
1365    ///
1366    /// # Examples
1367    ///
1368    /// ```
1369    /// use train_station::Tensor;
1370    ///
1371    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1372    /// let slice = tensor.slice_view(1, 2, 2); // [2.0, 4.0]
1373    /// assert_eq!(slice.data(), &[2.0, 4.0]);
1374    /// ```
1375    #[track_caller]
1376    pub fn slice_view(&self, start: usize, step: usize, length: usize) -> Tensor {
1377        use crate::tensor::transform::view::TensorViewExt;
1378        TensorViewExt::slice_view(self, start, step, length)
1379    }
1380
1381    /// Create a tensor view from raw components
1382    ///
1383    /// Creates a tensor that views existing memory with the specified shape and device.
1384    /// The tensor shares memory with the original allocation through the allocation_owner.
1385    ///
1386    /// # Safety
1387    ///
1388    /// The caller must ensure:
1389    /// - `data` is valid for the number of elements specified by `shape`
1390    /// - `data` remains valid for the lifetime of the returned tensor
1391    /// - `allocation_owner` properly manages the memory lifecycle
1392    ///
1393    /// # Arguments
1394    ///
1395    /// * `data` - Raw pointer to the tensor data
1396    /// * `shape` - Shape of the tensor view
1397    /// * `device` - Device where the tensor is located
1398    /// * `allocation_owner` - Optional shared allocation owner
1399    ///
1400    /// # Returns
1401    ///
1402    /// A new tensor that views the specified memory
1403    ///
1404    /// # Implementation Details
1405    ///
1406    /// This method is used internally to create tensor views that share memory
1407    /// with other tensors. It's primarily used for view operations and memory
1408    /// management.
1409    pub(crate) fn from_raw_view(
1410        data: *const f32,
1411        shape: Shape,
1412        device: Device,
1413        allocation_owner: Option<Arc<Allocation>>,
1414    ) -> Self {
1415        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1416
1417        Self {
1418            data: NonNull::new(data as *mut f32).expect("Data pointer cannot be null"),
1419            shape,
1420            device,
1421            id,
1422            requires_grad: false,
1423            grad: None,
1424            grad_fn: GradFn::None,
1425            allocation_owner,
1426            _phantom: PhantomData,
1427        }
1428    }
1429
1430    /// Get the allocation owner for this tensor
1431    ///
1432    /// Returns the shared allocation owner if this tensor is a view,
1433    /// or None if this tensor owns its memory directly.
1434    ///
1435    /// # Returns
1436    ///
1437    /// Optional reference to the allocation owner
1438    ///
1439    /// # Implementation Details
1440    ///
1441    /// This method is used internally to manage memory lifecycle for tensor views.
1442    /// It helps determine whether a tensor shares memory with another tensor.
1443    #[track_caller]
1444    pub fn allocation_owner(&self) -> Option<&Arc<Allocation>> {
1445        self.allocation_owner.as_ref()
1446    }
1447
1448    /// Create a new tensor with uninitialized memory
1449    ///
1450    /// This method allocates memory for a tensor without initializing it to any value.
1451    /// This is useful for performance-critical operations where the memory will be
1452    /// immediately overwritten, such as matrix multiplication results.
1453    ///
1454    /// # Safety
1455    ///
1456    /// The caller must ensure that all memory is written before reading from the tensor.
1457    /// Reading from uninitialized memory is undefined behavior.
1458    ///
1459    /// # Arguments
1460    ///
1461    /// * `shape_dims` - The dimensions of the tensor
1462    ///
1463    /// # Returns
1464    ///
1465    /// A tensor with uninitialized memory
1466    ///
1467    /// # Performance
1468    ///
1469    /// - **Zero Initialization**: Skips memory initialization for maximum performance
1470    /// - **SIMD Ready**: Properly aligned for vectorized operations
1471    /// - **Memory Efficient**: Uses optimized alignment strategies
1472    ///
1473    /// # Example
1474    ///
1475    /// ```
1476    /// use train_station::Tensor;
1477    ///
1478    /// // Create uninitialized tensor for matmul result
1479    /// let mut result = Tensor::new_uninitialized(vec![100, 100]);
1480    /// // Initialize the memory before use
1481    /// for value in result.data_mut() {
1482    ///     *value = 0.0;
1483    /// }
1484    /// ```
1485    #[inline]
1486    #[track_caller]
1487    pub fn new_uninitialized(shape_dims: Vec<usize>) -> Self {
1488        let shape = Shape::new(shape_dims);
1489        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1490
1491        if shape.size == 0 {
1492            // Handle zero-sized tensors
1493            return Self {
1494                data: NonNull::dangling(),
1495                shape,
1496                device: current_device(),
1497                id,
1498                requires_grad: false,
1499                grad: None,
1500                grad_fn: GradFn::None,
1501                allocation_owner: None,
1502                _phantom: PhantomData,
1503            };
1504        }
1505
1506        // Optimized layout calculation for better cache performance
1507        let element_size = std::mem::size_of::<f32>();
1508        let total_size = shape.size * element_size;
1509
1510        // Use cache line alignment for large tensors, smaller alignment for small ones
1511        let alignment = if total_size > 4096 {
1512            64 // Cache line alignment for large tensors
1513        } else if shape.size >= 8 {
1514            32 // AVX2 alignment for medium tensors
1515        } else {
1516            16 // SSE alignment for small tensors
1517        };
1518
1519        let layout = Layout::from_size_align(total_size, alignment)
1520            .expect("Failed to create layout for tensor data");
1521
1522        // Allocate memory via shared Allocation (uninitialized)
1523        let alloc_obj = Allocation::new_uninitialized(shape.size, alignment, layout);
1524        let ptr = alloc_obj.ptr;
1525
1526        Self {
1527            data: ptr,
1528            shape,
1529            device: current_device(),
1530            id,
1531            requires_grad: false,
1532            grad: None,
1533            grad_fn: GradFn::None,
1534            allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
1535            _phantom: PhantomData,
1536        }
1537    }
1538}