train_station/tensor/init/
basic.rs

1//! Basic tensor initialization methods
2//!
3//! This module provides fundamental tensor initialization operations for creating
4//! tensors with specific constant values. All methods are optimized with SIMD
5//! operations for maximum performance on large tensors.
6//!
7//! # Key Features
8//!
9//! - **`zeros`**: Create tensors filled with zeros
10//! - **`ones`**: Create tensors filled with ones  
11//! - **`fill`**: Fill existing tensors with a constant value
12//! - **Device-aware initialization**: Create tensors on specific devices
13//! - **SIMD optimization**: Vectorized operations for large tensors
14//! - **Thread safety**: All operations are thread-safe
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Allocation**: Single allocation with optimized alignment
19//! - **SIMD Operations**: AVX2-optimized filling for large tensors
20//! - **Unrolled Loops**: 4x unrolling for better instruction throughput
21//! - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
22//! - **Zero-sized Handling**: Efficient handling of empty tensors
23//!
24//! # Examples
25//!
26//! ## Basic Initialization
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensors with different constant values
32//! let zeros = Tensor::zeros(vec![2, 3]);
33//! let ones = Tensor::ones(vec![2, 3]);
34//! let mut filled = Tensor::new(vec![2, 3]);
35//! filled.fill(42.0);
36//!
37//! assert_eq!(zeros.shape().dims, vec![2, 3]);
38//! assert_eq!(ones.shape().dims, vec![2, 3]);
39//! assert_eq!(filled.shape().dims, vec![2, 3]);
40//!
41//! // Verify initialization values
42//! assert_eq!(zeros.get(&[0, 0]), 0.0);
43//! assert_eq!(ones.get(&[0, 0]), 1.0);
44//! assert_eq!(filled.get(&[0, 0]), 42.0);
45//! ```
46//!
47//! ## Device-Specific Initialization
48//!
49//! ```
50//! use train_station::Tensor;
51//! use train_station::Device;
52//!
53//! // Create tensors on specific devices
54//! let cpu_zeros = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
55//! let cpu_ones = Tensor::ones_on_device(vec![2, 2], Device::cpu());
56//!
57//! assert_eq!(cpu_zeros.device(), Device::cpu());
58//! assert_eq!(cpu_ones.device(), Device::cpu());
59//! assert_eq!(cpu_zeros.size(), 4);
60//! assert_eq!(cpu_ones.size(), 4);
61//!
62//! // Verify device-specific initialization
63//! assert_eq!(cpu_zeros.get(&[0, 0]), 0.0);
64//! assert_eq!(cpu_ones.get(&[0, 0]), 1.0);
65//! ```
66//!
67//! ## Fill Operations
68//!
69//! ```
70//! use train_station::Tensor;
71//!
72//! // Fill existing tensors with constant values
73//! let mut tensor = Tensor::new(vec![3, 3]);
74//! tensor.fill(3.14159);
75//!
76//! // Verify all elements are filled with the specified value
77//! for i in 0..tensor.size() {
78//!     assert!((tensor.get(&[i / 3, i % 3]) - 3.14159).abs() < 1e-6);
79//! }
80//! ```
81//!
82//! ## Zero-Sized Tensor Handling
83//!
84//! ```
85//! use train_station::Tensor;
86//!
87//! // Handle zero-sized tensors gracefully
88//! let mut empty_tensor = Tensor::new(vec![0]);
89//! empty_tensor.fill(42.0); // Should not panic
90//! assert_eq!(empty_tensor.size(), 0);
91//! ```
92//!
93//! # Design Principles
94//!
95//! - **Performance First**: SIMD-optimized operations for maximum speed
96//! - **Memory Safety**: Safe operations with proper bounds checking
97//! - **Device Abstraction**: Unified interface for CPU and future GPU operations
98//! - **Zero-Cost Abstractions**: Minimal overhead for initialization operations
99//! - **Thread Safety**: All operations are safe for concurrent access
100
101use crate::tensor::core::Tensor;
102
103impl Tensor {
104    /// Creates a new tensor filled with zeros
105    ///
106    /// Convenience constructor that creates a tensor and initializes all elements
107    /// to zero. Uses optimized SIMD operations for efficient zero initialization.
108    ///
109    /// # Arguments
110    ///
111    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
112    ///
113    /// # Returns
114    ///
115    /// A new tensor with all elements initialized to zero
116    ///
117    /// # Performance
118    ///
119    /// - **Memory Allocation**: Single allocation with optimized alignment
120    /// - **Initialization**: SIMD-optimized zero filling for large tensors
121    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
122    ///
123    /// # Examples
124    ///
125    /// ```
126    /// use train_station::Tensor;
127    ///
128    /// let tensor = Tensor::zeros(vec![2, 3]);
129    /// assert_eq!(tensor.size(), 6);
130    /// assert_eq!(tensor.shape().dims, vec![2, 3]);
131    ///
132    /// // Verify all elements are zero
133    /// assert_eq!(tensor.get(&[0, 0]), 0.0);
134    /// assert_eq!(tensor.get(&[1, 2]), 0.0);
135    /// ```
136    #[inline]
137    pub fn zeros(shape_dims: Vec<usize>) -> Self {
138        let mut tensor = Self::new(shape_dims);
139        tensor.fill(0.0);
140        tensor
141    }
142
143    /// Creates a new tensor filled with ones
144    ///
145    /// Convenience constructor that creates a tensor and initializes all elements
146    /// to one. Uses optimized SIMD operations for efficient initialization.
147    ///
148    /// # Arguments
149    ///
150    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
151    ///
152    /// # Returns
153    ///
154    /// A new tensor with all elements initialized to one
155    ///
156    /// # Performance
157    ///
158    /// - **Memory Allocation**: Single allocation with optimized alignment
159    /// - **Initialization**: SIMD-optimized one filling for large tensors
160    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
161    ///
162    /// # Examples
163    ///
164    /// ```
165    /// use train_station::Tensor;
166    ///
167    /// let tensor = Tensor::ones(vec![2, 3]);
168    /// assert_eq!(tensor.size(), 6);
169    /// assert_eq!(tensor.shape().dims, vec![2, 3]);
170    ///
171    /// // Verify all elements are one
172    /// assert_eq!(tensor.get(&[0, 0]), 1.0);
173    /// assert_eq!(tensor.get(&[1, 2]), 1.0);
174    /// ```
175    #[inline]
176    pub fn ones(shape_dims: Vec<usize>) -> Self {
177        let mut tensor = Self::new(shape_dims);
178        tensor.fill(1.0);
179        tensor
180    }
181
182    /// Creates a new tensor filled with zeros on a specific device
183    ///
184    /// Convenience constructor that creates a tensor on the specified device
185    /// and initializes all elements to zero. Uses optimized SIMD operations
186    /// for efficient zero initialization.
187    ///
188    /// # Arguments
189    ///
190    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
191    /// * `device` - The device where the tensor should be allocated
192    ///
193    /// # Returns
194    ///
195    /// A new tensor with all elements initialized to zero
196    ///
197    /// # Performance
198    ///
199    /// - **Memory Allocation**: Device-specific allocation with optimized alignment
200    /// - **Initialization**: SIMD-optimized zero filling for large tensors
201    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
202    ///
203    /// # Examples
204    ///
205    /// ```
206    /// use train_station::Tensor;
207    /// use train_station::Device;
208    ///
209    /// let tensor = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
210    /// assert_eq!(tensor.device(), Device::cpu());
211    /// assert_eq!(tensor.size(), 4);
212    ///
213    /// // Verify all elements are zero
214    /// assert_eq!(tensor.get(&[0, 0]), 0.0);
215    /// assert_eq!(tensor.get(&[1, 1]), 0.0);
216    /// ```
217    #[inline]
218    pub fn zeros_on_device(shape_dims: Vec<usize>, device: crate::device::Device) -> Self {
219        let mut tensor = Self::new_on_device(shape_dims, device);
220        tensor.fill(0.0);
221        tensor
222    }
223
224    /// Creates a new tensor filled with ones on a specific device
225    ///
226    /// Convenience constructor that creates a tensor on the specified device
227    /// and initializes all elements to one. Uses optimized SIMD operations
228    /// for efficient initialization.
229    ///
230    /// # Arguments
231    ///
232    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
233    /// * `device` - The device where the tensor should be allocated
234    ///
235    /// # Returns
236    ///
237    /// A new tensor with all elements initialized to one
238    ///
239    /// # Performance
240    ///
241    /// - **Memory Allocation**: Device-specific allocation with optimized alignment
242    /// - **Initialization**: SIMD-optimized one filling for large tensors
243    /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
244    ///
245    /// # Examples
246    ///
247    /// ```
248    /// use train_station::Tensor;
249    /// use train_station::Device;
250    ///
251    /// let tensor = Tensor::ones_on_device(vec![2, 2], Device::cpu());
252    /// assert_eq!(tensor.device(), Device::cpu());
253    /// assert_eq!(tensor.size(), 4);
254    ///
255    /// // Verify all elements are one
256    /// assert_eq!(tensor.get(&[0, 0]), 1.0);
257    /// assert_eq!(tensor.get(&[1, 1]), 1.0);
258    /// ```
259    #[inline]
260    pub fn ones_on_device(shape_dims: Vec<usize>, device: crate::device::Device) -> Self {
261        let mut tensor = Self::new_on_device(shape_dims, device);
262        tensor.fill(1.0);
263        tensor
264    }
265
266    /// Fills the tensor with a constant value using SIMD optimization
267    ///
268    /// Efficiently initializes all elements of the tensor to the specified value.
269    /// Uses SIMD operations for large tensors to maximize performance.
270    ///
271    /// # Arguments
272    ///
273    /// * `value` - The value to fill the tensor with
274    ///
275    /// # Performance
276    ///
277    /// - **SIMD Optimization**: Uses AVX2 for large tensors when available
278    /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
279    /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
280    ///
281    /// # Examples
282    ///
283    /// ```
284    /// use train_station::Tensor;
285    ///
286    /// let mut tensor = Tensor::new(vec![2, 3]);
287    /// tensor.fill(42.0);
288    ///
289    /// // Verify all elements are 42.0
290    /// assert_eq!(tensor.get(&[0, 0]), 42.0);
291    /// assert_eq!(tensor.get(&[1, 2]), 42.0);
292    /// ```
293    ///
294    /// ## Zero-Sized Tensor Handling
295    ///
296    /// ```
297    /// use train_station::Tensor;
298    ///
299    /// let mut empty_tensor = Tensor::new(vec![0]);
300    /// empty_tensor.fill(42.0); // Should not panic
301    /// assert_eq!(empty_tensor.size(), 0);
302    /// ```
303    #[inline]
304    pub fn fill(&mut self, value: f32) {
305        if self.shape().size == 0 {
306            return;
307        }
308
309        unsafe {
310            let ptr = self.as_mut_ptr();
311
312            #[cfg(target_arch = "x86_64")]
313            {
314                // Use SIMD for better performance when available
315                if is_x86_feature_detected!("avx2") {
316                    self.fill_simd_avx2(ptr, value);
317                    return;
318                }
319            }
320
321            // Fallback to scalar operations
322            for i in 0..self.shape().size {
323                *ptr.add(i) = value;
324            }
325        }
326    }
327
328    /// Fills the tensor with a constant value using AVX2 SIMD optimization
329    ///
330    /// Internal method that uses AVX2 instructions to efficiently fill large tensors.
331    /// Processes 32 elements per iteration with 4x unrolling for maximum memory bandwidth.
332    ///
333    /// # Arguments
334    ///
335    /// * `ptr` - Mutable pointer to the tensor data
336    /// * `value` - The value to fill the tensor with
337    ///
338    /// # Safety
339    ///
340    /// The caller must ensure:
341    /// * `ptr` is a valid pointer to tensor data
342    /// * The tensor size matches the allocated memory
343    /// * AVX2 is available on the target architecture
344    ///
345    /// # Performance
346    ///
347    /// - **SIMD Operations**: 32 elements per iteration using AVX2
348    /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
349    /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
350    /// - **Remaining Elements**: Efficient handling of non-multiple-of-32 sizes
351    ///
352    /// # Implementation Details
353    ///
354    /// This method uses AVX2 SIMD instructions to fill memory efficiently:
355    /// 1. Creates a vector of 8 identical values using `_mm256_set1_ps`
356    /// 2. Processes 32 elements per iteration (4x unrolled)
357    /// 3. Handles remaining 8-element blocks
358    /// 4. Fills final elements with scalar operations
359    #[cfg(target_arch = "x86_64")]
360    #[inline]
361    unsafe fn fill_simd_avx2(&self, ptr: *mut f32, value: f32) {
362        let mut_ptr = ptr;
363        let value_vec = _mm256_set1_ps(value);
364        let size = self.shape().size;
365        let simd_count = size / 32; // Process 32 elements per iteration
366        let mut offset = 0;
367
368        // Unrolled SIMD fill for better memory bandwidth utilization
369        for _ in 0..simd_count {
370            _mm256_store_ps(mut_ptr.add(offset), value_vec);
371            _mm256_store_ps(mut_ptr.add(offset + 8), value_vec);
372            _mm256_store_ps(mut_ptr.add(offset + 16), value_vec);
373            _mm256_store_ps(mut_ptr.add(offset + 24), value_vec);
374            offset += 32;
375        }
376
377        // Handle remaining 8-element blocks
378        let remaining_full_blocks = (size - offset) / 8;
379        for _ in 0..remaining_full_blocks {
380            _mm256_store_ps(mut_ptr.add(offset), value_vec);
381            offset += 8;
382        }
383
384        // Handle final elements
385        for i in offset..size {
386            *mut_ptr.add(i) = value;
387        }
388    }
389}
390
391// SIMD optimizations for performance-critical operations
392#[cfg(target_arch = "x86_64")]
393use std::arch::x86_64::*;
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_zeros_basic() {
401        let tensor = Tensor::zeros(vec![2, 3]);
402        assert_eq!(tensor.size(), 6);
403        assert_eq!(tensor.shape().dims, vec![2, 3]);
404
405        // Verify all elements are zero
406        for i in 0..tensor.size() {
407            unsafe {
408                assert_eq!(*tensor.as_ptr().add(i), 0.0);
409            }
410        }
411    }
412
413    #[test]
414    fn test_ones_basic() {
415        let tensor = Tensor::ones(vec![2, 3]);
416        assert_eq!(tensor.size(), 6);
417        assert_eq!(tensor.shape().dims, vec![2, 3]);
418
419        // Verify all elements are one
420        for i in 0..tensor.size() {
421            unsafe {
422                assert_eq!(*tensor.as_ptr().add(i), 1.0);
423            }
424        }
425    }
426
427    #[test]
428    fn test_zeros_on_device() {
429        use crate::device::Device;
430
431        let tensor = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
432        assert_eq!(tensor.device(), Device::cpu());
433        assert_eq!(tensor.size(), 4);
434
435        // Verify all elements are zero
436        for i in 0..tensor.size() {
437            unsafe {
438                assert_eq!(*tensor.as_ptr().add(i), 0.0);
439            }
440        }
441    }
442
443    #[test]
444    fn test_ones_on_device() {
445        use crate::device::Device;
446
447        let tensor = Tensor::ones_on_device(vec![2, 2], Device::cpu());
448        assert_eq!(tensor.device(), Device::cpu());
449        assert_eq!(tensor.size(), 4);
450
451        // Verify all elements are one
452        for i in 0..tensor.size() {
453            unsafe {
454                assert_eq!(*tensor.as_ptr().add(i), 1.0);
455            }
456        }
457    }
458
459    #[test]
460    fn test_fill_basic() {
461        let mut tensor = Tensor::new(vec![2, 3]);
462        tensor.fill(42.0);
463
464        // Verify all elements are 42.0
465        for i in 0..tensor.size() {
466            unsafe {
467                assert_eq!(*tensor.as_ptr().add(i), 42.0);
468            }
469        }
470    }
471
472    #[test]
473    fn test_fill_zero_sized() {
474        let mut tensor = Tensor::new(vec![0]);
475        // Should not panic
476        tensor.fill(42.0);
477        assert_eq!(tensor.size(), 0);
478    }
479
480    #[test]
481    fn test_fill_large_tensor() {
482        let mut tensor = Tensor::new(vec![100, 100]);
483        tensor.fill(std::f32::consts::PI);
484
485        // Verify all elements are 3.14159
486        for i in 0..tensor.size() {
487            unsafe {
488                assert!((*tensor.as_ptr().add(i) - std::f32::consts::PI).abs() < 1e-6);
489            }
490        }
491    }
492}