train_station/tensor/core/utils.rs
1//! Tensor utility functions and core implementation methods
2//!
3//! This module provides essential utility functions for tensor creation, memory management,
4//! gradient tracking, and optimization. It contains the core implementation methods
5//! that enable efficient tensor operations and gradtrack functionality.
6//!
7//! # Key Features
8//!
9//! - **Tensor Creation**: Optimized constructors with memory alignment
10//! - **Memory Management**: Safe memory access and allocation utilities
11//! - **Gradient Tracking**: GradTrack system integration and gradient management
12//! - **Performance Optimization**: SIMD-ready memory layout and alignment
13//! - **Device Management**: CPU and future CUDA device support
14//! - **Memory Layout**: Contiguous, strided, and view memory access patterns
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Alignment**: 16-byte SSE, 32-byte AVX2, 64-byte cache-line alignment
19//! - **SIMD Optimization**: Properly aligned memory for vectorized operations
20//! - **Zero-Cost Abstractions**: Minimal overhead for utility operations
21//! - **Thread Safety**: Atomic operations for gradient tracking and ID generation
22//! - **Memory Efficiency**: Optimized allocation strategies for different tensor sizes
23//!
24//! # Examples
25//!
26//! ## Basic Tensor Creation
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensors of different sizes
32//! let small_tensor = Tensor::new(vec![2, 3]); // 16-byte alignment
33//! let medium_tensor = Tensor::new(vec![32, 32]); // 32-byte alignment
34//! let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
35//!
36//! // Initialize data before use
37//! let mut tensor = Tensor::new(vec![2, 3]);
38//! tensor.fill(0.0); // Initialize with zeros
39//! ```
40//!
41//! ## Gradient Tracking
42//!
43//! ```
44//! use train_station::Tensor;
45//!
46//! let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
47//! assert!(tensor.requires_grad());
48//! ```
49//!
50//! ## Memory Access
51//!
52//! ```
53//! use train_station::Tensor;
54//!
55//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
56//! let value = tensor.get(&[0, 1]);
57//! assert_eq!(value, 2.0);
58//!
59//! let mut tensor = Tensor::new(vec![2, 2]);
60//! tensor.set(&[0, 1], 42.0);
61//! assert_eq!(tensor.get(&[0, 1]), 42.0);
62//! ```
63
64use std::alloc::Layout;
65use std::marker::PhantomData;
66use std::ptr::NonNull;
67use std::sync::atomic::Ordering;
68use std::sync::Arc;
69
70use crate::device::current_device;
71use crate::gradtrack::{self, GradEngine, GradFn};
72use crate::tensor::core::{Allocation, Device, TENSOR_ID_COUNTER};
73use crate::tensor::Shape;
74
75use super::Tensor;
76
77impl Tensor {
78 /// Creates a new tensor with the specified shape and optimized memory layout
79 ///
80 /// Allocates memory with size-dependent alignment for optimal performance:
81 /// - Small tensors (≤8 elements): 16-byte SSE alignment
82 /// - Medium tensors (8-1024 elements): 32-byte AVX2 alignment
83 /// - Large tensors (>1024 elements): 64-byte cache-line alignment
84 ///
85 /// # Arguments
86 ///
87 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
88 ///
89 /// # Returns
90 ///
91 /// A new tensor with uninitialized data. The data must be initialized
92 /// before use to avoid undefined behavior.
93 ///
94 /// # Performance
95 ///
96 /// - **Memory Allocation**: Single allocation with optimized alignment
97 /// - **SIMD Ready**: Properly aligned for vectorized operations
98 /// - **Cache Friendly**: Optimized for CPU cache hierarchies
99 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
100 ///
101 /// # Safety
102 ///
103 /// The returned tensor contains uninitialized memory. You must initialize
104 /// the data before performing any operations that read from it.
105 ///
106 /// # Examples
107 ///
108 /// ```
109 /// use train_station::Tensor;
110 ///
111 /// // Create tensors of different sizes
112 /// let small_tensor = Tensor::new(vec![2, 3]); // 16-byte alignment
113 /// let medium_tensor = Tensor::new(vec![32, 32]); // 32-byte alignment
114 /// let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
115 ///
116 /// // Initialize data before use
117 /// let mut tensor = Tensor::new(vec![2, 3]);
118 /// tensor.fill(0.0); // Initialize with zeros
119 /// ```
120 #[inline]
121 #[track_caller]
122 pub fn new(shape_dims: Vec<usize>) -> Self {
123 let shape = Shape::new(shape_dims);
124 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
125
126 if shape.size == 0 {
127 // Handle zero-sized tensors
128 return Self {
129 data: NonNull::dangling(),
130 shape,
131 device: current_device(),
132 id,
133 requires_grad: false,
134 grad: None,
135 grad_fn: GradFn::None,
136 allocation_owner: None,
137 _phantom: PhantomData,
138 };
139 }
140
141 // Optimized layout calculation for better cache performance
142 let element_size = std::mem::size_of::<f32>();
143 let total_size = shape.size * element_size;
144
145 // Use cache line alignment for large tensors, smaller alignment for small ones
146 let alignment = if total_size > 4096 {
147 64 // Cache line alignment for large tensors
148 } else if shape.size >= 8 {
149 32 // AVX2 alignment for medium tensors
150 } else {
151 16 // SSE alignment for small tensors
152 };
153
154 let layout = Layout::from_size_align(total_size, alignment)
155 .expect("Failed to create layout for tensor data");
156
157 // Allocate memory via shared Allocation
158 let alloc_obj = Allocation::new(shape.size, alignment, layout);
159 let ptr = alloc_obj.ptr;
160
161 Self {
162 data: ptr,
163 shape,
164 device: current_device(),
165 id,
166 requires_grad: false,
167 grad: None,
168 grad_fn: GradFn::None,
169 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
170 _phantom: PhantomData,
171 }
172 }
173
174 /// Returns the shape and dimensional information of the tensor
175 ///
176 /// Provides access to the tensor's dimensions, size, strides, and memory
177 /// layout information. This is used for shape validation, memory access
178 /// calculations, and optimization decisions.
179 ///
180 /// # Returns
181 ///
182 /// Reference to the tensor's shape information containing dimensions,
183 /// size, strides, and memory layout type.
184 ///
185 /// # Performance
186 ///
187 /// - **Time Complexity**: O(1) - direct field access
188 /// - **Memory**: No allocation - returns reference to existing data
189 ///
190 /// # Examples
191 ///
192 /// ```
193 /// use train_station::Tensor;
194 ///
195 /// let tensor = Tensor::new(vec![2, 3, 4]);
196 /// let shape = tensor.shape();
197 /// assert_eq!(shape.dims, vec![2, 3, 4]);
198 /// assert_eq!(shape.size, 24);
199 /// assert_eq!(shape.rank(), 3);
200 /// ```
201 #[inline]
202 #[track_caller]
203 pub fn shape(&self) -> &Shape {
204 &self.shape
205 }
206
207 /// Returns the total number of elements in the tensor
208 ///
209 /// Provides the total count of elements across all dimensions. This is
210 /// used for memory allocation, iteration bounds, and performance optimization.
211 ///
212 /// # Returns
213 ///
214 /// Total number of elements as `usize`
215 ///
216 /// # Performance
217 ///
218 /// - **Time Complexity**: O(1) - direct field access
219 /// - **Memory**: No allocation - returns stored value
220 ///
221 /// # Examples
222 ///
223 /// ```
224 /// use train_station::Tensor;
225 ///
226 /// let tensor = Tensor::new(vec![2, 3, 4]);
227 /// assert_eq!(tensor.size(), 24); // 2 * 3 * 4
228 ///
229 /// let scalar = Tensor::new(vec![1]);
230 /// assert_eq!(scalar.size(), 1);
231 ///
232 /// let empty = Tensor::new(vec![0]);
233 /// assert_eq!(empty.size(), 0);
234 /// ```
235 #[inline]
236 #[track_caller]
237 pub fn size(&self) -> usize {
238 self.shape.size
239 }
240
241 /// Returns the device where this tensor is located
242 ///
243 /// Provides the physical location of the tensor data (CPU/GPU). This
244 /// determines which operations can be performed on the tensor and where
245 /// computations will be executed.
246 ///
247 /// # Returns
248 ///
249 /// Device enum indicating the tensor's physical location
250 ///
251 /// # Performance
252 ///
253 /// - **Time Complexity**: O(1) - direct field access
254 /// - **Memory**: No allocation - returns stored value
255 ///
256 /// # Examples
257 ///
258 /// ```
259 /// use train_station::Tensor;
260 ///
261 /// let tensor = Tensor::new(vec![2, 3]);
262 /// assert!(tensor.device().is_cpu());
263 /// assert!(!tensor.device().is_cuda());
264 /// ```
265 #[inline]
266 #[track_caller]
267 pub fn device(&self) -> Device {
268 self.device
269 }
270
271 /// Creates a new tensor with the specified shape on a specific device
272 ///
273 /// Allocates memory on the specified device with the same optimized alignment
274 /// strategy as `new()`. Currently supports CPU device with future CUDA support.
275 ///
276 /// # Arguments
277 ///
278 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
279 /// * `device` - The device where the tensor should be allocated
280 ///
281 /// # Returns
282 ///
283 /// A new tensor with uninitialized data on the specified device
284 ///
285 /// # Performance
286 ///
287 /// - **Memory Allocation**: Device-specific allocation with optimized alignment
288 /// - **SIMD Ready**: Properly aligned for vectorized operations on target device
289 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
290 ///
291 /// # Panics
292 ///
293 /// Panics if the specified device is not supported (e.g., CUDA without feature flag)
294 ///
295 /// # Examples
296 ///
297 /// ```
298 /// use train_station::Tensor;
299 ///
300 /// let tensor = Tensor::new_on_device(vec![2, 3], train_station::Device::cpu());
301 /// assert!(tensor.device().is_cpu());
302 /// assert_eq!(tensor.size(), 6);
303 /// ```
304 ///
305 /// # Arguments
306 ///
307 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
308 /// * `device` - Device where the tensor should be allocated
309 ///
310 /// # Returns
311 ///
312 /// A new tensor with uninitialized data on the specified device
313 ///
314 /// # Panics
315 ///
316 /// Panics if the device is not supported (currently only CPU is supported)
317 ///
318 /// # Performance
319 ///
320 /// - **Memory Allocation**: Single allocation with optimized alignment
321 /// - **SIMD Ready**: Properly aligned for vectorized operations
322 /// - **Cache Friendly**: Optimized for CPU cache hierarchies
323 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
324 ///
325 /// # Examples
326 ///
327 /// ```
328 /// use train_station::{Tensor, Device};
329 ///
330 /// // Create tensor on CPU device
331 /// let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
332 /// assert_eq!(tensor.device(), Device::cpu());
333 /// assert_eq!(tensor.size(), 6);
334 /// ```
335 #[track_caller]
336 pub fn new_on_device(shape_dims: Vec<usize>, device: Device) -> Self {
337 // For now, only CPU is supported
338 if !device.is_cpu() {
339 panic!("Only CPU device is currently supported. CUDA support is planned for future releases.");
340 }
341
342 let shape = Shape::new(shape_dims);
343 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
344
345 if shape.size == 0 {
346 // Handle zero-sized tensors
347 return Self {
348 data: NonNull::dangling(),
349 shape,
350 device,
351 id,
352 requires_grad: false,
353 grad: None,
354 grad_fn: GradFn::None,
355 allocation_owner: None,
356 _phantom: PhantomData,
357 };
358 }
359
360 // Optimized layout calculation for better cache performance
361 let element_size = std::mem::size_of::<f32>();
362 let total_size = shape.size * element_size;
363
364 // Use cache line alignment for large tensors, smaller alignment for small ones
365 let alignment = if total_size > 4096 {
366 64 // Cache line alignment for large tensors
367 } else if shape.size >= 8 {
368 32 // AVX2 alignment for medium tensors
369 } else {
370 16 // SSE alignment for small tensors
371 };
372
373 let layout = Layout::from_size_align(total_size, alignment)
374 .expect("Failed to create layout for tensor data");
375
376 // Allocate memory via shared Allocation
377 let alloc_obj = Allocation::new(shape.size, alignment, layout);
378 let ptr = alloc_obj.ptr;
379
380 Self {
381 data: ptr,
382 shape,
383 device,
384 id,
385 requires_grad: false,
386 grad: None,
387 grad_fn: GradFn::None,
388 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
389 _phantom: PhantomData,
390 }
391 }
392
393 /// Enable gradient computation for this tensor
394 ///
395 /// Builder method that enables automatic gradient tracking for this tensor.
396 /// When enabled, all operations involving this tensor will be recorded in
397 /// the computation graph for gradient computation during backward pass.
398 ///
399 /// # Returns
400 ///
401 /// `self` with gradient tracking enabled
402 ///
403 /// # Performance
404 ///
405 /// - **Time Complexity**: O(1) - simple field assignment
406 /// - **Memory**: No additional allocation
407 /// - **Overhead**: Minimal gradtrack tracking overhead when gradients computed
408 ///
409 /// # Examples
410 ///
411 /// ```
412 /// use train_station::Tensor;
413 ///
414 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
415 /// assert!(tensor.requires_grad());
416 /// ```
417 #[track_caller]
418 pub fn with_requires_grad(mut self) -> Self {
419 self.requires_grad = true;
420 self
421 }
422
423 /// Set gradient tracking for this tensor
424 ///
425 /// Controls whether the gradtrack system tracks operations on this tensor
426 /// and computes gradients during backward pass. When disabled, clears
427 /// any existing gradients and gradient functions.
428 ///
429 /// # Arguments
430 ///
431 /// * `requires_grad` - Whether to track gradients for this tensor
432 ///
433 /// # Performance
434 ///
435 /// - **Time Complexity**: O(1) - simple field assignment
436 /// - **Memory**: May free gradient storage when disabled
437 /// - **Overhead**: Zero overhead when gradients disabled
438 ///
439 /// # Examples
440 ///
441 /// ```
442 /// use train_station::Tensor;
443 ///
444 /// let mut tensor = Tensor::ones(vec![2, 3]);
445 /// tensor.set_requires_grad(true);
446 /// assert!(tensor.requires_grad());
447 ///
448 /// // Disable gradient tracking
449 /// tensor.set_requires_grad(false);
450 /// assert!(!tensor.requires_grad());
451 /// ```
452 #[track_caller]
453 pub fn set_requires_grad(&mut self, requires_grad: bool) {
454 self.requires_grad = requires_grad;
455 if !requires_grad {
456 self.grad = None;
457 self.grad_fn = GradFn::None;
458 }
459 }
460
461 /// Check if this tensor requires gradients
462 ///
463 /// # Returns
464 ///
465 /// `true` if gradient tracking is enabled for this tensor
466 ///
467 /// # Examples
468 ///
469 /// ```
470 /// use train_station::Tensor;
471 ///
472 /// let tensor = Tensor::new(vec![2, 3]);
473 /// assert!(!tensor.requires_grad());
474 ///
475 /// let grad_tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
476 /// assert!(grad_tensor.requires_grad());
477 /// ```
478 #[track_caller]
479 pub fn requires_grad(&self) -> bool {
480 self.requires_grad
481 }
482
483 /// Get the accumulated gradients (if any)
484 ///
485 /// Returns a reference to the gradient tensor if gradients have been computed
486 /// and this tensor has gradient tracking enabled.
487 ///
488 /// # Returns
489 ///
490 /// Optional reference to the gradient tensor, or `None` if no gradients exist
491 ///
492 /// # Examples
493 ///
494 /// ```
495 /// use train_station::Tensor;
496 ///
497 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
498 /// assert!(tensor.grad().is_none()); // No gradients computed yet
499 /// ```
500 #[track_caller]
501 pub fn grad(&self) -> Option<&Tensor> {
502 // First check if we have a gradient stored directly
503 if let Some(grad) = self.grad.as_ref() {
504 return Some(grad.as_ref());
505 }
506
507 // If not, check the gradient map for accumulated gradients
508 if let Some(_grad) = gradtrack::get_accumulated_gradient(self.id) {
509 // For simplicity, we'll return None here since we can't return a reference
510 // to a temporary value. In a full implementation, we'd store it in self.grad
511 return None;
512 }
513
514 None
515 }
516
517 /// Get accumulated gradient by value (helper for testing)
518 ///
519 /// Returns the gradient tensor by value, which is useful for testing and
520 /// when you need to own the gradient data.
521 ///
522 /// # Returns
523 ///
524 /// Optional gradient tensor, or `None` if no gradients exist
525 ///
526 /// # Examples
527 ///
528 /// ```
529 /// use train_station::Tensor;
530 ///
531 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
532 /// assert!(tensor.grad_by_value().is_none()); // No gradients computed yet
533 /// ```
534 #[track_caller]
535 pub fn grad_by_value(&self) -> Option<Tensor> {
536 // First check if we have a gradient stored directly
537 if let Some(grad) = self.grad.as_ref() {
538 return Some((**grad).clone());
539 }
540
541 // If not, check the gradient map for accumulated gradients
542 use crate::gradtrack;
543 gradtrack::get_accumulated_gradient(self.id)
544 }
545
546 /// Get the unique ID of this tensor
547 ///
548 /// Returns the unique identifier assigned to this tensor during creation.
549 /// This ID is used for gradtrack tracking and tensor identification.
550 ///
551 /// # Returns
552 ///
553 /// Unique tensor ID as `usize`
554 ///
555 /// # Examples
556 ///
557 /// ```
558 /// use train_station::Tensor;
559 ///
560 /// let tensor1 = Tensor::new(vec![2, 3]);
561 /// let tensor2 = Tensor::new(vec![2, 3]);
562 /// assert_ne!(tensor1.id(), tensor2.id()); // Each tensor has unique ID
563 /// ```
564 #[track_caller]
565 pub fn id(&self) -> usize {
566 self.id
567 }
568
569 /// Detach this tensor from the computation graph
570 ///
571 /// Returns a new tensor with the same data but no gradient tracking.
572 /// This is useful when you want to use a tensor in inference without
573 /// affecting the computation graph.
574 ///
575 /// # Returns
576 ///
577 /// A new tensor with the same data but gradient tracking disabled
578 ///
579 /// # Examples
580 ///
581 /// ```
582 /// use train_station::Tensor;
583 ///
584 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
585 /// let detached = tensor.detach();
586 /// assert!(!detached.requires_grad());
587 /// assert_eq!(tensor.size(), detached.size());
588 /// ```
589 #[track_caller]
590 pub fn detach(&self) -> Self {
591 let mut detached = Self::new(self.shape.dims.clone());
592
593 // Copy data
594 unsafe {
595 let src = self.as_ptr();
596 let dst = detached.as_mut_ptr();
597 std::ptr::copy_nonoverlapping(src, dst, self.size());
598 }
599
600 detached
601 }
602
603 /// Create a new tensor that doesn't track gradients from this one
604 ///
605 /// Similar to detach() but modifies this tensor in place. This is useful
606 /// when you want to disable gradient tracking for the current tensor
607 /// without creating a copy.
608 ///
609 /// # Examples
610 ///
611 /// ```
612 /// use train_station::Tensor;
613 ///
614 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
615 /// assert!(tensor.requires_grad());
616 /// tensor.detach_();
617 /// assert!(!tensor.requires_grad());
618 /// ```
619 #[track_caller]
620 pub fn detach_(&mut self) {
621 self.requires_grad = false;
622 self.grad = None;
623 self.grad_fn = GradFn::None;
624 }
625
626 /// Entry point for backward pass on this tensor
627 ///
628 /// Computes gradients for all tensors in the computation graph that have
629 /// `requires_grad` set to true. This is the main entry point for automatic
630 /// differentiation.
631 ///
632 /// # Arguments
633 ///
634 /// * `grad_output` - Optional gradient tensor for the output. If None, assumes
635 /// the tensor is a scalar (e.g., loss value) and uses a tensor of ones.
636 ///
637 /// # Examples
638 ///
639 /// ```
640 /// use train_station::Tensor;
641 ///
642 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
643 /// let mut result = tensor.add_scalar(5.0);
644 /// result.backward(None);
645 /// // Note: Gradient computation depends on the gradtrack system implementation
646 /// ```
647 #[track_caller]
648 pub fn backward(&mut self, grad_output: Option<Tensor>) {
649 GradEngine::backward(self, grad_output);
650 }
651
652 /// Returns a raw pointer to the tensor data for unsafe operations
653 ///
654 /// # Safety
655 ///
656 /// This is unsafe because it provides direct access to the underlying memory.
657 /// The caller must ensure:
658 /// - The tensor is not dropped while the pointer is used
659 /// - No concurrent mutable access occurs
660 /// - Bounds are respected
661 #[inline]
662 pub unsafe fn as_ptr(&self) -> *const f32 {
663 self.data.as_ptr()
664 }
665
666 /// Returns a mutable raw pointer to the tensor data for unsafe operations
667 ///
668 /// # Safety
669 ///
670 /// This is unsafe because it provides direct mutable access to the underlying memory.
671 /// The caller must ensure:
672 /// - The tensor is not dropped while the pointer is used
673 /// - No concurrent access occurs
674 /// - Bounds are respected
675 #[inline]
676 pub unsafe fn as_mut_ptr(&mut self) -> *mut f32 {
677 self.data.as_ptr()
678 }
679
680 /// Internal method to set gradient function (used by operations)
681 ///
682 /// Sets the gradient function for this tensor. This is used internally
683 /// by tensor operations to record the computation graph for gradtrack.
684 ///
685 /// # Arguments
686 ///
687 /// * `grad_fn` - The gradient function to set
688 ///
689 /// # Implementation Details
690 ///
691 /// This method is called by tensor operations to register the gradient
692 /// computation function. It only sets the gradient function if gradient
693 /// tracking is enabled for this tensor.
694 pub(crate) fn set_grad_fn(&mut self, grad_fn: GradFn) {
695 if self.requires_grad {
696 self.grad_fn = grad_fn;
697 }
698 }
699
700 /// Get a reference to the gradient function (for gradtrack)
701 ///
702 /// Returns a reference to the gradient function associated with this tensor.
703 /// This is used internally by the gradtrack system to compute gradients.
704 ///
705 /// # Returns
706 ///
707 /// Reference to the gradient function
708 ///
709 /// # Implementation Details
710 ///
711 /// This method is used by the gradtrack engine to access the gradient
712 /// computation function during backward pass.
713 #[track_caller]
714 pub fn grad_fn(&self) -> &GradFn {
715 &self.grad_fn
716 }
717
718 /// Internal method to set requires_grad (used by gradtrack operations)
719 ///
720 /// Sets the gradient tracking flag for this tensor. This is used internally
721 /// by gradtrack operations to control gradient computation.
722 ///
723 /// # Arguments
724 ///
725 /// * `requires_grad` - Whether to enable gradient tracking
726 ///
727 /// # Implementation Details
728 ///
729 /// This method is used internally by the gradtrack system to control
730 /// gradient tracking without triggering additional side effects.
731 pub(crate) fn set_requires_grad_internal(&mut self, requires_grad: bool) {
732 self.requires_grad = requires_grad;
733 }
734
735 /// Internal method to accumulate gradients with optimized operations
736 ///
737 /// Accumulates a gradient tensor with any existing gradients for this tensor.
738 /// This is used internally by the gradtrack system to handle gradient accumulation
739 /// during backward pass.
740 ///
741 /// # Arguments
742 ///
743 /// * `grad` - The gradient tensor to accumulate
744 ///
745 /// # Implementation Details
746 ///
747 /// This method is used internally by the gradtrack engine to accumulate
748 /// gradients from multiple backward passes or operations. It only accumulates
749 /// gradients if gradient tracking is enabled for this tensor.
750 pub(crate) fn accumulate_grad(&mut self, grad: Tensor) {
751 if !self.requires_grad {
752 return;
753 }
754
755 match &self.grad {
756 Some(existing_grad) => {
757 // Use optimized tensor addition but create new tensor for safety
758 let accumulated = existing_grad.add_tensor_optimized(&grad);
759 self.grad = Some(Arc::new(accumulated));
760 }
761 None => {
762 self.grad = Some(Arc::new(grad));
763 }
764 }
765 }
766
767 /// Set gradient from external source
768 ///
769 /// Sets the gradient tensor for this tensor. This is used internally by the
770 /// gradtrack system to set gradients during backward pass.
771 ///
772 /// # Arguments
773 ///
774 /// * `grad` - The gradient tensor to set
775 ///
776 /// # Implementation Details
777 ///
778 /// This method is used internally by the gradtrack engine to set gradients
779 /// during backward pass. It only sets the gradient if gradient tracking is
780 /// enabled for this tensor.
781 #[track_caller]
782 pub fn set_grad(&mut self, grad: Tensor) {
783 if self.requires_grad {
784 self.grad = Some(Arc::new(grad));
785 }
786 }
787
788 /// Clear accumulated gradients for this tensor
789 ///
790 /// This method is used by optimizers to zero gradients before each backward pass.
791 /// It clears any accumulated gradients, allowing for fresh gradient computation.
792 ///
793 /// # Examples
794 ///
795 /// ```
796 /// use train_station::Tensor;
797 ///
798 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
799 /// tensor.set_grad(Tensor::ones(vec![2, 3]));
800 /// assert!(tensor.grad().is_some());
801 /// tensor.zero_grad();
802 /// assert!(tensor.grad().is_none());
803 /// ```
804 #[track_caller]
805 pub fn zero_grad(&mut self) {
806 self.grad = None;
807 }
808
809 /// Negate all elements in the tensor in-place
810 ///
811 /// This is used internally for gradient computation in subtraction operations.
812 /// For tensor - tensor operations, the second operand gets a negated gradient.
813 ///
814 /// # Implementation Details
815 ///
816 /// This method is used internally by the gradtrack system to compute gradients
817 /// for subtraction operations. It uses SIMD optimization when available for
818 /// better performance.
819 #[inline]
820 pub(crate) fn negate_inplace(&mut self) {
821 if self.shape.size == 0 {
822 return;
823 }
824
825 unsafe {
826 let ptr = self.data.as_ptr();
827
828 #[cfg(target_arch = "x86_64")]
829 {
830 // Use SIMD for better performance when available
831 if is_x86_feature_detected!("avx2") {
832 self.negate_simd_avx2(ptr);
833 return;
834 }
835 }
836
837 // Fallback to scalar operations
838 for i in 0..self.shape.size {
839 *ptr.add(i) = -*ptr.add(i);
840 }
841 }
842 }
843
844 /// SIMD-optimized negation using AVX2 instructions
845 ///
846 /// Performs in-place negation of tensor elements using AVX2 SIMD instructions
847 /// for improved performance on x86_64 architectures.
848 ///
849 /// # Arguments
850 ///
851 /// * `ptr` - Raw pointer to the tensor data
852 ///
853 /// # Safety
854 ///
855 /// The caller must ensure:
856 /// - `ptr` is a valid pointer to tensor data
857 /// - The tensor size matches the actual data size
858 /// - The tensor is not moved or dropped during this operation
859 ///
860 /// # Implementation Details
861 ///
862 /// This method is used internally by `negate_inplace` when AVX2 is available.
863 /// It processes 8 elements per iteration using AVX2 instructions.
864 #[cfg(target_arch = "x86_64")]
865 #[inline]
866 unsafe fn negate_simd_avx2(&self, ptr: *mut f32) {
867 use std::arch::x86_64::_mm256_setzero_ps;
868
869 let size = self.shape.size;
870 let zero_vec = _mm256_setzero_ps();
871 let simd_count = size / 8; // Process 8 elements per iteration
872 let mut offset = 0;
873
874 // SIMD loop for negation
875 for _ in 0..simd_count {
876 use std::arch::x86_64::{_mm256_load_ps, _mm256_store_ps, _mm256_sub_ps};
877
878 let vec = _mm256_load_ps(ptr.add(offset));
879 let neg_vec = _mm256_sub_ps(zero_vec, vec);
880 _mm256_store_ps(ptr.add(offset), neg_vec);
881 offset += 8;
882 }
883
884 // Handle remaining elements
885 for i in offset..size {
886 *ptr.add(i) = -*ptr.add(i);
887 }
888 }
889
890 // ===== Memory Layout and Optimization API =====
891
892 /// Checks if the tensor data is stored contiguously in memory
893 ///
894 /// # Returns
895 ///
896 /// `true` if the tensor data is contiguous, enabling optimized SIMD operations
897 ///
898 /// # Examples
899 ///
900 /// ```
901 /// use train_station::Tensor;
902 ///
903 /// let tensor = Tensor::new(vec![2, 3, 4]);
904 /// assert!(tensor.is_contiguous());
905 /// ```
906 #[inline]
907 #[track_caller]
908 pub fn is_contiguous(&self) -> bool {
909 self.shape.is_contiguous()
910 }
911
912 /// Checks if this tensor is a view of another tensor
913 ///
914 /// # Returns
915 ///
916 /// `true` if this tensor is a view (non-contiguous reference)
917 #[inline]
918 #[track_caller]
919 pub fn is_view(&self) -> bool {
920 self.shape.is_view()
921 }
922
923 /// Gets the memory strides for all dimensions
924 ///
925 /// # Returns
926 ///
927 /// Reference to the stride vector for efficient memory access calculations
928 ///
929 /// # Examples
930 ///
931 /// ```
932 /// use train_station::Tensor;
933 ///
934 /// let tensor = Tensor::new(vec![2, 3, 4]);
935 /// assert_eq!(tensor.strides(), &[12, 4, 1]);
936 /// ```
937 #[inline]
938 #[track_caller]
939 pub fn strides(&self) -> &[usize] {
940 self.shape.strides()
941 }
942
943 /// Gets the memory stride for a specific dimension
944 ///
945 /// # Arguments
946 ///
947 /// * `dim` - The dimension index
948 ///
949 /// # Returns
950 ///
951 /// The memory stride for the given dimension
952 ///
953 /// # Panics
954 ///
955 /// Panics if `dim` is out of bounds
956 #[inline]
957 #[track_caller]
958 pub fn stride(&self, dim: usize) -> usize {
959 self.shape.stride(dim)
960 }
961
962 /// Gets the memory layout type for optimization decisions
963 ///
964 /// # Returns
965 ///
966 /// Reference to the memory layout information
967 #[inline]
968 #[track_caller]
969 pub fn layout(&self) -> &crate::tensor::MemoryLayout {
970 self.shape.layout()
971 }
972
973 /// Calculates the linear memory offset for given multi-dimensional indices
974 ///
975 /// # Arguments
976 ///
977 /// * `indices` - Vector of indices for each dimension
978 ///
979 /// # Returns
980 ///
981 /// Linear memory offset for direct memory access
982 ///
983 /// # Examples
984 ///
985 /// ```
986 /// use train_station::Tensor;
987 ///
988 /// let tensor = Tensor::new(vec![2, 3, 4]);
989 /// let offset = tensor.memory_offset(&[1, 2, 3]);
990 /// // offset = 1*12 + 2*4 + 3*1 = 23
991 /// ```
992 #[inline]
993 #[track_caller]
994 pub fn memory_offset(&self, indices: &[usize]) -> usize {
995 self.shape.offset(indices)
996 }
997
998 /// Broadcast this tensor with another tensor for element-wise operations
999 ///
1000 /// Returns a tuple containing:
1001 /// - Broadcasted view of self
1002 /// - Broadcasted view of other
1003 /// - Result shape for the operation
1004 ///
1005 /// # Arguments
1006 ///
1007 /// * `other` - The tensor to broadcast with
1008 ///
1009 /// # Returns
1010 ///
1011 /// A tuple `(broadcasted_self, broadcasted_other, result_shape)`
1012 ///
1013 /// # Examples
1014 ///
1015 /// ```
1016 /// use train_station::Tensor;
1017 ///
1018 /// let a = Tensor::ones(vec![2, 1, 4]);
1019 /// let b = Tensor::ones(vec![3, 1]);
1020 /// let result = a.broadcast_with(&b);
1021 /// assert!(result.is_ok());
1022 /// ```
1023 #[track_caller]
1024 pub fn broadcast_with(
1025 &self,
1026 other: &Tensor,
1027 ) -> Result<
1028 (Tensor, Tensor, crate::tensor::Shape),
1029 crate::tensor::ops::broadcasting::BroadcastError,
1030 > {
1031 crate::tensor::ops::broadcasting::broadcast_shapes(self, other)
1032 }
1033
1034 /// Checks if the tensor data is properly aligned for SIMD operations
1035 ///
1036 /// # Returns
1037 ///
1038 /// `true` if the tensor data is aligned to 32-byte boundaries for AVX2
1039 #[inline]
1040 #[track_caller]
1041 pub fn is_simd_aligned(&self) -> bool {
1042 (self.data.as_ptr() as usize) % 32 == 0
1043 }
1044
1045 /// Gets the memory alignment of the tensor data
1046 ///
1047 /// # Returns
1048 ///
1049 /// The memory alignment in bytes (typically 32 for SIMD optimization)
1050 #[inline]
1051 #[track_caller]
1052 pub fn memory_alignment(&self) -> usize {
1053 // Our tensors are allocated with 32-byte alignment for AVX2
1054 32
1055 }
1056
1057 /// Checks if this tensor is broadcastable with another tensor
1058 ///
1059 /// # Arguments
1060 ///
1061 /// * `other` - The other tensor to check broadcasting compatibility
1062 ///
1063 /// # Returns
1064 ///
1065 /// `true` if the tensors are broadcastable according to NumPy broadcasting rules
1066 ///
1067 /// # Examples
1068 ///
1069 /// ```
1070 /// use train_station::Tensor;
1071 ///
1072 /// let a = Tensor::new(vec![2, 3, 4]);
1073 /// let b = Tensor::new(vec![1, 3, 4]);
1074 /// assert!(a.is_broadcastable_with(&b));
1075 /// ```
1076 #[inline]
1077 #[track_caller]
1078 pub fn is_broadcastable_with(&self, other: &Tensor) -> bool {
1079 self.shape.is_broadcastable_with(&other.shape)
1080 }
1081
1082 /// Gets the total number of bytes allocated for this tensor
1083 ///
1084 /// # Returns
1085 ///
1086 /// Total memory footprint in bytes
1087 #[inline]
1088 #[track_caller]
1089 pub fn memory_footprint(&self) -> usize {
1090 self.shape.size * std::mem::size_of::<f32>()
1091 }
1092
1093 /// Get a single element from the tensor at the specified indices
1094 ///
1095 /// # Arguments
1096 ///
1097 /// * `indices` - Multi-dimensional indices to access the element
1098 ///
1099 /// # Returns
1100 ///
1101 /// The value at the specified position
1102 ///
1103 /// # Panics
1104 ///
1105 /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1106 ///
1107 /// # Examples
1108 ///
1109 /// ```
1110 /// use train_station::Tensor;
1111 ///
1112 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1113 /// let value = tensor.get(&[0, 1]);
1114 /// assert_eq!(value, 2.0);
1115 /// ```
1116 #[track_caller]
1117 pub fn get(&self, indices: &[usize]) -> f32 {
1118 assert_eq!(
1119 indices.len(),
1120 self.shape().rank(),
1121 "Indices length must match tensor rank"
1122 );
1123
1124 // Check bounds
1125 for (i, &idx) in indices.iter().enumerate() {
1126 assert!(
1127 idx < self.shape().dims[i],
1128 "Index {} out of bounds for dimension {}",
1129 idx,
1130 i
1131 );
1132 }
1133
1134 let offset = self.memory_offset(indices);
1135 unsafe { *self.as_ptr().add(offset) }
1136 }
1137
1138 /// Set a single element in the tensor at the specified indices
1139 ///
1140 /// # Arguments
1141 ///
1142 /// * `indices` - Multi-dimensional indices to set the element
1143 /// * `value` - The value to set
1144 ///
1145 /// # Panics
1146 ///
1147 /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1148 ///
1149 /// # Examples
1150 ///
1151 /// ```
1152 /// use train_station::Tensor;
1153 ///
1154 /// let mut tensor = Tensor::new(vec![2, 2]);
1155 /// tensor.set(&[0, 1], 42.0);
1156 /// assert_eq!(tensor.get(&[0, 1]), 42.0);
1157 /// ```
1158 #[track_caller]
1159 pub fn set(&mut self, indices: &[usize], value: f32) {
1160 assert_eq!(
1161 indices.len(),
1162 self.shape().rank(),
1163 "Indices length must match tensor rank"
1164 );
1165
1166 // Check bounds
1167 for (i, &idx) in indices.iter().enumerate() {
1168 assert!(
1169 idx < self.shape().dims[i],
1170 "Index {} out of bounds for dimension {}",
1171 idx,
1172 i
1173 );
1174 }
1175
1176 let offset = self.memory_offset(indices);
1177 unsafe {
1178 *self.as_mut_ptr().add(offset) = value;
1179 }
1180 }
1181
1182 /// Returns a safe slice of the tensor's underlying data
1183 ///
1184 /// Provides safe access to the tensor's data without requiring unsafe pointer operations.
1185 /// This is the preferred way to access tensor data for reading values, comparisons,
1186 /// and other operations that don't require direct pointer manipulation.
1187 ///
1188 /// # Returns
1189 ///
1190 /// A slice containing all tensor elements in row-major order
1191 ///
1192 /// # Performance
1193 ///
1194 /// - **Zero-Cost**: Direct slice creation with no copying
1195 /// - **Cache-Friendly**: Sequential memory access pattern
1196 /// - **Safe**: No unsafe code required for basic data access
1197 ///
1198 /// # Examples
1199 ///
1200 /// ```
1201 /// use train_station::Tensor;
1202 ///
1203 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1204 /// let data = tensor.data();
1205 ///
1206 /// // Safe indexing and comparisons
1207 /// assert_eq!(data[0], 1.0);
1208 /// assert_eq!(data.len(), tensor.size());
1209 /// ```
1210 #[inline]
1211 #[track_caller]
1212 pub fn data(&self) -> &[f32] {
1213 if self.size() == 0 {
1214 return &[];
1215 }
1216 unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()) }
1217 }
1218
1219 /// Returns a mutable slice of the tensor's underlying data
1220 ///
1221 /// Provides safe mutable access to the tensor's data without requiring unsafe
1222 /// pointer operations. Use this for in-place modifications of tensor values.
1223 ///
1224 /// # Returns
1225 ///
1226 /// A mutable slice containing all tensor elements in row-major order
1227 ///
1228 /// # Performance
1229 ///
1230 /// - **Zero-Cost**: Direct slice creation with no copying
1231 /// - **Cache-Friendly**: Sequential memory access pattern
1232 /// - **Safe**: No unsafe code required for basic data modification
1233 ///
1234 /// # Examples
1235 ///
1236 /// ```
1237 /// use train_station::Tensor;
1238 ///
1239 /// let mut tensor = Tensor::new(vec![2, 2]);
1240 /// let data = tensor.data_mut();
1241 ///
1242 /// // Safe indexing for modification
1243 /// data[0] = 1.0;
1244 /// data[1] = 2.0;
1245 ///
1246 /// assert_eq!(tensor.get(&[0, 0]), 1.0);
1247 /// assert_eq!(tensor.get(&[0, 1]), 2.0);
1248 /// ```
1249 #[inline]
1250 #[track_caller]
1251 pub fn data_mut(&mut self) -> &mut [f32] {
1252 if self.size() == 0 {
1253 return &mut [];
1254 }
1255 unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) }
1256 }
1257
1258 /// Extract scalar value from single-element tensor
1259 ///
1260 /// This method provides a convenient way to extract the scalar value from
1261 /// tensors that contain exactly one element. This is commonly used with
1262 /// element iterator results and scalar tensor operations.
1263 ///
1264 /// # Returns
1265 ///
1266 /// The scalar value contained in this tensor
1267 ///
1268 /// # Panics
1269 ///
1270 /// Panics if the tensor does not contain exactly one element
1271 ///
1272 /// # Examples
1273 ///
1274 /// ```
1275 /// use train_station::Tensor;
1276 ///
1277 /// // Single-element tensor
1278 /// let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
1279 /// assert_eq!(scalar.value(), 42.0);
1280 /// ```
1281 #[inline]
1282 #[track_caller]
1283 pub fn value(&self) -> f32 {
1284 assert_eq!(
1285 self.size(),
1286 1,
1287 "value() can only be called on tensors with exactly one element. \
1288 This tensor has {} elements with shape {:?}",
1289 self.size(),
1290 self.shape().dims
1291 );
1292 self.data()[0]
1293 }
1294
1295 /// Create a view with a new shape (requires contiguous memory)
1296 ///
1297 /// Behaves like PyTorch `view`: tensor must be contiguous and the total
1298 /// number of elements must remain the same. Supports -1 inference for one dimension.
1299 ///
1300 /// # Arguments
1301 ///
1302 /// * `new_shape` - New shape for the tensor (can contain -1 for inference)
1303 ///
1304 /// # Returns
1305 ///
1306 /// A tensor viewing the same data with a new shape
1307 ///
1308 /// # Examples
1309 ///
1310 /// ```
1311 /// use train_station::Tensor;
1312 ///
1313 /// let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
1314 /// let y = x.view(vec![2, 2]);
1315 /// assert_eq!(y.shape().dims, vec![2, 2]);
1316 /// ```
1317 #[track_caller]
1318 pub fn view(&self, new_shape: Vec<i32>) -> Tensor {
1319 // Use the views module implementation
1320 use crate::tensor::transform::view::TensorViewExt;
1321 TensorViewExt::view(self, new_shape)
1322 }
1323
1324 /// Create an element view for the specified index
1325 ///
1326 /// Returns a scalar tensor (shape \[1\]) that views a single element
1327 /// of the source tensor. Maintains gradient tracking.
1328 ///
1329 /// # Arguments
1330 ///
1331 /// * `index` - Linear index of the element to view
1332 ///
1333 /// # Returns
1334 ///
1335 /// A scalar tensor viewing the specified element
1336 ///
1337 /// # Examples
1338 ///
1339 /// ```
1340 /// use train_station::Tensor;
1341 ///
1342 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
1343 /// let element = tensor.element_view(1);
1344 /// assert_eq!(element.value(), 2.0);
1345 /// ```
1346 #[track_caller]
1347 pub fn element_view(&self, index: usize) -> Tensor {
1348 use crate::tensor::transform::view::TensorViewExt;
1349 TensorViewExt::element_view(self, index)
1350 }
1351
1352 /// Create a slice view of the tensor
1353 ///
1354 /// Returns a view of a contiguous or strided slice of the source tensor.
1355 ///
1356 /// # Arguments
1357 ///
1358 /// * `start` - Starting index
1359 /// * `step` - Step size (1 for contiguous)
1360 /// * `length` - Number of elements
1361 ///
1362 /// # Returns
1363 ///
1364 /// A tensor viewing the specified slice
1365 ///
1366 /// # Examples
1367 ///
1368 /// ```
1369 /// use train_station::Tensor;
1370 ///
1371 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1372 /// let slice = tensor.slice_view(1, 2, 2); // [2.0, 4.0]
1373 /// assert_eq!(slice.data(), &[2.0, 4.0]);
1374 /// ```
1375 #[track_caller]
1376 pub fn slice_view(&self, start: usize, step: usize, length: usize) -> Tensor {
1377 use crate::tensor::transform::view::TensorViewExt;
1378 TensorViewExt::slice_view(self, start, step, length)
1379 }
1380
1381 /// Create a tensor view from raw components
1382 ///
1383 /// Creates a tensor that views existing memory with the specified shape and device.
1384 /// The tensor shares memory with the original allocation through the allocation_owner.
1385 ///
1386 /// # Safety
1387 ///
1388 /// The caller must ensure:
1389 /// - `data` is valid for the number of elements specified by `shape`
1390 /// - `data` remains valid for the lifetime of the returned tensor
1391 /// - `allocation_owner` properly manages the memory lifecycle
1392 ///
1393 /// # Arguments
1394 ///
1395 /// * `data` - Raw pointer to the tensor data
1396 /// * `shape` - Shape of the tensor view
1397 /// * `device` - Device where the tensor is located
1398 /// * `allocation_owner` - Optional shared allocation owner
1399 ///
1400 /// # Returns
1401 ///
1402 /// A new tensor that views the specified memory
1403 ///
1404 /// # Implementation Details
1405 ///
1406 /// This method is used internally to create tensor views that share memory
1407 /// with other tensors. It's primarily used for view operations and memory
1408 /// management.
1409 pub(crate) fn from_raw_view(
1410 data: *const f32,
1411 shape: Shape,
1412 device: Device,
1413 allocation_owner: Option<Arc<Allocation>>,
1414 ) -> Self {
1415 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1416
1417 Self {
1418 data: NonNull::new(data as *mut f32).expect("Data pointer cannot be null"),
1419 shape,
1420 device,
1421 id,
1422 requires_grad: false,
1423 grad: None,
1424 grad_fn: GradFn::None,
1425 allocation_owner,
1426 _phantom: PhantomData,
1427 }
1428 }
1429
1430 /// Get the allocation owner for this tensor
1431 ///
1432 /// Returns the shared allocation owner if this tensor is a view,
1433 /// or None if this tensor owns its memory directly.
1434 ///
1435 /// # Returns
1436 ///
1437 /// Optional reference to the allocation owner
1438 ///
1439 /// # Implementation Details
1440 ///
1441 /// This method is used internally to manage memory lifecycle for tensor views.
1442 /// It helps determine whether a tensor shares memory with another tensor.
1443 #[track_caller]
1444 pub fn allocation_owner(&self) -> Option<&Arc<Allocation>> {
1445 self.allocation_owner.as_ref()
1446 }
1447
1448 /// Create a new tensor with uninitialized memory
1449 ///
1450 /// This method allocates memory for a tensor without initializing it to any value.
1451 /// This is useful for performance-critical operations where the memory will be
1452 /// immediately overwritten, such as matrix multiplication results.
1453 ///
1454 /// # Safety
1455 ///
1456 /// The caller must ensure that all memory is written before reading from the tensor.
1457 /// Reading from uninitialized memory is undefined behavior.
1458 ///
1459 /// # Arguments
1460 ///
1461 /// * `shape_dims` - The dimensions of the tensor
1462 ///
1463 /// # Returns
1464 ///
1465 /// A tensor with uninitialized memory
1466 ///
1467 /// # Performance
1468 ///
1469 /// - **Zero Initialization**: Skips memory initialization for maximum performance
1470 /// - **SIMD Ready**: Properly aligned for vectorized operations
1471 /// - **Memory Efficient**: Uses optimized alignment strategies
1472 ///
1473 /// # Example
1474 ///
1475 /// ```
1476 /// use train_station::Tensor;
1477 ///
1478 /// // Create uninitialized tensor for matmul result
1479 /// let mut result = Tensor::new_uninitialized(vec![100, 100]);
1480 /// // Initialize the memory before use
1481 /// for value in result.data_mut() {
1482 /// *value = 0.0;
1483 /// }
1484 /// ```
1485 #[inline]
1486 #[track_caller]
1487 pub fn new_uninitialized(shape_dims: Vec<usize>) -> Self {
1488 let shape = Shape::new(shape_dims);
1489 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1490
1491 if shape.size == 0 {
1492 // Handle zero-sized tensors
1493 return Self {
1494 data: NonNull::dangling(),
1495 shape,
1496 device: current_device(),
1497 id,
1498 requires_grad: false,
1499 grad: None,
1500 grad_fn: GradFn::None,
1501 allocation_owner: None,
1502 _phantom: PhantomData,
1503 };
1504 }
1505
1506 // Optimized layout calculation for better cache performance
1507 let element_size = std::mem::size_of::<f32>();
1508 let total_size = shape.size * element_size;
1509
1510 // Use cache line alignment for large tensors, smaller alignment for small ones
1511 let alignment = if total_size > 4096 {
1512 64 // Cache line alignment for large tensors
1513 } else if shape.size >= 8 {
1514 32 // AVX2 alignment for medium tensors
1515 } else {
1516 16 // SSE alignment for small tensors
1517 };
1518
1519 let layout = Layout::from_size_align(total_size, alignment)
1520 .expect("Failed to create layout for tensor data");
1521
1522 // Allocate memory via shared Allocation (uninitialized)
1523 let alloc_obj = Allocation::new_uninitialized(shape.size, alignment, layout);
1524 let ptr = alloc_obj.ptr;
1525
1526 Self {
1527 data: ptr,
1528 shape,
1529 device: current_device(),
1530 id,
1531 requires_grad: false,
1532 grad: None,
1533 grad_fn: GradFn::None,
1534 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
1535 _phantom: PhantomData,
1536 }
1537 }
1538}