1#[cfg(feature = "sql")]
2pub mod drift_executor {
3
4 use crate::error::DriftError;
5 use crate::{custom::CustomDrifter, llm::LLMDrifter, psi::PsiDrifter, spc::SpcDrifter};
6 use chrono::{DateTime, Utc};
7
8 use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
9 use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
10 use scouter_types::{DriftProfile, DriftTaskInfo, DriftType};
11 use sqlx::{Pool, Postgres};
12 use std::collections::BTreeMap;
13 use std::result::Result;
14 use std::result::Result::Ok;
15 use std::str::FromStr;
16 use tracing::{debug, error, info, instrument, span, Instrument, Level};
17
18 #[allow(clippy::enum_variant_names)]
19 pub enum Drifter {
20 SpcDrifter(SpcDrifter),
21 PsiDrifter(PsiDrifter),
22 CustomDrifter(CustomDrifter),
23 LLMDrifter(LLMDrifter),
24 }
25
26 impl Drifter {
27 pub async fn check_for_alerts(
28 &self,
29 db_pool: &Pool<Postgres>,
30 previous_run: DateTime<Utc>,
31 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
32 match self {
33 Drifter::SpcDrifter(drifter) => {
34 drifter.check_for_alerts(db_pool, previous_run).await
35 }
36 Drifter::PsiDrifter(drifter) => {
37 drifter.check_for_alerts(db_pool, previous_run).await
38 }
39 Drifter::CustomDrifter(drifter) => {
40 drifter.check_for_alerts(db_pool, previous_run).await
41 }
42 Drifter::LLMDrifter(drifter) => {
43 drifter.check_for_alerts(db_pool, previous_run).await
44 }
45 }
46 }
47 }
48
49 pub trait GetDrifter {
50 fn get_drifter(&self) -> Drifter;
51 }
52
53 impl GetDrifter for DriftProfile {
54 fn get_drifter(&self) -> Drifter {
66 match self {
67 DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
68 DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
69 DriftProfile::Custom(profile) => {
70 Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
71 }
72 DriftProfile::LLM(profile) => Drifter::LLMDrifter(LLMDrifter::new(profile.clone())),
73 }
74 }
75 }
76
77 pub struct DriftExecutor {
78 db_pool: Pool<Postgres>,
79 }
80
81 impl DriftExecutor {
82 pub fn new(db_pool: &Pool<Postgres>) -> Self {
83 Self {
84 db_pool: db_pool.clone(),
85 }
86 }
87
88 pub async fn _process_task(
100 &mut self,
101 profile: DriftProfile,
102 previous_run: DateTime<Utc>,
103 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
104 profile
107 .get_drifter()
108 .check_for_alerts(&self.db_pool, previous_run)
109 .await
110 }
111
112 async fn do_poll(&mut self) -> Result<Option<TaskRequest>, DriftError> {
113 debug!("Polling for drift tasks");
114
115 let task = PostgresClient::get_drift_profile_task(&self.db_pool).await?;
117
118 let Some(task) = task else {
119 return Ok(None);
120 };
121
122 let task_info = DriftTaskInfo {
123 space: task.space.clone(),
124 name: task.name.clone(),
125 version: task.version.clone(),
126 uid: task.uid.clone(),
127 drift_type: DriftType::from_str(&task.drift_type).unwrap(),
128 };
129
130 info!(
131 "Processing drift task for profile: {}/{}/{} and type {}",
132 task.space, task.name, task.version, task.drift_type
133 );
134
135 self.process_task(&task, &task_info).await?;
136
137 PostgresClient::update_drift_profile_run_dates(
139 &self.db_pool,
140 &task_info,
141 &task.schedule,
142 )
143 .instrument(span!(Level::INFO, "Update Run Dates"))
144 .await?;
145
146 Ok(Some(task))
147 }
148
149 #[instrument(skip_all)]
150 async fn process_task(
151 &mut self,
152 task: &TaskRequest,
153 task_info: &DriftTaskInfo,
154 ) -> Result<(), DriftError> {
155 let drift_type = DriftType::from_str(&task.drift_type).inspect_err(|e| {
157 error!("Error converting drift type: {:?}", e);
158 })?;
159
160 let profile = DriftProfile::from_str(drift_type.clone(), task.profile.clone())
162 .inspect_err(|e| {
163 error!(
164 "Error converting drift profile for type {:?}: {:?}",
165 drift_type, e
166 );
167 })?;
168
169 match self._process_task(profile, task.previous_run).await {
171 Ok(Some(alerts)) => {
172 info!("Drift task processed successfully with alerts");
173
174 for alert in alerts {
176 PostgresClient::insert_drift_alert(
177 &self.db_pool,
178 task_info,
179 alert.get("entity_name").unwrap_or(&"NA".to_string()),
180 &alert,
181 &drift_type,
182 )
183 .await
184 .map_err(|e| {
185 error!("Error inserting drift alert: {:?}", e);
186 DriftError::SqlError(e)
187 })?;
188 }
189 Ok(())
190 }
191 Ok(None) => {
192 info!("Drift task processed successfully with no alerts");
193 Ok(())
194 }
195 Err(e) => {
196 error!("Error processing drift task: {:?}", e);
197 Err(DriftError::AlertProcessingError(e.to_string()))
198 }
199 }
200 }
201
202 #[instrument(skip_all)]
208 pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
209 match self.do_poll().await? {
210 Some(_) => {
211 info!("Successfully processed drift task");
212 Ok(())
213 }
214 None => {
215 info!("No triggered schedules found in db. Sleeping for 10 seconds");
216 tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
217 Ok(())
218 }
219 }
220 }
221 }
222
223 #[cfg(test)]
224 mod tests {
225 use super::*;
226 use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
227 use scouter_settings::DatabaseSettings;
228 use scouter_sql::PostgresClient;
229 use scouter_types::spc::SpcFeatureDriftProfile;
230 use scouter_types::{
231 spc::{SpcAlertConfig, SpcAlertRule, SpcDriftConfig, SpcDriftProfile},
232 AlertDispatchConfig, DriftAlertRequest,
233 };
234 use scouter_types::{CommonCrons, ProfileArgs};
235 use semver::Version;
236 use sqlx::{postgres::Postgres, Pool};
237 use std::collections::HashMap;
238
239 pub async fn cleanup(pool: &Pool<Postgres>) {
240 sqlx::raw_sql(
241 r#"
242 DELETE
243 FROM scouter.spc_drift;
244
245 DELETE
246 FROM scouter.observability_metric;
247
248 DELETE
249 FROM scouter.custom_drift;
250
251 DELETE
252 FROM scouter.drift_alert;
253
254 DELETE
255 FROM scouter.drift_profile;
256
257 DELETE
258 FROM scouter.psi_drift;
259 "#,
260 )
261 .fetch_all(pool)
262 .await
263 .unwrap();
264
265 RustyLogger::setup_logging(Some(LoggingConfig::new(
266 None,
267 Some(LogLevel::Info),
268 None,
269 None,
270 )))
271 .unwrap();
272 }
273
274 #[tokio::test]
275 async fn test_drift_executor_spc() {
276 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
277 .await
278 .unwrap();
279
280 cleanup(&db_pool).await;
281
282 let alert_config = SpcAlertConfig {
283 rule: SpcAlertRule::default(),
284 schedule: CommonCrons::EveryDay.cron().to_string(),
285 features_to_monitor: vec!["col_1".to_string(), "col_3".to_string()],
286 dispatch_config: AlertDispatchConfig::default(),
287 };
288
289 let config = SpcDriftConfig::new(
290 Some("statworld".to_string()),
291 Some("test_app".to_string()),
292 Some("0.1.0".to_string()),
293 Some(true),
294 Some(25),
295 Some(alert_config),
296 None,
297 )
298 .unwrap();
299
300 let col1_profile = SpcFeatureDriftProfile {
301 id: "col_1".to_string(),
302 center: -3.997113080300062,
303 one_ucl: -1.9742384896265417,
304 one_lcl: -6.019987670973582,
305 two_ucl: 0.048636101046978464,
306 two_lcl: -8.042862261647102,
307 three_ucl: 2.071510691720498,
308 three_lcl: -10.065736852320622,
309 timestamp: Utc::now(),
310 };
311
312 let col3_profile = SpcFeatureDriftProfile {
313 id: "col_3".to_string(),
314 center: -3.937652409303277,
315 one_ucl: -2.0275656995100224,
316 one_lcl: -5.8477391190965315,
317 two_ucl: -0.1174789897167674,
318 two_lcl: -7.757825828889787,
319 three_ucl: 1.7926077200764872,
320 three_lcl: -9.66791253868304,
321 timestamp: Utc::now(),
322 };
323
324 let drift_profile = DriftProfile::Spc(SpcDriftProfile {
325 config,
326 features: HashMap::from([
327 (col1_profile.id.clone(), col1_profile),
328 (col3_profile.id.clone(), col3_profile),
329 ]),
330 scouter_version: "0.1.0".to_string(),
331 });
332
333 let profile_args = ProfileArgs {
334 space: "statworld".to_string(),
335 name: "test_app".to_string(),
336 version: Some("0.1.0".to_string()),
337 schedule: "* * * * * *".to_string(),
338 scouter_version: "0.1.0".to_string(),
339 drift_type: DriftType::Spc,
340 };
341
342 let version = Version::new(0, 1, 0);
343 PostgresClient::insert_drift_profile(
344 &db_pool,
345 &drift_profile,
346 &profile_args,
347 &version,
348 &true,
349 &true,
350 )
351 .await
352 .unwrap();
353
354 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
355
356 let mut populate_path =
357 std::env::current_dir().expect("Failed to get current directory");
358 populate_path.push("src/scripts/populate_spc.sql");
359 let script = std::fs::read_to_string(populate_path).unwrap();
360
361 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
362 let mut drift_executor = DriftExecutor::new(&db_pool);
363 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
364
365 drift_executor.poll_for_tasks().await.unwrap();
366
367 let request = DriftAlertRequest {
369 space: "statworld".to_string(),
370 name: "test_app".to_string(),
371 version: "0.1.0".to_string(),
372 limit_datetime: None,
373 active: None,
374 limit: None,
375 };
376 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
377 .await
378 .unwrap();
379 assert!(!alerts.is_empty());
380 }
381
382 #[tokio::test]
383 async fn test_drift_executor_spc_missing_feature_data() {
384 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
388 .await
389 .unwrap();
390 cleanup(&db_pool).await;
391
392 let mut populate_path =
393 std::env::current_dir().expect("Failed to get current directory");
394 populate_path.push("src/scripts/populate_spc_alert.sql");
395
396 let script = std::fs::read_to_string(populate_path).unwrap();
397 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
398 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
399
400 let mut drift_executor = DriftExecutor::new(&db_pool);
401
402 drift_executor.poll_for_tasks().await.unwrap();
403
404 let request = DriftAlertRequest {
406 space: "statworld".to_string(),
407 name: "test_app".to_string(),
408 version: "0.1.0".to_string(),
409 limit_datetime: None,
410 active: None,
411 limit: None,
412 };
413 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
414 .await
415 .unwrap();
416
417 assert!(!alerts.is_empty());
418 }
419
420 #[tokio::test]
421 async fn test_drift_executor_psi() {
422 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
423 .await
424 .unwrap();
425
426 cleanup(&db_pool).await;
427
428 let mut populate_path =
429 std::env::current_dir().expect("Failed to get current directory");
430 populate_path.push("src/scripts/populate_psi.sql");
431
432 let mut script = std::fs::read_to_string(populate_path).unwrap();
433 let bin_count = 1000;
434 let skew_feature = "feature_1";
435 let skew_factor = 10;
436 let apply_skew = true;
437 script = script.replace("{{bin_count}}", &bin_count.to_string());
438 script = script.replace("{{skew_feature}}", skew_feature);
439 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
440 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
441 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
442 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
443
444 let mut drift_executor = DriftExecutor::new(&db_pool);
445
446 drift_executor.poll_for_tasks().await.unwrap();
447
448 let request = DriftAlertRequest {
450 space: "scouter".to_string(),
451 name: "model".to_string(),
452 version: "0.1.0".to_string(),
453 limit_datetime: None,
454 active: None,
455 limit: None,
456 };
457 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
458 .await
459 .unwrap();
460
461 assert_eq!(alerts.len(), 1);
462 }
463
464 #[tokio::test]
477 async fn test_drift_executor_psi_not_enough_target_samples() {
478 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
479 .await
480 .unwrap();
481
482 cleanup(&db_pool).await;
483
484 let mut populate_path =
485 std::env::current_dir().expect("Failed to get current directory");
486 populate_path.push("src/scripts/populate_psi.sql");
487
488 let mut script = std::fs::read_to_string(populate_path).unwrap();
489 let bin_count = 2;
490 let skew_feature = "feature_1";
491 let skew_factor = 1;
492 let apply_skew = false;
493 script = script.replace("{{bin_count}}", &bin_count.to_string());
494 script = script.replace("{{skew_feature}}", skew_feature);
495 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
496 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
497 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
498 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
499
500 let mut drift_executor = DriftExecutor::new(&db_pool);
501
502 drift_executor.poll_for_tasks().await.unwrap();
503
504 let request = DriftAlertRequest {
506 space: "scouter".to_string(),
507 name: "model".to_string(),
508 version: "0.1.0".to_string(),
509 limit_datetime: None,
510 active: None,
511 limit: None,
512 };
513 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
514 .await
515 .unwrap();
516
517 assert!(alerts.is_empty());
518 }
519
520 #[tokio::test]
521 async fn test_drift_executor_custom() {
522 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
523 .await
524 .unwrap();
525
526 cleanup(&db_pool).await;
527
528 let mut populate_path =
529 std::env::current_dir().expect("Failed to get current directory");
530 populate_path.push("src/scripts/populate_custom.sql");
531
532 let script = std::fs::read_to_string(populate_path).unwrap();
533 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
534 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
535
536 let mut drift_executor = DriftExecutor::new(&db_pool);
537
538 drift_executor.poll_for_tasks().await.unwrap();
539
540 let request = DriftAlertRequest {
542 space: "scouter".to_string(),
543 name: "model".to_string(),
544 version: "0.1.0".to_string(),
545 limit_datetime: None,
546 active: None,
547 limit: None,
548 };
549 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
550 .await
551 .unwrap();
552
553 assert_eq!(alerts.len(), 1);
554 }
555 }
556}