train_station/tensor/core/utils.rs
1//! Tensor utility functions and core implementation methods
2//!
3//! This module provides essential utility functions for tensor creation, memory management,
4//! gradient tracking, and optimization. It contains the core implementation methods
5//! that enable efficient tensor operations and gradtrack functionality.
6//!
7//! # Key Features
8//!
9//! - **Tensor Creation**: Optimized constructors with memory alignment
10//! - **Memory Management**: Safe memory access and allocation utilities
11//! - **Gradient Tracking**: GradTrack system integration and gradient management
12//! - **Performance Optimization**: SIMD-ready memory layout and alignment
13//! - **Device Management**: CPU and future CUDA device support
14//! - **Memory Layout**: Contiguous, strided, and view memory access patterns
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Alignment**: 16-byte SSE, 32-byte AVX2, 64-byte cache-line alignment
19//! - **SIMD Optimization**: Properly aligned memory for vectorized operations
20//! - **Zero-Cost Abstractions**: Minimal overhead for utility operations
21//! - **Thread Safety**: Atomic operations for gradient tracking and ID generation
22//! - **Memory Efficiency**: Optimized allocation strategies for different tensor sizes
23//!
24//! # Examples
25//!
26//! ## Basic Tensor Creation
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensors of different sizes
32//! let small_tensor = Tensor::new(vec![2, 3]); // 16-byte alignment
33//! let medium_tensor = Tensor::new(vec![32, 32]); // 32-byte alignment
34//! let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
35//!
36//! // Initialize data before use
37//! let mut tensor = Tensor::new(vec![2, 3]);
38//! tensor.fill(0.0); // Initialize with zeros
39//! ```
40//!
41//! ## Gradient Tracking
42//!
43//! ```
44//! use train_station::Tensor;
45//!
46//! let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
47//! assert!(tensor.requires_grad());
48//! ```
49//!
50//! ## Memory Access
51//!
52//! ```
53//! use train_station::Tensor;
54//!
55//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
56//! let value = tensor.get(&[0, 1]);
57//! assert_eq!(value, 2.0);
58//!
59//! let mut tensor = Tensor::new(vec![2, 2]);
60//! tensor.set(&[0, 1], 42.0);
61//! assert_eq!(tensor.get(&[0, 1]), 42.0);
62//! ```
63
64use std::alloc::Layout;
65use std::marker::PhantomData;
66use std::ptr::NonNull;
67use std::sync::atomic::Ordering;
68use std::sync::Arc;
69
70use crate::device::current_device;
71use crate::gradtrack::engine::ensure_local_group_for_tensor;
72use crate::gradtrack::{self, GradEngine, GradFn};
73#[cfg(target_arch = "x86_64")]
74#[cfg(test)]
75use crate::tensor::core::memory::SimdLevel;
76use crate::tensor::core::memory::{compute_allocation_params, use_pool_alloc_enabled};
77use crate::tensor::core::{Allocation, Device, TENSOR_ID_COUNTER};
78use crate::tensor::Shape;
79
80use super::Tensor;
81
82impl Tensor {
83 /// Ensures that this tensor has unique ownership of its allocation before mutation.
84 ///
85 /// If the underlying allocation is shared (i.e., multiple views), this performs
86 /// a copy-on-write: allocates a new buffer and copies data so that mutating this tensor
87 /// does not affect existing views.
88 fn ensure_unique_allocation(&mut self) {
89 if self.size() == 0 {
90 return;
91 }
92 if let Some(owner) = self.allocation_owner.as_ref() {
93 if std::sync::Arc::strong_count(owner) > 1 {
94 // Allocate a new buffer with the same size/alignment policy as constructors
95 let (alignment, padded_elems) = compute_allocation_params(self.shape.size());
96 let total_size = padded_elems * std::mem::size_of::<f32>();
97 let layout = Layout::from_size_align(total_size, alignment)
98 .expect("Failed to create layout for tensor data");
99
100 let alloc_obj = if use_pool_alloc_enabled() {
101 Allocation::new_pooled(padded_elems, alignment, layout)
102 } else {
103 Allocation::new_uninitialized(padded_elems, alignment, layout)
104 };
105
106 let new_ptr = alloc_obj.ptr;
107 unsafe {
108 std::ptr::copy_nonoverlapping(self.as_ptr(), new_ptr.as_ptr(), self.size());
109 }
110
111 // Replace pointer and owner with the new unique allocation
112 self.data = new_ptr;
113 self.allocation_owner = Some(std::sync::Arc::new(alloc_obj));
114 }
115 }
116 }
117 // ===== Runtime SIMD capability helpers for ops selection =====
118
119 /// Returns the highest runtime-detected SIMD level on this CPU
120 #[inline]
121 #[cfg(test)]
122 #[cfg(target_arch = "x86_64")]
123 pub(crate) fn simd_runtime_level() -> SimdLevel {
124 use crate::tensor::core::memory::detect_runtime_simd;
125
126 detect_runtime_simd()
127 }
128
129 /// Returns the SIMD lane width (elements per vector) for f32 at runtime
130 #[inline]
131 #[cfg(test)]
132 #[cfg(target_arch = "x86_64")]
133 pub(crate) fn simd_lane_width_elems_runtime() -> usize {
134 use crate::tensor::core::memory::simd_lane_width_elems;
135
136 simd_lane_width_elems(Self::simd_runtime_level())
137 }
138
139 /// Checks whether this tensor's data pointer is aligned for the specified SIMD level
140 #[inline]
141 #[cfg(test)]
142 #[cfg(target_arch = "x86_64")]
143 pub(crate) fn is_aligned_for_level(&self, level: SimdLevel) -> bool {
144 use crate::tensor::core::memory::simd_alignment_bytes;
145
146 let required = simd_alignment_bytes(level);
147 (self.data.as_ptr() as usize).is_multiple_of(required)
148 }
149
150 /// Returns true if this tensor is aligned for the current runtime SIMD level
151 #[inline]
152 #[cfg(test)]
153 #[cfg(target_arch = "x86_64")]
154 pub(crate) fn is_aligned_for_runtime_level(&self) -> bool {
155 self.is_aligned_for_level(Self::simd_runtime_level())
156 }
157
158 /// Returns true if ops should use aligned SIMD loads/stores without a tail branch
159 /// (i.e., pointer is aligned for the runtime SIMD level and length is a multiple of lane)
160 #[inline]
161 #[cfg(test)]
162 #[cfg(target_arch = "x86_64")]
163 pub(crate) fn prefer_aligned_simd_ops(&self) -> bool {
164 let lane = Self::simd_lane_width_elems_runtime();
165 self.is_aligned_for_runtime_level() && self.size().is_multiple_of(lane)
166 }
167
168 /// Returns true if ops should use unaligned SIMD loads/stores (mm_loadu)
169 /// because pointer alignment or length does not meet aligned requirements.
170 /// Returns false when no SIMD is available.
171 #[inline]
172 #[cfg(target_arch = "x86_64")]
173 #[cfg(test)]
174 pub(crate) fn should_use_unaligned_simd_ops(&self) -> bool {
175 match Self::simd_runtime_level() {
176 SimdLevel::Scalar => false,
177 _ => !self.prefer_aligned_simd_ops(),
178 }
179 }
180
181 /// Returns the allocated capacity in elements, which may be padded beyond logical size
182 #[inline]
183 pub fn capacity_elems(&self) -> usize {
184 if let Some(owner) = self.allocation_owner() {
185 owner.capacity_elems()
186 } else {
187 self.size()
188 }
189 }
190 /// Creates a new tensor with the specified shape and optimized memory layout
191 ///
192 /// Allocates memory with size-dependent alignment for optimal performance:
193 /// - Small tensors (≤8 elements): 16-byte SSE alignment
194 /// - Medium tensors (8-1024 elements): 32-byte AVX2 alignment
195 /// - Large tensors (>1024 elements): 64-byte cache-line alignment
196 ///
197 /// # Arguments
198 ///
199 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
200 ///
201 /// # Returns
202 ///
203 /// A new tensor with uninitialized data. The data must be initialized
204 /// before use to avoid undefined behavior.
205 ///
206 /// # Performance
207 ///
208 /// - **Memory Allocation**: Single allocation with optimized alignment
209 /// - **SIMD Ready**: Properly aligned for vectorized operations
210 /// - **Cache Friendly**: Optimized for CPU cache hierarchies
211 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
212 ///
213 /// # Safety
214 ///
215 /// The returned tensor contains uninitialized memory. You must initialize
216 /// the data before performing any operations that read from it.
217 ///
218 /// # Examples
219 ///
220 /// ```
221 /// use train_station::Tensor;
222 ///
223 /// // Create tensors of different sizes
224 /// let small_tensor = Tensor::new(vec![2, 3]); // 16-byte alignment
225 /// let medium_tensor = Tensor::new(vec![32, 32]); // 32-byte alignment
226 /// let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
227 ///
228 /// // Initialize data before use
229 /// let mut tensor = Tensor::new(vec![2, 3]);
230 /// tensor.fill(0.0); // Initialize with zeros
231 /// ```
232 #[inline]
233 #[track_caller]
234 pub fn new(shape_dims: Vec<usize>) -> Self {
235 let shape = Shape::new(shape_dims);
236 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
237
238 if shape.size() == 0 {
239 // Handle zero-sized tensors
240 return Self {
241 data: NonNull::dangling(),
242 shape,
243 device: current_device(),
244 id,
245 requires_grad: false,
246 retain_grad: false,
247 grad: None,
248 grad_fn: GradFn::None,
249 allocation_owner: None,
250 graph_group: None,
251 _phantom: PhantomData,
252 };
253 }
254
255 // Compute alignment and padded element count based on runtime SIMD
256 let (alignment, padded_elems) = compute_allocation_params(shape.size());
257 let total_size = padded_elems * std::mem::size_of::<f32>();
258
259 let layout = Layout::from_size_align(total_size, alignment)
260 .expect("Failed to create layout for tensor data");
261
262 // Allocation policy: prefer pool for tiny tensors and medium/large classes.
263 // Route certain small sizes to system allocator to keep pool stats semantics used by tests.
264 let alloc_obj = if use_pool_alloc_enabled() {
265 Allocation::new_pooled(padded_elems, alignment, layout)
266 } else {
267 Allocation::new(padded_elems, alignment, layout)
268 };
269 let ptr = alloc_obj.ptr;
270
271 debug_assert!(
272 alloc_obj.capacity_elems() >= padded_elems,
273 "Allocation capacity ({}) smaller than padded elements ({})",
274 alloc_obj.capacity_elems(),
275 padded_elems
276 );
277
278 Self {
279 data: ptr,
280 shape,
281 device: current_device(),
282 id,
283 requires_grad: false,
284 retain_grad: false,
285 grad: None,
286 grad_fn: GradFn::None,
287 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
288 graph_group: None,
289 _phantom: PhantomData,
290 }
291 }
292
293 /// Returns the shape and dimensional information of the tensor
294 ///
295 /// Provides access to the tensor's dimensions, size, strides, and memory
296 /// layout information. This is used for shape validation, memory access
297 /// calculations, and optimization decisions.
298 ///
299 /// # Returns
300 ///
301 /// Reference to the tensor's shape information containing dimensions,
302 /// size, strides, and memory layout type.
303 ///
304 /// # Performance
305 ///
306 /// - **Time Complexity**: O(1) - direct field access
307 /// - **Memory**: No allocation - returns reference to existing data
308 ///
309 /// # Examples
310 ///
311 /// ```
312 /// use train_station::Tensor;
313 ///
314 /// let tensor = Tensor::new(vec![2, 3, 4]);
315 /// let shape = tensor.shape();
316 /// assert_eq!(shape.dims(), vec![2, 3, 4]);
317 /// assert_eq!(shape.size(), 24);
318 /// assert_eq!(shape.rank(), 3);
319 /// ```
320 #[inline]
321 #[track_caller]
322 pub fn shape(&self) -> &Shape {
323 &self.shape
324 }
325
326 /// Returns the total number of elements in the tensor
327 ///
328 /// Provides the total count of elements across all dimensions. This is
329 /// used for memory allocation, iteration bounds, and performance optimization.
330 ///
331 /// # Returns
332 ///
333 /// Total number of elements as `usize`
334 ///
335 /// # Performance
336 ///
337 /// - **Time Complexity**: O(1) - direct field access
338 /// - **Memory**: No allocation - returns stored value
339 ///
340 /// # Examples
341 ///
342 /// ```
343 /// use train_station::Tensor;
344 ///
345 /// let tensor = Tensor::new(vec![2, 3, 4]);
346 /// assert_eq!(tensor.size(), 24); // 2 * 3 * 4
347 ///
348 /// let scalar = Tensor::new(vec![1]);
349 /// assert_eq!(scalar.size(), 1);
350 ///
351 /// let empty = Tensor::new(vec![0]);
352 /// assert_eq!(empty.size(), 0);
353 /// ```
354 #[inline]
355 #[track_caller]
356 pub fn size(&self) -> usize {
357 self.shape().size()
358 }
359
360 /// Returns the device where this tensor is located
361 ///
362 /// Provides the physical location of the tensor data (CPU/GPU). This
363 /// determines which operations can be performed on the tensor and where
364 /// computations will be executed.
365 ///
366 /// # Returns
367 ///
368 /// Device enum indicating the tensor's physical location
369 ///
370 /// # Performance
371 ///
372 /// - **Time Complexity**: O(1) - direct field access
373 /// - **Memory**: No allocation - returns stored value
374 ///
375 /// # Examples
376 ///
377 /// ```
378 /// use train_station::Tensor;
379 ///
380 /// let tensor = Tensor::new(vec![2, 3]);
381 /// assert!(tensor.device().is_cpu());
382 /// assert!(!tensor.device().is_cuda());
383 /// ```
384 #[inline]
385 #[track_caller]
386 pub fn device(&self) -> Device {
387 self.device
388 }
389
390 /// Creates a new tensor with the specified shape on a specific device
391 ///
392 /// Allocates memory on the specified device with the same optimized alignment
393 /// strategy as `new()`. Currently supports CPU device with future CUDA support.
394 ///
395 /// # Arguments
396 ///
397 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
398 /// * `device` - The device where the tensor should be allocated
399 ///
400 /// # Returns
401 ///
402 /// A new tensor with uninitialized data on the specified device
403 ///
404 /// # Performance
405 ///
406 /// - **Memory Allocation**: Device-specific allocation with optimized alignment
407 /// - **SIMD Ready**: Properly aligned for vectorized operations on target device
408 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
409 ///
410 /// # Panics
411 ///
412 /// Panics if the specified device is not supported (e.g., CUDA without feature flag)
413 ///
414 /// # Examples
415 ///
416 /// ```
417 /// use train_station::Tensor;
418 ///
419 /// let tensor = Tensor::new_on_device(vec![2, 3], train_station::Device::cpu());
420 /// assert!(tensor.device().is_cpu());
421 /// assert_eq!(tensor.size(), 6);
422 /// ```
423 ///
424 /// # Arguments
425 ///
426 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
427 /// * `device` - Device where the tensor should be allocated
428 ///
429 /// # Returns
430 ///
431 /// A new tensor with uninitialized data on the specified device
432 ///
433 /// # Panics
434 ///
435 /// Panics if the device is not supported (currently only CPU is supported)
436 ///
437 /// # Performance
438 ///
439 /// - **Memory Allocation**: Single allocation with optimized alignment
440 /// - **SIMD Ready**: Properly aligned for vectorized operations
441 /// - **Cache Friendly**: Optimized for CPU cache hierarchies
442 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
443 ///
444 /// # Examples
445 ///
446 /// ```
447 /// use train_station::{Tensor, Device};
448 ///
449 /// // Create tensor on CPU device
450 /// let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
451 /// assert_eq!(tensor.device(), Device::cpu());
452 /// assert_eq!(tensor.size(), 6);
453 /// ```
454 #[track_caller]
455 pub fn new_on_device(shape_dims: Vec<usize>, device: Device) -> Self {
456 // For now, only CPU is supported
457 if !device.is_cpu() {
458 panic!("Only CPU device is currently supported. CUDA support is planned for future releases.");
459 }
460
461 let shape = Shape::new(shape_dims);
462 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
463
464 if shape.size() == 0 {
465 // Handle zero-sized tensors
466 return Self {
467 data: NonNull::dangling(),
468 shape,
469 device,
470 id,
471 requires_grad: false,
472 retain_grad: false,
473 grad: None,
474 grad_fn: GradFn::None,
475 allocation_owner: None,
476 graph_group: None,
477 _phantom: PhantomData,
478 };
479 }
480
481 let (alignment, padded_elems) = compute_allocation_params(shape.size());
482 let total_size = padded_elems * std::mem::size_of::<f32>();
483 let layout = Layout::from_size_align(total_size, alignment)
484 .expect("Failed to create layout for tensor data");
485
486 // Same allocation policy as new(): prefer pool for tiny and medium/large classes
487 let alloc_obj = if use_pool_alloc_enabled() {
488 Allocation::new_pooled(padded_elems, alignment, layout)
489 } else {
490 Allocation::new(padded_elems, alignment, layout)
491 };
492 let ptr = alloc_obj.ptr;
493
494 debug_assert!(
495 alloc_obj.capacity_elems() >= padded_elems,
496 "Allocation capacity ({}) smaller than padded elements ({})",
497 alloc_obj.capacity_elems(),
498 padded_elems
499 );
500
501 Self {
502 data: ptr,
503 shape,
504 device,
505 id,
506 requires_grad: false,
507 retain_grad: false,
508 grad: None,
509 grad_fn: GradFn::None,
510 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
511 graph_group: None,
512 _phantom: PhantomData,
513 }
514 }
515
516 /// Enable gradient computation for this tensor
517 ///
518 /// Builder method that enables automatic gradient tracking for this tensor.
519 /// When enabled, all operations involving this tensor will be recorded in
520 /// the computation graph for gradient computation during backward pass.
521 ///
522 /// # Returns
523 ///
524 /// `self` with gradient tracking enabled
525 ///
526 /// # Performance
527 ///
528 /// - **Time Complexity**: O(1) - simple field assignment
529 /// - **Memory**: No additional allocation
530 /// - **Overhead**: Minimal gradtrack tracking overhead when gradients computed
531 ///
532 /// # Examples
533 ///
534 /// ```
535 /// use train_station::Tensor;
536 ///
537 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
538 /// assert!(tensor.requires_grad());
539 /// ```
540 #[track_caller]
541 pub fn with_requires_grad(mut self) -> Self {
542 self.requires_grad = true;
543 if self.graph_group.is_none() {
544 self.graph_group = Some(ensure_local_group_for_tensor(self.id()));
545 }
546 self
547 }
548
549 /// Set gradient tracking for this tensor
550 ///
551 /// Controls whether the gradtrack system tracks operations on this tensor
552 /// and computes gradients during backward pass. When disabled, clears
553 /// any existing gradients and gradient functions.
554 ///
555 /// # Arguments
556 ///
557 /// * `requires_grad` - Whether to track gradients for this tensor
558 ///
559 /// # Performance
560 ///
561 /// - **Time Complexity**: O(1) - simple field assignment
562 /// - **Memory**: May free gradient storage when disabled
563 /// - **Overhead**: Zero overhead when gradients disabled
564 ///
565 /// # Examples
566 ///
567 /// ```
568 /// use train_station::Tensor;
569 ///
570 /// let mut tensor = Tensor::ones(vec![2, 3]);
571 /// tensor.set_requires_grad(true);
572 /// assert!(tensor.requires_grad());
573 ///
574 /// // Disable gradient tracking
575 /// tensor.set_requires_grad(false);
576 /// assert!(!tensor.requires_grad());
577 /// ```
578 #[track_caller]
579 pub fn set_requires_grad(&mut self, requires_grad: bool) {
580 self.requires_grad = requires_grad;
581 if !requires_grad {
582 self.grad = None;
583 self.grad_fn = GradFn::None;
584 self.graph_group = None;
585 } else {
586 // Lazily bind to a local graph group for this tensor id
587 if self.graph_group.is_none() {
588 self.graph_group = Some(ensure_local_group_for_tensor(self.id()));
589 }
590 }
591 }
592
593 /// Mark this tensor to retain gradients after backward, even if it is non-leaf.
594 ///
595 /// Builder-style API: returns self with `retain_grad=true`.
596 /// Call `materialize_grad()` or `grad_or_fetch()` after backward to copy the
597 /// accumulated gradient from the GradGraph into `self.grad` so `grad()` works.
598 #[track_caller]
599 pub fn retain_grad(mut self) -> Self {
600 self.retain_grad = true;
601 // Register intent with grad engine to keep gradient available after backward
602 crate::gradtrack::engine::mark_retain_grad(self.id);
603 self
604 }
605
606 /// In-place variant to enable or disable gradient retention for non-leaf tensors
607 #[track_caller]
608 pub fn retain_grad_(&mut self, enable: bool) {
609 self.retain_grad = enable;
610 if enable {
611 crate::gradtrack::engine::mark_retain_grad(self.id);
612 }
613 }
614
615 /// Check if this tensor requires gradients
616 ///
617 /// # Returns
618 ///
619 /// `true` if gradient tracking is enabled for this tensor
620 ///
621 /// # Examples
622 ///
623 /// ```
624 /// use train_station::Tensor;
625 ///
626 /// let tensor = Tensor::new(vec![2, 3]);
627 /// assert!(!tensor.requires_grad());
628 ///
629 /// let grad_tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
630 /// assert!(grad_tensor.requires_grad());
631 /// ```
632 #[track_caller]
633 pub fn requires_grad(&self) -> bool {
634 self.requires_grad
635 }
636
637 /// Get a reference to this tensor's locally cached gradient (if any)
638 ///
639 /// This accessor returns only the gradient cached on this tensor's `grad` field.
640 /// It does NOT query the global/autograd gradient storage. For leaf tensors,
641 /// gradients are accumulated and tracked by the grad engine and are not
642 /// automatically written back to the local `grad` field.
643 ///
644 /// To make `grad()` return `Some(&Tensor)`:
645 /// - Enable retention on this tensor (typically for non-leaf tensors) with
646 /// `retain_grad_(&mut tensor, true)` or `tensor.retain_grad()`
647 /// - After `backward()`, call `tensor.materialize_grad()` or
648 /// `tensor.grad_or_fetch()` to copy the accumulated gradient from the
649 /// autograd engine into `self.grad`
650 ///
651 /// If you want to read gradients without caching them locally, prefer
652 /// [`grad_owned`](#method.grad_owned), which consults the global gradient
653 /// storage.
654 ///
655 /// # Returns
656 ///
657 /// Optional reference to the locally cached gradient tensor, or `None` if
658 /// not materialized on this tensor.
659 ///
660 /// # Examples
661 ///
662 /// ```
663 /// use train_station::Tensor;
664 ///
665 /// let mut x = Tensor::ones(vec![2, 3]).with_requires_grad();
666 /// let mut loss = x.sum();
667 /// loss.backward(None);
668 /// // Without materialization, grad() typically returns None for leaves
669 /// assert!(x.grad().is_none());
670 /// // Materialize to cache locally so grad() works
671 /// x.retain_grad_ (true);
672 /// x.materialize_grad();
673 /// assert!(x.grad().is_some());
674 /// ```
675 #[track_caller]
676 pub fn grad(&self) -> Option<&Tensor> {
677 // First check if we have a gradient stored directly
678 if let Some(grad) = self.grad.as_ref() {
679 return Some(grad.as_ref());
680 }
681
682 None
683 }
684
685 /// Fetch the accumulated gradient after backward and cache it on this tensor if `retain_grad` is enabled.
686 ///
687 /// Returns true if a gradient was found and cached. After a successful call,
688 /// `grad()` will return `Some(&Tensor)` even for non-leaf tensors.
689 #[track_caller]
690 pub fn materialize_grad(&mut self) -> bool {
691 if self.grad.is_some() {
692 return true;
693 }
694 if !self.retain_grad {
695 return false;
696 }
697 if let Some(g) = self.grad_owned() {
698 self.set_grad(g);
699 return true;
700 }
701 false
702 }
703
704 /// Convenience accessor: if `retain_grad` is enabled, fetch and cache the gradient
705 /// on first access so callers can immediately get a reference.
706 #[track_caller]
707 pub fn grad_or_fetch(&mut self) -> Option<&Tensor> {
708 if self.grad.is_none() && self.retain_grad {
709 if let Some(g) = self.grad_owned() {
710 self.set_grad(g);
711 }
712 }
713 self.grad.as_deref()
714 }
715
716 /// Get the accumulated gradient from the autograd engine as an owned tensor
717 ///
718 /// This accessor queries the global/shared gradient storage for this tensor's
719 /// ID and returns the current accumulated gradient by value. It complements
720 /// [`grad`](#method.grad), which only returns a locally cached reference.
721 ///
722 /// - Works for leaf tensors without any prior materialization step
723 /// - Does not modify or clear internal gradient state
724 /// - Suitable for optimizers and logging that need the latest gradients
725 ///
726 /// # Returns
727 ///
728 /// `Some(Tensor)` containing the current accumulated gradient when available,
729 /// otherwise `None`.
730 ///
731 /// # Examples
732 ///
733 /// ```
734 /// use train_station::Tensor;
735 ///
736 /// let mut x = Tensor::ones(vec![2, 3]).with_requires_grad();
737 /// let mut loss = x.sum();
738 /// loss.backward(None);
739 ///
740 /// // Fetch directly from autograd storage (no materialization required)
741 /// let g = x.grad_owned().unwrap();
742 /// assert_eq!(g.shape().dims(), vec![2, 3]);
743 /// ```
744 #[track_caller]
745 pub fn grad_owned(&self) -> Option<Tensor> {
746 // First check if we have a gradient stored directly
747 if let Some(grad) = self.grad.as_ref() {
748 return Some((**grad).clone());
749 }
750
751 // Always consult the global/shared gradient storage so gradients accumulated
752 // in a promoted/merged shared group are visible regardless of the calling thread.
753 // This is critical for cross-thread forward/backward validation and parity with LibTorch.
754 use crate::gradtrack;
755 gradtrack::get_accumulated_gradient(self.id)
756 }
757
758 /// Get the unique ID of this tensor
759 ///
760 /// Returns the unique identifier assigned to this tensor during creation.
761 /// This ID is used for gradtrack tracking and tensor identification.
762 ///
763 /// # Returns
764 ///
765 /// Unique tensor ID as `usize`
766 ///
767 /// # Examples
768 ///
769 /// ```
770 /// use train_station::Tensor;
771 ///
772 /// let tensor1 = Tensor::new(vec![2, 3]);
773 /// let tensor2 = Tensor::new(vec![2, 3]);
774 /// assert_ne!(tensor1.id(), tensor2.id()); // Each tensor has unique ID
775 /// ```
776 #[track_caller]
777 pub fn id(&self) -> usize {
778 self.id
779 }
780
781 /// Detach this tensor from the computation graph
782 ///
783 /// Returns a new tensor with the same data but no gradient tracking.
784 /// This is useful when you want to use a tensor in inference without
785 /// affecting the computation graph.
786 ///
787 /// # Returns
788 ///
789 /// A new tensor with the same data but gradient tracking disabled
790 ///
791 /// # Examples
792 ///
793 /// ```
794 /// use train_station::Tensor;
795 ///
796 /// let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
797 /// let detached = tensor.detach();
798 /// assert!(!detached.requires_grad());
799 /// assert_eq!(tensor.size(), detached.size());
800 /// ```
801 #[track_caller]
802 pub fn detach(&self) -> Self {
803 let mut detached = Self::new(self.shape().dims().to_vec());
804
805 // Copy data
806 unsafe {
807 let src = self.as_ptr();
808 let dst = detached.as_mut_ptr();
809 std::ptr::copy_nonoverlapping(src, dst, self.size());
810 }
811
812 detached
813 }
814
815 /// Create a new tensor that doesn't track gradients from this one
816 ///
817 /// Similar to detach() but modifies this tensor in place. This is useful
818 /// when you want to disable gradient tracking for the current tensor
819 /// without creating a copy.
820 ///
821 /// # Examples
822 ///
823 /// ```
824 /// use train_station::Tensor;
825 ///
826 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
827 /// assert!(tensor.requires_grad());
828 /// tensor.detach_();
829 /// assert!(!tensor.requires_grad());
830 /// ```
831 #[track_caller]
832 pub fn detach_(&mut self) {
833 self.requires_grad = false;
834 self.grad = None;
835 self.grad_fn = GradFn::None;
836 }
837
838 /// Entry point for backward pass on this tensor
839 ///
840 /// Computes gradients for all tensors in the computation graph that have
841 /// `requires_grad` set to true. This is the main entry point for automatic
842 /// differentiation.
843 ///
844 /// # Arguments
845 ///
846 /// * `grad_output` - Optional gradient tensor for the output. If None, assumes
847 /// the tensor is a scalar (e.g., loss value) and uses a tensor of ones.
848 ///
849 /// # Examples
850 ///
851 /// ```
852 /// use train_station::Tensor;
853 ///
854 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
855 /// let mut result = tensor.add_scalar(5.0);
856 /// result.backward(None);
857 /// // Note: Gradient computation depends on the gradtrack system implementation
858 /// ```
859 #[track_caller]
860 pub fn backward(&mut self, grad_output: Option<Tensor>) {
861 GradEngine::backward(self, grad_output);
862 }
863
864 /// Returns a raw pointer to the tensor data for unsafe operations
865 ///
866 /// # Safety
867 ///
868 /// This is unsafe because it provides direct access to the underlying memory.
869 /// The caller must ensure:
870 /// - The tensor is not dropped while the pointer is used
871 /// - No concurrent mutable access occurs
872 /// - Bounds are respected
873 #[inline]
874 pub unsafe fn as_ptr(&self) -> *const f32 {
875 self.data.as_ptr()
876 }
877
878 /// Returns a mutable raw pointer to the tensor data for unsafe operations
879 ///
880 /// # Safety
881 ///
882 /// This is unsafe because it provides direct mutable access to the underlying memory.
883 /// The caller must ensure:
884 /// - The tensor is not dropped while the pointer is used
885 /// - No concurrent access occurs
886 /// - Bounds are respected
887 #[inline]
888 pub unsafe fn as_mut_ptr(&mut self) -> *mut f32 {
889 self.data.as_ptr()
890 }
891
892 /// Internal method to set gradient function (used by operations)
893 ///
894 /// Sets the gradient function for this tensor. This is used internally
895 /// by tensor operations to record the computation graph for gradtrack.
896 ///
897 /// # Arguments
898 ///
899 /// * `grad_fn` - The gradient function to set
900 ///
901 /// # Implementation Details
902 ///
903 /// This method is called by tensor operations to register the gradient
904 /// computation function. It only sets the gradient function if gradient
905 /// tracking is enabled for this tensor.
906 pub(crate) fn set_grad_fn(&mut self, grad_fn: GradFn) {
907 if self.requires_grad {
908 self.grad_fn = grad_fn;
909 }
910 }
911
912 /// Get a reference to the gradient function (for gradtrack)
913 ///
914 /// Returns a reference to the gradient function associated with this tensor.
915 /// This is used internally by the gradtrack system to compute gradients.
916 ///
917 /// # Returns
918 ///
919 /// Reference to the gradient function
920 ///
921 /// # Implementation Details
922 ///
923 /// This method is used by the gradtrack engine to access the gradient
924 /// computation function during backward pass.
925 #[track_caller]
926 pub fn grad_fn(&self) -> &GradFn {
927 &self.grad_fn
928 }
929
930 /// Internal method to set requires_grad (used by gradtrack operations)
931 ///
932 /// Sets the gradient tracking flag for this tensor. This is used internally
933 /// by gradtrack operations to control gradient computation.
934 ///
935 /// # Arguments
936 ///
937 /// * `requires_grad` - Whether to enable gradient tracking
938 ///
939 /// # Implementation Details
940 ///
941 /// This method is used internally by the gradtrack system to control
942 /// gradient tracking without triggering additional side effects.
943 pub(crate) fn set_requires_grad_internal(&mut self, requires_grad: bool) {
944 self.requires_grad = requires_grad;
945 }
946
947 /// Internal method to accumulate gradients with optimized operations
948 ///
949 /// Accumulates a gradient tensor with any existing gradients for this tensor.
950 /// This is used internally by the gradtrack system to handle gradient accumulation
951 /// during backward pass.
952 ///
953 /// # Arguments
954 ///
955 /// * `grad` - The gradient tensor to accumulate
956 ///
957 /// # Implementation Details
958 ///
959 /// This method is used internally by the gradtrack engine to accumulate
960 /// gradients from multiple backward passes or operations. It only accumulates
961 /// gradients if gradient tracking is enabled for this tensor.
962 pub(crate) fn accumulate_grad(&mut self, grad: Tensor) {
963 if !self.requires_grad {
964 return;
965 }
966
967 match &self.grad {
968 Some(existing_grad) => {
969 // Use optimized tensor addition but create new tensor for safety
970 let accumulated = existing_grad.add_tensor_optimized(&grad);
971 self.grad = Some(Arc::new(accumulated));
972 }
973 None => {
974 self.grad = Some(Arc::new(grad));
975 }
976 }
977 }
978
979 /// Set gradient from external source
980 ///
981 /// Sets the gradient tensor for this tensor. This is used internally by the
982 /// gradtrack system to set gradients during backward pass.
983 ///
984 /// # Arguments
985 ///
986 /// * `grad` - The gradient tensor to set
987 ///
988 /// # Implementation Details
989 ///
990 /// This method is used internally by the gradtrack engine to set gradients
991 /// during backward pass. It only sets the gradient if gradient tracking is
992 /// enabled for this tensor.
993 #[track_caller]
994 pub fn set_grad(&mut self, grad: Tensor) {
995 if self.requires_grad {
996 self.grad = Some(Arc::new(grad));
997 }
998 }
999
1000 /// Clear accumulated gradients for this tensor
1001 ///
1002 /// This method is used by optimizers to zero gradients before each backward pass.
1003 /// It clears any accumulated gradients, allowing for fresh gradient computation.
1004 ///
1005 /// # Examples
1006 ///
1007 /// ```
1008 /// use train_station::Tensor;
1009 ///
1010 /// let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
1011 /// tensor.set_grad(Tensor::ones(vec![2, 3]));
1012 /// assert!(tensor.grad().is_some());
1013 /// tensor.zero_grad();
1014 /// assert!(tensor.grad().is_none());
1015 /// ```
1016 #[track_caller]
1017 pub fn zero_grad(&mut self) {
1018 // Clear locally cached gradient
1019 self.grad = None;
1020 // Ensure gradient is also cleared from autograd storage so future backward runs
1021 // don't observe stale gradients when mixing TLS and shared groups.
1022 crate::gradtrack::engine::clear_gradient_for_tensor(self.id);
1023 }
1024
1025 /// Negate all elements in the tensor in-place
1026 ///
1027 /// This is used internally for gradient computation in subtraction operations.
1028 /// For tensor - tensor operations, the second operand gets a negated gradient.
1029 ///
1030 /// # Implementation Details
1031 ///
1032 /// This method is used internally by the gradtrack system to compute gradients
1033 /// for subtraction operations. It uses SIMD optimization when available for
1034 /// better performance.
1035 #[inline]
1036 pub(crate) fn negate_inplace(&mut self) {
1037 if self.shape.size() == 0 {
1038 return;
1039 }
1040
1041 unsafe {
1042 let ptr = self.data.as_ptr();
1043
1044 #[cfg(target_arch = "x86_64")]
1045 {
1046 // Use SIMD for better performance when available
1047 if is_x86_feature_detected!("avx2") {
1048 self.negate_simd_avx2(ptr);
1049 return;
1050 }
1051 }
1052
1053 // Fallback to scalar operations
1054 for i in 0..self.shape.size() {
1055 *ptr.add(i) = -*ptr.add(i);
1056 }
1057 }
1058 }
1059
1060 /// SIMD-optimized negation using AVX2 instructions
1061 ///
1062 /// Performs in-place negation of tensor elements using AVX2 SIMD instructions
1063 /// for improved performance on x86_64 architectures.
1064 ///
1065 /// # Arguments
1066 ///
1067 /// * `ptr` - Raw pointer to the tensor data
1068 ///
1069 /// # Safety
1070 ///
1071 /// The caller must ensure:
1072 /// - `ptr` is a valid pointer to tensor data
1073 /// - The tensor size matches the actual data size
1074 /// - The tensor is not moved or dropped during this operation
1075 ///
1076 /// # Implementation Details
1077 ///
1078 /// This method is used internally by `negate_inplace` when AVX2 is available.
1079 /// It processes 8 elements per iteration using AVX2 instructions.
1080 #[cfg(target_arch = "x86_64")]
1081 #[inline]
1082 unsafe fn negate_simd_avx2(&self, ptr: *mut f32) {
1083 use std::arch::x86_64::_mm256_setzero_ps;
1084
1085 let size = self.shape().size();
1086 let zero_vec = _mm256_setzero_ps();
1087 let simd_count = size / 8; // Process 8 elements per iteration
1088 let mut offset = 0;
1089
1090 // SIMD loop for negation
1091 for _ in 0..simd_count {
1092 use std::arch::x86_64::{_mm256_load_ps, _mm256_store_ps, _mm256_sub_ps};
1093
1094 let vec = _mm256_load_ps(ptr.add(offset));
1095 let neg_vec = _mm256_sub_ps(zero_vec, vec);
1096 _mm256_store_ps(ptr.add(offset), neg_vec);
1097 offset += 8;
1098 }
1099
1100 // Handle remaining elements
1101 for i in offset..size {
1102 *ptr.add(i) = -*ptr.add(i);
1103 }
1104 }
1105
1106 // ===== Memory Layout and Optimization API =====
1107
1108 /// Checks if the tensor data is stored contiguously in memory
1109 ///
1110 /// # Returns
1111 ///
1112 /// `true` if the tensor data is contiguous, enabling optimized SIMD operations
1113 ///
1114 /// # Examples
1115 ///
1116 /// ```
1117 /// use train_station::Tensor;
1118 ///
1119 /// let tensor = Tensor::new(vec![2, 3, 4]);
1120 /// assert!(tensor.is_contiguous());
1121 /// ```
1122 #[inline]
1123 #[track_caller]
1124 pub fn is_contiguous(&self) -> bool {
1125 self.shape().is_contiguous()
1126 }
1127
1128 /// Gets the memory strides for all dimensions
1129 ///
1130 /// # Returns
1131 ///
1132 /// Reference to the stride vector for efficient memory access calculations
1133 ///
1134 /// # Examples
1135 ///
1136 /// ```
1137 /// use train_station::Tensor;
1138 ///
1139 /// let tensor = Tensor::new(vec![2, 3, 4]);
1140 /// assert_eq!(tensor.strides(), &[12, 4, 1]);
1141 /// ```
1142 #[inline]
1143 #[track_caller]
1144 pub fn strides(&self) -> &[usize] {
1145 self.shape.strides()
1146 }
1147
1148 /// Gets the memory stride for a specific dimension
1149 ///
1150 /// # Arguments
1151 ///
1152 /// * `dim` - The dimension index
1153 ///
1154 /// # Returns
1155 ///
1156 /// The memory stride for the given dimension
1157 ///
1158 /// # Panics
1159 ///
1160 /// Panics if `dim` is out of bounds
1161 #[inline]
1162 #[track_caller]
1163 pub fn stride(&self, dim: usize) -> usize {
1164 self.shape.stride(dim)
1165 }
1166
1167 /// Calculates the linear memory offset for given multi-dimensional indices
1168 ///
1169 /// # Arguments
1170 ///
1171 /// * `indices` - Vector of indices for each dimension
1172 ///
1173 /// # Returns
1174 ///
1175 /// Linear memory offset for direct memory access
1176 ///
1177 /// # Examples
1178 ///
1179 /// ```
1180 /// use train_station::Tensor;
1181 ///
1182 /// let tensor = Tensor::new(vec![2, 3, 4]);
1183 /// let offset = tensor.memory_offset(&[1, 2, 3]);
1184 /// // offset = 1*12 + 2*4 + 3*1 = 23
1185 /// ```
1186 #[inline]
1187 #[track_caller]
1188 pub fn memory_offset(&self, indices: &[usize]) -> usize {
1189 self.shape.offset(indices)
1190 }
1191
1192 /// Broadcast this tensor with another tensor for element-wise operations
1193 ///
1194 /// Returns a tuple containing:
1195 /// - Broadcasted view of self
1196 /// - Broadcasted view of other
1197 /// - Result shape for the operation
1198 ///
1199 /// # Arguments
1200 ///
1201 /// * `other` - The tensor to broadcast with
1202 ///
1203 /// # Returns
1204 ///
1205 /// A tuple `(broadcasted_self, broadcasted_other, result_shape)`
1206 #[track_caller]
1207 pub(crate) fn broadcast_with(
1208 &self,
1209 other: &Tensor,
1210 ) -> Result<
1211 (Tensor, Tensor, crate::tensor::Shape),
1212 crate::tensor::ops::broadcasting::BroadcastError,
1213 > {
1214 crate::tensor::ops::broadcasting::broadcast_shapes(self, other)
1215 }
1216
1217 /// Checks if the tensor data is properly aligned for SIMD operations
1218 ///
1219 /// # Returns
1220 ///
1221 /// `true` if the tensor data is aligned to 32-byte boundaries for AVX2
1222 #[inline]
1223 #[track_caller]
1224 #[cfg(target_arch = "x86_64")]
1225 pub(crate) fn is_simd_aligned(&self) -> bool {
1226 // Maintain AVX2 compatibility for existing SIMD paths
1227 (self.data.as_ptr() as usize).is_multiple_of(32)
1228 }
1229
1230 /// Gets the memory alignment of the tensor data
1231 ///
1232 /// # Returns
1233 ///
1234 /// The memory alignment in bytes (typically 32 for SIMD optimization)
1235 #[inline]
1236 #[track_caller]
1237 pub fn memory_alignment(&self) -> usize {
1238 if let Some(owner) = self.allocation_owner() {
1239 owner.alignment()
1240 } else {
1241 // Zero-sized or non-owned views default to conservative 32
1242 32
1243 }
1244 }
1245
1246 /// Checks if this tensor is broadcastable with another tensor
1247 ///
1248 /// # Arguments
1249 ///
1250 /// * `other` - The other tensor to check broadcasting compatibility
1251 ///
1252 /// # Returns
1253 ///
1254 /// `true` if the tensors are broadcastable according to NumPy broadcasting rules
1255 ///
1256 /// # Examples
1257 ///
1258 /// ```
1259 /// use train_station::Tensor;
1260 ///
1261 /// let a = Tensor::new(vec![2, 3, 4]);
1262 /// let b = Tensor::new(vec![1, 3, 4]);
1263 /// assert!(a.is_broadcastable_with(&b));
1264 /// ```
1265 #[inline]
1266 #[track_caller]
1267 pub fn is_broadcastable_with(&self, other: &Tensor) -> bool {
1268 self.shape.is_broadcastable_with(&other.shape)
1269 }
1270
1271 /// Gets the total number of bytes allocated for this tensor
1272 ///
1273 /// # Returns
1274 ///
1275 /// Total memory footprint in bytes
1276 #[inline]
1277 #[track_caller]
1278 pub fn memory_footprint(&self) -> usize {
1279 self.shape.size() * std::mem::size_of::<f32>()
1280 }
1281
1282 /// Get a single element from the tensor at the specified indices
1283 ///
1284 /// # Arguments
1285 ///
1286 /// * `indices` - Multi-dimensional indices to access the element
1287 ///
1288 /// # Returns
1289 ///
1290 /// The value at the specified position
1291 ///
1292 /// # Panics
1293 ///
1294 /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1295 ///
1296 /// # Examples
1297 ///
1298 /// ```
1299 /// use train_station::Tensor;
1300 ///
1301 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1302 /// let value = tensor.get(&[0, 1]);
1303 /// assert_eq!(value, 2.0);
1304 /// ```
1305 #[track_caller]
1306 pub fn get(&self, indices: &[usize]) -> f32 {
1307 assert_eq!(
1308 indices.len(),
1309 self.shape().rank(),
1310 "Indices length must match tensor rank"
1311 );
1312
1313 // Check bounds
1314 for (i, &idx) in indices.iter().enumerate() {
1315 assert!(
1316 idx < self.shape().dims()[i],
1317 "Index {} out of bounds for dimension {}",
1318 idx,
1319 i
1320 );
1321 }
1322
1323 let offset = self.memory_offset(indices);
1324 unsafe { *self.as_ptr().add(offset) }
1325 }
1326
1327 /// Set a single element in the tensor at the specified indices
1328 ///
1329 /// # Arguments
1330 ///
1331 /// * `indices` - Multi-dimensional indices to set the element
1332 /// * `value` - The value to set
1333 ///
1334 /// # Panics
1335 ///
1336 /// Panics if indices are out of bounds or indices length doesn't match tensor rank
1337 ///
1338 /// # Examples
1339 ///
1340 /// ```
1341 /// use train_station::Tensor;
1342 ///
1343 /// let mut tensor = Tensor::new(vec![2, 2]);
1344 /// tensor.set(&[0, 1], 42.0);
1345 /// assert_eq!(tensor.get(&[0, 1]), 42.0);
1346 /// ```
1347 #[track_caller]
1348 pub fn set(&mut self, indices: &[usize], value: f32) {
1349 assert_eq!(
1350 indices.len(),
1351 self.shape().rank(),
1352 "Indices length must match tensor rank"
1353 );
1354
1355 // Check bounds
1356 for (i, &idx) in indices.iter().enumerate() {
1357 assert!(
1358 idx < self.shape().dims()[i],
1359 "Index {} out of bounds for dimension {}",
1360 idx,
1361 i
1362 );
1363 }
1364
1365 let offset = self.memory_offset(indices);
1366 unsafe {
1367 *self.as_mut_ptr().add(offset) = value;
1368 }
1369 }
1370
1371 /// Returns a safe slice of the tensor's underlying data
1372 ///
1373 /// Provides safe access to the tensor's data without requiring unsafe pointer operations.
1374 /// This is the preferred way to access tensor data for reading values, comparisons,
1375 /// and other operations that don't require direct pointer manipulation.
1376 ///
1377 /// # Returns
1378 ///
1379 /// A slice containing all tensor elements in row-major order
1380 ///
1381 /// # Performance
1382 ///
1383 /// - **Zero-Cost**: Direct slice creation with no copying
1384 /// - **Cache-Friendly**: Sequential memory access pattern
1385 /// - **Safe**: No unsafe code required for basic data access
1386 ///
1387 /// # Examples
1388 ///
1389 /// ```
1390 /// use train_station::Tensor;
1391 ///
1392 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
1393 /// let data = tensor.data();
1394 ///
1395 /// // Safe indexing and comparisons
1396 /// assert_eq!(data[0], 1.0);
1397 /// assert_eq!(data.len(), tensor.size());
1398 /// ```
1399 #[inline]
1400 #[track_caller]
1401 pub fn data(&self) -> &[f32] {
1402 if self.size() == 0 {
1403 return &[];
1404 }
1405 unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()) }
1406 }
1407
1408 /// Returns a mutable slice of the tensor's underlying data
1409 ///
1410 /// Provides safe mutable access to the tensor's data without requiring unsafe
1411 /// pointer operations. Use this for in-place modifications of tensor values.
1412 ///
1413 /// # Returns
1414 ///
1415 /// A mutable slice containing all tensor elements in row-major order
1416 ///
1417 /// # Performance
1418 ///
1419 /// - **Zero-Cost**: Direct slice creation with no copying
1420 /// - **Cache-Friendly**: Sequential memory access pattern
1421 /// - **Safe**: No unsafe code required for basic data modification
1422 ///
1423 /// # Examples
1424 ///
1425 /// ```
1426 /// use train_station::Tensor;
1427 ///
1428 /// let mut tensor = Tensor::new(vec![2, 2]);
1429 /// let data = tensor.data_mut();
1430 ///
1431 /// // Safe indexing for modification
1432 /// data[0] = 1.0;
1433 /// data[1] = 2.0;
1434 ///
1435 /// assert_eq!(tensor.get(&[0, 0]), 1.0);
1436 /// assert_eq!(tensor.get(&[0, 1]), 2.0);
1437 /// ```
1438 #[inline]
1439 #[track_caller]
1440 pub fn data_mut(&mut self) -> &mut [f32] {
1441 if self.size() == 0 {
1442 return &mut [];
1443 }
1444 // Copy-on-write to protect existing views
1445 self.ensure_unique_allocation();
1446 unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) }
1447 }
1448
1449 /// Extract scalar value from single-element tensor
1450 ///
1451 /// This method provides a convenient way to extract the scalar value from
1452 /// tensors that contain exactly one element. This is commonly used with
1453 /// element iterator results and scalar tensor operations.
1454 ///
1455 /// # Returns
1456 ///
1457 /// The scalar value contained in this tensor
1458 ///
1459 /// # Panics
1460 ///
1461 /// Panics if the tensor does not contain exactly one element
1462 ///
1463 /// # Examples
1464 ///
1465 /// ```
1466 /// use train_station::Tensor;
1467 ///
1468 /// // Single-element tensor
1469 /// let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
1470 /// assert_eq!(scalar.value(), 42.0);
1471 /// ```
1472 #[inline]
1473 #[track_caller]
1474 pub fn value(&self) -> f32 {
1475 assert_eq!(
1476 self.size(),
1477 1,
1478 "value() can only be called on tensors with exactly one element. \
1479 This tensor has {} elements with shape {:?}",
1480 self.size(),
1481 self.shape().dims()
1482 );
1483 self.data()[0]
1484 }
1485
1486 /// Create a view with a new shape (requires contiguous memory)
1487 ///
1488 /// Behaves like PyTorch `view`: tensor must be contiguous and the total
1489 /// number of elements must remain the same. Supports -1 inference for one dimension.
1490 ///
1491 /// # Arguments
1492 ///
1493 /// * `new_shape` - New shape for the tensor (can contain -1 for inference)
1494 ///
1495 /// # Returns
1496 ///
1497 /// A tensor viewing the same data with a new shape
1498 ///
1499 /// # Examples
1500 ///
1501 /// ```
1502 /// use train_station::Tensor;
1503 ///
1504 /// let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
1505 /// let y = x.view(vec![2, 2]);
1506 /// assert_eq!(y.shape().dims(), vec![2, 2]);
1507 /// ```
1508 #[track_caller]
1509 pub fn view(&self, new_shape: Vec<i32>) -> Tensor {
1510 // PyTorch-like view with single -1 inference; requires contiguity
1511 let size = self.size();
1512 let mut infer_idx: Option<usize> = None;
1513 let mut product: usize = 1;
1514 let mut dims: Vec<usize> = Vec::with_capacity(new_shape.len());
1515 for (i, d) in new_shape.iter().enumerate() {
1516 if *d == -1 {
1517 assert!(infer_idx.is_none(), "Only one -1 is allowed in view shape");
1518 infer_idx = Some(i);
1519 dims.push(1);
1520 } else {
1521 assert!(*d > 0, "Negative dims not supported in view shape");
1522 let du = *d as usize;
1523 product = product.saturating_mul(du);
1524 dims.push(du);
1525 }
1526 }
1527 if let Some(pos) = infer_idx {
1528 assert!(
1529 product > 0 && size.is_multiple_of(product),
1530 "View shape incompatible with tensor size"
1531 );
1532 dims[pos] = size / product;
1533 } else {
1534 assert!(
1535 product == size,
1536 "View shape has different number of elements"
1537 );
1538 }
1539
1540 let mut v = match crate::tensor::core::view::reshape_view(self, &dims) {
1541 Ok(v) => v,
1542 Err(e) => panic!("view reshape error: {:?}", e),
1543 };
1544 if self.requires_grad() && gradtrack::is_grad_enabled() {
1545 v.set_requires_grad(true);
1546 let grad_fn = GradFn::Reshape {
1547 original_shape: self.shape().dims().to_vec(),
1548 };
1549 v.set_grad_fn(grad_fn.clone());
1550 GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
1551 }
1552 v
1553 }
1554
1555 /// Create an element view for the specified index
1556 ///
1557 /// Returns a scalar tensor (shape \[1\]) that views a single element
1558 /// of the source tensor. Maintains gradient tracking.
1559 ///
1560 /// # Arguments
1561 ///
1562 /// * `index` - Linear index of the element to view
1563 ///
1564 /// # Returns
1565 ///
1566 /// A scalar tensor viewing the specified element
1567 ///
1568 /// # Examples
1569 ///
1570 /// ```
1571 /// use train_station::Tensor;
1572 ///
1573 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
1574 /// let element = tensor.element_view(1);
1575 /// assert_eq!(element.value(), 2.0);
1576 /// ```
1577 #[track_caller]
1578 pub fn element_view(&self, index: usize) -> Tensor {
1579 let mut v = match crate::tensor::core::view::element_view_linear(self, index) {
1580 Ok(v) => v,
1581 Err(e) => panic!("element_view error: {:?}", e),
1582 };
1583 if self.requires_grad() && gradtrack::is_grad_enabled() {
1584 v.set_requires_grad(true);
1585 // Reuse SliceView grad for single-element selection to propagate back to source
1586 let grad_fn = GradFn::View {
1587 mapping: crate::gradtrack::grad_fn::ViewMapping::LinearRange {
1588 start: index,
1589 step: 1,
1590 length: 1,
1591 },
1592 input_shape: self.shape().dims().to_vec(),
1593 };
1594 v.set_grad_fn(grad_fn.clone());
1595 GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
1596 }
1597 v
1598 }
1599
1600 /// Create a slice view of the tensor
1601 ///
1602 /// Returns a view of a contiguous or strided slice of the source tensor.
1603 ///
1604 /// # Arguments
1605 ///
1606 /// * `start` - Starting index
1607 /// * `step` - Step size (1 for contiguous)
1608 /// * `length` - Number of elements
1609 ///
1610 /// # Returns
1611 ///
1612 /// A tensor viewing the specified slice
1613 ///
1614 /// # Examples
1615 ///
1616 /// ```
1617 /// use train_station::Tensor;
1618 ///
1619 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1620 /// let slice = tensor.slice_view(1, 2, 2); // [2.0, 4.0]
1621 /// assert_eq!(slice.get(&[0]), 2.0);
1622 /// assert_eq!(slice.get(&[1]), 4.0);
1623 /// ```
1624 #[track_caller]
1625 pub fn slice_view(&self, start: usize, step: usize, length: usize) -> Tensor {
1626 let mut v = match crate::tensor::core::view::slice_view_linear(self, start, step, length) {
1627 Ok(v) => v,
1628 Err(e) => panic!("slice_view error: {:?}", e),
1629 };
1630 // Ensure correct data() representation for stepped views
1631 // Offset-only (step==1, start>0) is already contiguous via base pointer
1632 if step != 1 {
1633 v = v.contiguous();
1634 }
1635 if self.requires_grad() && gradtrack::is_grad_enabled() {
1636 v.set_requires_grad(true);
1637 let grad_fn = GradFn::View {
1638 mapping: crate::gradtrack::grad_fn::ViewMapping::LinearRange {
1639 start,
1640 step,
1641 length,
1642 },
1643 input_shape: self.shape().dims().to_vec(),
1644 };
1645 v.set_grad_fn(grad_fn.clone());
1646 GradEngine::register_operation(v.id(), vec![self.id()], grad_fn);
1647 }
1648 v
1649 }
1650
1651 /// Create a tensor view from raw components
1652 ///
1653 /// Creates a tensor that views existing memory with the specified shape and device.
1654 /// The tensor shares memory with the original allocation through the allocation_owner.
1655 ///
1656 /// # Safety
1657 ///
1658 /// The caller must ensure:
1659 /// - `data` is valid for the number of elements specified by `shape`
1660 /// - `data` remains valid for the lifetime of the returned tensor
1661 /// - `allocation_owner` properly manages the memory lifecycle
1662 ///
1663 /// # Arguments
1664 ///
1665 /// * `data` - Raw pointer to the tensor data
1666 /// * `shape` - Shape of the tensor view
1667 /// * `device` - Device where the tensor is located
1668 /// * `allocation_owner` - Optional shared allocation owner
1669 ///
1670 /// # Returns
1671 ///
1672 /// A new tensor that views the specified memory
1673 ///
1674 /// # Implementation Details
1675 ///
1676 /// This method is used internally to create tensor views that share memory
1677 /// with other tensors. It's primarily used for view operations and memory
1678 /// management.
1679 pub(crate) fn from_raw_view(
1680 data: *const f32,
1681 shape: Shape,
1682 device: Device,
1683 allocation_owner: Option<Arc<Allocation>>,
1684 ) -> Self {
1685 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1686
1687 Self {
1688 data: NonNull::new(data as *mut f32).expect("Data pointer cannot be null"),
1689 shape,
1690 device,
1691 id,
1692 requires_grad: false,
1693 retain_grad: false,
1694 grad: None,
1695 grad_fn: GradFn::None,
1696 allocation_owner,
1697 graph_group: None,
1698 _phantom: PhantomData,
1699 }
1700 }
1701
1702 /// Get the allocation owner for this tensor
1703 ///
1704 /// Returns the shared allocation owner if this tensor is a view,
1705 /// or None if this tensor owns its memory directly.
1706 ///
1707 /// # Returns
1708 ///
1709 /// Optional reference to the allocation owner
1710 ///
1711 /// # Implementation Details
1712 ///
1713 /// This method is used internally to manage memory lifecycle for tensor views.
1714 /// It helps determine whether a tensor shares memory with another tensor.
1715 #[track_caller]
1716 pub fn allocation_owner(&self) -> Option<&Arc<Allocation>> {
1717 self.allocation_owner.as_ref()
1718 }
1719
1720 /// Create a new tensor with uninitialized memory
1721 ///
1722 /// This method allocates memory for a tensor without initializing it to any value.
1723 /// This is useful for performance-critical operations where the memory will be
1724 /// immediately overwritten, such as matrix multiplication results.
1725 ///
1726 /// # Safety
1727 ///
1728 /// The caller must ensure that all memory is written before reading from the tensor.
1729 /// Reading from uninitialized memory is undefined behavior.
1730 ///
1731 /// # Arguments
1732 ///
1733 /// * `shape_dims` - The dimensions of the tensor
1734 ///
1735 /// # Returns
1736 ///
1737 /// A tensor with uninitialized memory
1738 ///
1739 /// # Performance
1740 ///
1741 /// - **Zero Initialization**: Skips memory initialization for maximum performance
1742 /// - **SIMD Ready**: Properly aligned for vectorized operations
1743 /// - **Memory Efficient**: Uses optimized alignment strategies
1744 ///
1745 /// # Example
1746 ///
1747 /// ```
1748 /// use train_station::Tensor;
1749 ///
1750 /// // Create uninitialized tensor for matmul result
1751 /// let mut result = Tensor::new_uninitialized(vec![100, 100]);
1752 /// // Initialize the memory before use
1753 /// for value in result.data_mut() {
1754 /// *value = 0.0;
1755 /// }
1756 /// ```
1757 #[inline]
1758 #[track_caller]
1759 pub fn new_uninitialized(shape_dims: Vec<usize>) -> Self {
1760 let shape = Shape::new(shape_dims);
1761 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1762
1763 if shape.size() == 0 {
1764 // Handle zero-sized tensors
1765 return Self {
1766 data: NonNull::dangling(),
1767 shape,
1768 device: current_device(),
1769 id,
1770 requires_grad: false,
1771 retain_grad: false,
1772 grad: None,
1773 grad_fn: GradFn::None,
1774 allocation_owner: None,
1775 graph_group: None,
1776 _phantom: PhantomData,
1777 };
1778 }
1779
1780 let (alignment, padded_elems) = compute_allocation_params(shape.size());
1781 let total_size = padded_elems * std::mem::size_of::<f32>();
1782 let layout = Layout::from_size_align(total_size, alignment)
1783 .expect("Failed to create layout for tensor data");
1784
1785 // Allocate memory via shared Allocation (uninitialized)
1786 let alloc_obj = if use_pool_alloc_enabled() {
1787 Allocation::new_pooled(padded_elems, alignment, layout)
1788 } else {
1789 Allocation::new_uninitialized(padded_elems, alignment, layout)
1790 };
1791 let ptr = alloc_obj.ptr;
1792
1793 debug_assert!(
1794 alloc_obj.capacity_elems() >= padded_elems,
1795 "Allocation capacity ({}) smaller than padded elements ({})",
1796 alloc_obj.capacity_elems(),
1797 padded_elems
1798 );
1799
1800 Self {
1801 data: ptr,
1802 shape,
1803 device: current_device(),
1804 id,
1805 requires_grad: false,
1806 retain_grad: false,
1807 grad: None,
1808 grad_fn: GradFn::None,
1809 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
1810 graph_group: None,
1811 _phantom: PhantomData,
1812 }
1813 }
1814
1815 /// Create a new uninitialized tensor with an explicit alignment request (in bytes)
1816 ///
1817 /// This is intended for internal high-performance paths (e.g., packed GEMM panels)
1818 /// where stronger alignment such as 64 bytes is desired even on AVX2 systems.
1819 #[inline]
1820 #[track_caller]
1821 pub fn new_uninitialized_aligned(shape_dims: Vec<usize>, alignment_bytes: usize) -> Self {
1822 let shape = Shape::new(shape_dims);
1823 let id = TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
1824
1825 if shape.size() == 0 {
1826 return Self {
1827 data: NonNull::dangling(),
1828 shape,
1829 device: current_device(),
1830 id,
1831 requires_grad: false,
1832 retain_grad: false,
1833 grad: None,
1834 grad_fn: GradFn::None,
1835 allocation_owner: None,
1836 graph_group: None,
1837 _phantom: PhantomData,
1838 };
1839 }
1840
1841 // Preserve the usual element padding policy
1842 let (_default_align, padded_elems) = compute_allocation_params(shape.size());
1843 // Honor explicit alignment request (at least 16)
1844 let alignment = alignment_bytes.max(16);
1845 let total_size = padded_elems * std::mem::size_of::<f32>();
1846 let layout = Layout::from_size_align(total_size, alignment)
1847 .expect("Failed to create layout for tensor data (aligned)");
1848
1849 let alloc_obj = if use_pool_alloc_enabled() {
1850 Allocation::new_pooled(padded_elems, alignment, layout)
1851 } else {
1852 Allocation::new_uninitialized(padded_elems, alignment, layout)
1853 };
1854 let ptr = alloc_obj.ptr;
1855
1856 debug_assert!(alloc_obj.capacity_elems() >= padded_elems);
1857
1858 Self {
1859 data: ptr,
1860 shape,
1861 device: current_device(),
1862 id,
1863 requires_grad: false,
1864 retain_grad: false,
1865 grad: None,
1866 grad_fn: GradFn::None,
1867 allocation_owner: Some(std::sync::Arc::new(alloc_obj)),
1868 graph_group: None,
1869 _phantom: PhantomData,
1870 }
1871 }
1872}
1873
1874// -------------------------------------------------------------------------------------------------
1875// Fast-path routing helper for iterator and collect APIs
1876// -------------------------------------------------------------------------------------------------
1877
1878/// Determine whether to use the no-grad fast path for iterator/collect operations.
1879///
1880/// Fast path is selected when either:
1881/// - No input requires gradients, or
1882/// - Global grad tracking is disabled for the current thread
1883#[inline]
1884pub(crate) fn should_use_fast_path(inputs: &[&Tensor]) -> bool {
1885 let any_requires_grad = inputs.iter().any(|t| t.requires_grad());
1886 !any_requires_grad || !crate::gradtrack::is_grad_enabled()
1887}
1888
1889#[cfg(test)]
1890mod memory_alloc_tests {
1891 use super::*;
1892 use crate::tensor::core::memory::{
1893 detect_runtime_simd, simd_alignment_bytes, simd_lane_width_elems, with_no_mem_padding,
1894 TensorMemoryPool,
1895 };
1896
1897 #[test]
1898 fn test_padding_and_alignment_pool_enabled() {
1899 let lane = simd_lane_width_elems(detect_runtime_simd());
1900 let align = simd_alignment_bytes(detect_runtime_simd());
1901
1902 let req = lane * 3 + 1; // force padding
1903 let t = Tensor::new(vec![req]);
1904
1905 assert_eq!(t.capacity_elems() % lane, 0);
1906 unsafe {
1907 assert_eq!((t.as_ptr() as usize) % align, 0);
1908 }
1909 // Logical size is not a multiple of lane, so aligned SIMD ops are not preferred
1910 #[cfg(target_arch = "x86_64")]
1911 assert!(!t.prefer_aligned_simd_ops());
1912
1913 drop(t);
1914 }
1915
1916 #[test]
1917 fn test_no_padding_with_guard() {
1918 with_no_mem_padding(|| {
1919 let lane = simd_lane_width_elems(detect_runtime_simd());
1920 let req = lane * 3 + 1; // non-multiple
1921 let t = Tensor::new(vec![req]);
1922
1923 assert_eq!(t.size(), req);
1924 #[cfg(target_arch = "x86_64")]
1925 assert!(!t.prefer_aligned_simd_ops());
1926 });
1927 }
1928
1929 #[test]
1930 fn test_system_alloc_no_pool_counters_unchanged() {
1931 let before = TensorMemoryPool::thread_stats();
1932 let _t1 = Tensor::new(vec![128]);
1933 let _t2 = Tensor::new(vec![257]);
1934 let after = TensorMemoryPool::thread_stats();
1935 assert_eq!(before.allocations, 0);
1936 assert_eq!(after.allocations, 2);
1937 assert_eq!(before.deallocations, 0);
1938 }
1939
1940 #[test]
1941 fn test_pool_alloc_and_dealloc_counters_match() {
1942 let before = TensorMemoryPool::thread_stats();
1943 {
1944 let _t1 = Tensor::new(vec![64]);
1945 let _t2 = Tensor::new(vec![2048]);
1946 let _t3 = Tensor::new(vec![131072]);
1947 }
1948 let after = TensorMemoryPool::thread_stats();
1949 assert!(after.allocations >= before.allocations + 3);
1950 assert!(after.deallocations >= before.deallocations + 3);
1951 }
1952
1953 #[test]
1954 fn test_mixed_modes_no_leak_in_pool_stats_scope() {
1955 let before = TensorMemoryPool::thread_stats();
1956 {
1957 let _a = Tensor::new(vec![1000]);
1958 let _b = Tensor::new(vec![1000]);
1959 let _c = Tensor::new(vec![2000]);
1960 let _c = Tensor::new(vec![2000]);
1961 }
1962
1963 let after = TensorMemoryPool::thread_stats();
1964 assert_eq!(
1965 after.allocations - before.allocations,
1966 after.deallocations - before.deallocations
1967 );
1968 }
1969
1970 #[cfg(target_arch = "x86_64")]
1971 #[test]
1972 fn test_alignment_helpers_and_hints() {
1973 let lane = simd_lane_width_elems(detect_runtime_simd());
1974 let req = lane * 4; // exact multiple
1975 let t = Tensor::new(vec![req]);
1976 assert!(t.is_aligned_for_runtime_level());
1977 assert!(t.prefer_aligned_simd_ops());
1978 assert!(!t.should_use_unaligned_simd_ops());
1979
1980 with_no_mem_padding(|| {
1981 let lane = simd_lane_width_elems(detect_runtime_simd());
1982 let req = lane * 4 + 1; // not multiple
1983 let t = Tensor::new(vec![req]);
1984 assert!(!t.prefer_aligned_simd_ops());
1985 assert!(t.should_use_unaligned_simd_ops());
1986 });
1987 }
1988}
1989
1990#[cfg(test)]
1991mod memory_alloc_additional_tests {
1992 use super::*;
1993 use crate::tensor::core::memory::{
1994 detect_runtime_simd, simd_lane_width_elems, with_no_mem_padding, with_no_mem_pool,
1995 TensorMemoryPool,
1996 };
1997 use std::time::Instant;
1998
1999 #[test]
2000 fn test_zero_size_tensors_pool_and_system() {
2001 let t = Tensor::new(vec![0]);
2002 assert_eq!(t.size(), 0);
2003 assert_eq!(t.capacity_elems(), 0);
2004 assert_eq!(t.data().len(), 0);
2005 let t = Tensor::new(vec![0]);
2006 assert_eq!(t.size(), 0);
2007 assert_eq!(t.capacity_elems(), 0);
2008 assert_eq!(t.data().len(), 0);
2009 }
2010
2011 #[test]
2012 fn test_capacity_across_various_sizes_padding_modes() {
2013 let lane = simd_lane_width_elems(detect_runtime_simd());
2014 let sizes = [
2015 0,
2016 1,
2017 lane - 1,
2018 lane,
2019 lane + 1,
2020 2 * lane - 1,
2021 2 * lane + 1,
2022 1000,
2023 1001,
2024 ];
2025
2026 for &n in &sizes {
2027 let t = Tensor::new(vec![n]);
2028 let expected_padded = if n == 0 { 0 } else { n.div_ceil(lane) * lane };
2029 // Pool may round up further than lane padding; ensure at least padded and lane-aligned
2030 assert!(
2031 t.capacity_elems() >= expected_padded,
2032 "n={} lane={} cap={}",
2033 n,
2034 lane,
2035 t.capacity_elems()
2036 );
2037 if t.capacity_elems() > 0 {
2038 assert_eq!(
2039 t.capacity_elems() % lane,
2040 0,
2041 "capacity not lane-multiple: {}",
2042 t.capacity_elems()
2043 );
2044 }
2045 }
2046 with_no_mem_padding(|| {
2047 for &n in &sizes {
2048 let t = Tensor::new(vec![n]);
2049 assert_eq!(t.size(), n);
2050 }
2051 });
2052 }
2053
2054 #[test]
2055 fn test_class_boundary_planned_capacity_pool() {
2056 let boundaries = [
2057 crate::tensor::core::memory::SMALL_BUFFER_SIZE,
2058 crate::tensor::core::memory::SMALL_BUFFER_SIZE + 1,
2059 crate::tensor::core::memory::MEDIUM_BUFFER_SIZE,
2060 crate::tensor::core::memory::MEDIUM_BUFFER_SIZE + 1,
2061 crate::tensor::core::memory::LARGE_BUFFER_SIZE,
2062 crate::tensor::core::memory::LARGE_BUFFER_SIZE + 1,
2063 ];
2064 let lane = simd_lane_width_elems(detect_runtime_simd());
2065 for &n in &boundaries {
2066 let padded = if n == 0 { 0 } else { n.div_ceil(lane) * lane };
2067 let planned = TensorMemoryPool::planned_capacity_elems(padded);
2068 let t = Tensor::new(vec![n]);
2069 assert_eq!(t.capacity_elems(), planned, "boundary n={}", n);
2070 }
2071 }
2072
2073 #[test]
2074 fn test_view_does_not_allocate_additional_memory() {
2075 let before = TensorMemoryPool::thread_stats();
2076 let t = Tensor::new(vec![256]);
2077 let _view = t.view(vec![16, 16]);
2078 let after = TensorMemoryPool::thread_stats();
2079 // Exactly one allocation happened (for t), view didn't allocate
2080 assert_eq!(after.allocations, before.allocations + 1);
2081 }
2082
2083 #[test]
2084 fn test_performance_pool_vs_system_small_allocs() {
2085 let iters = 1000;
2086 let _ = Tensor::new(vec![64]);
2087 let _ = Tensor::new(vec![64]);
2088 let _ = Tensor::new(vec![64]);
2089
2090 let pool_time = {
2091 let start = Instant::now();
2092 for _ in 0..iters {
2093 let _t = Tensor::new(vec![64]);
2094 }
2095 start.elapsed()
2096 };
2097 let sys_time = with_no_mem_pool(|| {
2098 let start = Instant::now();
2099 for _ in 0..iters {
2100 let _t = Tensor::new(vec![64]);
2101 }
2102 start.elapsed()
2103 });
2104 assert!(
2105 pool_time <= sys_time * 10,
2106 "pool {:?} vs system {:?}",
2107 pool_time,
2108 sys_time
2109 );
2110 }
2111
2112 #[test]
2113 fn test_performance_padded_vs_unpadded_fill() {
2114 let lane = simd_lane_width_elems(detect_runtime_simd());
2115 let n = lane * 2048;
2116
2117 let padded_time = {
2118 let t = Tensor::new(vec![n]);
2119 let mut x = t.clone();
2120 let start = Instant::now();
2121 x.fill(1.2345);
2122 start.elapsed()
2123 };
2124 let unpadded_time = with_no_mem_padding(|| {
2125 let t = Tensor::new(vec![n + 1]);
2126 let mut x = t.clone();
2127 let start = Instant::now();
2128 x.fill(1.2345);
2129 start.elapsed()
2130 });
2131 assert!(
2132 padded_time <= unpadded_time * 100,
2133 "padded {:?} vs unpadded {:?}",
2134 padded_time,
2135 unpadded_time
2136 );
2137 }
2138}