1use crate::error::TypeError;
2use crate::{
3 AlertDispatchConfig, AlertDispatchType, CommonCrons, DispatchAlertDescription,
4 OpsGenieDispatchConfig, SlackDispatchConfig, ValidateAlertConfig,
5};
6use core::fmt::Debug;
7use pyo3::prelude::*;
8use pyo3::types::PyString;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use tracing::error;
12
13#[pyclass]
14#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
15pub struct PsiAlertConfig {
16 #[pyo3(get, set)]
17 pub schedule: String,
18
19 #[pyo3(get, set)]
20 pub features_to_monitor: Vec<String>,
21
22 #[pyo3(get, set)]
23 pub psi_threshold: f64,
24
25 pub dispatch_config: AlertDispatchConfig,
26}
27
28impl Default for PsiAlertConfig {
29 fn default() -> PsiAlertConfig {
30 Self {
31 schedule: CommonCrons::EveryDay.cron(),
32 features_to_monitor: Vec::new(),
33 psi_threshold: 0.25,
34 dispatch_config: AlertDispatchConfig::default(),
35 }
36 }
37}
38
39impl ValidateAlertConfig for PsiAlertConfig {}
40
41#[pymethods]
42impl PsiAlertConfig {
43 #[new]
44 #[pyo3(signature = (schedule=None, features_to_monitor=vec![], psi_threshold=0.25, dispatch_config=None))]
45 pub fn new(
46 schedule: Option<&Bound<'_, PyAny>>,
47 features_to_monitor: Vec<String>,
48 psi_threshold: f64,
49 dispatch_config: Option<&Bound<'_, PyAny>>,
50 ) -> Result<Self, TypeError> {
51 let alert_dispatch_config = match dispatch_config {
52 None => AlertDispatchConfig::default(),
53 Some(config) => {
54 if config.is_instance_of::<SlackDispatchConfig>() {
55 AlertDispatchConfig::Slack(config.extract::<SlackDispatchConfig>()?)
56 } else if config.is_instance_of::<OpsGenieDispatchConfig>() {
57 AlertDispatchConfig::OpsGenie(config.extract::<OpsGenieDispatchConfig>()?)
58 } else {
59 AlertDispatchConfig::default()
60 }
61 }
62 };
63
64 let schedule = match schedule {
65 Some(schedule) => {
66 if schedule.is_instance_of::<PyString>() {
67 schedule.to_string()
68 } else if schedule.is_instance_of::<CommonCrons>() {
69 schedule.extract::<CommonCrons>().unwrap().cron()
70 } else {
71 error!("Invalid schedule type");
72 return Err(TypeError::InvalidScheduleError)?;
73 }
74 }
75 None => CommonCrons::EveryDay.cron(),
76 };
77
78 let schedule = Self::resolve_schedule(&schedule);
79
80 Ok(Self {
81 schedule,
82 features_to_monitor,
83 psi_threshold,
84 dispatch_config: alert_dispatch_config,
85 })
86 }
87
88 #[getter]
89 pub fn dispatch_type(&self) -> AlertDispatchType {
90 self.dispatch_config.dispatch_type()
91 }
92
93 #[getter]
94 pub fn dispatch_config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
95 self.dispatch_config.config(py)
96 }
97}
98
99pub struct PsiFeatureAlerts {
100 pub features: HashMap<String, f64>,
101 pub threshold: f64,
102}
103
104impl DispatchAlertDescription for PsiFeatureAlerts {
105 fn create_alert_description(&self, dispatch_type: AlertDispatchType) -> String {
106 let mut alert_description = String::new();
107
108 for (i, (feature_name, drift_value)) in self.features.iter().enumerate() {
109 let description = format!("Feature '{}' has experienced drift, with a current PSI score of {} that exceeds the configured threshold of {}.", feature_name, drift_value, self.threshold);
110
111 if i == 0 {
112 let header = "PSI Drift has been detected for the following features:\n";
113 alert_description.push_str(header);
114 }
115
116 let feature_name = match dispatch_type {
117 AlertDispatchType::Console | AlertDispatchType::OpsGenie => {
118 format!("{:indent$}{}: \n", "", &feature_name, indent = 4)
119 }
120 AlertDispatchType::Slack => format!("{}: \n", &feature_name),
121 };
122
123 alert_description = format!("{}{}", alert_description, feature_name);
124
125 let alert_details = match dispatch_type {
126 AlertDispatchType::Console | AlertDispatchType::OpsGenie => {
127 format!("{:indent$}Drift Value: {}\n", "", description, indent = 8)
128 }
129 AlertDispatchType::Slack => {
130 format!("{:indent$}Drift Value: {}\n", "", description, indent = 4)
131 }
132 };
133 alert_description = format!("{}{}", alert_description, alert_details);
134 }
135 alert_description
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn test_alert_config() {
145 let alert_config = PsiAlertConfig::default();
147 assert_eq!(alert_config.dispatch_config, AlertDispatchConfig::default());
148 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::Console);
149
150 let slack_alert_dispatch_config = SlackDispatchConfig {
152 channel: "test".to_string(),
153 };
154 let alert_config = PsiAlertConfig {
155 dispatch_config: AlertDispatchConfig::Slack(slack_alert_dispatch_config.clone()),
156 ..Default::default()
157 };
158 assert_eq!(
159 alert_config.dispatch_config,
160 AlertDispatchConfig::Slack(slack_alert_dispatch_config)
161 );
162 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::Slack);
163
164 let opsgenie_dispatch_config = AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
166 team: "test-team".to_string(),
167 priority: "P5".to_string(),
168 });
169 let alert_config = PsiAlertConfig {
170 dispatch_config: opsgenie_dispatch_config.clone(),
171 ..Default::default()
172 };
173
174 assert_eq!(
175 alert_config.dispatch_config,
176 opsgenie_dispatch_config.clone()
177 );
178 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::OpsGenie);
179 assert_eq!(
180 match &alert_config.dispatch_config {
181 AlertDispatchConfig::OpsGenie(config) => &config.team,
182 _ => panic!("Expected OpsGenie dispatch config"),
183 },
184 "test-team"
185 );
186 }
187
188 #[test]
189 fn test_create_alert_description() {
190 let features = HashMap::from([
191 ("feature1".to_string(), 0.35),
192 ("feature2".to_string(), 0.45),
193 ]);
194 let threshold = 0.3;
195 let psi_feature_alerts = PsiFeatureAlerts {
196 features,
197 threshold,
198 };
199
200 let description = psi_feature_alerts.create_alert_description(AlertDispatchType::Console);
202 assert!(description.contains("PSI Drift has been detected for the following features:"));
203 assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
204 assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
205
206 let description = psi_feature_alerts.create_alert_description(AlertDispatchType::Slack);
208 assert!(description.contains("PSI Drift has been detected for the following features:"));
209 assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
210 assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
211
212 let description = psi_feature_alerts.create_alert_description(AlertDispatchType::OpsGenie);
214 assert!(description.contains("PSI Drift has been detected for the following features:"));
215 assert!(description.contains("Feature 'feature1' has experienced drift, with a current PSI score of 0.35 that exceeds the configured threshold of 0.3."));
216 assert!(description.contains("Feature 'feature2' has experienced drift, with a current PSI score of 0.45 that exceeds the configured threshold of 0.3."));
217 }
218}