scouter_drift/
drifter.rs

1#[cfg(feature = "sql")]
2pub mod drift_executor {
3
4    use crate::error::DriftError;
5    use crate::{custom::CustomDrifter, psi::PsiDrifter, spc::SpcDrifter};
6    use chrono::{DateTime, Utc};
7
8    use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
9    use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
10    use scouter_types::{DriftProfile, DriftTaskInfo, DriftType};
11    use sqlx::{Pool, Postgres};
12    use std::collections::BTreeMap;
13    use std::result::Result;
14    use std::result::Result::Ok;
15    use std::str::FromStr;
16    use tracing::{debug, error, info, span, Instrument, Level};
17
18    #[allow(clippy::enum_variant_names)]
19    pub enum Drifter {
20        SpcDrifter(SpcDrifter),
21        PsiDrifter(PsiDrifter),
22        CustomDrifter(CustomDrifter),
23    }
24
25    impl Drifter {
26        pub async fn check_for_alerts(
27            &self,
28            db_pool: &Pool<Postgres>,
29            previous_run: DateTime<Utc>,
30        ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
31            match self {
32                Drifter::SpcDrifter(drifter) => {
33                    drifter.check_for_alerts(db_pool, previous_run).await
34                }
35                Drifter::PsiDrifter(drifter) => {
36                    drifter.check_for_alerts(db_pool, previous_run).await
37                }
38                Drifter::CustomDrifter(drifter) => {
39                    drifter.check_for_alerts(db_pool, previous_run).await
40                }
41            }
42        }
43    }
44
45    pub trait GetDrifter {
46        fn get_drifter(&self) -> Drifter;
47    }
48
49    impl GetDrifter for DriftProfile {
50        /// Get a Drifter for processing drift profile tasks
51        ///
52        /// # Arguments
53        ///
54        /// * `name` - Name of the drift profile
55        /// * `space` - Space of the drift profile
56        /// * `version` - Version of the drift profile
57        ///
58        /// # Returns
59        ///
60        /// * `Drifter` - Drifter enum
61        fn get_drifter(&self) -> Drifter {
62            match self {
63                DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
64                DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
65                DriftProfile::Custom(profile) => {
66                    Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
67                }
68            }
69        }
70    }
71
72    pub struct DriftExecutor {
73        db_pool: Pool<Postgres>,
74    }
75
76    impl DriftExecutor {
77        pub fn new(db_pool: &Pool<Postgres>) -> Self {
78            Self {
79                db_pool: db_pool.clone(),
80            }
81        }
82
83        /// Process a single drift computation task
84        ///
85        /// # Arguments
86        ///
87        /// * `drift_profile` - Drift profile to compute drift for
88        /// * `previous_run` - Previous run timestamp
89        /// * `schedule` - Schedule for drift computation
90        /// * `transaction` - Postgres transaction
91        ///
92        /// # Returns
93        ///
94        pub async fn _process_task(
95            &mut self,
96            profile: DriftProfile,
97            previous_run: DateTime<Utc>,
98        ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
99            // match Drifter enum
100
101            profile
102                .get_drifter()
103                .check_for_alerts(&self.db_pool, previous_run)
104                .await
105        }
106
107        async fn do_poll(&mut self) -> Result<Option<TaskRequest>, DriftError> {
108            debug!("Polling for drift tasks");
109
110            // Get task from the database (query uses skip lock to pull task and update to processing)
111            let task = PostgresClient::get_drift_profile_task(&self.db_pool).await?;
112
113            let Some(task) = task else {
114                return Ok(None);
115            };
116
117            let task_info = DriftTaskInfo {
118                space: task.space.clone(),
119                name: task.name.clone(),
120                version: task.version.clone(),
121                uid: task.uid.clone(),
122                drift_type: DriftType::from_str(&task.drift_type).unwrap(),
123            };
124
125            info!(
126                "Processing drift task for profile: {}/{}/{} and type {}",
127                task.space, task.name, task.version, task.drift_type
128            );
129
130            self.process_task(&task, &task_info).await?;
131
132            // Update the run dates while still holding the lock
133            PostgresClient::update_drift_profile_run_dates(
134                &self.db_pool,
135                &task_info,
136                &task.schedule,
137            )
138            .instrument(span!(Level::INFO, "Update Run Dates"))
139            .await?;
140
141            Ok(Some(task))
142        }
143
144        async fn process_task(
145            &mut self,
146            task: &TaskRequest,
147            task_info: &DriftTaskInfo,
148        ) -> Result<(), DriftError> {
149            // get the drift type
150            let drift_type = DriftType::from_str(&task.drift_type).inspect_err(|e| {
151                error!("Error converting drift type: {:?}", e);
152            })?;
153
154            // get the drift profile
155            let profile = DriftProfile::from_str(drift_type.clone(), task.profile.clone())
156                .inspect_err(|e| {
157                    error!("Error converting drift profile: {:?}", e);
158                })?;
159
160            // check for alerts
161            match self._process_task(profile, task.previous_run).await {
162                Ok(Some(alerts)) => {
163                    info!("Drift task processed successfully with alerts");
164
165                    // Insert alerts atomically within the same transaction
166                    for alert in alerts {
167                        PostgresClient::insert_drift_alert(
168                            &self.db_pool,
169                            task_info,
170                            alert.get("entity_name").unwrap_or(&"NA".to_string()),
171                            &alert,
172                            &drift_type,
173                        )
174                        .await
175                        .map_err(|e| {
176                            error!("Error inserting drift alert: {:?}", e);
177                            DriftError::SqlError(e)
178                        })?;
179                    }
180                    Ok(())
181                }
182                Ok(None) => {
183                    info!("Drift task processed successfully with no alerts");
184                    Ok(())
185                }
186                Err(e) => {
187                    error!("Error processing drift task: {:?}", e);
188                    Err(DriftError::AlertProcessingError(e.to_string()))
189                }
190            }
191        }
192
193        /// Execute single drift computation and alerting
194        ///
195        /// # Returns
196        ///
197        /// * `Result<()>` - Result of drift computation and alerting
198        pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
199            match self.do_poll().await? {
200                Some(_) => {
201                    info!("Successfully processed drift task");
202                    Ok(())
203                }
204                None => {
205                    info!("No triggered schedules found in db. Sleeping for 10 seconds");
206                    tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
207                    Ok(())
208                }
209            }
210        }
211    }
212
213    #[cfg(test)]
214    mod tests {
215        use super::*;
216        use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
217        use scouter_settings::DatabaseSettings;
218        use scouter_sql::PostgresClient;
219        use scouter_types::DriftAlertRequest;
220        use sqlx::{postgres::Postgres, Pool};
221
222        pub async fn cleanup(pool: &Pool<Postgres>) {
223            sqlx::raw_sql(
224                r#"
225                DELETE 
226                FROM scouter.spc_drift;
227
228                DELETE 
229                FROM scouter.observability_metric;
230
231                DELETE
232                FROM scouter.custom_drift;
233
234                DELETE
235                FROM scouter.drift_alert;
236
237                DELETE
238                FROM scouter.drift_profile;
239
240                DELETE
241                FROM scouter.psi_drift;
242                "#,
243            )
244            .fetch_all(pool)
245            .await
246            .unwrap();
247
248            RustyLogger::setup_logging(Some(LoggingConfig::new(
249                None,
250                Some(LogLevel::Info),
251                None,
252                None,
253            )))
254            .unwrap();
255        }
256
257        #[tokio::test]
258        async fn test_drift_executor_spc() {
259            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
260                .await
261                .unwrap();
262
263            cleanup(&db_pool).await;
264
265            let mut populate_path =
266                std::env::current_dir().expect("Failed to get current directory");
267            populate_path.push("src/scripts/populate_spc.sql");
268
269            let script = std::fs::read_to_string(populate_path).unwrap();
270            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
271
272            let mut drift_executor = DriftExecutor::new(&db_pool);
273
274            drift_executor.poll_for_tasks().await.unwrap();
275
276            // get alerts from db
277            let request = DriftAlertRequest {
278                space: "statworld".to_string(),
279                name: "test_app".to_string(),
280                version: "0.1.0".to_string(),
281                limit_datetime: None,
282                active: None,
283                limit: None,
284            };
285            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
286                .await
287                .unwrap();
288            assert!(!alerts.is_empty());
289        }
290
291        #[tokio::test]
292        async fn test_drift_executor_spc_missing_feature_data() {
293            // this tests the scenario where only 1 of 2 features has data in the db when polling
294            // for tasks. Need to ensure this does not fail and the present feature and data are
295            // still processed
296            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
297                .await
298                .unwrap();
299            cleanup(&db_pool).await;
300
301            let mut populate_path =
302                std::env::current_dir().expect("Failed to get current directory");
303            populate_path.push("src/scripts/populate_spc_alert.sql");
304
305            let script = std::fs::read_to_string(populate_path).unwrap();
306            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
307
308            let mut drift_executor = DriftExecutor::new(&db_pool);
309
310            drift_executor.poll_for_tasks().await.unwrap();
311
312            // get alerts from db
313            let request = DriftAlertRequest {
314                space: "statworld".to_string(),
315                name: "test_app".to_string(),
316                version: "0.1.0".to_string(),
317                limit_datetime: None,
318                active: None,
319                limit: None,
320            };
321            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
322                .await
323                .unwrap();
324
325            assert!(!alerts.is_empty());
326        }
327
328        #[tokio::test]
329        async fn test_drift_executor_psi() {
330            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
331                .await
332                .unwrap();
333
334            cleanup(&db_pool).await;
335
336            let mut populate_path =
337                std::env::current_dir().expect("Failed to get current directory");
338            populate_path.push("src/scripts/populate_psi.sql");
339
340            let mut script = std::fs::read_to_string(populate_path).unwrap();
341            let bin_count = 1000;
342            let skew_feature = "feature_1";
343            let skew_factor = 10;
344            let apply_skew = true;
345            script = script.replace("{{bin_count}}", &bin_count.to_string());
346            script = script.replace("{{skew_feature}}", skew_feature);
347            script = script.replace("{{skew_factor}}", &skew_factor.to_string());
348            script = script.replace("{{apply_skew}}", &apply_skew.to_string());
349            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
350
351            let mut drift_executor = DriftExecutor::new(&db_pool);
352
353            drift_executor.poll_for_tasks().await.unwrap();
354
355            // get alerts from db
356            let request = DriftAlertRequest {
357                space: "scouter".to_string(),
358                name: "model".to_string(),
359                version: "0.1.0".to_string(),
360                limit_datetime: None,
361                active: None,
362                limit: None,
363            };
364            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
365                .await
366                .unwrap();
367
368            assert_eq!(alerts.len(), 1);
369        }
370
371        /// This test verifies that the PSI drift executor does **not** generate any drift alerts
372        /// when there are **not enough target samples** to meet the minimum threshold required
373        /// for PSI calculation.
374        ///
375        /// This arg determines how many bin counts to simulate for a production environment.
376        /// In the script there are 3 features, each with 10 bins.
377        /// `bin_count = 2` means we simulate 2 observations per bin.
378        /// So for each feature: 10 bins * 2 samples = 20 samples inserted PER insert.
379        /// Since the script inserts each feature's data 3 times (simulating 3 production batches),
380        /// each feature ends up with: 20 samples * 3 = 60 samples total.
381        /// This is below the required threshold of >100 samples per feature for PSI calculation,
382        /// so no drift alert should be generated.
383        #[tokio::test]
384        async fn test_drift_executor_psi_not_enough_target_samples() {
385            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
386                .await
387                .unwrap();
388
389            cleanup(&db_pool).await;
390
391            let mut populate_path =
392                std::env::current_dir().expect("Failed to get current directory");
393            populate_path.push("src/scripts/populate_psi.sql");
394
395            let mut script = std::fs::read_to_string(populate_path).unwrap();
396            let bin_count = 2;
397            let skew_feature = "feature_1";
398            let skew_factor = 1;
399            let apply_skew = false;
400            script = script.replace("{{bin_count}}", &bin_count.to_string());
401            script = script.replace("{{skew_feature}}", skew_feature);
402            script = script.replace("{{skew_factor}}", &skew_factor.to_string());
403            script = script.replace("{{apply_skew}}", &apply_skew.to_string());
404            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
405
406            let mut drift_executor = DriftExecutor::new(&db_pool);
407
408            drift_executor.poll_for_tasks().await.unwrap();
409
410            // get alerts from db
411            let request = DriftAlertRequest {
412                space: "scouter".to_string(),
413                name: "model".to_string(),
414                version: "0.1.0".to_string(),
415                limit_datetime: None,
416                active: None,
417                limit: None,
418            };
419            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
420                .await
421                .unwrap();
422
423            assert!(alerts.is_empty());
424        }
425
426        #[tokio::test]
427        async fn test_drift_executor_custom() {
428            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
429                .await
430                .unwrap();
431
432            cleanup(&db_pool).await;
433
434            let mut populate_path =
435                std::env::current_dir().expect("Failed to get current directory");
436            populate_path.push("src/scripts/populate_custom.sql");
437
438            let script = std::fs::read_to_string(populate_path).unwrap();
439            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
440
441            let mut drift_executor = DriftExecutor::new(&db_pool);
442
443            drift_executor.poll_for_tasks().await.unwrap();
444
445            // get alerts from db
446            let request = DriftAlertRequest {
447                space: "scouter".to_string(),
448                name: "model".to_string(),
449                version: "0.1.0".to_string(),
450                limit_datetime: None,
451                active: None,
452                limit: None,
453            };
454            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
455                .await
456                .unwrap();
457
458            assert_eq!(alerts.len(), 1);
459        }
460    }
461}