scouter_drift/custom/
drift.rs

1#[cfg(feature = "sql")]
2pub mod custom_drifter {
3
4    use crate::error::DriftError;
5    use chrono::{DateTime, Utc};
6    use scouter_dispatch::AlertDispatcher;
7    use scouter_sql::sql::traits::CustomMetricSqlLogic;
8    use scouter_sql::PostgresClient;
9    use scouter_types::contracts::ServiceInfo;
10    use scouter_types::custom::{AlertThreshold, ComparisonMetricAlert, CustomDriftProfile};
11    use sqlx::{Pool, Postgres};
12    use std::collections::{BTreeMap, HashMap};
13    use tracing::error;
14    use tracing::info;
15
16    pub struct CustomDrifter {
17        service_info: ServiceInfo,
18        profile: CustomDriftProfile,
19    }
20
21    impl CustomDrifter {
22        pub fn new(profile: CustomDriftProfile) -> Self {
23            Self {
24                service_info: ServiceInfo {
25                    name: profile.config.name.clone(),
26                    space: profile.config.space.clone(),
27                    version: profile.config.version.clone(),
28                },
29                profile,
30            }
31        }
32
33        pub async fn get_observed_custom_metric_values(
34            &self,
35            limit_datetime: &DateTime<Utc>,
36            db_pool: &Pool<Postgres>,
37        ) -> Result<HashMap<String, f64>, DriftError> {
38            let metrics: Vec<String> = self.profile.metrics.keys().cloned().collect();
39
40            Ok(PostgresClient::get_custom_metric_values(
41                db_pool,
42                &self.service_info,
43                limit_datetime,
44                &metrics,
45            )
46            .await
47            .inspect_err(|e| {
48                let msg = format!(
49                    "Error: Unable to obtain custom metric data from DB for {}/{}/{}: {}",
50                    self.service_info.space, self.service_info.name, self.service_info.version, e
51                );
52                error!(msg);
53            })?)
54        }
55
56        pub async fn get_metric_map(
57            &self,
58            limit_datetime: &DateTime<Utc>,
59            db_pool: &Pool<Postgres>,
60        ) -> Result<Option<HashMap<String, f64>>, DriftError> {
61            let metric_map = self
62                .get_observed_custom_metric_values(limit_datetime, db_pool)
63                .await?;
64
65            if metric_map.is_empty() {
66                info!(
67                    "No custom metric data was found for {}/{}/{}. Skipping alert processing.",
68                    self.service_info.space, self.service_info.name, self.service_info.version,
69                );
70                return Ok(None);
71            }
72
73            Ok(Some(metric_map))
74        }
75
76        fn is_out_of_bounds(
77            training_value: f64,
78            observed_value: f64,
79            alert_condition: &AlertThreshold,
80            alert_boundary: Option<f64>,
81        ) -> bool {
82            if observed_value == training_value {
83                return false;
84            }
85
86            let below_threshold = |boundary: Option<f64>| match boundary {
87                Some(b) => observed_value < training_value - b,
88                None => observed_value < training_value,
89            };
90
91            let above_threshold = |boundary: Option<f64>| match boundary {
92                Some(b) => observed_value > training_value + b,
93                None => observed_value > training_value,
94            };
95
96            match alert_condition {
97                AlertThreshold::Below => below_threshold(alert_boundary),
98                AlertThreshold::Above => above_threshold(alert_boundary),
99                AlertThreshold::Outside => {
100                    below_threshold(alert_boundary) || above_threshold(alert_boundary)
101                } // Handled by early equality check
102            }
103        }
104
105        pub async fn generate_alerts(
106            &self,
107            metric_map: &HashMap<String, f64>,
108        ) -> Result<Option<Vec<ComparisonMetricAlert>>, DriftError> {
109            let metric_alerts: Vec<ComparisonMetricAlert> = metric_map
110                .iter()
111                .filter_map(|(name, observed_value)| {
112                    let training_value = self.profile.metrics[name];
113                    let alert_condition = &self
114                        .profile
115                        .config
116                        .alert_config
117                        .alert_conditions
118                        .as_ref()
119                        .unwrap()[name];
120                    if Self::is_out_of_bounds(
121                        training_value,
122                        *observed_value,
123                        &alert_condition.alert_threshold,
124                        alert_condition.alert_threshold_value,
125                    ) {
126                        Some(ComparisonMetricAlert {
127                            metric_name: name.clone(),
128                            training_metric_value: training_value,
129                            observed_metric_value: *observed_value,
130                            alert_threshold_value: alert_condition.alert_threshold_value,
131                            alert_threshold: alert_condition.alert_threshold.clone(),
132                        })
133                    } else {
134                        None
135                    }
136                })
137                .collect();
138
139            if metric_alerts.is_empty() {
140                info!(
141                    "No alerts to process for {}/{}/{}",
142                    self.service_info.space, self.service_info.name, self.service_info.version
143                );
144                return Ok(None);
145            }
146
147            let alert_dispatcher = AlertDispatcher::new(&self.profile.config).inspect_err(|e| {
148                let msg = format!(
149                    "Error creating alert dispatcher for {}/{}/{}: {}",
150                    self.service_info.space, self.service_info.name, self.service_info.version, e
151                );
152                error!(msg);
153            })?;
154
155            for alert in &metric_alerts {
156                alert_dispatcher
157                    .process_alerts(alert)
158                    .await
159                    .inspect_err(|e| {
160                        let msg = format!(
161                            "Error processing alerts for {}/{}/{}: {}",
162                            self.service_info.space,
163                            self.service_info.name,
164                            self.service_info.version,
165                            e
166                        );
167                        error!(msg);
168                    })?;
169            }
170
171            Ok(Some(metric_alerts))
172        }
173
174        fn organize_alerts(
175            mut alerts: Vec<ComparisonMetricAlert>,
176        ) -> Vec<BTreeMap<String, String>> {
177            let mut alert_vec = Vec::new();
178            alerts.iter_mut().for_each(|alert| {
179                let mut alert_map = BTreeMap::new();
180                alert_map.insert("entity_name".to_string(), alert.metric_name.clone());
181                alert_map.insert(
182                    "training_metric_value".to_string(),
183                    alert.training_metric_value.to_string(),
184                );
185                alert_map.insert(
186                    "observed_metric_value".to_string(),
187                    alert.observed_metric_value.to_string(),
188                );
189                let alert_threshold_value_str = match alert.alert_threshold_value {
190                    Some(value) => value.to_string(),
191                    None => "None".to_string(),
192                };
193                alert_map.insert(
194                    "alert_threshold_value".to_string(),
195                    alert_threshold_value_str,
196                );
197                alert_map.insert(
198                    "alert_threshold".to_string(),
199                    alert.alert_threshold.to_string(),
200                );
201                alert_vec.push(alert_map);
202            });
203
204            alert_vec
205        }
206
207        pub async fn check_for_alerts(
208            &self,
209            db_pool: &Pool<Postgres>,
210            previous_run: DateTime<Utc>,
211        ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
212            let metric_map = self.get_metric_map(&previous_run, db_pool).await?;
213
214            match metric_map {
215                Some(metric_map) => {
216                    let alerts = self.generate_alerts(&metric_map).await.inspect_err(|e| {
217                        let msg = format!(
218                            "Error generating alerts for {}/{}/{}: {}",
219                            self.service_info.space,
220                            self.service_info.name,
221                            self.service_info.version,
222                            e
223                        );
224                        error!(msg);
225                    })?;
226                    match alerts {
227                        Some(alerts) => Ok(Some(Self::organize_alerts(alerts))),
228                        None => Ok(None),
229                    }
230                }
231                None => Ok(None),
232            }
233        }
234    }
235
236    #[cfg(test)]
237    mod tests {
238        use super::*;
239        use scouter_types::custom::{
240            CustomMetric, CustomMetricAlertConfig, CustomMetricDriftConfig,
241        };
242
243        fn get_test_drifter() -> CustomDrifter {
244            let custom_metrics = vec![
245                CustomMetric::new("mse", 12.02, AlertThreshold::Above, Some(1.0)).unwrap(),
246                CustomMetric::new("accuracy", 0.75, AlertThreshold::Below, None).unwrap(),
247            ];
248
249            let drift_config = CustomMetricDriftConfig::new(
250                "scouter",
251                "model",
252                "0.1.0",
253                25,
254                CustomMetricAlertConfig::default(),
255                None,
256            )
257            .unwrap();
258
259            let profile = CustomDriftProfile::new(drift_config, custom_metrics, None).unwrap();
260
261            CustomDrifter::new(profile)
262        }
263
264        #[test]
265        fn test_is_out_of_bounds() {
266            // mse training value obtained during initial model training.
267            let mse_training_value = 12.0;
268
269            // observed mse metric value captured somewhere after the initial training run.
270            let mse_observed_value = 14.5;
271
272            // we want mse to be as small as possible, so we want to see if the metric has increased.
273            let mse_alert_condition = AlertThreshold::Above;
274
275            // we do not want to alert if the mse values has simply increased, but we want to alert
276            // if the metric observed has increased beyond (mse_training_value + 2.0)
277            let mse_alert_boundary = Some(2.0);
278
279            let mse_is_out_of_bounds = CustomDrifter::is_out_of_bounds(
280                mse_training_value,
281                mse_observed_value,
282                &mse_alert_condition,
283                mse_alert_boundary,
284            );
285            assert!(mse_is_out_of_bounds);
286
287            // test observed metric has decreased beyond threshold.
288
289            // accuracy training value obtained during initial model training.
290            let accuracy_training_value = 0.76;
291
292            // observed accuracy metric value captured somewhere after the initial training run.
293            let accuracy_observed_value = 0.67;
294
295            // we want to alert if accuracy has decreased.
296            let accuracy_alert_condition = AlertThreshold::Below;
297
298            // we will not be specifying a boundary here as we want to alert if accuracy has decreased by any amount
299            let accuracy_alert_boundary = None;
300
301            let accuracy_is_out_of_bounds = CustomDrifter::is_out_of_bounds(
302                accuracy_training_value,
303                accuracy_observed_value,
304                &accuracy_alert_condition,
305                accuracy_alert_boundary,
306            );
307            assert!(accuracy_is_out_of_bounds);
308
309            // test observed metric has not increased.
310
311            // mae training value obtained during initial model training.
312            let mae_training_value = 13.5;
313
314            // observed mae metric value captured somewhere after the initial training run.
315            let mae_observed_value = 10.5;
316
317            // we want to alert if mae has increased.
318            let mae_alert_condition = AlertThreshold::Above;
319
320            // we will not be specifying a boundary here as we want to alert if mae has increased by any amount
321            let mae_alert_boundary = None;
322
323            let mae_is_out_of_bounds = CustomDrifter::is_out_of_bounds(
324                mae_training_value,
325                mae_observed_value,
326                &mae_alert_condition,
327                mae_alert_boundary,
328            );
329            assert!(!mae_is_out_of_bounds);
330        }
331
332        #[tokio::test]
333        async fn test_generate_alerts() {
334            let drifter = get_test_drifter();
335
336            let mut metric_map = HashMap::new();
337            // mse had an initial value of 12.02 when the profile was generated
338            metric_map.insert("mse".to_string(), 14.0);
339            // accuracy had an initial 0.75 when the profile was generated
340            metric_map.insert("accuracy".to_string(), 0.65);
341
342            let alerts = drifter.generate_alerts(&metric_map).await.unwrap().unwrap();
343
344            assert_eq!(alerts.len(), 2);
345        }
346    }
347}