1use crate::error::{StatsError, StatsResult};
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
12use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
13use scirs2_core::{parallel_ops::*, simd_ops::SimdUnifiedOps, validation::*};
14use std::marker::PhantomData;
15
16pub struct EnhancedQMCGenerator<F> {
18 pub sequence_type: EnhancedSequenceType,
20 pub dimension: usize,
22 pub config: EnhancedQMCConfig,
24 pub state: QMCGeneratorState,
26 _phantom: PhantomData<F>,
27}
28
29#[derive(Debug, Clone, PartialEq)]
31pub enum EnhancedSequenceType {
32 SobolAdvanced {
34 owen_scrambling: bool,
36 digital_shift: bool,
38 nested_scrambling: bool,
40 },
41 Niederreiter {
43 base_strategy: BaseSelectionStrategy,
45 matrix_optimization: bool,
47 },
48 FaureImproved {
50 permutation_optimization: bool,
52 radical_inverse_improvements: bool,
54 },
55 DigitalNet {
57 net_params: DigitalNetParams,
59 construction_method: NetConstructionMethod,
61 },
62 Hybrid {
64 primary: Box<EnhancedSequenceType>,
66 secondary: Box<EnhancedSequenceType>,
68 combination: HybridCombinationStrategy,
70 },
71}
72
73#[derive(Debug, Clone, PartialEq)]
75pub enum BaseSelectionStrategy {
76 FirstPrimes,
78 OptimizedPrimes,
80 PrimePowers,
82 Automatic,
84}
85
86#[derive(Debug, Clone, PartialEq)]
88pub struct DigitalNetParams {
89 pub t: usize,
91 pub m: usize,
93 pub s: usize,
95 pub base: usize,
97}
98
99#[derive(Debug, Clone, PartialEq)]
101pub enum NetConstructionMethod {
102 Sobol,
104 NiederreiterXing,
106 PolynomialLattice,
108 FiniteField,
110}
111
112#[derive(Debug, Clone, PartialEq)]
114pub enum HybridCombinationStrategy {
115 Interleave,
117 Weighted(f64),
119 DimensionAlternation,
121 Adaptive,
123}
124
125#[derive(Debug, Clone)]
127pub struct EnhancedQMCConfig {
128 pub parallel: bool,
130 pub chunksize: usize,
132 pub seed: Option<u64>,
134 pub use_simd: bool,
136 pub quality_threshold: f64,
138 pub max_assessment_length: usize,
140 pub adaptive_refinement: bool,
142}
143
144impl Default for EnhancedQMCConfig {
145 fn default() -> Self {
146 Self {
147 parallel: true,
148 chunksize: 1000,
149 seed: None,
150 use_simd: true,
151 quality_threshold: 1e-3,
152 max_assessment_length: 10000,
153 adaptive_refinement: false,
154 }
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct QMCGeneratorState {
161 pub current_index: usize,
163 pub scrambling_matrices: Option<Vec<Array2<u32>>>,
165 pub digital_shifts: Option<Vec<Array1<u32>>>,
167 pub quality_metrics: QualityMetrics,
169}
170
171#[derive(Debug, Clone, Default)]
173pub struct QualityMetrics {
174 pub star_discrepancy: f64,
176 pub wraparound_discrepancy: f64,
178 pub diaphony: f64,
180 pub figure_of_merit: f64,
182}
183
184impl<F> EnhancedQMCGenerator<F>
185where
186 F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
187 for<'a> &'a F: std::iter::Product<&'a F>,
188{
189 pub fn new(
191 sequence_type: EnhancedSequenceType,
192 dimension: usize,
193 config: EnhancedQMCConfig,
194 ) -> StatsResult<Self> {
195 check_positive(dimension, "dimension")?;
196
197 if dimension > 1000 {
198 return Err(StatsError::InvalidArgument(
199 "Dimension cannot exceed 1000 for enhanced QMC sequences".to_string(),
200 ));
201 }
202
203 let state = QMCGeneratorState {
204 current_index: 0,
205 scrambling_matrices: None,
206 digital_shifts: None,
207 quality_metrics: QualityMetrics::default(),
208 };
209
210 let mut generator = Self {
211 sequence_type,
212 dimension,
213 config,
214 state,
215 _phantom: PhantomData,
216 };
217
218 generator.initialize_randomization()?;
220
221 Ok(generator)
222 }
223
224 pub fn generate(&mut self, n: usize) -> StatsResult<Array2<F>> {
226 check_positive(n, "n")?;
227
228 if self.config.parallel && n >= self.config.chunksize {
229 self.generate_parallel(n)
230 } else {
231 self.generate_sequential(n)
232 }
233 }
234
235 fn generate_parallel(&mut self, n: usize) -> StatsResult<Array2<F>> {
237 let chunksize = self.config.chunksize;
238 let num_chunks = n.div_ceil(chunksize);
239
240 let chunks = parallel_map_result(
241 (0..num_chunks).collect::<Vec<_>>().as_slice(),
242 |&chunk_idx| {
243 let start = chunk_idx * chunksize;
244 let end = (start + chunksize).min(n);
245 let chunksize = end - start;
246
247 self.generate_chunk(start, chunksize)
248 },
249 )?;
250
251 let mut result = Array2::zeros((n, self.dimension));
253 let mut row_idx = 0;
254
255 for chunk in chunks {
256 let chunk = chunk;
257 let chunk_rows = chunk.nrows();
258 result
259 .slice_mut(scirs2_core::ndarray::s![row_idx..row_idx + chunk_rows, ..])
260 .assign(&chunk);
261 row_idx += chunk_rows;
262 }
263
264 if n <= self.config.max_assessment_length {
266 self.assess_quality(&result)?;
267 }
268
269 Ok(result)
270 }
271
272 fn generate_sequential(&mut self, n: usize) -> StatsResult<Array2<F>> {
274 let mut result = Array2::zeros((n, self.dimension));
275
276 for i in 0..n {
277 let point = self.next_point()?;
278 result.row_mut(i).assign(&point);
279 }
280
281 if n <= self.config.max_assessment_length {
283 self.assess_quality(&result)?;
284 }
285
286 Ok(result)
287 }
288
289 fn generate_chunk(&self, start_index: usize, chunksize: usize) -> StatsResult<Array2<F>> {
291 let mut chunk = Array2::zeros((chunksize, self.dimension));
292
293 for i in 0..chunksize {
294 let _index = start_index + i;
295 let point = self.compute_point_at_index(_index)?;
296 chunk.row_mut(i).assign(&point);
297 }
298
299 Ok(chunk)
300 }
301
302 fn next_point(&mut self) -> StatsResult<Array1<F>> {
304 let point = self.compute_point_at_index(self.state.current_index)?;
305 self.state.current_index += 1;
306 Ok(point)
307 }
308
309 fn compute_point_at_index(&self, index: usize) -> StatsResult<Array1<F>> {
311 match &self.sequence_type {
312 EnhancedSequenceType::SobolAdvanced {
313 owen_scrambling,
314 digital_shift,
315 nested_scrambling,
316 } => self.compute_sobol_advanced(
317 index,
318 *owen_scrambling,
319 *digital_shift,
320 *nested_scrambling,
321 ),
322 EnhancedSequenceType::Niederreiter {
323 base_strategy,
324 matrix_optimization,
325 } => self.compute_niederreiter_enhanced(index, base_strategy, *matrix_optimization),
326 EnhancedSequenceType::FaureImproved {
327 permutation_optimization,
328 radical_inverse_improvements,
329 } => self.compute_faure_improved(
330 index,
331 *permutation_optimization,
332 *radical_inverse_improvements,
333 ),
334 EnhancedSequenceType::DigitalNet {
335 net_params,
336 construction_method,
337 } => self.compute_digital_net(index, net_params, construction_method),
338 EnhancedSequenceType::Hybrid {
339 primary,
340 secondary,
341 combination,
342 } => self.compute_hybrid_sequence(index, primary, secondary, combination),
343 }
344 }
345
346 fn compute_sobol_advanced(
348 &self,
349 index: usize,
350 owen_scrambling: bool,
351 digital_shift: bool,
352 _nested_scrambling: bool,
353 ) -> StatsResult<Array1<F>> {
354 let mut point = Array1::zeros(self.dimension);
355
356 for dim in 0..self.dimension {
359 let mut result = 0u32;
360 let mut idx = index;
361
362 let mut base_power = 1u32;
364 while idx > 0 {
365 if idx & 1 == 1 {
366 result ^= base_power << (31 - (base_power.trailing_zeros()));
367 }
368 idx >>= 1;
369 base_power <<= 1;
370 }
371
372 if owen_scrambling {
374 if let Some(ref matrices) = self.state.scrambling_matrices {
375 if dim < matrices.len() {
376 result = self.apply_owen_scrambling(result, &matrices[dim]);
377 }
378 }
379 }
380
381 if digital_shift {
383 if let Some(ref shifts) = self.state.digital_shifts {
384 if dim < shifts.len() {
385 result ^= shifts[dim][0]; }
387 }
388 }
389
390 point[dim] = F::from(result as f64 / (1u64 << 32) as f64).expect("Operation failed");
391 }
392
393 Ok(point)
394 }
395
396 fn compute_niederreiter_enhanced(
398 &self,
399 index: usize,
400 base_strategy: &BaseSelectionStrategy,
401 _matrix_optimization: bool,
402 ) -> StatsResult<Array1<F>> {
403 let mut point = Array1::zeros(self.dimension);
405
406 for dim in 0..self.dimension {
407 let base = self.get_prime(dim);
409 point[dim] = F::from(self.radical_inverse(index, base)).expect("Operation failed");
410 }
411
412 Ok(point)
413 }
414
415 fn compute_faure_improved(
417 &self,
418 index: usize,
419 _permutation_optimization: bool,
420 _radical_inverse_improvements: bool,
421 ) -> StatsResult<Array1<F>> {
422 let mut point = Array1::zeros(self.dimension);
424 let base = self.smallest_prime_geq(self.dimension as u32);
425
426 for dim in 0..self.dimension {
427 point[dim] = F::from(self.radical_inverse(index, base)).expect("Operation failed");
428 }
429
430 Ok(point)
431 }
432
433 fn compute_digital_net(
435 &self,
436 index: usize,
437 _net_params: &DigitalNetParams,
438 _construction_method: &NetConstructionMethod,
439 ) -> StatsResult<Array1<F>> {
440 self.compute_sobol_advanced(index, false, false, false)
442 }
443
444 fn compute_hybrid_sequence(
446 &self,
447 index: usize,
448 _primary: &EnhancedSequenceType,
449 _secondary: &EnhancedSequenceType,
450 _combination: &HybridCombinationStrategy,
451 ) -> StatsResult<Array1<F>> {
452 self.compute_sobol_advanced(index, true, true, false)
454 }
455
456 fn initialize_randomization(&mut self) -> StatsResult<()> {
458 let mut rng = match self.config.seed {
459 Some(seed) => StdRng::seed_from_u64(seed),
460 None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
461 };
462
463 if self.needs_scrambling() {
465 let mut matrices = Vec::with_capacity(self.dimension);
466 for _ in 0..self.dimension {
467 matrices.push(self.generate_scrambling_matrix(&mut rng)?);
468 }
469 self.state.scrambling_matrices = Some(matrices);
470 }
471
472 if self.needs_digital_shift() {
474 let mut shifts = Vec::with_capacity(self.dimension);
475 for _ in 0..self.dimension {
476 let shift = Array1::from_shape_fn(32, |_| rng.random::<u32>());
477 shifts.push(shift);
478 }
479 self.state.digital_shifts = Some(shifts);
480 }
481
482 Ok(())
483 }
484
485 fn needs_scrambling(&self) -> bool {
487 match &self.sequence_type {
488 EnhancedSequenceType::SobolAdvanced {
489 owen_scrambling, ..
490 } => *owen_scrambling,
491 _ => false,
492 }
493 }
494
495 fn needs_digital_shift(&self) -> bool {
497 match &self.sequence_type {
498 EnhancedSequenceType::SobolAdvanced { digital_shift, .. } => *digital_shift,
499 _ => false,
500 }
501 }
502
503 fn generate_scrambling_matrix<R: Rng>(&self, rng: &mut R) -> StatsResult<Array2<u32>> {
505 let mut matrix = Array2::zeros((32, 32));
506
507 for i in 0..32 {
509 let j = rng.random_range(0..32);
510 matrix[[i, j]] = 1;
511 }
512
513 Ok(matrix)
514 }
515
516 fn apply_owen_scrambling(&self, value: u32, matrix: &Array2<u32>) -> u32 {
518 let mut result = 0u32;
519
520 for i in 0..32 {
521 let bit = (value >> (31 - i)) & 1;
522 for j in 0..32 {
523 if matrix[[i, j]] == 1 && bit == 1 {
524 result |= 1u32 << (31 - j);
525 break;
526 }
527 }
528 }
529
530 result
531 }
532
533 fn radical_inverse(&self, index: usize, base: u32) -> f64 {
535 let mut result = 0.0;
536 let mut fraction = 1.0 / base as f64;
537 let mut i = index;
538
539 while i > 0 {
540 result += (i % base as usize) as f64 * fraction;
541 i /= base as usize;
542 fraction /= base as f64;
543 }
544
545 result
546 }
547
548 fn get_prime(&self, n: usize) -> u32 {
550 let primes = [
551 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
552 ];
553 if n < primes.len() {
554 primes[n]
555 } else {
556 let mut candidate = primes[primes.len() - 1] + 2;
558 let mut count = primes.len();
559
560 while count <= n {
561 if self.is_prime(candidate) {
562 if count == n {
563 return candidate;
564 }
565 count += 1;
566 }
567 candidate += 2;
568 }
569 candidate
570 }
571 }
572
573 fn smallest_prime_geq(&self, n: u32) -> u32 {
575 if n <= 2 {
576 return 2;
577 }
578
579 let mut candidate = if n.is_multiple_of(2) { n + 1 } else { n };
580
581 while !self.is_prime(candidate) {
582 candidate += 2;
583 }
584
585 candidate
586 }
587
588 fn is_prime(&self, n: u32) -> bool {
590 if n < 2 {
591 return false;
592 }
593 if n == 2 {
594 return true;
595 }
596 if n.is_multiple_of(2) {
597 return false;
598 }
599
600 let sqrt_n = (n as f64).sqrt() as u32;
601 for i in (3..=sqrt_n).step_by(2) {
602 if n.is_multiple_of(i) {
603 return false;
604 }
605 }
606
607 true
608 }
609
610 fn assess_quality(&mut self, sequence: &Array2<F>) -> StatsResult<()> {
612 let n = sequence.nrows();
614 let d = sequence.ncols();
615
616 let mut max_discrepancy = 0.0;
618 let num_test_points = 50.min(n);
619
620 let mut rng = scirs2_core::random::thread_rng();
621 for _ in 0..num_test_points {
622 let mut test_point = Array1::zeros(d);
623 for j in 0..d {
624 test_point[j] = F::from(rng.random::<f64>()).expect("Operation failed");
625 }
626
627 let mut count = 0;
628 for i in 0..n {
629 let mut in_box = true;
630 for j in 0..d {
631 if sequence[[i, j]] > test_point[j] {
632 in_box = false;
633 break;
634 }
635 }
636 if in_box {
637 count += 1;
638 }
639 }
640
641 let volume: F = test_point.iter().fold(F::one(), |acc, &x| acc * x);
642 let expected = volume.to_f64().expect("Operation failed") * n as f64;
643 let discrepancy = (count as f64 - expected).abs() / n as f64;
644 max_discrepancy = max_discrepancy.max(discrepancy);
645 }
646
647 self.state.quality_metrics.star_discrepancy = max_discrepancy;
648 Ok(())
649 }
650
651 pub fn quality_metrics(&self) -> &QualityMetrics {
653 &self.state.quality_metrics
654 }
655}
656
657#[allow(dead_code)]
659pub fn enhanced_sobol<F>(
660 n: usize,
661 dimension: usize,
662 scrambling: bool,
663 seed: Option<u64>,
664) -> StatsResult<Array2<F>>
665where
666 F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
667 for<'a> &'a F: std::iter::Product<&'a F>,
668{
669 let sequence_type = EnhancedSequenceType::SobolAdvanced {
670 owen_scrambling: scrambling,
671 digital_shift: true,
672 nested_scrambling: false,
673 };
674
675 let config = EnhancedQMCConfig {
676 seed,
677 ..Default::default()
678 };
679
680 let mut generator = EnhancedQMCGenerator::new(sequence_type, dimension, config)?;
681 generator.generate(n)
682}
683
684#[allow(dead_code)]
685pub fn enhanced_niederreiter<F>(
686 n: usize,
687 dimension: usize,
688 seed: Option<u64>,
689) -> StatsResult<Array2<F>>
690where
691 F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
692 for<'a> &'a F: std::iter::Product<&'a F>,
693{
694 let sequence_type = EnhancedSequenceType::Niederreiter {
695 base_strategy: BaseSelectionStrategy::OptimizedPrimes,
696 matrix_optimization: true,
697 };
698
699 let config = EnhancedQMCConfig {
700 seed,
701 ..Default::default()
702 };
703
704 let mut generator = EnhancedQMCGenerator::new(sequence_type, dimension, config)?;
705 generator.generate(n)
706}
707
708#[allow(dead_code)]
709pub fn enhanced_digital_net<F>(
710 n: usize,
711 dimension: usize,
712 t: usize,
713 seed: Option<u64>,
714) -> StatsResult<Array2<F>>
715where
716 F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
717 for<'a> &'a F: std::iter::Product<&'a F>,
718{
719 let net_params = DigitalNetParams {
720 t,
721 m: 32,
722 s: dimension,
723 base: 2,
724 };
725
726 let sequence_type = EnhancedSequenceType::DigitalNet {
727 net_params,
728 construction_method: NetConstructionMethod::Sobol,
729 };
730
731 let config = EnhancedQMCConfig {
732 seed,
733 ..Default::default()
734 };
735
736 let mut generator = EnhancedQMCGenerator::new(sequence_type, dimension, config)?;
737 generator.generate(n)
738}