1use crate::error::TypeError;
2use crate::{
3 dispatch::AlertDispatchType, AlertDispatchConfig, CommonCrons, DispatchAlertDescription,
4 OpsGenieDispatchConfig, ProfileFuncs, SlackDispatchConfig, ValidateAlertConfig,
5};
6use core::fmt::Debug;
7use pyo3::prelude::*;
8use pyo3::types::PyString;
9
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fmt::Display;
13use std::fmt::Formatter;
14
15#[pyclass]
16#[derive(Debug, Serialize, Deserialize, Clone)]
17pub struct CustomMetric {
18 #[pyo3(get, set)]
19 pub name: String,
20
21 #[pyo3(get, set)]
22 pub value: f64,
23
24 #[pyo3(get, set)]
25 pub alert_condition: CustomMetricAlertCondition,
26}
27
28#[pymethods]
29impl CustomMetric {
30 #[new]
31 #[pyo3(signature = (name, value, alert_threshold, alert_threshold_value=None))]
32 pub fn new(
33 name: &str,
34 value: f64,
35 alert_threshold: AlertThreshold,
36 alert_threshold_value: Option<f64>,
37 ) -> Result<Self, TypeError> {
38 let custom_condition =
39 CustomMetricAlertCondition::new(alert_threshold, alert_threshold_value);
40
41 Ok(Self {
42 name: name.to_lowercase(),
43 value,
44 alert_condition: custom_condition,
45 })
46 }
47
48 pub fn __str__(&self) -> String {
49 ProfileFuncs::__str__(self)
51 }
52
53 #[getter]
54 pub fn alert_threshold(&self) -> AlertThreshold {
55 self.alert_condition.alert_threshold.clone()
56 }
57
58 #[getter]
59 pub fn alert_threshold_value(&self) -> Option<f64> {
60 self.alert_condition.alert_threshold_value
61 }
62
63 #[getter]
64 pub fn class_id(&self) -> String {
65 self.name.clone()
66 }
67}
68
69#[pyclass(eq)]
70#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
71pub enum AlertThreshold {
72 Below,
73 Above,
74 Outside,
75}
76
77impl Display for AlertThreshold {
78 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
79 write!(f, "{:?}", self)
80 }
81}
82
83#[pymethods]
84impl AlertThreshold {
85 #[staticmethod]
86 pub fn from_value(value: &str) -> Option<Self> {
87 match value.to_lowercase().as_str() {
88 "below" => Some(AlertThreshold::Below),
89 "above" => Some(AlertThreshold::Above),
90 "outside" => Some(AlertThreshold::Outside),
91 _ => None,
92 }
93 }
94
95 pub fn __str__(&self) -> String {
96 match self {
97 AlertThreshold::Below => "Below".to_string(),
98 AlertThreshold::Above => "Above".to_string(),
99 AlertThreshold::Outside => "Outside".to_string(),
100 }
101 }
102}
103
104#[pyclass]
105#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
106pub struct CustomMetricAlertCondition {
107 #[pyo3(get, set)]
108 pub alert_threshold: AlertThreshold,
109
110 #[pyo3(get, set)]
111 pub alert_threshold_value: Option<f64>,
112}
113
114#[pymethods]
115#[allow(clippy::too_many_arguments)]
116impl CustomMetricAlertCondition {
117 #[new]
118 #[pyo3(signature = (alert_threshold, alert_threshold_value=None))]
119 pub fn new(alert_threshold: AlertThreshold, alert_threshold_value: Option<f64>) -> Self {
120 Self {
121 alert_threshold,
122 alert_threshold_value,
123 }
124 }
125
126 pub fn __str__(&self) -> String {
127 ProfileFuncs::__str__(self)
129 }
130}
131
132#[pyclass]
133#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
134pub struct CustomMetricAlertConfig {
135 pub dispatch_config: AlertDispatchConfig,
136
137 #[pyo3(get, set)]
138 pub schedule: String,
139
140 #[pyo3(get, set)]
141 pub alert_conditions: Option<HashMap<String, CustomMetricAlertCondition>>,
142}
143
144impl CustomMetricAlertConfig {
145 pub fn set_alert_conditions(&mut self, metrics: &[CustomMetric]) {
146 self.alert_conditions = Some(
147 metrics
148 .iter()
149 .map(|m| (m.name.clone(), m.alert_condition.clone()))
150 .collect(),
151 );
152 }
153}
154
155impl ValidateAlertConfig for CustomMetricAlertConfig {}
156
157#[pymethods]
158impl CustomMetricAlertConfig {
159 #[new]
160 #[pyo3(signature = (schedule=None, dispatch_config=None))]
161 pub fn new(
162 schedule: Option<&Bound<'_, PyAny>>,
163 dispatch_config: Option<&Bound<'_, PyAny>>,
164 ) -> Result<Self, TypeError> {
165 let alert_dispatch_config = match dispatch_config {
166 None => AlertDispatchConfig::default(),
167 Some(config) => {
168 if config.is_instance_of::<SlackDispatchConfig>() {
169 AlertDispatchConfig::Slack(config.extract::<SlackDispatchConfig>()?)
170 } else if config.is_instance_of::<OpsGenieDispatchConfig>() {
171 AlertDispatchConfig::OpsGenie(config.extract::<OpsGenieDispatchConfig>()?)
172 } else {
173 AlertDispatchConfig::default()
174 }
175 }
176 };
177
178 let schedule = match schedule {
179 Some(schedule) => {
180 if schedule.is_instance_of::<PyString>() {
181 schedule.to_string()
182 } else if schedule.is_instance_of::<CommonCrons>() {
183 schedule.extract::<CommonCrons>().unwrap().cron()
184 } else {
185 return Err(TypeError::InvalidScheduleError)?;
186 }
187 }
188 None => CommonCrons::EveryDay.cron(),
189 };
190
191 let schedule = Self::resolve_schedule(&schedule);
192
193 Ok(Self {
194 schedule,
195 dispatch_config: alert_dispatch_config,
196 alert_conditions: None,
197 })
198 }
199
200 #[getter]
201 pub fn dispatch_type(&self) -> AlertDispatchType {
202 self.dispatch_config.dispatch_type()
203 }
204
205 #[getter]
206 pub fn dispatch_config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
207 self.dispatch_config.config(py)
208 }
209}
210
211impl Default for CustomMetricAlertConfig {
212 fn default() -> CustomMetricAlertConfig {
213 Self {
214 dispatch_config: AlertDispatchConfig::default(),
215 schedule: CommonCrons::EveryDay.cron(),
216 alert_conditions: None,
217 }
218 }
219}
220
221pub struct ComparisonMetricAlert {
222 pub metric_name: String,
223 pub training_metric_value: f64,
224 pub observed_metric_value: f64,
225 pub alert_threshold_value: Option<f64>,
226 pub alert_threshold: AlertThreshold,
227}
228
229impl ComparisonMetricAlert {
230 fn alert_description_header(&self) -> String {
231 let below_threshold = |boundary: Option<f64>| match boundary {
232 Some(b) => format!(
233 "The observed {} metric value has dropped below the threshold (initial value - {})",
234 self.metric_name, b
235 ),
236 None => format!(
237 "The {} metric value has dropped below the initial value",
238 self.metric_name
239 ),
240 };
241
242 let above_threshold = |boundary: Option<f64>| match boundary {
243 Some(b) => format!(
244 "The {} metric value has increased beyond the threshold (initial value + {})",
245 self.metric_name, b
246 ),
247 None => format!(
248 "The {} metric value has increased beyond the initial value",
249 self.metric_name
250 ),
251 };
252
253 let outside_threshold = |boundary: Option<f64>| match boundary {
254 Some(b) => format!(
255 "The {} metric value has fallen outside the threshold (initial value ± {})",
256 self.metric_name, b,
257 ),
258 None => format!(
259 "The metric value has fallen outside the initial value for {}",
260 self.metric_name
261 ),
262 };
263
264 match self.alert_threshold {
265 AlertThreshold::Below => below_threshold(self.alert_threshold_value),
266 AlertThreshold::Above => above_threshold(self.alert_threshold_value),
267 AlertThreshold::Outside => outside_threshold(self.alert_threshold_value),
268 }
269 }
270}
271
272impl DispatchAlertDescription for ComparisonMetricAlert {
273 fn create_alert_description(&self, _dispatch_type: AlertDispatchType) -> String {
275 let mut alert_description = String::new();
276 let header = format!("{}\n", self.alert_description_header());
277 alert_description.push_str(&header);
278
279 let current_metric = format!("Current Metric Value: {}\n", self.observed_metric_value);
280 let historical_metric = format!("Initial Metric Value: {}\n", self.training_metric_value);
281
282 alert_description.push_str(&historical_metric);
283 alert_description.push_str(¤t_metric);
284
285 alert_description
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_alert_config() {
295 let dispatch_config = AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
297 team: "test-team".to_string(),
298 priority: "P5".to_string(),
299 });
300 let schedule = "0 0 * * * *".to_string();
301 let mut alert_config = CustomMetricAlertConfig {
302 dispatch_config,
303 schedule,
304 ..Default::default()
305 };
306 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::OpsGenie);
307
308 let custom_metrics = vec![
309 CustomMetric::new("mae", 12.4, AlertThreshold::Above, Some(2.3)).unwrap(),
310 CustomMetric::new("accuracy", 0.85, AlertThreshold::Below, None).unwrap(),
311 ];
312
313 alert_config.set_alert_conditions(&custom_metrics);
314
315 if let Some(alert_conditions) = alert_config.alert_conditions.as_ref() {
316 assert_eq!(
317 alert_conditions["mae"].alert_threshold,
318 AlertThreshold::Above
319 );
320 assert_eq!(alert_conditions["mae"].alert_threshold_value, Some(2.3));
321 assert_eq!(
322 alert_conditions["accuracy"].alert_threshold,
323 AlertThreshold::Below
324 );
325 assert_eq!(alert_conditions["accuracy"].alert_threshold_value, None);
326 } else {
327 panic!("alert_conditions should not be None");
328 }
329 }
330
331 #[test]
332 fn test_create_alert_description() {
333 let alert_above_threshold = ComparisonMetricAlert {
334 metric_name: "mse".to_string(),
335 training_metric_value: 12.5,
336 observed_metric_value: 14.0,
337 alert_threshold_value: Some(1.0),
338 alert_threshold: AlertThreshold::Above,
339 };
340
341 let description =
342 alert_above_threshold.create_alert_description(AlertDispatchType::Console);
343 assert!(description.contains(
344 "The mse metric value has increased beyond the threshold (initial value + 1)"
345 ));
346 assert!(description.contains("Initial Metric Value: 12.5"));
347 assert!(description.contains("Current Metric Value: 14"));
348
349 let alert_below_threshold = ComparisonMetricAlert {
350 metric_name: "accuracy".to_string(),
351 training_metric_value: 0.9,
352 observed_metric_value: 0.7,
353 alert_threshold_value: None,
354 alert_threshold: AlertThreshold::Below,
355 };
356
357 let description =
358 alert_below_threshold.create_alert_description(AlertDispatchType::Console);
359 assert!(
360 description.contains("The accuracy metric value has dropped below the initial value")
361 );
362 assert!(description.contains("Initial Metric Value: 0.9"));
363 assert!(description.contains("Current Metric Value: 0.7"));
364
365 let alert_outside_threshold = ComparisonMetricAlert {
366 metric_name: "mae".to_string(),
367 training_metric_value: 12.5,
368 observed_metric_value: 22.0,
369 alert_threshold_value: Some(2.0),
370 alert_threshold: AlertThreshold::Outside,
371 };
372
373 let description =
374 alert_outside_threshold.create_alert_description(AlertDispatchType::Console);
375 assert!(description
376 .contains("The mae metric value has fallen outside the threshold (initial value ± 2)"));
377 assert!(description.contains("Initial Metric Value: 12.5"));
378 assert!(description.contains("Current Metric Value: 22"));
379 }
380}