1#[cfg(feature = "sql")]
2pub mod drift_executor {
3
4 use crate::error::DriftError;
5 use crate::{custom::CustomDrifter, llm::LLMDrifter, psi::PsiDrifter, spc::SpcDrifter};
6 use chrono::{DateTime, Utc};
7
8 use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
9 use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
10 use scouter_types::{DriftProfile, DriftTaskInfo, DriftType};
11 use sqlx::{Pool, Postgres};
12 use std::collections::BTreeMap;
13 use std::result::Result;
14 use std::result::Result::Ok;
15 use std::str::FromStr;
16 use tracing::{debug, error, info, instrument, span, Instrument, Level};
17
18 #[allow(clippy::enum_variant_names)]
19 pub enum Drifter {
20 SpcDrifter(SpcDrifter),
21 PsiDrifter(PsiDrifter),
22 CustomDrifter(CustomDrifter),
23 LLMDrifter(LLMDrifter),
24 }
25
26 impl Drifter {
27 pub async fn check_for_alerts(
28 &self,
29 db_pool: &Pool<Postgres>,
30 previous_run: DateTime<Utc>,
31 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
32 match self {
33 Drifter::SpcDrifter(drifter) => {
34 drifter.check_for_alerts(db_pool, previous_run).await
35 }
36 Drifter::PsiDrifter(drifter) => {
37 drifter.check_for_alerts(db_pool, previous_run).await
38 }
39 Drifter::CustomDrifter(drifter) => {
40 drifter.check_for_alerts(db_pool, previous_run).await
41 }
42 Drifter::LLMDrifter(drifter) => {
43 drifter.check_for_alerts(db_pool, previous_run).await
44 }
45 }
46 }
47 }
48
49 pub trait GetDrifter {
50 fn get_drifter(&self) -> Drifter;
51 }
52
53 impl GetDrifter for DriftProfile {
54 fn get_drifter(&self) -> Drifter {
66 match self {
67 DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
68 DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
69 DriftProfile::Custom(profile) => {
70 Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
71 }
72 DriftProfile::LLM(profile) => Drifter::LLMDrifter(LLMDrifter::new(profile.clone())),
73 }
74 }
75 }
76
77 pub struct DriftExecutor {
78 db_pool: Pool<Postgres>,
79 }
80
81 impl DriftExecutor {
82 pub fn new(db_pool: &Pool<Postgres>) -> Self {
83 Self {
84 db_pool: db_pool.clone(),
85 }
86 }
87
88 pub async fn _process_task(
100 &mut self,
101 profile: DriftProfile,
102 previous_run: DateTime<Utc>,
103 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
104 profile
107 .get_drifter()
108 .check_for_alerts(&self.db_pool, previous_run)
109 .await
110 }
111
112 async fn do_poll(&mut self) -> Result<Option<TaskRequest>, DriftError> {
113 debug!("Polling for drift tasks");
114
115 let task = PostgresClient::get_drift_profile_task(&self.db_pool).await?;
117
118 let Some(task) = task else {
119 return Ok(None);
120 };
121
122 let task_info = DriftTaskInfo {
123 space: task.space.clone(),
124 name: task.name.clone(),
125 version: task.version.clone(),
126 uid: task.uid.clone(),
127 drift_type: DriftType::from_str(&task.drift_type).unwrap(),
128 };
129
130 info!(
131 "Processing drift task for profile: {}/{}/{} and type {}",
132 task.space, task.name, task.version, task.drift_type
133 );
134
135 self.process_task(&task, &task_info).await?;
136
137 PostgresClient::update_drift_profile_run_dates(
139 &self.db_pool,
140 &task_info,
141 &task.schedule,
142 )
143 .instrument(span!(Level::INFO, "Update Run Dates"))
144 .await?;
145
146 Ok(Some(task))
147 }
148
149 #[instrument(skip_all)]
150 async fn process_task(
151 &mut self,
152 task: &TaskRequest,
153 task_info: &DriftTaskInfo,
154 ) -> Result<(), DriftError> {
155 let drift_type = DriftType::from_str(&task.drift_type).inspect_err(|e| {
157 error!("Error converting drift type: {:?}", e);
158 })?;
159
160 let profile = DriftProfile::from_str(drift_type.clone(), task.profile.clone())
162 .inspect_err(|e| {
163 error!(
164 "Error converting drift profile for type {:?}: {:?}",
165 drift_type, e
166 );
167 })?;
168
169 match self._process_task(profile, task.previous_run).await {
171 Ok(Some(alerts)) => {
172 info!("Drift task processed successfully with alerts");
173
174 for alert in alerts {
176 PostgresClient::insert_drift_alert(
177 &self.db_pool,
178 task_info,
179 alert.get("entity_name").unwrap_or(&"NA".to_string()),
180 &alert,
181 &drift_type,
182 )
183 .await
184 .map_err(|e| {
185 error!("Error inserting drift alert: {:?}", e);
186 DriftError::SqlError(e)
187 })?;
188 }
189 Ok(())
190 }
191 Ok(None) => {
192 info!("Drift task processed successfully with no alerts");
193 Ok(())
194 }
195 Err(e) => {
196 error!("Error processing drift task: {:?}", e);
197 Err(DriftError::AlertProcessingError(e.to_string()))
198 }
199 }
200 }
201
202 #[instrument(skip_all)]
208 pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
209 match self.do_poll().await? {
210 Some(_) => {
211 info!("Successfully processed drift task");
212 Ok(())
213 }
214 None => {
215 info!("No triggered schedules found in db. Sleeping for 10 seconds");
216 tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
217 Ok(())
218 }
219 }
220 }
221 }
222
223 #[cfg(test)]
224 mod tests {
225 use super::*;
226 use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
227 use scouter_settings::DatabaseSettings;
228 use scouter_sql::PostgresClient;
229 use scouter_types::DriftAlertRequest;
230 use sqlx::{postgres::Postgres, Pool};
231
232 pub async fn cleanup(pool: &Pool<Postgres>) {
233 sqlx::raw_sql(
234 r#"
235 DELETE
236 FROM scouter.spc_drift;
237
238 DELETE
239 FROM scouter.observability_metric;
240
241 DELETE
242 FROM scouter.custom_drift;
243
244 DELETE
245 FROM scouter.drift_alert;
246
247 DELETE
248 FROM scouter.drift_profile;
249
250 DELETE
251 FROM scouter.psi_drift;
252 "#,
253 )
254 .fetch_all(pool)
255 .await
256 .unwrap();
257
258 RustyLogger::setup_logging(Some(LoggingConfig::new(
259 None,
260 Some(LogLevel::Info),
261 None,
262 None,
263 )))
264 .unwrap();
265 }
266
267 #[tokio::test]
268 async fn test_drift_executor_spc() {
269 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
270 .await
271 .unwrap();
272
273 cleanup(&db_pool).await;
274
275 let mut populate_path =
276 std::env::current_dir().expect("Failed to get current directory");
277 populate_path.push("src/scripts/populate_spc.sql");
278
279 let script = std::fs::read_to_string(populate_path).unwrap();
280 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
281
282 let mut drift_executor = DriftExecutor::new(&db_pool);
283
284 drift_executor.poll_for_tasks().await.unwrap();
285
286 let request = DriftAlertRequest {
288 space: "statworld".to_string(),
289 name: "test_app".to_string(),
290 version: "0.1.0".to_string(),
291 limit_datetime: None,
292 active: None,
293 limit: None,
294 };
295 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
296 .await
297 .unwrap();
298 assert!(!alerts.is_empty());
299 }
300
301 #[tokio::test]
302 async fn test_drift_executor_spc_missing_feature_data() {
303 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
307 .await
308 .unwrap();
309 cleanup(&db_pool).await;
310
311 let mut populate_path =
312 std::env::current_dir().expect("Failed to get current directory");
313 populate_path.push("src/scripts/populate_spc_alert.sql");
314
315 let script = std::fs::read_to_string(populate_path).unwrap();
316 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
317
318 let mut drift_executor = DriftExecutor::new(&db_pool);
319
320 drift_executor.poll_for_tasks().await.unwrap();
321
322 let request = DriftAlertRequest {
324 space: "statworld".to_string(),
325 name: "test_app".to_string(),
326 version: "0.1.0".to_string(),
327 limit_datetime: None,
328 active: None,
329 limit: None,
330 };
331 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
332 .await
333 .unwrap();
334
335 assert!(!alerts.is_empty());
336 }
337
338 #[tokio::test]
339 async fn test_drift_executor_psi() {
340 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
341 .await
342 .unwrap();
343
344 cleanup(&db_pool).await;
345
346 let mut populate_path =
347 std::env::current_dir().expect("Failed to get current directory");
348 populate_path.push("src/scripts/populate_psi.sql");
349
350 let mut script = std::fs::read_to_string(populate_path).unwrap();
351 let bin_count = 1000;
352 let skew_feature = "feature_1";
353 let skew_factor = 10;
354 let apply_skew = true;
355 script = script.replace("{{bin_count}}", &bin_count.to_string());
356 script = script.replace("{{skew_feature}}", skew_feature);
357 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
358 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
359 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
360
361 let mut drift_executor = DriftExecutor::new(&db_pool);
362
363 drift_executor.poll_for_tasks().await.unwrap();
364
365 let request = DriftAlertRequest {
367 space: "scouter".to_string(),
368 name: "model".to_string(),
369 version: "0.1.0".to_string(),
370 limit_datetime: None,
371 active: None,
372 limit: None,
373 };
374 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
375 .await
376 .unwrap();
377
378 assert_eq!(alerts.len(), 1);
379 }
380
381 #[tokio::test]
394 async fn test_drift_executor_psi_not_enough_target_samples() {
395 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
396 .await
397 .unwrap();
398
399 cleanup(&db_pool).await;
400
401 let mut populate_path =
402 std::env::current_dir().expect("Failed to get current directory");
403 populate_path.push("src/scripts/populate_psi.sql");
404
405 let mut script = std::fs::read_to_string(populate_path).unwrap();
406 let bin_count = 2;
407 let skew_feature = "feature_1";
408 let skew_factor = 1;
409 let apply_skew = false;
410 script = script.replace("{{bin_count}}", &bin_count.to_string());
411 script = script.replace("{{skew_feature}}", skew_feature);
412 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
413 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
414 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
415
416 let mut drift_executor = DriftExecutor::new(&db_pool);
417
418 drift_executor.poll_for_tasks().await.unwrap();
419
420 let request = DriftAlertRequest {
422 space: "scouter".to_string(),
423 name: "model".to_string(),
424 version: "0.1.0".to_string(),
425 limit_datetime: None,
426 active: None,
427 limit: None,
428 };
429 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
430 .await
431 .unwrap();
432
433 assert!(alerts.is_empty());
434 }
435
436 #[tokio::test]
437 async fn test_drift_executor_custom() {
438 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
439 .await
440 .unwrap();
441
442 cleanup(&db_pool).await;
443
444 let mut populate_path =
445 std::env::current_dir().expect("Failed to get current directory");
446 populate_path.push("src/scripts/populate_custom.sql");
447
448 let script = std::fs::read_to_string(populate_path).unwrap();
449 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
450
451 let mut drift_executor = DriftExecutor::new(&db_pool);
452
453 drift_executor.poll_for_tasks().await.unwrap();
454
455 let request = DriftAlertRequest {
457 space: "scouter".to_string(),
458 name: "model".to_string(),
459 version: "0.1.0".to_string(),
460 limit_datetime: None,
461 active: None,
462 limit: None,
463 };
464 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
465 .await
466 .unwrap();
467
468 assert_eq!(alerts.len(), 1);
469 }
470 }
471}