train_station/tensor/init/random.rs
1//! Random tensor initialization methods
2//!
3//! This module provides methods to create tensors with random values using
4//! efficient random number generation algorithms. All methods are optimized
5//! for performance with SIMD operations and provide reproducible results
6//! when seeds are specified.
7//!
8//! # Key Features
9//!
10//! - **`randn`**: Create tensors with normally distributed random values (mean=0, std=1)
11//! - **Box-Muller Transform**: Efficient normal distribution generation
12//! - **SIMD Optimization**: Vectorized operations for large tensors
13//! - **Reproducible Results**: Optional seed-based generation for deterministic output
14//! - **Thread Safety**: Thread-local random state management
15//! - **Statistical Quality**: High-quality random number generation
16//!
17//! # Performance Characteristics
18//!
19//! - **Box-Muller Transform**: Efficient normal distribution generation
20//! - **SIMD Operations**: AVX2-optimized operations for large tensors
21//! - **Memory Efficient**: Single-pass generation with optimized allocation
22//! - **Unrolled Loops**: 4x unrolling for better instruction throughput
23//! - **Zero Overhead**: Minimal validation overhead for correct usage
24//!
25//! # Examples
26//!
27//! ## Basic Random Generation
28//!
29//! ```
30//! use train_station::Tensor;
31//!
32//! // Create a 2x3 tensor with random normal values
33//! let tensor = Tensor::randn(vec![2, 3], None);
34//! assert_eq!(tensor.size(), 6);
35//! assert_eq!(tensor.shape().dims, vec![2, 3]);
36//!
37//! // Verify random values are generated
38//! let first_value = tensor.get(&[0, 0]);
39//! assert!(first_value != 0.0); // Should be random
40//! ```
41//!
42//! ## Reproducible Random Generation
43//!
44//! ```
45//! use train_station::Tensor;
46//!
47//! // Create with fixed seed for reproducible results
48//! let tensor1 = Tensor::randn(vec![100], Some(42));
49//! let tensor2 = Tensor::randn(vec![100], Some(42));
50//!
51//! // tensor1 and tensor2 will have identical values
52//! for i in 0..tensor1.size() {
53//! assert!((tensor1.get(&[i]) - tensor2.get(&[i])).abs() < 1e-6);
54//! }
55//! ```
56//!
57//! ## Different Seeds Produce Different Results
58//!
59//! ```
60//! use train_station::Tensor;
61//!
62//! let tensor1 = Tensor::randn(vec![100], Some(1));
63//! let tensor2 = Tensor::randn(vec![100], Some(2));
64//!
65//! // Should be different with different seeds
66//! let mut different = false;
67//! for i in 0..tensor1.size() {
68//! if (tensor1.get(&[i]) - tensor2.get(&[i])).abs() > 1e-6 {
69//! different = true;
70//! break;
71//! }
72//! }
73//! assert!(different, "Tensors with different seeds should be different");
74//! ```
75//!
76//! ## Large Tensor Generation
77//!
78//! ```
79//! use train_station::Tensor;
80//!
81//! // Efficient generation of large tensors
82//! let tensor = Tensor::randn(vec![100, 100], Some(42));
83//! assert_eq!(tensor.size(), 10000);
84//!
85//! // Check statistical properties
86//! let mut sum = 0.0;
87//! for i in 0..tensor.size() {
88//! sum += tensor.get(&[i / 100, i % 100]);
89//! }
90//! let mean = sum / tensor.size() as f32;
91//!
92//! // Mean should be close to 0 for normal distribution
93//! assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
94//! ```
95//!
96//! ## Zero-Sized Tensor Handling
97//!
98//! ```
99//! use train_station::Tensor;
100//!
101//! // Handle empty tensors gracefully
102//! let tensor = Tensor::randn(vec![0], Some(42));
103//! assert_eq!(tensor.size(), 0);
104//! assert_eq!(tensor.shape().dims, vec![0]);
105//! ```
106//!
107//! # Design Principles
108//!
109//! - **Statistical Quality**: High-quality random number generation with proper distribution
110//! - **Performance First**: SIMD-optimized operations for maximum speed
111//! - **Reproducibility**: Optional seed-based generation for deterministic results
112//! - **Memory Efficiency**: Efficient memory operations with minimal overhead
113//! - **Thread Safety**: Safe concurrent access with thread-local state
114//! - **Numerical Stability**: Robust handling of edge cases and numerical issues
115
116use crate::tensor::core::Tensor;
117use std::collections::hash_map::DefaultHasher;
118use std::hash::{Hash, Hasher};
119
120impl Tensor {
121 /// Creates a tensor with normally distributed random values (mean=0, std=1)
122 ///
123 /// Similar to PyTorch's `torch.randn()`, creates a tensor filled with random
124 /// values drawn from a standard normal distribution (mean=0, standard deviation=1).
125 /// Uses Box-Muller transform for efficient normal distribution generation.
126 ///
127 /// This method provides high-quality random number generation with optional
128 /// reproducibility through seed-based generation. The generated values follow
129 /// a standard normal distribution suitable for machine learning applications.
130 ///
131 /// # Arguments
132 ///
133 /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
134 /// * `seed` - Optional seed for reproducible random generation
135 ///
136 /// # Returns
137 ///
138 /// A new tensor with normally distributed random values
139 ///
140 /// # Performance
141 ///
142 /// - **Box-Muller Transform**: Efficient normal distribution generation
143 /// - **SIMD Optimization**: Vectorized operations for large tensors
144 /// - **Memory Efficient**: Single-pass generation with optimized allocation
145 /// - **Thread Safe**: Uses thread-local random state
146 ///
147 /// # Examples
148 ///
149 /// ## Basic Usage
150 ///
151 /// ```
152 /// use train_station::Tensor;
153 ///
154 /// // Create a 2x3 tensor with random normal values
155 /// let tensor = Tensor::randn(vec![2, 3], None);
156 /// assert_eq!(tensor.size(), 6);
157 /// assert_eq!(tensor.shape().dims, vec![2, 3]);
158 ///
159 /// // Verify random values are generated
160 /// let first_value = tensor.get(&[0, 0]);
161 /// assert!(first_value != 0.0); // Should be random
162 /// ```
163 ///
164 /// ## Reproducible Generation
165 ///
166 /// ```
167 /// use train_station::Tensor;
168 ///
169 /// // Create with fixed seed for reproducible results
170 /// let tensor1 = Tensor::randn(vec![100], Some(42));
171 /// let tensor2 = Tensor::randn(vec![100], Some(42));
172 ///
173 /// // tensor1 and tensor2 will have identical values
174 /// for i in 0..tensor1.size() {
175 /// assert!((tensor1.get(&[i]) - tensor2.get(&[i])).abs() < 1e-6);
176 /// }
177 /// ```
178 ///
179 /// ## Statistical Properties
180 ///
181 /// ```
182 /// use train_station::Tensor;
183 ///
184 /// // Generate large tensor for statistical analysis
185 /// let tensor = Tensor::randn(vec![1000], Some(42));
186 /// assert_eq!(tensor.size(), 1000);
187 ///
188 /// // Check that values are reasonable (within 4 standard deviations)
189 /// let mut min_val = f32::INFINITY;
190 /// let mut max_val = f32::NEG_INFINITY;
191 /// let mut sum = 0.0;
192 ///
193 /// for i in 0..tensor.size() {
194 /// let val = tensor.get(&[i]);
195 /// min_val = min_val.min(val);
196 /// max_val = max_val.max(val);
197 /// sum += val;
198 /// }
199 ///
200 /// let mean = sum / tensor.size() as f32;
201 ///
202 /// // Mean should be close to 0, values should be within reasonable bounds
203 /// assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
204 /// assert!(min_val > -4.0, "Values should not be too negative, min: {}", min_val);
205 /// assert!(max_val < 4.0, "Values should not be too positive, max: {}", max_val);
206 /// ```
207 ///
208 /// ## Zero-Sized Tensors
209 ///
210 /// ```
211 /// use train_station::Tensor;
212 ///
213 /// // Handle empty tensors gracefully
214 /// let tensor = Tensor::randn(vec![0], Some(42));
215 /// assert_eq!(tensor.size(), 0);
216 /// assert_eq!(tensor.shape().dims, vec![0]);
217 /// ```
218 ///
219 /// # Implementation Details
220 ///
221 /// This method uses the Box-Muller transform to generate normally distributed
222 /// random variables from uniform random variables. The process involves:
223 /// 1. **Random Number Generation**: Uses Xorshift algorithm for uniform random numbers
224 /// 2. **Box-Muller Transform**: Converts uniform random variables to normal distribution
225 /// 3. **SIMD Optimization**: Vectorized operations for large tensors when available
226 /// 4. **Numerical Stability**: Robust handling of edge cases and potential NaN values
227 ///
228 /// The Box-Muller transform ensures that the generated values follow a true
229 /// normal distribution with mean=0 and standard deviation=1, making it suitable
230 /// for machine learning applications requiring normally distributed random values.
231 pub fn randn(shape_dims: Vec<usize>, seed: Option<u64>) -> Self {
232 let mut tensor = Self::new(shape_dims);
233 tensor.fill_randn(seed);
234 tensor
235 }
236
237 /// Fills the tensor with normally distributed random values
238 ///
239 /// Internal method that fills an existing tensor with random values from
240 /// a standard normal distribution. Uses Box-Muller transform for efficiency
241 /// and provides SIMD optimization for large tensors.
242 ///
243 /// This method is used internally by `randn()` and provides the core
244 /// random number generation functionality with optimized performance
245 /// characteristics.
246 ///
247 /// # Arguments
248 ///
249 /// * `seed` - Optional seed for reproducible random generation
250 ///
251 /// # Performance
252 ///
253 /// - **Box-Muller Transform**: Generates pairs of normal random variables
254 /// - **SIMD Optimization**: Vectorized operations when possible
255 /// - **Memory Efficient**: Single-pass generation
256 /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
257 ///
258 /// # Implementation Details
259 ///
260 /// The method performs the following steps:
261 /// 1. **Zero-sized Check**: Returns early for empty tensors
262 /// 2. **RNG Initialization**: Creates Xorshift RNG with seed or system time
263 /// 3. **SIMD Detection**: Checks for AVX2 availability for optimized path
264 /// 4. **Generation**: Uses SIMD or scalar path based on hardware support
265 /// 5. **Completion**: Fills all tensor elements with normal random values
266 ///
267 /// The method automatically handles hardware capabilities and falls back
268 /// to scalar operations when SIMD is not available, ensuring compatibility
269 /// across different CPU architectures.
270 pub fn fill_randn(&mut self, seed: Option<u64>) {
271 if self.shape().size == 0 {
272 return;
273 }
274
275 // Initialize random number generator
276 let mut rng = if let Some(seed_val) = seed {
277 // Use provided seed for reproducible results
278 XorShiftRng::new(seed_val)
279 } else {
280 // Use system time for non-reproducible results
281 XorShiftRng::new_from_time()
282 };
283
284 unsafe {
285 let ptr = self.as_ptr();
286
287 #[cfg(target_arch = "x86_64")]
288 {
289 // Use SIMD for better performance when available
290 if is_x86_feature_detected!("avx2") {
291 self.fill_randn_simd_avx2(ptr, &mut rng);
292 return;
293 }
294 }
295
296 // Fallback to scalar operations
297 self.fill_randn_scalar(ptr, &mut rng);
298 }
299 }
300
301 /// Fills the tensor with normally distributed random values using AVX2 SIMD
302 ///
303 /// Internal method that uses AVX2 instructions to efficiently fill large tensors
304 /// with normal random values. Processes 8 elements per iteration for maximum
305 /// memory bandwidth utilization.
306 ///
307 /// # Arguments
308 ///
309 /// * `ptr` - Pointer to the tensor data
310 /// * `rng` - Random number generator instance
311 ///
312 /// # Safety
313 ///
314 /// The caller must ensure:
315 /// * `ptr` is a valid pointer to tensor data
316 /// * The tensor size matches the allocated memory
317 /// * AVX2 is available on the target architecture
318 /// * `rng` is a valid random number generator instance
319 ///
320 /// # Performance
321 ///
322 /// - **SIMD Operations**: 8 elements per iteration using AVX2
323 /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
324 /// - **Remaining Elements**: Efficient handling of non-multiple-of-8 sizes
325 /// - **Box-Muller Transform**: Integrated normal distribution generation
326 ///
327 /// # Implementation Details
328 ///
329 /// This method uses AVX2 SIMD instructions to fill memory efficiently:
330 /// 1. Generates 8 normal random values using Box-Muller transform
331 /// 2. Loads values into AVX2 vector register using `_mm256_loadu_ps`
332 /// 3. Stores vector to memory using `_mm256_storeu_ps`
333 /// 4. Processes remaining elements with scalar operations
334 ///
335 /// The method provides significant performance improvements for large tensors
336 /// by reducing the number of memory operations and leveraging vectorized
337 /// floating-point operations.
338 #[cfg(target_arch = "x86_64")]
339 #[inline]
340 unsafe fn fill_randn_simd_avx2(&self, ptr: *const f32, rng: &mut XorShiftRng) {
341 let mut_ptr = ptr as *mut f32;
342 let size = self.shape().size;
343 let simd_count = size / 8; // Process 8 elements per iteration
344 let mut offset = 0;
345
346 // SIMD loop for normal distribution generation
347 for _ in 0..simd_count {
348 let mut values = [0.0f32; 8];
349 for i in &mut values {
350 *i = rng.next_normal();
351 }
352
353 // Store 8 values using SIMD
354 let vec = _mm256_loadu_ps(values.as_ptr());
355 _mm256_storeu_ps(mut_ptr.add(offset), vec);
356 offset += 8;
357 }
358
359 // Handle remaining elements
360 for i in offset..size {
361 *mut_ptr.add(i) = rng.next_normal();
362 }
363 }
364
365 /// Fills the tensor with normally distributed random values using scalar operations
366 ///
367 /// Internal fallback method that uses scalar operations to fill tensors with
368 /// normal random values. Provides 4x unrolled loops for better instruction
369 /// throughput and serves as a fallback when SIMD is not available.
370 ///
371 /// # Arguments
372 ///
373 /// * `ptr` - Pointer to the tensor data
374 /// * `rng` - Random number generator instance
375 ///
376 /// # Safety
377 ///
378 /// The caller must ensure:
379 /// * `ptr` is a valid pointer to tensor data
380 /// * The tensor size matches the allocated memory
381 /// * `rng` is a valid random number generator instance
382 ///
383 /// # Performance
384 ///
385 /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
386 /// - **Box-Muller Transform**: Integrated normal distribution generation
387 /// - **Remaining Elements**: Efficient handling of non-multiple-of-4 sizes
388 /// - **Cross-Platform**: Works on all CPU architectures
389 ///
390 /// # Implementation Details
391 ///
392 /// This method provides optimized scalar operations:
393 /// 1. **Unrolled Generation**: Processes 4 elements per iteration
394 /// 2. **Box-Muller Transform**: Generates normal random values
395 /// 3. **Remaining Elements**: Handles final elements individually
396 /// 4. **Cross-Platform**: No architecture-specific dependencies
397 ///
398 /// The 4x unrolling reduces loop overhead and improves instruction-level
399 /// parallelism, making scalar operations more efficient than naive loops.
400 #[inline]
401 unsafe fn fill_randn_scalar(&self, ptr: *const f32, rng: &mut XorShiftRng) {
402 let mut_ptr = ptr as *mut f32;
403 let size = self.shape().size;
404 let unroll_count = size / 4;
405 let mut offset = 0;
406
407 // Unrolled scalar loop for better performance
408 for _ in 0..unroll_count {
409 *mut_ptr.add(offset) = rng.next_normal();
410 *mut_ptr.add(offset + 1) = rng.next_normal();
411 *mut_ptr.add(offset + 2) = rng.next_normal();
412 *mut_ptr.add(offset + 3) = rng.next_normal();
413 offset += 4;
414 }
415
416 // Handle remaining elements
417 for i in offset..size {
418 *mut_ptr.add(i) = rng.next_normal();
419 }
420 }
421}
422
423// SIMD optimizations for performance-critical operations
424#[cfg(target_arch = "x86_64")]
425use std::arch::x86_64::*;
426
427/// Fast random number generator using Xorshift algorithm
428///
429/// Provides efficient random number generation with good statistical properties.
430/// Implements Box-Muller transform for normal distribution generation.
431///
432/// The Xorshift algorithm is a fast, non-cryptographic random number generator
433/// that provides good statistical properties for machine learning applications.
434/// It combines multiple bit-shift and XOR operations to produce high-quality
435/// random sequences with long periods.
436///
437/// # Performance
438///
439/// - **Fast Generation**: Minimal computational overhead
440/// - **Good Statistical Properties**: Passes standard statistical tests
441/// - **Long Period**: 2^64 - 1 period for u64 state
442/// - **Memory Efficient**: Single u64 state variable
443///
444/// # Implementation Details
445///
446/// The Xorshift algorithm uses three bit-shift and XOR operations:
447/// 1. `state ^= state << 13` - Left shift by 13 bits
448/// 2. `state ^= state >> 7` - Right shift by 7 bits
449/// 3. `state ^= state << 17` - Left shift by 17 bits
450///
451/// This sequence provides excellent statistical properties and is much faster
452/// than more complex generators like Mersenne Twister.
453struct XorShiftRng {
454 state: u64,
455}
456
457impl XorShiftRng {
458 /// Creates a new random number generator with the specified seed
459 ///
460 /// Initializes the RNG with a user-provided seed for reproducible
461 /// random number generation. The same seed will always produce
462 /// the same sequence of random numbers.
463 ///
464 /// # Arguments
465 ///
466 /// * `seed` - The seed value for reproducible generation
467 ///
468 /// # Implementation Details
469 ///
470 /// This method initializes the internal state with the provided seed value.
471 /// The same seed will always produce the same sequence of random numbers,
472 /// making it suitable for reproducible random number generation.
473 fn new(seed: u64) -> Self {
474 Self { state: seed }
475 }
476
477 /// Creates a new random number generator seeded from system time
478 ///
479 /// Initializes the RNG with a seed derived from the current system time,
480 /// providing non-reproducible random number generation. Each call will
481 /// produce a different sequence of random numbers.
482 ///
483 /// # Implementation Details
484 ///
485 /// This method uses the system time to generate a seed:
486 /// 1. Gets current system time using `std::time::SystemTime::now()`
487 /// 2. Hashes the time value using `DefaultHasher`
488 /// 3. Uses the hash result as the RNG seed
489 ///
490 /// This approach provides good entropy for non-reproducible generation
491 /// while being efficient and portable across different platforms.
492 fn new_from_time() -> Self {
493 let mut hasher = DefaultHasher::new();
494 std::time::SystemTime::now().hash(&mut hasher);
495 let seed = hasher.finish();
496 Self { state: seed }
497 }
498
499 /// Generates the next random u64 value
500 ///
501 /// Produces the next 64-bit random value using the Xorshift algorithm.
502 /// This is the core random number generation method that drives all
503 /// other random value generation.
504 ///
505 /// # Returns
506 ///
507 /// A random u64 value from the Xorshift sequence
508 ///
509 /// # Performance
510 ///
511 /// - **Fast**: Only 3 bit-shift and XOR operations
512 /// - **Efficient**: No branching or complex arithmetic
513 /// - **Deterministic**: Same seed produces same sequence
514 ///
515 /// # Implementation Details
516 ///
517 /// The Xorshift algorithm performs three operations:
518 /// 1. `self.state ^= self.state << 13` - Left shift and XOR
519 /// 2. `self.state ^= self.state >> 7` - Right shift and XOR
520 /// 3. `self.state ^= self.state << 17` - Left shift and XOR
521 ///
522 /// This sequence provides excellent statistical properties with minimal
523 /// computational overhead, making it ideal for high-performance applications.
524 fn next_u64(&mut self) -> u64 {
525 self.state ^= self.state << 13;
526 self.state ^= self.state >> 7;
527 self.state ^= self.state << 17;
528 self.state
529 }
530
531 /// Generates the next random f32 value in [0, 1)
532 ///
533 /// Produces a random floating-point value in the half-open interval [0, 1).
534 /// Uses proper bit manipulation to ensure uniform distribution across
535 /// the floating-point range.
536 ///
537 /// # Returns
538 ///
539 /// A random f32 value in [0, 1)
540 ///
541 /// # Performance
542 ///
543 /// - **Efficient**: Single u64 generation with bit manipulation
544 /// - **Uniform**: Proper distribution across floating-point range
545 /// - **Precise**: Uses 23 bits of mantissa for good precision
546 ///
547 /// # Implementation Details
548 ///
549 /// The method converts u64 random bits to f32 using bit manipulation:
550 /// 1. Generates random u64 using `next_u64()`
551 /// 2. Extracts 23 bits for mantissa (IEEE 754 f32 format)
552 /// 3. Sets exponent to 0 (bias 126) for values in [1, 2)
553 /// 4. Converts to f32 and subtracts 1.0 for [0, 1) range
554 ///
555 /// This approach provides uniform distribution and avoids the bias
556 /// that can occur with simple division-based methods.
557 fn next_f32(&mut self) -> f32 {
558 let bits = self.next_u64();
559 // Convert to f32 in [0, 1) using proper bit manipulation
560 let mantissa = (bits & 0x7FFFFF) as u32; // 23 bits for mantissa
561 let exponent = 126 << 23; // 2^0 = 1, so bias 126 gives exponent 0
562 let float_bits = mantissa | exponent;
563 f32::from_bits(float_bits)
564 }
565
566 /// Generates the next normally distributed random value using Box-Muller transform
567 ///
568 /// Produces a random value from a standard normal distribution (mean=0, std=1)
569 /// using the Box-Muller transform. This method converts uniform random variables
570 /// to normally distributed random variables efficiently.
571 ///
572 /// # Returns
573 ///
574 /// A random f32 value from N(0, 1) distribution
575 ///
576 /// # Performance
577 ///
578 /// - **Efficient**: Single pass transformation
579 /// - **Accurate**: Proper normal distribution properties
580 /// - **Stable**: Handles edge cases and numerical issues
581 ///
582 /// # Implementation Details
583 ///
584 /// The Box-Muller transform converts uniform random variables to normal:
585 /// 1. Generates two uniform random variables u1, u2 in (0, 1)
586 /// 2. Applies Box-Muller formula: z = sqrt(-2*ln(u1)) * cos(2π*u2)
587 /// 3. Handles edge cases (u1 ≤ 0, u1 ≥ 1, NaN, infinite values)
588 /// 4. Returns z as normally distributed random value
589 ///
590 /// The method includes robust error handling to ensure numerical stability
591 /// and prevent invalid results from edge cases in the transformation.
592 fn next_normal(&mut self) -> f32 {
593 // Box-Muller transform: convert uniform random variables to normal
594 let u1 = self.next_f32();
595 let u2 = self.next_f32();
596
597 // Avoid log(0) and ensure u1 is in (0, 1)
598 let u1 = if u1 <= 0.0 {
599 1e-7
600 } else if u1 >= 1.0 {
601 1.0 - 1e-7
602 } else {
603 u1
604 };
605
606 // Box-Muller transform
607 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
608
609 // Handle potential NaN or infinite values
610 if z0.is_nan() || z0.is_infinite() {
611 0.0
612 } else {
613 z0
614 }
615 }
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621
622 #[test]
623 fn test_randn_basic() {
624 let tensor = Tensor::randn(vec![2, 3], Some(42));
625 assert_eq!(tensor.size(), 6);
626 assert_eq!(tensor.shape().dims, vec![2, 3]);
627
628 // With fixed seed, should be reproducible
629 let tensor2 = Tensor::randn(vec![2, 3], Some(42));
630 for i in 0..tensor.size() {
631 unsafe {
632 assert!((*tensor.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() < 1e-6);
633 }
634 }
635 }
636
637 #[test]
638 fn test_randn_reproducible() {
639 let seed = 12345;
640 let tensor1 = Tensor::randn(vec![100], Some(seed));
641 let tensor2 = Tensor::randn(vec![100], Some(seed));
642
643 // Should be identical with same seed
644 for i in 0..tensor1.size() {
645 unsafe {
646 assert!((*tensor1.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() < 1e-6);
647 }
648 }
649 }
650
651 #[test]
652 fn test_randn_different_seeds() {
653 let tensor1 = Tensor::randn(vec![100], Some(1));
654 let tensor2 = Tensor::randn(vec![100], Some(2));
655
656 // Should be different with different seeds
657 let mut different = false;
658 for i in 0..tensor1.size() {
659 unsafe {
660 if (*tensor1.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() > 1e-6 {
661 different = true;
662 break;
663 }
664 }
665 }
666 assert!(
667 different,
668 "Tensors with different seeds should be different"
669 );
670 }
671
672 #[test]
673 fn test_randn_no_seed() {
674 let tensor = Tensor::randn(vec![10], None);
675 assert_eq!(tensor.size(), 10);
676 assert_eq!(tensor.shape().dims, vec![10]);
677
678 // Should not be all zeros
679 let mut has_non_zero = false;
680 for i in 0..tensor.size() {
681 unsafe {
682 if *tensor.as_ptr().add(i) != 0.0 {
683 has_non_zero = true;
684 break;
685 }
686 }
687 }
688 assert!(has_non_zero, "Random tensor should not be all zeros");
689 }
690
691 #[test]
692 fn test_randn_zero_sized() {
693 let tensor = Tensor::randn(vec![0], Some(42));
694 assert_eq!(tensor.size(), 0);
695 assert_eq!(tensor.shape().dims, vec![0]);
696 }
697
698 #[test]
699 fn test_randn_large_tensor() {
700 let tensor = Tensor::randn(vec![100, 100], Some(42));
701 assert_eq!(tensor.size(), 10000);
702
703 // Check that values are reasonable (within 4 standard deviations)
704 let mut min_val = f32::INFINITY;
705 let mut max_val = f32::NEG_INFINITY;
706 let mut sum = 0.0;
707
708 for i in 0..tensor.size() {
709 unsafe {
710 let val = *tensor.as_ptr().add(i);
711 min_val = min_val.min(val);
712 max_val = max_val.max(val);
713 sum += val;
714 }
715 }
716
717 let mean = sum / tensor.size() as f32;
718
719 // Mean should be close to 0, values should be within reasonable bounds
720 assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
721 assert!(
722 min_val > -4.0,
723 "Values should not be too negative, min: {}",
724 min_val
725 );
726 assert!(
727 max_val < 4.0,
728 "Values should not be too positive, max: {}",
729 max_val
730 );
731 }
732
733 #[test]
734 fn test_fill_randn() {
735 let mut tensor = Tensor::new(vec![2, 3]);
736 tensor.fill_randn(Some(42));
737
738 assert_eq!(tensor.size(), 6);
739
740 // Should not be all zeros
741 let mut has_non_zero = false;
742 for i in 0..tensor.size() {
743 unsafe {
744 if *tensor.as_ptr().add(i) != 0.0 {
745 has_non_zero = true;
746 break;
747 }
748 }
749 }
750 assert!(has_non_zero, "Random tensor should not be all zeros");
751 }
752}