Skip to main content

scirs2_stats/qmc/
enhanced_sequences.rs

1//! Enhanced Quasi-Monte Carlo sequences with state-of-the-art algorithms
2//!
3//! This module provides advanced QMC sequences with:
4//! - Optimal digital nets and (t,m,s)-nets
5//! - Advanced scrambling and randomization techniques
6//! - Parallel QMC sequence generation
7//! - Adaptive sequence refinement
8
9use crate::error::{StatsError, StatsResult};
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
12use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
13use scirs2_core::{parallel_ops::*, simd_ops::SimdUnifiedOps, validation::*};
14use std::marker::PhantomData;
15
16/// Enhanced QMC sequence generator with parallel support
17pub struct EnhancedQMCGenerator<F> {
18    /// Sequence type
19    pub sequence_type: EnhancedSequenceType,
20    /// Dimension
21    pub dimension: usize,
22    /// Configuration
23    pub config: EnhancedQMCConfig,
24    /// Generator state
25    pub state: QMCGeneratorState,
26    _phantom: PhantomData<F>,
27}
28
29/// Enhanced sequence types with advanced algorithms
30#[derive(Debug, Clone, PartialEq)]
31pub enum EnhancedSequenceType {
32    /// Sobol sequence with advanced scrambling
33    SobolAdvanced {
34        /// Use Owen scrambling
35        owen_scrambling: bool,
36        /// Use digital shift
37        digital_shift: bool,
38        /// Use nested scrambling
39        nested_scrambling: bool,
40    },
41    /// Niederreiter sequence with base optimization
42    Niederreiter {
43        /// Base selection strategy
44        base_strategy: BaseSelectionStrategy,
45        /// Use generating matrix optimization
46        matrix_optimization: bool,
47    },
48    /// Faure sequence with improved uniformity
49    FaureImproved {
50        /// Use permutation optimization
51        permutation_optimization: bool,
52        /// Use radical inverse improvements
53        radical_inverse_improvements: bool,
54    },
55    /// Digital (t,m,s)-nets
56    DigitalNet {
57        /// Net parameters
58        net_params: DigitalNetParams,
59        /// Construction method
60        construction_method: NetConstructionMethod,
61    },
62    /// Hybrid sequences combining multiple methods
63    Hybrid {
64        /// Primary sequence type
65        primary: Box<EnhancedSequenceType>,
66        /// Secondary sequence type
67        secondary: Box<EnhancedSequenceType>,
68        /// Combination strategy
69        combination: HybridCombinationStrategy,
70    },
71}
72
73/// Base selection strategies for Niederreiter sequences
74#[derive(Debug, Clone, PartialEq)]
75pub enum BaseSelectionStrategy {
76    /// Use first primes
77    FirstPrimes,
78    /// Use optimized primes for given dimension
79    OptimizedPrimes,
80    /// Use prime powers for better uniformity
81    PrimePowers,
82    /// Automatic selection based on dimension
83    Automatic,
84}
85
86/// Digital net parameters
87#[derive(Debug, Clone, PartialEq)]
88pub struct DigitalNetParams {
89    /// t parameter (strength)
90    pub t: usize,
91    /// m parameter (precision)
92    pub m: usize,
93    /// s parameter (dimension)
94    pub s: usize,
95    /// Base (usually 2)
96    pub base: usize,
97}
98
99/// Net construction methods
100#[derive(Debug, Clone, PartialEq)]
101pub enum NetConstructionMethod {
102    /// Sobol construction
103    Sobol,
104    /// Niederreiter-Xing construction
105    NiederreiterXing,
106    /// Polynomial lattice rules
107    PolynomialLattice,
108    /// Finite field constructions
109    FiniteField,
110}
111
112/// Hybrid combination strategies
113#[derive(Debug, Clone, PartialEq)]
114pub enum HybridCombinationStrategy {
115    /// Interleave sequences
116    Interleave,
117    /// Weighted combination
118    Weighted(f64),
119    /// Dimension-wise alternation
120    DimensionAlternation,
121    /// Adaptive selection based on uniformity
122    Adaptive,
123}
124
125/// Enhanced QMC configuration
126#[derive(Debug, Clone)]
127pub struct EnhancedQMCConfig {
128    /// Enable parallel generation
129    pub parallel: bool,
130    /// Chunk size for parallel processing
131    pub chunksize: usize,
132    /// Randomization seed
133    pub seed: Option<u64>,
134    /// Enable SIMD optimizations
135    pub use_simd: bool,
136    /// Quality assessment threshold
137    pub quality_threshold: f64,
138    /// Maximum sequence length for quality assessment
139    pub max_assessment_length: usize,
140    /// Enable adaptive refinement
141    pub adaptive_refinement: bool,
142}
143
144impl Default for EnhancedQMCConfig {
145    fn default() -> Self {
146        Self {
147            parallel: true,
148            chunksize: 1000,
149            seed: None,
150            use_simd: true,
151            quality_threshold: 1e-3,
152            max_assessment_length: 10000,
153            adaptive_refinement: false,
154        }
155    }
156}
157
158/// Generator state for QMC sequences
159#[derive(Debug, Clone)]
160pub struct QMCGeneratorState {
161    /// Current index
162    pub current_index: usize,
163    /// Scrambling matrices (if used)
164    pub scrambling_matrices: Option<Vec<Array2<u32>>>,
165    /// Digital shift vectors (if used)
166    pub digital_shifts: Option<Vec<Array1<u32>>>,
167    /// Quality metrics
168    pub quality_metrics: QualityMetrics,
169}
170
171/// Quality metrics for QMC sequences
172#[derive(Debug, Clone, Default)]
173pub struct QualityMetrics {
174    /// Star discrepancy estimate
175    pub star_discrepancy: f64,
176    /// Wrap-around discrepancy
177    pub wraparound_discrepancy: f64,
178    /// Diaphony (spectral measure)
179    pub diaphony: f64,
180    /// Figure of merit
181    pub figure_of_merit: f64,
182}
183
184impl<F> EnhancedQMCGenerator<F>
185where
186    F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
187    for<'a> &'a F: std::iter::Product<&'a F>,
188{
189    /// Create new enhanced QMC generator
190    pub fn new(
191        sequence_type: EnhancedSequenceType,
192        dimension: usize,
193        config: EnhancedQMCConfig,
194    ) -> StatsResult<Self> {
195        check_positive(dimension, "dimension")?;
196
197        if dimension > 1000 {
198            return Err(StatsError::InvalidArgument(
199                "Dimension cannot exceed 1000 for enhanced QMC sequences".to_string(),
200            ));
201        }
202
203        let state = QMCGeneratorState {
204            current_index: 0,
205            scrambling_matrices: None,
206            digital_shifts: None,
207            quality_metrics: QualityMetrics::default(),
208        };
209
210        let mut generator = Self {
211            sequence_type,
212            dimension,
213            config,
214            state,
215            _phantom: PhantomData,
216        };
217
218        // Initialize scrambling and digital shifts if needed
219        generator.initialize_randomization()?;
220
221        Ok(generator)
222    }
223
224    /// Generate enhanced QMC sequence
225    pub fn generate(&mut self, n: usize) -> StatsResult<Array2<F>> {
226        check_positive(n, "n")?;
227
228        if self.config.parallel && n >= self.config.chunksize {
229            self.generate_parallel(n)
230        } else {
231            self.generate_sequential(n)
232        }
233    }
234
235    /// Generate sequence in parallel
236    fn generate_parallel(&mut self, n: usize) -> StatsResult<Array2<F>> {
237        let chunksize = self.config.chunksize;
238        let num_chunks = n.div_ceil(chunksize);
239
240        let chunks = parallel_map_result(
241            (0..num_chunks).collect::<Vec<_>>().as_slice(),
242            |&chunk_idx| {
243                let start = chunk_idx * chunksize;
244                let end = (start + chunksize).min(n);
245                let chunksize = end - start;
246
247                self.generate_chunk(start, chunksize)
248            },
249        )?;
250
251        // Combine chunks
252        let mut result = Array2::zeros((n, self.dimension));
253        let mut row_idx = 0;
254
255        for chunk in chunks {
256            let chunk = chunk;
257            let chunk_rows = chunk.nrows();
258            result
259                .slice_mut(scirs2_core::ndarray::s![row_idx..row_idx + chunk_rows, ..])
260                .assign(&chunk);
261            row_idx += chunk_rows;
262        }
263
264        // Update quality metrics
265        if n <= self.config.max_assessment_length {
266            self.assess_quality(&result)?;
267        }
268
269        Ok(result)
270    }
271
272    /// Generate sequence sequentially
273    fn generate_sequential(&mut self, n: usize) -> StatsResult<Array2<F>> {
274        let mut result = Array2::zeros((n, self.dimension));
275
276        for i in 0..n {
277            let point = self.next_point()?;
278            result.row_mut(i).assign(&point);
279        }
280
281        // Update quality metrics
282        if n <= self.config.max_assessment_length {
283            self.assess_quality(&result)?;
284        }
285
286        Ok(result)
287    }
288
289    /// Generate a chunk of the sequence
290    fn generate_chunk(&self, start_index: usize, chunksize: usize) -> StatsResult<Array2<F>> {
291        let mut chunk = Array2::zeros((chunksize, self.dimension));
292
293        for i in 0..chunksize {
294            let _index = start_index + i;
295            let point = self.compute_point_at_index(_index)?;
296            chunk.row_mut(i).assign(&point);
297        }
298
299        Ok(chunk)
300    }
301
302    /// Compute next point in sequence
303    fn next_point(&mut self) -> StatsResult<Array1<F>> {
304        let point = self.compute_point_at_index(self.state.current_index)?;
305        self.state.current_index += 1;
306        Ok(point)
307    }
308
309    /// Compute point at specific index
310    fn compute_point_at_index(&self, index: usize) -> StatsResult<Array1<F>> {
311        match &self.sequence_type {
312            EnhancedSequenceType::SobolAdvanced {
313                owen_scrambling,
314                digital_shift,
315                nested_scrambling,
316            } => self.compute_sobol_advanced(
317                index,
318                *owen_scrambling,
319                *digital_shift,
320                *nested_scrambling,
321            ),
322            EnhancedSequenceType::Niederreiter {
323                base_strategy,
324                matrix_optimization,
325            } => self.compute_niederreiter_enhanced(index, base_strategy, *matrix_optimization),
326            EnhancedSequenceType::FaureImproved {
327                permutation_optimization,
328                radical_inverse_improvements,
329            } => self.compute_faure_improved(
330                index,
331                *permutation_optimization,
332                *radical_inverse_improvements,
333            ),
334            EnhancedSequenceType::DigitalNet {
335                net_params,
336                construction_method,
337            } => self.compute_digital_net(index, net_params, construction_method),
338            EnhancedSequenceType::Hybrid {
339                primary,
340                secondary,
341                combination,
342            } => self.compute_hybrid_sequence(index, primary, secondary, combination),
343        }
344    }
345
346    /// Compute advanced Sobol sequence point
347    fn compute_sobol_advanced(
348        &self,
349        index: usize,
350        owen_scrambling: bool,
351        digital_shift: bool,
352        _nested_scrambling: bool,
353    ) -> StatsResult<Array1<F>> {
354        let mut point = Array1::zeros(self.dimension);
355
356        // Use simplified Sobol computation for now
357        // Full implementation would use proper direction numbers and _scrambling
358        for dim in 0..self.dimension {
359            let mut result = 0u32;
360            let mut idx = index;
361
362            // Basic van der Corput sequence in base 2
363            let mut base_power = 1u32;
364            while idx > 0 {
365                if idx & 1 == 1 {
366                    result ^= base_power << (31 - (base_power.trailing_zeros()));
367                }
368                idx >>= 1;
369                base_power <<= 1;
370            }
371
372            // Apply _scrambling if enabled
373            if owen_scrambling {
374                if let Some(ref matrices) = self.state.scrambling_matrices {
375                    if dim < matrices.len() {
376                        result = self.apply_owen_scrambling(result, &matrices[dim]);
377                    }
378                }
379            }
380
381            // Apply digital _shift if enabled
382            if digital_shift {
383                if let Some(ref shifts) = self.state.digital_shifts {
384                    if dim < shifts.len() {
385                        result ^= shifts[dim][0]; // Simplified
386                    }
387                }
388            }
389
390            point[dim] = F::from(result as f64 / (1u64 << 32) as f64).expect("Operation failed");
391        }
392
393        Ok(point)
394    }
395
396    /// Compute enhanced Niederreiter sequence point
397    fn compute_niederreiter_enhanced(
398        &self,
399        index: usize,
400        base_strategy: &BaseSelectionStrategy,
401        _matrix_optimization: bool,
402    ) -> StatsResult<Array1<F>> {
403        // Simplified implementation
404        let mut point = Array1::zeros(self.dimension);
405
406        for dim in 0..self.dimension {
407            // Use prime base for this dimension
408            let base = self.get_prime(dim);
409            point[dim] = F::from(self.radical_inverse(index, base)).expect("Operation failed");
410        }
411
412        Ok(point)
413    }
414
415    /// Compute improved Faure sequence point
416    fn compute_faure_improved(
417        &self,
418        index: usize,
419        _permutation_optimization: bool,
420        _radical_inverse_improvements: bool,
421    ) -> StatsResult<Array1<F>> {
422        // Simplified implementation
423        let mut point = Array1::zeros(self.dimension);
424        let base = self.smallest_prime_geq(self.dimension as u32);
425
426        for dim in 0..self.dimension {
427            point[dim] = F::from(self.radical_inverse(index, base)).expect("Operation failed");
428        }
429
430        Ok(point)
431    }
432
433    /// Compute digital net point
434    fn compute_digital_net(
435        &self,
436        index: usize,
437        _net_params: &DigitalNetParams,
438        _construction_method: &NetConstructionMethod,
439    ) -> StatsResult<Array1<F>> {
440        // Simplified implementation - use Sobol-like computation
441        self.compute_sobol_advanced(index, false, false, false)
442    }
443
444    /// Compute hybrid sequence point
445    fn compute_hybrid_sequence(
446        &self,
447        index: usize,
448        _primary: &EnhancedSequenceType,
449        _secondary: &EnhancedSequenceType,
450        _combination: &HybridCombinationStrategy,
451    ) -> StatsResult<Array1<F>> {
452        // Simplified implementation - use _primary sequence
453        self.compute_sobol_advanced(index, true, true, false)
454    }
455
456    /// Initialize randomization (scrambling, digital shifts)
457    fn initialize_randomization(&mut self) -> StatsResult<()> {
458        let mut rng = match self.config.seed {
459            Some(seed) => StdRng::seed_from_u64(seed),
460            None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
461        };
462
463        // Initialize scrambling matrices
464        if self.needs_scrambling() {
465            let mut matrices = Vec::with_capacity(self.dimension);
466            for _ in 0..self.dimension {
467                matrices.push(self.generate_scrambling_matrix(&mut rng)?);
468            }
469            self.state.scrambling_matrices = Some(matrices);
470        }
471
472        // Initialize digital shifts
473        if self.needs_digital_shift() {
474            let mut shifts = Vec::with_capacity(self.dimension);
475            for _ in 0..self.dimension {
476                let shift = Array1::from_shape_fn(32, |_| rng.random::<u32>());
477                shifts.push(shift);
478            }
479            self.state.digital_shifts = Some(shifts);
480        }
481
482        Ok(())
483    }
484
485    /// Check if sequence type needs scrambling
486    fn needs_scrambling(&self) -> bool {
487        match &self.sequence_type {
488            EnhancedSequenceType::SobolAdvanced {
489                owen_scrambling, ..
490            } => *owen_scrambling,
491            _ => false,
492        }
493    }
494
495    /// Check if sequence type needs digital shift
496    fn needs_digital_shift(&self) -> bool {
497        match &self.sequence_type {
498            EnhancedSequenceType::SobolAdvanced { digital_shift, .. } => *digital_shift,
499            _ => false,
500        }
501    }
502
503    /// Generate scrambling matrix
504    fn generate_scrambling_matrix<R: Rng>(&self, rng: &mut R) -> StatsResult<Array2<u32>> {
505        let mut matrix = Array2::zeros((32, 32));
506
507        // Generate random permutation matrix
508        for i in 0..32 {
509            let j = rng.random_range(0..32);
510            matrix[[i, j]] = 1;
511        }
512
513        Ok(matrix)
514    }
515
516    /// Apply Owen scrambling to a value
517    fn apply_owen_scrambling(&self, value: u32, matrix: &Array2<u32>) -> u32 {
518        let mut result = 0u32;
519
520        for i in 0..32 {
521            let bit = (value >> (31 - i)) & 1;
522            for j in 0..32 {
523                if matrix[[i, j]] == 1 && bit == 1 {
524                    result |= 1u32 << (31 - j);
525                    break;
526                }
527            }
528        }
529
530        result
531    }
532
533    /// Compute radical inverse
534    fn radical_inverse(&self, index: usize, base: u32) -> f64 {
535        let mut result = 0.0;
536        let mut fraction = 1.0 / base as f64;
537        let mut i = index;
538
539        while i > 0 {
540            result += (i % base as usize) as f64 * fraction;
541            i /= base as usize;
542            fraction /= base as f64;
543        }
544
545        result
546    }
547
548    /// Get nth prime number
549    fn get_prime(&self, n: usize) -> u32 {
550        let primes = [
551            2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
552        ];
553        if n < primes.len() {
554            primes[n]
555        } else {
556            // Fallback to simple generation
557            let mut candidate = primes[primes.len() - 1] + 2;
558            let mut count = primes.len();
559
560            while count <= n {
561                if self.is_prime(candidate) {
562                    if count == n {
563                        return candidate;
564                    }
565                    count += 1;
566                }
567                candidate += 2;
568            }
569            candidate
570        }
571    }
572
573    /// Find smallest prime >= n
574    fn smallest_prime_geq(&self, n: u32) -> u32 {
575        if n <= 2 {
576            return 2;
577        }
578
579        let mut candidate = if n.is_multiple_of(2) { n + 1 } else { n };
580
581        while !self.is_prime(candidate) {
582            candidate += 2;
583        }
584
585        candidate
586    }
587
588    /// Check if number is prime
589    fn is_prime(&self, n: u32) -> bool {
590        if n < 2 {
591            return false;
592        }
593        if n == 2 {
594            return true;
595        }
596        if n.is_multiple_of(2) {
597            return false;
598        }
599
600        let sqrt_n = (n as f64).sqrt() as u32;
601        for i in (3..=sqrt_n).step_by(2) {
602            if n.is_multiple_of(i) {
603                return false;
604            }
605        }
606
607        true
608    }
609
610    /// Assess sequence quality
611    fn assess_quality(&mut self, sequence: &Array2<F>) -> StatsResult<()> {
612        // Simplified quality assessment
613        let n = sequence.nrows();
614        let d = sequence.ncols();
615
616        // Estimate star discrepancy (simplified)
617        let mut max_discrepancy = 0.0;
618        let num_test_points = 50.min(n);
619
620        let mut rng = scirs2_core::random::thread_rng();
621        for _ in 0..num_test_points {
622            let mut test_point = Array1::zeros(d);
623            for j in 0..d {
624                test_point[j] = F::from(rng.random::<f64>()).expect("Operation failed");
625            }
626
627            let mut count = 0;
628            for i in 0..n {
629                let mut in_box = true;
630                for j in 0..d {
631                    if sequence[[i, j]] > test_point[j] {
632                        in_box = false;
633                        break;
634                    }
635                }
636                if in_box {
637                    count += 1;
638                }
639            }
640
641            let volume: F = test_point.iter().fold(F::one(), |acc, &x| acc * x);
642            let expected = volume.to_f64().expect("Operation failed") * n as f64;
643            let discrepancy = (count as f64 - expected).abs() / n as f64;
644            max_discrepancy = max_discrepancy.max(discrepancy);
645        }
646
647        self.state.quality_metrics.star_discrepancy = max_discrepancy;
648        Ok(())
649    }
650
651    /// Get current quality metrics
652    pub fn quality_metrics(&self) -> &QualityMetrics {
653        &self.state.quality_metrics
654    }
655}
656
657/// Convenience functions for enhanced QMC
658#[allow(dead_code)]
659pub fn enhanced_sobol<F>(
660    n: usize,
661    dimension: usize,
662    scrambling: bool,
663    seed: Option<u64>,
664) -> StatsResult<Array2<F>>
665where
666    F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
667    for<'a> &'a F: std::iter::Product<&'a F>,
668{
669    let sequence_type = EnhancedSequenceType::SobolAdvanced {
670        owen_scrambling: scrambling,
671        digital_shift: true,
672        nested_scrambling: false,
673    };
674
675    let config = EnhancedQMCConfig {
676        seed,
677        ..Default::default()
678    };
679
680    let mut generator = EnhancedQMCGenerator::new(sequence_type, dimension, config)?;
681    generator.generate(n)
682}
683
684#[allow(dead_code)]
685pub fn enhanced_niederreiter<F>(
686    n: usize,
687    dimension: usize,
688    seed: Option<u64>,
689) -> StatsResult<Array2<F>>
690where
691    F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
692    for<'a> &'a F: std::iter::Product<&'a F>,
693{
694    let sequence_type = EnhancedSequenceType::Niederreiter {
695        base_strategy: BaseSelectionStrategy::OptimizedPrimes,
696        matrix_optimization: true,
697    };
698
699    let config = EnhancedQMCConfig {
700        seed,
701        ..Default::default()
702    };
703
704    let mut generator = EnhancedQMCGenerator::new(sequence_type, dimension, config)?;
705    generator.generate(n)
706}
707
708#[allow(dead_code)]
709pub fn enhanced_digital_net<F>(
710    n: usize,
711    dimension: usize,
712    t: usize,
713    seed: Option<u64>,
714) -> StatsResult<Array2<F>>
715where
716    F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
717    for<'a> &'a F: std::iter::Product<&'a F>,
718{
719    let net_params = DigitalNetParams {
720        t,
721        m: 32,
722        s: dimension,
723        base: 2,
724    };
725
726    let sequence_type = EnhancedSequenceType::DigitalNet {
727        net_params,
728        construction_method: NetConstructionMethod::Sobol,
729    };
730
731    let config = EnhancedQMCConfig {
732        seed,
733        ..Default::default()
734    };
735
736    let mut generator = EnhancedQMCGenerator::new(sequence_type, dimension, config)?;
737    generator.generate(n)
738}