train_station/tensor/init/basic.rs
1//! Basic tensor initialization methods
2//!
3//! This module provides fundamental tensor initialization operations for creating
4//! tensors with specific constant values. All methods are optimized with SIMD
5//! operations for maximum performance on large tensors.
6//!
7//! # Key Features
8//!
9//! - **`zeros`**: Create tensors filled with zeros
10//! - **`ones`**: Create tensors filled with ones
11//! - **`fill`**: Fill existing tensors with a constant value
12//! - **Device-aware initialization**: Create tensors on specific devices
13//! - **SIMD optimization**: Vectorized operations for large tensors
14//! - **Thread safety**: All operations are thread-safe
15//!
16//! # Performance Characteristics
17//!
18//! - **Memory Allocation**: Single allocation with optimized alignment
19//! - **SIMD Operations**: AVX2-optimized filling for large tensors
20//! - **Unrolled Loops**: 4x unrolling for better instruction throughput
21//! - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
22//! - **Zero-sized Handling**: Efficient handling of empty tensors
23//!
24//! # Examples
25//!
26//! ## Basic Initialization
27//!
28//! ```
29//! use train_station::Tensor;
30//!
31//! // Create tensors with different constant values
32//! let zeros = Tensor::zeros(vec![2, 3]);
33//! let ones = Tensor::ones(vec![2, 3]);
34//! let mut filled = Tensor::new(vec![2, 3]);
35//! filled.fill(42.0);
36//!
37//! assert_eq!(zeros.shape().dims, vec![2, 3]);
38//! assert_eq!(ones.shape().dims, vec![2, 3]);
39//! assert_eq!(filled.shape().dims, vec![2, 3]);
40//!
41//! // Verify initialization values
42//! assert_eq!(zeros.get(&[0, 0]), 0.0);
43//! assert_eq!(ones.get(&[0, 0]), 1.0);
44//! assert_eq!(filled.get(&[0, 0]), 42.0);
45//! ```
46//!
47//! ## Device-Specific Initialization
48//!
49//! ```
50//! use train_station::Tensor;
51//! use train_station::Device;
52//!
53//! // Create tensors on specific devices
54//! let cpu_zeros = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
55//! let cpu_ones = Tensor::ones_on_device(vec![2, 2], Device::cpu());
56//!
57//! assert_eq!(cpu_zeros.device(), Device::cpu());
58//! assert_eq!(cpu_ones.device(), Device::cpu());
59//! assert_eq!(cpu_zeros.size(), 4);
60//! assert_eq!(cpu_ones.size(), 4);
61//!
62//! // Verify device-specific initialization
63//! assert_eq!(cpu_zeros.get(&[0, 0]), 0.0);
64//! assert_eq!(cpu_ones.get(&[0, 0]), 1.0);
65//! ```
66//!
67//! ## Fill Operations
68//!
69//! ```
70//! use train_station::Tensor;
71//!
72//! // Fill existing tensors with constant values
73//! let mut tensor = Tensor::new(vec![3, 3]);
74//! tensor.fill(3.14159);
75//!
76//! // Verify all elements are filled with the specified value
77//! for i in 0..tensor.size() {
78//! assert!((tensor.get(&[i / 3, i % 3]) - 3.14159).abs() < 1e-6);
79//! }
80//! ```
81//!
82//! ## Zero-Sized Tensor Handling
83//!
84//! ```
85//! use train_station::Tensor;
86//!
87//! // Handle zero-sized tensors gracefully
88//! let mut empty_tensor = Tensor::new(vec![0]);
89//! empty_tensor.fill(42.0); // Should not panic
90//! assert_eq!(empty_tensor.size(), 0);
91//! ```
92//!
93//! # Design Principles
94//!
95//! - **Performance First**: SIMD-optimized operations for maximum speed
96//! - **Memory Safety**: Safe operations with proper bounds checking
97//! - **Device Abstraction**: Unified interface for CPU and future GPU operations
98//! - **Zero-Cost Abstractions**: Minimal overhead for initialization operations
99//! - **Thread Safety**: All operations are safe for concurrent access
100
101use crate::tensor::core::Tensor;
102
103impl Tensor {
104 /// Creates a new tensor filled with zeros
105 ///
106 /// Convenience constructor that creates a tensor and initializes all elements
107 /// to zero. Uses optimized SIMD operations for efficient zero initialization.
108 ///
109 /// # Arguments
110 ///
111 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
112 ///
113 /// # Returns
114 ///
115 /// A new tensor with all elements initialized to zero
116 ///
117 /// # Performance
118 ///
119 /// - **Memory Allocation**: Single allocation with optimized alignment
120 /// - **Initialization**: SIMD-optimized zero filling for large tensors
121 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
122 ///
123 /// # Examples
124 ///
125 /// ```
126 /// use train_station::Tensor;
127 ///
128 /// let tensor = Tensor::zeros(vec![2, 3]);
129 /// assert_eq!(tensor.size(), 6);
130 /// assert_eq!(tensor.shape().dims, vec![2, 3]);
131 ///
132 /// // Verify all elements are zero
133 /// assert_eq!(tensor.get(&[0, 0]), 0.0);
134 /// assert_eq!(tensor.get(&[1, 2]), 0.0);
135 /// ```
136 #[inline]
137 pub fn zeros(shape_dims: Vec<usize>) -> Self {
138 let mut tensor = Self::new(shape_dims);
139 tensor.fill(0.0);
140 tensor
141 }
142
143 /// Creates a new tensor filled with ones
144 ///
145 /// Convenience constructor that creates a tensor and initializes all elements
146 /// to one. Uses optimized SIMD operations for efficient initialization.
147 ///
148 /// # Arguments
149 ///
150 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
151 ///
152 /// # Returns
153 ///
154 /// A new tensor with all elements initialized to one
155 ///
156 /// # Performance
157 ///
158 /// - **Memory Allocation**: Single allocation with optimized alignment
159 /// - **Initialization**: SIMD-optimized one filling for large tensors
160 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
161 ///
162 /// # Examples
163 ///
164 /// ```
165 /// use train_station::Tensor;
166 ///
167 /// let tensor = Tensor::ones(vec![2, 3]);
168 /// assert_eq!(tensor.size(), 6);
169 /// assert_eq!(tensor.shape().dims, vec![2, 3]);
170 ///
171 /// // Verify all elements are one
172 /// assert_eq!(tensor.get(&[0, 0]), 1.0);
173 /// assert_eq!(tensor.get(&[1, 2]), 1.0);
174 /// ```
175 #[inline]
176 pub fn ones(shape_dims: Vec<usize>) -> Self {
177 let mut tensor = Self::new(shape_dims);
178 tensor.fill(1.0);
179 tensor
180 }
181
182 /// Creates a new tensor filled with zeros on a specific device
183 ///
184 /// Convenience constructor that creates a tensor on the specified device
185 /// and initializes all elements to zero. Uses optimized SIMD operations
186 /// for efficient zero initialization.
187 ///
188 /// # Arguments
189 ///
190 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
191 /// * `device` - The device where the tensor should be allocated
192 ///
193 /// # Returns
194 ///
195 /// A new tensor with all elements initialized to zero
196 ///
197 /// # Performance
198 ///
199 /// - **Memory Allocation**: Device-specific allocation with optimized alignment
200 /// - **Initialization**: SIMD-optimized zero filling for large tensors
201 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
202 ///
203 /// # Examples
204 ///
205 /// ```
206 /// use train_station::Tensor;
207 /// use train_station::Device;
208 ///
209 /// let tensor = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
210 /// assert_eq!(tensor.device(), Device::cpu());
211 /// assert_eq!(tensor.size(), 4);
212 ///
213 /// // Verify all elements are zero
214 /// assert_eq!(tensor.get(&[0, 0]), 0.0);
215 /// assert_eq!(tensor.get(&[1, 1]), 0.0);
216 /// ```
217 #[inline]
218 pub fn zeros_on_device(shape_dims: Vec<usize>, device: crate::device::Device) -> Self {
219 let mut tensor = Self::new_on_device(shape_dims, device);
220 tensor.fill(0.0);
221 tensor
222 }
223
224 /// Creates a new tensor filled with ones on a specific device
225 ///
226 /// Convenience constructor that creates a tensor on the specified device
227 /// and initializes all elements to one. Uses optimized SIMD operations
228 /// for efficient initialization.
229 ///
230 /// # Arguments
231 ///
232 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
233 /// * `device` - The device where the tensor should be allocated
234 ///
235 /// # Returns
236 ///
237 /// A new tensor with all elements initialized to one
238 ///
239 /// # Performance
240 ///
241 /// - **Memory Allocation**: Device-specific allocation with optimized alignment
242 /// - **Initialization**: SIMD-optimized one filling for large tensors
243 /// - **Thread Safe**: Atomic ID generation for gradtrack tracking
244 ///
245 /// # Examples
246 ///
247 /// ```
248 /// use train_station::Tensor;
249 /// use train_station::Device;
250 ///
251 /// let tensor = Tensor::ones_on_device(vec![2, 2], Device::cpu());
252 /// assert_eq!(tensor.device(), Device::cpu());
253 /// assert_eq!(tensor.size(), 4);
254 ///
255 /// // Verify all elements are one
256 /// assert_eq!(tensor.get(&[0, 0]), 1.0);
257 /// assert_eq!(tensor.get(&[1, 1]), 1.0);
258 /// ```
259 #[inline]
260 pub fn ones_on_device(shape_dims: Vec<usize>, device: crate::device::Device) -> Self {
261 let mut tensor = Self::new_on_device(shape_dims, device);
262 tensor.fill(1.0);
263 tensor
264 }
265
266 /// Fills the tensor with a constant value using SIMD optimization
267 ///
268 /// Efficiently initializes all elements of the tensor to the specified value.
269 /// Uses SIMD operations for large tensors to maximize performance.
270 ///
271 /// # Arguments
272 ///
273 /// * `value` - The value to fill the tensor with
274 ///
275 /// # Performance
276 ///
277 /// - **SIMD Optimization**: Uses AVX2 for large tensors when available
278 /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
279 /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
280 ///
281 /// # Examples
282 ///
283 /// ```
284 /// use train_station::Tensor;
285 ///
286 /// let mut tensor = Tensor::new(vec![2, 3]);
287 /// tensor.fill(42.0);
288 ///
289 /// // Verify all elements are 42.0
290 /// assert_eq!(tensor.get(&[0, 0]), 42.0);
291 /// assert_eq!(tensor.get(&[1, 2]), 42.0);
292 /// ```
293 ///
294 /// ## Zero-Sized Tensor Handling
295 ///
296 /// ```
297 /// use train_station::Tensor;
298 ///
299 /// let mut empty_tensor = Tensor::new(vec![0]);
300 /// empty_tensor.fill(42.0); // Should not panic
301 /// assert_eq!(empty_tensor.size(), 0);
302 /// ```
303 #[inline]
304 pub fn fill(&mut self, value: f32) {
305 if self.shape().size == 0 {
306 return;
307 }
308
309 unsafe {
310 let ptr = self.as_mut_ptr();
311
312 #[cfg(target_arch = "x86_64")]
313 {
314 // Use SIMD for better performance when available
315 if is_x86_feature_detected!("avx2") {
316 self.fill_simd_avx2(ptr, value);
317 return;
318 }
319 }
320
321 // Fallback to scalar operations
322 for i in 0..self.shape().size {
323 *ptr.add(i) = value;
324 }
325 }
326 }
327
328 /// Fills the tensor with a constant value using AVX2 SIMD optimization
329 ///
330 /// Internal method that uses AVX2 instructions to efficiently fill large tensors.
331 /// Processes 32 elements per iteration with 4x unrolling for maximum memory bandwidth.
332 ///
333 /// # Arguments
334 ///
335 /// * `ptr` - Mutable pointer to the tensor data
336 /// * `value` - The value to fill the tensor with
337 ///
338 /// # Safety
339 ///
340 /// The caller must ensure:
341 /// * `ptr` is a valid pointer to tensor data
342 /// * The tensor size matches the allocated memory
343 /// * AVX2 is available on the target architecture
344 ///
345 /// # Performance
346 ///
347 /// - **SIMD Operations**: 32 elements per iteration using AVX2
348 /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
349 /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
350 /// - **Remaining Elements**: Efficient handling of non-multiple-of-32 sizes
351 ///
352 /// # Implementation Details
353 ///
354 /// This method uses AVX2 SIMD instructions to fill memory efficiently:
355 /// 1. Creates a vector of 8 identical values using `_mm256_set1_ps`
356 /// 2. Processes 32 elements per iteration (4x unrolled)
357 /// 3. Handles remaining 8-element blocks
358 /// 4. Fills final elements with scalar operations
359 #[cfg(target_arch = "x86_64")]
360 #[inline]
361 unsafe fn fill_simd_avx2(&self, ptr: *mut f32, value: f32) {
362 let mut_ptr = ptr;
363 let value_vec = _mm256_set1_ps(value);
364 let size = self.shape().size;
365 let simd_count = size / 32; // Process 32 elements per iteration
366 let mut offset = 0;
367
368 // Unrolled SIMD fill for better memory bandwidth utilization
369 for _ in 0..simd_count {
370 _mm256_store_ps(mut_ptr.add(offset), value_vec);
371 _mm256_store_ps(mut_ptr.add(offset + 8), value_vec);
372 _mm256_store_ps(mut_ptr.add(offset + 16), value_vec);
373 _mm256_store_ps(mut_ptr.add(offset + 24), value_vec);
374 offset += 32;
375 }
376
377 // Handle remaining 8-element blocks
378 let remaining_full_blocks = (size - offset) / 8;
379 for _ in 0..remaining_full_blocks {
380 _mm256_store_ps(mut_ptr.add(offset), value_vec);
381 offset += 8;
382 }
383
384 // Handle final elements
385 for i in offset..size {
386 *mut_ptr.add(i) = value;
387 }
388 }
389}
390
391// SIMD optimizations for performance-critical operations
392#[cfg(target_arch = "x86_64")]
393use std::arch::x86_64::*;
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_zeros_basic() {
401 let tensor = Tensor::zeros(vec![2, 3]);
402 assert_eq!(tensor.size(), 6);
403 assert_eq!(tensor.shape().dims, vec![2, 3]);
404
405 // Verify all elements are zero
406 for i in 0..tensor.size() {
407 unsafe {
408 assert_eq!(*tensor.as_ptr().add(i), 0.0);
409 }
410 }
411 }
412
413 #[test]
414 fn test_ones_basic() {
415 let tensor = Tensor::ones(vec![2, 3]);
416 assert_eq!(tensor.size(), 6);
417 assert_eq!(tensor.shape().dims, vec![2, 3]);
418
419 // Verify all elements are one
420 for i in 0..tensor.size() {
421 unsafe {
422 assert_eq!(*tensor.as_ptr().add(i), 1.0);
423 }
424 }
425 }
426
427 #[test]
428 fn test_zeros_on_device() {
429 use crate::device::Device;
430
431 let tensor = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
432 assert_eq!(tensor.device(), Device::cpu());
433 assert_eq!(tensor.size(), 4);
434
435 // Verify all elements are zero
436 for i in 0..tensor.size() {
437 unsafe {
438 assert_eq!(*tensor.as_ptr().add(i), 0.0);
439 }
440 }
441 }
442
443 #[test]
444 fn test_ones_on_device() {
445 use crate::device::Device;
446
447 let tensor = Tensor::ones_on_device(vec![2, 2], Device::cpu());
448 assert_eq!(tensor.device(), Device::cpu());
449 assert_eq!(tensor.size(), 4);
450
451 // Verify all elements are one
452 for i in 0..tensor.size() {
453 unsafe {
454 assert_eq!(*tensor.as_ptr().add(i), 1.0);
455 }
456 }
457 }
458
459 #[test]
460 fn test_fill_basic() {
461 let mut tensor = Tensor::new(vec![2, 3]);
462 tensor.fill(42.0);
463
464 // Verify all elements are 42.0
465 for i in 0..tensor.size() {
466 unsafe {
467 assert_eq!(*tensor.as_ptr().add(i), 42.0);
468 }
469 }
470 }
471
472 #[test]
473 fn test_fill_zero_sized() {
474 let mut tensor = Tensor::new(vec![0]);
475 // Should not panic
476 tensor.fill(42.0);
477 assert_eq!(tensor.size(), 0);
478 }
479
480 #[test]
481 fn test_fill_large_tensor() {
482 let mut tensor = Tensor::new(vec![100, 100]);
483 tensor.fill(std::f32::consts::PI);
484
485 // Verify all elements are 3.14159
486 for i in 0..tensor.size() {
487 unsafe {
488 assert!((*tensor.as_ptr().add(i) - std::f32::consts::PI).abs() < 1e-6);
489 }
490 }
491 }
492}