1#[cfg(feature = "sql")]
2pub mod custom_drifter {
3
4 use crate::error::DriftError;
5 use chrono::{DateTime, Utc};
6 use scouter_dispatch::AlertDispatcher;
7 use scouter_sql::sql::traits::CustomMetricSqlLogic;
8 use scouter_sql::PostgresClient;
9 use scouter_types::contracts::ServiceInfo;
10 use scouter_types::custom::{AlertThreshold, ComparisonMetricAlert, CustomDriftProfile};
11 use sqlx::{Pool, Postgres};
12 use std::collections::{BTreeMap, HashMap};
13 use tracing::error;
14 use tracing::info;
15
16 pub struct CustomDrifter {
17 service_info: ServiceInfo,
18 profile: CustomDriftProfile,
19 }
20
21 impl CustomDrifter {
22 pub fn new(profile: CustomDriftProfile) -> Self {
23 Self {
24 service_info: ServiceInfo {
25 name: profile.config.name.clone(),
26 space: profile.config.space.clone(),
27 version: profile.config.version.clone(),
28 },
29 profile,
30 }
31 }
32
33 pub async fn get_observed_custom_metric_values(
34 &self,
35 limit_datetime: &DateTime<Utc>,
36 db_pool: &Pool<Postgres>,
37 ) -> Result<HashMap<String, f64>, DriftError> {
38 let metrics: Vec<String> = self.profile.metrics.keys().cloned().collect();
39
40 Ok(PostgresClient::get_custom_metric_values(
41 db_pool,
42 &self.service_info,
43 limit_datetime,
44 &metrics,
45 )
46 .await
47 .inspect_err(|e| {
48 let msg = format!(
49 "Error: Unable to obtain custom metric data from DB for {}/{}/{}: {}",
50 self.service_info.space, self.service_info.name, self.service_info.version, e
51 );
52 error!(msg);
53 })?)
54 }
55
56 pub async fn get_metric_map(
57 &self,
58 limit_datetime: &DateTime<Utc>,
59 db_pool: &Pool<Postgres>,
60 ) -> Result<Option<HashMap<String, f64>>, DriftError> {
61 let metric_map = self
62 .get_observed_custom_metric_values(limit_datetime, db_pool)
63 .await?;
64
65 if metric_map.is_empty() {
66 info!(
67 "No custom metric data was found for {}/{}/{}. Skipping alert processing.",
68 self.service_info.space, self.service_info.name, self.service_info.version,
69 );
70 return Ok(None);
71 }
72
73 Ok(Some(metric_map))
74 }
75
76 fn is_out_of_bounds(
77 training_value: f64,
78 observed_value: f64,
79 alert_condition: &AlertThreshold,
80 alert_boundary: Option<f64>,
81 ) -> bool {
82 if observed_value == training_value {
83 return false;
84 }
85
86 let below_threshold = |boundary: Option<f64>| match boundary {
87 Some(b) => observed_value < training_value - b,
88 None => observed_value < training_value,
89 };
90
91 let above_threshold = |boundary: Option<f64>| match boundary {
92 Some(b) => observed_value > training_value + b,
93 None => observed_value > training_value,
94 };
95
96 match alert_condition {
97 AlertThreshold::Below => below_threshold(alert_boundary),
98 AlertThreshold::Above => above_threshold(alert_boundary),
99 AlertThreshold::Outside => {
100 below_threshold(alert_boundary) || above_threshold(alert_boundary)
101 } }
103 }
104
105 pub async fn generate_alerts(
106 &self,
107 metric_map: &HashMap<String, f64>,
108 ) -> Result<Option<Vec<ComparisonMetricAlert>>, DriftError> {
109 let metric_alerts: Vec<ComparisonMetricAlert> = metric_map
110 .iter()
111 .filter_map(|(name, observed_value)| {
112 let training_value = self.profile.metrics[name];
113 let alert_condition = &self
114 .profile
115 .config
116 .alert_config
117 .alert_conditions
118 .as_ref()
119 .unwrap()[name];
120 if Self::is_out_of_bounds(
121 training_value,
122 *observed_value,
123 &alert_condition.alert_threshold,
124 alert_condition.alert_threshold_value,
125 ) {
126 Some(ComparisonMetricAlert {
127 metric_name: name.clone(),
128 training_metric_value: training_value,
129 observed_metric_value: *observed_value,
130 alert_threshold_value: alert_condition.alert_threshold_value,
131 alert_threshold: alert_condition.alert_threshold.clone(),
132 })
133 } else {
134 None
135 }
136 })
137 .collect();
138
139 if metric_alerts.is_empty() {
140 info!(
141 "No alerts to process for {}/{}/{}",
142 self.service_info.space, self.service_info.name, self.service_info.version
143 );
144 return Ok(None);
145 }
146
147 let alert_dispatcher = AlertDispatcher::new(&self.profile.config).inspect_err(|e| {
148 let msg = format!(
149 "Error creating alert dispatcher for {}/{}/{}: {}",
150 self.service_info.space, self.service_info.name, self.service_info.version, e
151 );
152 error!(msg);
153 })?;
154
155 for alert in &metric_alerts {
156 alert_dispatcher
157 .process_alerts(alert)
158 .await
159 .inspect_err(|e| {
160 let msg = format!(
161 "Error processing alerts for {}/{}/{}: {}",
162 self.service_info.space,
163 self.service_info.name,
164 self.service_info.version,
165 e
166 );
167 error!(msg);
168 })?;
169 }
170
171 Ok(Some(metric_alerts))
172 }
173
174 fn organize_alerts(
175 mut alerts: Vec<ComparisonMetricAlert>,
176 ) -> Vec<BTreeMap<String, String>> {
177 let mut alert_vec = Vec::new();
178 alerts.iter_mut().for_each(|alert| {
179 let mut alert_map = BTreeMap::new();
180 alert_map.insert("entity_name".to_string(), alert.metric_name.clone());
181 alert_map.insert(
182 "training_metric_value".to_string(),
183 alert.training_metric_value.to_string(),
184 );
185 alert_map.insert(
186 "observed_metric_value".to_string(),
187 alert.observed_metric_value.to_string(),
188 );
189 let alert_threshold_value_str = match alert.alert_threshold_value {
190 Some(value) => value.to_string(),
191 None => "None".to_string(),
192 };
193 alert_map.insert(
194 "alert_threshold_value".to_string(),
195 alert_threshold_value_str,
196 );
197 alert_map.insert(
198 "alert_threshold".to_string(),
199 alert.alert_threshold.to_string(),
200 );
201 alert_vec.push(alert_map);
202 });
203
204 alert_vec
205 }
206
207 pub async fn check_for_alerts(
208 &self,
209 db_pool: &Pool<Postgres>,
210 previous_run: DateTime<Utc>,
211 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
212 let metric_map = self.get_metric_map(&previous_run, db_pool).await?;
213
214 match metric_map {
215 Some(metric_map) => {
216 let alerts = self.generate_alerts(&metric_map).await.inspect_err(|e| {
217 let msg = format!(
218 "Error generating alerts for {}/{}/{}: {}",
219 self.service_info.space,
220 self.service_info.name,
221 self.service_info.version,
222 e
223 );
224 error!(msg);
225 })?;
226 match alerts {
227 Some(alerts) => Ok(Some(Self::organize_alerts(alerts))),
228 None => Ok(None),
229 }
230 }
231 None => Ok(None),
232 }
233 }
234 }
235
236 #[cfg(test)]
237 mod tests {
238 use super::*;
239 use scouter_types::custom::{
240 CustomMetric, CustomMetricAlertConfig, CustomMetricDriftConfig,
241 };
242
243 fn get_test_drifter() -> CustomDrifter {
244 let custom_metrics = vec![
245 CustomMetric::new("mse", 12.02, AlertThreshold::Above, Some(1.0)).unwrap(),
246 CustomMetric::new("accuracy", 0.75, AlertThreshold::Below, None).unwrap(),
247 ];
248
249 let drift_config = CustomMetricDriftConfig::new(
250 "scouter",
251 "model",
252 "0.1.0",
253 25,
254 CustomMetricAlertConfig::default(),
255 None,
256 )
257 .unwrap();
258
259 let profile = CustomDriftProfile::new(drift_config, custom_metrics, None).unwrap();
260
261 CustomDrifter::new(profile)
262 }
263
264 #[test]
265 fn test_is_out_of_bounds() {
266 let mse_training_value = 12.0;
268
269 let mse_observed_value = 14.5;
271
272 let mse_alert_condition = AlertThreshold::Above;
274
275 let mse_alert_boundary = Some(2.0);
278
279 let mse_is_out_of_bounds = CustomDrifter::is_out_of_bounds(
280 mse_training_value,
281 mse_observed_value,
282 &mse_alert_condition,
283 mse_alert_boundary,
284 );
285 assert!(mse_is_out_of_bounds);
286
287 let accuracy_training_value = 0.76;
291
292 let accuracy_observed_value = 0.67;
294
295 let accuracy_alert_condition = AlertThreshold::Below;
297
298 let accuracy_alert_boundary = None;
300
301 let accuracy_is_out_of_bounds = CustomDrifter::is_out_of_bounds(
302 accuracy_training_value,
303 accuracy_observed_value,
304 &accuracy_alert_condition,
305 accuracy_alert_boundary,
306 );
307 assert!(accuracy_is_out_of_bounds);
308
309 let mae_training_value = 13.5;
313
314 let mae_observed_value = 10.5;
316
317 let mae_alert_condition = AlertThreshold::Above;
319
320 let mae_alert_boundary = None;
322
323 let mae_is_out_of_bounds = CustomDrifter::is_out_of_bounds(
324 mae_training_value,
325 mae_observed_value,
326 &mae_alert_condition,
327 mae_alert_boundary,
328 );
329 assert!(!mae_is_out_of_bounds);
330 }
331
332 #[tokio::test]
333 async fn test_generate_alerts() {
334 let drifter = get_test_drifter();
335
336 let mut metric_map = HashMap::new();
337 metric_map.insert("mse".to_string(), 14.0);
339 metric_map.insert("accuracy".to_string(), 0.65);
341
342 let alerts = drifter.generate_alerts(&metric_map).await.unwrap().unwrap();
343
344 assert_eq!(alerts.len(), 2);
345 }
346 }
347}