scouter_drift/
drifter.rs

1#[cfg(feature = "sql")]
2pub mod drift_executor {
3
4    use crate::error::DriftError;
5    use crate::{custom::CustomDrifter, psi::PsiDrifter, spc::SpcDrifter};
6    use chrono::{DateTime, Utc};
7
8    use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
9    use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
10    use scouter_types::{DriftProfile, DriftTaskInfo, DriftType};
11    use sqlx::{Pool, Postgres};
12    use std::collections::BTreeMap;
13    use std::result::Result;
14    use std::result::Result::Ok;
15    use std::str::FromStr;
16    use tracing::{debug, error, info, span, Instrument, Level};
17
18    #[allow(clippy::enum_variant_names)]
19    pub enum Drifter {
20        SpcDrifter(SpcDrifter),
21        PsiDrifter(PsiDrifter),
22        CustomDrifter(CustomDrifter),
23    }
24
25    impl Drifter {
26        pub async fn check_for_alerts(
27            &self,
28            db_pool: &Pool<Postgres>,
29            previous_run: DateTime<Utc>,
30        ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
31            match self {
32                Drifter::SpcDrifter(drifter) => {
33                    drifter.check_for_alerts(db_pool, previous_run).await
34                }
35                Drifter::PsiDrifter(drifter) => {
36                    drifter.check_for_alerts(db_pool, previous_run).await
37                }
38                Drifter::CustomDrifter(drifter) => {
39                    drifter.check_for_alerts(db_pool, previous_run).await
40                }
41            }
42        }
43    }
44
45    pub trait GetDrifter {
46        fn get_drifter(&self) -> Drifter;
47    }
48
49    impl GetDrifter for DriftProfile {
50        /// Get a Drifter for processing drift profile tasks
51        ///
52        /// # Arguments
53        ///
54        /// * `name` - Name of the drift profile
55        /// * `space` - Space of the drift profile
56        /// * `version` - Version of the drift profile
57        ///
58        /// # Returns
59        ///
60        /// * `Drifter` - Drifter enum
61        fn get_drifter(&self) -> Drifter {
62            match self {
63                DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
64                DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
65                DriftProfile::Custom(profile) => {
66                    Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
67                }
68            }
69        }
70    }
71
72    pub struct DriftExecutor {
73        db_pool: Pool<Postgres>,
74    }
75
76    impl DriftExecutor {
77        pub fn new(db_pool: &Pool<Postgres>) -> Self {
78            Self {
79                db_pool: db_pool.clone(),
80            }
81        }
82
83        /// Process a single drift computation task
84        ///
85        /// # Arguments
86        ///
87        /// * `drift_profile` - Drift profile to compute drift for
88        /// * `previous_run` - Previous run timestamp
89        /// * `schedule` - Schedule for drift computation
90        /// * `transaction` - Postgres transaction
91        ///
92        /// # Returns
93        ///
94        pub async fn _process_task(
95            &mut self,
96            profile: DriftProfile,
97            previous_run: DateTime<Utc>,
98        ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
99            // match Drifter enum
100
101            profile
102                .get_drifter()
103                .check_for_alerts(&self.db_pool, previous_run)
104                .await
105        }
106
107        async fn do_poll(&mut self) -> Result<Option<TaskRequest>, DriftError> {
108            debug!("Polling for drift tasks");
109
110            // Get task from the database (query uses skip lock to pull task and update to processing)
111            let task = PostgresClient::get_drift_profile_task(&self.db_pool).await?;
112
113            let Some(task) = task else {
114                return Ok(None);
115            };
116
117            let task_info = DriftTaskInfo {
118                space: task.space.clone(),
119                name: task.name.clone(),
120                version: task.version.clone(),
121                uid: task.uid.clone(),
122                drift_type: DriftType::from_str(&task.drift_type).unwrap(),
123            };
124
125            info!(
126                "Processing drift task for profile: {}/{}/{} and type {}",
127                task.space, task.name, task.version, task.drift_type
128            );
129
130            self.process_task(&task, &task_info).await?;
131
132            // Update the run dates while still holding the lock
133            PostgresClient::update_drift_profile_run_dates(
134                &self.db_pool,
135                &task_info,
136                &task.schedule,
137            )
138            .instrument(span!(Level::INFO, "Update Run Dates"))
139            .await?;
140
141            Ok(Some(task))
142        }
143
144        async fn process_task(
145            &mut self,
146            task: &TaskRequest,
147            task_info: &DriftTaskInfo,
148        ) -> Result<(), DriftError> {
149            // get the drift type
150            let drift_type = DriftType::from_str(&task.drift_type).inspect_err(|e| {
151                error!("Error converting drift type: {:?}", e);
152            })?;
153
154            // get the drift profile
155            let profile = DriftProfile::from_str(drift_type.clone(), task.profile.clone())
156                .inspect_err(|e| {
157                    error!(
158                        "Error converting drift profile for type {:?}: {:?}",
159                        drift_type, e
160                    );
161                })?;
162
163            // check for alerts
164            match self._process_task(profile, task.previous_run).await {
165                Ok(Some(alerts)) => {
166                    info!("Drift task processed successfully with alerts");
167
168                    // Insert alerts atomically within the same transaction
169                    for alert in alerts {
170                        PostgresClient::insert_drift_alert(
171                            &self.db_pool,
172                            task_info,
173                            alert.get("entity_name").unwrap_or(&"NA".to_string()),
174                            &alert,
175                            &drift_type,
176                        )
177                        .await
178                        .map_err(|e| {
179                            error!("Error inserting drift alert: {:?}", e);
180                            DriftError::SqlError(e)
181                        })?;
182                    }
183                    Ok(())
184                }
185                Ok(None) => {
186                    info!("Drift task processed successfully with no alerts");
187                    Ok(())
188                }
189                Err(e) => {
190                    error!("Error processing drift task: {:?}", e);
191                    Err(DriftError::AlertProcessingError(e.to_string()))
192                }
193            }
194        }
195
196        /// Execute single drift computation and alerting
197        ///
198        /// # Returns
199        ///
200        /// * `Result<()>` - Result of drift computation and alerting
201        pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
202            match self.do_poll().await? {
203                Some(_) => {
204                    info!("Successfully processed drift task");
205                    Ok(())
206                }
207                None => {
208                    info!("No triggered schedules found in db. Sleeping for 10 seconds");
209                    tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
210                    Ok(())
211                }
212            }
213        }
214    }
215
216    #[cfg(test)]
217    mod tests {
218        use super::*;
219        use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
220        use scouter_settings::DatabaseSettings;
221        use scouter_sql::PostgresClient;
222        use scouter_types::DriftAlertRequest;
223        use sqlx::{postgres::Postgres, Pool};
224
225        pub async fn cleanup(pool: &Pool<Postgres>) {
226            sqlx::raw_sql(
227                r#"
228                DELETE 
229                FROM scouter.spc_drift;
230
231                DELETE 
232                FROM scouter.observability_metric;
233
234                DELETE
235                FROM scouter.custom_drift;
236
237                DELETE
238                FROM scouter.drift_alert;
239
240                DELETE
241                FROM scouter.drift_profile;
242
243                DELETE
244                FROM scouter.psi_drift;
245                "#,
246            )
247            .fetch_all(pool)
248            .await
249            .unwrap();
250
251            RustyLogger::setup_logging(Some(LoggingConfig::new(
252                None,
253                Some(LogLevel::Info),
254                None,
255                None,
256            )))
257            .unwrap();
258        }
259
260        #[tokio::test]
261        async fn test_drift_executor_spc() {
262            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
263                .await
264                .unwrap();
265
266            cleanup(&db_pool).await;
267
268            let mut populate_path =
269                std::env::current_dir().expect("Failed to get current directory");
270            populate_path.push("src/scripts/populate_spc.sql");
271
272            let script = std::fs::read_to_string(populate_path).unwrap();
273            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
274
275            let mut drift_executor = DriftExecutor::new(&db_pool);
276
277            drift_executor.poll_for_tasks().await.unwrap();
278
279            // get alerts from db
280            let request = DriftAlertRequest {
281                space: "statworld".to_string(),
282                name: "test_app".to_string(),
283                version: "0.1.0".to_string(),
284                limit_datetime: None,
285                active: None,
286                limit: None,
287            };
288            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
289                .await
290                .unwrap();
291            assert!(!alerts.is_empty());
292        }
293
294        #[tokio::test]
295        async fn test_drift_executor_spc_missing_feature_data() {
296            // this tests the scenario where only 1 of 2 features has data in the db when polling
297            // for tasks. Need to ensure this does not fail and the present feature and data are
298            // still processed
299            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
300                .await
301                .unwrap();
302            cleanup(&db_pool).await;
303
304            let mut populate_path =
305                std::env::current_dir().expect("Failed to get current directory");
306            populate_path.push("src/scripts/populate_spc_alert.sql");
307
308            let script = std::fs::read_to_string(populate_path).unwrap();
309            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
310
311            let mut drift_executor = DriftExecutor::new(&db_pool);
312
313            drift_executor.poll_for_tasks().await.unwrap();
314
315            // get alerts from db
316            let request = DriftAlertRequest {
317                space: "statworld".to_string(),
318                name: "test_app".to_string(),
319                version: "0.1.0".to_string(),
320                limit_datetime: None,
321                active: None,
322                limit: None,
323            };
324            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
325                .await
326                .unwrap();
327
328            assert!(!alerts.is_empty());
329        }
330
331        #[tokio::test]
332        async fn test_drift_executor_psi() {
333            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
334                .await
335                .unwrap();
336
337            cleanup(&db_pool).await;
338
339            let mut populate_path =
340                std::env::current_dir().expect("Failed to get current directory");
341            populate_path.push("src/scripts/populate_psi.sql");
342
343            let mut script = std::fs::read_to_string(populate_path).unwrap();
344            let bin_count = 1000;
345            let skew_feature = "feature_1";
346            let skew_factor = 10;
347            let apply_skew = true;
348            script = script.replace("{{bin_count}}", &bin_count.to_string());
349            script = script.replace("{{skew_feature}}", skew_feature);
350            script = script.replace("{{skew_factor}}", &skew_factor.to_string());
351            script = script.replace("{{apply_skew}}", &apply_skew.to_string());
352            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
353
354            let mut drift_executor = DriftExecutor::new(&db_pool);
355
356            drift_executor.poll_for_tasks().await.unwrap();
357
358            // get alerts from db
359            let request = DriftAlertRequest {
360                space: "scouter".to_string(),
361                name: "model".to_string(),
362                version: "0.1.0".to_string(),
363                limit_datetime: None,
364                active: None,
365                limit: None,
366            };
367            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
368                .await
369                .unwrap();
370
371            assert_eq!(alerts.len(), 1);
372        }
373
374        /// This test verifies that the PSI drift executor does **not** generate any drift alerts
375        /// when there are **not enough target samples** to meet the minimum threshold required
376        /// for PSI calculation.
377        ///
378        /// This arg determines how many bin counts to simulate for a production environment.
379        /// In the script there are 3 features, each with 10 bins.
380        /// `bin_count = 2` means we simulate 2 observations per bin.
381        /// So for each feature: 10 bins * 2 samples = 20 samples inserted PER insert.
382        /// Since the script inserts each feature's data 3 times (simulating 3 production batches),
383        /// each feature ends up with: 20 samples * 3 = 60 samples total.
384        /// This is below the required threshold of >100 samples per feature for PSI calculation,
385        /// so no drift alert should be generated.
386        #[tokio::test]
387        async fn test_drift_executor_psi_not_enough_target_samples() {
388            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
389                .await
390                .unwrap();
391
392            cleanup(&db_pool).await;
393
394            let mut populate_path =
395                std::env::current_dir().expect("Failed to get current directory");
396            populate_path.push("src/scripts/populate_psi.sql");
397
398            let mut script = std::fs::read_to_string(populate_path).unwrap();
399            let bin_count = 2;
400            let skew_feature = "feature_1";
401            let skew_factor = 1;
402            let apply_skew = false;
403            script = script.replace("{{bin_count}}", &bin_count.to_string());
404            script = script.replace("{{skew_feature}}", skew_feature);
405            script = script.replace("{{skew_factor}}", &skew_factor.to_string());
406            script = script.replace("{{apply_skew}}", &apply_skew.to_string());
407            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
408
409            let mut drift_executor = DriftExecutor::new(&db_pool);
410
411            drift_executor.poll_for_tasks().await.unwrap();
412
413            // get alerts from db
414            let request = DriftAlertRequest {
415                space: "scouter".to_string(),
416                name: "model".to_string(),
417                version: "0.1.0".to_string(),
418                limit_datetime: None,
419                active: None,
420                limit: None,
421            };
422            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
423                .await
424                .unwrap();
425
426            assert!(alerts.is_empty());
427        }
428
429        #[tokio::test]
430        async fn test_drift_executor_custom() {
431            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
432                .await
433                .unwrap();
434
435            cleanup(&db_pool).await;
436
437            let mut populate_path =
438                std::env::current_dir().expect("Failed to get current directory");
439            populate_path.push("src/scripts/populate_custom.sql");
440
441            let script = std::fs::read_to_string(populate_path).unwrap();
442            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
443
444            let mut drift_executor = DriftExecutor::new(&db_pool);
445
446            drift_executor.poll_for_tasks().await.unwrap();
447
448            // get alerts from db
449            let request = DriftAlertRequest {
450                space: "scouter".to_string(),
451                name: "model".to_string(),
452                version: "0.1.0".to_string(),
453                limit_datetime: None,
454                active: None,
455                limit: None,
456            };
457            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
458                .await
459                .unwrap();
460
461            assert_eq!(alerts.len(), 1);
462        }
463    }
464}