train_station/tensor/init/
random.rs

1//! Random tensor initialization methods
2//!
3//! This module provides methods to create tensors with random values using
4//! efficient random number generation algorithms. All methods are optimized
5//! for performance with SIMD operations and provide reproducible results
6//! when seeds are specified.
7//!
8//! # Key Features
9//!
10//! - **`randn`**: Create tensors with normally distributed random values (mean=0, std=1)
11//! - **Box-Muller Transform**: Efficient normal distribution generation
12//! - **SIMD Optimization**: Vectorized operations for large tensors
13//! - **Reproducible Results**: Optional seed-based generation for deterministic output
14//! - **Thread Safety**: Thread-local random state management
15//! - **Statistical Quality**: High-quality random number generation
16//!
17//! # Performance Characteristics
18//!
19//! - **Box-Muller Transform**: Efficient normal distribution generation
20//! - **SIMD Operations**: AVX2-optimized operations for large tensors
21//! - **Memory Efficient**: Single-pass generation with optimized allocation
22//! - **Unrolled Loops**: 4x unrolling for better instruction throughput
23//! - **Zero Overhead**: Minimal validation overhead for correct usage
24//!
25//! # Examples
26//!
27//! ## Basic Random Generation
28//!
29//! ```
30//! use train_station::Tensor;
31//!
32//! // Create a 2x3 tensor with random normal values
33//! let tensor = Tensor::randn(vec![2, 3], None);
34//! assert_eq!(tensor.size(), 6);
35//! assert_eq!(tensor.shape().dims, vec![2, 3]);
36//!
37//! // Verify random values are generated
38//! let first_value = tensor.get(&[0, 0]);
39//! assert!(first_value != 0.0); // Should be random
40//! ```
41//!
42//! ## Reproducible Random Generation
43//!
44//! ```
45//! use train_station::Tensor;
46//!
47//! // Create with fixed seed for reproducible results
48//! let tensor1 = Tensor::randn(vec![100], Some(42));
49//! let tensor2 = Tensor::randn(vec![100], Some(42));
50//!
51//! // tensor1 and tensor2 will have identical values
52//! for i in 0..tensor1.size() {
53//!     assert!((tensor1.get(&[i]) - tensor2.get(&[i])).abs() < 1e-6);
54//! }
55//! ```
56//!
57//! ## Different Seeds Produce Different Results
58//!
59//! ```
60//! use train_station::Tensor;
61//!
62//! let tensor1 = Tensor::randn(vec![100], Some(1));
63//! let tensor2 = Tensor::randn(vec![100], Some(2));
64//!
65//! // Should be different with different seeds
66//! let mut different = false;
67//! for i in 0..tensor1.size() {
68//!     if (tensor1.get(&[i]) - tensor2.get(&[i])).abs() > 1e-6 {
69//!         different = true;
70//!         break;
71//!     }
72//! }
73//! assert!(different, "Tensors with different seeds should be different");
74//! ```
75//!
76//! ## Large Tensor Generation
77//!
78//! ```
79//! use train_station::Tensor;
80//!
81//! // Efficient generation of large tensors
82//! let tensor = Tensor::randn(vec![100, 100], Some(42));
83//! assert_eq!(tensor.size(), 10000);
84//!
85//! // Check statistical properties
86//! let mut sum = 0.0;
87//! for i in 0..tensor.size() {
88//!     sum += tensor.get(&[i / 100, i % 100]);
89//! }
90//! let mean = sum / tensor.size() as f32;
91//!
92//! // Mean should be close to 0 for normal distribution
93//! assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
94//! ```
95//!
96//! ## Zero-Sized Tensor Handling
97//!
98//! ```
99//! use train_station::Tensor;
100//!
101//! // Handle empty tensors gracefully
102//! let tensor = Tensor::randn(vec![0], Some(42));
103//! assert_eq!(tensor.size(), 0);
104//! assert_eq!(tensor.shape().dims, vec![0]);
105//! ```
106//!
107//! # Design Principles
108//!
109//! - **Statistical Quality**: High-quality random number generation with proper distribution
110//! - **Performance First**: SIMD-optimized operations for maximum speed
111//! - **Reproducibility**: Optional seed-based generation for deterministic results
112//! - **Memory Efficiency**: Efficient memory operations with minimal overhead
113//! - **Thread Safety**: Safe concurrent access with thread-local state
114//! - **Numerical Stability**: Robust handling of edge cases and numerical issues
115
116use crate::tensor::core::Tensor;
117use std::collections::hash_map::DefaultHasher;
118use std::hash::{Hash, Hasher};
119
120impl Tensor {
121    /// Creates a tensor with normally distributed random values (mean=0, std=1)
122    ///
123    /// Similar to PyTorch's `torch.randn()`, creates a tensor filled with random
124    /// values drawn from a standard normal distribution (mean=0, standard deviation=1).
125    /// Uses Box-Muller transform for efficient normal distribution generation.
126    ///
127    /// This method provides high-quality random number generation with optional
128    /// reproducibility through seed-based generation. The generated values follow
129    /// a standard normal distribution suitable for machine learning applications.
130    ///
131    /// # Arguments
132    ///
133    /// * `shape_dims` - Vector of dimension sizes defining the tensor shape
134    /// * `seed` - Optional seed for reproducible random generation
135    ///
136    /// # Returns
137    ///
138    /// A new tensor with normally distributed random values
139    ///
140    /// # Performance
141    ///
142    /// - **Box-Muller Transform**: Efficient normal distribution generation
143    /// - **SIMD Optimization**: Vectorized operations for large tensors
144    /// - **Memory Efficient**: Single-pass generation with optimized allocation
145    /// - **Thread Safe**: Uses thread-local random state
146    ///
147    /// # Examples
148    ///
149    /// ## Basic Usage
150    ///
151    /// ```
152    /// use train_station::Tensor;
153    ///
154    /// // Create a 2x3 tensor with random normal values
155    /// let tensor = Tensor::randn(vec![2, 3], None);
156    /// assert_eq!(tensor.size(), 6);
157    /// assert_eq!(tensor.shape().dims, vec![2, 3]);
158    ///
159    /// // Verify random values are generated
160    /// let first_value = tensor.get(&[0, 0]);
161    /// assert!(first_value != 0.0); // Should be random
162    /// ```
163    ///
164    /// ## Reproducible Generation
165    ///
166    /// ```
167    /// use train_station::Tensor;
168    ///
169    /// // Create with fixed seed for reproducible results
170    /// let tensor1 = Tensor::randn(vec![100], Some(42));
171    /// let tensor2 = Tensor::randn(vec![100], Some(42));
172    ///
173    /// // tensor1 and tensor2 will have identical values
174    /// for i in 0..tensor1.size() {
175    ///     assert!((tensor1.get(&[i]) - tensor2.get(&[i])).abs() < 1e-6);
176    /// }
177    /// ```
178    ///
179    /// ## Statistical Properties
180    ///
181    /// ```
182    /// use train_station::Tensor;
183    ///
184    /// // Generate large tensor for statistical analysis
185    /// let tensor = Tensor::randn(vec![1000], Some(42));
186    /// assert_eq!(tensor.size(), 1000);
187    ///
188    /// // Check that values are reasonable (within 4 standard deviations)
189    /// let mut min_val = f32::INFINITY;
190    /// let mut max_val = f32::NEG_INFINITY;
191    /// let mut sum = 0.0;
192    ///
193    /// for i in 0..tensor.size() {
194    ///     let val = tensor.get(&[i]);
195    ///     min_val = min_val.min(val);
196    ///     max_val = max_val.max(val);
197    ///     sum += val;
198    /// }
199    ///
200    /// let mean = sum / tensor.size() as f32;
201    ///
202    /// // Mean should be close to 0, values should be within reasonable bounds
203    /// assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
204    /// assert!(min_val > -4.0, "Values should not be too negative, min: {}", min_val);
205    /// assert!(max_val < 4.0, "Values should not be too positive, max: {}", max_val);
206    /// ```
207    ///
208    /// ## Zero-Sized Tensors
209    ///
210    /// ```
211    /// use train_station::Tensor;
212    ///
213    /// // Handle empty tensors gracefully
214    /// let tensor = Tensor::randn(vec![0], Some(42));
215    /// assert_eq!(tensor.size(), 0);
216    /// assert_eq!(tensor.shape().dims, vec![0]);
217    /// ```
218    ///
219    /// # Implementation Details
220    ///
221    /// This method uses the Box-Muller transform to generate normally distributed
222    /// random variables from uniform random variables. The process involves:
223    /// 1. **Random Number Generation**: Uses Xorshift algorithm for uniform random numbers
224    /// 2. **Box-Muller Transform**: Converts uniform random variables to normal distribution
225    /// 3. **SIMD Optimization**: Vectorized operations for large tensors when available
226    /// 4. **Numerical Stability**: Robust handling of edge cases and potential NaN values
227    ///
228    /// The Box-Muller transform ensures that the generated values follow a true
229    /// normal distribution with mean=0 and standard deviation=1, making it suitable
230    /// for machine learning applications requiring normally distributed random values.
231    #[track_caller]
232    pub fn randn(shape_dims: Vec<usize>, seed: Option<u64>) -> Self {
233        let mut tensor = Self::new(shape_dims);
234        tensor.fill_randn(seed);
235        tensor
236    }
237
238    /// Fills the tensor with normally distributed random values
239    ///
240    /// Internal method that fills an existing tensor with random values from
241    /// a standard normal distribution. Uses Box-Muller transform for efficiency
242    /// and provides SIMD optimization for large tensors.
243    ///
244    /// This method is used internally by `randn()` and provides the core
245    /// random number generation functionality with optimized performance
246    /// characteristics.
247    ///
248    /// # Arguments
249    ///
250    /// * `seed` - Optional seed for reproducible random generation
251    ///
252    /// # Performance
253    ///
254    /// - **Box-Muller Transform**: Generates pairs of normal random variables
255    /// - **SIMD Optimization**: Vectorized operations when possible
256    /// - **Memory Efficient**: Single-pass generation
257    /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
258    ///
259    /// # Implementation Details
260    ///
261    /// The method performs the following steps:
262    /// 1. **Zero-sized Check**: Returns early for empty tensors
263    /// 2. **RNG Initialization**: Creates Xorshift RNG with seed or system time
264    /// 3. **SIMD Detection**: Checks for AVX2 availability for optimized path
265    /// 4. **Generation**: Uses SIMD or scalar path based on hardware support
266    /// 5. **Completion**: Fills all tensor elements with normal random values
267    ///
268    /// The method automatically handles hardware capabilities and falls back
269    /// to scalar operations when SIMD is not available, ensuring compatibility
270    /// across different CPU architectures.
271    #[track_caller]
272    pub fn fill_randn(&mut self, seed: Option<u64>) {
273        if self.shape().size == 0 {
274            return;
275        }
276
277        // Initialize random number generator
278        let mut rng = if let Some(seed_val) = seed {
279            // Use provided seed for reproducible results
280            XorShiftRng::new(seed_val)
281        } else {
282            // Use system time for non-reproducible results
283            XorShiftRng::new_from_time()
284        };
285
286        unsafe {
287            let ptr = self.as_ptr();
288
289            #[cfg(target_arch = "x86_64")]
290            {
291                // Use SIMD for better performance when available
292                if is_x86_feature_detected!("avx2") {
293                    self.fill_randn_simd_avx2(ptr, &mut rng);
294                    return;
295                }
296            }
297
298            // Fallback to scalar operations
299            self.fill_randn_scalar(ptr, &mut rng);
300        }
301    }
302
303    /// Fills the tensor with normally distributed random values using AVX2 SIMD
304    ///
305    /// Internal method that uses AVX2 instructions to efficiently fill large tensors
306    /// with normal random values. Processes 8 elements per iteration for maximum
307    /// memory bandwidth utilization.
308    ///
309    /// # Arguments
310    ///
311    /// * `ptr` - Pointer to the tensor data
312    /// * `rng` - Random number generator instance
313    ///
314    /// # Safety
315    ///
316    /// The caller must ensure:
317    /// * `ptr` is a valid pointer to tensor data
318    /// * The tensor size matches the allocated memory
319    /// * AVX2 is available on the target architecture
320    /// * `rng` is a valid random number generator instance
321    ///
322    /// # Performance
323    ///
324    /// - **SIMD Operations**: 8 elements per iteration using AVX2
325    /// - **Memory Bandwidth**: Optimized for maximum memory bandwidth utilization
326    /// - **Remaining Elements**: Efficient handling of non-multiple-of-8 sizes
327    /// - **Box-Muller Transform**: Integrated normal distribution generation
328    ///
329    /// # Implementation Details
330    ///
331    /// This method uses AVX2 SIMD instructions to fill memory efficiently:
332    /// 1. Generates 8 normal random values using Box-Muller transform
333    /// 2. Loads values into AVX2 vector register using `_mm256_loadu_ps`
334    /// 3. Stores vector to memory using `_mm256_storeu_ps`
335    /// 4. Processes remaining elements with scalar operations
336    ///
337    /// The method provides significant performance improvements for large tensors
338    /// by reducing the number of memory operations and leveraging vectorized
339    /// floating-point operations.
340    #[cfg(target_arch = "x86_64")]
341    #[inline]
342    unsafe fn fill_randn_simd_avx2(&self, ptr: *const f32, rng: &mut XorShiftRng) {
343        let mut_ptr = ptr as *mut f32;
344        let size = self.shape().size;
345        let simd_count = size / 8; // Process 8 elements per iteration
346        let mut offset = 0;
347
348        // SIMD loop for normal distribution generation
349        for _ in 0..simd_count {
350            let mut values = [0.0f32; 8];
351            for i in &mut values {
352                *i = rng.next_normal();
353            }
354
355            // Store 8 values using SIMD
356            let vec = _mm256_loadu_ps(values.as_ptr());
357            _mm256_storeu_ps(mut_ptr.add(offset), vec);
358            offset += 8;
359        }
360
361        // Handle remaining elements
362        for i in offset..size {
363            *mut_ptr.add(i) = rng.next_normal();
364        }
365    }
366
367    /// Fills the tensor with normally distributed random values using scalar operations
368    ///
369    /// Internal fallback method that uses scalar operations to fill tensors with
370    /// normal random values. Provides 4x unrolled loops for better instruction
371    /// throughput and serves as a fallback when SIMD is not available.
372    ///
373    /// # Arguments
374    ///
375    /// * `ptr` - Pointer to the tensor data
376    /// * `rng` - Random number generator instance
377    ///
378    /// # Safety
379    ///
380    /// The caller must ensure:
381    /// * `ptr` is a valid pointer to tensor data
382    /// * The tensor size matches the allocated memory
383    /// * `rng` is a valid random number generator instance
384    ///
385    /// # Performance
386    ///
387    /// - **Unrolled Loops**: 4x unrolling for better instruction throughput
388    /// - **Box-Muller Transform**: Integrated normal distribution generation
389    /// - **Remaining Elements**: Efficient handling of non-multiple-of-4 sizes
390    /// - **Cross-Platform**: Works on all CPU architectures
391    ///
392    /// # Implementation Details
393    ///
394    /// This method provides optimized scalar operations:
395    /// 1. **Unrolled Generation**: Processes 4 elements per iteration
396    /// 2. **Box-Muller Transform**: Generates normal random values
397    /// 3. **Remaining Elements**: Handles final elements individually
398    /// 4. **Cross-Platform**: No architecture-specific dependencies
399    ///
400    /// The 4x unrolling reduces loop overhead and improves instruction-level
401    /// parallelism, making scalar operations more efficient than naive loops.
402    #[inline]
403    unsafe fn fill_randn_scalar(&self, ptr: *const f32, rng: &mut XorShiftRng) {
404        let mut_ptr = ptr as *mut f32;
405        let size = self.shape().size;
406        let unroll_count = size / 4;
407        let mut offset = 0;
408
409        // Unrolled scalar loop for better performance
410        for _ in 0..unroll_count {
411            *mut_ptr.add(offset) = rng.next_normal();
412            *mut_ptr.add(offset + 1) = rng.next_normal();
413            *mut_ptr.add(offset + 2) = rng.next_normal();
414            *mut_ptr.add(offset + 3) = rng.next_normal();
415            offset += 4;
416        }
417
418        // Handle remaining elements
419        for i in offset..size {
420            *mut_ptr.add(i) = rng.next_normal();
421        }
422    }
423}
424
425// SIMD optimizations for performance-critical operations
426#[cfg(target_arch = "x86_64")]
427use std::arch::x86_64::*;
428
429/// Fast random number generator using Xorshift algorithm
430///
431/// Provides efficient random number generation with good statistical properties.
432/// Implements Box-Muller transform for normal distribution generation.
433///
434/// The Xorshift algorithm is a fast, non-cryptographic random number generator
435/// that provides good statistical properties for machine learning applications.
436/// It combines multiple bit-shift and XOR operations to produce high-quality
437/// random sequences with long periods.
438///
439/// # Performance
440///
441/// - **Fast Generation**: Minimal computational overhead
442/// - **Good Statistical Properties**: Passes standard statistical tests
443/// - **Long Period**: 2^64 - 1 period for u64 state
444/// - **Memory Efficient**: Single u64 state variable
445///
446/// # Implementation Details
447///
448/// The Xorshift algorithm uses three bit-shift and XOR operations:
449/// 1. `state ^= state << 13` - Left shift by 13 bits
450/// 2. `state ^= state >> 7` - Right shift by 7 bits  
451/// 3. `state ^= state << 17` - Left shift by 17 bits
452///
453/// This sequence provides excellent statistical properties and is much faster
454/// than more complex generators like Mersenne Twister.
455struct XorShiftRng {
456    state: u64,
457}
458
459impl XorShiftRng {
460    /// Creates a new random number generator with the specified seed
461    ///
462    /// Initializes the RNG with a user-provided seed for reproducible
463    /// random number generation. The same seed will always produce
464    /// the same sequence of random numbers.
465    ///
466    /// # Arguments
467    ///
468    /// * `seed` - The seed value for reproducible generation
469    ///
470    /// # Implementation Details
471    ///
472    /// This method initializes the internal state with the provided seed value.
473    /// The same seed will always produce the same sequence of random numbers,
474    /// making it suitable for reproducible random number generation.
475    fn new(seed: u64) -> Self {
476        Self { state: seed }
477    }
478
479    /// Creates a new random number generator seeded from system time
480    ///
481    /// Initializes the RNG with a seed derived from the current system time,
482    /// providing non-reproducible random number generation. Each call will
483    /// produce a different sequence of random numbers.
484    ///
485    /// # Implementation Details
486    ///
487    /// This method uses the system time to generate a seed:
488    /// 1. Gets current system time using `std::time::SystemTime::now()`
489    /// 2. Hashes the time value using `DefaultHasher`
490    /// 3. Uses the hash result as the RNG seed
491    ///
492    /// This approach provides good entropy for non-reproducible generation
493    /// while being efficient and portable across different platforms.
494    fn new_from_time() -> Self {
495        let mut hasher = DefaultHasher::new();
496        std::time::SystemTime::now().hash(&mut hasher);
497        let seed = hasher.finish();
498        Self { state: seed }
499    }
500
501    /// Generates the next random u64 value
502    ///
503    /// Produces the next 64-bit random value using the Xorshift algorithm.
504    /// This is the core random number generation method that drives all
505    /// other random value generation.
506    ///
507    /// # Returns
508    ///
509    /// A random u64 value from the Xorshift sequence
510    ///
511    /// # Performance
512    ///
513    /// - **Fast**: Only 3 bit-shift and XOR operations
514    /// - **Efficient**: No branching or complex arithmetic
515    /// - **Deterministic**: Same seed produces same sequence
516    ///
517    /// # Implementation Details
518    ///
519    /// The Xorshift algorithm performs three operations:
520    /// 1. `self.state ^= self.state << 13` - Left shift and XOR
521    /// 2. `self.state ^= self.state >> 7` - Right shift and XOR
522    /// 3. `self.state ^= self.state << 17` - Left shift and XOR
523    ///
524    /// This sequence provides excellent statistical properties with minimal
525    /// computational overhead, making it ideal for high-performance applications.
526    fn next_u64(&mut self) -> u64 {
527        self.state ^= self.state << 13;
528        self.state ^= self.state >> 7;
529        self.state ^= self.state << 17;
530        self.state
531    }
532
533    /// Generates the next random f32 value in [0, 1)
534    ///
535    /// Produces a random floating-point value in the half-open interval [0, 1).
536    /// Uses proper bit manipulation to ensure uniform distribution across
537    /// the floating-point range.
538    ///
539    /// # Returns
540    ///
541    /// A random f32 value in [0, 1)
542    ///
543    /// # Performance
544    ///
545    /// - **Efficient**: Single u64 generation with bit manipulation
546    /// - **Uniform**: Proper distribution across floating-point range
547    /// - **Precise**: Uses 23 bits of mantissa for good precision
548    ///
549    /// # Implementation Details
550    ///
551    /// The method converts u64 random bits to f32 using bit manipulation:
552    /// 1. Generates random u64 using `next_u64()`
553    /// 2. Extracts 23 bits for mantissa (IEEE 754 f32 format)
554    /// 3. Sets exponent to 0 (bias 126) for values in [1, 2)
555    /// 4. Converts to f32 and subtracts 1.0 for [0, 1) range
556    ///
557    /// This approach provides uniform distribution and avoids the bias
558    /// that can occur with simple division-based methods.
559    fn next_f32(&mut self) -> f32 {
560        let bits = self.next_u64();
561        // Convert to f32 in [0, 1) using proper bit manipulation
562        let mantissa = (bits & 0x7FFFFF) as u32; // 23 bits for mantissa
563        let exponent = 126 << 23; // 2^0 = 1, so bias 126 gives exponent 0
564        let float_bits = mantissa | exponent;
565        f32::from_bits(float_bits)
566    }
567
568    /// Generates the next normally distributed random value using Box-Muller transform
569    ///
570    /// Produces a random value from a standard normal distribution (mean=0, std=1)
571    /// using the Box-Muller transform. This method converts uniform random variables
572    /// to normally distributed random variables efficiently.
573    ///
574    /// # Returns
575    ///
576    /// A random f32 value from N(0, 1) distribution
577    ///
578    /// # Performance
579    ///
580    /// - **Efficient**: Single pass transformation
581    /// - **Accurate**: Proper normal distribution properties
582    /// - **Stable**: Handles edge cases and numerical issues
583    ///
584    /// # Implementation Details
585    ///
586    /// The Box-Muller transform converts uniform random variables to normal:
587    /// 1. Generates two uniform random variables u1, u2 in (0, 1)
588    /// 2. Applies Box-Muller formula: z = sqrt(-2*ln(u1)) * cos(2π*u2)
589    /// 3. Handles edge cases (u1 ≤ 0, u1 ≥ 1, NaN, infinite values)
590    /// 4. Returns z as normally distributed random value
591    ///
592    /// The method includes robust error handling to ensure numerical stability
593    /// and prevent invalid results from edge cases in the transformation.
594    fn next_normal(&mut self) -> f32 {
595        // Box-Muller transform: convert uniform random variables to normal
596        let u1 = self.next_f32();
597        let u2 = self.next_f32();
598
599        // Avoid log(0) and ensure u1 is in (0, 1)
600        let u1 = if u1 <= 0.0 {
601            1e-7
602        } else if u1 >= 1.0 {
603            1.0 - 1e-7
604        } else {
605            u1
606        };
607
608        // Box-Muller transform
609        let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
610
611        // Handle potential NaN or infinite values
612        if z0.is_nan() || z0.is_infinite() {
613            0.0
614        } else {
615            z0
616        }
617    }
618}
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623
624    #[test]
625    fn test_randn_basic() {
626        let tensor = Tensor::randn(vec![2, 3], Some(42));
627        assert_eq!(tensor.size(), 6);
628        assert_eq!(tensor.shape().dims, vec![2, 3]);
629
630        // With fixed seed, should be reproducible
631        let tensor2 = Tensor::randn(vec![2, 3], Some(42));
632        for i in 0..tensor.size() {
633            unsafe {
634                assert!((*tensor.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() < 1e-6);
635            }
636        }
637    }
638
639    #[test]
640    fn test_randn_reproducible() {
641        let seed = 12345;
642        let tensor1 = Tensor::randn(vec![100], Some(seed));
643        let tensor2 = Tensor::randn(vec![100], Some(seed));
644
645        // Should be identical with same seed
646        for i in 0..tensor1.size() {
647            unsafe {
648                assert!((*tensor1.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() < 1e-6);
649            }
650        }
651    }
652
653    #[test]
654    fn test_randn_different_seeds() {
655        let tensor1 = Tensor::randn(vec![100], Some(1));
656        let tensor2 = Tensor::randn(vec![100], Some(2));
657
658        // Should be different with different seeds
659        let mut different = false;
660        for i in 0..tensor1.size() {
661            unsafe {
662                if (*tensor1.as_ptr().add(i) - *tensor2.as_ptr().add(i)).abs() > 1e-6 {
663                    different = true;
664                    break;
665                }
666            }
667        }
668        assert!(
669            different,
670            "Tensors with different seeds should be different"
671        );
672    }
673
674    #[test]
675    fn test_randn_no_seed() {
676        let tensor = Tensor::randn(vec![10], None);
677        assert_eq!(tensor.size(), 10);
678        assert_eq!(tensor.shape().dims, vec![10]);
679
680        // Should not be all zeros
681        let mut has_non_zero = false;
682        for i in 0..tensor.size() {
683            unsafe {
684                if *tensor.as_ptr().add(i) != 0.0 {
685                    has_non_zero = true;
686                    break;
687                }
688            }
689        }
690        assert!(has_non_zero, "Random tensor should not be all zeros");
691    }
692
693    #[test]
694    fn test_randn_zero_sized() {
695        let tensor = Tensor::randn(vec![0], Some(42));
696        assert_eq!(tensor.size(), 0);
697        assert_eq!(tensor.shape().dims, vec![0]);
698    }
699
700    #[test]
701    fn test_randn_large_tensor() {
702        let tensor = Tensor::randn(vec![100, 100], Some(42));
703        assert_eq!(tensor.size(), 10000);
704
705        // Check that values are reasonable (within 4 standard deviations)
706        let mut min_val = f32::INFINITY;
707        let mut max_val = f32::NEG_INFINITY;
708        let mut sum = 0.0;
709
710        for i in 0..tensor.size() {
711            unsafe {
712                let val = *tensor.as_ptr().add(i);
713                min_val = min_val.min(val);
714                max_val = max_val.max(val);
715                sum += val;
716            }
717        }
718
719        let mean = sum / tensor.size() as f32;
720
721        // Mean should be close to 0, values should be within reasonable bounds
722        assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
723        assert!(
724            min_val > -4.0,
725            "Values should not be too negative, min: {}",
726            min_val
727        );
728        assert!(
729            max_val < 4.0,
730            "Values should not be too positive, max: {}",
731            max_val
732        );
733    }
734
735    #[test]
736    fn test_fill_randn() {
737        let mut tensor = Tensor::new(vec![2, 3]);
738        tensor.fill_randn(Some(42));
739
740        assert_eq!(tensor.size(), 6);
741
742        // Should not be all zeros
743        let mut has_non_zero = false;
744        for i in 0..tensor.size() {
745            unsafe {
746                if *tensor.as_ptr().add(i) != 0.0 {
747                    has_non_zero = true;
748                    break;
749                }
750            }
751        }
752        assert!(has_non_zero, "Random tensor should not be all zeros");
753    }
754}