1use scirs2_core::ndarray::ArrayView2;
9use scirs2_core::random::{seeded_rng, Rng, RngExt};
10
11use crate::error::{SpatialError, SpatialResult};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ScanModel {
16 Bernoulli,
18 Poisson,
20}
21
22#[derive(Debug, Clone)]
24pub struct ScanStatisticConfig {
25 pub model: ScanModel,
27 pub max_population_fraction: f64,
30 pub n_monte_carlo: usize,
32 pub seed: u64,
34 pub max_secondary_clusters: usize,
37}
38
39impl Default for ScanStatisticConfig {
40 fn default() -> Self {
41 Self {
42 model: ScanModel::Poisson,
43 max_population_fraction: 0.5,
44 n_monte_carlo: 999,
45 seed: 12345,
46 max_secondary_clusters: 5,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct ScanCluster {
54 pub center_index: usize,
56 pub radius: f64,
58 pub llr: f64,
60 pub p_value: f64,
62 pub member_indices: Vec<usize>,
64 pub observed_inside: f64,
66 pub expected_inside: f64,
68}
69
70#[derive(Debug, Clone)]
72pub struct ScanResult {
73 pub primary_cluster: ScanCluster,
75 pub secondary_clusters: Vec<ScanCluster>,
77}
78
79pub fn kulldorff_scan(
98 coordinates: &ArrayView2<f64>,
99 cases: &[f64],
100 population: &[f64],
101 config: &ScanStatisticConfig,
102) -> SpatialResult<ScanResult> {
103 let n = coordinates.nrows();
104 if n < 3 {
105 return Err(SpatialError::ValueError(
106 "Need at least 3 locations".to_string(),
107 ));
108 }
109 if coordinates.ncols() < 2 {
110 return Err(SpatialError::DimensionError(
111 "Coordinates must be 2D".to_string(),
112 ));
113 }
114 if cases.len() != n || population.len() != n {
115 return Err(SpatialError::DimensionError(
116 "cases and population must have length n".to_string(),
117 ));
118 }
119 if config.max_population_fraction <= 0.0 || config.max_population_fraction > 1.0 {
120 return Err(SpatialError::ValueError(
121 "max_population_fraction must be in (0, 1]".to_string(),
122 ));
123 }
124
125 let total_cases: f64 = cases.iter().sum();
126 let total_population: f64 = population.iter().sum();
127
128 if total_cases <= 0.0 || total_population <= 0.0 {
129 return Err(SpatialError::ValueError(
130 "Total cases and population must be positive".to_string(),
131 ));
132 }
133
134 let max_pop = config.max_population_fraction * total_population;
135
136 let distances = precompute_distances(coordinates, n);
138
139 let sorted_neighbours = build_sorted_neighbours(&distances, n);
141
142 let (best_llr, best_center, best_radius, best_members, best_obs, best_exp) = find_best_window(
144 &sorted_neighbours,
145 &distances,
146 cases,
147 population,
148 total_cases,
149 total_population,
150 max_pop,
151 config.model,
152 n,
153 );
154
155 let mc_p = monte_carlo_p_value(
157 &sorted_neighbours,
158 &distances,
159 population,
160 total_cases,
161 total_population,
162 max_pop,
163 config.model,
164 best_llr,
165 config.n_monte_carlo,
166 config.seed,
167 n,
168 );
169
170 let primary = ScanCluster {
171 center_index: best_center,
172 radius: best_radius,
173 llr: best_llr,
174 p_value: mc_p,
175 member_indices: best_members.clone(),
176 observed_inside: best_obs,
177 expected_inside: best_exp,
178 };
179
180 let secondary = find_secondary_clusters(
182 &sorted_neighbours,
183 &distances,
184 cases,
185 population,
186 total_cases,
187 total_population,
188 max_pop,
189 config.model,
190 &best_members,
191 config.max_secondary_clusters,
192 n,
193 );
194
195 Ok(ScanResult {
196 primary_cluster: primary,
197 secondary_clusters: secondary,
198 })
199}
200
201fn precompute_distances(coordinates: &ArrayView2<f64>, n: usize) -> Vec<Vec<f64>> {
206 let mut dists = vec![vec![0.0; n]; n];
207 for i in 0..n {
208 for j in (i + 1)..n {
209 let dx = coordinates[[i, 0]] - coordinates[[j, 0]];
210 let dy = coordinates[[i, 1]] - coordinates[[j, 1]];
211 let d = (dx * dx + dy * dy).sqrt();
212 dists[i][j] = d;
213 dists[j][i] = d;
214 }
215 }
216 dists
217}
218
219fn build_sorted_neighbours(distances: &[Vec<f64>], n: usize) -> Vec<Vec<usize>> {
220 let mut sorted = Vec::with_capacity(n);
221 for i in 0..n {
222 let mut neighbours: Vec<usize> = (0..n).collect();
223 neighbours.sort_by(|&a, &b| {
224 distances[i][a]
225 .partial_cmp(&distances[i][b])
226 .unwrap_or(std::cmp::Ordering::Equal)
227 });
228 sorted.push(neighbours);
229 }
230 sorted
231}
232
233fn compute_llr(
235 obs_in: f64,
236 exp_in: f64,
237 total_cases: f64,
238 total_expected: f64,
239 model: ScanModel,
240) -> f64 {
241 let obs_out = total_cases - obs_in;
242 let exp_out = total_expected - exp_in;
243
244 match model {
245 ScanModel::Poisson => {
246 if exp_in <= 0.0 || exp_out <= 0.0 || obs_in <= 0.0 {
249 return 0.0;
250 }
251 let rate_in = obs_in / exp_in;
252 let rate_out = if obs_out > 0.0 && exp_out > 0.0 {
253 obs_out / exp_out
254 } else {
255 0.0
256 };
257
258 if rate_in <= rate_out {
260 return 0.0;
261 }
262
263 let mut llr = obs_in * (obs_in / exp_in).ln();
264 if obs_out > 0.0 && exp_out > 0.0 {
265 llr += obs_out * (obs_out / exp_out).ln();
266 }
267 if total_cases > 0.0 && total_expected > 0.0 {
269 llr -= total_cases * (total_cases / total_expected).ln();
270 }
271 llr.max(0.0)
272 }
273 ScanModel::Bernoulli => {
274 if exp_in <= 0.0 || exp_out <= 0.0 {
276 return 0.0;
277 }
278
279 let p_in = obs_in / exp_in;
280 let p_out = if exp_out > 0.0 {
281 obs_out / exp_out
282 } else {
283 0.0
284 };
285
286 if p_in <= p_out || p_in <= 0.0 || p_in >= 1.0 {
287 return 0.0;
288 }
289
290 let p_in_c = p_in.clamp(1e-15, 1.0 - 1e-15);
292 let p_out_c = p_out.clamp(1e-15, 1.0 - 1e-15);
293 let p_total = total_cases / total_expected;
294 let p_total_c = p_total.clamp(1e-15, 1.0 - 1e-15);
295
296 let llr = obs_in * (p_in_c / p_total_c).ln()
297 + (exp_in - obs_in) * ((1.0 - p_in_c) / (1.0 - p_total_c)).ln()
298 + obs_out * (p_out_c / p_total_c).ln()
299 + (exp_out - obs_out) * ((1.0 - p_out_c) / (1.0 - p_total_c)).ln();
300
301 llr.max(0.0)
302 }
303 }
304}
305
306#[allow(clippy::too_many_arguments)]
307fn find_best_window(
308 sorted_neighbours: &[Vec<usize>],
309 distances: &[Vec<f64>],
310 cases: &[f64],
311 population: &[f64],
312 total_cases: f64,
313 total_population: f64,
314 max_pop: f64,
315 model: ScanModel,
316 n: usize,
317) -> (f64, usize, f64, Vec<usize>, f64, f64) {
318 let mut best_llr = 0.0;
319 let mut best_center = 0;
320 let mut best_radius = 0.0;
321 let mut best_members: Vec<usize> = Vec::new();
322 let mut best_obs = 0.0;
323 let mut best_exp = 0.0;
324
325 for i in 0..n {
326 let mut cum_cases = 0.0;
327 let mut cum_pop = 0.0;
328 let mut members: Vec<usize> = Vec::new();
329
330 for &j in &sorted_neighbours[i] {
331 cum_cases += cases[j];
332 cum_pop += population[j];
333 members.push(j);
334
335 if cum_pop > max_pop {
336 break;
337 }
338
339 let exp_in = match model {
340 ScanModel::Poisson => cum_pop * total_cases / total_population,
341 ScanModel::Bernoulli => cum_pop,
342 };
343
344 let total_expected = match model {
345 ScanModel::Poisson => total_population * total_cases / total_population,
346 ScanModel::Bernoulli => total_population,
347 };
348
349 let llr = compute_llr(cum_cases, exp_in, total_cases, total_expected, model);
350
351 if llr > best_llr {
352 best_llr = llr;
353 best_center = i;
354 best_radius = distances[i][j];
355 best_members = members.clone();
356 best_obs = cum_cases;
357 best_exp = exp_in;
358 }
359 }
360 }
361
362 (
363 best_llr,
364 best_center,
365 best_radius,
366 best_members,
367 best_obs,
368 best_exp,
369 )
370}
371
372#[allow(clippy::too_many_arguments)]
373fn monte_carlo_p_value(
374 sorted_neighbours: &[Vec<usize>],
375 distances: &[Vec<f64>],
376 population: &[f64],
377 total_cases: f64,
378 total_population: f64,
379 max_pop: f64,
380 model: ScanModel,
381 observed_llr: f64,
382 n_simulations: usize,
383 seed: u64,
384 n: usize,
385) -> f64 {
386 let mut rng = seeded_rng(seed);
387 let mut count_ge = 0usize;
388
389 for _sim in 0..n_simulations {
390 let sim_cases = generate_null_cases(&mut rng, population, total_cases, model, n);
392
393 let (sim_best_llr, _, _, _, _, _) = find_best_window(
395 sorted_neighbours,
396 distances,
397 &sim_cases,
398 population,
399 total_cases,
400 total_population,
401 max_pop,
402 model,
403 n,
404 );
405
406 if sim_best_llr >= observed_llr {
407 count_ge += 1;
408 }
409 }
410
411 (count_ge as f64 + 1.0) / (n_simulations as f64 + 1.0)
412}
413
414fn generate_null_cases<R: Rng + ?Sized>(
415 rng: &mut R,
416 population: &[f64],
417 total_cases: f64,
418 model: ScanModel,
419 n: usize,
420) -> Vec<f64> {
421 match model {
422 ScanModel::Poisson => {
423 let total_pop: f64 = population.iter().sum();
425 let mut sim_cases = vec![0.0; n];
426 let mut remaining = total_cases as usize;
427
428 for i in 0..n {
430 if remaining == 0 {
431 break;
432 }
433 let prob = population[i] / total_pop;
434 let expected = remaining as f64 * prob;
436 let allocated = poisson_sample(rng, expected).min(remaining as f64);
438 sim_cases[i] = allocated;
439 remaining = remaining.saturating_sub(allocated as usize);
440 }
441 while remaining > 0 {
443 let idx = rng.random_range(0..n);
444 sim_cases[idx] += 1.0;
445 remaining -= 1;
446 }
447 sim_cases
448 }
449 ScanModel::Bernoulli => {
450 let total_pop: f64 = population.iter().sum();
452 let p = total_cases / total_pop;
453 let mut sim_cases = vec![0.0; n];
454 for i in 0..n {
455 let trials = population[i] as usize;
457 let mut count = 0.0;
458 for _ in 0..trials {
459 if rng.random::<f64>() < p {
460 count += 1.0;
461 }
462 }
463 sim_cases[i] = count;
464 }
465 sim_cases
466 }
467 }
468}
469
470fn poisson_sample<R: Rng + ?Sized>(rng: &mut R, lambda: f64) -> f64 {
472 if lambda <= 0.0 {
473 return 0.0;
474 }
475 let l = (-lambda).exp();
476 let mut k: f64 = 0.0;
477 let mut p: f64 = 1.0;
478 loop {
479 k += 1.0;
480 let u: f64 = rng.random::<f64>();
481 p *= u;
482 if p < l {
483 break;
484 }
485 }
486 if k - 1.0 > 0.0 {
487 k - 1.0
488 } else {
489 0.0
490 }
491}
492
493#[allow(clippy::too_many_arguments)]
494fn find_secondary_clusters(
495 sorted_neighbours: &[Vec<usize>],
496 distances: &[Vec<f64>],
497 cases: &[f64],
498 population: &[f64],
499 total_cases: f64,
500 total_population: f64,
501 max_pop: f64,
502 model: ScanModel,
503 primary_members: &[usize],
504 max_secondary: usize,
505 n: usize,
506) -> Vec<ScanCluster> {
507 let mut candidates: Vec<(f64, usize, f64, Vec<usize>, f64, f64)> = Vec::new();
509
510 for i in 0..n {
511 if primary_members.contains(&i) {
513 continue;
514 }
515
516 let mut cum_cases = 0.0;
517 let mut cum_pop = 0.0;
518 let mut members: Vec<usize> = Vec::new();
519
520 for &j in &sorted_neighbours[i] {
521 if primary_members.contains(&j) {
523 continue;
524 }
525
526 cum_cases += cases[j];
527 cum_pop += population[j];
528 members.push(j);
529
530 if cum_pop > max_pop {
531 break;
532 }
533
534 let exp_in = match model {
535 ScanModel::Poisson => cum_pop * total_cases / total_population,
536 ScanModel::Bernoulli => cum_pop,
537 };
538
539 let total_expected = match model {
540 ScanModel::Poisson => total_cases,
541 ScanModel::Bernoulli => total_population,
542 };
543
544 let llr = compute_llr(cum_cases, exp_in, total_cases, total_expected, model);
545
546 if llr > 0.0 {
547 candidates.push((llr, i, distances[i][j], members.clone(), cum_cases, exp_in));
548 }
549 }
550 }
551
552 candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
554
555 let mut result = Vec::new();
557 let mut used_indices: Vec<usize> = primary_members.to_vec();
558
559 for (llr, center, radius, members, obs, exp) in candidates {
560 if result.len() >= max_secondary {
561 break;
562 }
563
564 let overlaps = members.iter().any(|m| used_indices.contains(m));
566 if overlaps {
567 continue;
568 }
569
570 used_indices.extend_from_slice(&members);
571 result.push(ScanCluster {
572 center_index: center,
573 radius,
574 llr,
575 p_value: 1.0, member_indices: members,
577 observed_inside: obs,
578 expected_inside: exp,
579 });
580 }
581
582 result
583}
584
585#[cfg(test)]
590mod tests {
591 use super::*;
592 use scirs2_core::ndarray::array;
593
594 #[test]
595 fn test_planted_poisson_cluster() {
596 let coords = array![
598 [0.0, 0.0],
599 [1.0, 0.0],
600 [0.0, 1.0],
601 [1.0, 1.0], [5.0, 0.0],
603 [6.0, 0.0],
604 [5.0, 1.0],
605 [6.0, 1.0],
606 [3.0, 3.0],
607 [4.0, 4.0],
608 ];
609
610 let cases = [20.0, 18.0, 22.0, 19.0, 2.0, 3.0, 1.0, 2.0, 1.0, 1.0];
612 let population = [100.0; 10];
613
614 let config = ScanStatisticConfig {
615 model: ScanModel::Poisson,
616 max_population_fraction: 0.5,
617 n_monte_carlo: 99,
618 seed: 42,
619 max_secondary_clusters: 3,
620 };
621
622 let result =
623 kulldorff_scan(&coords.view(), &cases, &population, &config).expect("scan failed");
624
625 assert!(
627 result.primary_cluster.llr > 0.0,
628 "Primary cluster LLR should be positive"
629 );
630 assert!(
631 result.primary_cluster.center_index < 4,
632 "Primary cluster should be centred in the high-rate area, got index {}",
633 result.primary_cluster.center_index
634 );
635
636 assert!(
638 result.primary_cluster.p_value < 0.5,
639 "p-value should be < 0.5, got {}",
640 result.primary_cluster.p_value
641 );
642 }
643
644 #[test]
645 fn test_planted_bernoulli_cluster() {
646 let coords = array![
647 [0.0, 0.0],
648 [0.5, 0.0],
649 [0.0, 0.5],
650 [5.0, 5.0],
651 [5.5, 5.0],
652 [5.0, 5.5],
653 ];
654
655 let cases = [9.0, 8.0, 10.0, 1.0, 2.0, 1.0];
657 let population = [10.0, 10.0, 10.0, 10.0, 10.0, 10.0];
658
659 let config = ScanStatisticConfig {
660 model: ScanModel::Bernoulli,
661 max_population_fraction: 0.5,
662 n_monte_carlo: 99,
663 seed: 123,
664 max_secondary_clusters: 2,
665 };
666
667 let result =
668 kulldorff_scan(&coords.view(), &cases, &population, &config).expect("bernoulli scan");
669
670 assert!(result.primary_cluster.llr > 0.0);
671 let centre = result.primary_cluster.center_index;
673 assert!(
674 centre < 3,
675 "Expected cluster centre in [0,3), got {}",
676 centre
677 );
678 }
679
680 #[test]
681 fn test_no_cluster_uniform() {
682 let coords = array![
684 [0.0, 0.0],
685 [1.0, 0.0],
686 [2.0, 0.0],
687 [0.0, 1.0],
688 [1.0, 1.0],
689 [2.0, 1.0],
690 ];
691
692 let cases = [5.0, 5.0, 5.0, 5.0, 5.0, 5.0];
693 let population = [100.0; 6];
694
695 let config = ScanStatisticConfig {
696 model: ScanModel::Poisson,
697 n_monte_carlo: 99,
698 seed: 77,
699 ..Default::default()
700 };
701
702 let result =
703 kulldorff_scan(&coords.view(), &cases, &population, &config).expect("uniform scan");
704
705 assert!(
707 result.primary_cluster.p_value > 0.05,
708 "p-value should be > 0.05 for uniform data, got {}",
709 result.primary_cluster.p_value
710 );
711 }
712
713 #[test]
714 fn test_scan_errors() {
715 let coords = array![[0.0, 0.0], [1.0, 0.0]]; let cases = [1.0, 1.0];
717 let population = [10.0, 10.0];
718 let config = ScanStatisticConfig::default();
719 assert!(kulldorff_scan(&coords.view(), &cases, &population, &config).is_err());
720 }
721
722 #[test]
723 fn test_scan_secondary_clusters() {
724 let coords = array![
726 [0.0, 0.0],
727 [0.1, 0.0],
728 [0.0, 0.1],
729 [10.0, 10.0],
730 [10.1, 10.0],
731 [10.0, 10.1],
732 [5.0, 5.0], [5.1, 5.0],
734 [5.0, 5.1],
735 ];
736
737 let cases = [15.0, 14.0, 16.0, 12.0, 13.0, 14.0, 1.0, 1.0, 1.0];
738 let population = [20.0; 9];
739
740 let config = ScanStatisticConfig {
741 model: ScanModel::Poisson,
742 n_monte_carlo: 49,
743 seed: 999,
744 max_secondary_clusters: 3,
745 ..Default::default()
746 };
747
748 let result =
749 kulldorff_scan(&coords.view(), &cases, &population, &config).expect("secondary scan");
750
751 assert!(result.primary_cluster.llr > 0.0);
753 }
756}