train_station/tensor/init/
random.rs

1//! Random tensor initialization methods
2//!
3//! This module provides methods to create tensors with random values using
4//! efficient random number generation algorithms. All methods are optimized
5//! for performance with SIMD operations and provide reproducible results
6//! when seeds are specified.
7//!
8//! # Key Features
9//!
10//! - **`randn`**: Create tensors with normally distributed random values (mean=0, std=1)
11//! - **Box-Muller Transform**: Efficient normal distribution generation
12//! - **SIMD Optimization**: Vectorized operations for large tensors
13//! - **Reproducible Results**: Optional seed-based generation for deterministic output
14//! - **Thread Safety**: Thread-local random state management
15//! - **Statistical Quality**: High-quality random number generation
16//!
17//! # Performance Characteristics
18//!
19//! - **Box-Muller Transform**: Efficient normal distribution generation
20//! - **SIMD Operations**: AVX2-optimized operations for large tensors
21//! - **Memory Efficient**: Single-pass generation with optimized allocation
22//! - **Unrolled Loops**: 4x unrolling for better instruction throughput
23//! - **Zero Overhead**: Minimal validation overhead for correct usage
24//!
25//! # Examples
26//!
27//! ## Basic Random Generation
28//!
29//! ```
30//! use train_station::Tensor;
31//!
32//! // Create a 2x3 tensor with random normal values
33//! let tensor = Tensor::randn(vec![2, 3], None);
34//! assert_eq!(tensor.size(), 6);
35//! assert_eq!(tensor.shape().dims, vec![2, 3]);
36//!
37//! // Verify random values are generated
38//! let first_value = tensor.get(&[0, 0]);
39//! assert!(first_value != 0.0); // Should be random
40//! ```
41//!
42//! ## Reproducible Random Generation
43//!
44//! ```
45//! use train_station::Tensor;
46//!
47//! // Create with fixed seed for reproducible results
48//! let tensor1 = Tensor::randn(vec![100], Some(42));
49//! let tensor2 = Tensor::randn(vec![100], Some(42));
50//!
51//! // tensor1 and tensor2 will have identical values
52//! for i in 0..tensor1.size() {
53//!     assert!((tensor1.get(&[i]) - tensor2.get(&[i])).abs() < 1e-6);
54//! }
55//! ```
56//!
57//! ## Different Seeds Produce Different Results
58//!
59//! ```
60//! use train_station::Tensor;
61//!
62//! let tensor1 = Tensor::randn(vec![100], Some(1));
63//! let tensor2 = Tensor::randn(vec![100], Some(2));
64//!
65//! // Should be different with different seeds
66//! let mut different = false;
67//! for i in 0..tensor1.size() {
68//!     if (tensor1.get(&[i]) - tensor2.get(&[i])).abs() > 1e-6 {
69//!         different = true;
70//!         break;
71//!     }
72//! }
73//! assert!(different, "Tensors with different seeds should be different");
74//! ```
75//!
76//! ## Large Tensor Generation
77//!
78//! ```
79//! use train_station::Tensor;
80//!
81//! // Efficient generation of large tensors
82//! let tensor = Tensor::randn(vec![100, 100], Some(42));
83//! assert_eq!(tensor.size(), 10000);
84//!
85//! // Check statistical properties
86//! let mut sum = 0.0;
87//! for i in 0..tensor.size() {
88//!     sum += tensor.get(&[i / 100, i % 100]);
89//! }
90//! let mean = sum / tensor.size() as f32;
91//!
92//! // Mean should be close to 0 for normal distribution
93//! assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
94//! ```
95//!
96//! ## Zero-Sized Tensor Handling
97//!
98//! ```
99//! use train_station::Tensor;
100//!
101//! // Handle empty tensors gracefully
102//! let tensor = Tensor::randn(vec![0], Some(42));
103//! assert_eq!(tensor.size(), 0);
104//! assert_eq!(tensor.shape().dims, vec![0]);
105//! ```
106//!
107//! # Design Principles
108//!
109//! - **Statistical Quality**: High-quality random number generation with proper distribution
110//! - **Performance First**: SIMD-optimized operations for maximum speed
111//! - **Reproducibility**: Optional seed-based generation for deterministic results
112//! - **Memory Efficiency**: Efficient memory operations with minimal overhead
113//! - **Thread Safety**: Safe concurrent access with thread-local state
114//! - **Numerical Stability**: Robust handling of edge cases and numerical issues
115
116use crate::tensor::core::Tensor;
117use std::collections::hash_map::DefaultHasher;
118use std::hash::{Hash, Hasher};
119
120impl Tensor {
121    /// Creates a tensor with normally distributed random values (mean=0, std=1)
122    ///
123    /// Similar to PyTorch's `torch.randn()`, creates a tensor filled with random
124    /// values drawn from a standard normal distribution (mean=0, standard deviation=1).
125    /// Uses Box-Muller transform for efficient normal distribution generation.
126    ///
127    /// This method provides high-quality random number generation with optional
128    /// reproducibility through seed-based generation. The generated values follow
129    /// a standard normal distribution suitable for machine learning applications.
130    ///
131    /// # Arguments
132    ///
133    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
134    /// * `seed` - Optional seed for reproducible random generation
135    ///
136    /// # Returns
137    ///
138    /// A new tensor with normally distributed random values
139    ///
140    /// # Performance
141    ///
142    /// - **Box-Muller Transform**: Efficient normal distribution generation
143    /// - **SIMD Optimization**: Vectorized operations for large tensors
144    /// - **Memory Efficient**: Single-pass generation with optimized allocation
145    /// - **Thread Safe**: Uses thread-local random state
146    ///
147    /// # Examples
148    ///
149    /// ## Basic Usage
150    ///
151    /// ```
152    /// use train_station::Tensor;
153    ///
154    /// // Create a 2x3 tensor with random normal values
155    /// let tensor = Tensor::randn(vec![2, 3], None);
156    /// assert_eq!(tensor.size(), 6);
157    /// assert_eq!(tensor.shape().dims, vec![2, 3]);
158    ///
159    /// // Verify random values are generated
160    /// let first_value = tensor.get(&[0, 0]);
161    /// assert!(first_value != 0.0); // Should be random
162    /// ```
163    ///
164    /// ## Reproducible Generation
165    ///
166    /// ```
167    /// use train_station::Tensor;
168    ///
169    /// // Create with fixed seed for reproducible results
170    /// let tensor1 = Tensor::randn(vec![100], Some(42));
171    /// let tensor2 = Tensor::randn(vec![100], Some(42));
172    ///
173    /// // tensor1 and tensor2 will have identical values
174    /// for i in 0..tensor1.size() {
175    ///     assert!((tensor1.get(&[i]) - tensor2.get(&[i])).abs() < 1e-6);
176    /// }
177    /// ```
178    ///
179    /// ## Statistical Properties
180    ///
181    /// ```
182    /// use train_station::Tensor;
183    ///
184    /// // Generate large tensor for statistical analysis
185    /// let tensor = Tensor::randn(vec![1000], Some(42));
186    /// assert_eq!(tensor.size(), 1000);
187    ///
188    /// // Check that values are reasonable (within 4 standard deviations)
189    /// let mut min_val = f32::INFINITY;
190    /// let mut max_val = f32::NEG_INFINITY;
191    /// let mut sum = 0.0;
192    ///
193    /// for i in 0..tensor.size() {
194    ///     let val = tensor.get(&[i]);
195    ///     min_val = min_val.min(val);
196    ///     max_val = max_val.max(val);
197    ///     sum += val;
198    /// }
199    ///
200    /// let mean = sum / tensor.size() as f32;
201    ///
202    /// // Mean should be close to 0, values should be within reasonable bounds
203    /// assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
204    /// assert!(min_val > -4.0, "Values should not be too negative, min: {}", min_val);
205    /// assert!(max_val < 4.0, "Values should not be too positive, max: {}", max_val);
206    /// ```
207    ///
208    /// ## Zero-Sized Tensors
209    ///
210    /// ```
211    /// use train_station::Tensor;
212    ///
213    /// // Handle empty tensors gracefully
214    /// let tensor = Tensor::randn(vec![0], Some(42));
215    /// assert_eq!(tensor.size(), 0);
216    /// assert_eq!(tensor.shape().dims, vec![0]);
217    /// ```
218    ///
219    /// # Implementation Details
220    ///
221    /// This method uses the Box-Muller transform to generate normally distributed
222    /// random variables from uniform random variables. The process involves:
223    /// 1. **Random Number Generation**: Uses Xorshift algorithm for uniform random numbers
224    /// 2. **Box-Muller Transform**: Converts uniform random variables to normal distribution
225    /// 3. **SIMD Optimization**: Vectorized operations for large tensors when available
226    /// 4. **Numerical Stability**: Robust handling of edge cases and potential NaN values
227    ///
228    /// The Box-Muller transform ensures that the generated values follow a true
229    /// normal distribution with mean=0 and standard deviation=1, making it suitable
230    /// for machine learning applications requiring normally distributed random values.
231    pub fn randn(shape_dims: Vec<usize>, seed: Option<u64>) -> Self {
232        let mut tensor = Self::new(shape_dims);
233        tensor.fill_randn(seed);
234        tensor
235    }
236
237    /// Fills the tensor with normally distributed random values
238    ///
239    /// Internal method that fills an existing tensor with random values from
240    /// a standard normal distribution. Uses Box-Muller transform for efficiency
241    /// and provides SIMD optimization for large tensors.
242    ///
243    /// This method is used internally by `randn()` and provides the core
244    /// random number generation functionality with optimized performance
245    /// characteristics.
246    ///
247    /// # Arguments
248    ///
249    /// * `seed` - Optional seed for reproducible random generation
250    ///
251    /// # Performance
252    ///
253    /// - **Box-Muller Transform**: Generates pairs of normal random variables
254    /// - **SIMD Optimization**: Vectorized operations when possible
255    /// - **Memory Efficient**: Single-pass generation
256    /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
257    ///
258    /// # Implementation Details
259    ///
260    /// The method performs the following steps:
261    /// 1. **Zero-sized Check**: Returns early for empty tensors
262    /// 2. **RNG Initialization**: Creates Xorshift RNG with seed or system time
263    /// 3. **SIMD Detection**: Checks for AVX2 availability for optimized path
264    /// 4. **Generation**: Uses SIMD or scalar path based on hardware support
265    /// 5. **Completion**: Fills all tensor elements with normal random values
266    ///
267    /// The method automatically handles hardware capabilities and falls back
268    /// to scalar operations when SIMD is not available, ensuring compatibility
269    /// across different CPU architectures.
270    pub fn fill_randn(&mut self, seed: Option<u64>) {
271        if self.shape().size == 0 {
272            return;
273        }
274
275        // Initialize random number generator
276        let mut rng = if let Some(seed_val) = seed {
277            // Use provided seed for reproducible results
278            XorShiftRng::new(seed_val)
279        } else {
280            // Use system time for non-reproducible results
281            XorShiftRng::new_from_time()
282        };
283
284        unsafe {
285            let ptr = self.as_ptr();
286
287            #[cfg(target_arch = "x86_64")]
288            {
289                // Use SIMD for better performance when available
290                if is_x86_feature_detected!("avx2") {
291                    self.fill_randn_simd_avx2(ptr, &mut rng);
292                    return;
293                }
294            }
295
296            // Fallback to scalar operations
297            self.fill_randn_scalar(ptr, &mut rng);
298        }
299    }
300
301    /// Fills the tensor with normally distributed random values using AVX2 SIMD
302    ///
303    /// Internal method that uses AVX2 instructions to efficiently fill large tensors
304    /// with normal random values. Processes 8 elements per iteration for maximum
305    /// memory bandwidth utilization.
306    ///
307    /// # Arguments
308    ///
309    /// * `ptr` - Pointer to the tensor data
310    /// * `rng` - Random number generator instance
311    ///
312    /// # Safety
313    ///
314    /// The caller must ensure:
315    /// * `ptr` is a valid pointer to tensor data
316    /// * The tensor size matches the allocated memory
317    /// * AVX2 is available on the target architecture
318    /// * `rng` is a valid random number generator instance
319    ///
320    /// # Performance
321    ///
322    /// - **SIMD Operations**: 8 elements per iteration using AVX2
323    /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
324    /// - **Remaining Elements**: Efficient handling of non-multiple-of-8 sizes
325    /// - **Box-Muller Transform**: Integrated normal distribution generation
326    ///
327    /// # Implementation Details
328    ///
329    /// This method uses AVX2 SIMD instructions to fill memory efficiently:
330    /// 1. Generates 8 normal random values using Box-Muller transform
331    /// 2. Loads values into AVX2 vector register using `_mm256_loadu_ps`
332    /// 3. Stores vector to memory using `_mm256_storeu_ps`
333    /// 4. Processes remaining elements with scalar operations
334    ///
335    /// The method provides significant performance improvements for large tensors
336    /// by reducing the number of memory operations and leveraging vectorized
337    /// floating-point operations.
338    #[cfg(target_arch = "x86_64")]
339    #[inline]
340    unsafe fn fill_randn_simd_avx2(&self, ptr: *const f32, rng: &mut XorShiftRng) {
341        let mut_ptr = ptr as *mut f32;
342        let size = self.shape().size;
343        let simd_count = size / 8; // Process 8 elements per iteration
344        let mut offset = 0;
345
346        // SIMD loop for normal distribution generation
347        for _ in 0..simd_count {
348            let mut values = [0.0f32; 8];
349            for i in &mut values {
350                *i = rng.next_normal();
351            }
352
353            // Store 8 values using SIMD
354            let vec = _mm256_loadu_ps(values.as_ptr());
355            _mm256_storeu_ps(mut_ptr.add(offset), vec);
356            offset += 8;
357        }
358
359        // Handle remaining elements
360        for i in offset..size {
361            *mut_ptr.add(i) = rng.next_normal();
362        }
363    }
364
365    /// Fills the tensor with normally distributed random values using scalar operations
366    ///
367    /// Internal fallback method that uses scalar operations to fill tensors with
368    /// normal random values. Provides 4x unrolled loops for better instruction
369    /// throughput and serves as a fallback when SIMD is not available.
370    ///
371    /// # Arguments
372    ///
373    /// * `ptr` - Pointer to the tensor data
374    /// * `rng` - Random number generator instance
375    ///
376    /// # Safety
377    ///
378    /// The caller must ensure:
379    /// * `ptr` is a valid pointer to tensor data
380    /// * The tensor size matches the allocated memory
381    /// * `rng` is a valid random number generator instance
382    ///
383    /// # Performance
384    ///
385    /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
386    /// - **Box-Muller Transform**: Integrated normal distribution generation
387    /// - **Remaining Elements**: Efficient handling of non-multiple-of-4 sizes
388    /// - **Cross-Platform**: Works on all CPU architectures
389    ///
390    /// # Implementation Details
391    ///
392    /// This method provides optimized scalar operations:
393    /// 1. **Unrolled Generation**: Processes 4 elements per iteration
394    /// 2. **Box-Muller Transform**: Generates normal random values
395    /// 3. **Remaining Elements**: Handles final elements individually
396    /// 4. **Cross-Platform**: No architecture-specific dependencies
397    ///
398    /// The 4x unrolling reduces loop overhead and improves instruction-level
399    /// parallelism, making scalar operations more efficient than naive loops.
400    #[inline]
401    unsafe fn fill_randn_scalar(&self, ptr: *const f32, rng: &mut XorShiftRng) {
402        let mut_ptr = ptr as *mut f32;
403        let size = self.shape().size;
404        let unroll_count = size / 4;
405        let mut offset = 0;
406
407        // Unrolled scalar loop for better performance
408        for _ in 0..unroll_count {
409            *mut_ptr.add(offset) = rng.next_normal();
410            *mut_ptr.add(offset + 1) = rng.next_normal();
411            *mut_ptr.add(offset + 2) = rng.next_normal();
412            *mut_ptr.add(offset + 3) = rng.next_normal();
413            offset += 4;
414        }
415
416        // Handle remaining elements
417        for i in offset..size {
418            *mut_ptr.add(i) = rng.next_normal();
419        }
420    }
421}
422
423// SIMD optimizations for performance-critical operations
424#[cfg(target_arch = "x86_64")]
425use std::arch::x86_64::*;
426
427/// Fast random number generator using Xorshift algorithm
428///
429/// Provides efficient random number generation with good statistical properties.
430/// Implements Box-Muller transform for normal distribution generation.
431///
432/// The Xorshift algorithm is a fast, non-cryptographic random number generator
433/// that provides good statistical properties for machine learning applications.
434/// It combines multiple bit-shift and XOR operations to produce high-quality
435/// random sequences with long periods.
436///
437/// # Performance
438///
439/// - **Fast Generation**: Minimal computational overhead
440/// - **Good Statistical Properties**: Passes standard statistical tests
441/// - **Long Period**: 2^64 - 1 period for u64 state
442/// - **Memory Efficient**: Single u64 state variable
443///
444/// # Implementation Details
445///
446/// The Xorshift algorithm uses three bit-shift and XOR operations:
447/// 1. `state ^= state << 13` - Left shift by 13 bits
448/// 2. `state ^= state >> 7` - Right shift by 7 bits  
449/// 3. `state ^= state << 17` - Left shift by 17 bits
450///
451/// This sequence provides excellent statistical properties and is much faster
452/// than more complex generators like Mersenne Twister.
453struct XorShiftRng {
454    state: u64,
455}
456
457impl XorShiftRng {
458    /// Creates a new random number generator with the specified seed
459    ///
460    /// Initializes the RNG with a user-provided seed for reproducible
461    /// random number generation. The same seed will always produce
462    /// the same sequence of random numbers.
463    ///
464    /// # Arguments
465    ///
466    /// * `seed` - The seed value for reproducible generation
467    ///
468    /// # Implementation Details
469    ///
470    /// This method initializes the internal state with the provided seed value.
471    /// The same seed will always produce the same sequence of random numbers,
472    /// making it suitable for reproducible random number generation.
473    fn new(seed: u64) -> Self {
474        Self { state: seed }
475    }
476
477    /// Creates a new random number generator seeded from system time
478    ///
479    /// Initializes the RNG with a seed derived from the current system time,
480    /// providing non-reproducible random number generation. Each call will
481    /// produce a different sequence of random numbers.
482    ///
483    /// # Implementation Details
484    ///
485    /// This method uses the system time to generate a seed:
486    /// 1. Gets current system time using `std::time::SystemTime::now()`
487    /// 2. Hashes the time value using `DefaultHasher`
488    /// 3. Uses the hash result as the RNG seed
489    ///
490    /// This approach provides good entropy for non-reproducible generation
491    /// while being efficient and portable across different platforms.
492    fn new_from_time() -> Self {
493        let mut hasher = DefaultHasher::new();
494        std::time::SystemTime::now().hash(&mut hasher);
495        let seed = hasher.finish();
496        Self { state: seed }
497    }
498
499    /// Generates the next random u64 value
500    ///
501    /// Produces the next 64-bit random value using the Xorshift algorithm.
502    /// This is the core random number generation method that drives all
503    /// other random value generation.
504    ///
505    /// # Returns
506    ///
507    /// A random u64 value from the Xorshift sequence
508    ///
509    /// # Performance
510    ///
511    /// - **Fast**: Only 3 bit-shift and XOR operations
512    /// - **Efficient**: No branching or complex arithmetic
513    /// - **Deterministic**: Same seed produces same sequence
514    ///
515    /// # Implementation Details
516    ///
517    /// The Xorshift algorithm performs three operations:
518    /// 1. `self.state ^= self.state << 13` - Left shift and XOR
519    /// 2. `self.state ^= self.state >> 7` - Right shift and XOR
520    /// 3. `self.state ^= self.state << 17` - Left shift and XOR
521    ///
522    /// This sequence provides excellent statistical properties with minimal
523    /// computational overhead, making it ideal for high-performance applications.
524    fn next_u64(&mut self) -> u64 {
525        self.state ^= self.state << 13;
526        self.state ^= self.state >> 7;
527        self.state ^= self.state << 17;
528        self.state
529    }
530
531    /// Generates the next random f32 value in [0, 1)
532    ///
533    /// Produces a random floating-point value in the half-open interval [0, 1).
534    /// Uses proper bit manipulation to ensure uniform distribution across
535    /// the floating-point range.
536    ///
537    /// # Returns
538    ///
539    /// A random f32 value in [0, 1)
540    ///
541    /// # Performance
542    ///
543    /// - **Efficient**: Single u64 generation with bit manipulation
544    /// - **Uniform**: Proper distribution across floating-point range
545    /// - **Precise**: Uses 23 bits of mantissa for good precision
546    ///
547    /// # Implementation Details
548    ///
549    /// The method converts u64 random bits to f32 using bit manipulation:
550    /// 1. Generates random u64 using `next_u64()`
551    /// 2. Extracts 23 bits for mantissa (IEEE 754 f32 format)
552    /// 3. Sets exponent to 0 (bias 126) for values in [1, 2)
553    /// 4. Converts to f32 and subtracts 1.0 for [0, 1) range
554    ///
555    /// This approach provides uniform distribution and avoids the bias
556    /// that can occur with simple division-based methods.
557    fn next_f32(&mut self) -> f32 {
558        let bits = self.next_u64();
559        // Convert to f32 in [0, 1) using proper bit manipulation
560        let mantissa = (bits & 0x7FFFFF) as u32; // 23 bits for mantissa
561        let exponent = 126 << 23; // 2^0 = 1, so bias 126 gives exponent 0
562        let float_bits = mantissa | exponent;
563        f32::from_bits(float_bits)
564    }
565
566    /// Generates the next normally distributed random value using Box-Muller transform
567    ///
568    /// Produces a random value from a standard normal distribution (mean=0, std=1)
569    /// using the Box-Muller transform. This method converts uniform random variables
570    /// to normally distributed random variables efficiently.
571    ///
572    /// # Returns
573    ///
574    /// A random f32 value from N(0, 1) distribution
575    ///
576    /// # Performance
577    ///
578    /// - **Efficient**: Single pass transformation
579    /// - **Accurate**: Proper normal distribution properties
580    /// - **Stable**: Handles edge cases and numerical issues
581    ///
582    /// # Implementation Details
583    ///
584    /// The Box-Muller transform converts uniform random variables to normal:
585    /// 1. Generates two uniform random variables u1, u2 in (0, 1)
586    /// 2. Applies Box-Muller formula: z = sqrt(-2*ln(u1)) * cos(2π*u2)
587    /// 3. Handles edge cases (u1 ≤ 0, u1 ≥ 1, NaN, infinite values)
588    /// 4. Returns z as normally distributed random value
589    ///
590    /// The method includes robust error handling to ensure numerical stability
591    /// and prevent invalid results from edge cases in the transformation.
592    fn next_normal(&mut self) -> f32 {
593        // Box-Muller transform: convert uniform random variables to normal
594        let u1 = self.next_f32();
595        let u2 = self.next_f32();
596
597        // Avoid log(0) and ensure u1 is in (0, 1)
598        let u1 = if u1 <= 0.0 {
599            1e-7
600        } else if u1 >= 1.0 {
601            1.0 - 1e-7
602        } else {
603            u1
604        };
605
606        // Box-Muller transform
607        let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
608
609        // Handle potential NaN or infinite values
610        if z0.is_nan() || z0.is_infinite() {
611            0.0
612        } else {
613            z0
614        }
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    #[test]
623    fn test_randn_basic() {
624        let tensor = Tensor::randn(vec![2, 3], Some(42));
625        assert_eq!(tensor.size(), 6);
626        assert_eq!(tensor.shape().dims, vec![2, 3]);
627
628        // With fixed seed, should be reproducible
629        let tensor2 = Tensor::randn(vec![2, 3], Some(42));
630        for i in 0..tensor.size() {
631            unsafe {
632                assert!((*tensor.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() < 1e-6);
633            }
634        }
635    }
636
637    #[test]
638    fn test_randn_reproducible() {
639        let seed = 12345;
640        let tensor1 = Tensor::randn(vec![100], Some(seed));
641        let tensor2 = Tensor::randn(vec![100], Some(seed));
642
643        // Should be identical with same seed
644        for i in 0..tensor1.size() {
645            unsafe {
646                assert!((*tensor1.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() < 1e-6);
647            }
648        }
649    }
650
651    #[test]
652    fn test_randn_different_seeds() {
653        let tensor1 = Tensor::randn(vec![100], Some(1));
654        let tensor2 = Tensor::randn(vec![100], Some(2));
655
656        // Should be different with different seeds
657        let mut different = false;
658        for i in 0..tensor1.size() {
659            unsafe {
660                if (*tensor1.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() > 1e-6 {
661                    different = true;
662                    break;
663                }
664            }
665        }
666        assert!(
667            different,
668            "Tensors with different seeds should be different"
669        );
670    }
671
672    #[test]
673    fn test_randn_no_seed() {
674        let tensor = Tensor::randn(vec![10], None);
675        assert_eq!(tensor.size(), 10);
676        assert_eq!(tensor.shape().dims, vec![10]);
677
678        // Should not be all zeros
679        let mut has_non_zero = false;
680        for i in 0..tensor.size() {
681            unsafe {
682                if *tensor.as_ptr().add(i) != 0.0 {
683                    has_non_zero = true;
684                    break;
685                }
686            }
687        }
688        assert!(has_non_zero, "Random tensor should not be all zeros");
689    }
690
691    #[test]
692    fn test_randn_zero_sized() {
693        let tensor = Tensor::randn(vec![0], Some(42));
694        assert_eq!(tensor.size(), 0);
695        assert_eq!(tensor.shape().dims, vec![0]);
696    }
697
698    #[test]
699    fn test_randn_large_tensor() {
700        let tensor = Tensor::randn(vec![100, 100], Some(42));
701        assert_eq!(tensor.size(), 10000);
702
703        // Check that values are reasonable (within 4 standard deviations)
704        let mut min_val = f32::INFINITY;
705        let mut max_val = f32::NEG_INFINITY;
706        let mut sum = 0.0;
707
708        for i in 0..tensor.size() {
709            unsafe {
710                let val = *tensor.as_ptr().add(i);
711                min_val = min_val.min(val);
712                max_val = max_val.max(val);
713                sum += val;
714            }
715        }
716
717        let mean = sum / tensor.size() as f32;
718
719        // Mean should be close to 0, values should be within reasonable bounds
720        assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
721        assert!(
722            min_val > -4.0,
723            "Values should not be too negative, min: {}",
724            min_val
725        );
726        assert!(
727            max_val < 4.0,
728            "Values should not be too positive, max: {}",
729            max_val
730        );
731    }
732
733    #[test]
734    fn test_fill_randn() {
735        let mut tensor = Tensor::new(vec![2, 3]);
736        tensor.fill_randn(Some(42));
737
738        assert_eq!(tensor.size(), 6);
739
740        // Should not be all zeros
741        let mut has_non_zero = false;
742        for i in 0..tensor.size() {
743            unsafe {
744                if *tensor.as_ptr().add(i) != 0.0 {
745                    has_non_zero = true;
746                    break;
747                }
748            }
749        }
750        assert!(has_non_zero, "Random tensor should not be all zeros");
751    }
752}