1#[cfg(feature = "sql")]
2pub mod drift_executor {
3
4 use crate::error::DriftError;
5 use crate::{custom::CustomDrifter, psi::PsiDrifter, spc::SpcDrifter};
6 use chrono::{DateTime, Utc};
7
8 use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
9 use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
10 use scouter_types::{DriftProfile, DriftTaskInfo, DriftType};
11 use sqlx::{Pool, Postgres};
12 use std::collections::BTreeMap;
13 use std::result::Result;
14 use std::result::Result::Ok;
15 use std::str::FromStr;
16 use tracing::{debug, error, info, instrument, span, Instrument, Level};
17
18 #[allow(clippy::enum_variant_names)]
19 pub enum Drifter {
20 SpcDrifter(SpcDrifter),
21 PsiDrifter(PsiDrifter),
22 CustomDrifter(CustomDrifter),
23 }
24
25 impl Drifter {
26 pub async fn check_for_alerts(
27 &self,
28 db_pool: &Pool<Postgres>,
29 previous_run: DateTime<Utc>,
30 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
31 match self {
32 Drifter::SpcDrifter(drifter) => {
33 drifter.check_for_alerts(db_pool, previous_run).await
34 }
35 Drifter::PsiDrifter(drifter) => {
36 drifter.check_for_alerts(db_pool, previous_run).await
37 }
38 Drifter::CustomDrifter(drifter) => {
39 drifter.check_for_alerts(db_pool, previous_run).await
40 }
41 }
42 }
43 }
44
45 pub trait GetDrifter {
46 fn get_drifter(&self) -> Drifter;
47 }
48
49 impl GetDrifter for DriftProfile {
50 fn get_drifter(&self) -> Drifter {
62 match self {
63 DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
64 DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
65 DriftProfile::Custom(profile) => {
66 Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
67 }
68 }
69 }
70 }
71
72 pub struct DriftExecutor {
73 db_pool: Pool<Postgres>,
74 }
75
76 impl DriftExecutor {
77 pub fn new(db_pool: &Pool<Postgres>) -> Self {
78 Self {
79 db_pool: db_pool.clone(),
80 }
81 }
82
83 pub async fn _process_task(
95 &mut self,
96 profile: DriftProfile,
97 previous_run: DateTime<Utc>,
98 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
99 profile
102 .get_drifter()
103 .check_for_alerts(&self.db_pool, previous_run)
104 .await
105 }
106
107 async fn do_poll(&mut self) -> Result<Option<TaskRequest>, DriftError> {
108 debug!("Polling for drift tasks");
109
110 let task = PostgresClient::get_drift_profile_task(&self.db_pool).await?;
112
113 let Some(task) = task else {
114 return Ok(None);
115 };
116
117 let task_info = DriftTaskInfo {
118 space: task.space.clone(),
119 name: task.name.clone(),
120 version: task.version.clone(),
121 uid: task.uid.clone(),
122 drift_type: DriftType::from_str(&task.drift_type).unwrap(),
123 };
124
125 info!(
126 "Processing drift task for profile: {}/{}/{} and type {}",
127 task.space, task.name, task.version, task.drift_type
128 );
129
130 self.process_task(&task, &task_info).await?;
131
132 PostgresClient::update_drift_profile_run_dates(
134 &self.db_pool,
135 &task_info,
136 &task.schedule,
137 )
138 .instrument(span!(Level::INFO, "Update Run Dates"))
139 .await?;
140
141 Ok(Some(task))
142 }
143
144 #[instrument(skip_all)]
145 async fn process_task(
146 &mut self,
147 task: &TaskRequest,
148 task_info: &DriftTaskInfo,
149 ) -> Result<(), DriftError> {
150 let drift_type = DriftType::from_str(&task.drift_type).inspect_err(|e| {
152 error!("Error converting drift type: {:?}", e);
153 })?;
154
155 let profile = DriftProfile::from_str(drift_type.clone(), task.profile.clone())
157 .inspect_err(|e| {
158 error!(
159 "Error converting drift profile for type {:?}: {:?}",
160 drift_type, e
161 );
162 })?;
163
164 match self._process_task(profile, task.previous_run).await {
166 Ok(Some(alerts)) => {
167 info!("Drift task processed successfully with alerts");
168
169 for alert in alerts {
171 PostgresClient::insert_drift_alert(
172 &self.db_pool,
173 task_info,
174 alert.get("entity_name").unwrap_or(&"NA".to_string()),
175 &alert,
176 &drift_type,
177 )
178 .await
179 .map_err(|e| {
180 error!("Error inserting drift alert: {:?}", e);
181 DriftError::SqlError(e)
182 })?;
183 }
184 Ok(())
185 }
186 Ok(None) => {
187 info!("Drift task processed successfully with no alerts");
188 Ok(())
189 }
190 Err(e) => {
191 error!("Error processing drift task: {:?}", e);
192 Err(DriftError::AlertProcessingError(e.to_string()))
193 }
194 }
195 }
196
197 #[instrument(skip_all)]
203 pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
204 match self.do_poll().await? {
205 Some(_) => {
206 info!("Successfully processed drift task");
207 Ok(())
208 }
209 None => {
210 info!("No triggered schedules found in db. Sleeping for 10 seconds");
211 tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
212 Ok(())
213 }
214 }
215 }
216 }
217
218 #[cfg(test)]
219 mod tests {
220 use super::*;
221 use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
222 use scouter_settings::DatabaseSettings;
223 use scouter_sql::PostgresClient;
224 use scouter_types::DriftAlertRequest;
225 use sqlx::{postgres::Postgres, Pool};
226
227 pub async fn cleanup(pool: &Pool<Postgres>) {
228 sqlx::raw_sql(
229 r#"
230 DELETE
231 FROM scouter.spc_drift;
232
233 DELETE
234 FROM scouter.observability_metric;
235
236 DELETE
237 FROM scouter.custom_drift;
238
239 DELETE
240 FROM scouter.drift_alert;
241
242 DELETE
243 FROM scouter.drift_profile;
244
245 DELETE
246 FROM scouter.psi_drift;
247 "#,
248 )
249 .fetch_all(pool)
250 .await
251 .unwrap();
252
253 RustyLogger::setup_logging(Some(LoggingConfig::new(
254 None,
255 Some(LogLevel::Info),
256 None,
257 None,
258 )))
259 .unwrap();
260 }
261
262 #[tokio::test]
263 async fn test_drift_executor_spc() {
264 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
265 .await
266 .unwrap();
267
268 cleanup(&db_pool).await;
269
270 let mut populate_path =
271 std::env::current_dir().expect("Failed to get current directory");
272 populate_path.push("src/scripts/populate_spc.sql");
273
274 let script = std::fs::read_to_string(populate_path).unwrap();
275 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
276
277 let mut drift_executor = DriftExecutor::new(&db_pool);
278
279 drift_executor.poll_for_tasks().await.unwrap();
280
281 let request = DriftAlertRequest {
283 space: "statworld".to_string(),
284 name: "test_app".to_string(),
285 version: "0.1.0".to_string(),
286 limit_datetime: None,
287 active: None,
288 limit: None,
289 };
290 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
291 .await
292 .unwrap();
293 assert!(!alerts.is_empty());
294 }
295
296 #[tokio::test]
297 async fn test_drift_executor_spc_missing_feature_data() {
298 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
302 .await
303 .unwrap();
304 cleanup(&db_pool).await;
305
306 let mut populate_path =
307 std::env::current_dir().expect("Failed to get current directory");
308 populate_path.push("src/scripts/populate_spc_alert.sql");
309
310 let script = std::fs::read_to_string(populate_path).unwrap();
311 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
312
313 let mut drift_executor = DriftExecutor::new(&db_pool);
314
315 drift_executor.poll_for_tasks().await.unwrap();
316
317 let request = DriftAlertRequest {
319 space: "statworld".to_string(),
320 name: "test_app".to_string(),
321 version: "0.1.0".to_string(),
322 limit_datetime: None,
323 active: None,
324 limit: None,
325 };
326 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
327 .await
328 .unwrap();
329
330 assert!(!alerts.is_empty());
331 }
332
333 #[tokio::test]
334 async fn test_drift_executor_psi() {
335 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
336 .await
337 .unwrap();
338
339 cleanup(&db_pool).await;
340
341 let mut populate_path =
342 std::env::current_dir().expect("Failed to get current directory");
343 populate_path.push("src/scripts/populate_psi.sql");
344
345 let mut script = std::fs::read_to_string(populate_path).unwrap();
346 let bin_count = 1000;
347 let skew_feature = "feature_1";
348 let skew_factor = 10;
349 let apply_skew = true;
350 script = script.replace("{{bin_count}}", &bin_count.to_string());
351 script = script.replace("{{skew_feature}}", skew_feature);
352 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
353 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
354 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
355
356 let mut drift_executor = DriftExecutor::new(&db_pool);
357
358 drift_executor.poll_for_tasks().await.unwrap();
359
360 let request = DriftAlertRequest {
362 space: "scouter".to_string(),
363 name: "model".to_string(),
364 version: "0.1.0".to_string(),
365 limit_datetime: None,
366 active: None,
367 limit: None,
368 };
369 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
370 .await
371 .unwrap();
372
373 assert_eq!(alerts.len(), 1);
374 }
375
376 #[tokio::test]
389 async fn test_drift_executor_psi_not_enough_target_samples() {
390 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
391 .await
392 .unwrap();
393
394 cleanup(&db_pool).await;
395
396 let mut populate_path =
397 std::env::current_dir().expect("Failed to get current directory");
398 populate_path.push("src/scripts/populate_psi.sql");
399
400 let mut script = std::fs::read_to_string(populate_path).unwrap();
401 let bin_count = 2;
402 let skew_feature = "feature_1";
403 let skew_factor = 1;
404 let apply_skew = false;
405 script = script.replace("{{bin_count}}", &bin_count.to_string());
406 script = script.replace("{{skew_feature}}", skew_feature);
407 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
408 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
409 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
410
411 let mut drift_executor = DriftExecutor::new(&db_pool);
412
413 drift_executor.poll_for_tasks().await.unwrap();
414
415 let request = DriftAlertRequest {
417 space: "scouter".to_string(),
418 name: "model".to_string(),
419 version: "0.1.0".to_string(),
420 limit_datetime: None,
421 active: None,
422 limit: None,
423 };
424 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
425 .await
426 .unwrap();
427
428 assert!(alerts.is_empty());
429 }
430
431 #[tokio::test]
432 async fn test_drift_executor_custom() {
433 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
434 .await
435 .unwrap();
436
437 cleanup(&db_pool).await;
438
439 let mut populate_path =
440 std::env::current_dir().expect("Failed to get current directory");
441 populate_path.push("src/scripts/populate_custom.sql");
442
443 let script = std::fs::read_to_string(populate_path).unwrap();
444 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
445
446 let mut drift_executor = DriftExecutor::new(&db_pool);
447
448 drift_executor.poll_for_tasks().await.unwrap();
449
450 let request = DriftAlertRequest {
452 space: "scouter".to_string(),
453 name: "model".to_string(),
454 version: "0.1.0".to_string(),
455 limit_datetime: None,
456 active: None,
457 limit: None,
458 };
459 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
460 .await
461 .unwrap();
462
463 assert_eq!(alerts.len(), 1);
464 }
465 }
466}