Skip to main content

scouter_types/psi/
alert.rs

1use crate::error::TypeError;
2use crate::{
3    AlertDispatchConfig, AlertDispatchType, AlertMap, CommonCrons, DispatchAlertDescription,
4    OpsGenieDispatchConfig, SlackDispatchConfig, ValidateAlertConfig,
5};
6use core::fmt::Debug;
7use pyo3::prelude::*;
8use pyo3::types::PyString;
9use pyo3::IntoPyObjectExt;
10use serde::{Deserialize, Serialize};
11use statrs::distribution::{ChiSquared, ContinuousCDF, Normal};
12
13#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
14pub enum PsiThreshold {
15    Normal(PsiNormalThreshold),
16    ChiSquare(PsiChiSquareThreshold),
17    Fixed(PsiFixedThreshold),
18}
19
20impl PsiThreshold {
21    pub fn config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
22        match self {
23            PsiThreshold::Normal(config) => config.clone().into_bound_py_any(py),
24            PsiThreshold::ChiSquare(config) => config.clone().into_bound_py_any(py),
25            PsiThreshold::Fixed(config) => config.clone().into_bound_py_any(py),
26        }
27    }
28
29    pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
30        match self {
31            PsiThreshold::Normal(normal) => normal.compute_threshold(target_sample_size, bin_count),
32            PsiThreshold::ChiSquare(chi) => chi.compute_threshold(target_sample_size, bin_count),
33            PsiThreshold::Fixed(fixed) => fixed.compute_threshold(),
34        }
35    }
36}
37
38impl Default for PsiThreshold {
39    // Default threshold is ChiSquare with alpha = 0.05
40    fn default() -> Self {
41        PsiThreshold::ChiSquare(PsiChiSquareThreshold { alpha: 0.05 })
42    }
43}
44
45#[pyclass]
46#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
47pub struct PsiNormalThreshold {
48    #[pyo3(get, set)]
49    pub alpha: f64,
50}
51
52impl PsiNormalThreshold {
53    /// Based on Yurdakul (2018) "Statistical Properties of Population Stability Index"
54    /// Method I (Section 3.1.1): Normal approximation for one-sample case (fixed base)
55    ///
56    /// Paper: https://scholarworks.wmich.edu/dissertations/3208
57    ///
58    /// Formula: PSI > (B-1)/M + z_α × √[2(B-1)]/M
59    /// where the base population is treated as fixed and only target sample is random
60    #[allow(non_snake_case)]
61    pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
62        let M = target_sample_size as f64;
63        let B = bin_count as f64;
64
65        let normal = Normal::new(0.0, 1.0).unwrap();
66        let z_alpha = normal.inverse_cdf(1.0 - self.alpha);
67
68        let exp_val = (B - 1.0) / M;
69        let std_dev = (2.0 * (B - 1.0)).sqrt() / M;
70
71        exp_val + z_alpha * std_dev
72    }
73}
74
75#[pymethods]
76impl PsiNormalThreshold {
77    #[new]
78    #[pyo3(signature = (alpha=0.05))]
79    pub fn new(alpha: f64) -> PyResult<Self> {
80        if !(0.0..1.0).contains(&alpha) {
81            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
82                "alpha must be between 0.0 and 1.0 (exclusive)",
83            ));
84        }
85        Ok(Self { alpha })
86    }
87}
88
89#[pyclass]
90#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
91pub struct PsiChiSquareThreshold {
92    #[pyo3(get, set)]
93    pub alpha: f64,
94}
95
96impl PsiChiSquareThreshold {
97    /// Based on Yurdakul (2018) "Statistical Properties of Population Stability Index"
98    /// Method II (Section 3.1.2): Chi-square approximation for one-sample case (fixed base)
99    ///
100    /// Paper: https://scholarworks.wmich.edu/dissertations/3208
101    ///
102    /// Formula: PSI > χ²_{α,B-1} × (1/M)
103    /// where the base population is treated as fixed and only target sample is random
104    #[allow(non_snake_case)]
105    pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
106        let M = target_sample_size as f64;
107        let B = bin_count as f64;
108        let chi2 = ChiSquared::new(B - 1.0).unwrap();
109        let chi2_critical = chi2.inverse_cdf(1.0 - self.alpha);
110
111        chi2_critical / M
112    }
113}
114
115#[pymethods]
116impl PsiChiSquareThreshold {
117    #[new]
118    #[pyo3(signature = (alpha=0.05))]
119    pub fn new(alpha: f64) -> PyResult<Self> {
120        if !(0.0..1.0).contains(&alpha) {
121            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
122                "alpha must be between 0.0 and 1.0 (exclusive)",
123            ));
124        }
125        Ok(Self { alpha })
126    }
127}
128
129#[pyclass]
130#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
131pub struct PsiFixedThreshold {
132    #[pyo3(get, set)]
133    pub threshold: f64,
134}
135
136impl PsiFixedThreshold {
137    pub fn compute_threshold(&self) -> f64 {
138        self.threshold
139    }
140}
141
142#[pymethods]
143impl PsiFixedThreshold {
144    #[new]
145    #[pyo3(signature = (threshold=0.25))]
146    pub fn new(threshold: f64) -> PyResult<Self> {
147        if threshold < 0.0 {
148            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
149                "Threshold values must be non-zero",
150            ));
151        }
152        Ok(Self { threshold })
153    }
154}
155
156#[pyclass]
157#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
158pub struct PsiAlertConfig {
159    #[pyo3(get, set)]
160    pub schedule: String,
161
162    #[pyo3(get, set)]
163    pub features_to_monitor: Vec<String>,
164
165    pub dispatch_config: AlertDispatchConfig,
166
167    pub threshold: PsiThreshold,
168}
169
170impl Default for PsiAlertConfig {
171    fn default() -> PsiAlertConfig {
172        Self {
173            schedule: CommonCrons::EveryDay.cron(),
174            features_to_monitor: Vec::new(),
175            dispatch_config: AlertDispatchConfig::default(),
176            threshold: PsiThreshold::default(),
177        }
178    }
179}
180
181impl ValidateAlertConfig for PsiAlertConfig {}
182
183#[pymethods]
184impl PsiAlertConfig {
185    #[new]
186    #[pyo3(signature = (schedule=None, features_to_monitor=vec![], dispatch_config=None, threshold=None))]
187    pub fn new(
188        schedule: Option<&Bound<'_, PyAny>>,
189        features_to_monitor: Vec<String>,
190        dispatch_config: Option<&Bound<'_, PyAny>>,
191        threshold: Option<&Bound<'_, PyAny>>,
192    ) -> Result<Self, TypeError> {
193        let dispatch_config = match dispatch_config {
194            None => AlertDispatchConfig::default(),
195            Some(config) => {
196                if config.is_instance_of::<SlackDispatchConfig>() {
197                    AlertDispatchConfig::Slack(config.extract()?)
198                } else if config.is_instance_of::<OpsGenieDispatchConfig>() {
199                    AlertDispatchConfig::OpsGenie(config.extract()?)
200                } else {
201                    return Err(TypeError::InvalidDispatchConfigError);
202                }
203            }
204        };
205
206        let threshold = match threshold {
207            None => PsiThreshold::default(),
208            Some(config) => {
209                if config.is_instance_of::<PsiNormalThreshold>() {
210                    PsiThreshold::Normal(config.extract()?)
211                } else if config.is_instance_of::<PsiChiSquareThreshold>() {
212                    PsiThreshold::ChiSquare(config.extract()?)
213                } else if config.is_instance_of::<PsiFixedThreshold>() {
214                    // ← Fixed bug
215                    PsiThreshold::Fixed(config.extract()?)
216                } else {
217                    return Err(TypeError::InvalidPsiThresholdError);
218                }
219            }
220        };
221
222        let schedule = match schedule {
223            Some(schedule) => {
224                if schedule.is_instance_of::<PyString>() {
225                    schedule.to_string()
226                } else if schedule.is_instance_of::<CommonCrons>() {
227                    schedule.extract::<CommonCrons>()?.cron()
228                } else {
229                    return Err(TypeError::InvalidScheduleError);
230                }
231            }
232            None => CommonCrons::EveryDay.cron(),
233        };
234
235        let schedule = Self::resolve_schedule(&schedule);
236
237        Ok(Self {
238            schedule,
239            features_to_monitor,
240            dispatch_config,
241            threshold,
242        })
243    }
244    #[getter]
245    pub fn dispatch_type(&self) -> AlertDispatchType {
246        self.dispatch_config.dispatch_type()
247    }
248
249    #[getter]
250    pub fn dispatch_config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
251        self.dispatch_config.config(py)
252    }
253
254    #[getter]
255    pub fn threshold<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
256        self.threshold.config(py)
257    }
258}
259
260#[derive(Serialize, Deserialize, Debug, Default, Clone)]
261pub struct PsiFeatureAlert {
262    pub feature: String,
263    pub drift: f64,
264    pub threshold: f64,
265}
266impl From<PsiFeatureAlert> for AlertMap {
267    fn from(val: PsiFeatureAlert) -> Self {
268        AlertMap::Psi(val)
269    }
270}
271
272pub struct PsiFeatureAlerts {
273    pub alerts: Vec<PsiFeatureAlert>,
274}
275
276impl DispatchAlertDescription for PsiFeatureAlerts {
277    fn create_alert_description(&self, dispatch_type: AlertDispatchType) -> String {
278        let mut alert_description = String::new();
279
280        for (i, alert) in self.alerts.iter().enumerate() {
281            let description = format!("Feature '{}' has experienced drift, with a current PSI score of {} that exceeds the configured threshold of {}.", alert.feature, alert.drift, alert.threshold);
282
283            if i == 0 {
284                let header = "PSI Drift has been detected for the following features:\n";
285                alert_description.push_str(header);
286            }
287
288            let feature_name = match dispatch_type {
289                AlertDispatchType::Console | AlertDispatchType::OpsGenie => {
290                    format!("{:indent$}{}: \n", "", alert.feature, indent = 4)
291                }
292                AlertDispatchType::Slack => format!("{}: \n", alert.feature),
293            };
294
295            alert_description = format!("{alert_description}{feature_name}");
296
297            let alert_details = match dispatch_type {
298                AlertDispatchType::Console | AlertDispatchType::OpsGenie => {
299                    format!("{:indent$}Drift Value: {}\n", "", description, indent = 8)
300                }
301                AlertDispatchType::Slack => {
302                    format!("{:indent$}Drift Value: {}\n", "", description, indent = 4)
303                }
304            };
305            alert_description = format!("{alert_description}{alert_details}");
306        }
307        alert_description
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use approx::assert_relative_eq;
315
316    #[test]
317    fn test_compute_threshold_method_i_paper_validation() {
318        // Test based on Yurdakul (2018) Method I: Normal approximation for fixed base population
319        //
320        // Test case from Table 3.1 in the paper
321        // B = 10 bins, M = 400 sample size, α = 0.05 (95th percentile)
322        let threshold = PsiNormalThreshold { alpha: 0.05 };
323        let result = threshold.compute_threshold(400, 10);
324
325        // From Table 3.1: Expected ~4.0% for N=∞, M=400, B=10 using normal approximation
326        // Expected value = 9 / 400 = 0.0225
327        // Std dev = sqrt(2 * 9) / 400 ≈ 4.24 / 400 ≈ 0.0106
328        // z_α for 95th percentile ≈ 1.645
329        // Threshold ≈ 0.0225 + 1.645 * 0.0106 ≈ 0.0400
330        assert_relative_eq!(result, 0.0400, epsilon = 0.002);
331    }
332
333    #[test]
334    fn test_compute_threshold_method_ii_paper_validation() {
335        // Test based on Yurdakul (2018) Method II: PSI > χ²_{α,B-1} × (1/M)
336
337        // Test case from Tables 3.2 and 3.4 in the paper
338        // B=10 bins, M=400 sample size, α=0.05 (95th percentile)
339        let threshold = PsiChiSquareThreshold { alpha: 0.05 };
340        let result = threshold.compute_threshold(400, 10);
341
342        // From Table 3.2: Expected ~8.5% for N=∞, M=400, B=10
343        // Chi-square with 9 df at 95th percentile ≈ 16.919
344        // Expected: 16.919 / 400 ≈ 0.0423 (4.23%)
345        assert_relative_eq!(result, 0.0423, epsilon = 0.002);
346
347        // Test case: B=20 bins, M=1000 sample size, α=0.05
348        let result_20_bins = threshold.compute_threshold(1000, 20);
349        // Chi-square with 19 df at 95th percentile ≈ 30.144
350        // Expected: 30.144 / 1000 ≈ 0.0301 (3.01%)
351        assert_relative_eq!(result_20_bins, 0.0301, epsilon = 0.002);
352    }
353
354    #[test]
355    fn test_compute_threshold_paper_table_values() {
356        // Validate against Table 3.2 from the paper
357        // Method II: P95 of χ²_{B-1}, B=10
358
359        let threshold = PsiChiSquareThreshold { alpha: 0.05 };
360
361        // Sample sizes from the paper's table
362        let test_cases = [
363            (100, 0.169),  // M=100 → ~16.9%
364            (200, 0.085),  // M=200 → ~8.5%
365            (400, 0.042),  // M=400 → ~4.2%
366            (1000, 0.017), // M=1000 → ~1.7%
367        ];
368
369        for (sample_size, expected_approx) in test_cases {
370            let result = threshold.compute_threshold(sample_size, 10);
371            let diff = (result - expected_approx).abs();
372
373            if diff >= 0.005 {
374                panic!(
375                    "Failed for sample size {sample_size}: expected ~{expected_approx}, got {result}, diff={diff}"
376                );
377            }
378        }
379    }
380
381    #[test]
382    fn test_degrees_of_freedom_relationship_chi() {
383        // Test that B-1 degrees of freedom is correctly applied
384        let threshold = PsiChiSquareThreshold { alpha: 0.05 };
385
386        // More bins (higher df) should give larger chi-square critical values
387        let bins_5 = threshold.compute_threshold(1000, 5); // 4 df
388        let bins_10 = threshold.compute_threshold(1000, 10); // 9 df
389        let bins_20 = threshold.compute_threshold(1000, 20); // 19 df
390
391        assert!(
392            bins_5 < bins_10,
393            "5 bins should give smaller threshold than 10 bins"
394        );
395        assert!(
396            bins_10 < bins_20,
397            "10 bins should give smaller threshold than 20 bins"
398        );
399    }
400
401    #[test]
402    fn test_degrees_of_freedom_relationship_normal() {
403        let threshold = PsiNormalThreshold { alpha: 0.05 };
404
405        let t_5 = threshold.compute_threshold(1000, 5);
406        let t_10 = threshold.compute_threshold(1000, 10);
407        let t_20 = threshold.compute_threshold(1000, 20);
408
409        assert!(t_5 < t_10 && t_10 < t_20);
410    }
411
412    #[test]
413    fn test_alpha_significance_levels_chi() {
414        // Test different alpha values (significance levels)
415        let sample_size = 1000;
416        let bin_count = 10;
417
418        let alpha_01 = PsiChiSquareThreshold { alpha: 0.01 }; // 99th percentile
419        let alpha_05 = PsiChiSquareThreshold { alpha: 0.05 }; // 95th percentile
420        let alpha_10 = PsiChiSquareThreshold { alpha: 0.10 }; // 90th percentile
421
422        let threshold_99 = alpha_01.compute_threshold(sample_size, bin_count);
423        let threshold_95 = alpha_05.compute_threshold(sample_size, bin_count);
424        let threshold_90 = alpha_10.compute_threshold(sample_size, bin_count);
425
426        // More conservative (lower alpha) should give higher thresholds
427        assert!(
428            threshold_99 > threshold_95,
429            "99th percentile should be higher than 95th: {threshold_99} > {threshold_95}"
430        );
431        assert!(
432            threshold_95 > threshold_90,
433            "95th percentile should be higher than 90th: {threshold_95} > {threshold_90}"
434        );
435    }
436
437    #[test]
438    fn test_alpha_significance_levels_normal() {
439        // Test different alpha values (significance levels)
440        let sample_size = 1000;
441        let bin_count = 10;
442
443        let alpha_01 = PsiNormalThreshold { alpha: 0.01 }; // 99th percentile
444        let alpha_05 = PsiNormalThreshold { alpha: 0.05 }; // 95th percentile
445        let alpha_10 = PsiNormalThreshold { alpha: 0.10 }; // 90th percentile
446
447        let threshold_99 = alpha_01.compute_threshold(sample_size, bin_count);
448        let threshold_95 = alpha_05.compute_threshold(sample_size, bin_count);
449        let threshold_90 = alpha_10.compute_threshold(sample_size, bin_count);
450
451        // More conservative (lower alpha) should give higher thresholds
452        assert!(
453            threshold_99 > threshold_95,
454            "99th percentile should be higher than 95th: {threshold_99} > {threshold_95}"
455        );
456        assert!(
457            threshold_95 > threshold_90,
458            "95th percentile should be higher than 90th: {threshold_95} > {threshold_90}"
459        );
460    }
461
462    #[test]
463    fn test_alert_config() {
464        //test console alert config
465        let alert_config = PsiAlertConfig::default();
466        assert_eq!(alert_config.dispatch_config, AlertDispatchConfig::default());
467        assert_eq!(alert_config.dispatch_type(), AlertDispatchType::Console);
468
469        //test slack alert config
470        let slack_alert_dispatch_config = SlackDispatchConfig {
471            channel: "test".to_string(),
472        };
473        let alert_config = PsiAlertConfig {
474            dispatch_config: AlertDispatchConfig::Slack(slack_alert_dispatch_config.clone()),
475            ..Default::default()
476        };
477        assert_eq!(
478            alert_config.dispatch_config,
479            AlertDispatchConfig::Slack(slack_alert_dispatch_config)
480        );
481        assert_eq!(alert_config.dispatch_type(), AlertDispatchType::Slack);
482
483        //test opsgenie alert config
484        let opsgenie_dispatch_config = AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
485            team: "test-team".to_string(),
486            priority: "P5".to_string(),
487        });
488        let alert_config = PsiAlertConfig {
489            dispatch_config: opsgenie_dispatch_config.clone(),
490            ..Default::default()
491        };
492
493        assert_eq!(
494            alert_config.dispatch_config,
495            opsgenie_dispatch_config.clone()
496        );
497        assert_eq!(alert_config.dispatch_type(), AlertDispatchType::OpsGenie);
498        assert_eq!(
499            match &alert_config.dispatch_config {
500                AlertDispatchConfig::OpsGenie(config) => &config.team,
501                _ => panic!("Expected OpsGenie dispatch config"),
502            },
503            "test-team"
504        );
505    }
506
507    #[test]
508    fn test_create_alert_description() {
509        let alerts = vec![
510            PsiFeatureAlert {
511                feature: "feature1".to_string(),
512                drift: 0.35,
513                threshold: 0.3,
514            },
515            PsiFeatureAlert {
516                feature: "feature2".to_string(),
517                drift: 0.45,
518                threshold: 0.3,
519            },
520        ];
521        let psi_feature_alerts = PsiFeatureAlerts { alerts };
522
523        // Test for Console dispatch type
524        let description = psi_feature_alerts.create_alert_description(AlertDispatchType::Console);
525        assert!(description.contains("PSI Drift has been detected for the following features:"));
526        assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
527        assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
528
529        // Test for Slack dispatch type
530        let description = psi_feature_alerts.create_alert_description(AlertDispatchType::Slack);
531        assert!(description.contains("PSI Drift has been detected for the following features:"));
532        assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
533        assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
534
535        // Test for OpsGenie dispatch type
536        let description = psi_feature_alerts.create_alert_description(AlertDispatchType::OpsGenie);
537        assert!(description.contains("PSI Drift has been detected for the following features:"));
538        assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
539        assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
540    }
541}