1use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::rngs::StdRng as RealStdRng;
9use scirs2_core::random::{thread_rng, Rng, SeedableRng};
10use sklears_core::{
11 error::{Result, SklearsError},
12 prelude::{Fit, Transform},
13 traits::{Trained, Untrained},
14 types::Float,
15};
16use std::f64::consts::PI;
17use std::marker::PhantomData;
18
19#[derive(Debug, Clone, Copy)]
21pub enum QuasiRandomSequence {
23 Sobol,
25 Halton,
27 VanDerCorput,
29 Faure,
31}
32
33#[derive(Debug, Clone)]
63pub struct QuasiRandomRBFSampler<State = Untrained> {
65 pub gamma: Float,
67 pub n_components: usize,
69 pub sequence_type: QuasiRandomSequence,
71 pub random_state: Option<u64>,
73
74 random_weights_: Option<Array2<Float>>,
76 random_offset_: Option<Array1<Float>>,
77
78 _state: PhantomData<State>,
80}
81
82impl QuasiRandomRBFSampler<Untrained> {
83 pub fn new(n_components: usize) -> Self {
88 Self {
89 gamma: 1.0,
90 n_components,
91 sequence_type: QuasiRandomSequence::Sobol,
92 random_state: None,
93 random_weights_: None,
94 random_offset_: None,
95 _state: PhantomData,
96 }
97 }
98
99 pub fn gamma(mut self, gamma: Float) -> Self {
101 self.gamma = gamma;
102 self
103 }
104
105 pub fn sequence_type(mut self, sequence_type: QuasiRandomSequence) -> Self {
107 self.sequence_type = sequence_type;
108 self
109 }
110
111 pub fn random_state(mut self, seed: u64) -> Self {
113 self.random_state = Some(seed);
114 self
115 }
116}
117
118impl Fit<Array2<Float>, ()> for QuasiRandomRBFSampler<Untrained> {
119 type Fitted = QuasiRandomRBFSampler<Trained>;
120
121 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
122 let (n_samples, n_features) = x.dim();
123
124 if n_samples == 0 || n_features == 0 {
125 return Err(SklearsError::InvalidInput(
126 "Input array is empty".to_string(),
127 ));
128 }
129
130 let mut rng = match self.random_state {
131 Some(seed) => RealStdRng::seed_from_u64(seed),
132 None => RealStdRng::from_seed(thread_rng().gen()),
133 };
134
135 let random_weights = match self.sequence_type {
137 QuasiRandomSequence::Sobol => {
138 generate_sobol_gaussian(self.n_components, n_features, self.gamma, &mut rng)
139 }
140 QuasiRandomSequence::Halton => {
141 generate_halton_gaussian(self.n_components, n_features, self.gamma, &mut rng)
142 }
143 QuasiRandomSequence::VanDerCorput => generate_van_der_corput_gaussian(
144 self.n_components,
145 n_features,
146 self.gamma,
147 &mut rng,
148 ),
149 QuasiRandomSequence::Faure => {
150 generate_faure_gaussian(self.n_components, n_features, self.gamma, &mut rng)
151 }
152 }?;
153
154 let random_offset = match self.sequence_type {
156 QuasiRandomSequence::Sobol => generate_sobol_uniform(self.n_components, &mut rng),
157 QuasiRandomSequence::Halton => generate_halton_uniform(self.n_components, &mut rng),
158 QuasiRandomSequence::VanDerCorput => {
159 generate_van_der_corput_uniform(self.n_components, &mut rng)
160 }
161 QuasiRandomSequence::Faure => generate_faure_uniform(self.n_components, &mut rng),
162 }
163 .mapv(|x| x * 2.0 * PI);
164
165 Ok(QuasiRandomRBFSampler {
166 gamma: self.gamma,
167 n_components: self.n_components,
168 sequence_type: self.sequence_type,
169 random_state: self.random_state,
170 random_weights_: Some(random_weights),
171 random_offset_: Some(random_offset),
172 _state: PhantomData,
173 })
174 }
175}
176
177impl Transform<Array2<Float>> for QuasiRandomRBFSampler<Trained> {
178 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
179 let random_weights =
180 self.random_weights_
181 .as_ref()
182 .ok_or_else(|| SklearsError::NotFitted {
183 operation: "transform".to_string(),
184 })?;
185
186 let random_offset =
187 self.random_offset_
188 .as_ref()
189 .ok_or_else(|| SklearsError::NotFitted {
190 operation: "transform".to_string(),
191 })?;
192
193 let (_n_samples, n_features) = x.dim();
194
195 if n_features != random_weights.ncols() {
196 return Err(SklearsError::InvalidInput(format!(
197 "Input has {} features, expected {}",
198 n_features,
199 random_weights.ncols()
200 )));
201 }
202
203 let projection = x.dot(&random_weights.t()) + random_offset;
205
206 let normalization = (2.0 / random_weights.nrows() as Float).sqrt();
208 Ok(projection.mapv(|x| x.cos() * normalization))
209 }
210}
211
212fn generate_sobol_gaussian<R: Rng>(
214 n_components: usize,
215 n_features: usize,
216 gamma: Float,
217 _rng: &mut R,
218) -> Result<Array2<Float>> {
219 let mut weights = Array2::zeros((n_components, n_features));
222 let std_dev = (2.0 * gamma).sqrt();
223
224 for i in 0..n_components {
226 for j in 0..n_features {
227 let sobol_u1 = sobol_point(2 * i, 2 * j);
229 let sobol_u2 = sobol_point(2 * i + 1, 2 * j + 1);
230
231 let gaussian = box_muller_transform(sobol_u1, sobol_u2).0;
233 weights[[i, j]] = gaussian * std_dev;
234 }
235 }
236
237 Ok(weights)
238}
239
240fn generate_halton_gaussian<R: Rng>(
242 n_components: usize,
243 n_features: usize,
244 gamma: Float,
245 _rng: &mut R,
246) -> Result<Array2<Float>> {
247 let mut weights = Array2::zeros((n_components, n_features));
248 let std_dev = (2.0 * gamma).sqrt();
249
250 let primes = get_first_primes(2 * n_features);
251
252 for i in 0..n_components {
253 for j in 0..n_features {
254 let u1 = halton_sequence(i + 1, primes[2 * j]);
256 let u2 = halton_sequence(i + 1, primes[2 * j + 1]);
257
258 let gaussian = box_muller_transform(u1, u2).0;
260 weights[[i, j]] = gaussian * std_dev;
261 }
262 }
263
264 Ok(weights)
265}
266
267fn generate_van_der_corput_gaussian<R: Rng>(
269 n_components: usize,
270 n_features: usize,
271 gamma: Float,
272 _rng: &mut R,
273) -> Result<Array2<Float>> {
274 let mut weights = Array2::zeros((n_components, n_features));
275 let std_dev = (2.0 * gamma).sqrt();
276
277 for i in 0..n_components {
278 for j in 0..n_features {
279 let u1 = van_der_corput_sequence(i + 1, 2);
281 let u2 = van_der_corput_sequence(i + 1, 3);
282
283 let gaussian = box_muller_transform(u1, u2).0;
285 weights[[i, j]] = gaussian * std_dev;
286 }
287 }
288
289 Ok(weights)
290}
291
292fn generate_faure_gaussian<R: Rng>(
294 n_components: usize,
295 n_features: usize,
296 gamma: Float,
297 _rng: &mut R,
298) -> Result<Array2<Float>> {
299 generate_halton_gaussian(n_components, n_features, gamma, _rng)
301}
302
303fn generate_sobol_uniform<R: Rng>(n_components: usize, _rng: &mut R) -> Array1<Float> {
305 let mut uniform = Array1::zeros(n_components);
306 for i in 0..n_components {
307 uniform[i] = sobol_point(i, 0);
308 }
309 uniform
310}
311
312fn generate_halton_uniform<R: Rng>(n_components: usize, _rng: &mut R) -> Array1<Float> {
314 let mut uniform = Array1::zeros(n_components);
315 for i in 0..n_components {
316 uniform[i] = halton_sequence(i + 1, 2);
317 }
318 uniform
319}
320
321fn generate_van_der_corput_uniform<R: Rng>(n_components: usize, _rng: &mut R) -> Array1<Float> {
323 let mut uniform = Array1::zeros(n_components);
324 for i in 0..n_components {
325 uniform[i] = van_der_corput_sequence(i + 1, 2);
326 }
327 uniform
328}
329
330fn generate_faure_uniform<R: Rng>(n_components: usize, rng: &mut R) -> Array1<Float> {
332 generate_halton_uniform(n_components, rng)
334}
335
336fn sobol_point(i: usize, dim: usize) -> Float {
338 let mut n = i;
341 let mut result = 0.0;
342 let mut weight = 0.5;
343
344 let gray_code = n ^ (n >> 1);
346 n = gray_code;
347
348 while n > 0 {
349 if n & 1 == 1 {
350 result += weight;
351 }
352 weight *= 0.5;
353 n >>= 1;
354 }
355
356 result = (result + dim as Float * 0.123456789) % 1.0;
358 result
359}
360
361fn halton_sequence(i: usize, base: usize) -> Float {
363 let mut result = 0.0;
364 let mut f = 1.0 / base as Float;
365 let mut i = i;
366
367 while i > 0 {
368 result += f * (i % base) as Float;
369 i /= base;
370 f /= base as Float;
371 }
372
373 result
374}
375
376fn van_der_corput_sequence(i: usize, base: usize) -> Float {
378 halton_sequence(i, base)
379}
380
381fn box_muller_transform(u1: Float, u2: Float) -> (Float, Float) {
383 let u1 = u1.max(1e-10).min(1.0 - 1e-10); let u2 = u2.max(1e-10).min(1.0 - 1e-10);
385
386 let r = (-2.0 * u1.ln()).sqrt();
387 let theta = 2.0 * PI * u2;
388
389 (r * theta.cos(), r * theta.sin())
390}
391
392fn get_first_primes(n: usize) -> Vec<usize> {
394 if n == 0 {
395 return vec![];
396 }
397
398 let mut primes = vec![2];
399 let mut candidate = 3;
400
401 while primes.len() < n {
402 let mut is_prime = true;
403 let sqrt_candidate = (candidate as f64).sqrt() as usize;
404
405 for &prime in &primes {
406 if prime > sqrt_candidate {
407 break;
408 }
409 if candidate % prime == 0 {
410 is_prime = false;
411 break;
412 }
413 }
414
415 if is_prime {
416 primes.push(candidate);
417 }
418 candidate += 2; }
420
421 primes
422}
423
424#[allow(non_snake_case)]
425#[cfg(test)]
426mod tests {
427 use super::*;
428 use approx::assert_abs_diff_eq;
429 use scirs2_core::ndarray::array;
430
431 #[test]
432 fn test_quasi_random_rbf_sampler_basic() {
433 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
434
435 let sampler = QuasiRandomRBFSampler::new(10)
436 .gamma(1.0)
437 .sequence_type(QuasiRandomSequence::Sobol)
438 .random_state(42);
439
440 let fitted = sampler.fit(&x, &()).unwrap();
441 let features = fitted.transform(&x).unwrap();
442
443 assert_eq!(features.shape(), &[3, 10]);
444
445 for &val in features.iter() {
447 assert!(val >= -2.0 && val <= 2.0);
448 }
449 }
450
451 #[test]
452 fn test_different_sequence_types() {
453 let x = array![[1.0, 2.0], [3.0, 4.0]];
454
455 let sequences = [
456 QuasiRandomSequence::Sobol,
457 QuasiRandomSequence::Halton,
458 QuasiRandomSequence::VanDerCorput,
459 QuasiRandomSequence::Faure,
460 ];
461
462 for seq_type in &sequences {
463 let sampler = QuasiRandomRBFSampler::new(20)
464 .sequence_type(*seq_type)
465 .random_state(42);
466
467 let fitted = sampler.fit(&x, &()).unwrap();
468 let features = fitted.transform(&x).unwrap();
469
470 assert_eq!(features.shape(), &[2, 20]);
471 }
472 }
473
474 #[test]
475 fn test_reproducibility() {
476 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
477
478 let sampler1 = QuasiRandomRBFSampler::new(50)
479 .gamma(0.5)
480 .sequence_type(QuasiRandomSequence::Halton)
481 .random_state(123);
482
483 let sampler2 = QuasiRandomRBFSampler::new(50)
484 .gamma(0.5)
485 .sequence_type(QuasiRandomSequence::Halton)
486 .random_state(123);
487
488 let fitted1 = sampler1.fit(&x, &()).unwrap();
489 let fitted2 = sampler2.fit(&x, &()).unwrap();
490
491 let features1 = fitted1.transform(&x).unwrap();
492 let features2 = fitted2.transform(&x).unwrap();
493
494 for (f1, f2) in features1.iter().zip(features2.iter()) {
495 assert_abs_diff_eq!(f1, f2, epsilon = 1e-10);
496 }
497 }
498
499 #[test]
500 fn test_halton_sequence() {
501 assert_abs_diff_eq!(halton_sequence(1, 2), 0.5, epsilon = 1e-10);
503 assert_abs_diff_eq!(halton_sequence(2, 2), 0.25, epsilon = 1e-10);
504 assert_abs_diff_eq!(halton_sequence(3, 2), 0.75, epsilon = 1e-10);
505 assert_abs_diff_eq!(halton_sequence(4, 2), 0.125, epsilon = 1e-10);
506
507 assert_abs_diff_eq!(halton_sequence(1, 3), 1.0 / 3.0, epsilon = 1e-10);
509 assert_abs_diff_eq!(halton_sequence(2, 3), 2.0 / 3.0, epsilon = 1e-10);
510 assert_abs_diff_eq!(halton_sequence(3, 3), 1.0 / 9.0, epsilon = 1e-10);
511 }
512
513 #[test]
514 fn test_van_der_corput_sequence() {
515 for i in 1..10 {
517 assert_abs_diff_eq!(
518 van_der_corput_sequence(i, 2),
519 halton_sequence(i, 2),
520 epsilon = 1e-10
521 );
522 }
523 }
524
525 #[test]
526 fn test_box_muller_transform() {
527 let (z1, z2) = box_muller_transform(0.5, 0.5);
528 assert!(z1.abs() < 10.0); assert!(z2.abs() < 10.0);
531
532 let (z1, z2) = box_muller_transform(0.999, 0.001);
534 assert!(z1.is_finite());
535 assert!(z2.is_finite());
536 }
537
538 #[test]
539 fn test_get_first_primes() {
540 let primes = get_first_primes(10);
541 let expected: Vec<usize> = vec![2, 3, 5, 7, 11, 13, 17, 19, 23, 29];
542 assert_eq!(primes, expected);
543
544 let empty = get_first_primes(0);
545 let expected_empty: Vec<usize> = vec![];
546 assert_eq!(empty, expected_empty);
547
548 let first = get_first_primes(1);
549 let expected_first: Vec<usize> = vec![2];
550 assert_eq!(first, expected_first);
551 }
552
553 #[test]
554 fn test_gamma_parameter() {
555 let x = array![[1.0, 2.0], [3.0, 4.0]];
556
557 let sampler_low = QuasiRandomRBFSampler::new(100).gamma(0.1).random_state(42);
558
559 let sampler_high = QuasiRandomRBFSampler::new(100).gamma(10.0).random_state(42);
560
561 let fitted_low = sampler_low.fit(&x, &()).unwrap();
562 let fitted_high = sampler_high.fit(&x, &()).unwrap();
563
564 let features_low = fitted_low.transform(&x).unwrap();
565 let features_high = fitted_high.transform(&x).unwrap();
566
567 assert!(features_low != features_high);
569 }
570
571 #[test]
572 fn test_error_handling() {
573 let empty = Array2::<Float>::zeros((0, 0));
574 let sampler = QuasiRandomRBFSampler::new(10);
575
576 assert!(sampler.clone().fit(&empty, &()).is_err());
577
578 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
580 let x_test = array![[1.0, 2.0, 3.0]]; let fitted = sampler.fit(&x_train, &()).unwrap();
583 assert!(fitted.transform(&x_test).is_err());
584 }
585
586 #[test]
587 fn test_sobol_point_properties() {
588 for i in 0..100 {
590 let point = sobol_point(i, 0);
591 assert!(point >= 0.0 && point <= 1.0);
592 }
593
594 let point_dim0 = sobol_point(5, 0);
596 let point_dim1 = sobol_point(5, 1);
597 assert_ne!(point_dim0, point_dim1);
598 }
599}