train_station/tensor/core/
utils.rs

1//! Tensor utility functions and core implementation methods
2//!
3//! This module provides essential utility functions for tensor creation, memory management,
4//! gradient tracking, and optimization. It contains the core implementation methods
5//! that enable efficient tensor operations and gradtrack functionality.
6//!
7//! # Key Features
8//!
9//! - **Tensor Creation**: Optimized constructors with memory alignment
10//! - **Memory Management**: Safe memory access and allocation utilities
11//! - **Gradient Tracking**: GradTrack system integration and gradient management
12//! - **Performance Optimization**: SIMD-ready memory layout and alignment
13//! - **Device Management**: CPU and future CUDA device support
14//! - **Memory Layout**: Contiguous, strided, and view memory access patterns
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Alignment**: 16-byte SSE, 32-byte AVX2, 64-byte cache-line alignment
19//! - **SIMD Optimization**: Properly aligned memory for vectorized operations
20//! - **Zero-Cost Abstractions**: Minimal overhead for utility operations
21//! - **Thread Safety**: Atomic operations for gradient tracking and ID generation
22//! - **Memory Efficiency**: Optimized allocation strategies for different tensor sizes
23//!
24//! # Examples
25//!
26//! ## Basic Tensor Creation
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensors of different sizes
32//! let small_tensor = Tensor::new(vec![2, 3]);      // 16-byte alignment
33//! let medium_tensor = Tensor::new(vec![32, 32]);   // 32-byte alignment
34//! let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
35//!
36//! // Initialize data before use
37//! let mut tensor = Tensor::new(vec![2, 3]);
38//! tensor.fill(0.0); // Initialize with zeros
39//! ```
40//!
41//! ## Gradient Tracking
42//!
43//! ```
44//! use train_station::Tensor;
45//!
46//! let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
47//! assert!(tensor.requires_grad());
48//! ```
49//!
50//! ## Memory Access
51//!
52//! ```
53//! use train_station::Tensor;
54//!
55//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
56//! let value = tensor.get(&[0, 1]);
57//! assert_eq!(value, 2.0);
58//!
59//! let mut tensor = Tensor::new(vec![2, 2]);
60//! tensor.set(&[0, 1], 42.0);
61//! assert_eq!(tensor.get(&[0, 1]), 42.0);
62//! ```
63
64use std::alloc::Layout;
65use std::marker::PhantomData;
66use std::ptr::NonNull;
67use std::sync::atomic::Ordering;
68use std::sync::Arc;
69
70use crate::device::current_device;
71use crate::gradtrack::{self, GradEngine, GradFn};
72use crate::tensor::core::{Allocation, Device, TENSOR_ID_COUNTER};
73use crate::tensor::Shape;
74
75use super::Tensor;
76
77impl Tensor {
78    /// Creates a new tensor with the specified shape and optimized memory layout
79    ///
80    /// Allocates memory with size-dependent alignment for optimal performance:
81    /// - Small tensors (≤8 elements): 16-byte SSE alignment
82    /// - Medium tensors (8-1024 elements): 32-byte AVX2 alignment
83    /// - Large tensors (>1024 elements): 64-byte cache-line alignment
84    ///
85    /// # Arguments
86    ///
87    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
88    ///
89    /// # Returns
90    ///
91    /// A new tensor with uninitialized data. The data must be initialized
92    /// before use to avoid undefined behavior.
93    ///
94    /// # Performance
95    ///
96    /// - **Memory Allocation**: Single allocation with optimized alignment
97    /// - **SIMD Ready**: Properly aligned for vectorized operations
98    /// - **Cache Friendly**: Optimized for CPU cache hierarchies
99    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
100    ///
101    /// # Safety
102    ///
103    /// The returned tensor contains uninitialized memory. You must initialize
104    /// the data before performing any operations that read from it.
105    ///
106    /// # Examples
107    ///
108    /// ```
109    /// use train_station::Tensor;
110    ///
111    /// // Create tensors of different sizes
112    /// let small_tensor = Tensor::new(vec![2, 3]);      // 16-byte alignment
113    /// let medium_tensor = Tensor::new(vec![32, 32]);   // 32-byte alignment
114    /// let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
115    ///
116    /// // Initialize data before use
117    /// let mut tensor = Tensor::new(vec![2, 3]);
118    /// tensor.fill(0.0); // Initialize with zeros
119    /// ```
120    #[inline]
121    pub fn new(shape_dims: Vec<usize>) -> Self {
122        let shape = Shape::new(shape_dims);
123        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
124
125        if shape.size == 0 {
126            // Handle zero-sized tensors
127            return Self {
128                data: NonNull::dangling(),
129                shape,
130                device: current_device(),
131                id,
132                requires_grad: false,
133                grad: None,
134                grad_fn: GradFn::None,
135                allocation_owner: None,
136                _phantom: PhantomData,
137            };
138        }
139
140        // Optimized layout calculation for better cache performance
141        let element_size = std::mem::size_of::<f32>();
142        let total_size = shape.size * element_size;
143
144        // Use cache line alignment for large tensors, smaller alignment for small ones
145        let alignment = if total_size > 4096 {
146            64 // Cache line alignment for large tensors
147        } else if shape.size >= 8 {
148            32 // AVX2 alignment for medium tensors
149        } else {
150            16 // SSE alignment for small tensors
151        };
152
153        let layout = Layout::from_size_align(total_size, alignment)
154            .expect("Failed to create layout for tensor data");
155
156        // Allocate memory via shared Allocation
157        let alloc_obj = Allocation::new(shape.size, alignment, layout);
158        let ptr = alloc_obj.ptr;
159
160        Self {
161            data: ptr,
162            shape,
163            device: current_device(),
164            id,
165            requires_grad: false,
166            grad: None,
167            grad_fn: GradFn::None,
168            allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
169            _phantom: PhantomData,
170        }
171    }
172
173    /// Returns the shape and dimensional information of the tensor
174    ///
175    /// Provides access to the tensor's dimensions, size, strides, and memory
176    /// layout information. This is used for shape validation, memory access
177    /// calculations, and optimization decisions.
178    ///
179    /// # Returns
180    ///
181    /// Reference to the tensor's shape information containing dimensions,
182    /// size, strides, and memory layout type.
183    ///
184    /// # Performance
185    ///
186    /// - **Time Complexity**: O(1) - direct field access
187    /// - **Memory**: No allocation - returns reference to existing data
188    ///
189    /// # Examples
190    ///
191    /// ```
192    /// use train_station::Tensor;
193    ///
194    /// let tensor = Tensor::new(vec![2, 3, 4]);
195    /// let shape = tensor.shape();
196    /// assert_eq!(shape.dims, vec![2, 3, 4]);
197    /// assert_eq!(shape.size, 24);
198    /// assert_eq!(shape.rank(), 3);
199    /// ```
200    #[inline]
201    pub fn shape(&self) -> &Shape {
202        &self.shape
203    }
204
205    /// Returns the total number of elements in the tensor
206    ///
207    /// Provides the total count of elements across all dimensions. This is
208    /// used for memory allocation, iteration bounds, and performance optimization.
209    ///
210    /// # Returns
211    ///
212    /// Total number of elements as `usize`
213    ///
214    /// # Performance
215    ///
216    /// - **Time Complexity**: O(1) - direct field access
217    /// - **Memory**: No allocation - returns stored value
218    ///
219    /// # Examples
220    ///
221    /// ```
222    /// use train_station::Tensor;
223    ///
224    /// let tensor = Tensor::new(vec![2, 3, 4]);
225    /// assert_eq!(tensor.size(), 24); // 2 * 3 * 4
226    ///
227    /// let scalar = Tensor::new(vec![1]);
228    /// assert_eq!(scalar.size(), 1);
229    ///
230    /// let empty = Tensor::new(vec![0]);
231    /// assert_eq!(empty.size(), 0);
232    /// ```
233    #[inline]
234    pub fn size(&self) -> usize {
235        self.shape.size
236    }
237
238    /// Returns the device where this tensor is located
239    ///
240    /// Provides the physical location of the tensor data (CPU/GPU). This
241    /// determines which operations can be performed on the tensor and where
242    /// computations will be executed.
243    ///
244    /// # Returns
245    ///
246    /// Device enum indicating the tensor's physical location
247    ///
248    /// # Performance
249    ///
250    /// - **Time Complexity**: O(1) - direct field access
251    /// - **Memory**: No allocation - returns stored value
252    ///
253    /// # Examples
254    ///
255    /// ```
256    /// use train_station::Tensor;
257    ///
258    /// let tensor = Tensor::new(vec![2, 3]);
259    /// assert!(tensor.device().is_cpu());
260    /// assert!(!tensor.device().is_cuda());
261    /// ```
262    #[inline]
263    pub fn device(&self) -> Device {
264        self.device
265    }
266
267    /// Creates a new tensor with the specified shape on a specific device
268    ///
269    /// Allocates memory on the specified device with the same optimized alignment
270    /// strategy as `new()`. Currently supports CPU device with future CUDA support.
271    ///
272    /// # Arguments
273    ///
274    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
275    /// * `device` - The device where the tensor should be allocated
276    ///
277    /// # Returns
278    ///
279    /// A new tensor with uninitialized data on the specified device
280    ///
281    /// # Performance
282    ///
283    /// - **Memory Allocation**: Device-specific allocation with optimized alignment
284    /// - **SIMD Ready**: Properly aligned for vectorized operations on target device
285    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
286    ///
287    /// # Panics
288    ///
289    /// Panics if the specified device is not supported (e.g., CUDA without feature flag)
290    ///
291    /// # Examples
292    ///
293    /// ```
294    /// use train_station::Tensor;
295    ///
296    /// let tensor = Tensor::new_on_device(vec![2, 3], train_station::Device::cpu());
297    /// assert!(tensor.device().is_cpu());
298    /// assert_eq!(tensor.size(), 6);
299    /// ```
300    ///
301    /// # Arguments
302    ///
303    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
304    /// * `device` - Device where the tensor should be allocated
305    ///
306    /// # Returns
307    ///
308    /// A new tensor with uninitialized data on the specified device
309    ///
310    /// # Panics
311    ///
312    /// Panics if the device is not supported (currently only CPU is supported)
313    ///
314    /// # Performance
315    ///
316    /// - **Memory Allocation**: Single allocation with optimized alignment
317    /// - **SIMD Ready**: Properly aligned for vectorized operations
318    /// - **Cache Friendly**: Optimized for CPU cache hierarchies
319    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
320    ///
321    /// # Examples
322    ///
323    /// ```
324    /// use train_station::{Tensor, Device};
325    ///
326    /// // Create tensor on CPU device
327    /// let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
328    /// assert_eq!(tensor.device(), Device::cpu());
329    /// assert_eq!(tensor.size(), 6);
330    /// ```
331    pub fn new_on_device(shape_dims: Vec<usize>, device: Device) -> Self {
332        // For now, only CPU is supported
333        if !device.is_cpu() {
334            panic!("Only CPU device is currently supported. CUDA support is planned for future releases.");
335        }
336
337        let shape = Shape::new(shape_dims);
338        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
339
340        if shape.size == 0 {
341            // Handle zero-sized tensors
342            return Self {
343                data: NonNull::dangling(),
344                shape,
345                device,
346                id,
347                requires_grad: false,
348                grad: None,
349                grad_fn: GradFn::None,
350                allocation_owner: None,
351                _phantom: PhantomData,
352            };
353        }
354
355        // Optimized layout calculation for better cache performance
356        let element_size = std::mem::size_of::<f32>();
357        let total_size = shape.size * element_size;
358
359        // Use cache line alignment for large tensors, smaller alignment for small ones
360        let alignment = if total_size > 4096 {
361            64 // Cache line alignment for large tensors
362        } else if shape.size >= 8 {
363            32 // AVX2 alignment for medium tensors
364        } else {
365            16 // SSE alignment for small tensors
366        };
367
368        let layout = Layout::from_size_align(total_size, alignment)
369            .expect("Failed to create layout for tensor data");
370
371        // Allocate memory via shared Allocation
372        let alloc_obj = Allocation::new(shape.size, alignment, layout);
373        let ptr = alloc_obj.ptr;
374
375        Self {
376            data: ptr,
377            shape,
378            device,
379            id,
380            requires_grad: false,
381            grad: None,
382            grad_fn: GradFn::None,
383            allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
384            _phantom: PhantomData,
385        }
386    }
387
388    /// Enable gradient computation for this tensor
389    ///
390    /// Builder method that enables automatic gradient tracking for this tensor.
391    /// When enabled, all operations involving this tensor will be recorded in
392    /// the computation graph for gradient computation during backward pass.
393    ///
394    /// # Returns
395    ///
396    /// `self` with gradient tracking enabled
397    ///
398    /// # Performance
399    ///
400    /// - **Time Complexity**: O(1) - simple field assignment
401    /// - **Memory**: No additional allocation
402    /// - **Overhead**: Minimal gradtrack tracking overhead when gradients computed
403    ///
404    /// # Examples
405    ///
406    /// ```
407    /// use train_station::Tensor;
408    ///
409    /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
410    /// assert!(tensor.requires_grad());
411    /// ```
412    pub fn with_requires_grad(mut self) -> Self {
413        self.requires_grad = true;
414        self
415    }
416
417    /// Set gradient tracking for this tensor
418    ///
419    /// Controls whether the gradtrack system tracks operations on this tensor
420    /// and computes gradients during backward pass. When disabled, clears
421    /// any existing gradients and gradient functions.
422    ///
423    /// # Arguments
424    ///
425    /// * `requires_grad` - Whether to track gradients for this tensor
426    ///
427    /// # Performance
428    ///
429    /// - **Time Complexity**: O(1) - simple field assignment
430    /// - **Memory**: May free gradient storage when disabled
431    /// - **Overhead**: Zero overhead when gradients disabled
432    ///
433    /// # Examples
434    ///
435    /// ```
436    /// use train_station::Tensor;
437    ///
438    /// let mut tensor = Tensor::ones(vec![2, 3]);
439    /// tensor.set_requires_grad(true);
440    /// assert!(tensor.requires_grad());
441    ///
442    /// // Disable gradient tracking
443    /// tensor.set_requires_grad(false);
444    /// assert!(!tensor.requires_grad());
445    /// ```
446    pub fn set_requires_grad(&mut self, requires_grad: bool) {
447        self.requires_grad = requires_grad;
448        if !requires_grad {
449            self.grad = None;
450            self.grad_fn = GradFn::None;
451        }
452    }
453
454    /// Check if this tensor requires gradients
455    ///
456    /// # Returns
457    ///
458    /// `true` if gradient tracking is enabled for this tensor
459    ///
460    /// # Examples
461    ///
462    /// ```
463    /// use train_station::Tensor;
464    ///
465    /// let tensor = Tensor::new(vec![2, 3]);
466    /// assert!(!tensor.requires_grad());
467    ///
468    /// let grad_tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
469    /// assert!(grad_tensor.requires_grad());
470    /// ```
471    pub fn requires_grad(&self) -> bool {
472        self.requires_grad
473    }
474
475    /// Get the accumulated gradients (if any)
476    ///
477    /// Returns a reference to the gradient tensor if gradients have been computed
478    /// and this tensor has gradient tracking enabled.
479    ///
480    /// # Returns
481    ///
482    /// Optional reference to the gradient tensor, or `None` if no gradients exist
483    ///
484    /// # Examples
485    ///
486    /// ```
487    /// use train_station::Tensor;
488    ///
489    /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
490    /// assert!(tensor.grad().is_none()); // No gradients computed yet
491    /// ```
492    pub fn grad(&self) -> Option<&Tensor> {
493        // First check if we have a gradient stored directly
494        if let Some(grad) = self.grad.as_ref() {
495            return Some(grad.as_ref());
496        }
497
498        // If not, check the gradient map for accumulated gradients
499        if let Some(_grad) = gradtrack::get_accumulated_gradient(self.id) {
500            // For simplicity, we'll return None here since we can't return a reference
501            // to a temporary value. In a full implementation, we'd store it in self.grad
502            return None;
503        }
504
505        None
506    }
507
508    /// Get accumulated gradient by value (helper for testing)
509    ///
510    /// Returns the gradient tensor by value, which is useful for testing and
511    /// when you need to own the gradient data.
512    ///
513    /// # Returns
514    ///
515    /// Optional gradient tensor, or `None` if no gradients exist
516    ///
517    /// # Examples
518    ///
519    /// ```
520    /// use train_station::Tensor;
521    ///
522    /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
523    /// assert!(tensor.grad_by_value().is_none()); // No gradients computed yet
524    /// ```
525    pub fn grad_by_value(&self) -> Option<Tensor> {
526        // First check if we have a gradient stored directly
527        if let Some(grad) = self.grad.as_ref() {
528            return Some((**grad).clone());
529        }
530
531        // If not, check the gradient map for accumulated gradients
532        use crate::gradtrack;
533        gradtrack::get_accumulated_gradient(self.id)
534    }
535
536    /// Get the unique ID of this tensor
537    ///
538    /// Returns the unique identifier assigned to this tensor during creation.
539    /// This ID is used for gradtrack tracking and tensor identification.
540    ///
541    /// # Returns
542    ///
543    /// Unique tensor ID as `usize`
544    ///
545    /// # Examples
546    ///
547    /// ```
548    /// use train_station::Tensor;
549    ///
550    /// let tensor1 = Tensor::new(vec![2, 3]);
551    /// let tensor2 = Tensor::new(vec![2, 3]);
552    /// assert_ne!(tensor1.id(), tensor2.id()); // Each tensor has unique ID
553    /// ```
554    pub fn id(&self) -> usize {
555        self.id
556    }
557
558    /// Detach this tensor from the computation graph
559    ///
560    /// Returns a new tensor with the same data but no gradient tracking.
561    /// This is useful when you want to use a tensor in inference without
562    /// affecting the computation graph.
563    ///
564    /// # Returns
565    ///
566    /// A new tensor with the same data but gradient tracking disabled
567    ///
568    /// # Examples
569    ///
570    /// ```
571    /// use train_station::Tensor;
572    ///
573    /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
574    /// let detached = tensor.detach();
575    /// assert!(!detached.requires_grad());
576    /// assert_eq!(tensor.size(), detached.size());
577    /// ```
578    pub fn detach(&self) -> Self {
579        let mut detached = Self::new(self.shape.dims.clone());
580
581        // Copy data
582        unsafe {
583            let src = self.as_ptr();
584            let dst = detached.as_mut_ptr();
585            std::ptr::copy_nonoverlapping(src, dst, self.size());
586        }
587
588        detached
589    }
590
591    /// Create a new tensor that doesn't track gradients from this one
592    ///
593    /// Similar to detach() but modifies this tensor in place. This is useful
594    /// when you want to disable gradient tracking for the current tensor
595    /// without creating a copy.
596    ///
597    /// # Examples
598    ///
599    /// ```
600    /// use train_station::Tensor;
601    ///
602    /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
603    /// assert!(tensor.requires_grad());
604    /// tensor.detach_();
605    /// assert!(!tensor.requires_grad());
606    /// ```
607    pub fn detach_(&mut self) {
608        self.requires_grad = false;
609        self.grad = None;
610        self.grad_fn = GradFn::None;
611    }
612
613    /// Entry point for backward pass on this tensor
614    ///
615    /// Computes gradients for all tensors in the computation graph that have
616    /// `requires_grad` set to true. This is the main entry point for automatic
617    /// differentiation.
618    ///
619    /// # Arguments
620    ///
621    /// * `grad_output` - Optional gradient tensor for the output. If None, assumes
622    ///   the tensor is a scalar (e.g., loss value) and uses a tensor of ones.
623    ///
624    /// # Examples
625    ///
626    /// ```
627    /// use train_station::Tensor;
628    ///
629    /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
630    /// let mut result = tensor.add_scalar(5.0);
631    /// result.backward(None);
632    /// // Note: Gradient computation depends on the gradtrack system implementation
633    /// ```
634    pub fn backward(&mut self, grad_output: Option<Tensor>) {
635        GradEngine::backward(self, grad_output);
636    }
637
638    /// Returns a raw pointer to the tensor data for unsafe operations
639    ///
640    /// # Safety
641    ///
642    /// This is unsafe because it provides direct access to the underlying memory.
643    /// The caller must ensure:
644    /// - The tensor is not dropped while the pointer is used
645    /// - No concurrent mutable access occurs
646    /// - Bounds are respected
647    #[inline]
648    pub unsafe fn as_ptr(&self) -> *const f32 {
649        self.data.as_ptr()
650    }
651
652    /// Returns a mutable raw pointer to the tensor data for unsafe operations
653    ///
654    /// # Safety
655    ///
656    /// This is unsafe because it provides direct mutable access to the underlying memory.
657    /// The caller must ensure:
658    /// - The tensor is not dropped while the pointer is used
659    /// - No concurrent access occurs
660    /// - Bounds are respected
661    #[inline]
662    pub unsafe fn as_mut_ptr(&mut self) -> *mut f32 {
663        self.data.as_ptr()
664    }
665
666    /// Internal method to set gradient function (used by operations)
667    ///
668    /// Sets the gradient function for this tensor. This is used internally
669    /// by tensor operations to record the computation graph for gradtrack.
670    ///
671    /// # Arguments
672    ///
673    /// * `grad_fn` - The gradient function to set
674    ///
675    /// # Implementation Details
676    ///
677    /// This method is called by tensor operations to register the gradient
678    /// computation function. It only sets the gradient function if gradient
679    /// tracking is enabled for this tensor.
680    pub(crate) fn set_grad_fn(&mut self, grad_fn: GradFn) {
681        if self.requires_grad {
682            self.grad_fn = grad_fn;
683        }
684    }
685
686    /// Get a reference to the gradient function (for gradtrack)
687    ///
688    /// Returns a reference to the gradient function associated with this tensor.
689    /// This is used internally by the gradtrack system to compute gradients.
690    ///
691    /// # Returns
692    ///
693    /// Reference to the gradient function
694    ///
695    /// # Implementation Details
696    ///
697    /// This method is used by the gradtrack engine to access the gradient
698    /// computation function during backward pass.
699    pub fn grad_fn(&self) -> &GradFn {
700        &self.grad_fn
701    }
702
703    /// Internal method to set requires_grad (used by gradtrack operations)
704    ///
705    /// Sets the gradient tracking flag for this tensor. This is used internally
706    /// by gradtrack operations to control gradient computation.
707    ///
708    /// # Arguments
709    ///
710    /// * `requires_grad` - Whether to enable gradient tracking
711    ///
712    /// # Implementation Details
713    ///
714    /// This method is used internally by the gradtrack system to control
715    /// gradient tracking without triggering additional side effects.
716    pub(crate) fn set_requires_grad_internal(&mut self, requires_grad: bool) {
717        self.requires_grad = requires_grad;
718    }
719
720    /// Internal method to accumulate gradients with optimized operations
721    ///
722    /// Accumulates a gradient tensor with any existing gradients for this tensor.
723    /// This is used internally by the gradtrack system to handle gradient accumulation
724    /// during backward pass.
725    ///
726    /// # Arguments
727    ///
728    /// * `grad` - The gradient tensor to accumulate
729    ///
730    /// # Implementation Details
731    ///
732    /// This method is used internally by the gradtrack engine to accumulate
733    /// gradients from multiple backward passes or operations. It only accumulates
734    /// gradients if gradient tracking is enabled for this tensor.
735    pub(crate) fn accumulate_grad(&mut self, grad: Tensor) {
736        if !self.requires_grad {
737            return;
738        }
739
740        match &self.grad {
741            Some(existing_grad) => {
742                // Use optimized tensor addition but create new tensor for safety
743                let accumulated = existing_grad.add_tensor_optimized(&grad);
744                self.grad = Some(Arc::new(accumulated));
745            }
746            None => {
747                self.grad = Some(Arc::new(grad));
748            }
749        }
750    }
751
752    /// Set gradient from external source
753    ///
754    /// Sets the gradient tensor for this tensor. This is used internally by the
755    /// gradtrack system to set gradients during backward pass.
756    ///
757    /// # Arguments
758    ///
759    /// * `grad` - The gradient tensor to set
760    ///
761    /// # Implementation Details
762    ///
763    /// This method is used internally by the gradtrack engine to set gradients
764    /// during backward pass. It only sets the gradient if gradient tracking is
765    /// enabled for this tensor.
766    pub fn set_grad(&mut self, grad: Tensor) {
767        if self.requires_grad {
768            self.grad = Some(Arc::new(grad));
769        }
770    }
771
772    /// Clear accumulated gradients for this tensor
773    ///
774    /// This method is used by optimizers to zero gradients before each backward pass.
775    /// It clears any accumulated gradients, allowing for fresh gradient computation.
776    ///
777    /// # Examples
778    ///
779    /// ```
780    /// use train_station::Tensor;
781    ///
782    /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
783    /// tensor.set_grad(Tensor::ones(vec![2, 3]));
784    /// assert!(tensor.grad().is_some());
785    /// tensor.zero_grad();
786    /// assert!(tensor.grad().is_none());
787    /// ```
788    pub fn zero_grad(&mut self) {
789        self.grad = None;
790    }
791
792    /// Negate all elements in the tensor in-place
793    ///
794    /// This is used internally for gradient computation in subtraction operations.
795    /// For tensor - tensor operations, the second operand gets a negated gradient.
796    ///
797    /// # Implementation Details
798    ///
799    /// This method is used internally by the gradtrack system to compute gradients
800    /// for subtraction operations. It uses SIMD optimization when available for
801    /// better performance.
802    #[inline]
803    pub(crate) fn negate_inplace(&mut self) {
804        if self.shape.size == 0 {
805            return;
806        }
807
808        unsafe {
809            let ptr = self.data.as_ptr();
810
811            #[cfg(target_arch = "x86_64")]
812            {
813                // Use SIMD for better performance when available
814                if is_x86_feature_detected!("avx2") {
815                    self.negate_simd_avx2(ptr);
816                    return;
817                }
818            }
819
820            // Fallback to scalar operations
821            for i in 0..self.shape.size {
822                *ptr.add(i) = -*ptr.add(i);
823            }
824        }
825    }
826
827    /// SIMD-optimized negation using AVX2 instructions
828    ///
829    /// Performs in-place negation of tensor elements using AVX2 SIMD instructions
830    /// for improved performance on x86_64 architectures.
831    ///
832    /// # Arguments
833    ///
834    /// * `ptr` - Raw pointer to the tensor data
835    ///
836    /// # Safety
837    ///
838    /// The caller must ensure:
839    /// - `ptr` is a valid pointer to tensor data
840    /// - The tensor size matches the actual data size
841    /// - The tensor is not moved or dropped during this operation
842    ///
843    /// # Implementation Details
844    ///
845    /// This method is used internally by `negate_inplace` when AVX2 is available.
846    /// It processes 8 elements per iteration using AVX2 instructions.
847    #[cfg(target_arch = "x86_64")]
848    #[inline]
849    unsafe fn negate_simd_avx2(&self, ptr: *mut f32) {
850        use std::arch::x86_64::_mm256_setzero_ps;
851
852        let size = self.shape.size;
853        let zero_vec = _mm256_setzero_ps();
854        let simd_count = size / 8; // Process 8 elements per iteration
855        let mut offset = 0;
856
857        // SIMD loop for negation
858        for _ in 0..simd_count {
859            use std::arch::x86_64::{_mm256_load_ps, _mm256_store_ps, _mm256_sub_ps};
860
861            let vec = _mm256_load_ps(ptr.add(offset));
862            let neg_vec = _mm256_sub_ps(zero_vec, vec);
863            _mm256_store_ps(ptr.add(offset), neg_vec);
864            offset += 8;
865        }
866
867        // Handle remaining elements
868        for i in offset..size {
869            *ptr.add(i) = -*ptr.add(i);
870        }
871    }
872
873    // ===== Memory Layout and Optimization API =====
874
875    /// Checks if the tensor data is stored contiguously in memory
876    ///
877    /// # Returns
878    ///
879    /// `true` if the tensor data is contiguous, enabling optimized SIMD operations
880    ///
881    /// # Examples
882    ///
883    /// ```
884    /// use train_station::Tensor;
885    ///
886    /// let tensor = Tensor::new(vec![2, 3, 4]);
887    /// assert!(tensor.is_contiguous());
888    /// ```
889    #[inline]
890    pub fn is_contiguous(&self) -> bool {
891        self.shape.is_contiguous()
892    }
893
894    /// Checks if this tensor is a view of another tensor
895    ///
896    /// # Returns
897    ///
898    /// `true` if this tensor is a view (non-contiguous reference)
899    #[inline]
900    pub fn is_view(&self) -> bool {
901        self.shape.is_view()
902    }
903
904    /// Gets the memory strides for all dimensions
905    ///
906    /// # Returns
907    ///
908    /// Reference to the stride vector for efficient memory access calculations
909    ///
910    /// # Examples
911    ///
912    /// ```
913    /// use train_station::Tensor;
914    ///
915    /// let tensor = Tensor::new(vec![2, 3, 4]);
916    /// assert_eq!(tensor.strides(), &[12, 4, 1]);
917    /// ```
918    #[inline]
919    pub fn strides(&self) -> &[usize] {
920        self.shape.strides()
921    }
922
923    /// Gets the memory stride for a specific dimension
924    ///
925    /// # Arguments
926    ///
927    /// * `dim` - The dimension index
928    ///
929    /// # Returns
930    ///
931    /// The memory stride for the given dimension
932    ///
933    /// # Panics
934    ///
935    /// Panics if `dim` is out of bounds
936    #[inline]
937    pub fn stride(&self, dim: usize) -> usize {
938        self.shape.stride(dim)
939    }
940
941    /// Gets the memory layout type for optimization decisions
942    ///
943    /// # Returns
944    ///
945    /// Reference to the memory layout information
946    #[inline]
947    pub fn layout(&self) -> &crate::tensor::MemoryLayout {
948        self.shape.layout()
949    }
950
951    /// Calculates the linear memory offset for given multi-dimensional indices
952    ///
953    /// # Arguments
954    ///
955    /// * `indices` - Vector of indices for each dimension
956    ///
957    /// # Returns
958    ///
959    /// Linear memory offset for direct memory access
960    ///
961    /// # Examples
962    ///
963    /// ```
964    /// use train_station::Tensor;
965    ///
966    /// let tensor = Tensor::new(vec![2, 3, 4]);
967    /// let offset = tensor.memory_offset(&[1, 2, 3]);
968    /// // offset = 1*12 + 2*4 + 3*1 = 23
969    /// ```
970    #[inline]
971    pub fn memory_offset(&self, indices: &[usize]) -> usize {
972        self.shape.offset(indices)
973    }
974
975    /// Broadcast this tensor with another tensor for element-wise operations
976    ///
977    /// Returns a tuple containing:
978    /// - Broadcasted view of self
979    /// - Broadcasted view of other
980    /// - Result shape for the operation
981    ///
982    /// # Arguments
983    ///
984    /// * `other` - The tensor to broadcast with
985    ///
986    /// # Returns
987    ///
988    /// A tuple `(broadcasted_self, broadcasted_other, result_shape)`
989    ///
990    /// # Examples
991    ///
992    /// ```
993    /// use train_station::Tensor;
994    ///
995    /// let a = Tensor::ones(vec![2, 1, 4]);
996    /// let b = Tensor::ones(vec![3, 1]);
997    /// let result = a.broadcast_with(&b);
998    /// assert!(result.is_ok());
999    /// ```
1000    pub fn broadcast_with(
1001        &self,
1002        other: &Tensor,
1003    ) -> Result<
1004        (Tensor, Tensor, crate::tensor::Shape),
1005        crate::tensor::ops::broadcasting::BroadcastError,
1006    > {
1007        crate::tensor::ops::broadcasting::broadcast_shapes(self, other)
1008    }
1009
1010    /// Checks if the tensor data is properly aligned for SIMD operations
1011    ///
1012    /// # Returns
1013    ///
1014    /// `true` if the tensor data is aligned to 32-byte boundaries for AVX2
1015    #[inline]
1016    pub fn is_simd_aligned(&self) -> bool {
1017        (self.data.as_ptr() as usize) % 32 == 0
1018    }
1019
1020    /// Gets the memory alignment of the tensor data
1021    ///
1022    /// # Returns
1023    ///
1024    /// The memory alignment in bytes (typically 32 for SIMD optimization)
1025    #[inline]
1026    pub fn memory_alignment(&self) -> usize {
1027        // Our tensors are allocated with 32-byte alignment for AVX2
1028        32
1029    }
1030
1031    /// Checks if this tensor is broadcastable with another tensor
1032    ///
1033    /// # Arguments
1034    ///
1035    /// * `other` - The other tensor to check broadcasting compatibility
1036    ///
1037    /// # Returns
1038    ///
1039    /// `true` if the tensors are broadcastable according to NumPy broadcasting rules
1040    ///
1041    /// # Examples
1042    ///
1043    /// ```
1044    /// use train_station::Tensor;
1045    ///
1046    /// let a = Tensor::new(vec![2, 3, 4]);
1047    /// let b = Tensor::new(vec![1, 3, 4]);
1048    /// assert!(a.is_broadcastable_with(&b));
1049    /// ```
1050    #[inline]
1051    pub fn is_broadcastable_with(&self, other: &Tensor) -> bool {
1052        self.shape.is_broadcastable_with(&other.shape)
1053    }
1054
1055    /// Gets the total number of bytes allocated for this tensor
1056    ///
1057    /// # Returns
1058    ///
1059    /// Total memory footprint in bytes
1060    #[inline]
1061    pub fn memory_footprint(&self) -> usize {
1062        self.shape.size * std::mem::size_of::<f32>()
1063    }
1064
1065    /// Get a single element from the tensor at the specified indices
1066    ///
1067    /// # Arguments
1068    ///
1069    /// * `indices` - Multi-dimensional indices to access the element
1070    ///
1071    /// # Returns
1072    ///
1073    /// The value at the specified position
1074    ///
1075    /// # Panics
1076    ///
1077    /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1078    ///
1079    /// # Examples
1080    ///
1081    /// ```
1082    /// use train_station::Tensor;
1083    ///
1084    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1085    /// let value = tensor.get(&[0, 1]);
1086    /// assert_eq!(value, 2.0);
1087    /// ```
1088    pub fn get(&self, indices: &[usize]) -> f32 {
1089        assert_eq!(
1090            indices.len(),
1091            self.shape().rank(),
1092            "Indices length must match tensor rank"
1093        );
1094
1095        // Check bounds
1096        for (i, &idx) in indices.iter().enumerate() {
1097            assert!(
1098                idx < self.shape().dims[i],
1099                "Index {} out of bounds for dimension {}",
1100                idx,
1101                i
1102            );
1103        }
1104
1105        let offset = self.memory_offset(indices);
1106        unsafe { *self.as_ptr().add(offset) }
1107    }
1108
1109    /// Set a single element in the tensor at the specified indices
1110    ///
1111    /// # Arguments
1112    ///
1113    /// * `indices` - Multi-dimensional indices to set the element
1114    /// * `value` - The value to set
1115    ///
1116    /// # Panics
1117    ///
1118    /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1119    ///
1120    /// # Examples
1121    ///
1122    /// ```
1123    /// use train_station::Tensor;
1124    ///
1125    /// let mut tensor = Tensor::new(vec![2, 2]);
1126    /// tensor.set(&[0, 1], 42.0);
1127    /// assert_eq!(tensor.get(&[0, 1]), 42.0);
1128    /// ```
1129    pub fn set(&mut self, indices: &[usize], value: f32) {
1130        assert_eq!(
1131            indices.len(),
1132            self.shape().rank(),
1133            "Indices length must match tensor rank"
1134        );
1135
1136        // Check bounds
1137        for (i, &idx) in indices.iter().enumerate() {
1138            assert!(
1139                idx < self.shape().dims[i],
1140                "Index {} out of bounds for dimension {}",
1141                idx,
1142                i
1143            );
1144        }
1145
1146        let offset = self.memory_offset(indices);
1147        unsafe {
1148            *self.as_mut_ptr().add(offset) = value;
1149        }
1150    }
1151
1152    /// Returns a safe slice of the tensor's underlying data
1153    ///
1154    /// Provides safe access to the tensor's data without requiring unsafe pointer operations.
1155    /// This is the preferred way to access tensor data for reading values, comparisons,
1156    /// and other operations that don't require direct pointer manipulation.
1157    ///
1158    /// # Returns
1159    ///
1160    /// A slice containing all tensor elements in row-major order
1161    ///
1162    /// # Performance
1163    ///
1164    /// - **Zero-Cost**: Direct slice creation with no copying
1165    /// - **Cache-Friendly**: Sequential memory access pattern
1166    /// - **Safe**: No unsafe code required for basic data access
1167    ///
1168    /// # Examples
1169    ///
1170    /// ```
1171    /// use train_station::Tensor;
1172    ///
1173    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1174    /// let data = tensor.data();
1175    ///
1176    /// // Safe indexing and comparisons
1177    /// assert_eq!(data[0], 1.0);
1178    /// assert_eq!(data.len(), tensor.size());
1179    /// ```
1180    #[inline]
1181    pub fn data(&self) -> &[f32] {
1182        if self.size() == 0 {
1183            return &[];
1184        }
1185        unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()) }
1186    }
1187
1188    /// Returns a mutable slice of the tensor's underlying data
1189    ///
1190    /// Provides safe mutable access to the tensor's data without requiring unsafe
1191    /// pointer operations. Use this for in-place modifications of tensor values.
1192    ///
1193    /// # Returns
1194    ///
1195    /// A mutable slice containing all tensor elements in row-major order
1196    ///
1197    /// # Performance
1198    ///
1199    /// - **Zero-Cost**: Direct slice creation with no copying
1200    /// - **Cache-Friendly**: Sequential memory access pattern
1201    /// - **Safe**: No unsafe code required for basic data modification
1202    ///
1203    /// # Examples
1204    ///
1205    /// ```
1206    /// use train_station::Tensor;
1207    ///
1208    /// let mut tensor = Tensor::new(vec![2, 2]);
1209    /// let data = tensor.data_mut();
1210    ///
1211    /// // Safe indexing for modification
1212    /// data[0] = 1.0;
1213    /// data[1] = 2.0;
1214    ///
1215    /// assert_eq!(tensor.get(&[0, 0]), 1.0);
1216    /// assert_eq!(tensor.get(&[0, 1]), 2.0);
1217    /// ```
1218    #[inline]
1219    pub fn data_mut(&mut self) -> &mut [f32] {
1220        if self.size() == 0 {
1221            return &mut [];
1222        }
1223        unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) }
1224    }
1225
1226    /// Extract scalar value from single-element tensor
1227    ///
1228    /// This method provides a convenient way to extract the scalar value from
1229    /// tensors that contain exactly one element. This is commonly used with
1230    /// element iterator results and scalar tensor operations.
1231    ///
1232    /// # Returns
1233    ///
1234    /// The scalar value contained in this tensor
1235    ///
1236    /// # Panics
1237    ///
1238    /// Panics if the tensor does not contain exactly one element
1239    ///
1240    /// # Examples
1241    ///
1242    /// ```
1243    /// use train_station::Tensor;
1244    ///
1245    /// // Single-element tensor
1246    /// let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
1247    /// assert_eq!(scalar.value(), 42.0);
1248    /// ```
1249    #[inline]
1250    #[track_caller]
1251    pub fn value(&self) -> f32 {
1252        assert_eq!(
1253            self.size(),
1254            1,
1255            "value() can only be called on tensors with exactly one element. \
1256             This tensor has {} elements with shape {:?}",
1257            self.size(),
1258            self.shape().dims
1259        );
1260        self.data()[0]
1261    }
1262
1263    /// Create a view with a new shape (requires contiguous memory)
1264    ///
1265    /// Behaves like PyTorch `view`: tensor must be contiguous and the total
1266    /// number of elements must remain the same. Supports -1 inference for one dimension.
1267    ///
1268    /// # Arguments
1269    ///
1270    /// * `new_shape` - New shape for the tensor (can contain -1 for inference)
1271    ///
1272    /// # Returns
1273    ///
1274    /// A tensor viewing the same data with a new shape
1275    ///
1276    /// # Examples
1277    ///
1278    /// ```
1279    /// use train_station::Tensor;
1280    ///
1281    /// let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
1282    /// let y = x.view(vec![2, 2]);
1283    /// assert_eq!(y.shape().dims, vec![2, 2]);
1284    /// ```
1285    pub fn view(&self, new_shape: Vec<i32>) -> Tensor {
1286        // Use the views module implementation
1287        use crate::tensor::transform::view::TensorViewExt;
1288        TensorViewExt::view(self, new_shape)
1289    }
1290
1291    /// Create an element view for the specified index
1292    ///
1293    /// Returns a scalar tensor (shape \[1\]) that views a single element
1294    /// of the source tensor. Maintains gradient tracking.
1295    ///
1296    /// # Arguments
1297    ///
1298    /// * `index` - Linear index of the element to view
1299    ///
1300    /// # Returns
1301    ///
1302    /// A scalar tensor viewing the specified element
1303    ///
1304    /// # Examples
1305    ///
1306    /// ```
1307    /// use train_station::Tensor;
1308    ///
1309    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
1310    /// let element = tensor.element_view(1);
1311    /// assert_eq!(element.value(), 2.0);
1312    /// ```
1313    pub fn element_view(&self, index: usize) -> Tensor {
1314        use crate::tensor::transform::view::TensorViewExt;
1315        TensorViewExt::element_view(self, index)
1316    }
1317
1318    /// Create a slice view of the tensor
1319    ///
1320    /// Returns a view of a contiguous or strided slice of the source tensor.
1321    ///
1322    /// # Arguments
1323    ///
1324    /// * `start` - Starting index
1325    /// * `step` - Step size (1 for contiguous)
1326    /// * `length` - Number of elements
1327    ///
1328    /// # Returns
1329    ///
1330    /// A tensor viewing the specified slice
1331    ///
1332    /// # Examples
1333    ///
1334    /// ```
1335    /// use train_station::Tensor;
1336    ///
1337    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1338    /// let slice = tensor.slice_view(1, 2, 2); // [2.0, 4.0]
1339    /// assert_eq!(slice.data(), &[2.0, 4.0]);
1340    /// ```
1341    pub fn slice_view(&self, start: usize, step: usize, length: usize) -> Tensor {
1342        use crate::tensor::transform::view::TensorViewExt;
1343        TensorViewExt::slice_view(self, start, step, length)
1344    }
1345
1346    /// Create a tensor view from raw components
1347    ///
1348    /// Creates a tensor that views existing memory with the specified shape and device.
1349    /// The tensor shares memory with the original allocation through the allocation_owner.
1350    ///
1351    /// # Safety
1352    ///
1353    /// The caller must ensure:
1354    /// - `data` is valid for the number of elements specified by `shape`
1355    /// - `data` remains valid for the lifetime of the returned tensor
1356    /// - `allocation_owner` properly manages the memory lifecycle
1357    ///
1358    /// # Arguments
1359    ///
1360    /// * `data` - Raw pointer to the tensor data
1361    /// * `shape` - Shape of the tensor view
1362    /// * `device` - Device where the tensor is located
1363    /// * `allocation_owner` - Optional shared allocation owner
1364    ///
1365    /// # Returns
1366    ///
1367    /// A new tensor that views the specified memory
1368    ///
1369    /// # Implementation Details
1370    ///
1371    /// This method is used internally to create tensor views that share memory
1372    /// with other tensors. It's primarily used for view operations and memory
1373    /// management.
1374    pub(crate) fn from_raw_view(
1375        data: *const f32,
1376        shape: Shape,
1377        device: Device,
1378        allocation_owner: Option<Arc<Allocation>>,
1379    ) -> Self {
1380        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1381
1382        Self {
1383            data: NonNull::new(data as *mut f32).expect("Data pointer cannot be null"),
1384            shape,
1385            device,
1386            id,
1387            requires_grad: false,
1388            grad: None,
1389            grad_fn: GradFn::None,
1390            allocation_owner,
1391            _phantom: PhantomData,
1392        }
1393    }
1394
1395    /// Get the allocation owner for this tensor
1396    ///
1397    /// Returns the shared allocation owner if this tensor is a view,
1398    /// or None if this tensor owns its memory directly.
1399    ///
1400    /// # Returns
1401    ///
1402    /// Optional reference to the allocation owner
1403    ///
1404    /// # Implementation Details
1405    ///
1406    /// This method is used internally to manage memory lifecycle for tensor views.
1407    /// It helps determine whether a tensor shares memory with another tensor.
1408    pub fn allocation_owner(&self) -> Option<&Arc<Allocation>> {
1409        self.allocation_owner.as_ref()
1410    }
1411
1412    /// Create a new tensor with uninitialized memory
1413    ///
1414    /// This method allocates memory for a tensor without initializing it to any value.
1415    /// This is useful for performance-critical operations where the memory will be
1416    /// immediately overwritten, such as matrix multiplication results.
1417    ///
1418    /// # Safety
1419    ///
1420    /// The caller must ensure that all memory is written before reading from the tensor.
1421    /// Reading from uninitialized memory is undefined behavior.
1422    ///
1423    /// # Arguments
1424    ///
1425    /// * `shape_dims` - The dimensions of the tensor
1426    ///
1427    /// # Returns
1428    ///
1429    /// A tensor with uninitialized memory
1430    ///
1431    /// # Performance
1432    ///
1433    /// - **Zero Initialization**: Skips memory initialization for maximum performance
1434    /// - **SIMD Ready**: Properly aligned for vectorized operations
1435    /// - **Memory Efficient**: Uses optimized alignment strategies
1436    ///
1437    /// # Example
1438    ///
1439    /// ```
1440    /// use train_station::Tensor;
1441    ///
1442    /// // Create uninitialized tensor for matmul result
1443    /// let mut result = Tensor::new_uninitialized(vec![100, 100]);
1444    /// // Initialize the memory before use
1445    /// for value in result.data_mut() {
1446    ///     *value = 0.0;
1447    /// }
1448    /// ```
1449    #[inline]
1450    pub fn new_uninitialized(shape_dims: Vec<usize>) -> Self {
1451        let shape = Shape::new(shape_dims);
1452        let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1453
1454        if shape.size == 0 {
1455            // Handle zero-sized tensors
1456            return Self {
1457                data: NonNull::dangling(),
1458                shape,
1459                device: current_device(),
1460                id,
1461                requires_grad: false,
1462                grad: None,
1463                grad_fn: GradFn::None,
1464                allocation_owner: None,
1465                _phantom: PhantomData,
1466            };
1467        }
1468
1469        // Optimized layout calculation for better cache performance
1470        let element_size = std::mem::size_of::<f32>();
1471        let total_size = shape.size * element_size;
1472
1473        // Use cache line alignment for large tensors, smaller alignment for small ones
1474        let alignment = if total_size > 4096 {
1475            64 // Cache line alignment for large tensors
1476        } else if shape.size >= 8 {
1477            32 // AVX2 alignment for medium tensors
1478        } else {
1479            16 // SSE alignment for small tensors
1480        };
1481
1482        let layout = Layout::from_size_align(total_size, alignment)
1483            .expect("Failed to create layout for tensor data");
1484
1485        // Allocate memory via shared Allocation (uninitialized)
1486        let alloc_obj = Allocation::new_uninitialized(shape.size, alignment, layout);
1487        let ptr = alloc_obj.ptr;
1488
1489        Self {
1490            data: ptr,
1491            shape,
1492            device: current_device(),
1493            id,
1494            requires_grad: false,
1495            grad: None,
1496            grad_fn: GradFn::None,
1497            allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
1498            _phantom: PhantomData,
1499        }
1500    }
1501}