1use crate::error::TypeError;
2use crate::{
3 AlertDispatchConfig, AlertDispatchType, AlertMap, CommonCrons, DispatchAlertDescription,
4 OpsGenieDispatchConfig, SlackDispatchConfig, ValidateAlertConfig,
5};
6use core::fmt::Debug;
7use pyo3::prelude::*;
8use pyo3::types::PyString;
9use pyo3::IntoPyObjectExt;
10use serde::{Deserialize, Serialize};
11use statrs::distribution::{ChiSquared, ContinuousCDF, Normal};
12
13#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
14pub enum PsiThreshold {
15 Normal(PsiNormalThreshold),
16 ChiSquare(PsiChiSquareThreshold),
17 Fixed(PsiFixedThreshold),
18}
19
20impl PsiThreshold {
21 pub fn config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
22 match self {
23 PsiThreshold::Normal(config) => config.clone().into_bound_py_any(py),
24 PsiThreshold::ChiSquare(config) => config.clone().into_bound_py_any(py),
25 PsiThreshold::Fixed(config) => config.clone().into_bound_py_any(py),
26 }
27 }
28
29 pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
30 match self {
31 PsiThreshold::Normal(normal) => normal.compute_threshold(target_sample_size, bin_count),
32 PsiThreshold::ChiSquare(chi) => chi.compute_threshold(target_sample_size, bin_count),
33 PsiThreshold::Fixed(fixed) => fixed.compute_threshold(),
34 }
35 }
36}
37
38impl Default for PsiThreshold {
39 fn default() -> Self {
41 PsiThreshold::ChiSquare(PsiChiSquareThreshold { alpha: 0.05 })
42 }
43}
44
45#[pyclass]
46#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
47pub struct PsiNormalThreshold {
48 #[pyo3(get, set)]
49 pub alpha: f64,
50}
51
52impl PsiNormalThreshold {
53 #[allow(non_snake_case)]
61 pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
62 let M = target_sample_size as f64;
63 let B = bin_count as f64;
64
65 let normal = Normal::new(0.0, 1.0).unwrap();
66 let z_alpha = normal.inverse_cdf(1.0 - self.alpha);
67
68 let exp_val = (B - 1.0) / M;
69 let std_dev = (2.0 * (B - 1.0)).sqrt() / M;
70
71 exp_val + z_alpha * std_dev
72 }
73}
74
75#[pymethods]
76impl PsiNormalThreshold {
77 #[new]
78 #[pyo3(signature = (alpha=0.05))]
79 pub fn new(alpha: f64) -> PyResult<Self> {
80 if !(0.0..1.0).contains(&alpha) {
81 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
82 "alpha must be between 0.0 and 1.0 (exclusive)",
83 ));
84 }
85 Ok(Self { alpha })
86 }
87}
88
89#[pyclass]
90#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
91pub struct PsiChiSquareThreshold {
92 #[pyo3(get, set)]
93 pub alpha: f64,
94}
95
96impl PsiChiSquareThreshold {
97 #[allow(non_snake_case)]
105 pub fn compute_threshold(&self, target_sample_size: u64, bin_count: u64) -> f64 {
106 let M = target_sample_size as f64;
107 let B = bin_count as f64;
108 let chi2 = ChiSquared::new(B - 1.0).unwrap();
109 let chi2_critical = chi2.inverse_cdf(1.0 - self.alpha);
110
111 chi2_critical / M
112 }
113}
114
115#[pymethods]
116impl PsiChiSquareThreshold {
117 #[new]
118 #[pyo3(signature = (alpha=0.05))]
119 pub fn new(alpha: f64) -> PyResult<Self> {
120 if !(0.0..1.0).contains(&alpha) {
121 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
122 "alpha must be between 0.0 and 1.0 (exclusive)",
123 ));
124 }
125 Ok(Self { alpha })
126 }
127}
128
129#[pyclass]
130#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
131pub struct PsiFixedThreshold {
132 #[pyo3(get, set)]
133 pub threshold: f64,
134}
135
136impl PsiFixedThreshold {
137 pub fn compute_threshold(&self) -> f64 {
138 self.threshold
139 }
140}
141
142#[pymethods]
143impl PsiFixedThreshold {
144 #[new]
145 #[pyo3(signature = (threshold=0.25))]
146 pub fn new(threshold: f64) -> PyResult<Self> {
147 if threshold < 0.0 {
148 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
149 "Threshold values must be non-zero",
150 ));
151 }
152 Ok(Self { threshold })
153 }
154}
155
156#[pyclass]
157#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
158pub struct PsiAlertConfig {
159 #[pyo3(get, set)]
160 pub schedule: String,
161
162 #[pyo3(get, set)]
163 pub features_to_monitor: Vec<String>,
164
165 pub dispatch_config: AlertDispatchConfig,
166
167 pub threshold: PsiThreshold,
168}
169
170impl Default for PsiAlertConfig {
171 fn default() -> PsiAlertConfig {
172 Self {
173 schedule: CommonCrons::EveryDay.cron(),
174 features_to_monitor: Vec::new(),
175 dispatch_config: AlertDispatchConfig::default(),
176 threshold: PsiThreshold::default(),
177 }
178 }
179}
180
181impl ValidateAlertConfig for PsiAlertConfig {}
182
183#[pymethods]
184impl PsiAlertConfig {
185 #[new]
186 #[pyo3(signature = (schedule=None, features_to_monitor=vec![], dispatch_config=None, threshold=None))]
187 pub fn new(
188 schedule: Option<&Bound<'_, PyAny>>,
189 features_to_monitor: Vec<String>,
190 dispatch_config: Option<&Bound<'_, PyAny>>,
191 threshold: Option<&Bound<'_, PyAny>>,
192 ) -> Result<Self, TypeError> {
193 let dispatch_config = match dispatch_config {
194 None => AlertDispatchConfig::default(),
195 Some(config) => {
196 if config.is_instance_of::<SlackDispatchConfig>() {
197 AlertDispatchConfig::Slack(config.extract()?)
198 } else if config.is_instance_of::<OpsGenieDispatchConfig>() {
199 AlertDispatchConfig::OpsGenie(config.extract()?)
200 } else {
201 return Err(TypeError::InvalidDispatchConfigError);
202 }
203 }
204 };
205
206 let threshold = match threshold {
207 None => PsiThreshold::default(),
208 Some(config) => {
209 if config.is_instance_of::<PsiNormalThreshold>() {
210 PsiThreshold::Normal(config.extract()?)
211 } else if config.is_instance_of::<PsiChiSquareThreshold>() {
212 PsiThreshold::ChiSquare(config.extract()?)
213 } else if config.is_instance_of::<PsiFixedThreshold>() {
214 PsiThreshold::Fixed(config.extract()?)
216 } else {
217 return Err(TypeError::InvalidPsiThresholdError);
218 }
219 }
220 };
221
222 let schedule = match schedule {
223 Some(schedule) => {
224 if schedule.is_instance_of::<PyString>() {
225 schedule.to_string()
226 } else if schedule.is_instance_of::<CommonCrons>() {
227 schedule.extract::<CommonCrons>()?.cron()
228 } else {
229 return Err(TypeError::InvalidScheduleError);
230 }
231 }
232 None => CommonCrons::EveryDay.cron(),
233 };
234
235 let schedule = Self::resolve_schedule(&schedule);
236
237 Ok(Self {
238 schedule,
239 features_to_monitor,
240 dispatch_config,
241 threshold,
242 })
243 }
244 #[getter]
245 pub fn dispatch_type(&self) -> AlertDispatchType {
246 self.dispatch_config.dispatch_type()
247 }
248
249 #[getter]
250 pub fn dispatch_config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
251 self.dispatch_config.config(py)
252 }
253
254 #[getter]
255 pub fn threshold<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
256 self.threshold.config(py)
257 }
258}
259
260#[derive(Serialize, Deserialize, Debug, Default, Clone)]
261pub struct PsiFeatureAlert {
262 pub feature: String,
263 pub drift: f64,
264 pub threshold: f64,
265}
266impl From<PsiFeatureAlert> for AlertMap {
267 fn from(val: PsiFeatureAlert) -> Self {
268 AlertMap::Psi(val)
269 }
270}
271
272pub struct PsiFeatureAlerts {
273 pub alerts: Vec<PsiFeatureAlert>,
274}
275
276impl DispatchAlertDescription for PsiFeatureAlerts {
277 fn create_alert_description(&self, dispatch_type: AlertDispatchType) -> String {
278 let mut alert_description = String::new();
279
280 for (i, alert) in self.alerts.iter().enumerate() {
281 let description = format!("Feature '{}' has experienced drift, with a current PSI score of {} that exceeds the configured threshold of {}.", alert.feature, alert.drift, alert.threshold);
282
283 if i == 0 {
284 let header = "PSI Drift has been detected for the following features:\n";
285 alert_description.push_str(header);
286 }
287
288 let feature_name = match dispatch_type {
289 AlertDispatchType::Console | AlertDispatchType::OpsGenie => {
290 format!("{:indent$}{}: \n", "", alert.feature, indent = 4)
291 }
292 AlertDispatchType::Slack => format!("{}: \n", alert.feature),
293 };
294
295 alert_description = format!("{alert_description}{feature_name}");
296
297 let alert_details = match dispatch_type {
298 AlertDispatchType::Console | AlertDispatchType::OpsGenie => {
299 format!("{:indent$}Drift Value: {}\n", "", description, indent = 8)
300 }
301 AlertDispatchType::Slack => {
302 format!("{:indent$}Drift Value: {}\n", "", description, indent = 4)
303 }
304 };
305 alert_description = format!("{alert_description}{alert_details}");
306 }
307 alert_description
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use approx::assert_relative_eq;
315
316 #[test]
317 fn test_compute_threshold_method_i_paper_validation() {
318 let threshold = PsiNormalThreshold { alpha: 0.05 };
323 let result = threshold.compute_threshold(400, 10);
324
325 assert_relative_eq!(result, 0.0400, epsilon = 0.002);
331 }
332
333 #[test]
334 fn test_compute_threshold_method_ii_paper_validation() {
335 let threshold = PsiChiSquareThreshold { alpha: 0.05 };
340 let result = threshold.compute_threshold(400, 10);
341
342 assert_relative_eq!(result, 0.0423, epsilon = 0.002);
346
347 let result_20_bins = threshold.compute_threshold(1000, 20);
349 assert_relative_eq!(result_20_bins, 0.0301, epsilon = 0.002);
352 }
353
354 #[test]
355 fn test_compute_threshold_paper_table_values() {
356 let threshold = PsiChiSquareThreshold { alpha: 0.05 };
360
361 let test_cases = [
363 (100, 0.169), (200, 0.085), (400, 0.042), (1000, 0.017), ];
368
369 for (sample_size, expected_approx) in test_cases {
370 let result = threshold.compute_threshold(sample_size, 10);
371 let diff = (result - expected_approx).abs();
372
373 if diff >= 0.005 {
374 panic!(
375 "Failed for sample size {sample_size}: expected ~{expected_approx}, got {result}, diff={diff}"
376 );
377 }
378 }
379 }
380
381 #[test]
382 fn test_degrees_of_freedom_relationship_chi() {
383 let threshold = PsiChiSquareThreshold { alpha: 0.05 };
385
386 let bins_5 = threshold.compute_threshold(1000, 5); let bins_10 = threshold.compute_threshold(1000, 10); let bins_20 = threshold.compute_threshold(1000, 20); assert!(
392 bins_5 < bins_10,
393 "5 bins should give smaller threshold than 10 bins"
394 );
395 assert!(
396 bins_10 < bins_20,
397 "10 bins should give smaller threshold than 20 bins"
398 );
399 }
400
401 #[test]
402 fn test_degrees_of_freedom_relationship_normal() {
403 let threshold = PsiNormalThreshold { alpha: 0.05 };
404
405 let t_5 = threshold.compute_threshold(1000, 5);
406 let t_10 = threshold.compute_threshold(1000, 10);
407 let t_20 = threshold.compute_threshold(1000, 20);
408
409 assert!(t_5 < t_10 && t_10 < t_20);
410 }
411
412 #[test]
413 fn test_alpha_significance_levels_chi() {
414 let sample_size = 1000;
416 let bin_count = 10;
417
418 let alpha_01 = PsiChiSquareThreshold { alpha: 0.01 }; let alpha_05 = PsiChiSquareThreshold { alpha: 0.05 }; let alpha_10 = PsiChiSquareThreshold { alpha: 0.10 }; let threshold_99 = alpha_01.compute_threshold(sample_size, bin_count);
423 let threshold_95 = alpha_05.compute_threshold(sample_size, bin_count);
424 let threshold_90 = alpha_10.compute_threshold(sample_size, bin_count);
425
426 assert!(
428 threshold_99 > threshold_95,
429 "99th percentile should be higher than 95th: {threshold_99} > {threshold_95}"
430 );
431 assert!(
432 threshold_95 > threshold_90,
433 "95th percentile should be higher than 90th: {threshold_95} > {threshold_90}"
434 );
435 }
436
437 #[test]
438 fn test_alpha_significance_levels_normal() {
439 let sample_size = 1000;
441 let bin_count = 10;
442
443 let alpha_01 = PsiNormalThreshold { alpha: 0.01 }; let alpha_05 = PsiNormalThreshold { alpha: 0.05 }; let alpha_10 = PsiNormalThreshold { alpha: 0.10 }; let threshold_99 = alpha_01.compute_threshold(sample_size, bin_count);
448 let threshold_95 = alpha_05.compute_threshold(sample_size, bin_count);
449 let threshold_90 = alpha_10.compute_threshold(sample_size, bin_count);
450
451 assert!(
453 threshold_99 > threshold_95,
454 "99th percentile should be higher than 95th: {threshold_99} > {threshold_95}"
455 );
456 assert!(
457 threshold_95 > threshold_90,
458 "95th percentile should be higher than 90th: {threshold_95} > {threshold_90}"
459 );
460 }
461
462 #[test]
463 fn test_alert_config() {
464 let alert_config = PsiAlertConfig::default();
466 assert_eq!(alert_config.dispatch_config, AlertDispatchConfig::default());
467 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::Console);
468
469 let slack_alert_dispatch_config = SlackDispatchConfig {
471 channel: "test".to_string(),
472 };
473 let alert_config = PsiAlertConfig {
474 dispatch_config: AlertDispatchConfig::Slack(slack_alert_dispatch_config.clone()),
475 ..Default::default()
476 };
477 assert_eq!(
478 alert_config.dispatch_config,
479 AlertDispatchConfig::Slack(slack_alert_dispatch_config)
480 );
481 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::Slack);
482
483 let opsgenie_dispatch_config = AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
485 team: "test-team".to_string(),
486 priority: "P5".to_string(),
487 });
488 let alert_config = PsiAlertConfig {
489 dispatch_config: opsgenie_dispatch_config.clone(),
490 ..Default::default()
491 };
492
493 assert_eq!(
494 alert_config.dispatch_config,
495 opsgenie_dispatch_config.clone()
496 );
497 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::OpsGenie);
498 assert_eq!(
499 match &alert_config.dispatch_config {
500 AlertDispatchConfig::OpsGenie(config) => &config.team,
501 _ => panic!("Expected OpsGenie dispatch config"),
502 },
503 "test-team"
504 );
505 }
506
507 #[test]
508 fn test_create_alert_description() {
509 let alerts = vec![
510 PsiFeatureAlert {
511 feature: "feature1".to_string(),
512 drift: 0.35,
513 threshold: 0.3,
514 },
515 PsiFeatureAlert {
516 feature: "feature2".to_string(),
517 drift: 0.45,
518 threshold: 0.3,
519 },
520 ];
521 let psi_feature_alerts = PsiFeatureAlerts { alerts };
522
523 let description = psi_feature_alerts.create_alert_description(AlertDispatchType::Console);
525 assert!(description.contains("PSI Drift has been detected for the following features:"));
526 assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
527 assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
528
529 let description = psi_feature_alerts.create_alert_description(AlertDispatchType::Slack);
531 assert!(description.contains("PSI Drift has been detected for the following features:"));
532 assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
533 assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
534
535 let description = psi_feature_alerts.create_alert_description(AlertDispatchType::OpsGenie);
537 assert!(description.contains("PSI Drift has been detected for the following features:"));
538 assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
539 assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
540 }
541}