scouter_types/psi/
alert.rs

1use crate::error::TypeError;
2use crate::{
3    AlertDispatchConfig, AlertDispatchType, CommonCrons, DispatchAlertDescription,
4    OpsGenieDispatchConfig, SlackDispatchConfig, ValidateAlertConfig,
5};
6use core::fmt::Debug;
7use pyo3::prelude::*;
8use pyo3::types::PyString;
9use pyo3::IntoPyObjectExt;
10use serde::{Deserialize, Serialize};
11use statrs::distribution::{ChiSquared, ContinuousCDF, Normal};
12
13#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
14pub enum PsiThreshold {
15    Normal(PsiNormalThreshold),
16    ChiSquare(PsiChiSquareThreshold),
17    Fixed(PsiFixedThreshold),
18}
19
20impl PsiThreshold {
21    pub fn config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
22        match self {
23            PsiThreshold::Normal(config) => config.clone().into_bound_py_any(py),
24            PsiThreshold::ChiSquare(config) => config.clone().into_bound_py_any(py),
25            PsiThreshold::Fixed(config) => config.clone().into_bound_py_any(py),
26        }
27    }
28
29    pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
30        match self {
31            PsiThreshold::Normal(normal) => normal.compute_threshold(target_sample_size, bin_count),
32            PsiThreshold::ChiSquare(chi) => chi.compute_threshold(target_sample_size, bin_count),
33            PsiThreshold::Fixed(fixed) => fixed.compute_threshold(),
34        }
35    }
36}
37
38impl Default for PsiThreshold {
39    // Default threshold is ChiSquare with alpha = 0.05
40    fn default() -> Self {
41        PsiThreshold::ChiSquare(PsiChiSquareThreshold { alpha: 0.05 })
42    }
43}
44
45#[pyclass]
46#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
47pub struct PsiNormalThreshold {
48    #[pyo3(get, set)]
49    pub alpha: f64,
50}
51
52impl PsiNormalThreshold {
53    /// Based on Yurdakul (2018) "Statistical Properties of Population Stability Index"
54    /// Method I (Section 3.1.1): Normal approximation for one-sample case (fixed base)
55    ///
56    /// Paper: https://scholarworks.wmich.edu/dissertations/3208
57    ///
58    /// Formula: PSI > (B-1)/M + z_α × √[2(B-1)]/M
59    /// where the base population is treated as fixed and only target sample is random
60    #[allow(non_snake_case)]
61    pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
62        let M = target_sample_size as f64;
63        let B = bin_count as f64;
64
65        let normal = Normal::new(0.0, 1.0).unwrap();
66        let z_alpha = normal.inverse_cdf(1.0 - self.alpha);
67
68        let exp_val = (B - 1.0) / M;
69        let std_dev = (2.0 * (B - 1.0)).sqrt() / M;
70
71        exp_val + z_alpha * std_dev
72    }
73}
74
75#[pymethods]
76impl PsiNormalThreshold {
77    #[new]
78    #[pyo3(signature = (alpha=0.05))]
79    pub fn new(alpha: f64) -> PyResult<Self> {
80        if !(0.0..1.0).contains(&alpha) {
81            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
82                "alpha must be between 0.0 and 1.0 (exclusive)",
83            ));
84        }
85        Ok(Self { alpha })
86    }
87}
88
89#[pyclass]
90#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
91pub struct PsiChiSquareThreshold {
92    #[pyo3(get, set)]
93    pub alpha: f64,
94}
95
96impl PsiChiSquareThreshold {
97    /// Based on Yurdakul (2018) "Statistical Properties of Population Stability Index"
98    /// Method II (Section 3.1.2): Chi-square approximation for one-sample case (fixed base)
99    ///
100    /// Paper: https://scholarworks.wmich.edu/dissertations/3208
101    ///
102    /// Formula: PSI > χ²_{α,B-1} × (1/M)
103    /// where the base population is treated as fixed and only target sample is random
104    #[allow(non_snake_case)]
105    pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
106        let M = target_sample_size as f64;
107        let B = bin_count as f64;
108        let chi2 = ChiSquared::new(B - 1.0).unwrap();
109        let chi2_critical = chi2.inverse_cdf(1.0 - self.alpha);
110
111        chi2_critical / M
112    }
113}
114
115#[pymethods]
116impl PsiChiSquareThreshold {
117    #[new]
118    #[pyo3(signature = (alpha=0.05))]
119    pub fn new(alpha: f64) -> PyResult<Self> {
120        if !(0.0..1.0).contains(&alpha) {
121            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
122                "alpha must be between 0.0 and 1.0 (exclusive)",
123            ));
124        }
125        Ok(Self { alpha })
126    }
127}
128
129#[pyclass]
130#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
131pub struct PsiFixedThreshold {
132    #[pyo3(get, set)]
133    pub threshold: f64,
134}
135
136impl PsiFixedThreshold {
137    pub fn compute_threshold(&self) -> f64 {
138        self.threshold
139    }
140}
141
142#[pymethods]
143impl PsiFixedThreshold {
144    #[new]
145    #[pyo3(signature = (threshold=0.25))]
146    pub fn new(threshold: f64) -> PyResult<Self> {
147        if threshold < 0.0 {
148            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
149                "Threshold values must be non-zero",
150            ));
151        }
152        Ok(Self { threshold })
153    }
154}
155
156#[pyclass]
157#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
158pub struct PsiAlertConfig {
159    #[pyo3(get, set)]
160    pub schedule: String,
161
162    #[pyo3(get, set)]
163    pub features_to_monitor: Vec<String>,
164
165    pub dispatch_config: AlertDispatchConfig,
166
167    pub threshold: PsiThreshold,
168}
169
170impl Default for PsiAlertConfig {
171    fn default() -> PsiAlertConfig {
172        Self {
173            schedule: CommonCrons::EveryDay.cron(),
174            features_to_monitor: Vec::new(),
175            dispatch_config: AlertDispatchConfig::default(),
176            threshold: PsiThreshold::default(),
177        }
178    }
179}
180
181impl ValidateAlertConfig for PsiAlertConfig {}
182
183#[pymethods]
184impl PsiAlertConfig {
185    #[new]
186    #[pyo3(signature = (schedule=None, features_to_monitor=vec![], dispatch_config=None, threshold=None))]
187    pub fn new(
188        schedule: Option<&Bound<'_, PyAny>>,
189        features_to_monitor: Vec<String>,
190        dispatch_config: Option<&Bound<'_, PyAny>>,
191        threshold: Option<&Bound<'_, PyAny>>,
192    ) -> Result<Self, TypeError> {
193        let dispatch_config = match dispatch_config {
194            None => AlertDispatchConfig::default(),
195            Some(config) => {
196                if config.is_instance_of::<SlackDispatchConfig>() {
197                    AlertDispatchConfig::Slack(config.extract()?)
198                } else if config.is_instance_of::<OpsGenieDispatchConfig>() {
199                    AlertDispatchConfig::OpsGenie(config.extract()?)
200                } else {
201                    return Err(TypeError::InvalidDispatchConfigError);
202                }
203            }
204        };
205
206        let threshold = match threshold {
207            None => PsiThreshold::default(),
208            Some(config) => {
209                if config.is_instance_of::<PsiNormalThreshold>() {
210                    PsiThreshold::Normal(config.extract()?)
211                } else if config.is_instance_of::<PsiChiSquareThreshold>() {
212                    PsiThreshold::ChiSquare(config.extract()?)
213                } else if config.is_instance_of::<PsiFixedThreshold>() {
214                    // ← Fixed bug
215                    PsiThreshold::Fixed(config.extract()?)
216                } else {
217                    return Err(TypeError::InvalidPsiThresholdError);
218                }
219            }
220        };
221
222        let schedule = match schedule {
223            Some(schedule) => {
224                if schedule.is_instance_of::<PyString>() {
225                    schedule.to_string()
226                } else if schedule.is_instance_of::<CommonCrons>() {
227                    schedule.extract::<CommonCrons>()?.cron()
228                } else {
229                    return Err(TypeError::InvalidScheduleError);
230                }
231            }
232            None => CommonCrons::EveryDay.cron(),
233        };
234
235        let schedule = Self::resolve_schedule(&schedule);
236
237        Ok(Self {
238            schedule,
239            features_to_monitor,
240            dispatch_config,
241            threshold,
242        })
243    }
244    #[getter]
245    pub fn dispatch_type(&self) -> AlertDispatchType {
246        self.dispatch_config.dispatch_type()
247    }
248
249    #[getter]
250    pub fn dispatch_config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
251        self.dispatch_config.config(py)
252    }
253
254    #[getter]
255    pub fn threshold<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
256        self.threshold.config(py)
257    }
258}
259
260#[derive(Clone, Debug)]
261pub struct PsiFeatureAlert {
262    pub feature: String,
263    pub drift: f64,
264    pub threshold: f64,
265}
266
267pub struct PsiFeatureAlerts {
268    pub alerts: Vec<PsiFeatureAlert>,
269}
270
271impl DispatchAlertDescription for PsiFeatureAlerts {
272    fn create_alert_description(&self, dispatch_type: AlertDispatchType) -> String {
273        let mut alert_description = String::new();
274
275        for (i, alert) in self.alerts.iter().enumerate() {
276            let description = format!("Feature '{}' has experienced drift, with a current PSI score of {} that exceeds the configured threshold of {}.", alert.feature, alert.drift, alert.threshold);
277
278            if i == 0 {
279                let header = "PSI Drift has been detected for the following features:\n";
280                alert_description.push_str(header);
281            }
282
283            let feature_name = match dispatch_type {
284                AlertDispatchType::Console | AlertDispatchType::OpsGenie => {
285                    format!("{:indent$}{}: \n", "", alert.feature, indent = 4)
286                }
287                AlertDispatchType::Slack => format!("{}: \n", alert.feature),
288            };
289
290            alert_description = format!("{alert_description}{feature_name}");
291
292            let alert_details = match dispatch_type {
293                AlertDispatchType::Console | AlertDispatchType::OpsGenie => {
294                    format!("{:indent$}Drift Value: {}\n", "", description, indent = 8)
295                }
296                AlertDispatchType::Slack => {
297                    format!("{:indent$}Drift Value: {}\n", "", description, indent = 4)
298                }
299            };
300            alert_description = format!("{alert_description}{alert_details}");
301        }
302        alert_description
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use approx::assert_relative_eq;
310
311    #[test]
312    fn test_compute_threshold_method_i_paper_validation() {
313        // Test based on Yurdakul (2018) Method I: Normal approximation for fixed base population
314        //
315        // Test case from Table 3.1 in the paper
316        // B = 10 bins, M = 400 sample size, α = 0.05 (95th percentile)
317        let threshold = PsiNormalThreshold { alpha: 0.05 };
318        let result = threshold.compute_threshold(400, 10);
319
320        // From Table 3.1: Expected ~4.0% for N=∞, M=400, B=10 using normal approximation
321        // Expected value = 9 / 400 = 0.0225
322        // Std dev = sqrt(2 * 9) / 400 ≈ 4.24 / 400 ≈ 0.0106
323        // z_α for 95th percentile ≈ 1.645
324        // Threshold ≈ 0.0225 + 1.645 * 0.0106 ≈ 0.0400
325        assert_relative_eq!(result, 0.0400, epsilon = 0.002);
326    }
327
328    #[test]
329    fn test_compute_threshold_method_ii_paper_validation() {
330        // Test based on Yurdakul (2018) Method II: PSI > χ²_{α,B-1} × (1/M)
331
332        // Test case from Tables 3.2 and 3.4 in the paper
333        // B=10 bins, M=400 sample size, α=0.05 (95th percentile)
334        let threshold = PsiChiSquareThreshold { alpha: 0.05 };
335        let result = threshold.compute_threshold(400, 10);
336
337        // From Table 3.2: Expected ~8.5% for N=∞, M=400, B=10
338        // Chi-square with 9 df at 95th percentile ≈ 16.919
339        // Expected: 16.919 / 400 ≈ 0.0423 (4.23%)
340        assert_relative_eq!(result, 0.0423, epsilon = 0.002);
341
342        // Test case: B=20 bins, M=1000 sample size, α=0.05
343        let result_20_bins = threshold.compute_threshold(1000, 20);
344        // Chi-square with 19 df at 95th percentile ≈ 30.144
345        // Expected: 30.144 / 1000 ≈ 0.0301 (3.01%)
346        assert_relative_eq!(result_20_bins, 0.0301, epsilon = 0.002);
347    }
348
349    #[test]
350    fn test_compute_threshold_paper_table_values() {
351        // Validate against Table 3.2 from the paper
352        // Method II: P95 of χ²_{B-1}, B=10
353
354        let threshold = PsiChiSquareThreshold { alpha: 0.05 };
355
356        // Sample sizes from the paper's table
357        let test_cases = [
358            (100, 0.169),  // M=100 → ~16.9%
359            (200, 0.085),  // M=200 → ~8.5%
360            (400, 0.042),  // M=400 → ~4.2%
361            (1000, 0.017), // M=1000 → ~1.7%
362        ];
363
364        for (sample_size, expected_approx) in test_cases {
365            let result = threshold.compute_threshold(sample_size, 10);
366            let diff = (result - expected_approx).abs();
367
368            if diff >= 0.005 {
369                panic!(
370                    "Failed for sample size {sample_size}: expected ~{expected_approx}, got {result}, diff={diff}"
371                );
372            }
373        }
374    }
375
376    #[test]
377    fn test_degrees_of_freedom_relationship_chi() {
378        // Test that B-1 degrees of freedom is correctly applied
379        let threshold = PsiChiSquareThreshold { alpha: 0.05 };
380
381        // More bins (higher df) should give larger chi-square critical values
382        let bins_5 = threshold.compute_threshold(1000, 5); // 4 df
383        let bins_10 = threshold.compute_threshold(1000, 10); // 9 df
384        let bins_20 = threshold.compute_threshold(1000, 20); // 19 df
385
386        assert!(
387            bins_5 < bins_10,
388            "5 bins should give smaller threshold than 10 bins"
389        );
390        assert!(
391            bins_10 < bins_20,
392            "10 bins should give smaller threshold than 20 bins"
393        );
394    }
395
396    #[test]
397    fn test_degrees_of_freedom_relationship_normal() {
398        let threshold = PsiNormalThreshold { alpha: 0.05 };
399
400        let t_5 = threshold.compute_threshold(1000, 5);
401        let t_10 = threshold.compute_threshold(1000, 10);
402        let t_20 = threshold.compute_threshold(1000, 20);
403
404        assert!(t_5 < t_10 && t_10 < t_20);
405    }
406
407    #[test]
408    fn test_alpha_significance_levels_chi() {
409        // Test different alpha values (significance levels)
410        let sample_size = 1000;
411        let bin_count = 10;
412
413        let alpha_01 = PsiChiSquareThreshold { alpha: 0.01 }; // 99th percentile
414        let alpha_05 = PsiChiSquareThreshold { alpha: 0.05 }; // 95th percentile
415        let alpha_10 = PsiChiSquareThreshold { alpha: 0.10 }; // 90th percentile
416
417        let threshold_99 = alpha_01.compute_threshold(sample_size, bin_count);
418        let threshold_95 = alpha_05.compute_threshold(sample_size, bin_count);
419        let threshold_90 = alpha_10.compute_threshold(sample_size, bin_count);
420
421        // More conservative (lower alpha) should give higher thresholds
422        assert!(
423            threshold_99 > threshold_95,
424            "99th percentile should be higher than 95th: {threshold_99} > {threshold_95}"
425        );
426        assert!(
427            threshold_95 > threshold_90,
428            "95th percentile should be higher than 90th: {threshold_95} > {threshold_90}"
429        );
430    }
431
432    #[test]
433    fn test_alpha_significance_levels_normal() {
434        // Test different alpha values (significance levels)
435        let sample_size = 1000;
436        let bin_count = 10;
437
438        let alpha_01 = PsiNormalThreshold { alpha: 0.01 }; // 99th percentile
439        let alpha_05 = PsiNormalThreshold { alpha: 0.05 }; // 95th percentile
440        let alpha_10 = PsiNormalThreshold { alpha: 0.10 }; // 90th percentile
441
442        let threshold_99 = alpha_01.compute_threshold(sample_size, bin_count);
443        let threshold_95 = alpha_05.compute_threshold(sample_size, bin_count);
444        let threshold_90 = alpha_10.compute_threshold(sample_size, bin_count);
445
446        // More conservative (lower alpha) should give higher thresholds
447        assert!(
448            threshold_99 > threshold_95,
449            "99th percentile should be higher than 95th: {threshold_99} > {threshold_95}"
450        );
451        assert!(
452            threshold_95 > threshold_90,
453            "95th percentile should be higher than 90th: {threshold_95} > {threshold_90}"
454        );
455    }
456
457    #[test]
458    fn test_alert_config() {
459        //test console alert config
460        let alert_config = PsiAlertConfig::default();
461        assert_eq!(alert_config.dispatch_config, AlertDispatchConfig::default());
462        assert_eq!(alert_config.dispatch_type(), AlertDispatchType::Console);
463
464        //test slack alert config
465        let slack_alert_dispatch_config = SlackDispatchConfig {
466            channel: "test".to_string(),
467        };
468        let alert_config = PsiAlertConfig {
469            dispatch_config: AlertDispatchConfig::Slack(slack_alert_dispatch_config.clone()),
470            ..Default::default()
471        };
472        assert_eq!(
473            alert_config.dispatch_config,
474            AlertDispatchConfig::Slack(slack_alert_dispatch_config)
475        );
476        assert_eq!(alert_config.dispatch_type(), AlertDispatchType::Slack);
477
478        //test opsgenie alert config
479        let opsgenie_dispatch_config = AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
480            team: "test-team".to_string(),
481            priority: "P5".to_string(),
482        });
483        let alert_config = PsiAlertConfig {
484            dispatch_config: opsgenie_dispatch_config.clone(),
485            ..Default::default()
486        };
487
488        assert_eq!(
489            alert_config.dispatch_config,
490            opsgenie_dispatch_config.clone()
491        );
492        assert_eq!(alert_config.dispatch_type(), AlertDispatchType::OpsGenie);
493        assert_eq!(
494            match &alert_config.dispatch_config {
495                AlertDispatchConfig::OpsGenie(config) => &config.team,
496                _ => panic!("Expected OpsGenie dispatch config"),
497            },
498            "test-team"
499        );
500    }
501
502    #[test]
503    fn test_create_alert_description() {
504        let alerts = vec![
505            PsiFeatureAlert {
506                feature: "feature1".to_string(),
507                drift: 0.35,
508                threshold: 0.3,
509            },
510            PsiFeatureAlert {
511                feature: "feature2".to_string(),
512                drift: 0.45,
513                threshold: 0.3,
514            },
515        ];
516        let psi_feature_alerts = PsiFeatureAlerts { alerts };
517
518        // Test for Console dispatch type
519        let description = psi_feature_alerts.create_alert_description(AlertDispatchType::Console);
520        assert!(description.contains("PSI Drift has been detected for the following features:"));
521        assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
522        assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
523
524        // Test for Slack dispatch type
525        let description = psi_feature_alerts.create_alert_description(AlertDispatchType::Slack);
526        assert!(description.contains("PSI Drift has been detected for the following features:"));
527        assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
528        assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
529
530        // Test for OpsGenie dispatch type
531        let description = psi_feature_alerts.create_alert_description(AlertDispatchType::OpsGenie);
532        assert!(description.contains("PSI Drift has been detected for the following features:"));
533        assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
534        assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
535    }
536}