1#[cfg(feature = "sql")]
2pub mod drift_executor {
3
4 use crate::error::DriftError;
5 use crate::{custom::CustomDrifter, psi::PsiDrifter, spc::SpcDrifter};
6 use chrono::{DateTime, Utc};
7
8 use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
9 use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
10 use scouter_types::{DriftProfile, DriftTaskInfo, DriftType};
11 use sqlx::{Pool, Postgres};
12 use std::collections::BTreeMap;
13 use std::result::Result;
14 use std::result::Result::Ok;
15 use std::str::FromStr;
16 use tracing::{debug, error, info, span, Instrument, Level};
17
18 #[allow(clippy::enum_variant_names)]
19 pub enum Drifter {
20 SpcDrifter(SpcDrifter),
21 PsiDrifter(PsiDrifter),
22 CustomDrifter(CustomDrifter),
23 }
24
25 impl Drifter {
26 pub async fn check_for_alerts(
27 &self,
28 db_pool: &Pool<Postgres>,
29 previous_run: DateTime<Utc>,
30 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
31 match self {
32 Drifter::SpcDrifter(drifter) => {
33 drifter.check_for_alerts(db_pool, previous_run).await
34 }
35 Drifter::PsiDrifter(drifter) => {
36 drifter.check_for_alerts(db_pool, previous_run).await
37 }
38 Drifter::CustomDrifter(drifter) => {
39 drifter.check_for_alerts(db_pool, previous_run).await
40 }
41 }
42 }
43 }
44
45 pub trait GetDrifter {
46 fn get_drifter(&self) -> Drifter;
47 }
48
49 impl GetDrifter for DriftProfile {
50 fn get_drifter(&self) -> Drifter {
62 match self {
63 DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
64 DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
65 DriftProfile::Custom(profile) => {
66 Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
67 }
68 }
69 }
70 }
71
72 pub struct DriftExecutor {
73 db_pool: Pool<Postgres>,
74 }
75
76 impl DriftExecutor {
77 pub fn new(db_pool: &Pool<Postgres>) -> Self {
78 Self {
79 db_pool: db_pool.clone(),
80 }
81 }
82
83 pub async fn _process_task(
95 &mut self,
96 profile: DriftProfile,
97 previous_run: DateTime<Utc>,
98 ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
99 profile
102 .get_drifter()
103 .check_for_alerts(&self.db_pool, previous_run)
104 .await
105 }
106
107 async fn do_poll(&mut self) -> Result<Option<TaskRequest>, DriftError> {
108 debug!("Polling for drift tasks");
109
110 let task = PostgresClient::get_drift_profile_task(&self.db_pool).await?;
112
113 let Some(task) = task else {
114 return Ok(None);
115 };
116
117 let task_info = DriftTaskInfo {
118 space: task.space.clone(),
119 name: task.name.clone(),
120 version: task.version.clone(),
121 uid: task.uid.clone(),
122 drift_type: DriftType::from_str(&task.drift_type).unwrap(),
123 };
124
125 info!(
126 "Processing drift task for profile: {}/{}/{} and type {}",
127 task.space, task.name, task.version, task.drift_type
128 );
129
130 self.process_task(&task, &task_info).await?;
131
132 PostgresClient::update_drift_profile_run_dates(
134 &self.db_pool,
135 &task_info,
136 &task.schedule,
137 )
138 .instrument(span!(Level::INFO, "Update Run Dates"))
139 .await?;
140
141 Ok(Some(task))
142 }
143
144 async fn process_task(
145 &mut self,
146 task: &TaskRequest,
147 task_info: &DriftTaskInfo,
148 ) -> Result<(), DriftError> {
149 let drift_type = DriftType::from_str(&task.drift_type).inspect_err(|e| {
151 error!("Error converting drift type: {:?}", e);
152 })?;
153
154 let profile = DriftProfile::from_str(drift_type.clone(), task.profile.clone())
156 .inspect_err(|e| {
157 error!(
158 "Error converting drift profile for type {:?}: {:?}",
159 drift_type, e
160 );
161 })?;
162
163 match self._process_task(profile, task.previous_run).await {
165 Ok(Some(alerts)) => {
166 info!("Drift task processed successfully with alerts");
167
168 for alert in alerts {
170 PostgresClient::insert_drift_alert(
171 &self.db_pool,
172 task_info,
173 alert.get("entity_name").unwrap_or(&"NA".to_string()),
174 &alert,
175 &drift_type,
176 )
177 .await
178 .map_err(|e| {
179 error!("Error inserting drift alert: {:?}", e);
180 DriftError::SqlError(e)
181 })?;
182 }
183 Ok(())
184 }
185 Ok(None) => {
186 info!("Drift task processed successfully with no alerts");
187 Ok(())
188 }
189 Err(e) => {
190 error!("Error processing drift task: {:?}", e);
191 Err(DriftError::AlertProcessingError(e.to_string()))
192 }
193 }
194 }
195
196 pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
202 match self.do_poll().await? {
203 Some(_) => {
204 info!("Successfully processed drift task");
205 Ok(())
206 }
207 None => {
208 info!("No triggered schedules found in db. Sleeping for 10 seconds");
209 tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
210 Ok(())
211 }
212 }
213 }
214 }
215
216 #[cfg(test)]
217 mod tests {
218 use super::*;
219 use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
220 use scouter_settings::DatabaseSettings;
221 use scouter_sql::PostgresClient;
222 use scouter_types::DriftAlertRequest;
223 use sqlx::{postgres::Postgres, Pool};
224
225 pub async fn cleanup(pool: &Pool<Postgres>) {
226 sqlx::raw_sql(
227 r#"
228 DELETE
229 FROM scouter.spc_drift;
230
231 DELETE
232 FROM scouter.observability_metric;
233
234 DELETE
235 FROM scouter.custom_drift;
236
237 DELETE
238 FROM scouter.drift_alert;
239
240 DELETE
241 FROM scouter.drift_profile;
242
243 DELETE
244 FROM scouter.psi_drift;
245 "#,
246 )
247 .fetch_all(pool)
248 .await
249 .unwrap();
250
251 RustyLogger::setup_logging(Some(LoggingConfig::new(
252 None,
253 Some(LogLevel::Info),
254 None,
255 None,
256 )))
257 .unwrap();
258 }
259
260 #[tokio::test]
261 async fn test_drift_executor_spc() {
262 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
263 .await
264 .unwrap();
265
266 cleanup(&db_pool).await;
267
268 let mut populate_path =
269 std::env::current_dir().expect("Failed to get current directory");
270 populate_path.push("src/scripts/populate_spc.sql");
271
272 let script = std::fs::read_to_string(populate_path).unwrap();
273 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
274
275 let mut drift_executor = DriftExecutor::new(&db_pool);
276
277 drift_executor.poll_for_tasks().await.unwrap();
278
279 let request = DriftAlertRequest {
281 space: "statworld".to_string(),
282 name: "test_app".to_string(),
283 version: "0.1.0".to_string(),
284 limit_datetime: None,
285 active: None,
286 limit: None,
287 };
288 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
289 .await
290 .unwrap();
291 assert!(!alerts.is_empty());
292 }
293
294 #[tokio::test]
295 async fn test_drift_executor_spc_missing_feature_data() {
296 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
300 .await
301 .unwrap();
302 cleanup(&db_pool).await;
303
304 let mut populate_path =
305 std::env::current_dir().expect("Failed to get current directory");
306 populate_path.push("src/scripts/populate_spc_alert.sql");
307
308 let script = std::fs::read_to_string(populate_path).unwrap();
309 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
310
311 let mut drift_executor = DriftExecutor::new(&db_pool);
312
313 drift_executor.poll_for_tasks().await.unwrap();
314
315 let request = DriftAlertRequest {
317 space: "statworld".to_string(),
318 name: "test_app".to_string(),
319 version: "0.1.0".to_string(),
320 limit_datetime: None,
321 active: None,
322 limit: None,
323 };
324 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
325 .await
326 .unwrap();
327
328 assert!(!alerts.is_empty());
329 }
330
331 #[tokio::test]
332 async fn test_drift_executor_psi() {
333 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
334 .await
335 .unwrap();
336
337 cleanup(&db_pool).await;
338
339 let mut populate_path =
340 std::env::current_dir().expect("Failed to get current directory");
341 populate_path.push("src/scripts/populate_psi.sql");
342
343 let mut script = std::fs::read_to_string(populate_path).unwrap();
344 let bin_count = 1000;
345 let skew_feature = "feature_1";
346 let skew_factor = 10;
347 let apply_skew = true;
348 script = script.replace("{{bin_count}}", &bin_count.to_string());
349 script = script.replace("{{skew_feature}}", skew_feature);
350 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
351 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
352 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
353
354 let mut drift_executor = DriftExecutor::new(&db_pool);
355
356 drift_executor.poll_for_tasks().await.unwrap();
357
358 let request = DriftAlertRequest {
360 space: "scouter".to_string(),
361 name: "model".to_string(),
362 version: "0.1.0".to_string(),
363 limit_datetime: None,
364 active: None,
365 limit: None,
366 };
367 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
368 .await
369 .unwrap();
370
371 assert_eq!(alerts.len(), 1);
372 }
373
374 #[tokio::test]
387 async fn test_drift_executor_psi_not_enough_target_samples() {
388 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
389 .await
390 .unwrap();
391
392 cleanup(&db_pool).await;
393
394 let mut populate_path =
395 std::env::current_dir().expect("Failed to get current directory");
396 populate_path.push("src/scripts/populate_psi.sql");
397
398 let mut script = std::fs::read_to_string(populate_path).unwrap();
399 let bin_count = 2;
400 let skew_feature = "feature_1";
401 let skew_factor = 1;
402 let apply_skew = false;
403 script = script.replace("{{bin_count}}", &bin_count.to_string());
404 script = script.replace("{{skew_feature}}", skew_feature);
405 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
406 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
407 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
408
409 let mut drift_executor = DriftExecutor::new(&db_pool);
410
411 drift_executor.poll_for_tasks().await.unwrap();
412
413 let request = DriftAlertRequest {
415 space: "scouter".to_string(),
416 name: "model".to_string(),
417 version: "0.1.0".to_string(),
418 limit_datetime: None,
419 active: None,
420 limit: None,
421 };
422 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
423 .await
424 .unwrap();
425
426 assert!(alerts.is_empty());
427 }
428
429 #[tokio::test]
430 async fn test_drift_executor_custom() {
431 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
432 .await
433 .unwrap();
434
435 cleanup(&db_pool).await;
436
437 let mut populate_path =
438 std::env::current_dir().expect("Failed to get current directory");
439 populate_path.push("src/scripts/populate_custom.sql");
440
441 let script = std::fs::read_to_string(populate_path).unwrap();
442 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
443
444 let mut drift_executor = DriftExecutor::new(&db_pool);
445
446 drift_executor.poll_for_tasks().await.unwrap();
447
448 let request = DriftAlertRequest {
450 space: "scouter".to_string(),
451 name: "model".to_string(),
452 version: "0.1.0".to_string(),
453 limit_datetime: None,
454 active: None,
455 limit: None,
456 };
457 let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
458 .await
459 .unwrap();
460
461 assert_eq!(alerts.len(), 1);
462 }
463 }
464}