Skip to main content

scirs2_transform/drift/
mod.rs

1//! Data drift and distribution shift detection
2//!
3//! This module provides multivariate drift detection for monitoring feature distributions
4//! over time between a reference dataset and a current (production) dataset.
5//!
6//! ## Methods
7//!
8//! | Method | Description |
9//! |--------|-------------|
10//! | [`DriftMethod::KolmogorovSmirnov`] | 2-sample KS test per feature with asymptotic p-value |
11//! | [`DriftMethod::PopulationStabilityIndex`] | Binning-based PSI score |
12//! | [`DriftMethod::Wasserstein`] | W1 (Earth mover's) distance per feature |
13//! | [`DriftMethod::MaximumMeanDiscrepancy`] | MMD² with RBF kernel (multivariate) |
14//!
15//! ## Example
16//!
17//! ```rust
18//! use scirs2_transform::drift::{DriftDetector, DriftDetectorConfig, DriftMethod};
19//! use scirs2_core::ndarray::Array2;
20//!
21//! let reference = Array2::<f64>::zeros((100, 3));
22//! let config = DriftDetectorConfig::default();
23//! let detector = DriftDetector::fit(&reference, config);
24//!
25//! let current = Array2::<f64>::zeros((80, 3));
26//! let report = detector.detect(&current).expect("detection should succeed");
27//! assert!(!report.drifted, "identical distributions should not drift");
28//! ```
29
30use scirs2_core::ndarray::{Array1, Array2, Axis};
31
32use crate::error::{Result, TransformError};
33
34// ---------------------------------------------------------------------------
35// DriftMethod enum
36// ---------------------------------------------------------------------------
37
38/// Method used for drift detection.
39#[non_exhaustive]
40#[derive(Debug, Clone, PartialEq)]
41pub enum DriftMethod {
42    /// Two-sample Kolmogorov-Smirnov test per feature.
43    ///
44    /// The KS statistic D = sup|F1(x) − F2(x)| is computed per feature.
45    /// A p-value is derived from the asymptotic Kolmogorov distribution.
46    KolmogorovSmirnov,
47    /// Population Stability Index (PSI) via reference-bin comparison.
48    ///
49    /// PSI = Σ (ref_pct − curr_pct) × ln(ref_pct / curr_pct)
50    /// Thresholds: PSI < 0.1 (stable), 0.1–0.2 (moderate), > 0.2 (severe).
51    PopulationStabilityIndex,
52    /// Wasserstein-1 (Earth mover's) distance per feature.
53    ///
54    /// Computed exactly in 1D via sorted CDF integration.
55    Wasserstein,
56    /// Maximum Mean Discrepancy with RBF kernel.
57    ///
58    /// MMD² = E[k(x,x')] + E[k(y,y')] − 2E[k(x,y)] using U-statistic estimator.
59    /// Bandwidth is set via median heuristic if not specified.
60    MaximumMeanDiscrepancy,
61}
62
63// ---------------------------------------------------------------------------
64// DriftDetectorConfig
65// ---------------------------------------------------------------------------
66
67/// Configuration for a [`DriftDetector`].
68#[derive(Debug, Clone)]
69pub struct DriftDetectorConfig {
70    /// Detection method.
71    pub method: DriftMethod,
72    /// Significance level for hypothesis-test-based methods (KS). Default: 0.05.
73    pub significance_level: f64,
74    /// Number of bins for PSI. Default: 10.
75    pub n_bins: usize,
76    /// RBF kernel bandwidth for MMD. `None` uses the median pairwise distance heuristic.
77    pub mmd_bandwidth: Option<f64>,
78    /// W1 distance threshold for Wasserstein method. Default: 0.1.
79    pub wasserstein_threshold: f64,
80    /// PSI threshold above which drift is flagged as severe. Default: 0.2.
81    pub psi_threshold: f64,
82}
83
84impl Default for DriftDetectorConfig {
85    fn default() -> Self {
86        Self {
87            method: DriftMethod::KolmogorovSmirnov,
88            significance_level: 0.05,
89            n_bins: 10,
90            mmd_bandwidth: None,
91            wasserstein_threshold: 0.1,
92            psi_threshold: 0.2,
93        }
94    }
95}
96
97// ---------------------------------------------------------------------------
98// DriftReport
99// ---------------------------------------------------------------------------
100
101/// Report produced by [`DriftDetector::detect`].
102#[derive(Debug, Clone)]
103pub struct DriftReport {
104    /// Per-feature drift scores (interpretation depends on method).
105    pub feature_scores: Vec<f64>,
106    /// Per-feature drift flags.
107    pub feature_drifted: Vec<bool>,
108    /// Aggregate drift score (mean of per-feature scores).
109    pub overall_score: f64,
110    /// Whether overall drift is detected.
111    pub drifted: bool,
112    /// Method used for detection.
113    pub method: DriftMethod,
114}
115
116// ---------------------------------------------------------------------------
117// DriftDetector
118// ---------------------------------------------------------------------------
119
120/// Multivariate drift detector that compares a reference distribution against
121/// a current (test) distribution feature by feature.
122pub struct DriftDetector {
123    config: DriftDetectorConfig,
124    /// Reference distribution: shape (n_ref × n_features).
125    reference: Array2<f64>,
126}
127
128impl DriftDetector {
129    /// Fit the detector by storing the reference distribution.
130    ///
131    /// # Arguments
132    /// * `reference` – Reference dataset with shape (n_samples × n_features).
133    /// * `config`    – Detection configuration.
134    pub fn fit(reference: &Array2<f64>, config: DriftDetectorConfig) -> Self {
135        Self {
136            config,
137            reference: reference.to_owned(),
138        }
139    }
140
141    /// Detect drift between the stored reference and a new current dataset.
142    ///
143    /// # Arguments
144    /// * `current` – Current dataset with shape (n_samples × n_features).
145    ///
146    /// # Errors
147    /// Returns [`TransformError::InvalidInput`] if the number of features differs
148    /// or if either dataset is empty.
149    pub fn detect(&self, current: &Array2<f64>) -> Result<DriftReport> {
150        let n_ref = self.reference.nrows();
151        let n_cur = current.nrows();
152        let n_features = self.reference.ncols();
153
154        if n_ref == 0 {
155            return Err(TransformError::InvalidInput(
156                "Reference dataset is empty".to_string(),
157            ));
158        }
159        if n_cur == 0 {
160            return Err(TransformError::InvalidInput(
161                "Current dataset is empty".to_string(),
162            ));
163        }
164        if current.ncols() != n_features {
165            return Err(TransformError::InvalidInput(format!(
166                "Feature dimension mismatch: reference has {n_features} features, current has {}",
167                current.ncols()
168            )));
169        }
170
171        match &self.config.method {
172            DriftMethod::KolmogorovSmirnov => self.detect_ks(current),
173            DriftMethod::PopulationStabilityIndex => self.detect_psi(current),
174            DriftMethod::Wasserstein => self.detect_wasserstein(current),
175            DriftMethod::MaximumMeanDiscrepancy => self.detect_mmd(current),
176        }
177    }
178
179    /// Replace the reference distribution (e.g. for sliding window monitoring).
180    pub fn update_reference(&mut self, new_reference: &Array2<f64>) {
181        self.reference = new_reference.to_owned();
182    }
183
184    // -----------------------------------------------------------------------
185    // KS per-feature detection
186    // -----------------------------------------------------------------------
187
188    fn detect_ks(&self, current: &Array2<f64>) -> Result<DriftReport> {
189        let n_features = self.reference.ncols();
190        let mut scores = Vec::with_capacity(n_features);
191        let mut drifted_flags = Vec::with_capacity(n_features);
192
193        for f in 0..n_features {
194            let ref_col: Vec<f64> = self.reference.column(f).iter().copied().collect();
195            let cur_col: Vec<f64> = current.column(f).iter().copied().collect();
196
197            let (ks_stat, p_value) = ks_2samp(&ref_col, &cur_col)?;
198            scores.push(ks_stat);
199            drifted_flags.push(p_value < self.config.significance_level);
200        }
201
202        let overall_score = scores.iter().copied().sum::<f64>() / scores.len() as f64;
203        let drifted = drifted_flags.iter().any(|&d| d);
204
205        Ok(DriftReport {
206            feature_scores: scores,
207            feature_drifted: drifted_flags,
208            overall_score,
209            drifted,
210            method: DriftMethod::KolmogorovSmirnov,
211        })
212    }
213
214    // -----------------------------------------------------------------------
215    // PSI per-feature detection
216    // -----------------------------------------------------------------------
217
218    fn detect_psi(&self, current: &Array2<f64>) -> Result<DriftReport> {
219        let n_features = self.reference.ncols();
220        let mut scores = Vec::with_capacity(n_features);
221        let mut drifted_flags = Vec::with_capacity(n_features);
222
223        for f in 0..n_features {
224            let ref_col: Vec<f64> = self.reference.column(f).iter().copied().collect();
225            let cur_col: Vec<f64> = current.column(f).iter().copied().collect();
226
227            let psi = compute_psi(&ref_col, &cur_col, self.config.n_bins)?;
228            scores.push(psi);
229            drifted_flags.push(psi > self.config.psi_threshold);
230        }
231
232        let overall_score = scores.iter().copied().sum::<f64>() / scores.len() as f64;
233        let drifted = drifted_flags.iter().any(|&d| d);
234
235        Ok(DriftReport {
236            feature_scores: scores,
237            feature_drifted: drifted_flags,
238            overall_score,
239            drifted,
240            method: DriftMethod::PopulationStabilityIndex,
241        })
242    }
243
244    // -----------------------------------------------------------------------
245    // Wasserstein-1 per-feature detection
246    // -----------------------------------------------------------------------
247
248    fn detect_wasserstein(&self, current: &Array2<f64>) -> Result<DriftReport> {
249        let n_features = self.reference.ncols();
250        let mut scores = Vec::with_capacity(n_features);
251        let mut drifted_flags = Vec::with_capacity(n_features);
252
253        for f in 0..n_features {
254            let ref_col: Vec<f64> = self.reference.column(f).iter().copied().collect();
255            let cur_col: Vec<f64> = current.column(f).iter().copied().collect();
256
257            let w1 = wasserstein_1d_distance(&ref_col, &cur_col)?;
258            scores.push(w1);
259            drifted_flags.push(w1 > self.config.wasserstein_threshold);
260        }
261
262        let overall_score = scores.iter().copied().sum::<f64>() / scores.len() as f64;
263        let drifted = drifted_flags.iter().any(|&d| d);
264
265        Ok(DriftReport {
266            feature_scores: scores,
267            feature_drifted: drifted_flags,
268            overall_score,
269            drifted,
270            method: DriftMethod::Wasserstein,
271        })
272    }
273
274    // -----------------------------------------------------------------------
275    // MMD multivariate detection
276    // -----------------------------------------------------------------------
277
278    fn detect_mmd(&self, current: &Array2<f64>) -> Result<DriftReport> {
279        let n_features = self.reference.ncols();
280
281        // Determine bandwidth: median heuristic on reference if not specified
282        let bandwidth = match self.config.mmd_bandwidth {
283            Some(bw) => {
284                if bw <= 0.0 {
285                    return Err(TransformError::InvalidInput(
286                        "mmd_bandwidth must be positive".to_string(),
287                    ));
288                }
289                bw
290            }
291            None => median_heuristic_bandwidth(&self.reference)?,
292        };
293
294        // Per-feature MMD scores
295        let mut scores = Vec::with_capacity(n_features);
296        let mut drifted_flags = Vec::with_capacity(n_features);
297
298        for f in 0..n_features {
299            let ref_col: Vec<f64> = self.reference.column(f).iter().copied().collect();
300            let cur_col: Vec<f64> = current.column(f).iter().copied().collect();
301
302            let mmd2 = mmd_u_statistic_1d(&ref_col, &cur_col, bandwidth)?;
303            scores.push(mmd2);
304            // Threshold: mmd2 > 2 * bandwidth_scale (simple heuristic)
305            // A positive MMD² >> 0 indicates drift; threshold chosen as function of n.
306            let n_eff = (ref_col.len().min(cur_col.len()) as f64).max(1.0);
307            let threshold = 4.0 / n_eff.sqrt();
308            drifted_flags.push(mmd2 > threshold);
309        }
310
311        let overall_score = scores.iter().copied().sum::<f64>() / scores.len() as f64;
312        let drifted = drifted_flags.iter().any(|&d| d);
313
314        Ok(DriftReport {
315            feature_scores: scores,
316            feature_drifted: drifted_flags,
317            overall_score,
318            drifted,
319            method: DriftMethod::MaximumMeanDiscrepancy,
320        })
321    }
322}
323
324// ---------------------------------------------------------------------------
325// KS 2-sample statistic and asymptotic p-value
326// ---------------------------------------------------------------------------
327
328/// Compute the two-sample KS statistic D and an asymptotic p-value.
329///
330/// Uses the Kolmogorov asymptotic approximation:
331/// p ≈ 2 Σ_{j=1}^{∞} (−1)^{j+1} exp(−2 j² λ²)  where  λ = D √(n1 n2 / (n1+n2))
332fn ks_2samp(x: &[f64], y: &[f64]) -> Result<(f64, f64)> {
333    if x.is_empty() || y.is_empty() {
334        return Err(TransformError::InvalidInput(
335            "KS samples must be non-empty".to_string(),
336        ));
337    }
338
339    let mut xs = x.to_vec();
340    let mut ys = y.to_vec();
341    xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
342    ys.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
343
344    let n1 = xs.len();
345    let n2 = ys.len();
346    let n1f = n1 as f64;
347    let n2f = n2 as f64;
348
349    // Two-pointer sweep over merged sorted values
350    let mut i = 0usize;
351    let mut j = 0usize;
352    let mut d_max: f64 = 0.0;
353
354    while i < n1 || j < n2 {
355        let xv = if i < n1 { xs[i] } else { f64::INFINITY };
356        let yv = if j < n2 { ys[j] } else { f64::INFINITY };
357
358        // Advance both pointers past the current value
359        let cur = xv.min(yv);
360        while i < n1 && xs[i] <= cur {
361            i += 1;
362        }
363        while j < n2 && ys[j] <= cur {
364            j += 1;
365        }
366
367        let cdf1 = i as f64 / n1f;
368        let cdf2 = j as f64 / n2f;
369        d_max = d_max.max((cdf1 - cdf2).abs());
370    }
371
372    // Asymptotic Kolmogorov distribution
373    let lambda = d_max * ((n1f * n2f / (n1f + n2f)).sqrt());
374    let p_value = kolmogorov_p_value(lambda);
375
376    Ok((d_max, p_value))
377}
378
379/// Asymptotic KS p-value: P(D_n > d) ≈ 2 Σ_{j=1}^{K} (−1)^{j+1} exp(−2 j² λ²)
380fn kolmogorov_p_value(lambda: f64) -> f64 {
381    if lambda <= 0.0 {
382        return 1.0;
383    }
384    if lambda > 4.0 {
385        return 0.0;
386    }
387
388    let mut sum = 0.0;
389    // Converges quickly; 20 terms is more than sufficient
390    for j in 1usize..=20 {
391        let jf = j as f64;
392        let term = (-2.0 * jf * jf * lambda * lambda).exp();
393        if j % 2 == 0 {
394            sum -= term;
395        } else {
396            sum += term;
397        }
398        if term < 1e-15 {
399            break;
400        }
401    }
402    (2.0 * sum).clamp(0.0, 1.0)
403}
404
405// ---------------------------------------------------------------------------
406// PSI
407// ---------------------------------------------------------------------------
408
409/// Compute the Population Stability Index between two 1D distributions.
410///
411/// PSI = Σ (ref_pct − curr_pct) × ln(ref_pct / curr_pct)
412///
413/// Bins are derived from the reference distribution quantiles.
414fn compute_psi(reference: &[f64], current: &[f64], n_bins: usize) -> Result<f64> {
415    if reference.is_empty() || current.is_empty() {
416        return Err(TransformError::InvalidInput(
417            "PSI samples must be non-empty".to_string(),
418        ));
419    }
420    if n_bins == 0 {
421        return Err(TransformError::InvalidInput(
422            "n_bins must be at least 1".to_string(),
423        ));
424    }
425
426    let mut ref_sorted: Vec<f64> = reference.to_vec();
427    ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
428
429    // Build bin edges from reference quantiles
430    let mut edges: Vec<f64> = Vec::with_capacity(n_bins + 1);
431    edges.push(f64::NEG_INFINITY);
432    for i in 1..n_bins {
433        let q = i as f64 / n_bins as f64;
434        let idx = ((q * reference.len() as f64) as usize).min(reference.len() - 1);
435        edges.push(ref_sorted[idx]);
436    }
437    edges.push(f64::INFINITY);
438
439    // Deduplicate edges (handles constant features)
440    edges.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON);
441
442    let actual_bins = edges.len() - 1;
443    if actual_bins == 0 {
444        // Perfectly constant feature — no drift possible
445        return Ok(0.0);
446    }
447
448    let ref_n = reference.len() as f64;
449    let cur_n = current.len() as f64;
450    let epsilon = 1e-8; // smoothing to avoid log(0)
451
452    let mut ref_counts = vec![0u64; actual_bins];
453    let mut cur_counts = vec![0u64; actual_bins];
454
455    for &v in reference {
456        let bin = find_bin(v, &edges);
457        ref_counts[bin] += 1;
458    }
459    for &v in current {
460        let bin = find_bin(v, &edges);
461        cur_counts[bin] += 1;
462    }
463
464    let mut psi = 0.0_f64;
465    for b in 0..actual_bins {
466        let ref_pct = (ref_counts[b] as f64 / ref_n + epsilon).min(1.0);
467        let cur_pct = (cur_counts[b] as f64 / cur_n + epsilon).min(1.0);
468        psi += (ref_pct - cur_pct) * (ref_pct / cur_pct).ln();
469    }
470
471    Ok(psi.max(0.0))
472}
473
474/// Find which bin index a value falls in using the bin edges.
475fn find_bin(value: f64, edges: &[f64]) -> usize {
476    let n_bins = edges.len() - 1;
477    // Binary search: find rightmost edge that is <= value
478    let mut lo = 1usize; // skip −∞
479    let mut hi = n_bins;
480    while lo < hi {
481        let mid = (lo + hi) / 2;
482        if edges[mid] <= value {
483            lo = mid + 1;
484        } else {
485            hi = mid;
486        }
487    }
488    (lo - 1).min(n_bins - 1)
489}
490
491// ---------------------------------------------------------------------------
492// Wasserstein-1D
493// ---------------------------------------------------------------------------
494
495/// Exact 1D Wasserstein-1 distance via sorted CDF sweep.
496///
497/// W₁(p, q) = ∫|F_p(x) − F_q(x)| dx  ≡  (1/n) Σ |x_i − y_i| after sorting
498/// (when both have equal weight; otherwise we do the general CDF area integral).
499fn wasserstein_1d_distance(x: &[f64], y: &[f64]) -> Result<f64> {
500    if x.is_empty() || y.is_empty() {
501        return Err(TransformError::InvalidInput(
502            "Wasserstein samples must be non-empty".to_string(),
503        ));
504    }
505
506    let mut xs: Vec<f64> = x.to_vec();
507    let mut ys: Vec<f64> = y.to_vec();
508    xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
509    ys.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
510
511    let n = xs.len();
512    let m = ys.len();
513    let nf = n as f64;
514    let mf = m as f64;
515
516    // Build merged sorted event points
517    let mut events: Vec<f64> = xs.iter().chain(ys.iter()).copied().collect();
518    events.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
519    events.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON * a.abs().max(1.0));
520
521    let mut dist = 0.0_f64;
522    let mut ix = 0usize;
523    let mut iy = 0usize;
524
525    for w in events.windows(2) {
526        let lo = w[0];
527        let hi = w[1];
528        let dx = hi - lo;
529
530        // Count elements <= lo in xs and ys
531        while ix < n && xs[ix] <= lo {
532            ix += 1;
533        }
534        while iy < m && ys[iy] <= lo {
535            iy += 1;
536        }
537
538        let cdf_x = ix as f64 / nf;
539        let cdf_y = iy as f64 / mf;
540        dist += (cdf_x - cdf_y).abs() * dx;
541    }
542
543    Ok(dist)
544}
545
546// ---------------------------------------------------------------------------
547// MMD U-statistic (1D)
548// ---------------------------------------------------------------------------
549
550/// U-statistic estimator of MMD² for 1D samples with RBF kernel.
551///
552/// MMD²_u = (1/(n(n-1))) Σ_{i≠j} k(x_i,x_j)
553///         + (1/(m(m-1))) Σ_{i≠j} k(y_i,y_j)
554///         − (2/(nm))     Σ_{i,j} k(x_i,y_j)
555///
556/// where k(a,b) = exp(−(a−b)² / (2σ²)).
557fn mmd_u_statistic_1d(x: &[f64], y: &[f64], bandwidth: f64) -> Result<f64> {
558    if x.is_empty() || y.is_empty() {
559        return Err(TransformError::InvalidInput(
560            "MMD samples must be non-empty".to_string(),
561        ));
562    }
563    if bandwidth <= 0.0 {
564        return Err(TransformError::InvalidInput(
565            "MMD bandwidth must be positive".to_string(),
566        ));
567    }
568
569    let n = x.len();
570    let m = y.len();
571    let gamma = 1.0 / (2.0 * bandwidth * bandwidth);
572
573    // Term 1: (1/(n(n-1))) Σ_{i≠j} k(x_i, x_j)
574    let kxx = if n > 1 {
575        let mut sum = 0.0_f64;
576        for i in 0..n {
577            for j in (i + 1)..n {
578                let d = x[i] - x[j];
579                sum += (-gamma * d * d).exp();
580            }
581        }
582        2.0 * sum / (n * (n - 1)) as f64
583    } else {
584        0.0
585    };
586
587    // Term 2: (1/(m(m-1))) Σ_{i≠j} k(y_i, y_j)
588    let kyy = if m > 1 {
589        let mut sum = 0.0_f64;
590        for i in 0..m {
591            for j in (i + 1)..m {
592                let d = y[i] - y[j];
593                sum += (-gamma * d * d).exp();
594            }
595        }
596        2.0 * sum / (m * (m - 1)) as f64
597    } else {
598        0.0
599    };
600
601    // Term 3: (2/(nm)) Σ_{i,j} k(x_i, y_j)
602    let mut kxy_sum = 0.0_f64;
603    for &xi in x {
604        for &yi in y {
605            let d = xi - yi;
606            kxy_sum += (-gamma * d * d).exp();
607        }
608    }
609    let kxy = 2.0 * kxy_sum / (n * m) as f64;
610
611    Ok((kxx + kyy - kxy).max(0.0))
612}
613
614// ---------------------------------------------------------------------------
615// Median heuristic for MMD bandwidth
616// ---------------------------------------------------------------------------
617
618/// Estimate RBF bandwidth via the median of pairwise Euclidean distances in the
619/// reference dataset (subsampled to at most 500 rows for efficiency).
620fn median_heuristic_bandwidth(data: &Array2<f64>) -> Result<f64> {
621    let n = data.nrows();
622    if n == 0 {
623        return Err(TransformError::InvalidInput(
624            "Cannot compute bandwidth on empty data".to_string(),
625        ));
626    }
627
628    // Subsample to limit O(n²) cost
629    let max_samples = 500usize;
630    let step = if n > max_samples { n / max_samples } else { 1 };
631    let indices: Vec<usize> = (0..n).step_by(step).collect();
632    let k = indices.len();
633
634    let mut dists: Vec<f64> = Vec::with_capacity(k * (k - 1) / 2);
635    for i in 0..k {
636        for j in (i + 1)..k {
637            let row_i = data.row(indices[i]);
638            let row_j = data.row(indices[j]);
639            let sq_dist: f64 = row_i
640                .iter()
641                .zip(row_j.iter())
642                .map(|(a, b)| (a - b) * (a - b))
643                .sum();
644            dists.push(sq_dist.sqrt());
645        }
646    }
647
648    if dists.is_empty() {
649        return Ok(1.0); // single sample: use unit bandwidth
650    }
651
652    dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
653    let median = dists[dists.len() / 2];
654    Ok(if median > 0.0 { median } else { 1.0 })
655}
656
657// ---------------------------------------------------------------------------
658// Public convenience functions
659// ---------------------------------------------------------------------------
660
661/// Compute the 2-sample KS statistic and asymptotic p-value for two 1D samples.
662///
663/// Returns `(ks_statistic, p_value)`.
664pub fn ks_test(x: &[f64], y: &[f64]) -> Result<(f64, f64)> {
665    ks_2samp(x, y)
666}
667
668/// Compute PSI between a reference and current 1D distribution.
669pub fn psi(reference: &[f64], current: &[f64], n_bins: usize) -> Result<f64> {
670    compute_psi(reference, current, n_bins)
671}
672
673/// Compute the W1 (Wasserstein-1) distance between two 1D empirical distributions.
674pub fn wasserstein_distance_1d(x: &[f64], y: &[f64]) -> Result<f64> {
675    wasserstein_1d_distance(x, y)
676}
677
678/// Compute MMD² with RBF kernel for two 1D sample arrays.
679pub fn mmd_rbf(x: &[f64], y: &[f64], bandwidth: f64) -> Result<f64> {
680    mmd_u_statistic_1d(x, y, bandwidth)
681}
682
683// ---------------------------------------------------------------------------
684// Tests
685// ---------------------------------------------------------------------------
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690    use scirs2_core::ndarray::Array2;
691
692    fn zeros_matrix(rows: usize, cols: usize) -> Array2<f64> {
693        Array2::<f64>::zeros((rows, cols))
694    }
695
696    fn linspace_col(start: f64, end: f64, n: usize) -> Vec<f64> {
697        (0..n)
698            .map(|i| start + (end - start) * (i as f64) / ((n - 1) as f64))
699            .collect()
700    }
701
702    // ------------------------------------------------------------------
703    // KS tests
704    // ------------------------------------------------------------------
705
706    #[test]
707    fn test_ks_no_drift() {
708        // Same distribution → KS statistic ≈ 0, high p-value, no drift
709        let data: Vec<f64> = linspace_col(0.0, 1.0, 100);
710        let config = DriftDetectorConfig {
711            method: DriftMethod::KolmogorovSmirnov,
712            ..Default::default()
713        };
714        let ref_mat = Array2::from_shape_vec((100, 1), data.clone()).expect("shape ok");
715        let cur_mat = Array2::from_shape_vec((100, 1), data).expect("shape ok");
716
717        let detector = DriftDetector::fit(&ref_mat, config);
718        let report = detector.detect(&cur_mat).expect("detect ok");
719        assert!(!report.drifted, "identical distributions should not drift");
720        assert!(
721            report.overall_score < 0.01,
722            "KS statistic should be near 0, got {}",
723            report.overall_score
724        );
725    }
726
727    #[test]
728    fn test_ks_drift_detected() {
729        // Reference: [0, 1], Current: [10, 11] — clearly different
730        let ref_data: Vec<f64> = linspace_col(0.0, 1.0, 100);
731        let cur_data: Vec<f64> = linspace_col(10.0, 11.0, 100);
732        let config = DriftDetectorConfig {
733            method: DriftMethod::KolmogorovSmirnov,
734            ..Default::default()
735        };
736        let ref_mat = Array2::from_shape_vec((100, 1), ref_data).expect("shape ok");
737        let cur_mat = Array2::from_shape_vec((100, 1), cur_data).expect("shape ok");
738
739        let detector = DriftDetector::fit(&ref_mat, config);
740        let report = detector.detect(&cur_mat).expect("detect ok");
741        assert!(report.drifted, "clearly shifted distributions should drift");
742        assert!(
743            report.overall_score > 0.9,
744            "KS stat should be close to 1.0, got {}",
745            report.overall_score
746        );
747    }
748
749    #[test]
750    fn test_ks_drift_multifeature() {
751        // 3-feature dataset; only feature 0 drifts
752        let n = 100usize;
753        let mut ref_data = vec![0.0f64; n * 3];
754        let mut cur_data = vec![0.0f64; n * 3];
755
756        for i in 0..n {
757            // Feature 0: reference [0,1], current [5,6]
758            ref_data[i * 3] = i as f64 / n as f64;
759            cur_data[i * 3] = 5.0 + i as f64 / n as f64;
760            // Feature 1 & 2: same
761            ref_data[i * 3 + 1] = i as f64 / n as f64;
762            cur_data[i * 3 + 1] = i as f64 / n as f64;
763            ref_data[i * 3 + 2] = i as f64 / n as f64;
764            cur_data[i * 3 + 2] = i as f64 / n as f64;
765        }
766
767        let ref_mat = Array2::from_shape_vec((n, 3), ref_data).expect("shape ok");
768        let cur_mat = Array2::from_shape_vec((n, 3), cur_data).expect("shape ok");
769        let config = DriftDetectorConfig {
770            method: DriftMethod::KolmogorovSmirnov,
771            ..Default::default()
772        };
773
774        let detector = DriftDetector::fit(&ref_mat, config);
775        let report = detector.detect(&cur_mat).expect("detect ok");
776        assert!(report.drifted, "overall should be drifted");
777        assert!(report.feature_drifted[0], "feature 0 should drift");
778        assert!(!report.feature_drifted[1], "feature 1 should not drift");
779        assert!(!report.feature_drifted[2], "feature 2 should not drift");
780    }
781
782    // ------------------------------------------------------------------
783    // PSI tests
784    // ------------------------------------------------------------------
785
786    #[test]
787    fn test_psi_no_drift() {
788        // Same distribution → PSI ≈ 0
789        let data: Vec<f64> = linspace_col(0.0, 1.0, 100);
790        let config = DriftDetectorConfig {
791            method: DriftMethod::PopulationStabilityIndex,
792            ..Default::default()
793        };
794        let ref_mat = Array2::from_shape_vec((100, 1), data.clone()).expect("shape ok");
795        let cur_mat = Array2::from_shape_vec((100, 1), data).expect("shape ok");
796
797        let detector = DriftDetector::fit(&ref_mat, config);
798        let report = detector.detect(&cur_mat).expect("detect ok");
799        assert!(
800            !report.drifted,
801            "same distribution should not trigger PSI drift"
802        );
803    }
804
805    #[test]
806    fn test_psi_severe_drift() {
807        // PSI > 0.2 (severe drift)
808        let ref_data: Vec<f64> = linspace_col(0.0, 1.0, 200);
809        let cur_data: Vec<f64> = linspace_col(5.0, 6.0, 200);
810        let config = DriftDetectorConfig {
811            method: DriftMethod::PopulationStabilityIndex,
812            psi_threshold: 0.2,
813            ..Default::default()
814        };
815        let ref_mat = Array2::from_shape_vec((200, 1), ref_data).expect("shape ok");
816        let cur_mat = Array2::from_shape_vec((200, 1), cur_data).expect("shape ok");
817
818        let detector = DriftDetector::fit(&ref_mat, config);
819        let report = detector.detect(&cur_mat).expect("detect ok");
820        assert!(report.drifted, "severe shift should trigger PSI drift");
821        assert!(
822            report.overall_score > 0.2,
823            "PSI should exceed threshold, got {}",
824            report.overall_score
825        );
826    }
827
828    // ------------------------------------------------------------------
829    // Wasserstein tests
830    // ------------------------------------------------------------------
831
832    #[test]
833    fn test_wasserstein_drift() {
834        // Reference: [0,1], Current: [0.5, 1.5]
835        // W1 ≈ 0.5 > threshold 0.1
836        let ref_data: Vec<f64> = linspace_col(0.0, 1.0, 100);
837        let cur_data: Vec<f64> = linspace_col(0.5, 1.5, 100);
838        let config = DriftDetectorConfig {
839            method: DriftMethod::Wasserstein,
840            wasserstein_threshold: 0.1,
841            ..Default::default()
842        };
843        let ref_mat = Array2::from_shape_vec((100, 1), ref_data).expect("shape ok");
844        let cur_mat = Array2::from_shape_vec((100, 1), cur_data).expect("shape ok");
845
846        let detector = DriftDetector::fit(&ref_mat, config);
847        let report = detector.detect(&cur_mat).expect("detect ok");
848        assert!(
849            report.drifted,
850            "shifted distribution should trigger W1 drift"
851        );
852        assert!(
853            (report.overall_score - 0.5).abs() < 0.05,
854            "W1 distance should be ~0.5, got {}",
855            report.overall_score
856        );
857    }
858
859    #[test]
860    fn test_wasserstein_no_drift() {
861        // Same distribution → W1 = 0
862        let data: Vec<f64> = linspace_col(0.0, 1.0, 100);
863        let config = DriftDetectorConfig {
864            method: DriftMethod::Wasserstein,
865            wasserstein_threshold: 0.1,
866            ..Default::default()
867        };
868        let ref_mat = Array2::from_shape_vec((100, 1), data.clone()).expect("shape ok");
869        let cur_mat = Array2::from_shape_vec((100, 1), data).expect("shape ok");
870
871        let detector = DriftDetector::fit(&ref_mat, config);
872        let report = detector.detect(&cur_mat).expect("detect ok");
873        assert!(!report.drifted, "identical distributions should not drift");
874        assert!(
875            report.overall_score < 1e-10,
876            "W1 should be 0, got {}",
877            report.overall_score
878        );
879    }
880
881    // ------------------------------------------------------------------
882    // MMD tests
883    // ------------------------------------------------------------------
884
885    #[test]
886    fn test_mmd_identical() {
887        // MMD(P, P) should be ≈ 0
888        let data: Vec<f64> = linspace_col(0.0, 1.0, 50);
889        let ref_mat = Array2::from_shape_vec((50, 1), data.clone()).expect("shape ok");
890        let cur_mat = Array2::from_shape_vec((50, 1), data).expect("shape ok");
891        let config = DriftDetectorConfig {
892            method: DriftMethod::MaximumMeanDiscrepancy,
893            mmd_bandwidth: Some(0.5),
894            ..Default::default()
895        };
896
897        let detector = DriftDetector::fit(&ref_mat, config);
898        let report = detector.detect(&cur_mat).expect("detect ok");
899        assert!(
900            report.overall_score < 1e-6,
901            "MMD of identical distributions should be near 0, got {}",
902            report.overall_score
903        );
904    }
905
906    #[test]
907    fn test_mmd_different() {
908        // N(0,1) vs N(5,1) → large MMD
909        let ref_data: Vec<f64> = linspace_col(-3.0, 3.0, 60);
910        let cur_data: Vec<f64> = linspace_col(2.0, 8.0, 60);
911        let ref_mat = Array2::from_shape_vec((60, 1), ref_data).expect("shape ok");
912        let cur_mat = Array2::from_shape_vec((60, 1), cur_data).expect("shape ok");
913        let config = DriftDetectorConfig {
914            method: DriftMethod::MaximumMeanDiscrepancy,
915            mmd_bandwidth: Some(1.0),
916            ..Default::default()
917        };
918
919        let detector = DriftDetector::fit(&ref_mat, config);
920        let report = detector.detect(&cur_mat).expect("detect ok");
921        assert!(
922            report.overall_score > 0.01,
923            "MMD between N(0,1) and N(5,1) should be positive, got {}",
924            report.overall_score
925        );
926    }
927
928    #[test]
929    fn test_mmd_median_heuristic() {
930        // Should work with automatic bandwidth selection
931        let ref_data: Vec<f64> = linspace_col(0.0, 1.0, 40);
932        let cur_data: Vec<f64> = linspace_col(5.0, 6.0, 40);
933        let ref_mat = Array2::from_shape_vec((40, 1), ref_data).expect("shape ok");
934        let cur_mat = Array2::from_shape_vec((40, 1), cur_data).expect("shape ok");
935        let config = DriftDetectorConfig {
936            method: DriftMethod::MaximumMeanDiscrepancy,
937            mmd_bandwidth: None,
938            ..Default::default()
939        };
940
941        let detector = DriftDetector::fit(&ref_mat, config);
942        let report = detector.detect(&cur_mat).expect("detect ok");
943        // Just verify it runs without error and produces a non-negative score
944        assert!(report.overall_score >= 0.0, "MMD² must be non-negative");
945    }
946
947    // ------------------------------------------------------------------
948    // update_reference
949    // ------------------------------------------------------------------
950
951    #[test]
952    fn test_update_reference() {
953        let initial_ref = zeros_matrix(50, 2);
954        let config = DriftDetectorConfig::default();
955        let mut detector = DriftDetector::fit(&initial_ref, config);
956
957        // Update reference to drifted data
958        let shifted_ref = Array2::from_elem((50, 2), 5.0);
959        detector.update_reference(&shifted_ref);
960
961        // Current is same as new reference → no drift
962        let current = Array2::from_elem((50, 2), 5.0);
963        let report = detector.detect(&current).expect("detect ok");
964        assert!(!report.drifted, "after update, same data should not drift");
965    }
966
967    // ------------------------------------------------------------------
968    // Edge cases / error handling
969    // ------------------------------------------------------------------
970
971    #[test]
972    fn test_dimension_mismatch_error() {
973        let ref_mat = zeros_matrix(50, 3);
974        let cur_mat = zeros_matrix(50, 2);
975        let config = DriftDetectorConfig::default();
976        let detector = DriftDetector::fit(&ref_mat, config);
977        let result = detector.detect(&cur_mat);
978        assert!(result.is_err(), "should error on dimension mismatch");
979    }
980
981    #[test]
982    fn test_ks_test_function() {
983        let x: Vec<f64> = linspace_col(0.0, 1.0, 50);
984        let y: Vec<f64> = linspace_col(0.0, 1.0, 50);
985        let (d, p) = ks_test(&x, &y).expect("ks test ok");
986        assert!(d < 0.05, "KS stat should be near 0");
987        assert!(p > 0.5, "p-value should be high");
988    }
989
990    #[test]
991    fn test_psi_function() {
992        let ref_data: Vec<f64> = linspace_col(0.0, 1.0, 100);
993        let cur_data: Vec<f64> = linspace_col(0.0, 1.0, 100);
994        let score = psi(&ref_data, &cur_data, 10).expect("psi ok");
995        assert!(score < 0.01, "PSI should be near 0 for same distribution");
996    }
997
998    #[test]
999    fn test_wasserstein_distance_1d_function() {
1000        let x = vec![0.0, 1.0, 2.0, 3.0];
1001        let y = vec![1.0, 2.0, 3.0, 4.0];
1002        let d = wasserstein_distance_1d(&x, &y).expect("w1 ok");
1003        assert!((d - 1.0).abs() < 0.05, "W1 should be ~1.0, got {d}");
1004    }
1005
1006    #[test]
1007    fn test_mmd_rbf_function() {
1008        let x = vec![0.0f64; 10];
1009        let y = vec![0.0f64; 10];
1010        let mmd2 = mmd_rbf(&x, &y, 1.0).expect("mmd ok");
1011        assert!(mmd2 < 1e-8, "MMD of identical samples should be 0");
1012    }
1013}