train_station/tensor/init/random.rs
1//! Random tensor initialization methods
2//!
3//! This module provides methods to create tensors with random values using
4//! efficient random number generation algorithms. All methods are optimized
5//! for performance with SIMD operations and provide reproducible results
6//! when seeds are specified.
7//!
8//! # Key Features
9//!
10//! - **`randn`**: Create tensors with normally distributed random values (mean=0, std=1)
11//! - **Box-Muller Transform**: Efficient normal distribution generation
12//! - **SIMD Optimization**: Vectorized operations for large tensors
13//! - **Reproducible Results**: Optional seed-based generation for deterministic output
14//! - **Thread Safety**: Thread-local random state management
15//! - **Statistical Quality**: High-quality random number generation
16//!
17//! # Performance Characteristics
18//!
19//! - **Box-Muller Transform**: Efficient normal distribution generation
20//! - **SIMD Operations**: AVX2-optimized operations for large tensors
21//! - **Memory Efficient**: Single-pass generation with optimized allocation
22//! - **Unrolled Loops**: 4x unrolling for better instruction throughput
23//! - **Zero Overhead**: Minimal validation overhead for correct usage
24//!
25//! # Examples
26//!
27//! ## Basic Random Generation
28//!
29//! ```
30//! use train_station::Tensor;
31//!
32//! // Create a 2x3 tensor with random normal values
33//! let tensor = Tensor::randn(vec![2, 3], None);
34//! assert_eq!(tensor.size(), 6);
35//! assert_eq!(tensor.shape().dims, vec![2, 3]);
36//!
37//! // Verify random values are generated
38//! let first_value = tensor.get(&[0, 0]);
39//! assert!(first_value != 0.0); // Should be random
40//! ```
41//!
42//! ## Reproducible Random Generation
43//!
44//! ```
45//! use train_station::Tensor;
46//!
47//! // Create with fixed seed for reproducible results
48//! let tensor1 = Tensor::randn(vec![100], Some(42));
49//! let tensor2 = Tensor::randn(vec![100], Some(42));
50//!
51//! // tensor1 and tensor2 will have identical values
52//! for i in 0..tensor1.size() {
53//! assert!((tensor1.get(&[i]) - tensor2.get(&[i])).abs() < 1e-6);
54//! }
55//! ```
56//!
57//! ## Different Seeds Produce Different Results
58//!
59//! ```
60//! use train_station::Tensor;
61//!
62//! let tensor1 = Tensor::randn(vec![100], Some(1));
63//! let tensor2 = Tensor::randn(vec![100], Some(2));
64//!
65//! // Should be different with different seeds
66//! let mut different = false;
67//! for i in 0..tensor1.size() {
68//! if (tensor1.get(&[i]) - tensor2.get(&[i])).abs() > 1e-6 {
69//! different = true;
70//! break;
71//! }
72//! }
73//! assert!(different, "Tensors with different seeds should be different");
74//! ```
75//!
76//! ## Large Tensor Generation
77//!
78//! ```
79//! use train_station::Tensor;
80//!
81//! // Efficient generation of large tensors
82//! let tensor = Tensor::randn(vec![100, 100], Some(42));
83//! assert_eq!(tensor.size(), 10000);
84//!
85//! // Check statistical properties
86//! let mut sum = 0.0;
87//! for i in 0..tensor.size() {
88//! sum += tensor.get(&[i / 100, i % 100]);
89//! }
90//! let mean = sum / tensor.size() as f32;
91//!
92//! // Mean should be close to 0 for normal distribution
93//! assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
94//! ```
95//!
96//! ## Zero-Sized Tensor Handling
97//!
98//! ```
99//! use train_station::Tensor;
100//!
101//! // Handle empty tensors gracefully
102//! let tensor = Tensor::randn(vec![0], Some(42));
103//! assert_eq!(tensor.size(), 0);
104//! assert_eq!(tensor.shape().dims, vec![0]);
105//! ```
106//!
107//! # Design Principles
108//!
109//! - **Statistical Quality**: High-quality random number generation with proper distribution
110//! - **Performance First**: SIMD-optimized operations for maximum speed
111//! - **Reproducibility**: Optional seed-based generation for deterministic results
112//! - **Memory Efficiency**: Efficient memory operations with minimal overhead
113//! - **Thread Safety**: Safe concurrent access with thread-local state
114//! - **Numerical Stability**: Robust handling of edge cases and numerical issues
115
116use crate::tensor::core::Tensor;
117use std::collections::hash_map::DefaultHasher;
118use std::hash::{Hash, Hasher};
119
120impl Tensor {
121 /// Creates a tensor with normally distributed random values (mean=0, std=1)
122 ///
123 /// Similar to PyTorch's `torch.randn()`, creates a tensor filled with random
124 /// values drawn from a standard normal distribution (mean=0, standard deviation=1).
125 /// Uses Box-Muller transform for efficient normal distribution generation.
126 ///
127 /// This method provides high-quality random number generation with optional
128 /// reproducibility through seed-based generation. The generated values follow
129 /// a standard normal distribution suitable for machine learning applications.
130 ///
131 /// # Arguments
132 ///
133 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
134 /// * `seed` - Optional seed for reproducible random generation
135 ///
136 /// # Returns
137 ///
138 /// A new tensor with normally distributed random values
139 ///
140 /// # Performance
141 ///
142 /// - **Box-Muller Transform**: Efficient normal distribution generation
143 /// - **SIMD Optimization**: Vectorized operations for large tensors
144 /// - **Memory Efficient**: Single-pass generation with optimized allocation
145 /// - **Thread Safe**: Uses thread-local random state
146 ///
147 /// # Examples
148 ///
149 /// ## Basic Usage
150 ///
151 /// ```
152 /// use train_station::Tensor;
153 ///
154 /// // Create a 2x3 tensor with random normal values
155 /// let tensor = Tensor::randn(vec![2, 3], None);
156 /// assert_eq!(tensor.size(), 6);
157 /// assert_eq!(tensor.shape().dims, vec![2, 3]);
158 ///
159 /// // Verify random values are generated
160 /// let first_value = tensor.get(&[0, 0]);
161 /// assert!(first_value != 0.0); // Should be random
162 /// ```
163 ///
164 /// ## Reproducible Generation
165 ///
166 /// ```
167 /// use train_station::Tensor;
168 ///
169 /// // Create with fixed seed for reproducible results
170 /// let tensor1 = Tensor::randn(vec![100], Some(42));
171 /// let tensor2 = Tensor::randn(vec![100], Some(42));
172 ///
173 /// // tensor1 and tensor2 will have identical values
174 /// for i in 0..tensor1.size() {
175 /// assert!((tensor1.get(&[i]) - tensor2.get(&[i])).abs() < 1e-6);
176 /// }
177 /// ```
178 ///
179 /// ## Statistical Properties
180 ///
181 /// ```
182 /// use train_station::Tensor;
183 ///
184 /// // Generate large tensor for statistical analysis
185 /// let tensor = Tensor::randn(vec![1000], Some(42));
186 /// assert_eq!(tensor.size(), 1000);
187 ///
188 /// // Check that values are reasonable (within 4 standard deviations)
189 /// let mut min_val = f32::INFINITY;
190 /// let mut max_val = f32::NEG_INFINITY;
191 /// let mut sum = 0.0;
192 ///
193 /// for i in 0..tensor.size() {
194 /// let val = tensor.get(&[i]);
195 /// min_val = min_val.min(val);
196 /// max_val = max_val.max(val);
197 /// sum += val;
198 /// }
199 ///
200 /// let mean = sum / tensor.size() as f32;
201 ///
202 /// // Mean should be close to 0, values should be within reasonable bounds
203 /// assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
204 /// assert!(min_val > -4.0, "Values should not be too negative, min: {}", min_val);
205 /// assert!(max_val < 4.0, "Values should not be too positive, max: {}", max_val);
206 /// ```
207 ///
208 /// ## Zero-Sized Tensors
209 ///
210 /// ```
211 /// use train_station::Tensor;
212 ///
213 /// // Handle empty tensors gracefully
214 /// let tensor = Tensor::randn(vec![0], Some(42));
215 /// assert_eq!(tensor.size(), 0);
216 /// assert_eq!(tensor.shape().dims, vec![0]);
217 /// ```
218 ///
219 /// # Implementation Details
220 ///
221 /// This method uses the Box-Muller transform to generate normally distributed
222 /// random variables from uniform random variables. The process involves:
223 /// 1. **Random Number Generation**: Uses Xorshift algorithm for uniform random numbers
224 /// 2. **Box-Muller Transform**: Converts uniform random variables to normal distribution
225 /// 3. **SIMD Optimization**: Vectorized operations for large tensors when available
226 /// 4. **Numerical Stability**: Robust handling of edge cases and potential NaN values
227 ///
228 /// The Box-Muller transform ensures that the generated values follow a true
229 /// normal distribution with mean=0 and standard deviation=1, making it suitable
230 /// for machine learning applications requiring normally distributed random values.
231 #[track_caller]
232 pub fn randn(shape_dims: Vec<usize>, seed: Option<u64>) -> Self {
233 let mut tensor = Self::new(shape_dims);
234 tensor.fill_randn(seed);
235 tensor
236 }
237
238 /// Fills the tensor with normally distributed random values
239 ///
240 /// Internal method that fills an existing tensor with random values from
241 /// a standard normal distribution. Uses Box-Muller transform for efficiency
242 /// and provides SIMD optimization for large tensors.
243 ///
244 /// This method is used internally by `randn()` and provides the core
245 /// random number generation functionality with optimized performance
246 /// characteristics.
247 ///
248 /// # Arguments
249 ///
250 /// * `seed` - Optional seed for reproducible random generation
251 ///
252 /// # Performance
253 ///
254 /// - **Box-Muller Transform**: Generates pairs of normal random variables
255 /// - **SIMD Optimization**: Vectorized operations when possible
256 /// - **Memory Efficient**: Single-pass generation
257 /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
258 ///
259 /// # Implementation Details
260 ///
261 /// The method performs the following steps:
262 /// 1. **Zero-sized Check**: Returns early for empty tensors
263 /// 2. **RNG Initialization**: Creates Xorshift RNG with seed or system time
264 /// 3. **SIMD Detection**: Checks for AVX2 availability for optimized path
265 /// 4. **Generation**: Uses SIMD or scalar path based on hardware support
266 /// 5. **Completion**: Fills all tensor elements with normal random values
267 ///
268 /// The method automatically handles hardware capabilities and falls back
269 /// to scalar operations when SIMD is not available, ensuring compatibility
270 /// across different CPU architectures.
271 #[track_caller]
272 pub fn fill_randn(&mut self, seed: Option<u64>) {
273 if self.shape().size == 0 {
274 return;
275 }
276
277 // Initialize random number generator
278 let mut rng = if let Some(seed_val) = seed {
279 // Use provided seed for reproducible results
280 XorShiftRng::new(seed_val)
281 } else {
282 // Use system time for non-reproducible results
283 XorShiftRng::new_from_time()
284 };
285
286 unsafe {
287 let ptr = self.as_ptr();
288
289 #[cfg(target_arch = "x86_64")]
290 {
291 // Use SIMD for better performance when available
292 if is_x86_feature_detected!("avx2") {
293 self.fill_randn_simd_avx2(ptr, &mut rng);
294 return;
295 }
296 }
297
298 // Fallback to scalar operations
299 self.fill_randn_scalar(ptr, &mut rng);
300 }
301 }
302
303 /// Fills the tensor with normally distributed random values using AVX2 SIMD
304 ///
305 /// Internal method that uses AVX2 instructions to efficiently fill large tensors
306 /// with normal random values. Processes 8 elements per iteration for maximum
307 /// memory bandwidth utilization.
308 ///
309 /// # Arguments
310 ///
311 /// * `ptr` - Pointer to the tensor data
312 /// * `rng` - Random number generator instance
313 ///
314 /// # Safety
315 ///
316 /// The caller must ensure:
317 /// * `ptr` is a valid pointer to tensor data
318 /// * The tensor size matches the allocated memory
319 /// * AVX2 is available on the target architecture
320 /// * `rng` is a valid random number generator instance
321 ///
322 /// # Performance
323 ///
324 /// - **SIMD Operations**: 8 elements per iteration using AVX2
325 /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
326 /// - **Remaining Elements**: Efficient handling of non-multiple-of-8 sizes
327 /// - **Box-Muller Transform**: Integrated normal distribution generation
328 ///
329 /// # Implementation Details
330 ///
331 /// This method uses AVX2 SIMD instructions to fill memory efficiently:
332 /// 1. Generates 8 normal random values using Box-Muller transform
333 /// 2. Loads values into AVX2 vector register using `_mm256_loadu_ps`
334 /// 3. Stores vector to memory using `_mm256_storeu_ps`
335 /// 4. Processes remaining elements with scalar operations
336 ///
337 /// The method provides significant performance improvements for large tensors
338 /// by reducing the number of memory operations and leveraging vectorized
339 /// floating-point operations.
340 #[cfg(target_arch = "x86_64")]
341 #[inline]
342 unsafe fn fill_randn_simd_avx2(&self, ptr: *const f32, rng: &mut XorShiftRng) {
343 let mut_ptr = ptr as *mut f32;
344 let size = self.shape().size;
345 let simd_count = size / 8; // Process 8 elements per iteration
346 let mut offset = 0;
347
348 // SIMD loop for normal distribution generation
349 for _ in 0..simd_count {
350 let mut values = [0.0f32; 8];
351 for i in &mut values {
352 *i = rng.next_normal();
353 }
354
355 // Store 8 values using SIMD
356 let vec = _mm256_loadu_ps(values.as_ptr());
357 _mm256_storeu_ps(mut_ptr.add(offset), vec);
358 offset += 8;
359 }
360
361 // Handle remaining elements
362 for i in offset..size {
363 *mut_ptr.add(i) = rng.next_normal();
364 }
365 }
366
367 /// Fills the tensor with normally distributed random values using scalar operations
368 ///
369 /// Internal fallback method that uses scalar operations to fill tensors with
370 /// normal random values. Provides 4x unrolled loops for better instruction
371 /// throughput and serves as a fallback when SIMD is not available.
372 ///
373 /// # Arguments
374 ///
375 /// * `ptr` - Pointer to the tensor data
376 /// * `rng` - Random number generator instance
377 ///
378 /// # Safety
379 ///
380 /// The caller must ensure:
381 /// * `ptr` is a valid pointer to tensor data
382 /// * The tensor size matches the allocated memory
383 /// * `rng` is a valid random number generator instance
384 ///
385 /// # Performance
386 ///
387 /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
388 /// - **Box-Muller Transform**: Integrated normal distribution generation
389 /// - **Remaining Elements**: Efficient handling of non-multiple-of-4 sizes
390 /// - **Cross-Platform**: Works on all CPU architectures
391 ///
392 /// # Implementation Details
393 ///
394 /// This method provides optimized scalar operations:
395 /// 1. **Unrolled Generation**: Processes 4 elements per iteration
396 /// 2. **Box-Muller Transform**: Generates normal random values
397 /// 3. **Remaining Elements**: Handles final elements individually
398 /// 4. **Cross-Platform**: No architecture-specific dependencies
399 ///
400 /// The 4x unrolling reduces loop overhead and improves instruction-level
401 /// parallelism, making scalar operations more efficient than naive loops.
402 #[inline]
403 unsafe fn fill_randn_scalar(&self, ptr: *const f32, rng: &mut XorShiftRng) {
404 let mut_ptr = ptr as *mut f32;
405 let size = self.shape().size;
406 let unroll_count = size / 4;
407 let mut offset = 0;
408
409 // Unrolled scalar loop for better performance
410 for _ in 0..unroll_count {
411 *mut_ptr.add(offset) = rng.next_normal();
412 *mut_ptr.add(offset + 1) = rng.next_normal();
413 *mut_ptr.add(offset + 2) = rng.next_normal();
414 *mut_ptr.add(offset + 3) = rng.next_normal();
415 offset += 4;
416 }
417
418 // Handle remaining elements
419 for i in offset..size {
420 *mut_ptr.add(i) = rng.next_normal();
421 }
422 }
423}
424
425// SIMD optimizations for performance-critical operations
426#[cfg(target_arch = "x86_64")]
427use std::arch::x86_64::*;
428
429/// Fast random number generator using Xorshift algorithm
430///
431/// Provides efficient random number generation with good statistical properties.
432/// Implements Box-Muller transform for normal distribution generation.
433///
434/// The Xorshift algorithm is a fast, non-cryptographic random number generator
435/// that provides good statistical properties for machine learning applications.
436/// It combines multiple bit-shift and XOR operations to produce high-quality
437/// random sequences with long periods.
438///
439/// # Performance
440///
441/// - **Fast Generation**: Minimal computational overhead
442/// - **Good Statistical Properties**: Passes standard statistical tests
443/// - **Long Period**: 2^64 - 1 period for u64 state
444/// - **Memory Efficient**: Single u64 state variable
445///
446/// # Implementation Details
447///
448/// The Xorshift algorithm uses three bit-shift and XOR operations:
449/// 1. `state ^= state << 13` - Left shift by 13 bits
450/// 2. `state ^= state >> 7` - Right shift by 7 bits
451/// 3. `state ^= state << 17` - Left shift by 17 bits
452///
453/// This sequence provides excellent statistical properties and is much faster
454/// than more complex generators like Mersenne Twister.
455struct XorShiftRng {
456 state: u64,
457}
458
459impl XorShiftRng {
460 /// Creates a new random number generator with the specified seed
461 ///
462 /// Initializes the RNG with a user-provided seed for reproducible
463 /// random number generation. The same seed will always produce
464 /// the same sequence of random numbers.
465 ///
466 /// # Arguments
467 ///
468 /// * `seed` - The seed value for reproducible generation
469 ///
470 /// # Implementation Details
471 ///
472 /// This method initializes the internal state with the provided seed value.
473 /// The same seed will always produce the same sequence of random numbers,
474 /// making it suitable for reproducible random number generation.
475 fn new(seed: u64) -> Self {
476 Self { state: seed }
477 }
478
479 /// Creates a new random number generator seeded from system time
480 ///
481 /// Initializes the RNG with a seed derived from the current system time,
482 /// providing non-reproducible random number generation. Each call will
483 /// produce a different sequence of random numbers.
484 ///
485 /// # Implementation Details
486 ///
487 /// This method uses the system time to generate a seed:
488 /// 1. Gets current system time using `std::time::SystemTime::now()`
489 /// 2. Hashes the time value using `DefaultHasher`
490 /// 3. Uses the hash result as the RNG seed
491 ///
492 /// This approach provides good entropy for non-reproducible generation
493 /// while being efficient and portable across different platforms.
494 fn new_from_time() -> Self {
495 let mut hasher = DefaultHasher::new();
496 std::time::SystemTime::now().hash(&mut hasher);
497 let seed = hasher.finish();
498 Self { state: seed }
499 }
500
501 /// Generates the next random u64 value
502 ///
503 /// Produces the next 64-bit random value using the Xorshift algorithm.
504 /// This is the core random number generation method that drives all
505 /// other random value generation.
506 ///
507 /// # Returns
508 ///
509 /// A random u64 value from the Xorshift sequence
510 ///
511 /// # Performance
512 ///
513 /// - **Fast**: Only 3 bit-shift and XOR operations
514 /// - **Efficient**: No branching or complex arithmetic
515 /// - **Deterministic**: Same seed produces same sequence
516 ///
517 /// # Implementation Details
518 ///
519 /// The Xorshift algorithm performs three operations:
520 /// 1. `self.state ^= self.state << 13` - Left shift and XOR
521 /// 2. `self.state ^= self.state >> 7` - Right shift and XOR
522 /// 3. `self.state ^= self.state << 17` - Left shift and XOR
523 ///
524 /// This sequence provides excellent statistical properties with minimal
525 /// computational overhead, making it ideal for high-performance applications.
526 fn next_u64(&mut self) -> u64 {
527 self.state ^= self.state << 13;
528 self.state ^= self.state >> 7;
529 self.state ^= self.state << 17;
530 self.state
531 }
532
533 /// Generates the next random f32 value in [0, 1)
534 ///
535 /// Produces a random floating-point value in the half-open interval [0, 1).
536 /// Uses proper bit manipulation to ensure uniform distribution across
537 /// the floating-point range.
538 ///
539 /// # Returns
540 ///
541 /// A random f32 value in [0, 1)
542 ///
543 /// # Performance
544 ///
545 /// - **Efficient**: Single u64 generation with bit manipulation
546 /// - **Uniform**: Proper distribution across floating-point range
547 /// - **Precise**: Uses 23 bits of mantissa for good precision
548 ///
549 /// # Implementation Details
550 ///
551 /// The method converts u64 random bits to f32 using bit manipulation:
552 /// 1. Generates random u64 using `next_u64()`
553 /// 2. Extracts 23 bits for mantissa (IEEE 754 f32 format)
554 /// 3. Sets exponent to 0 (bias 126) for values in [1, 2)
555 /// 4. Converts to f32 and subtracts 1.0 for [0, 1) range
556 ///
557 /// This approach provides uniform distribution and avoids the bias
558 /// that can occur with simple division-based methods.
559 fn next_f32(&mut self) -> f32 {
560 let bits = self.next_u64();
561 // Convert to f32 in [0, 1) using proper bit manipulation
562 let mantissa = (bits & 0x7FFFFF) as u32; // 23 bits for mantissa
563 let exponent = 126 << 23; // 2^0 = 1, so bias 126 gives exponent 0
564 let float_bits = mantissa | exponent;
565 f32::from_bits(float_bits)
566 }
567
568 /// Generates the next normally distributed random value using Box-Muller transform
569 ///
570 /// Produces a random value from a standard normal distribution (mean=0, std=1)
571 /// using the Box-Muller transform. This method converts uniform random variables
572 /// to normally distributed random variables efficiently.
573 ///
574 /// # Returns
575 ///
576 /// A random f32 value from N(0, 1) distribution
577 ///
578 /// # Performance
579 ///
580 /// - **Efficient**: Single pass transformation
581 /// - **Accurate**: Proper normal distribution properties
582 /// - **Stable**: Handles edge cases and numerical issues
583 ///
584 /// # Implementation Details
585 ///
586 /// The Box-Muller transform converts uniform random variables to normal:
587 /// 1. Generates two uniform random variables u1, u2 in (0, 1)
588 /// 2. Applies Box-Muller formula: z = sqrt(-2*ln(u1)) * cos(2π*u2)
589 /// 3. Handles edge cases (u1 ≤ 0, u1 ≥ 1, NaN, infinite values)
590 /// 4. Returns z as normally distributed random value
591 ///
592 /// The method includes robust error handling to ensure numerical stability
593 /// and prevent invalid results from edge cases in the transformation.
594 fn next_normal(&mut self) -> f32 {
595 // Box-Muller transform: convert uniform random variables to normal
596 let u1 = self.next_f32();
597 let u2 = self.next_f32();
598
599 // Avoid log(0) and ensure u1 is in (0, 1)
600 let u1 = if u1 <= 0.0 {
601 1e-7
602 } else if u1 >= 1.0 {
603 1.0 - 1e-7
604 } else {
605 u1
606 };
607
608 // Box-Muller transform
609 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
610
611 // Handle potential NaN or infinite values
612 if z0.is_nan() || z0.is_infinite() {
613 0.0
614 } else {
615 z0
616 }
617 }
618}
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623
624 #[test]
625 fn test_randn_basic() {
626 let tensor = Tensor::randn(vec![2, 3], Some(42));
627 assert_eq!(tensor.size(), 6);
628 assert_eq!(tensor.shape().dims, vec![2, 3]);
629
630 // With fixed seed, should be reproducible
631 let tensor2 = Tensor::randn(vec![2, 3], Some(42));
632 for i in 0..tensor.size() {
633 unsafe {
634 assert!((*tensor.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() < 1e-6);
635 }
636 }
637 }
638
639 #[test]
640 fn test_randn_reproducible() {
641 let seed = 12345;
642 let tensor1 = Tensor::randn(vec![100], Some(seed));
643 let tensor2 = Tensor::randn(vec![100], Some(seed));
644
645 // Should be identical with same seed
646 for i in 0..tensor1.size() {
647 unsafe {
648 assert!((*tensor1.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() < 1e-6);
649 }
650 }
651 }
652
653 #[test]
654 fn test_randn_different_seeds() {
655 let tensor1 = Tensor::randn(vec![100], Some(1));
656 let tensor2 = Tensor::randn(vec![100], Some(2));
657
658 // Should be different with different seeds
659 let mut different = false;
660 for i in 0..tensor1.size() {
661 unsafe {
662 if (*tensor1.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() > 1e-6 {
663 different = true;
664 break;
665 }
666 }
667 }
668 assert!(
669 different,
670 "Tensors with different seeds should be different"
671 );
672 }
673
674 #[test]
675 fn test_randn_no_seed() {
676 let tensor = Tensor::randn(vec![10], None);
677 assert_eq!(tensor.size(), 10);
678 assert_eq!(tensor.shape().dims, vec![10]);
679
680 // Should not be all zeros
681 let mut has_non_zero = false;
682 for i in 0..tensor.size() {
683 unsafe {
684 if *tensor.as_ptr().add(i) != 0.0 {
685 has_non_zero = true;
686 break;
687 }
688 }
689 }
690 assert!(has_non_zero, "Random tensor should not be all zeros");
691 }
692
693 #[test]
694 fn test_randn_zero_sized() {
695 let tensor = Tensor::randn(vec![0], Some(42));
696 assert_eq!(tensor.size(), 0);
697 assert_eq!(tensor.shape().dims, vec![0]);
698 }
699
700 #[test]
701 fn test_randn_large_tensor() {
702 let tensor = Tensor::randn(vec![100, 100], Some(42));
703 assert_eq!(tensor.size(), 10000);
704
705 // Check that values are reasonable (within 4 standard deviations)
706 let mut min_val = f32::INFINITY;
707 let mut max_val = f32::NEG_INFINITY;
708 let mut sum = 0.0;
709
710 for i in 0..tensor.size() {
711 unsafe {
712 let val = *tensor.as_ptr().add(i);
713 min_val = min_val.min(val);
714 max_val = max_val.max(val);
715 sum += val;
716 }
717 }
718
719 let mean = sum / tensor.size() as f32;
720
721 // Mean should be close to 0, values should be within reasonable bounds
722 assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
723 assert!(
724 min_val > -4.0,
725 "Values should not be too negative, min: {}",
726 min_val
727 );
728 assert!(
729 max_val < 4.0,
730 "Values should not be too positive, max: {}",
731 max_val
732 );
733 }
734
735 #[test]
736 fn test_fill_randn() {
737 let mut tensor = Tensor::new(vec![2, 3]);
738 tensor.fill_randn(Some(42));
739
740 assert_eq!(tensor.size(), 6);
741
742 // Should not be all zeros
743 let mut has_non_zero = false;
744 for i in 0..tensor.size() {
745 unsafe {
746 if *tensor.as_ptr().add(i) != 0.0 {
747 has_non_zero = true;
748 break;
749 }
750 }
751 }
752 assert!(has_non_zero, "Random tensor should not be all zeros");
753 }
754}