scouter_drift/
drifter.rs

1#[cfg(feature = "sql")]
2pub mod drift_executor {
3
4    use crate::error::DriftError;
5    use crate::{custom::CustomDrifter, psi::PsiDrifter, spc::SpcDrifter};
6    use chrono::{DateTime, Utc};
7
8    use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
9    use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
10    use scouter_types::{DriftProfile, DriftTaskInfo, DriftType};
11    use sqlx::{Pool, Postgres};
12    use std::collections::BTreeMap;
13    use std::result::Result;
14    use std::result::Result::Ok;
15    use std::str::FromStr;
16    use tracing::{debug, error, info, instrument, span, Instrument, Level};
17
18    #[allow(clippy::enum_variant_names)]
19    pub enum Drifter {
20        SpcDrifter(SpcDrifter),
21        PsiDrifter(PsiDrifter),
22        CustomDrifter(CustomDrifter),
23    }
24
25    impl Drifter {
26        pub async fn check_for_alerts(
27            &self,
28            db_pool: &Pool<Postgres>,
29            previous_run: DateTime<Utc>,
30        ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
31            match self {
32                Drifter::SpcDrifter(drifter) => {
33                    drifter.check_for_alerts(db_pool, previous_run).await
34                }
35                Drifter::PsiDrifter(drifter) => {
36                    drifter.check_for_alerts(db_pool, previous_run).await
37                }
38                Drifter::CustomDrifter(drifter) => {
39                    drifter.check_for_alerts(db_pool, previous_run).await
40                }
41            }
42        }
43    }
44
45    pub trait GetDrifter {
46        fn get_drifter(&self) -> Drifter;
47    }
48
49    impl GetDrifter for DriftProfile {
50        /// Get a Drifter for processing drift profile tasks
51        ///
52        /// # Arguments
53        ///
54        /// * `name` - Name of the drift profile
55        /// * `space` - Space of the drift profile
56        /// * `version` - Version of the drift profile
57        ///
58        /// # Returns
59        ///
60        /// * `Drifter` - Drifter enum
61        fn get_drifter(&self) -> Drifter {
62            match self {
63                DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
64                DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
65                DriftProfile::Custom(profile) => {
66                    Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
67                }
68            }
69        }
70    }
71
72    pub struct DriftExecutor {
73        db_pool: Pool<Postgres>,
74    }
75
76    impl DriftExecutor {
77        pub fn new(db_pool: &Pool<Postgres>) -> Self {
78            Self {
79                db_pool: db_pool.clone(),
80            }
81        }
82
83        /// Process a single drift computation task
84        ///
85        /// # Arguments
86        ///
87        /// * `drift_profile` - Drift profile to compute drift for
88        /// * `previous_run` - Previous run timestamp
89        /// * `schedule` - Schedule for drift computation
90        /// * `transaction` - Postgres transaction
91        ///
92        /// # Returns
93        ///
94        pub async fn _process_task(
95            &mut self,
96            profile: DriftProfile,
97            previous_run: DateTime<Utc>,
98        ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
99            // match Drifter enum
100
101            profile
102                .get_drifter()
103                .check_for_alerts(&self.db_pool, previous_run)
104                .await
105        }
106
107        async fn do_poll(&mut self) -> Result<Option<TaskRequest>, DriftError> {
108            debug!("Polling for drift tasks");
109
110            // Get task from the database (query uses skip lock to pull task and update to processing)
111            let task = PostgresClient::get_drift_profile_task(&self.db_pool).await?;
112
113            let Some(task) = task else {
114                return Ok(None);
115            };
116
117            let task_info = DriftTaskInfo {
118                space: task.space.clone(),
119                name: task.name.clone(),
120                version: task.version.clone(),
121                uid: task.uid.clone(),
122                drift_type: DriftType::from_str(&task.drift_type).unwrap(),
123            };
124
125            info!(
126                "Processing drift task for profile: {}/{}/{} and type {}",
127                task.space, task.name, task.version, task.drift_type
128            );
129
130            self.process_task(&task, &task_info).await?;
131
132            // Update the run dates while still holding the lock
133            PostgresClient::update_drift_profile_run_dates(
134                &self.db_pool,
135                &task_info,
136                &task.schedule,
137            )
138            .instrument(span!(Level::INFO, "Update Run Dates"))
139            .await?;
140
141            Ok(Some(task))
142        }
143
144        #[instrument(skip_all)]
145        async fn process_task(
146            &mut self,
147            task: &TaskRequest,
148            task_info: &DriftTaskInfo,
149        ) -> Result<(), DriftError> {
150            // get the drift type
151            let drift_type = DriftType::from_str(&task.drift_type).inspect_err(|e| {
152                error!("Error converting drift type: {:?}", e);
153            })?;
154
155            // get the drift profile
156            let profile = DriftProfile::from_str(drift_type.clone(), task.profile.clone())
157                .inspect_err(|e| {
158                    error!(
159                        "Error converting drift profile for type {:?}: {:?}",
160                        drift_type, e
161                    );
162                })?;
163
164            // check for alerts
165            match self._process_task(profile, task.previous_run).await {
166                Ok(Some(alerts)) => {
167                    info!("Drift task processed successfully with alerts");
168
169                    // Insert alerts atomically within the same transaction
170                    for alert in alerts {
171                        PostgresClient::insert_drift_alert(
172                            &self.db_pool,
173                            task_info,
174                            alert.get("entity_name").unwrap_or(&"NA".to_string()),
175                            &alert,
176                            &drift_type,
177                        )
178                        .await
179                        .map_err(|e| {
180                            error!("Error inserting drift alert: {:?}", e);
181                            DriftError::SqlError(e)
182                        })?;
183                    }
184                    Ok(())
185                }
186                Ok(None) => {
187                    info!("Drift task processed successfully with no alerts");
188                    Ok(())
189                }
190                Err(e) => {
191                    error!("Error processing drift task: {:?}", e);
192                    Err(DriftError::AlertProcessingError(e.to_string()))
193                }
194            }
195        }
196
197        /// Execute single drift computation and alerting
198        ///
199        /// # Returns
200        ///
201        /// * `Result<()>` - Result of drift computation and alerting
202        #[instrument(skip_all)]
203        pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
204            match self.do_poll().await? {
205                Some(_) => {
206                    info!("Successfully processed drift task");
207                    Ok(())
208                }
209                None => {
210                    info!("No triggered schedules found in db. Sleeping for 10 seconds");
211                    tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
212                    Ok(())
213                }
214            }
215        }
216    }
217
218    #[cfg(test)]
219    mod tests {
220        use super::*;
221        use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
222        use scouter_settings::DatabaseSettings;
223        use scouter_sql::PostgresClient;
224        use scouter_types::DriftAlertRequest;
225        use sqlx::{postgres::Postgres, Pool};
226
227        pub async fn cleanup(pool: &Pool<Postgres>) {
228            sqlx::raw_sql(
229                r#"
230                DELETE 
231                FROM scouter.spc_drift;
232
233                DELETE 
234                FROM scouter.observability_metric;
235
236                DELETE
237                FROM scouter.custom_drift;
238
239                DELETE
240                FROM scouter.drift_alert;
241
242                DELETE
243                FROM scouter.drift_profile;
244
245                DELETE
246                FROM scouter.psi_drift;
247                "#,
248            )
249            .fetch_all(pool)
250            .await
251            .unwrap();
252
253            RustyLogger::setup_logging(Some(LoggingConfig::new(
254                None,
255                Some(LogLevel::Info),
256                None,
257                None,
258            )))
259            .unwrap();
260        }
261
262        #[tokio::test]
263        async fn test_drift_executor_spc() {
264            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
265                .await
266                .unwrap();
267
268            cleanup(&db_pool).await;
269
270            let mut populate_path =
271                std::env::current_dir().expect("Failed to get current directory");
272            populate_path.push("src/scripts/populate_spc.sql");
273
274            let script = std::fs::read_to_string(populate_path).unwrap();
275            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
276
277            let mut drift_executor = DriftExecutor::new(&db_pool);
278
279            drift_executor.poll_for_tasks().await.unwrap();
280
281            // get alerts from db
282            let request = DriftAlertRequest {
283                space: "statworld".to_string(),
284                name: "test_app".to_string(),
285                version: "0.1.0".to_string(),
286                limit_datetime: None,
287                active: None,
288                limit: None,
289            };
290            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
291                .await
292                .unwrap();
293            assert!(!alerts.is_empty());
294        }
295
296        #[tokio::test]
297        async fn test_drift_executor_spc_missing_feature_data() {
298            // this tests the scenario where only 1 of 2 features has data in the db when polling
299            // for tasks. Need to ensure this does not fail and the present feature and data are
300            // still processed
301            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
302                .await
303                .unwrap();
304            cleanup(&db_pool).await;
305
306            let mut populate_path =
307                std::env::current_dir().expect("Failed to get current directory");
308            populate_path.push("src/scripts/populate_spc_alert.sql");
309
310            let script = std::fs::read_to_string(populate_path).unwrap();
311            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
312
313            let mut drift_executor = DriftExecutor::new(&db_pool);
314
315            drift_executor.poll_for_tasks().await.unwrap();
316
317            // get alerts from db
318            let request = DriftAlertRequest {
319                space: "statworld".to_string(),
320                name: "test_app".to_string(),
321                version: "0.1.0".to_string(),
322                limit_datetime: None,
323                active: None,
324                limit: None,
325            };
326            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
327                .await
328                .unwrap();
329
330            assert!(!alerts.is_empty());
331        }
332
333        #[tokio::test]
334        async fn test_drift_executor_psi() {
335            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
336                .await
337                .unwrap();
338
339            cleanup(&db_pool).await;
340
341            let mut populate_path =
342                std::env::current_dir().expect("Failed to get current directory");
343            populate_path.push("src/scripts/populate_psi.sql");
344
345            let mut script = std::fs::read_to_string(populate_path).unwrap();
346            let bin_count = 1000;
347            let skew_feature = "feature_1";
348            let skew_factor = 10;
349            let apply_skew = true;
350            script = script.replace("{{bin_count}}", &bin_count.to_string());
351            script = script.replace("{{skew_feature}}", skew_feature);
352            script = script.replace("{{skew_factor}}", &skew_factor.to_string());
353            script = script.replace("{{apply_skew}}", &apply_skew.to_string());
354            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
355
356            let mut drift_executor = DriftExecutor::new(&db_pool);
357
358            drift_executor.poll_for_tasks().await.unwrap();
359
360            // get alerts from db
361            let request = DriftAlertRequest {
362                space: "scouter".to_string(),
363                name: "model".to_string(),
364                version: "0.1.0".to_string(),
365                limit_datetime: None,
366                active: None,
367                limit: None,
368            };
369            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
370                .await
371                .unwrap();
372
373            assert_eq!(alerts.len(), 1);
374        }
375
376        /// This test verifies that the PSI drift executor does **not** generate any drift alerts
377        /// when there are **not enough target samples** to meet the minimum threshold required
378        /// for PSI calculation.
379        ///
380        /// This arg determines how many bin counts to simulate for a production environment.
381        /// In the script there are 3 features, each with 10 bins.
382        /// `bin_count = 2` means we simulate 2 observations per bin.
383        /// So for each feature: 10 bins * 2 samples = 20 samples inserted PER insert.
384        /// Since the script inserts each feature's data 3 times (simulating 3 production batches),
385        /// each feature ends up with: 20 samples * 3 = 60 samples total.
386        /// This is below the required threshold of >100 samples per feature for PSI calculation,
387        /// so no drift alert should be generated.
388        #[tokio::test]
389        async fn test_drift_executor_psi_not_enough_target_samples() {
390            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
391                .await
392                .unwrap();
393
394            cleanup(&db_pool).await;
395
396            let mut populate_path =
397                std::env::current_dir().expect("Failed to get current directory");
398            populate_path.push("src/scripts/populate_psi.sql");
399
400            let mut script = std::fs::read_to_string(populate_path).unwrap();
401            let bin_count = 2;
402            let skew_feature = "feature_1";
403            let skew_factor = 1;
404            let apply_skew = false;
405            script = script.replace("{{bin_count}}", &bin_count.to_string());
406            script = script.replace("{{skew_feature}}", skew_feature);
407            script = script.replace("{{skew_factor}}", &skew_factor.to_string());
408            script = script.replace("{{apply_skew}}", &apply_skew.to_string());
409            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
410
411            let mut drift_executor = DriftExecutor::new(&db_pool);
412
413            drift_executor.poll_for_tasks().await.unwrap();
414
415            // get alerts from db
416            let request = DriftAlertRequest {
417                space: "scouter".to_string(),
418                name: "model".to_string(),
419                version: "0.1.0".to_string(),
420                limit_datetime: None,
421                active: None,
422                limit: None,
423            };
424            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
425                .await
426                .unwrap();
427
428            assert!(alerts.is_empty());
429        }
430
431        #[tokio::test]
432        async fn test_drift_executor_custom() {
433            let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
434                .await
435                .unwrap();
436
437            cleanup(&db_pool).await;
438
439            let mut populate_path =
440                std::env::current_dir().expect("Failed to get current directory");
441            populate_path.push("src/scripts/populate_custom.sql");
442
443            let script = std::fs::read_to_string(populate_path).unwrap();
444            sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
445
446            let mut drift_executor = DriftExecutor::new(&db_pool);
447
448            drift_executor.poll_for_tasks().await.unwrap();
449
450            // get alerts from db
451            let request = DriftAlertRequest {
452                space: "scouter".to_string(),
453                name: "model".to_string(),
454                version: "0.1.0".to_string(),
455                limit_datetime: None,
456                active: None,
457                limit: None,
458            };
459            let alerts = PostgresClient::get_drift_alerts(&db_pool, &request)
460                .await
461                .unwrap();
462
463            assert_eq!(alerts.len(), 1);
464        }
465    }
466}