1use crate::error::TypeError;
2use crate::{
3 AlertDispatchConfig, AlertDispatchType, CommonCrons, DispatchAlertDescription,
4 OpsGenieDispatchConfig, SlackDispatchConfig, ValidateAlertConfig,
5};
6use core::fmt::Debug;
7use pyo3::prelude::*;
8use pyo3::types::PyString;
9use pyo3::IntoPyObjectExt;
10use serde::{Deserialize, Serialize};
11use statrs::distribution::{ChiSquared, ContinuousCDF, Normal};
12
13#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
14pub enum PsiThreshold {
15 Normal(PsiNormalThreshold),
16 ChiSquare(PsiChiSquareThreshold),
17 Fixed(PsiFixedThreshold),
18}
19
20impl PsiThreshold {
21 pub fn config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
22 match self {
23 PsiThreshold::Normal(config) => config.clone().into_bound_py_any(py),
24 PsiThreshold::ChiSquare(config) => config.clone().into_bound_py_any(py),
25 PsiThreshold::Fixed(config) => config.clone().into_bound_py_any(py),
26 }
27 }
28
29 pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
30 match self {
31 PsiThreshold::Normal(normal) => normal.compute_threshold(target_sample_size, bin_count),
32 PsiThreshold::ChiSquare(chi) => chi.compute_threshold(target_sample_size, bin_count),
33 PsiThreshold::Fixed(fixed) => fixed.compute_threshold(),
34 }
35 }
36}
37
38impl Default for PsiThreshold {
39 fn default() -> Self {
41 PsiThreshold::ChiSquare(PsiChiSquareThreshold { alpha: 0.05 })
42 }
43}
44
45#[pyclass]
46#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
47pub struct PsiNormalThreshold {
48 #[pyo3(get, set)]
49 pub alpha: f64,
50}
51
52impl PsiNormalThreshold {
53 #[allow(non_snake_case)]
61 pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
62 let M = target_sample_size as f64;
63 let B = bin_count as f64;
64
65 let normal = Normal::new(0.0, 1.0).unwrap();
66 let z_alpha = normal.inverse_cdf(1.0 - self.alpha);
67
68 let exp_val = (B - 1.0) / M;
69 let std_dev = (2.0 * (B - 1.0)).sqrt() / M;
70
71 exp_val + z_alpha * std_dev
72 }
73}
74
75#[pymethods]
76impl PsiNormalThreshold {
77 #[new]
78 #[pyo3(signature = (alpha=0.05))]
79 pub fn new(alpha: f64) -> PyResult<Self> {
80 if !(0.0..1.0).contains(&alpha) {
81 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
82 "alpha must be between 0.0 and 1.0 (exclusive)",
83 ));
84 }
85 Ok(Self { alpha })
86 }
87}
88
89#[pyclass]
90#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
91pub struct PsiChiSquareThreshold {
92 #[pyo3(get, set)]
93 pub alpha: f64,
94}
95
96impl PsiChiSquareThreshold {
97 #[allow(non_snake_case)]
105 pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
106 let M = target_sample_size as f64;
107 let B = bin_count as f64;
108 let chi2 = ChiSquared::new(B - 1.0).unwrap();
109 let chi2_critical = chi2.inverse_cdf(1.0 - self.alpha);
110
111 chi2_critical / M
112 }
113}
114
115#[pymethods]
116impl PsiChiSquareThreshold {
117 #[new]
118 #[pyo3(signature = (alpha=0.05))]
119 pub fn new(alpha: f64) -> PyResult<Self> {
120 if !(0.0..1.0).contains(&alpha) {
121 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
122 "alpha must be between 0.0 and 1.0 (exclusive)",
123 ));
124 }
125 Ok(Self { alpha })
126 }
127}
128
129#[pyclass]
130#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
131pub struct PsiFixedThreshold {
132 #[pyo3(get, set)]
133 pub threshold: f64,
134}
135
136impl PsiFixedThreshold {
137 pub fn compute_threshold(&self) -> f64 {
138 self.threshold
139 }
140}
141
142#[pymethods]
143impl PsiFixedThreshold {
144 #[new]
145 #[pyo3(signature = (threshold=0.25))]
146 pub fn new(threshold: f64) -> PyResult<Self> {
147 if threshold < 0.0 {
148 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
149 "Threshold values must be non-zero",
150 ));
151 }
152 Ok(Self { threshold })
153 }
154}
155
156#[pyclass]
157#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
158pub struct PsiAlertConfig {
159 #[pyo3(get, set)]
160 pub schedule: String,
161
162 #[pyo3(get, set)]
163 pub features_to_monitor: Vec<String>,
164
165 pub dispatch_config: AlertDispatchConfig,
166
167 pub threshold: PsiThreshold,
168}
169
170impl Default for PsiAlertConfig {
171 fn default() -> PsiAlertConfig {
172 Self {
173 schedule: CommonCrons::EveryDay.cron(),
174 features_to_monitor: Vec::new(),
175 dispatch_config: AlertDispatchConfig::default(),
176 threshold: PsiThreshold::default(),
177 }
178 }
179}
180
181impl ValidateAlertConfig for PsiAlertConfig {}
182
183#[pymethods]
184impl PsiAlertConfig {
185 #[new]
186 #[pyo3(signature = (schedule=None, features_to_monitor=vec![], dispatch_config=None, threshold=None))]
187 pub fn new(
188 schedule: Option<&Bound<'_, PyAny>>,
189 features_to_monitor: Vec<String>,
190 dispatch_config: Option<&Bound<'_, PyAny>>,
191 threshold: Option<&Bound<'_, PyAny>>,
192 ) -> Result<Self, TypeError> {
193 let dispatch_config = match dispatch_config {
194 None => AlertDispatchConfig::default(),
195 Some(config) => {
196 if config.is_instance_of::<SlackDispatchConfig>() {
197 AlertDispatchConfig::Slack(config.extract()?)
198 } else if config.is_instance_of::<OpsGenieDispatchConfig>() {
199 AlertDispatchConfig::OpsGenie(config.extract()?)
200 } else {
201 return Err(TypeError::InvalidDispatchConfigError);
202 }
203 }
204 };
205
206 let threshold = match threshold {
207 None => PsiThreshold::default(),
208 Some(config) => {
209 if config.is_instance_of::<PsiNormalThreshold>() {
210 PsiThreshold::Normal(config.extract()?)
211 } else if config.is_instance_of::<PsiChiSquareThreshold>() {
212 PsiThreshold::ChiSquare(config.extract()?)
213 } else if config.is_instance_of::<PsiFixedThreshold>() {
214 PsiThreshold::Fixed(config.extract()?)
216 } else {
217 return Err(TypeError::InvalidPsiThresholdError);
218 }
219 }
220 };
221
222 let schedule = match schedule {
223 Some(schedule) => {
224 if schedule.is_instance_of::<PyString>() {
225 schedule.to_string()
226 } else if schedule.is_instance_of::<CommonCrons>() {
227 schedule.extract::<CommonCrons>()?.cron()
228 } else {
229 return Err(TypeError::InvalidScheduleError);
230 }
231 }
232 None => CommonCrons::EveryDay.cron(),
233 };
234
235 let schedule = Self::resolve_schedule(&schedule);
236
237 Ok(Self {
238 schedule,
239 features_to_monitor,
240 dispatch_config,
241 threshold,
242 })
243 }
244 #[getter]
245 pub fn dispatch_type(&self) -> AlertDispatchType {
246 self.dispatch_config.dispatch_type()
247 }
248
249 #[getter]
250 pub fn dispatch_config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
251 self.dispatch_config.config(py)
252 }
253
254 #[getter]
255 pub fn threshold<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
256 self.threshold.config(py)
257 }
258}
259
260#[derive(Clone, Debug)]
261pub struct PsiFeatureAlert {
262 pub feature: String,
263 pub drift: f64,
264 pub threshold: f64,
265}
266
267pub struct PsiFeatureAlerts {
268 pub alerts: Vec<PsiFeatureAlert>,
269}
270
271impl DispatchAlertDescription for PsiFeatureAlerts {
272 fn create_alert_description(&self, dispatch_type: AlertDispatchType) -> String {
273 let mut alert_description = String::new();
274
275 for (i, alert) in self.alerts.iter().enumerate() {
276 let description = format!("Feature '{}' has experienced drift, with a current PSI score of {} that exceeds the configured threshold of {}.", alert.feature, alert.drift, alert.threshold);
277
278 if i == 0 {
279 let header = "PSI Drift has been detected for the following features:\n";
280 alert_description.push_str(header);
281 }
282
283 let feature_name = match dispatch_type {
284 AlertDispatchType::Console | AlertDispatchType::OpsGenie => {
285 format!("{:indent$}{}: \n", "", alert.feature, indent = 4)
286 }
287 AlertDispatchType::Slack => format!("{}: \n", alert.feature),
288 };
289
290 alert_description = format!("{alert_description}{feature_name}");
291
292 let alert_details = match dispatch_type {
293 AlertDispatchType::Console | AlertDispatchType::OpsGenie => {
294 format!("{:indent$}Drift Value: {}\n", "", description, indent = 8)
295 }
296 AlertDispatchType::Slack => {
297 format!("{:indent$}Drift Value: {}\n", "", description, indent = 4)
298 }
299 };
300 alert_description = format!("{alert_description}{alert_details}");
301 }
302 alert_description
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use approx::assert_relative_eq;
310
311 #[test]
312 fn test_compute_threshold_method_i_paper_validation() {
313 let threshold = PsiNormalThreshold { alpha: 0.05 };
318 let result = threshold.compute_threshold(400, 10);
319
320 assert_relative_eq!(result, 0.0400, epsilon = 0.002);
326 }
327
328 #[test]
329 fn test_compute_threshold_method_ii_paper_validation() {
330 let threshold = PsiChiSquareThreshold { alpha: 0.05 };
335 let result = threshold.compute_threshold(400, 10);
336
337 assert_relative_eq!(result, 0.0423, epsilon = 0.002);
341
342 let result_20_bins = threshold.compute_threshold(1000, 20);
344 assert_relative_eq!(result_20_bins, 0.0301, epsilon = 0.002);
347 }
348
349 #[test]
350 fn test_compute_threshold_paper_table_values() {
351 let threshold = PsiChiSquareThreshold { alpha: 0.05 };
355
356 let test_cases = [
358 (100, 0.169), (200, 0.085), (400, 0.042), (1000, 0.017), ];
363
364 for (sample_size, expected_approx) in test_cases {
365 let result = threshold.compute_threshold(sample_size, 10);
366 let diff = (result - expected_approx).abs();
367
368 if diff >= 0.005 {
369 panic!(
370 "Failed for sample size {sample_size}: expected ~{expected_approx}, got {result}, diff={diff}"
371 );
372 }
373 }
374 }
375
376 #[test]
377 fn test_degrees_of_freedom_relationship_chi() {
378 let threshold = PsiChiSquareThreshold { alpha: 0.05 };
380
381 let bins_5 = threshold.compute_threshold(1000, 5); let bins_10 = threshold.compute_threshold(1000, 10); let bins_20 = threshold.compute_threshold(1000, 20); assert!(
387 bins_5 < bins_10,
388 "5 bins should give smaller threshold than 10 bins"
389 );
390 assert!(
391 bins_10 < bins_20,
392 "10 bins should give smaller threshold than 20 bins"
393 );
394 }
395
396 #[test]
397 fn test_degrees_of_freedom_relationship_normal() {
398 let threshold = PsiNormalThreshold { alpha: 0.05 };
399
400 let t_5 = threshold.compute_threshold(1000, 5);
401 let t_10 = threshold.compute_threshold(1000, 10);
402 let t_20 = threshold.compute_threshold(1000, 20);
403
404 assert!(t_5 < t_10 && t_10 < t_20);
405 }
406
407 #[test]
408 fn test_alpha_significance_levels_chi() {
409 let sample_size = 1000;
411 let bin_count = 10;
412
413 let alpha_01 = PsiChiSquareThreshold { alpha: 0.01 }; let alpha_05 = PsiChiSquareThreshold { alpha: 0.05 }; let alpha_10 = PsiChiSquareThreshold { alpha: 0.10 }; let threshold_99 = alpha_01.compute_threshold(sample_size, bin_count);
418 let threshold_95 = alpha_05.compute_threshold(sample_size, bin_count);
419 let threshold_90 = alpha_10.compute_threshold(sample_size, bin_count);
420
421 assert!(
423 threshold_99 > threshold_95,
424 "99th percentile should be higher than 95th: {threshold_99} > {threshold_95}"
425 );
426 assert!(
427 threshold_95 > threshold_90,
428 "95th percentile should be higher than 90th: {threshold_95} > {threshold_90}"
429 );
430 }
431
432 #[test]
433 fn test_alpha_significance_levels_normal() {
434 let sample_size = 1000;
436 let bin_count = 10;
437
438 let alpha_01 = PsiNormalThreshold { alpha: 0.01 }; let alpha_05 = PsiNormalThreshold { alpha: 0.05 }; let alpha_10 = PsiNormalThreshold { alpha: 0.10 }; let threshold_99 = alpha_01.compute_threshold(sample_size, bin_count);
443 let threshold_95 = alpha_05.compute_threshold(sample_size, bin_count);
444 let threshold_90 = alpha_10.compute_threshold(sample_size, bin_count);
445
446 assert!(
448 threshold_99 > threshold_95,
449 "99th percentile should be higher than 95th: {threshold_99} > {threshold_95}"
450 );
451 assert!(
452 threshold_95 > threshold_90,
453 "95th percentile should be higher than 90th: {threshold_95} > {threshold_90}"
454 );
455 }
456
457 #[test]
458 fn test_alert_config() {
459 let alert_config = PsiAlertConfig::default();
461 assert_eq!(alert_config.dispatch_config, AlertDispatchConfig::default());
462 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::Console);
463
464 let slack_alert_dispatch_config = SlackDispatchConfig {
466 channel: "test".to_string(),
467 };
468 let alert_config = PsiAlertConfig {
469 dispatch_config: AlertDispatchConfig::Slack(slack_alert_dispatch_config.clone()),
470 ..Default::default()
471 };
472 assert_eq!(
473 alert_config.dispatch_config,
474 AlertDispatchConfig::Slack(slack_alert_dispatch_config)
475 );
476 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::Slack);
477
478 let opsgenie_dispatch_config = AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
480 team: "test-team".to_string(),
481 priority: "P5".to_string(),
482 });
483 let alert_config = PsiAlertConfig {
484 dispatch_config: opsgenie_dispatch_config.clone(),
485 ..Default::default()
486 };
487
488 assert_eq!(
489 alert_config.dispatch_config,
490 opsgenie_dispatch_config.clone()
491 );
492 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::OpsGenie);
493 assert_eq!(
494 match &alert_config.dispatch_config {
495 AlertDispatchConfig::OpsGenie(config) => &config.team,
496 _ => panic!("Expected OpsGenie dispatch config"),
497 },
498 "test-team"
499 );
500 }
501
502 #[test]
503 fn test_create_alert_description() {
504 let alerts = vec![
505 PsiFeatureAlert {
506 feature: "feature1".to_string(),
507 drift: 0.35,
508 threshold: 0.3,
509 },
510 PsiFeatureAlert {
511 feature: "feature2".to_string(),
512 drift: 0.45,
513 threshold: 0.3,
514 },
515 ];
516 let psi_feature_alerts = PsiFeatureAlerts { alerts };
517
518 let description = psi_feature_alerts.create_alert_description(AlertDispatchType::Console);
520 assert!(description.contains("PSI Drift has been detected for the following features:"));
521 assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
522 assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
523
524 let description = psi_feature_alerts.create_alert_description(AlertDispatchType::Slack);
526 assert!(description.contains("PSI Drift has been detected for the following features:"));
527 assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
528 assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
529
530 let description = psi_feature_alerts.create_alert_description(AlertDispatchType::OpsGenie);
532 assert!(description.contains("PSI Drift has been detected for the following features:"));
533 assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
534 assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
535 }
536}