train_station/tensor/core/utils.rs
1//! Tensor utility functions and core implementation methods
2//!
3//! This module provides essential utility functions for tensor creation, memory management,
4//! gradient tracking, and optimization. It contains the core implementation methods
5//! that enable efficient tensor operations and gradtrack functionality.
6//!
7//! # Key Features
8//!
9//! - **Tensor Creation**: Optimized constructors with memory alignment
10//! - **Memory Management**: Safe memory access and allocation utilities
11//! - **Gradient Tracking**: GradTrack system integration and gradient management
12//! - **Performance Optimization**: SIMD-ready memory layout and alignment
13//! - **Device Management**: CPU and future CUDA device support
14//! - **Memory Layout**: Contiguous, strided, and view memory access patterns
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Alignment**: 16-byte SSE, 32-byte AVX2, 64-byte cache-line alignment
19//! - **SIMD Optimization**: Properly aligned memory for vectorized operations
20//! - **Zero-Cost Abstractions**: Minimal overhead for utility operations
21//! - **Thread Safety**: Atomic operations for gradient tracking and ID generation
22//! - **Memory Efficiency**: Optimized allocation strategies for different tensor sizes
23//!
24//! # Examples
25//!
26//! ## Basic Tensor Creation
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensors of different sizes
32//! let small_tensor = Tensor::new(vec![2, 3]); // 16-byte alignment
33//! let medium_tensor = Tensor::new(vec![32, 32]); // 32-byte alignment
34//! let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
35//!
36//! // Initialize data before use
37//! let mut tensor = Tensor::new(vec![2, 3]);
38//! tensor.fill(0.0); // Initialize with zeros
39//! ```
40//!
41//! ## Gradient Tracking
42//!
43//! ```
44//! use train_station::Tensor;
45//!
46//! let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
47//! assert!(tensor.requires_grad());
48//! ```
49//!
50//! ## Memory Access
51//!
52//! ```
53//! use train_station::Tensor;
54//!
55//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
56//! let value = tensor.get(&[0, 1]);
57//! assert_eq!(value, 2.0);
58//!
59//! let mut tensor = Tensor::new(vec![2, 2]);
60//! tensor.set(&[0, 1], 42.0);
61//! assert_eq!(tensor.get(&[0, 1]), 42.0);
62//! ```
63
64use std::alloc::Layout;
65use std::marker::PhantomData;
66use std::ptr::NonNull;
67use std::sync::atomic::Ordering;
68use std::sync::Arc;
69
70use crate::device::current_device;
71use crate::gradtrack::{self, GradEngine, GradFn};
72use crate::tensor::core::{Allocation, Device, TENSOR_ID_COUNTER};
73use crate::tensor::Shape;
74
75use super::Tensor;
76
77impl Tensor {
78 /// Creates a new tensor with the specified shape and optimized memory layout
79 ///
80 /// Allocates memory with size-dependent alignment for optimal performance:
81 /// - Small tensors (≤8 elements): 16-byte SSE alignment
82 /// - Medium tensors (8-1024 elements): 32-byte AVX2 alignment
83 /// - Large tensors (>1024 elements): 64-byte cache-line alignment
84 ///
85 /// # Arguments
86 ///
87 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
88 ///
89 /// # Returns
90 ///
91 /// A new tensor with uninitialized data. The data must be initialized
92 /// before use to avoid undefined behavior.
93 ///
94 /// # Performance
95 ///
96 /// - **Memory Allocation**: Single allocation with optimized alignment
97 /// - **SIMD Ready**: Properly aligned for vectorized operations
98 /// - **Cache Friendly**: Optimized for CPU cache hierarchies
99 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
100 ///
101 /// # Safety
102 ///
103 /// The returned tensor contains uninitialized memory. You must initialize
104 /// the data before performing any operations that read from it.
105 ///
106 /// # Examples
107 ///
108 /// ```
109 /// use train_station::Tensor;
110 ///
111 /// // Create tensors of different sizes
112 /// let small_tensor = Tensor::new(vec![2, 3]); // 16-byte alignment
113 /// let medium_tensor = Tensor::new(vec![32, 32]); // 32-byte alignment
114 /// let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
115 ///
116 /// // Initialize data before use
117 /// let mut tensor = Tensor::new(vec![2, 3]);
118 /// tensor.fill(0.0); // Initialize with zeros
119 /// ```
120 #[inline]
121 pub fn new(shape_dims: Vec<usize>) -> Self {
122 let shape = Shape::new(shape_dims);
123 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
124
125 if shape.size == 0 {
126 // Handle zero-sized tensors
127 return Self {
128 data: NonNull::dangling(),
129 shape,
130 device: current_device(),
131 id,
132 requires_grad: false,
133 grad: None,
134 grad_fn: GradFn::None,
135 allocation_owner: None,
136 _phantom: PhantomData,
137 };
138 }
139
140 // Optimized layout calculation for better cache performance
141 let element_size = std::mem::size_of::<f32>();
142 let total_size = shape.size * element_size;
143
144 // Use cache line alignment for large tensors, smaller alignment for small ones
145 let alignment = if total_size > 4096 {
146 64 // Cache line alignment for large tensors
147 } else if shape.size >= 8 {
148 32 // AVX2 alignment for medium tensors
149 } else {
150 16 // SSE alignment for small tensors
151 };
152
153 let layout = Layout::from_size_align(total_size, alignment)
154 .expect("Failed to create layout for tensor data");
155
156 // Allocate memory via shared Allocation
157 let alloc_obj = Allocation::new(shape.size, alignment, layout);
158 let ptr = alloc_obj.ptr;
159
160 Self {
161 data: ptr,
162 shape,
163 device: current_device(),
164 id,
165 requires_grad: false,
166 grad: None,
167 grad_fn: GradFn::None,
168 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
169 _phantom: PhantomData,
170 }
171 }
172
173 /// Returns the shape and dimensional information of the tensor
174 ///
175 /// Provides access to the tensor's dimensions, size, strides, and memory
176 /// layout information. This is used for shape validation, memory access
177 /// calculations, and optimization decisions.
178 ///
179 /// # Returns
180 ///
181 /// Reference to the tensor's shape information containing dimensions,
182 /// size, strides, and memory layout type.
183 ///
184 /// # Performance
185 ///
186 /// - **Time Complexity**: O(1) - direct field access
187 /// - **Memory**: No allocation - returns reference to existing data
188 ///
189 /// # Examples
190 ///
191 /// ```
192 /// use train_station::Tensor;
193 ///
194 /// let tensor = Tensor::new(vec![2, 3, 4]);
195 /// let shape = tensor.shape();
196 /// assert_eq!(shape.dims, vec![2, 3, 4]);
197 /// assert_eq!(shape.size, 24);
198 /// assert_eq!(shape.rank(), 3);
199 /// ```
200 #[inline]
201 pub fn shape(&self) -> &Shape {
202 &self.shape
203 }
204
205 /// Returns the total number of elements in the tensor
206 ///
207 /// Provides the total count of elements across all dimensions. This is
208 /// used for memory allocation, iteration bounds, and performance optimization.
209 ///
210 /// # Returns
211 ///
212 /// Total number of elements as `usize`
213 ///
214 /// # Performance
215 ///
216 /// - **Time Complexity**: O(1) - direct field access
217 /// - **Memory**: No allocation - returns stored value
218 ///
219 /// # Examples
220 ///
221 /// ```
222 /// use train_station::Tensor;
223 ///
224 /// let tensor = Tensor::new(vec![2, 3, 4]);
225 /// assert_eq!(tensor.size(), 24); // 2 * 3 * 4
226 ///
227 /// let scalar = Tensor::new(vec![1]);
228 /// assert_eq!(scalar.size(), 1);
229 ///
230 /// let empty = Tensor::new(vec![0]);
231 /// assert_eq!(empty.size(), 0);
232 /// ```
233 #[inline]
234 pub fn size(&self) -> usize {
235 self.shape.size
236 }
237
238 /// Returns the device where this tensor is located
239 ///
240 /// Provides the physical location of the tensor data (CPU/GPU). This
241 /// determines which operations can be performed on the tensor and where
242 /// computations will be executed.
243 ///
244 /// # Returns
245 ///
246 /// Device enum indicating the tensor's physical location
247 ///
248 /// # Performance
249 ///
250 /// - **Time Complexity**: O(1) - direct field access
251 /// - **Memory**: No allocation - returns stored value
252 ///
253 /// # Examples
254 ///
255 /// ```
256 /// use train_station::Tensor;
257 ///
258 /// let tensor = Tensor::new(vec![2, 3]);
259 /// assert!(tensor.device().is_cpu());
260 /// assert!(!tensor.device().is_cuda());
261 /// ```
262 #[inline]
263 pub fn device(&self) -> Device {
264 self.device
265 }
266
267 /// Creates a new tensor with the specified shape on a specific device
268 ///
269 /// Allocates memory on the specified device with the same optimized alignment
270 /// strategy as `new()`. Currently supports CPU device with future CUDA support.
271 ///
272 /// # Arguments
273 ///
274 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
275 /// * `device` - The device where the tensor should be allocated
276 ///
277 /// # Returns
278 ///
279 /// A new tensor with uninitialized data on the specified device
280 ///
281 /// # Performance
282 ///
283 /// - **Memory Allocation**: Device-specific allocation with optimized alignment
284 /// - **SIMD Ready**: Properly aligned for vectorized operations on target device
285 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
286 ///
287 /// # Panics
288 ///
289 /// Panics if the specified device is not supported (e.g., CUDA without feature flag)
290 ///
291 /// # Examples
292 ///
293 /// ```
294 /// use train_station::Tensor;
295 ///
296 /// let tensor = Tensor::new_on_device(vec![2, 3], train_station::Device::cpu());
297 /// assert!(tensor.device().is_cpu());
298 /// assert_eq!(tensor.size(), 6);
299 /// ```
300 ///
301 /// # Arguments
302 ///
303 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
304 /// * `device` - Device where the tensor should be allocated
305 ///
306 /// # Returns
307 ///
308 /// A new tensor with uninitialized data on the specified device
309 ///
310 /// # Panics
311 ///
312 /// Panics if the device is not supported (currently only CPU is supported)
313 ///
314 /// # Performance
315 ///
316 /// - **Memory Allocation**: Single allocation with optimized alignment
317 /// - **SIMD Ready**: Properly aligned for vectorized operations
318 /// - **Cache Friendly**: Optimized for CPU cache hierarchies
319 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
320 ///
321 /// # Examples
322 ///
323 /// ```
324 /// use train_station::{Tensor, Device};
325 ///
326 /// // Create tensor on CPU device
327 /// let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
328 /// assert_eq!(tensor.device(), Device::cpu());
329 /// assert_eq!(tensor.size(), 6);
330 /// ```
331 pub fn new_on_device(shape_dims: Vec<usize>, device: Device) -> Self {
332 // For now, only CPU is supported
333 if !device.is_cpu() {
334 panic!("Only CPU device is currently supported. CUDA support is planned for future releases.");
335 }
336
337 let shape = Shape::new(shape_dims);
338 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
339
340 if shape.size == 0 {
341 // Handle zero-sized tensors
342 return Self {
343 data: NonNull::dangling(),
344 shape,
345 device,
346 id,
347 requires_grad: false,
348 grad: None,
349 grad_fn: GradFn::None,
350 allocation_owner: None,
351 _phantom: PhantomData,
352 };
353 }
354
355 // Optimized layout calculation for better cache performance
356 let element_size = std::mem::size_of::<f32>();
357 let total_size = shape.size * element_size;
358
359 // Use cache line alignment for large tensors, smaller alignment for small ones
360 let alignment = if total_size > 4096 {
361 64 // Cache line alignment for large tensors
362 } else if shape.size >= 8 {
363 32 // AVX2 alignment for medium tensors
364 } else {
365 16 // SSE alignment for small tensors
366 };
367
368 let layout = Layout::from_size_align(total_size, alignment)
369 .expect("Failed to create layout for tensor data");
370
371 // Allocate memory via shared Allocation
372 let alloc_obj = Allocation::new(shape.size, alignment, layout);
373 let ptr = alloc_obj.ptr;
374
375 Self {
376 data: ptr,
377 shape,
378 device,
379 id,
380 requires_grad: false,
381 grad: None,
382 grad_fn: GradFn::None,
383 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
384 _phantom: PhantomData,
385 }
386 }
387
388 /// Enable gradient computation for this tensor
389 ///
390 /// Builder method that enables automatic gradient tracking for this tensor.
391 /// When enabled, all operations involving this tensor will be recorded in
392 /// the computation graph for gradient computation during backward pass.
393 ///
394 /// # Returns
395 ///
396 /// `self` with gradient tracking enabled
397 ///
398 /// # Performance
399 ///
400 /// - **Time Complexity**: O(1) - simple field assignment
401 /// - **Memory**: No additional allocation
402 /// - **Overhead**: Minimal gradtrack tracking overhead when gradients computed
403 ///
404 /// # Examples
405 ///
406 /// ```
407 /// use train_station::Tensor;
408 ///
409 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
410 /// assert!(tensor.requires_grad());
411 /// ```
412 pub fn with_requires_grad(mut self) -> Self {
413 self.requires_grad = true;
414 self
415 }
416
417 /// Set gradient tracking for this tensor
418 ///
419 /// Controls whether the gradtrack system tracks operations on this tensor
420 /// and computes gradients during backward pass. When disabled, clears
421 /// any existing gradients and gradient functions.
422 ///
423 /// # Arguments
424 ///
425 /// * `requires_grad` - Whether to track gradients for this tensor
426 ///
427 /// # Performance
428 ///
429 /// - **Time Complexity**: O(1) - simple field assignment
430 /// - **Memory**: May free gradient storage when disabled
431 /// - **Overhead**: Zero overhead when gradients disabled
432 ///
433 /// # Examples
434 ///
435 /// ```
436 /// use train_station::Tensor;
437 ///
438 /// let mut tensor = Tensor::ones(vec![2, 3]);
439 /// tensor.set_requires_grad(true);
440 /// assert!(tensor.requires_grad());
441 ///
442 /// // Disable gradient tracking
443 /// tensor.set_requires_grad(false);
444 /// assert!(!tensor.requires_grad());
445 /// ```
446 pub fn set_requires_grad(&mut self, requires_grad: bool) {
447 self.requires_grad = requires_grad;
448 if !requires_grad {
449 self.grad = None;
450 self.grad_fn = GradFn::None;
451 }
452 }
453
454 /// Check if this tensor requires gradients
455 ///
456 /// # Returns
457 ///
458 /// `true` if gradient tracking is enabled for this tensor
459 ///
460 /// # Examples
461 ///
462 /// ```
463 /// use train_station::Tensor;
464 ///
465 /// let tensor = Tensor::new(vec![2, 3]);
466 /// assert!(!tensor.requires_grad());
467 ///
468 /// let grad_tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
469 /// assert!(grad_tensor.requires_grad());
470 /// ```
471 pub fn requires_grad(&self) -> bool {
472 self.requires_grad
473 }
474
475 /// Get the accumulated gradients (if any)
476 ///
477 /// Returns a reference to the gradient tensor if gradients have been computed
478 /// and this tensor has gradient tracking enabled.
479 ///
480 /// # Returns
481 ///
482 /// Optional reference to the gradient tensor, or `None` if no gradients exist
483 ///
484 /// # Examples
485 ///
486 /// ```
487 /// use train_station::Tensor;
488 ///
489 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
490 /// assert!(tensor.grad().is_none()); // No gradients computed yet
491 /// ```
492 pub fn grad(&self) -> Option<&Tensor> {
493 // First check if we have a gradient stored directly
494 if let Some(grad) = self.grad.as_ref() {
495 return Some(grad.as_ref());
496 }
497
498 // If not, check the gradient map for accumulated gradients
499 if let Some(_grad) = gradtrack::get_accumulated_gradient(self.id) {
500 // For simplicity, we'll return None here since we can't return a reference
501 // to a temporary value. In a full implementation, we'd store it in self.grad
502 return None;
503 }
504
505 None
506 }
507
508 /// Get accumulated gradient by value (helper for testing)
509 ///
510 /// Returns the gradient tensor by value, which is useful for testing and
511 /// when you need to own the gradient data.
512 ///
513 /// # Returns
514 ///
515 /// Optional gradient tensor, or `None` if no gradients exist
516 ///
517 /// # Examples
518 ///
519 /// ```
520 /// use train_station::Tensor;
521 ///
522 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
523 /// assert!(tensor.grad_by_value().is_none()); // No gradients computed yet
524 /// ```
525 pub fn grad_by_value(&self) -> Option<Tensor> {
526 // First check if we have a gradient stored directly
527 if let Some(grad) = self.grad.as_ref() {
528 return Some((**grad).clone());
529 }
530
531 // If not, check the gradient map for accumulated gradients
532 use crate::gradtrack;
533 gradtrack::get_accumulated_gradient(self.id)
534 }
535
536 /// Get the unique ID of this tensor
537 ///
538 /// Returns the unique identifier assigned to this tensor during creation.
539 /// This ID is used for gradtrack tracking and tensor identification.
540 ///
541 /// # Returns
542 ///
543 /// Unique tensor ID as `usize`
544 ///
545 /// # Examples
546 ///
547 /// ```
548 /// use train_station::Tensor;
549 ///
550 /// let tensor1 = Tensor::new(vec![2, 3]);
551 /// let tensor2 = Tensor::new(vec![2, 3]);
552 /// assert_ne!(tensor1.id(), tensor2.id()); // Each tensor has unique ID
553 /// ```
554 pub fn id(&self) -> usize {
555 self.id
556 }
557
558 /// Detach this tensor from the computation graph
559 ///
560 /// Returns a new tensor with the same data but no gradient tracking.
561 /// This is useful when you want to use a tensor in inference without
562 /// affecting the computation graph.
563 ///
564 /// # Returns
565 ///
566 /// A new tensor with the same data but gradient tracking disabled
567 ///
568 /// # Examples
569 ///
570 /// ```
571 /// use train_station::Tensor;
572 ///
573 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
574 /// let detached = tensor.detach();
575 /// assert!(!detached.requires_grad());
576 /// assert_eq!(tensor.size(), detached.size());
577 /// ```
578 pub fn detach(&self) -> Self {
579 let mut detached = Self::new(self.shape.dims.clone());
580
581 // Copy data
582 unsafe {
583 let src = self.as_ptr();
584 let dst = detached.as_mut_ptr();
585 std::ptr::copy_nonoverlapping(src, dst, self.size());
586 }
587
588 detached
589 }
590
591 /// Create a new tensor that doesn't track gradients from this one
592 ///
593 /// Similar to detach() but modifies this tensor in place. This is useful
594 /// when you want to disable gradient tracking for the current tensor
595 /// without creating a copy.
596 ///
597 /// # Examples
598 ///
599 /// ```
600 /// use train_station::Tensor;
601 ///
602 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
603 /// assert!(tensor.requires_grad());
604 /// tensor.detach_();
605 /// assert!(!tensor.requires_grad());
606 /// ```
607 pub fn detach_(&mut self) {
608 self.requires_grad = false;
609 self.grad = None;
610 self.grad_fn = GradFn::None;
611 }
612
613 /// Entry point for backward pass on this tensor
614 ///
615 /// Computes gradients for all tensors in the computation graph that have
616 /// `requires_grad` set to true. This is the main entry point for automatic
617 /// differentiation.
618 ///
619 /// # Arguments
620 ///
621 /// * `grad_output` - Optional gradient tensor for the output. If None, assumes
622 /// the tensor is a scalar (e.g., loss value) and uses a tensor of ones.
623 ///
624 /// # Examples
625 ///
626 /// ```
627 /// use train_station::Tensor;
628 ///
629 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
630 /// let mut result = tensor.add_scalar(5.0);
631 /// result.backward(None);
632 /// // Note: Gradient computation depends on the gradtrack system implementation
633 /// ```
634 pub fn backward(&mut self, grad_output: Option<Tensor>) {
635 GradEngine::backward(self, grad_output);
636 }
637
638 /// Returns a raw pointer to the tensor data for unsafe operations
639 ///
640 /// # Safety
641 ///
642 /// This is unsafe because it provides direct access to the underlying memory.
643 /// The caller must ensure:
644 /// - The tensor is not dropped while the pointer is used
645 /// - No concurrent mutable access occurs
646 /// - Bounds are respected
647 #[inline]
648 pub unsafe fn as_ptr(&self) -> *const f32 {
649 self.data.as_ptr()
650 }
651
652 /// Returns a mutable raw pointer to the tensor data for unsafe operations
653 ///
654 /// # Safety
655 ///
656 /// This is unsafe because it provides direct mutable access to the underlying memory.
657 /// The caller must ensure:
658 /// - The tensor is not dropped while the pointer is used
659 /// - No concurrent access occurs
660 /// - Bounds are respected
661 #[inline]
662 pub unsafe fn as_mut_ptr(&mut self) -> *mut f32 {
663 self.data.as_ptr()
664 }
665
666 /// Internal method to set gradient function (used by operations)
667 ///
668 /// Sets the gradient function for this tensor. This is used internally
669 /// by tensor operations to record the computation graph for gradtrack.
670 ///
671 /// # Arguments
672 ///
673 /// * `grad_fn` - The gradient function to set
674 ///
675 /// # Implementation Details
676 ///
677 /// This method is called by tensor operations to register the gradient
678 /// computation function. It only sets the gradient function if gradient
679 /// tracking is enabled for this tensor.
680 pub(crate) fn set_grad_fn(&mut self, grad_fn: GradFn) {
681 if self.requires_grad {
682 self.grad_fn = grad_fn;
683 }
684 }
685
686 /// Get a reference to the gradient function (for gradtrack)
687 ///
688 /// Returns a reference to the gradient function associated with this tensor.
689 /// This is used internally by the gradtrack system to compute gradients.
690 ///
691 /// # Returns
692 ///
693 /// Reference to the gradient function
694 ///
695 /// # Implementation Details
696 ///
697 /// This method is used by the gradtrack engine to access the gradient
698 /// computation function during backward pass.
699 pub fn grad_fn(&self) -> &GradFn {
700 &self.grad_fn
701 }
702
703 /// Internal method to set requires_grad (used by gradtrack operations)
704 ///
705 /// Sets the gradient tracking flag for this tensor. This is used internally
706 /// by gradtrack operations to control gradient computation.
707 ///
708 /// # Arguments
709 ///
710 /// * `requires_grad` - Whether to enable gradient tracking
711 ///
712 /// # Implementation Details
713 ///
714 /// This method is used internally by the gradtrack system to control
715 /// gradient tracking without triggering additional side effects.
716 pub(crate) fn set_requires_grad_internal(&mut self, requires_grad: bool) {
717 self.requires_grad = requires_grad;
718 }
719
720 /// Internal method to accumulate gradients with optimized operations
721 ///
722 /// Accumulates a gradient tensor with any existing gradients for this tensor.
723 /// This is used internally by the gradtrack system to handle gradient accumulation
724 /// during backward pass.
725 ///
726 /// # Arguments
727 ///
728 /// * `grad` - The gradient tensor to accumulate
729 ///
730 /// # Implementation Details
731 ///
732 /// This method is used internally by the gradtrack engine to accumulate
733 /// gradients from multiple backward passes or operations. It only accumulates
734 /// gradients if gradient tracking is enabled for this tensor.
735 pub(crate) fn accumulate_grad(&mut self, grad: Tensor) {
736 if !self.requires_grad {
737 return;
738 }
739
740 match &self.grad {
741 Some(existing_grad) => {
742 // Use optimized tensor addition but create new tensor for safety
743 let accumulated = existing_grad.add_tensor_optimized(&grad);
744 self.grad = Some(Arc::new(accumulated));
745 }
746 None => {
747 self.grad = Some(Arc::new(grad));
748 }
749 }
750 }
751
752 /// Set gradient from external source
753 ///
754 /// Sets the gradient tensor for this tensor. This is used internally by the
755 /// gradtrack system to set gradients during backward pass.
756 ///
757 /// # Arguments
758 ///
759 /// * `grad` - The gradient tensor to set
760 ///
761 /// # Implementation Details
762 ///
763 /// This method is used internally by the gradtrack engine to set gradients
764 /// during backward pass. It only sets the gradient if gradient tracking is
765 /// enabled for this tensor.
766 pub fn set_grad(&mut self, grad: Tensor) {
767 if self.requires_grad {
768 self.grad = Some(Arc::new(grad));
769 }
770 }
771
772 /// Clear accumulated gradients for this tensor
773 ///
774 /// This method is used by optimizers to zero gradients before each backward pass.
775 /// It clears any accumulated gradients, allowing for fresh gradient computation.
776 ///
777 /// # Examples
778 ///
779 /// ```
780 /// use train_station::Tensor;
781 ///
782 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
783 /// tensor.set_grad(Tensor::ones(vec![2, 3]));
784 /// assert!(tensor.grad().is_some());
785 /// tensor.zero_grad();
786 /// assert!(tensor.grad().is_none());
787 /// ```
788 pub fn zero_grad(&mut self) {
789 self.grad = None;
790 }
791
792 /// Negate all elements in the tensor in-place
793 ///
794 /// This is used internally for gradient computation in subtraction operations.
795 /// For tensor - tensor operations, the second operand gets a negated gradient.
796 ///
797 /// # Implementation Details
798 ///
799 /// This method is used internally by the gradtrack system to compute gradients
800 /// for subtraction operations. It uses SIMD optimization when available for
801 /// better performance.
802 #[inline]
803 pub(crate) fn negate_inplace(&mut self) {
804 if self.shape.size == 0 {
805 return;
806 }
807
808 unsafe {
809 let ptr = self.data.as_ptr();
810
811 #[cfg(target_arch = "x86_64")]
812 {
813 // Use SIMD for better performance when available
814 if is_x86_feature_detected!("avx2") {
815 self.negate_simd_avx2(ptr);
816 return;
817 }
818 }
819
820 // Fallback to scalar operations
821 for i in 0..self.shape.size {
822 *ptr.add(i) = -*ptr.add(i);
823 }
824 }
825 }
826
827 /// SIMD-optimized negation using AVX2 instructions
828 ///
829 /// Performs in-place negation of tensor elements using AVX2 SIMD instructions
830 /// for improved performance on x86_64 architectures.
831 ///
832 /// # Arguments
833 ///
834 /// * `ptr` - Raw pointer to the tensor data
835 ///
836 /// # Safety
837 ///
838 /// The caller must ensure:
839 /// - `ptr` is a valid pointer to tensor data
840 /// - The tensor size matches the actual data size
841 /// - The tensor is not moved or dropped during this operation
842 ///
843 /// # Implementation Details
844 ///
845 /// This method is used internally by `negate_inplace` when AVX2 is available.
846 /// It processes 8 elements per iteration using AVX2 instructions.
847 #[cfg(target_arch = "x86_64")]
848 #[inline]
849 unsafe fn negate_simd_avx2(&self, ptr: *mut f32) {
850 use std::arch::x86_64::_mm256_setzero_ps;
851
852 let size = self.shape.size;
853 let zero_vec = _mm256_setzero_ps();
854 let simd_count = size / 8; // Process 8 elements per iteration
855 let mut offset = 0;
856
857 // SIMD loop for negation
858 for _ in 0..simd_count {
859 use std::arch::x86_64::{_mm256_load_ps, _mm256_store_ps, _mm256_sub_ps};
860
861 let vec = _mm256_load_ps(ptr.add(offset));
862 let neg_vec = _mm256_sub_ps(zero_vec, vec);
863 _mm256_store_ps(ptr.add(offset), neg_vec);
864 offset += 8;
865 }
866
867 // Handle remaining elements
868 for i in offset..size {
869 *ptr.add(i) = -*ptr.add(i);
870 }
871 }
872
873 // ===== Memory Layout and Optimization API =====
874
875 /// Checks if the tensor data is stored contiguously in memory
876 ///
877 /// # Returns
878 ///
879 /// `true` if the tensor data is contiguous, enabling optimized SIMD operations
880 ///
881 /// # Examples
882 ///
883 /// ```
884 /// use train_station::Tensor;
885 ///
886 /// let tensor = Tensor::new(vec![2, 3, 4]);
887 /// assert!(tensor.is_contiguous());
888 /// ```
889 #[inline]
890 pub fn is_contiguous(&self) -> bool {
891 self.shape.is_contiguous()
892 }
893
894 /// Checks if this tensor is a view of another tensor
895 ///
896 /// # Returns
897 ///
898 /// `true` if this tensor is a view (non-contiguous reference)
899 #[inline]
900 pub fn is_view(&self) -> bool {
901 self.shape.is_view()
902 }
903
904 /// Gets the memory strides for all dimensions
905 ///
906 /// # Returns
907 ///
908 /// Reference to the stride vector for efficient memory access calculations
909 ///
910 /// # Examples
911 ///
912 /// ```
913 /// use train_station::Tensor;
914 ///
915 /// let tensor = Tensor::new(vec![2, 3, 4]);
916 /// assert_eq!(tensor.strides(), &[12, 4, 1]);
917 /// ```
918 #[inline]
919 pub fn strides(&self) -> &[usize] {
920 self.shape.strides()
921 }
922
923 /// Gets the memory stride for a specific dimension
924 ///
925 /// # Arguments
926 ///
927 /// * `dim` - The dimension index
928 ///
929 /// # Returns
930 ///
931 /// The memory stride for the given dimension
932 ///
933 /// # Panics
934 ///
935 /// Panics if `dim` is out of bounds
936 #[inline]
937 pub fn stride(&self, dim: usize) -> usize {
938 self.shape.stride(dim)
939 }
940
941 /// Gets the memory layout type for optimization decisions
942 ///
943 /// # Returns
944 ///
945 /// Reference to the memory layout information
946 #[inline]
947 pub fn layout(&self) -> &crate::tensor::MemoryLayout {
948 self.shape.layout()
949 }
950
951 /// Calculates the linear memory offset for given multi-dimensional indices
952 ///
953 /// # Arguments
954 ///
955 /// * `indices` - Vector of indices for each dimension
956 ///
957 /// # Returns
958 ///
959 /// Linear memory offset for direct memory access
960 ///
961 /// # Examples
962 ///
963 /// ```
964 /// use train_station::Tensor;
965 ///
966 /// let tensor = Tensor::new(vec![2, 3, 4]);
967 /// let offset = tensor.memory_offset(&[1, 2, 3]);
968 /// // offset = 1*12 + 2*4 + 3*1 = 23
969 /// ```
970 #[inline]
971 pub fn memory_offset(&self, indices: &[usize]) -> usize {
972 self.shape.offset(indices)
973 }
974
975 /// Broadcast this tensor with another tensor for element-wise operations
976 ///
977 /// Returns a tuple containing:
978 /// - Broadcasted view of self
979 /// - Broadcasted view of other
980 /// - Result shape for the operation
981 ///
982 /// # Arguments
983 ///
984 /// * `other` - The tensor to broadcast with
985 ///
986 /// # Returns
987 ///
988 /// A tuple `(broadcasted_self, broadcasted_other, result_shape)`
989 ///
990 /// # Examples
991 ///
992 /// ```
993 /// use train_station::Tensor;
994 ///
995 /// let a = Tensor::ones(vec![2, 1, 4]);
996 /// let b = Tensor::ones(vec![3, 1]);
997 /// let result = a.broadcast_with(&b);
998 /// assert!(result.is_ok());
999 /// ```
1000 pub fn broadcast_with(
1001 &self,
1002 other: &Tensor,
1003 ) -> Result<
1004 (Tensor, Tensor, crate::tensor::Shape),
1005 crate::tensor::ops::broadcasting::BroadcastError,
1006 > {
1007 crate::tensor::ops::broadcasting::broadcast_shapes(self, other)
1008 }
1009
1010 /// Checks if the tensor data is properly aligned for SIMD operations
1011 ///
1012 /// # Returns
1013 ///
1014 /// `true` if the tensor data is aligned to 32-byte boundaries for AVX2
1015 #[inline]
1016 pub fn is_simd_aligned(&self) -> bool {
1017 (self.data.as_ptr() as usize) % 32 == 0
1018 }
1019
1020 /// Gets the memory alignment of the tensor data
1021 ///
1022 /// # Returns
1023 ///
1024 /// The memory alignment in bytes (typically 32 for SIMD optimization)
1025 #[inline]
1026 pub fn memory_alignment(&self) -> usize {
1027 // Our tensors are allocated with 32-byte alignment for AVX2
1028 32
1029 }
1030
1031 /// Checks if this tensor is broadcastable with another tensor
1032 ///
1033 /// # Arguments
1034 ///
1035 /// * `other` - The other tensor to check broadcasting compatibility
1036 ///
1037 /// # Returns
1038 ///
1039 /// `true` if the tensors are broadcastable according to NumPy broadcasting rules
1040 ///
1041 /// # Examples
1042 ///
1043 /// ```
1044 /// use train_station::Tensor;
1045 ///
1046 /// let a = Tensor::new(vec![2, 3, 4]);
1047 /// let b = Tensor::new(vec![1, 3, 4]);
1048 /// assert!(a.is_broadcastable_with(&b));
1049 /// ```
1050 #[inline]
1051 pub fn is_broadcastable_with(&self, other: &Tensor) -> bool {
1052 self.shape.is_broadcastable_with(&other.shape)
1053 }
1054
1055 /// Gets the total number of bytes allocated for this tensor
1056 ///
1057 /// # Returns
1058 ///
1059 /// Total memory footprint in bytes
1060 #[inline]
1061 pub fn memory_footprint(&self) -> usize {
1062 self.shape.size * std::mem::size_of::<f32>()
1063 }
1064
1065 /// Get a single element from the tensor at the specified indices
1066 ///
1067 /// # Arguments
1068 ///
1069 /// * `indices` - Multi-dimensional indices to access the element
1070 ///
1071 /// # Returns
1072 ///
1073 /// The value at the specified position
1074 ///
1075 /// # Panics
1076 ///
1077 /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1078 ///
1079 /// # Examples
1080 ///
1081 /// ```
1082 /// use train_station::Tensor;
1083 ///
1084 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1085 /// let value = tensor.get(&[0, 1]);
1086 /// assert_eq!(value, 2.0);
1087 /// ```
1088 pub fn get(&self, indices: &[usize]) -> f32 {
1089 assert_eq!(
1090 indices.len(),
1091 self.shape().rank(),
1092 "Indices length must match tensor rank"
1093 );
1094
1095 // Check bounds
1096 for (i, &idx) in indices.iter().enumerate() {
1097 assert!(
1098 idx < self.shape().dims[i],
1099 "Index {} out of bounds for dimension {}",
1100 idx,
1101 i
1102 );
1103 }
1104
1105 let offset = self.memory_offset(indices);
1106 unsafe { *self.as_ptr().add(offset) }
1107 }
1108
1109 /// Set a single element in the tensor at the specified indices
1110 ///
1111 /// # Arguments
1112 ///
1113 /// * `indices` - Multi-dimensional indices to set the element
1114 /// * `value` - The value to set
1115 ///
1116 /// # Panics
1117 ///
1118 /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1119 ///
1120 /// # Examples
1121 ///
1122 /// ```
1123 /// use train_station::Tensor;
1124 ///
1125 /// let mut tensor = Tensor::new(vec![2, 2]);
1126 /// tensor.set(&[0, 1], 42.0);
1127 /// assert_eq!(tensor.get(&[0, 1]), 42.0);
1128 /// ```
1129 pub fn set(&mut self, indices: &[usize], value: f32) {
1130 assert_eq!(
1131 indices.len(),
1132 self.shape().rank(),
1133 "Indices length must match tensor rank"
1134 );
1135
1136 // Check bounds
1137 for (i, &idx) in indices.iter().enumerate() {
1138 assert!(
1139 idx < self.shape().dims[i],
1140 "Index {} out of bounds for dimension {}",
1141 idx,
1142 i
1143 );
1144 }
1145
1146 let offset = self.memory_offset(indices);
1147 unsafe {
1148 *self.as_mut_ptr().add(offset) = value;
1149 }
1150 }
1151
1152 /// Returns a safe slice of the tensor's underlying data
1153 ///
1154 /// Provides safe access to the tensor's data without requiring unsafe pointer operations.
1155 /// This is the preferred way to access tensor data for reading values, comparisons,
1156 /// and other operations that don't require direct pointer manipulation.
1157 ///
1158 /// # Returns
1159 ///
1160 /// A slice containing all tensor elements in row-major order
1161 ///
1162 /// # Performance
1163 ///
1164 /// - **Zero-Cost**: Direct slice creation with no copying
1165 /// - **Cache-Friendly**: Sequential memory access pattern
1166 /// - **Safe**: No unsafe code required for basic data access
1167 ///
1168 /// # Examples
1169 ///
1170 /// ```
1171 /// use train_station::Tensor;
1172 ///
1173 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1174 /// let data = tensor.data();
1175 ///
1176 /// // Safe indexing and comparisons
1177 /// assert_eq!(data[0], 1.0);
1178 /// assert_eq!(data.len(), tensor.size());
1179 /// ```
1180 #[inline]
1181 pub fn data(&self) -> &[f32] {
1182 if self.size() == 0 {
1183 return &[];
1184 }
1185 unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()) }
1186 }
1187
1188 /// Returns a mutable slice of the tensor's underlying data
1189 ///
1190 /// Provides safe mutable access to the tensor's data without requiring unsafe
1191 /// pointer operations. Use this for in-place modifications of tensor values.
1192 ///
1193 /// # Returns
1194 ///
1195 /// A mutable slice containing all tensor elements in row-major order
1196 ///
1197 /// # Performance
1198 ///
1199 /// - **Zero-Cost**: Direct slice creation with no copying
1200 /// - **Cache-Friendly**: Sequential memory access pattern
1201 /// - **Safe**: No unsafe code required for basic data modification
1202 ///
1203 /// # Examples
1204 ///
1205 /// ```
1206 /// use train_station::Tensor;
1207 ///
1208 /// let mut tensor = Tensor::new(vec![2, 2]);
1209 /// let data = tensor.data_mut();
1210 ///
1211 /// // Safe indexing for modification
1212 /// data[0] = 1.0;
1213 /// data[1] = 2.0;
1214 ///
1215 /// assert_eq!(tensor.get(&[0, 0]), 1.0);
1216 /// assert_eq!(tensor.get(&[0, 1]), 2.0);
1217 /// ```
1218 #[inline]
1219 pub fn data_mut(&mut self) -> &mut [f32] {
1220 if self.size() == 0 {
1221 return &mut [];
1222 }
1223 unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) }
1224 }
1225
1226 /// Extract scalar value from single-element tensor
1227 ///
1228 /// This method provides a convenient way to extract the scalar value from
1229 /// tensors that contain exactly one element. This is commonly used with
1230 /// element iterator results and scalar tensor operations.
1231 ///
1232 /// # Returns
1233 ///
1234 /// The scalar value contained in this tensor
1235 ///
1236 /// # Panics
1237 ///
1238 /// Panics if the tensor does not contain exactly one element
1239 ///
1240 /// # Examples
1241 ///
1242 /// ```
1243 /// use train_station::Tensor;
1244 ///
1245 /// // Single-element tensor
1246 /// let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
1247 /// assert_eq!(scalar.value(), 42.0);
1248 /// ```
1249 #[inline]
1250 #[track_caller]
1251 pub fn value(&self) -> f32 {
1252 assert_eq!(
1253 self.size(),
1254 1,
1255 "value() can only be called on tensors with exactly one element. \
1256 This tensor has {} elements with shape {:?}",
1257 self.size(),
1258 self.shape().dims
1259 );
1260 self.data()[0]
1261 }
1262
1263 /// Create a view with a new shape (requires contiguous memory)
1264 ///
1265 /// Behaves like PyTorch `view`: tensor must be contiguous and the total
1266 /// number of elements must remain the same. Supports -1 inference for one dimension.
1267 ///
1268 /// # Arguments
1269 ///
1270 /// * `new_shape` - New shape for the tensor (can contain -1 for inference)
1271 ///
1272 /// # Returns
1273 ///
1274 /// A tensor viewing the same data with a new shape
1275 ///
1276 /// # Examples
1277 ///
1278 /// ```
1279 /// use train_station::Tensor;
1280 ///
1281 /// let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
1282 /// let y = x.view(vec![2, 2]);
1283 /// assert_eq!(y.shape().dims, vec![2, 2]);
1284 /// ```
1285 pub fn view(&self, new_shape: Vec<i32>) -> Tensor {
1286 // Use the views module implementation
1287 use crate::tensor::transform::view::TensorViewExt;
1288 TensorViewExt::view(self, new_shape)
1289 }
1290
1291 /// Create an element view for the specified index
1292 ///
1293 /// Returns a scalar tensor (shape \[1\]) that views a single element
1294 /// of the source tensor. Maintains gradient tracking.
1295 ///
1296 /// # Arguments
1297 ///
1298 /// * `index` - Linear index of the element to view
1299 ///
1300 /// # Returns
1301 ///
1302 /// A scalar tensor viewing the specified element
1303 ///
1304 /// # Examples
1305 ///
1306 /// ```
1307 /// use train_station::Tensor;
1308 ///
1309 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
1310 /// let element = tensor.element_view(1);
1311 /// assert_eq!(element.value(), 2.0);
1312 /// ```
1313 pub fn element_view(&self, index: usize) -> Tensor {
1314 use crate::tensor::transform::view::TensorViewExt;
1315 TensorViewExt::element_view(self, index)
1316 }
1317
1318 /// Create a slice view of the tensor
1319 ///
1320 /// Returns a view of a contiguous or strided slice of the source tensor.
1321 ///
1322 /// # Arguments
1323 ///
1324 /// * `start` - Starting index
1325 /// * `step` - Step size (1 for contiguous)
1326 /// * `length` - Number of elements
1327 ///
1328 /// # Returns
1329 ///
1330 /// A tensor viewing the specified slice
1331 ///
1332 /// # Examples
1333 ///
1334 /// ```
1335 /// use train_station::Tensor;
1336 ///
1337 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1338 /// let slice = tensor.slice_view(1, 2, 2); // [2.0, 4.0]
1339 /// assert_eq!(slice.data(), &[2.0, 4.0]);
1340 /// ```
1341 pub fn slice_view(&self, start: usize, step: usize, length: usize) -> Tensor {
1342 use crate::tensor::transform::view::TensorViewExt;
1343 TensorViewExt::slice_view(self, start, step, length)
1344 }
1345
1346 /// Create a tensor view from raw components
1347 ///
1348 /// Creates a tensor that views existing memory with the specified shape and device.
1349 /// The tensor shares memory with the original allocation through the allocation_owner.
1350 ///
1351 /// # Safety
1352 ///
1353 /// The caller must ensure:
1354 /// - `data` is valid for the number of elements specified by `shape`
1355 /// - `data` remains valid for the lifetime of the returned tensor
1356 /// - `allocation_owner` properly manages the memory lifecycle
1357 ///
1358 /// # Arguments
1359 ///
1360 /// * `data` - Raw pointer to the tensor data
1361 /// * `shape` - Shape of the tensor view
1362 /// * `device` - Device where the tensor is located
1363 /// * `allocation_owner` - Optional shared allocation owner
1364 ///
1365 /// # Returns
1366 ///
1367 /// A new tensor that views the specified memory
1368 ///
1369 /// # Implementation Details
1370 ///
1371 /// This method is used internally to create tensor views that share memory
1372 /// with other tensors. It's primarily used for view operations and memory
1373 /// management.
1374 pub(crate) fn from_raw_view(
1375 data: *const f32,
1376 shape: Shape,
1377 device: Device,
1378 allocation_owner: Option<Arc<Allocation>>,
1379 ) -> Self {
1380 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1381
1382 Self {
1383 data: NonNull::new(data as *mut f32).expect("Data pointer cannot be null"),
1384 shape,
1385 device,
1386 id,
1387 requires_grad: false,
1388 grad: None,
1389 grad_fn: GradFn::None,
1390 allocation_owner,
1391 _phantom: PhantomData,
1392 }
1393 }
1394
1395 /// Get the allocation owner for this tensor
1396 ///
1397 /// Returns the shared allocation owner if this tensor is a view,
1398 /// or None if this tensor owns its memory directly.
1399 ///
1400 /// # Returns
1401 ///
1402 /// Optional reference to the allocation owner
1403 ///
1404 /// # Implementation Details
1405 ///
1406 /// This method is used internally to manage memory lifecycle for tensor views.
1407 /// It helps determine whether a tensor shares memory with another tensor.
1408 pub fn allocation_owner(&self) -> Option<&Arc<Allocation>> {
1409 self.allocation_owner.as_ref()
1410 }
1411
1412 /// Create a new tensor with uninitialized memory
1413 ///
1414 /// This method allocates memory for a tensor without initializing it to any value.
1415 /// This is useful for performance-critical operations where the memory will be
1416 /// immediately overwritten, such as matrix multiplication results.
1417 ///
1418 /// # Safety
1419 ///
1420 /// The caller must ensure that all memory is written before reading from the tensor.
1421 /// Reading from uninitialized memory is undefined behavior.
1422 ///
1423 /// # Arguments
1424 ///
1425 /// * `shape_dims` - The dimensions of the tensor
1426 ///
1427 /// # Returns
1428 ///
1429 /// A tensor with uninitialized memory
1430 ///
1431 /// # Performance
1432 ///
1433 /// - **Zero Initialization**: Skips memory initialization for maximum performance
1434 /// - **SIMD Ready**: Properly aligned for vectorized operations
1435 /// - **Memory Efficient**: Uses optimized alignment strategies
1436 ///
1437 /// # Example
1438 ///
1439 /// ```
1440 /// use train_station::Tensor;
1441 ///
1442 /// // Create uninitialized tensor for matmul result
1443 /// let mut result = Tensor::new_uninitialized(vec![100, 100]);
1444 /// // Initialize the memory before use
1445 /// for value in result.data_mut() {
1446 /// *value = 0.0;
1447 /// }
1448 /// ```
1449 #[inline]
1450 pub fn new_uninitialized(shape_dims: Vec<usize>) -> Self {
1451 let shape = Shape::new(shape_dims);
1452 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1453
1454 if shape.size == 0 {
1455 // Handle zero-sized tensors
1456 return Self {
1457 data: NonNull::dangling(),
1458 shape,
1459 device: current_device(),
1460 id,
1461 requires_grad: false,
1462 grad: None,
1463 grad_fn: GradFn::None,
1464 allocation_owner: None,
1465 _phantom: PhantomData,
1466 };
1467 }
1468
1469 // Optimized layout calculation for better cache performance
1470 let element_size = std::mem::size_of::<f32>();
1471 let total_size = shape.size * element_size;
1472
1473 // Use cache line alignment for large tensors, smaller alignment for small ones
1474 let alignment = if total_size > 4096 {
1475 64 // Cache line alignment for large tensors
1476 } else if shape.size >= 8 {
1477 32 // AVX2 alignment for medium tensors
1478 } else {
1479 16 // SSE alignment for small tensors
1480 };
1481
1482 let layout = Layout::from_size_align(total_size, alignment)
1483 .expect("Failed to create layout for tensor data");
1484
1485 // Allocate memory via shared Allocation (uninitialized)
1486 let alloc_obj = Allocation::new_uninitialized(shape.size, alignment, layout);
1487 let ptr = alloc_obj.ptr;
1488
1489 Self {
1490 data: ptr,
1491 shape,
1492 device: current_device(),
1493 id,
1494 requires_grad: false,
1495 grad: None,
1496 grad_fn: GradFn::None,
1497 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
1498 _phantom: PhantomData,
1499 }
1500 }
1501}