1use crate::error::TypeError;
2use crate::{
3 dispatch::AlertDispatchType, AlertDispatchConfig, AlertThreshold, CommonCrons,
4 DispatchAlertDescription, OpsGenieDispatchConfig, PyHelperFuncs, SlackDispatchConfig,
5 ValidateAlertConfig,
6};
7use core::fmt::Debug;
8use pyo3::prelude::*;
9use pyo3::types::PyString;
10
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14#[pyclass]
15#[derive(Debug, Serialize, Deserialize, Clone)]
16pub struct CustomMetric {
17 #[pyo3(get, set)]
18 pub name: String,
19
20 #[pyo3(get, set)]
21 pub value: f64,
22
23 #[pyo3(get, set)]
24 pub alert_condition: CustomMetricAlertCondition,
25}
26
27#[pymethods]
28impl CustomMetric {
29 #[new]
30 #[pyo3(signature = (name, value, alert_threshold, alert_threshold_value=None))]
31 pub fn new(
32 name: &str,
33 value: f64,
34 alert_threshold: AlertThreshold,
35 alert_threshold_value: Option<f64>,
36 ) -> Result<Self, TypeError> {
37 let custom_condition =
38 CustomMetricAlertCondition::new(alert_threshold, alert_threshold_value);
39
40 Ok(Self {
41 name: name.to_lowercase(),
42 value,
43 alert_condition: custom_condition,
44 })
45 }
46
47 pub fn __str__(&self) -> String {
48 PyHelperFuncs::__str__(self)
50 }
51
52 #[getter]
53 pub fn alert_threshold(&self) -> AlertThreshold {
54 self.alert_condition.alert_threshold.clone()
55 }
56
57 #[getter]
58 pub fn alert_threshold_value(&self) -> Option<f64> {
59 self.alert_condition.alert_threshold_value
60 }
61
62 #[getter]
63 pub fn class_id(&self) -> String {
64 self.name.clone()
65 }
66}
67
68#[pyclass]
69#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
70pub struct CustomMetricAlertCondition {
71 #[pyo3(get, set)]
72 pub alert_threshold: AlertThreshold,
73
74 #[pyo3(get, set)]
75 pub alert_threshold_value: Option<f64>,
76}
77
78#[pymethods]
79#[allow(clippy::too_many_arguments)]
80impl CustomMetricAlertCondition {
81 #[new]
82 #[pyo3(signature = (alert_threshold, alert_threshold_value=None))]
83 pub fn new(alert_threshold: AlertThreshold, alert_threshold_value: Option<f64>) -> Self {
84 Self {
85 alert_threshold,
86 alert_threshold_value,
87 }
88 }
89
90 pub fn __str__(&self) -> String {
91 PyHelperFuncs::__str__(self)
93 }
94}
95
96#[pyclass]
97#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
98pub struct CustomMetricAlertConfig {
99 pub dispatch_config: AlertDispatchConfig,
100
101 #[pyo3(get, set)]
102 pub schedule: String,
103
104 #[pyo3(get, set)]
105 pub alert_conditions: Option<HashMap<String, CustomMetricAlertCondition>>,
106}
107
108impl CustomMetricAlertConfig {
109 pub fn set_alert_conditions(&mut self, metrics: &[CustomMetric]) {
110 self.alert_conditions = Some(
111 metrics
112 .iter()
113 .map(|m| (m.name.clone(), m.alert_condition.clone()))
114 .collect(),
115 );
116 }
117}
118
119impl ValidateAlertConfig for CustomMetricAlertConfig {}
120
121#[pymethods]
122impl CustomMetricAlertConfig {
123 #[new]
124 #[pyo3(signature = (schedule=None, dispatch_config=None))]
125 pub fn new(
126 schedule: Option<&Bound<'_, PyAny>>,
127 dispatch_config: Option<&Bound<'_, PyAny>>,
128 ) -> Result<Self, TypeError> {
129 let alert_dispatch_config = match dispatch_config {
130 None => AlertDispatchConfig::default(),
131 Some(config) => {
132 if config.is_instance_of::<SlackDispatchConfig>() {
133 AlertDispatchConfig::Slack(config.extract::<SlackDispatchConfig>()?)
134 } else if config.is_instance_of::<OpsGenieDispatchConfig>() {
135 AlertDispatchConfig::OpsGenie(config.extract::<OpsGenieDispatchConfig>()?)
136 } else {
137 AlertDispatchConfig::default()
138 }
139 }
140 };
141
142 let schedule = match schedule {
143 Some(schedule) => {
144 if schedule.is_instance_of::<PyString>() {
145 schedule.to_string()
146 } else if schedule.is_instance_of::<CommonCrons>() {
147 schedule.extract::<CommonCrons>().unwrap().cron()
148 } else {
149 return Err(TypeError::InvalidScheduleError)?;
150 }
151 }
152 None => CommonCrons::EveryDay.cron(),
153 };
154
155 let schedule = Self::resolve_schedule(&schedule);
156
157 Ok(Self {
158 schedule,
159 dispatch_config: alert_dispatch_config,
160 alert_conditions: None,
161 })
162 }
163
164 #[getter]
165 pub fn dispatch_type(&self) -> AlertDispatchType {
166 self.dispatch_config.dispatch_type()
167 }
168
169 #[getter]
170 pub fn dispatch_config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
171 self.dispatch_config.config(py)
172 }
173}
174
175impl Default for CustomMetricAlertConfig {
176 fn default() -> CustomMetricAlertConfig {
177 Self {
178 dispatch_config: AlertDispatchConfig::default(),
179 schedule: CommonCrons::EveryDay.cron(),
180 alert_conditions: None,
181 }
182 }
183}
184
185pub struct ComparisonMetricAlert {
186 pub metric_name: String,
187 pub training_metric_value: f64,
188 pub observed_metric_value: f64,
189 pub alert_threshold_value: Option<f64>,
190 pub alert_threshold: AlertThreshold,
191}
192
193impl ComparisonMetricAlert {
194 fn alert_description_header(&self) -> String {
195 let below_threshold = |boundary: Option<f64>| match boundary {
196 Some(b) => format!(
197 "The observed {} metric value has dropped below the threshold (initial value - {})",
198 self.metric_name, b
199 ),
200 None => format!(
201 "The {} metric value has dropped below the initial value",
202 self.metric_name
203 ),
204 };
205
206 let above_threshold = |boundary: Option<f64>| match boundary {
207 Some(b) => format!(
208 "The {} metric value has increased beyond the threshold (initial value + {})",
209 self.metric_name, b
210 ),
211 None => format!(
212 "The {} metric value has increased beyond the initial value",
213 self.metric_name
214 ),
215 };
216
217 let outside_threshold = |boundary: Option<f64>| match boundary {
218 Some(b) => format!(
219 "The {} metric value has fallen outside the threshold (initial value ± {})",
220 self.metric_name, b,
221 ),
222 None => format!(
223 "The metric value has fallen outside the initial value for {}",
224 self.metric_name
225 ),
226 };
227
228 match self.alert_threshold {
229 AlertThreshold::Below => below_threshold(self.alert_threshold_value),
230 AlertThreshold::Above => above_threshold(self.alert_threshold_value),
231 AlertThreshold::Outside => outside_threshold(self.alert_threshold_value),
232 }
233 }
234}
235
236impl DispatchAlertDescription for ComparisonMetricAlert {
237 fn create_alert_description(&self, _dispatch_type: AlertDispatchType) -> String {
239 let mut alert_description = String::new();
240 let header = format!("{}\n", self.alert_description_header());
241 alert_description.push_str(&header);
242
243 let current_metric = format!("Current Metric Value: {}\n", self.observed_metric_value);
244 let historical_metric = format!("Initial Metric Value: {}\n", self.training_metric_value);
245
246 alert_description.push_str(&historical_metric);
247 alert_description.push_str(¤t_metric);
248
249 alert_description
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_alert_config() {
259 let dispatch_config = AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
261 team: "test-team".to_string(),
262 priority: "P5".to_string(),
263 });
264 let schedule = "0 0 * * * *".to_string();
265 let mut alert_config = CustomMetricAlertConfig {
266 dispatch_config,
267 schedule,
268 ..Default::default()
269 };
270 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::OpsGenie);
271
272 let custom_metrics = vec![
273 CustomMetric::new("mae", 12.4, AlertThreshold::Above, Some(2.3)).unwrap(),
274 CustomMetric::new("accuracy", 0.85, AlertThreshold::Below, None).unwrap(),
275 ];
276
277 alert_config.set_alert_conditions(&custom_metrics);
278
279 if let Some(alert_conditions) = alert_config.alert_conditions.as_ref() {
280 assert_eq!(
281 alert_conditions["mae"].alert_threshold,
282 AlertThreshold::Above
283 );
284 assert_eq!(alert_conditions["mae"].alert_threshold_value, Some(2.3));
285 assert_eq!(
286 alert_conditions["accuracy"].alert_threshold,
287 AlertThreshold::Below
288 );
289 assert_eq!(alert_conditions["accuracy"].alert_threshold_value, None);
290 } else {
291 panic!("alert_conditions should not be None");
292 }
293 }
294
295 #[test]
296 fn test_create_alert_description() {
297 let alert_above_threshold = ComparisonMetricAlert {
298 metric_name: "mse".to_string(),
299 training_metric_value: 12.5,
300 observed_metric_value: 14.0,
301 alert_threshold_value: Some(1.0),
302 alert_threshold: AlertThreshold::Above,
303 };
304
305 let description =
306 alert_above_threshold.create_alert_description(AlertDispatchType::Console);
307 assert!(description.contains(
308 "The mse metric value has increased beyond the threshold (initial value + 1)"
309 ));
310 assert!(description.contains("Initial Metric Value: 12.5"));
311 assert!(description.contains("Current Metric Value: 14"));
312
313 let alert_below_threshold = ComparisonMetricAlert {
314 metric_name: "accuracy".to_string(),
315 training_metric_value: 0.9,
316 observed_metric_value: 0.7,
317 alert_threshold_value: None,
318 alert_threshold: AlertThreshold::Below,
319 };
320
321 let description =
322 alert_below_threshold.create_alert_description(AlertDispatchType::Console);
323 assert!(
324 description.contains("The accuracy metric value has dropped below the initial value")
325 );
326 assert!(description.contains("Initial Metric Value: 0.9"));
327 assert!(description.contains("Current Metric Value: 0.7"));
328
329 let alert_outside_threshold = ComparisonMetricAlert {
330 metric_name: "mae".to_string(),
331 training_metric_value: 12.5,
332 observed_metric_value: 22.0,
333 alert_threshold_value: Some(2.0),
334 alert_threshold: AlertThreshold::Outside,
335 };
336
337 let description =
338 alert_outside_threshold.create_alert_description(AlertDispatchType::Console);
339 assert!(description
340 .contains("The mae metric value has fallen outside the threshold (initial value ± 2)"));
341 assert!(description.contains("Initial Metric Value: 12.5"));
342 assert!(description.contains("Current Metric Value: 22"));
343 }
344}