train_station/tensor/init/basic.rs
1//! Basic tensor initialization methods
2//!
3//! This module provides fundamental tensor initialization operations for creating
4//! tensors with specific constant values. All methods are optimized with SIMD
5//! operations for maximum performance on large tensors.
6//!
7//! # Key Features
8//!
9//! - **`zeros`**: Create tensors filled with zeros
10//! - **`ones`**: Create tensors filled with ones
11//! - **`fill`**: Fill existing tensors with a constant value
12//! - **Device-aware initialization**: Create tensors on specific devices
13//! - **SIMD optimization**: Vectorized operations for large tensors
14//! - **Thread safety**: All operations are thread-safe
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Allocation**: Single allocation with optimized alignment
19//! - **SIMD Operations**: AVX2-optimized filling for large tensors
20//! - **Unrolled Loops**: 4x unrolling for better instruction throughput
21//! - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
22//! - **Zero-sized Handling**: Efficient handling of empty tensors
23//!
24//! # Examples
25//!
26//! ## Basic Initialization
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensors with different constant values
32//! let zeros = Tensor::zeros(vec![2, 3]);
33//! let ones = Tensor::ones(vec![2, 3]);
34//! let mut filled = Tensor::new(vec![2, 3]);
35//! filled.fill(42.0);
36//!
37//! assert_eq!(zeros.shape().dims, vec![2, 3]);
38//! assert_eq!(ones.shape().dims, vec![2, 3]);
39//! assert_eq!(filled.shape().dims, vec![2, 3]);
40//!
41//! // Verify initialization values
42//! assert_eq!(zeros.get(&[0, 0]), 0.0);
43//! assert_eq!(ones.get(&[0, 0]), 1.0);
44//! assert_eq!(filled.get(&[0, 0]), 42.0);
45//! ```
46//!
47//! ## Device-Specific Initialization
48//!
49//! ```
50//! use train_station::Tensor;
51//! use train_station::Device;
52//!
53//! // Create tensors on specific devices
54//! let cpu_zeros = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
55//! let cpu_ones = Tensor::ones_on_device(vec![2, 2], Device::cpu());
56//!
57//! assert_eq!(cpu_zeros.device(), Device::cpu());
58//! assert_eq!(cpu_ones.device(), Device::cpu());
59//! assert_eq!(cpu_zeros.size(), 4);
60//! assert_eq!(cpu_ones.size(), 4);
61//!
62//! // Verify device-specific initialization
63//! assert_eq!(cpu_zeros.get(&[0, 0]), 0.0);
64//! assert_eq!(cpu_ones.get(&[0, 0]), 1.0);
65//! ```
66//!
67//! ## Fill Operations
68//!
69//! ```
70//! use train_station::Tensor;
71//!
72//! // Fill existing tensors with constant values
73//! let mut tensor = Tensor::new(vec![3, 3]);
74//! tensor.fill(3.14159);
75//!
76//! // Verify all elements are filled with the specified value
77//! for i in 0..tensor.size() {
78//! assert!((tensor.get(&[i / 3, i % 3]) - 3.14159).abs() < 1e-6);
79//! }
80//! ```
81//!
82//! ## Zero-Sized Tensor Handling
83//!
84//! ```
85//! use train_station::Tensor;
86//!
87//! // Handle zero-sized tensors gracefully
88//! let mut empty_tensor = Tensor::new(vec![0]);
89//! empty_tensor.fill(42.0); // Should not panic
90//! assert_eq!(empty_tensor.size(), 0);
91//! ```
92//!
93//! # Design Principles
94//!
95//! - **Performance First**: SIMD-optimized operations for maximum speed
96//! - **Memory Safety**: Safe operations with proper bounds checking
97//! - **Device Abstraction**: Unified interface for CPU and future GPU operations
98//! - **Zero-Cost Abstractions**: Minimal overhead for initialization operations
99//! - **Thread Safety**: All operations are safe for concurrent access
100
101use crate::tensor::core::Tensor;
102
103impl Tensor {
104 /// Creates a new tensor filled with zeros
105 ///
106 /// Convenience constructor that creates a tensor and initializes all elements
107 /// to zero. Uses optimized SIMD operations for efficient zero initialization.
108 ///
109 /// # Arguments
110 ///
111 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
112 ///
113 /// # Returns
114 ///
115 /// A new tensor with all elements initialized to zero
116 ///
117 /// # Performance
118 ///
119 /// - **Memory Allocation**: Single allocation with optimized alignment
120 /// - **Initialization**: SIMD-optimized zero filling for large tensors
121 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
122 ///
123 /// # Examples
124 ///
125 /// ```
126 /// use train_station::Tensor;
127 ///
128 /// let tensor = Tensor::zeros(vec![2, 3]);
129 /// assert_eq!(tensor.size(), 6);
130 /// assert_eq!(tensor.shape().dims, vec![2, 3]);
131 ///
132 /// // Verify all elements are zero
133 /// assert_eq!(tensor.get(&[0, 0]), 0.0);
134 /// assert_eq!(tensor.get(&[1, 2]), 0.0);
135 /// ```
136 #[inline]
137 #[track_caller]
138 pub fn zeros(shape_dims: Vec<usize>) -> Self {
139 let mut tensor = Self::new(shape_dims);
140 tensor.fill(0.0);
141 tensor
142 }
143
144 /// Creates a new tensor filled with ones
145 ///
146 /// Convenience constructor that creates a tensor and initializes all elements
147 /// to one. Uses optimized SIMD operations for efficient initialization.
148 ///
149 /// # Arguments
150 ///
151 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
152 ///
153 /// # Returns
154 ///
155 /// A new tensor with all elements initialized to one
156 ///
157 /// # Performance
158 ///
159 /// - **Memory Allocation**: Single allocation with optimized alignment
160 /// - **Initialization**: SIMD-optimized one filling for large tensors
161 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
162 ///
163 /// # Examples
164 ///
165 /// ```
166 /// use train_station::Tensor;
167 ///
168 /// let tensor = Tensor::ones(vec![2, 3]);
169 /// assert_eq!(tensor.size(), 6);
170 /// assert_eq!(tensor.shape().dims, vec![2, 3]);
171 ///
172 /// // Verify all elements are one
173 /// assert_eq!(tensor.get(&[0, 0]), 1.0);
174 /// assert_eq!(tensor.get(&[1, 2]), 1.0);
175 /// ```
176 #[inline]
177 #[track_caller]
178 pub fn ones(shape_dims: Vec<usize>) -> Self {
179 let mut tensor = Self::new(shape_dims);
180 tensor.fill(1.0);
181 tensor
182 }
183
184 /// Creates a new tensor filled with zeros on a specific device
185 ///
186 /// Convenience constructor that creates a tensor on the specified device
187 /// and initializes all elements to zero. Uses optimized SIMD operations
188 /// for efficient zero initialization.
189 ///
190 /// # Arguments
191 ///
192 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
193 /// * `device` - The device where the tensor should be allocated
194 ///
195 /// # Returns
196 ///
197 /// A new tensor with all elements initialized to zero
198 ///
199 /// # Performance
200 ///
201 /// - **Memory Allocation**: Device-specific allocation with optimized alignment
202 /// - **Initialization**: SIMD-optimized zero filling for large tensors
203 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
204 ///
205 /// # Examples
206 ///
207 /// ```
208 /// use train_station::Tensor;
209 /// use train_station::Device;
210 ///
211 /// let tensor = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
212 /// assert_eq!(tensor.device(), Device::cpu());
213 /// assert_eq!(tensor.size(), 4);
214 ///
215 /// // Verify all elements are zero
216 /// assert_eq!(tensor.get(&[0, 0]), 0.0);
217 /// assert_eq!(tensor.get(&[1, 1]), 0.0);
218 /// ```
219 #[inline]
220 #[track_caller]
221 pub fn zeros_on_device(shape_dims: Vec<usize>, device: crate::device::Device) -> Self {
222 let mut tensor = Self::new_on_device(shape_dims, device);
223 tensor.fill(0.0);
224 tensor
225 }
226
227 /// Creates a new tensor filled with ones on a specific device
228 ///
229 /// Convenience constructor that creates a tensor on the specified device
230 /// and initializes all elements to one. Uses optimized SIMD operations
231 /// for efficient initialization.
232 ///
233 /// # Arguments
234 ///
235 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
236 /// * `device` - The device where the tensor should be allocated
237 ///
238 /// # Returns
239 ///
240 /// A new tensor with all elements initialized to one
241 ///
242 /// # Performance
243 ///
244 /// - **Memory Allocation**: Device-specific allocation with optimized alignment
245 /// - **Initialization**: SIMD-optimized one filling for large tensors
246 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
247 ///
248 /// # Examples
249 ///
250 /// ```
251 /// use train_station::Tensor;
252 /// use train_station::Device;
253 ///
254 /// let tensor = Tensor::ones_on_device(vec![2, 2], Device::cpu());
255 /// assert_eq!(tensor.device(), Device::cpu());
256 /// assert_eq!(tensor.size(), 4);
257 ///
258 /// // Verify all elements are one
259 /// assert_eq!(tensor.get(&[0, 0]), 1.0);
260 /// assert_eq!(tensor.get(&[1, 1]), 1.0);
261 /// ```
262 #[inline]
263 #[track_caller]
264 pub fn ones_on_device(shape_dims: Vec<usize>, device: crate::device::Device) -> Self {
265 let mut tensor = Self::new_on_device(shape_dims, device);
266 tensor.fill(1.0);
267 tensor
268 }
269
270 /// Fills the tensor with a constant value using SIMD optimization
271 ///
272 /// Efficiently initializes all elements of the tensor to the specified value.
273 /// Uses SIMD operations for large tensors to maximize performance.
274 ///
275 /// # Arguments
276 ///
277 /// * `value` - The value to fill the tensor with
278 ///
279 /// # Performance
280 ///
281 /// - **SIMD Optimization**: Uses AVX2 for large tensors when available
282 /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
283 /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
284 ///
285 /// # Examples
286 ///
287 /// ```
288 /// use train_station::Tensor;
289 ///
290 /// let mut tensor = Tensor::new(vec![2, 3]);
291 /// tensor.fill(42.0);
292 ///
293 /// // Verify all elements are 42.0
294 /// assert_eq!(tensor.get(&[0, 0]), 42.0);
295 /// assert_eq!(tensor.get(&[1, 2]), 42.0);
296 /// ```
297 ///
298 /// ## Zero-Sized Tensor Handling
299 ///
300 /// ```
301 /// use train_station::Tensor;
302 ///
303 /// let mut empty_tensor = Tensor::new(vec![0]);
304 /// empty_tensor.fill(42.0); // Should not panic
305 /// assert_eq!(empty_tensor.size(), 0);
306 /// ```
307 #[inline]
308 #[track_caller]
309 pub fn fill(&mut self, value: f32) {
310 if self.shape().size == 0 {
311 return;
312 }
313
314 unsafe {
315 let ptr = self.as_mut_ptr();
316
317 #[cfg(target_arch = "x86_64")]
318 {
319 // Use SIMD for better performance when available
320 if is_x86_feature_detected!("avx2") {
321 self.fill_simd_avx2(ptr, value);
322 return;
323 }
324 }
325
326 // Fallback to scalar operations
327 for i in 0..self.shape().size {
328 *ptr.add(i) = value;
329 }
330 }
331 }
332
333 /// Fills the tensor with a constant value using AVX2 SIMD optimization
334 ///
335 /// Internal method that uses AVX2 instructions to efficiently fill large tensors.
336 /// Processes 32 elements per iteration with 4x unrolling for maximum memory bandwidth.
337 ///
338 /// # Arguments
339 ///
340 /// * `ptr` - Mutable pointer to the tensor data
341 /// * `value` - The value to fill the tensor with
342 ///
343 /// # Safety
344 ///
345 /// The caller must ensure:
346 /// * `ptr` is a valid pointer to tensor data
347 /// * The tensor size matches the allocated memory
348 /// * AVX2 is available on the target architecture
349 ///
350 /// # Performance
351 ///
352 /// - **SIMD Operations**: 32 elements per iteration using AVX2
353 /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
354 /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
355 /// - **Remaining Elements**: Efficient handling of non-multiple-of-32 sizes
356 ///
357 /// # Implementation Details
358 ///
359 /// This method uses AVX2 SIMD instructions to fill memory efficiently:
360 /// 1. Creates a vector of 8 identical values using `_mm256_set1_ps`
361 /// 2. Processes 32 elements per iteration (4x unrolled)
362 /// 3. Handles remaining 8-element blocks
363 /// 4. Fills final elements with scalar operations
364 #[cfg(target_arch = "x86_64")]
365 #[inline]
366 unsafe fn fill_simd_avx2(&self, ptr: *mut f32, value: f32) {
367 let mut_ptr = ptr;
368 let value_vec = _mm256_set1_ps(value);
369 let size = self.shape().size;
370 let simd_count = size / 32; // Process 32 elements per iteration
371 let mut offset = 0;
372
373 // Unrolled SIMD fill for better memory bandwidth utilization
374 for _ in 0..simd_count {
375 _mm256_store_ps(mut_ptr.add(offset), value_vec);
376 _mm256_store_ps(mut_ptr.add(offset + 8), value_vec);
377 _mm256_store_ps(mut_ptr.add(offset + 16), value_vec);
378 _mm256_store_ps(mut_ptr.add(offset + 24), value_vec);
379 offset += 32;
380 }
381
382 // Handle remaining 8-element blocks
383 let remaining_full_blocks = (size - offset) / 8;
384 for _ in 0..remaining_full_blocks {
385 _mm256_store_ps(mut_ptr.add(offset), value_vec);
386 offset += 8;
387 }
388
389 // Handle final elements
390 for i in offset..size {
391 *mut_ptr.add(i) = value;
392 }
393 }
394}
395
396// SIMD optimizations for performance-critical operations
397#[cfg(target_arch = "x86_64")]
398use std::arch::x86_64::*;
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_zeros_basic() {
406 let tensor = Tensor::zeros(vec![2, 3]);
407 assert_eq!(tensor.size(), 6);
408 assert_eq!(tensor.shape().dims, vec![2, 3]);
409
410 // Verify all elements are zero
411 for i in 0..tensor.size() {
412 unsafe {
413 assert_eq!(*tensor.as_ptr().add(i), 0.0);
414 }
415 }
416 }
417
418 #[test]
419 fn test_ones_basic() {
420 let tensor = Tensor::ones(vec![2, 3]);
421 assert_eq!(tensor.size(), 6);
422 assert_eq!(tensor.shape().dims, vec![2, 3]);
423
424 // Verify all elements are one
425 for i in 0..tensor.size() {
426 unsafe {
427 assert_eq!(*tensor.as_ptr().add(i), 1.0);
428 }
429 }
430 }
431
432 #[test]
433 fn test_zeros_on_device() {
434 use crate::device::Device;
435
436 let tensor = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
437 assert_eq!(tensor.device(), Device::cpu());
438 assert_eq!(tensor.size(), 4);
439
440 // Verify all elements are zero
441 for i in 0..tensor.size() {
442 unsafe {
443 assert_eq!(*tensor.as_ptr().add(i), 0.0);
444 }
445 }
446 }
447
448 #[test]
449 fn test_ones_on_device() {
450 use crate::device::Device;
451
452 let tensor = Tensor::ones_on_device(vec![2, 2], Device::cpu());
453 assert_eq!(tensor.device(), Device::cpu());
454 assert_eq!(tensor.size(), 4);
455
456 // Verify all elements are one
457 for i in 0..tensor.size() {
458 unsafe {
459 assert_eq!(*tensor.as_ptr().add(i), 1.0);
460 }
461 }
462 }
463
464 #[test]
465 fn test_fill_basic() {
466 let mut tensor = Tensor::new(vec![2, 3]);
467 tensor.fill(42.0);
468
469 // Verify all elements are 42.0
470 for i in 0..tensor.size() {
471 unsafe {
472 assert_eq!(*tensor.as_ptr().add(i), 42.0);
473 }
474 }
475 }
476
477 #[test]
478 fn test_fill_zero_sized() {
479 let mut tensor = Tensor::new(vec![0]);
480 // Should not panic
481 tensor.fill(42.0);
482 assert_eq!(tensor.size(), 0);
483 }
484
485 #[test]
486 fn test_fill_large_tensor() {
487 let mut tensor = Tensor::new(vec![100, 100]);
488 tensor.fill(std::f32::consts::PI);
489
490 // Verify all elements are 3.14159
491 for i in 0..tensor.size() {
492 unsafe {
493 assert!((*tensor.as_ptr().add(i) - std::f32::consts::PI).abs() < 1e-6);
494 }
495 }
496 }
497}