1use crate::error::{IntegrateError, IntegrateResult};
17use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
18use scirs2_core::numeric::Complex64;
19use scirs2_core::random::prelude::{Normal, Rng, StdRng};
20use scirs2_core::Distribution;
21use scirs2_fft::{fft, ifft};
22
23#[derive(Debug, Clone)]
25pub enum CorrelationFunction {
26 Exponential { length_scale: f64 },
28 Gaussian { length_scale: f64 },
30 Matern { nu: f64, length_scale: f64 },
33 Powered { exponent: f64, length_scale: f64 },
35}
36
37impl CorrelationFunction {
38 pub fn evaluate(&self, r: f64) -> f64 {
40 match self {
41 CorrelationFunction::Exponential { length_scale } => (-r / length_scale).exp(),
42 CorrelationFunction::Gaussian { length_scale } => {
43 let z = r / length_scale;
44 (-0.5 * z * z).exp()
45 }
46 CorrelationFunction::Matern { nu, length_scale } => {
47 evaluate_matern(r, *nu, *length_scale)
48 }
49 CorrelationFunction::Powered {
50 exponent,
51 length_scale,
52 } => {
53 let z = r / length_scale;
54 (-(z.powf(*exponent))).exp()
55 }
56 }
57 }
58
59 pub fn spectral_density_1d(&self, omega: f64) -> f64 {
61 let omega2 = omega * omega;
62 match self {
63 CorrelationFunction::Exponential { length_scale } => {
64 let ell = length_scale;
65 2.0 * ell / (1.0 + ell * ell * omega2)
66 }
67 CorrelationFunction::Gaussian { length_scale } => {
68 let ell = length_scale;
69 ell * std::f64::consts::TAU.sqrt() * (-0.5 * ell * ell * omega2).exp()
70 }
71 CorrelationFunction::Matern { nu, length_scale } => {
72 let ell = length_scale;
73 let lambda = (2.0 * nu).sqrt() / ell;
74 match (nu * 2.0).round() as i32 {
75 1 => 2.0 / (lambda * lambda + omega2),
76 3 => {
77 let d = lambda * lambda + omega2;
78 4.0 * lambda * lambda / (d * d)
79 }
80 5 => {
81 let d = lambda * lambda + omega2;
82 (8.0 / 3.0) * lambda.powi(4) / d.powi(3)
83 }
84 _ => {
85 let eff_ell = ell * (*nu).sqrt();
86 eff_ell
87 * std::f64::consts::TAU.sqrt()
88 * (-0.5 * eff_ell * eff_ell * omega2).exp()
89 }
90 }
91 }
92 CorrelationFunction::Powered {
93 exponent: _,
94 length_scale,
95 } => {
96 let ell = length_scale;
97 ell * std::f64::consts::TAU.sqrt() * (-0.5 * ell * ell * omega2).exp()
98 }
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct RandomField {
106 pub grid: Array2<f64>,
108 pub covariance: CorrelationFunction,
110}
111
112impl RandomField {
113 pub fn sample_circulant_embedding(
122 grid_x: ArrayView1<f64>,
123 grid_y: ArrayView1<f64>,
124 cov: CorrelationFunction,
125 rng: &mut StdRng,
126 ) -> IntegrateResult<Array2<f64>> {
127 let nx = grid_x.len();
128 let ny = grid_y.len();
129 if nx == 0 || ny == 0 {
130 return Err(IntegrateError::InvalidInput(
131 "Grid dimensions must be positive".to_string(),
132 ));
133 }
134
135 let m = 2 * nx;
136 let n = 2 * ny;
137 let normal = Normal::new(0.0_f64, 1.0).map_err(|e| {
138 IntegrateError::ComputationError(format!("Normal distribution error: {e}"))
139 })?;
140
141 let mut cov_flat = vec![0.0_f64; m * n];
143 for i in 0..m {
144 let xi = if i < nx {
145 grid_x[i] - grid_x[0]
146 } else if i == nx {
147 grid_x[nx - 1] - grid_x[0]
149 } else {
150 grid_x[m - i] - grid_x[0]
152 };
153 for j in 0..n {
154 let yj = if j < ny {
155 grid_y[j] - grid_y[0]
156 } else if j == ny {
157 grid_y[ny - 1] - grid_y[0]
159 } else {
160 grid_y[n - j] - grid_y[0]
162 };
163 let r = (xi * xi + yj * yj).sqrt();
164 cov_flat[i * n + j] = cov.evaluate(r);
165 }
166 }
167
168 let eigenvalues = fft2d_real_from_flat(&cov_flat, m, n)?;
170
171 let scale = 1.0 / ((m * n) as f64).sqrt();
173 let mut noise_complex: Vec<Complex64> = (0..m * n)
174 .map(|_| {
175 let re = rng.sample(normal);
176 let im = rng.sample(normal);
177 Complex64::new(re, im)
178 })
179 .collect();
180
181 for (idx, c) in noise_complex.iter_mut().enumerate() {
182 let lambda = eigenvalues[idx].max(0.0).sqrt() * scale;
183 c.re *= lambda;
184 c.im *= lambda;
185 }
186
187 let field_full = ifft2d_complex_flat(&noise_complex, m, n)?;
189
190 let mut result = Array2::<f64>::zeros((nx, ny));
191 for i in 0..nx {
192 for j in 0..ny {
193 result[[i, j]] = field_full[i * n + j];
194 }
195 }
196 Ok(result)
197 }
198
199 pub fn sample_kl_expansion(
211 grid_x: ArrayView1<f64>,
212 grid_y: ArrayView1<f64>,
213 cov: CorrelationFunction,
214 n_terms: usize,
215 rng: &mut StdRng,
216 ) -> IntegrateResult<Array2<f64>> {
217 let nx = grid_x.len();
218 let ny = grid_y.len();
219 let n_pts = nx * ny;
220
221 if n_pts == 0 {
222 return Err(IntegrateError::InvalidInput(
223 "Grid must be non-empty".to_string(),
224 ));
225 }
226 if n_terms == 0 {
227 return Err(IntegrateError::InvalidInput(
228 "n_terms must be at least 1".to_string(),
229 ));
230 }
231 let n_terms = n_terms.min(n_pts);
232
233 let mut coords_x = vec![0.0_f64; n_pts];
235 let mut coords_y = vec![0.0_f64; n_pts];
236 for i in 0..nx {
237 for j in 0..ny {
238 coords_x[i * ny + j] = grid_x[i];
239 coords_y[i * ny + j] = grid_y[j];
240 }
241 }
242
243 let mut cov_mat = vec![0.0_f64; n_pts * n_pts];
245 for p in 0..n_pts {
246 for q in p..n_pts {
247 let dx = coords_x[p] - coords_x[q];
248 let dy = coords_y[p] - coords_y[q];
249 let r = (dx * dx + dy * dy).sqrt();
250 let c = cov.evaluate(r);
251 cov_mat[p * n_pts + q] = c;
252 cov_mat[q * n_pts + p] = c;
253 }
254 }
255
256 let (eigenvalues, eigenvectors) = power_iteration_eigenpairs(&cov_mat, n_pts, n_terms)?;
258
259 let normal = Normal::new(0.0_f64, 1.0).map_err(|e| {
260 IntegrateError::ComputationError(format!("Normal distribution error: {e}"))
261 })?;
262
263 let xi: Vec<f64> = (0..n_terms).map(|_| rng.sample(normal)).collect();
264
265 let mut u_flat = vec![0.0_f64; n_pts];
267 for k in 0..n_terms {
268 let lambda_k = eigenvalues[k].max(0.0).sqrt() * xi[k];
269 for p in 0..n_pts {
270 u_flat[p] += lambda_k * eigenvectors[k][p];
271 }
272 }
273
274 let mut result = Array2::<f64>::zeros((nx, ny));
275 for i in 0..nx {
276 for j in 0..ny {
277 result[[i, j]] = u_flat[i * ny + j];
278 }
279 }
280 Ok(result)
281 }
282
283 pub fn sample_fourier(
290 nx: usize,
291 ny: usize,
292 lx: f64,
293 ly: f64,
294 cov: CorrelationFunction,
295 rng: &mut StdRng,
296 ) -> IntegrateResult<Array2<f64>> {
297 if nx == 0 || ny == 0 {
298 return Err(IntegrateError::InvalidInput(
299 "Grid dimensions must be positive".to_string(),
300 ));
301 }
302
303 let normal = Normal::new(0.0_f64, 1.0).map_err(|e| {
304 IntegrateError::ComputationError(format!("Normal distribution error: {e}"))
305 })?;
306
307 let dx = lx / nx as f64;
308 let dy = ly / ny as f64;
309 let two_pi_over_lx = std::f64::consts::TAU / lx;
310 let two_pi_over_ly = std::f64::consts::TAU / ly;
311
312 let mut z_complex: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); nx * ny];
314 for k in 0..nx {
315 let omega_x = (if k <= nx / 2 {
316 k as f64
317 } else {
318 k as f64 - nx as f64
319 }) * two_pi_over_lx;
320 for l in 0..ny {
321 let omega_y = (if l <= ny / 2 {
322 l as f64
323 } else {
324 l as f64 - ny as f64
325 }) * two_pi_over_ly;
326 let s_x = cov.spectral_density_1d(omega_x);
327 let s_y = cov.spectral_density_1d(omega_y);
328 let amplitude = (s_x * s_y / (dx * dy)).max(0.0).sqrt();
329 let re = rng.sample(normal);
330 let im = rng.sample(normal);
331 z_complex[k * ny + l] = Complex64::new(amplitude * re, amplitude * im);
332 }
333 }
334
335 let field_full = ifft2d_complex_flat(&z_complex, nx, ny)?;
337
338 let mut result = Array2::<f64>::zeros((nx, ny));
339 for i in 0..nx {
340 for j in 0..ny {
341 result[[i, j]] = field_full[i * ny + j];
342 }
343 }
344 Ok(result)
345 }
346}
347
348fn evaluate_matern(r: f64, nu: f64, length_scale: f64) -> f64 {
354 if r < 1e-14 {
355 return 1.0;
356 }
357 let sqrt2nu_r_over_ell = (2.0 * nu).sqrt() * r / length_scale;
358 let nu2 = (nu * 2.0).round() as i32;
359 match nu2 {
360 1 => (-sqrt2nu_r_over_ell).exp(),
361 3 => {
362 let x = sqrt2nu_r_over_ell;
363 (1.0 + x) * (-x).exp()
364 }
365 5 => {
366 let x = sqrt2nu_r_over_ell;
367 (1.0 + x + x * x / 3.0) * (-x).exp()
368 }
369 _ => {
370 if nu > 50.0 {
371 let z = r / length_scale;
372 return (-0.5 * z * z).exp();
373 }
374 let x = sqrt2nu_r_over_ell;
375 let bk = bessel_k_approx(nu, x);
376 if bk <= 0.0 || !bk.is_finite() {
377 return 0.0;
378 }
379 let log_val = nu * x.ln() - log_gamma(nu) + (1.0 - nu) * 2.0_f64.ln() + bk.ln();
380 log_val.exp().min(1.0).max(0.0)
381 }
382 }
383}
384
385fn log_gamma(x: f64) -> f64 {
387 if x <= 0.0 {
388 return f64::INFINITY;
389 }
390 let g = 7.0_f64;
391 let c = [
392 0.999_999_999_999_810,
393 676.520_368_121_885,
394 -1_259.139_216_722_403,
395 771.323_428_777_653,
396 -176.615_029_162_141,
397 12.507_343_278_686_9,
398 -0.138_571_095_265_720,
399 9.984_369_578_019_57e-6,
400 1.505_632_735_149_31e-7,
401 ];
402 let x = x - 1.0;
403 let t = x + g + 0.5;
404 let mut sum = c[0];
405 for (i, &ci) in c[1..].iter().enumerate() {
406 sum += ci / (x + i as f64 + 1.0);
407 }
408 0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * t.ln() - t + sum.ln()
409}
410
411fn bessel_k_approx(nu: f64, x: f64) -> f64 {
413 if x < 1e-10 {
414 return 1e10;
415 }
416 let sqrt_pi_over_2x = (std::f64::consts::PI / (2.0 * x)).sqrt();
417 let exp_neg_x = (-x).exp();
418 let correction = 1.0 + (4.0 * nu * nu - 1.0) / (8.0 * x);
419 sqrt_pi_over_2x * exp_neg_x * correction
420}
421
422fn fft2d_real_from_flat(a: &[f64], m: usize, n: usize) -> IntegrateResult<Vec<f64>> {
424 let complex_in: Vec<Complex64> = a.iter().map(|&v| Complex64::new(v, 0.0)).collect();
425 let (real_out, _) = fft2d_complex_transform_flat(&complex_in, m, n, false)?;
426 Ok(real_out)
427}
428
429fn ifft2d_complex_flat(z: &[Complex64], m: usize, n: usize) -> IntegrateResult<Vec<f64>> {
431 let (real_out, _) = fft2d_complex_transform_flat(z, m, n, true)?;
432 Ok(real_out)
433}
434
435fn fft2d_complex_transform_flat(
438 z_in: &[Complex64],
439 m: usize,
440 n: usize,
441 inverse: bool,
442) -> IntegrateResult<(Vec<f64>, Vec<f64>)> {
443 let total = m * n;
444 let mut buf: Vec<Complex64> = z_in.to_vec();
445
446 for j in 0..n {
448 let col: Vec<Complex64> = (0..m).map(|i| buf[i * n + j]).collect();
449 let transformed = fft1d_complex_transform(&col, inverse)?;
450 for i in 0..m {
451 buf[i * n + j] = transformed[i];
452 }
453 }
454
455 for i in 0..m {
457 let row: Vec<Complex64> = (0..n).map(|j| buf[i * n + j]).collect();
458 let transformed = fft1d_complex_transform(&row, inverse)?;
459 for j in 0..n {
460 buf[i * n + j] = transformed[j];
461 }
462 }
463
464 if inverse {
465 let scale = 1.0 / total as f64;
466 for c in buf.iter_mut() {
467 c.re *= scale;
468 c.im *= scale;
469 }
470 }
471
472 let real_out: Vec<f64> = buf.iter().map(|c| c.re).collect();
473 let imag_out: Vec<f64> = buf.iter().map(|c| c.im).collect();
474 Ok((real_out, imag_out))
475}
476
477fn fft1d_complex_transform(input: &[Complex64], inverse: bool) -> IntegrateResult<Vec<Complex64>> {
479 if inverse {
480 ifft(input, None).map_err(|e| IntegrateError::ComputationError(format!("IFFT error: {e}")))
481 } else {
482 fft(input, None).map_err(|e| IntegrateError::ComputationError(format!("FFT error: {e}")))
483 }
484}
485
486fn power_iteration_eigenpairs(
489 a: &[f64],
490 n: usize,
491 n_terms: usize,
492) -> IntegrateResult<(Vec<f64>, Vec<Vec<f64>>)> {
493 let mut eigenvalues = Vec::with_capacity(n_terms);
494 let mut eigenvectors: Vec<Vec<f64>> = Vec::with_capacity(n_terms);
495
496 let mut mat = a.to_vec();
497 let max_iter = 1000;
498 let tol = 1e-10;
499
500 for _k in 0..n_terms {
501 let mut v: Vec<f64> = (0..n).map(|i| ((i + 1) as f64).sin()).collect();
503 normalize_vec(&mut v);
504
505 let mut lambda_prev = 0.0_f64;
506
507 for _iter in 0..max_iter {
508 let mut w = vec![0.0_f64; n];
510 for i in 0..n {
511 let mut s = 0.0_f64;
512 for j in 0..n {
513 s += mat[i * n + j] * v[j];
514 }
515 w[i] = s;
516 }
517
518 let lambda: f64 = v.iter().zip(w.iter()).map(|(&vi, &wi)| vi * wi).sum();
519 normalize_vec(&mut w);
520 v = w;
521
522 if (lambda - lambda_prev).abs() < tol {
523 break;
524 }
525 lambda_prev = lambda;
526 }
527
528 let mut av = vec![0.0_f64; n];
530 for i in 0..n {
531 let mut s = 0.0_f64;
532 for j in 0..n {
533 s += mat[i * n + j] * v[j];
534 }
535 av[i] = s;
536 }
537 let lambda: f64 = v.iter().zip(av.iter()).map(|(&vi, &avi)| vi * avi).sum();
538
539 eigenvalues.push(lambda.max(0.0));
540 eigenvectors.push(v.clone());
541
542 for i in 0..n {
544 for j in 0..n {
545 mat[i * n + j] -= lambda * v[i] * v[j];
546 }
547 }
548 }
549
550 Ok((eigenvalues, eigenvectors))
551}
552
553fn normalize_vec(v: &mut [f64]) {
555 let norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
556 if norm > 1e-15 {
557 for x in v.iter_mut() {
558 *x /= norm;
559 }
560 }
561}
562
563#[cfg(test)]
568mod tests {
569 use super::*;
570 use scirs2_core::ndarray::Array1;
571 use scirs2_core::random::prelude::*;
572
573 fn make_rng() -> StdRng {
574 seeded_rng(42)
575 }
576
577 #[test]
578 fn test_correlation_function_at_zero() {
579 let covs = [
580 CorrelationFunction::Exponential { length_scale: 1.0 },
581 CorrelationFunction::Gaussian { length_scale: 1.0 },
582 CorrelationFunction::Matern {
583 nu: 1.5,
584 length_scale: 1.0,
585 },
586 CorrelationFunction::Powered {
587 exponent: 1.5,
588 length_scale: 1.0,
589 },
590 ];
591 for cov in &covs {
592 let c0 = cov.evaluate(0.0);
593 assert!((c0 - 1.0).abs() < 1e-10, "C(0) must be 1, got {c0}");
594 }
595 }
596
597 #[test]
598 fn test_correlation_function_decreasing() {
599 let cov = CorrelationFunction::Gaussian { length_scale: 1.0 };
600 let c1 = cov.evaluate(0.5);
601 let c2 = cov.evaluate(1.0);
602 let c3 = cov.evaluate(2.0);
603 assert!(
604 c1 > c2 && c2 > c3,
605 "Correlation should decrease with distance"
606 );
607 }
608
609 #[test]
610 fn test_circulant_embedding_shape() {
611 let mut rng = make_rng();
612 let gx = Array1::linspace(0.0, 1.0, 8);
613 let gy = Array1::linspace(0.0, 1.0, 8);
614 let cov = CorrelationFunction::Exponential { length_scale: 0.3 };
615 let field = RandomField::sample_circulant_embedding(gx.view(), gy.view(), cov, &mut rng)
616 .expect("Circulant embedding failed");
617 assert_eq!(field.dim(), (8, 8));
618 }
619
620 #[test]
621 fn test_kl_expansion_shape() {
622 let mut rng = make_rng();
623 let gx = Array1::linspace(0.0, 1.0, 6);
624 let gy = Array1::linspace(0.0, 1.0, 6);
625 let cov = CorrelationFunction::Gaussian { length_scale: 0.3 };
626 let field = RandomField::sample_kl_expansion(gx.view(), gy.view(), cov, 10, &mut rng)
627 .expect("KL expansion failed");
628 assert_eq!(field.dim(), (6, 6));
629 }
630
631 #[test]
632 fn test_fourier_sampling_shape() {
633 let mut rng = make_rng();
634 let cov = CorrelationFunction::Gaussian { length_scale: 0.3 };
635 let field = RandomField::sample_fourier(8, 8, 1.0, 1.0, cov, &mut rng)
636 .expect("Fourier sampling failed");
637 assert_eq!(field.dim(), (8, 8));
638 }
639
640 #[test]
641 fn test_matern_various_nu() {
642 for nu in [0.5, 1.5, 2.5, 5.0] {
643 let cov = CorrelationFunction::Matern {
644 nu,
645 length_scale: 1.0,
646 };
647 let c = cov.evaluate(1.0);
648 assert!(
649 c > 0.0 && c < 1.0,
650 "Matérn({nu}) at r=1 should be in (0,1): got {c}"
651 );
652 }
653 }
654
655 #[test]
656 fn test_fourier_field_finite() {
657 let mut rng = make_rng();
658 let cov = CorrelationFunction::Exponential { length_scale: 0.5 };
659 let field = RandomField::sample_fourier(16, 16, 2.0, 2.0, cov, &mut rng)
660 .expect("sample_fourier should succeed with valid params");
661 assert!(
662 field.iter().all(|v| v.is_finite()),
663 "Fourier field contains non-finite values"
664 );
665 }
666}