train_station/tensor/init/
basic.rs

1//! Basic tensor initialization methods
2//!
3//! This module provides fundamental tensor initialization operations for creating
4//! tensors with specific constant values. All methods are optimized with SIMD
5//! operations for maximum performance on large tensors.
6//!
7//! # Key Features
8//!
9//! - **`zeros`**: Create tensors filled with zeros
10//! - **`ones`**: Create tensors filled with ones  
11//! - **`fill`**: Fill existing tensors with a constant value
12//! - **Device-aware initialization**: Create tensors on specific devices
13//! - **SIMD optimization**: Vectorized operations for large tensors
14//! - **Thread safety**: All operations are thread-safe
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Allocation**: Single allocation with optimized alignment
19//! - **SIMD Operations**: AVX2-optimized filling for large tensors
20//! - **Unrolled Loops**: 4x unrolling for better instruction throughput
21//! - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
22//! - **Zero-sized Handling**: Efficient handling of empty tensors
23//!
24//! # Examples
25//!
26//! ## Basic Initialization
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensors with different constant values
32//! let zeros = Tensor::zeros(vec![2, 3]);
33//! let ones = Tensor::ones(vec![2, 3]);
34//! let mut filled = Tensor::new(vec![2, 3]);
35//! filled.fill(42.0);
36//!
37//! assert_eq!(zeros.shape().dims, vec![2, 3]);
38//! assert_eq!(ones.shape().dims, vec![2, 3]);
39//! assert_eq!(filled.shape().dims, vec![2, 3]);
40//!
41//! // Verify initialization values
42//! assert_eq!(zeros.get(&[0, 0]), 0.0);
43//! assert_eq!(ones.get(&[0, 0]), 1.0);
44//! assert_eq!(filled.get(&[0, 0]), 42.0);
45//! ```
46//!
47//! ## Device-Specific Initialization
48//!
49//! ```
50//! use train_station::Tensor;
51//! use train_station::Device;
52//!
53//! // Create tensors on specific devices
54//! let cpu_zeros = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
55//! let cpu_ones = Tensor::ones_on_device(vec![2, 2], Device::cpu());
56//!
57//! assert_eq!(cpu_zeros.device(), Device::cpu());
58//! assert_eq!(cpu_ones.device(), Device::cpu());
59//! assert_eq!(cpu_zeros.size(), 4);
60//! assert_eq!(cpu_ones.size(), 4);
61//!
62//! // Verify device-specific initialization
63//! assert_eq!(cpu_zeros.get(&[0, 0]), 0.0);
64//! assert_eq!(cpu_ones.get(&[0, 0]), 1.0);
65//! ```
66//!
67//! ## Fill Operations
68//!
69//! ```
70//! use train_station::Tensor;
71//!
72//! // Fill existing tensors with constant values
73//! let mut tensor = Tensor::new(vec![3, 3]);
74//! tensor.fill(3.14159);
75//!
76//! // Verify all elements are filled with the specified value
77//! for i in 0..tensor.size() {
78//!     assert!((tensor.get(&[i / 3, i % 3]) - 3.14159).abs() < 1e-6);
79//! }
80//! ```
81//!
82//! ## Zero-Sized Tensor Handling
83//!
84//! ```
85//! use train_station::Tensor;
86//!
87//! // Handle zero-sized tensors gracefully
88//! let mut empty_tensor = Tensor::new(vec![0]);
89//! empty_tensor.fill(42.0); // Should not panic
90//! assert_eq!(empty_tensor.size(), 0);
91//! ```
92//!
93//! # Design Principles
94//!
95//! - **Performance First**: SIMD-optimized operations for maximum speed
96//! - **Memory Safety**: Safe operations with proper bounds checking
97//! - **Device Abstraction**: Unified interface for CPU and future GPU operations
98//! - **Zero-Cost Abstractions**: Minimal overhead for initialization operations
99//! - **Thread Safety**: All operations are safe for concurrent access
100
101use crate::tensor::core::Tensor;
102
103impl Tensor {
104    /// Creates a new tensor filled with zeros
105    ///
106    /// Convenience constructor that creates a tensor and initializes all elements
107    /// to zero. Uses optimized SIMD operations for efficient zero initialization.
108    ///
109    /// # Arguments
110    ///
111    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
112    ///
113    /// # Returns
114    ///
115    /// A new tensor with all elements initialized to zero
116    ///
117    /// # Performance
118    ///
119    /// - **Memory Allocation**: Single allocation with optimized alignment
120    /// - **Initialization**: SIMD-optimized zero filling for large tensors
121    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
122    ///
123    /// # Examples
124    ///
125    /// ```
126    /// use train_station::Tensor;
127    ///
128    /// let tensor = Tensor::zeros(vec![2, 3]);
129    /// assert_eq!(tensor.size(), 6);
130    /// assert_eq!(tensor.shape().dims, vec![2, 3]);
131    ///
132    /// // Verify all elements are zero
133    /// assert_eq!(tensor.get(&[0, 0]), 0.0);
134    /// assert_eq!(tensor.get(&[1, 2]), 0.0);
135    /// ```
136    #[inline]
137    #[track_caller]
138    pub fn zeros(shape_dims: Vec<usize>) -> Self {
139        let mut tensor = Self::new(shape_dims);
140        tensor.fill(0.0);
141        tensor
142    }
143
144    /// Creates a new tensor filled with ones
145    ///
146    /// Convenience constructor that creates a tensor and initializes all elements
147    /// to one. Uses optimized SIMD operations for efficient initialization.
148    ///
149    /// # Arguments
150    ///
151    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
152    ///
153    /// # Returns
154    ///
155    /// A new tensor with all elements initialized to one
156    ///
157    /// # Performance
158    ///
159    /// - **Memory Allocation**: Single allocation with optimized alignment
160    /// - **Initialization**: SIMD-optimized one filling for large tensors
161    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
162    ///
163    /// # Examples
164    ///
165    /// ```
166    /// use train_station::Tensor;
167    ///
168    /// let tensor = Tensor::ones(vec![2, 3]);
169    /// assert_eq!(tensor.size(), 6);
170    /// assert_eq!(tensor.shape().dims, vec![2, 3]);
171    ///
172    /// // Verify all elements are one
173    /// assert_eq!(tensor.get(&[0, 0]), 1.0);
174    /// assert_eq!(tensor.get(&[1, 2]), 1.0);
175    /// ```
176    #[inline]
177    #[track_caller]
178    pub fn ones(shape_dims: Vec<usize>) -> Self {
179        let mut tensor = Self::new(shape_dims);
180        tensor.fill(1.0);
181        tensor
182    }
183
184    /// Creates a new tensor filled with zeros on a specific device
185    ///
186    /// Convenience constructor that creates a tensor on the specified device
187    /// and initializes all elements to zero. Uses optimized SIMD operations
188    /// for efficient zero initialization.
189    ///
190    /// # Arguments
191    ///
192    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
193    /// * `device` - The device where the tensor should be allocated
194    ///
195    /// # Returns
196    ///
197    /// A new tensor with all elements initialized to zero
198    ///
199    /// # Performance
200    ///
201    /// - **Memory Allocation**: Device-specific allocation with optimized alignment
202    /// - **Initialization**: SIMD-optimized zero filling for large tensors
203    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
204    ///
205    /// # Examples
206    ///
207    /// ```
208    /// use train_station::Tensor;
209    /// use train_station::Device;
210    ///
211    /// let tensor = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
212    /// assert_eq!(tensor.device(), Device::cpu());
213    /// assert_eq!(tensor.size(), 4);
214    ///
215    /// // Verify all elements are zero
216    /// assert_eq!(tensor.get(&[0, 0]), 0.0);
217    /// assert_eq!(tensor.get(&[1, 1]), 0.0);
218    /// ```
219    #[inline]
220    #[track_caller]
221    pub fn zeros_on_device(shape_dims: Vec<usize>, device: crate::device::Device) -> Self {
222        let mut tensor = Self::new_on_device(shape_dims, device);
223        tensor.fill(0.0);
224        tensor
225    }
226
227    /// Creates a new tensor filled with ones on a specific device
228    ///
229    /// Convenience constructor that creates a tensor on the specified device
230    /// and initializes all elements to one. Uses optimized SIMD operations
231    /// for efficient initialization.
232    ///
233    /// # Arguments
234    ///
235    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
236    /// * `device` - The device where the tensor should be allocated
237    ///
238    /// # Returns
239    ///
240    /// A new tensor with all elements initialized to one
241    ///
242    /// # Performance
243    ///
244    /// - **Memory Allocation**: Device-specific allocation with optimized alignment
245    /// - **Initialization**: SIMD-optimized one filling for large tensors
246    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
247    ///
248    /// # Examples
249    ///
250    /// ```
251    /// use train_station::Tensor;
252    /// use train_station::Device;
253    ///
254    /// let tensor = Tensor::ones_on_device(vec![2, 2], Device::cpu());
255    /// assert_eq!(tensor.device(), Device::cpu());
256    /// assert_eq!(tensor.size(), 4);
257    ///
258    /// // Verify all elements are one
259    /// assert_eq!(tensor.get(&[0, 0]), 1.0);
260    /// assert_eq!(tensor.get(&[1, 1]), 1.0);
261    /// ```
262    #[inline]
263    #[track_caller]
264    pub fn ones_on_device(shape_dims: Vec<usize>, device: crate::device::Device) -> Self {
265        let mut tensor = Self::new_on_device(shape_dims, device);
266        tensor.fill(1.0);
267        tensor
268    }
269
270    /// Fills the tensor with a constant value using SIMD optimization
271    ///
272    /// Efficiently initializes all elements of the tensor to the specified value.
273    /// Uses SIMD operations for large tensors to maximize performance.
274    ///
275    /// # Arguments
276    ///
277    /// * `value` - The value to fill the tensor with
278    ///
279    /// # Performance
280    ///
281    /// - **SIMD Optimization**: Uses AVX2 for large tensors when available
282    /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
283    /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
284    ///
285    /// # Examples
286    ///
287    /// ```
288    /// use train_station::Tensor;
289    ///
290    /// let mut tensor = Tensor::new(vec![2, 3]);
291    /// tensor.fill(42.0);
292    ///
293    /// // Verify all elements are 42.0
294    /// assert_eq!(tensor.get(&[0, 0]), 42.0);
295    /// assert_eq!(tensor.get(&[1, 2]), 42.0);
296    /// ```
297    ///
298    /// ## Zero-Sized Tensor Handling
299    ///
300    /// ```
301    /// use train_station::Tensor;
302    ///
303    /// let mut empty_tensor = Tensor::new(vec![0]);
304    /// empty_tensor.fill(42.0); // Should not panic
305    /// assert_eq!(empty_tensor.size(), 0);
306    /// ```
307    #[inline]
308    #[track_caller]
309    pub fn fill(&mut self, value: f32) {
310        if self.shape().size == 0 {
311            return;
312        }
313
314        unsafe {
315            let ptr = self.as_mut_ptr();
316
317            #[cfg(target_arch = "x86_64")]
318            {
319                // Use SIMD for better performance when available
320                if is_x86_feature_detected!("avx2") {
321                    self.fill_simd_avx2(ptr, value);
322                    return;
323                }
324            }
325
326            // Fallback to scalar operations
327            for i in 0..self.shape().size {
328                *ptr.add(i) = value;
329            }
330        }
331    }
332
333    /// Fills the tensor with a constant value using AVX2 SIMD optimization
334    ///
335    /// Internal method that uses AVX2 instructions to efficiently fill large tensors.
336    /// Processes 32 elements per iteration with 4x unrolling for maximum memory bandwidth.
337    ///
338    /// # Arguments
339    ///
340    /// * `ptr` - Mutable pointer to the tensor data
341    /// * `value` - The value to fill the tensor with
342    ///
343    /// # Safety
344    ///
345    /// The caller must ensure:
346    /// * `ptr` is a valid pointer to tensor data
347    /// * The tensor size matches the allocated memory
348    /// * AVX2 is available on the target architecture
349    ///
350    /// # Performance
351    ///
352    /// - **SIMD Operations**: 32 elements per iteration using AVX2
353    /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
354    /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
355    /// - **Remaining Elements**: Efficient handling of non-multiple-of-32 sizes
356    ///
357    /// # Implementation Details
358    ///
359    /// This method uses AVX2 SIMD instructions to fill memory efficiently:
360    /// 1. Creates a vector of 8 identical values using `_mm256_set1_ps`
361    /// 2. Processes 32 elements per iteration (4x unrolled)
362    /// 3. Handles remaining 8-element blocks
363    /// 4. Fills final elements with scalar operations
364    #[cfg(target_arch = "x86_64")]
365    #[inline]
366    unsafe fn fill_simd_avx2(&self, ptr: *mut f32, value: f32) {
367        let mut_ptr = ptr;
368        let value_vec = _mm256_set1_ps(value);
369        let size = self.shape().size;
370        let simd_count = size / 32; // Process 32 elements per iteration
371        let mut offset = 0;
372
373        // Unrolled SIMD fill for better memory bandwidth utilization
374        for _ in 0..simd_count {
375            _mm256_store_ps(mut_ptr.add(offset), value_vec);
376            _mm256_store_ps(mut_ptr.add(offset + 8), value_vec);
377            _mm256_store_ps(mut_ptr.add(offset + 16), value_vec);
378            _mm256_store_ps(mut_ptr.add(offset + 24), value_vec);
379            offset += 32;
380        }
381
382        // Handle remaining 8-element blocks
383        let remaining_full_blocks = (size - offset) / 8;
384        for _ in 0..remaining_full_blocks {
385            _mm256_store_ps(mut_ptr.add(offset), value_vec);
386            offset += 8;
387        }
388
389        // Handle final elements
390        for i in offset..size {
391            *mut_ptr.add(i) = value;
392        }
393    }
394}
395
396// SIMD optimizations for performance-critical operations
397#[cfg(target_arch = "x86_64")]
398use std::arch::x86_64::*;
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_zeros_basic() {
406        let tensor = Tensor::zeros(vec![2, 3]);
407        assert_eq!(tensor.size(), 6);
408        assert_eq!(tensor.shape().dims, vec![2, 3]);
409
410        // Verify all elements are zero
411        for i in 0..tensor.size() {
412            unsafe {
413                assert_eq!(*tensor.as_ptr().add(i), 0.0);
414            }
415        }
416    }
417
418    #[test]
419    fn test_ones_basic() {
420        let tensor = Tensor::ones(vec![2, 3]);
421        assert_eq!(tensor.size(), 6);
422        assert_eq!(tensor.shape().dims, vec![2, 3]);
423
424        // Verify all elements are one
425        for i in 0..tensor.size() {
426            unsafe {
427                assert_eq!(*tensor.as_ptr().add(i), 1.0);
428            }
429        }
430    }
431
432    #[test]
433    fn test_zeros_on_device() {
434        use crate::device::Device;
435
436        let tensor = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
437        assert_eq!(tensor.device(), Device::cpu());
438        assert_eq!(tensor.size(), 4);
439
440        // Verify all elements are zero
441        for i in 0..tensor.size() {
442            unsafe {
443                assert_eq!(*tensor.as_ptr().add(i), 0.0);
444            }
445        }
446    }
447
448    #[test]
449    fn test_ones_on_device() {
450        use crate::device::Device;
451
452        let tensor = Tensor::ones_on_device(vec![2, 2], Device::cpu());
453        assert_eq!(tensor.device(), Device::cpu());
454        assert_eq!(tensor.size(), 4);
455
456        // Verify all elements are one
457        for i in 0..tensor.size() {
458            unsafe {
459                assert_eq!(*tensor.as_ptr().add(i), 1.0);
460            }
461        }
462    }
463
464    #[test]
465    fn test_fill_basic() {
466        let mut tensor = Tensor::new(vec![2, 3]);
467        tensor.fill(42.0);
468
469        // Verify all elements are 42.0
470        for i in 0..tensor.size() {
471            unsafe {
472                assert_eq!(*tensor.as_ptr().add(i), 42.0);
473            }
474        }
475    }
476
477    #[test]
478    fn test_fill_zero_sized() {
479        let mut tensor = Tensor::new(vec![0]);
480        // Should not panic
481        tensor.fill(42.0);
482        assert_eq!(tensor.size(), 0);
483    }
484
485    #[test]
486    fn test_fill_large_tensor() {
487        let mut tensor = Tensor::new(vec![100, 100]);
488        tensor.fill(std::f32::consts::PI);
489
490        // Verify all elements are 3.14159
491        for i in 0..tensor.size() {
492            unsafe {
493                assert!((*tensor.as_ptr().add(i) - std::f32::consts::PI).abs() < 1e-6);
494            }
495        }
496    }
497}