1use crate::error::DriftError;
2use crate::{custom::CustomDrifter, genai::GenAIDrifter, psi::PsiDrifter, spc::SpcDrifter};
3use chrono::{DateTime, Utc};
4
5use scouter_sql::sql::traits::{AlertSqlLogic, ProfileSqlLogic};
6use scouter_sql::{sql::schema::TaskRequest, PostgresClient};
7use scouter_types::{AlertMap, DriftProfile};
8use sqlx::{Pool, Postgres};
9use std::result::Result;
10use std::result::Result::Ok;
11
12use tracing::{debug, error, info, instrument, span, Instrument, Level};
13
14#[allow(clippy::enum_variant_names)]
15pub enum Drifter {
16 SpcDrifter(SpcDrifter),
17 PsiDrifter(PsiDrifter),
18 CustomDrifter(CustomDrifter),
19 GenAIDrifter(GenAIDrifter),
20}
21
22impl Drifter {
23 pub async fn check_for_alerts(
24 &self,
25 db_pool: &Pool<Postgres>,
26 previous_run: &DateTime<Utc>,
27 ) -> Result<Option<Vec<AlertMap>>, DriftError> {
28 match self {
29 Drifter::SpcDrifter(drifter) => drifter.check_for_alerts(db_pool, previous_run).await,
30 Drifter::PsiDrifter(drifter) => drifter.check_for_alerts(db_pool, previous_run).await,
31 Drifter::CustomDrifter(drifter) => {
32 drifter.check_for_alerts(db_pool, previous_run).await
33 }
34 Drifter::GenAIDrifter(drifter) => drifter.check_for_alerts(db_pool, previous_run).await,
35 }
36 }
37}
38
39pub trait GetDrifter {
40 fn get_drifter(&self) -> Drifter;
41}
42
43impl GetDrifter for DriftProfile {
44 fn get_drifter(&self) -> Drifter {
56 match self {
57 DriftProfile::Spc(profile) => Drifter::SpcDrifter(SpcDrifter::new(profile.clone())),
58 DriftProfile::Psi(profile) => Drifter::PsiDrifter(PsiDrifter::new(profile.clone())),
59 DriftProfile::Custom(profile) => {
60 Drifter::CustomDrifter(CustomDrifter::new(profile.clone()))
61 }
62 DriftProfile::GenAI(profile) => {
63 Drifter::GenAIDrifter(GenAIDrifter::new(profile.clone()))
64 }
65 }
66 }
67}
68
69pub struct DriftExecutor {
70 db_pool: Pool<Postgres>,
71}
72
73impl DriftExecutor {
74 pub fn new(db_pool: &Pool<Postgres>) -> Self {
75 Self {
76 db_pool: db_pool.clone(),
77 }
78 }
79
80 pub async fn _process_task(
92 &mut self,
93 profile: DriftProfile,
94 previous_run: &DateTime<Utc>,
95 ) -> Result<Option<Vec<AlertMap>>, DriftError> {
96 profile
99 .get_drifter()
100 .check_for_alerts(&self.db_pool, previous_run)
101 .await
102 }
103
104 async fn do_poll(&mut self) -> bool {
105 debug!("Polling for drift tasks");
106
107 let task = match PostgresClient::get_drift_profile_task(&self.db_pool).await {
109 Ok(task) => task,
110 Err(e) => {
111 error!("Error fetching drift task: {:?}", e);
112 return false;
113 }
114 };
115
116 let Some(task) = task else {
117 return false;
118 };
119
120 info!(
121 "Processing drift task for profile: {} and type {}",
122 task.uid, task.drift_type
123 );
124
125 match self.process_task(&task).await {
127 Ok(_) => info!(
128 "Successfully processed drift task for profile: {}",
129 task.uid
130 ),
131 Err(e) => error!(
132 "Error processing drift task for profile {}: {:?}",
133 task.uid, e
134 ),
135 }
136
137 match PostgresClient::update_drift_profile_run_dates(
138 &self.db_pool,
139 &task.entity_id,
140 &task.schedule,
141 &task.previous_run,
142 )
143 .instrument(span!(Level::INFO, "Update Run Dates"))
144 .await
145 {
146 Ok(_) => info!("Updated run dates for drift profile task: {}", task.uid),
147 Err(e) => error!(
148 "CRITICAL: Failed to reschedule task Error updating run dates for drift profile task {}: {:?}",
149 task.uid, e
150 ),
151 }
152
153 true
154 }
155
156 #[instrument(skip_all)]
157 async fn process_task(
158 &mut self,
159 task: &TaskRequest,
160 ) -> Result<(), DriftError> {
163 let profile = DriftProfile::from_str(&task.drift_type, &task.profile).inspect_err(|e| {
165 error!(
166 "Error converting drift profile for type {:?}: {:?}",
167 &task.drift_type, e
168 );
169 })?;
170
171 match self._process_task(profile, &task.previous_run).await {
172 Ok(Some(alerts)) => {
173 info!("Drift task processed successfully with alerts");
174
175 for alert in alerts {
177 PostgresClient::insert_drift_alert(&self.db_pool, &task.entity_id, &alert)
178 .await
179 .map_err(|e| {
180 error!("Error inserting drift alert: {:?}", e);
181 DriftError::SqlError(e)
182 })?;
183 }
184 Ok(())
185 }
186 Ok(None) => {
187 info!("Drift task processed successfully with no alerts");
188 Ok(())
189 }
190 Err(e) => {
191 error!("Error processing drift task: {:?}", e);
192 Err(DriftError::AlertProcessingError(e.to_string()))
193 }
194 }
195 }
196
197 #[instrument(skip_all)]
203 pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
204 match self.do_poll().await {
205 true => {
206 info!("Successfully processed drift task");
207 Ok(())
208 }
209 false => {
210 info!("No triggered schedules found in db. Sleeping for 10 seconds");
211 tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
212 Ok(())
213 }
214 }
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use crate::GenAIPoller;
221
222 use super::*;
223 use rusty_logging::logger::{LogLevel, LoggingConfig, RustyLogger};
224 use scouter_settings::DatabaseSettings;
225 use scouter_sql::sql::traits::{EntitySqlLogic, GenAIDriftSqlLogic, SpcSqlLogic};
226 use scouter_sql::PostgresClient;
227 use scouter_types::spc::SpcFeatureDriftProfile;
228 use scouter_types::{
229 spc::{SpcAlertConfig, SpcAlertRule, SpcDriftConfig, SpcDriftProfile},
230 AlertDispatchConfig, DriftAlertPaginationRequest,
231 };
232 use scouter_types::{BoxedGenAIEvalRecord, DriftType, ProfileArgs, SpcRecord};
233 use semver::Version;
234 use sqlx::{postgres::Postgres, Pool};
235 use std::collections::HashMap;
236
237 use potato_head::mock::{create_score_prompt, LLMTestServer};
238 use scouter_types::genai::{
239 AssertionTask, ComparisonOperator, EvaluationTaskType, EvaluationTasks, GenAIAlertConfig,
240 GenAIEvalConfig, GenAIEvalProfile, LLMJudgeTask,
241 };
242 use scouter_types::{AlertCondition, AlertThreshold, GenAIEvalRecord};
243 use serde_json::Value;
244
245 pub async fn cleanup(pool: &Pool<Postgres>) {
246 sqlx::raw_sql(
247 r#"
248 DELETE
249 FROM scouter.spc_drift;
250
251 DELETE
252 FROM scouter.drift_entities;
253
254 DELETE
255 FROM scouter.observability_metric;
256
257 DELETE
258 FROM scouter.custom_drift;
259
260 DELETE
261 FROM scouter.drift_alert;
262
263 DELETE
264 FROM scouter.drift_profile;
265
266 DELETE
267 FROM scouter.psi_drift;
268
269 DELETE
270 FROM scouter.genai_eval_workflow;
271
272 DELETE
273 FROM scouter.genai_eval_task;
274
275 DELETE
276 FROM scouter.genai_eval_record;
277 "#,
278 )
279 .fetch_all(pool)
280 .await
281 .unwrap();
282
283 RustyLogger::setup_logging(Some(LoggingConfig::new(
284 None,
285 Some(LogLevel::Info),
286 None,
287 None,
288 )))
289 .unwrap();
290 }
291
292 #[tokio::test]
293 async fn test_drift_executor_spc() {
294 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
295 .await
296 .unwrap();
297
298 cleanup(&db_pool).await;
299
300 let alert_config = SpcAlertConfig {
301 rule: SpcAlertRule::default(),
302 schedule: "* * * * * * *".to_string(),
304 features_to_monitor: vec!["col_1".to_string(), "col_3".to_string()],
305 dispatch_config: AlertDispatchConfig::default(),
306 };
307
308 let config = SpcDriftConfig::new(
309 "statworld",
310 "test_app",
311 "0.1.0",
312 Some(true),
313 Some(25),
314 Some(alert_config),
315 None,
316 )
317 .unwrap();
318
319 let col1_profile = SpcFeatureDriftProfile {
320 id: "col_1".to_string(),
321 center: -3.997113080300062,
322 one_ucl: -1.9742384896265417,
323 one_lcl: -6.019987670973582,
324 two_ucl: 0.048636101046978464,
325 two_lcl: -8.042862261647102,
326 three_ucl: 2.071510691720498,
327 three_lcl: -10.065736852320622,
328 timestamp: Utc::now(),
329 };
330
331 let col3_profile = SpcFeatureDriftProfile {
332 id: "col_3".to_string(),
333 center: -3.937652409303277,
334 one_ucl: -2.0275656995100224,
335 one_lcl: -5.8477391190965315,
336 two_ucl: -0.1174789897167674,
337 two_lcl: -7.757825828889787,
338 three_ucl: 1.7926077200764872,
339 three_lcl: -9.66791253868304,
340 timestamp: Utc::now(),
341 };
342
343 let drift_profile = DriftProfile::Spc(SpcDriftProfile {
344 config,
345 features: HashMap::from([
346 (col1_profile.id.clone(), col1_profile),
347 (col3_profile.id.clone(), col3_profile),
348 ]),
349 scouter_version: "0.1.0".to_string(),
350 });
351
352 let profile_args = ProfileArgs {
353 space: "statworld".to_string(),
354 name: "test_app".to_string(),
355 version: Some("0.1.0".to_string()),
356 schedule: "* * * * * *".to_string(),
357 scouter_version: "0.1.0".to_string(),
358 drift_type: DriftType::Spc,
359 };
360
361 let version = Version::new(0, 1, 0);
362 let uid = PostgresClient::insert_drift_profile(
363 &db_pool,
364 &drift_profile,
365 &profile_args,
366 &version,
367 &true,
368 &true,
369 )
370 .await
371 .unwrap();
372 let entity_id = PostgresClient::get_entity_id_from_uid(&db_pool, &uid)
373 .await
374 .unwrap();
375
376 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
377
378 let mut records = vec![]; for i in 0..100 {
380 let record = SpcRecord {
381 created_at: Utc::now() + chrono::Duration::seconds(i),
383 uid: uid.clone(),
384 feature: "col_1".to_string(),
385 value: 10.0 + i as f64,
386 entity_id,
387 };
388 records.push(record);
389 }
390
391 PostgresClient::insert_spc_drift_records_batch(&db_pool, &records, &entity_id)
392 .await
393 .unwrap();
394
395 let mut drift_executor = DriftExecutor::new(&db_pool);
396 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
397
398 drift_executor.poll_for_tasks().await.unwrap();
399
400 let request = DriftAlertPaginationRequest {
402 active: None,
403 limit: None,
404 uid: uid.clone(),
405 ..Default::default()
406 };
407
408 let entity_id = PostgresClient::get_entity_id_from_uid(&db_pool, &uid)
409 .await
410 .unwrap();
411
412 let alerts = PostgresClient::get_paginated_drift_alerts(&db_pool, &request, &entity_id)
413 .await
414 .unwrap();
415 assert!(!alerts.items.is_empty());
416 }
417
418 #[tokio::test]
419 async fn test_drift_executor_psi() {
420 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
421 .await
422 .unwrap();
423
424 cleanup(&db_pool).await;
425
426 let mut populate_path = std::env::current_dir().expect("Failed to get current directory");
427 populate_path.push("src/scripts/populate_psi.sql");
428
429 let mut script = std::fs::read_to_string(populate_path).unwrap();
430 let bin_count = 1000;
431 let skew_feature = "feature_1";
432 let skew_factor = 10;
433 let apply_skew = true;
434 script = script.replace("{{bin_count}}", &bin_count.to_string());
435 script = script.replace("{{skew_feature}}", skew_feature);
436 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
437 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
438 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
439 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
440
441 let mut drift_executor = DriftExecutor::new(&db_pool);
442
443 drift_executor.poll_for_tasks().await.unwrap();
444
445 let request = DriftAlertPaginationRequest {
447 uid: "019ae1b4-3003-77c1-8eed-2ec005e85963".to_string(),
448 active: None,
449 limit: None,
450 ..Default::default()
451 };
452
453 let entity_id = PostgresClient::get_entity_id_from_space_name_version_drift_type(
454 &db_pool,
455 "scouter",
456 "model",
457 "0.1.0",
458 DriftType::Psi.to_string(),
459 )
460 .await
461 .unwrap();
462
463 let alerts = PostgresClient::get_paginated_drift_alerts(&db_pool, &request, &entity_id)
464 .await
465 .unwrap();
466
467 assert_eq!(alerts.items.len(), 1);
468 }
469
470 #[tokio::test]
483 async fn test_drift_executor_psi_not_enough_target_samples() {
484 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
485 .await
486 .unwrap();
487
488 cleanup(&db_pool).await;
489
490 let mut populate_path = std::env::current_dir().expect("Failed to get current directory");
491 populate_path.push("src/scripts/populate_psi.sql");
492
493 let mut script = std::fs::read_to_string(populate_path).unwrap();
494 let bin_count = 2;
495 let skew_feature = "feature_1";
496 let skew_factor = 1;
497 let apply_skew = false;
498 script = script.replace("{{bin_count}}", &bin_count.to_string());
499 script = script.replace("{{skew_feature}}", skew_feature);
500 script = script.replace("{{skew_factor}}", &skew_factor.to_string());
501 script = script.replace("{{apply_skew}}", &apply_skew.to_string());
502 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
503 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
504
505 let mut drift_executor = DriftExecutor::new(&db_pool);
506
507 drift_executor.poll_for_tasks().await.unwrap();
508
509 let request = DriftAlertPaginationRequest {
511 uid: "019ae1b4-3003-77c1-8eed-2ec005e85963".to_string(),
512 active: None,
513 limit: None,
514 ..Default::default()
515 };
516
517 let entity_id = PostgresClient::get_entity_id_from_space_name_version_drift_type(
518 &db_pool,
519 "scouter",
520 "model",
521 "0.1.0",
522 DriftType::Psi.to_string(),
523 )
524 .await
525 .unwrap();
526
527 let alerts = PostgresClient::get_paginated_drift_alerts(&db_pool, &request, &entity_id)
528 .await
529 .unwrap();
530
531 assert!(alerts.items.is_empty());
532 }
533
534 #[tokio::test]
535 async fn test_drift_executor_custom() {
536 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
537 .await
538 .unwrap();
539
540 cleanup(&db_pool).await;
541
542 let mut populate_path = std::env::current_dir().expect("Failed to get current directory");
543 populate_path.push("src/scripts/populate_custom.sql");
544
545 let script = std::fs::read_to_string(populate_path).unwrap();
546 sqlx::raw_sql(&script).execute(&db_pool).await.unwrap();
547 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
548
549 let mut drift_executor = DriftExecutor::new(&db_pool);
550
551 drift_executor.poll_for_tasks().await.unwrap();
552
553 let request = DriftAlertPaginationRequest {
555 uid: "scouter|model|0.1.0|custom".to_string(),
556 ..Default::default()
557 };
558
559 let entity_id = PostgresClient::get_entity_id_from_space_name_version_drift_type(
560 &db_pool,
561 "scouter",
562 "model",
563 "0.1.0",
564 DriftType::Custom.to_string(),
565 )
566 .await
567 .unwrap();
568
569 let alerts = PostgresClient::get_paginated_drift_alerts(&db_pool, &request, &entity_id)
570 .await
571 .unwrap();
572
573 assert_eq!(alerts.items.len(), 2);
574 }
575
576 #[test]
577 fn test_drift_executor_genai() {
578 let mut mock = LLMTestServer::new();
580 mock.start_server().unwrap();
581 let runtime = tokio::runtime::Runtime::new().unwrap();
582
583 let db_pool = runtime.block_on(async {
584 let db_pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
586 .await
587 .unwrap();
588
589 cleanup(&db_pool).await;
590 db_pool
591 });
592
593 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
595
596 let assertion_level_1 = AssertionTask {
597 id: "input_check".to_string(),
598 field_path: Some("input.foo".to_string()),
599 operator: ComparisonOperator::Equals,
600 expected_value: Value::String("bar".to_string()),
601 description: Some("Check if input.foo is bar".to_string()),
602 task_type: EvaluationTaskType::Assertion,
603 depends_on: vec![],
604 result: None,
605 condition: false,
606 };
607
608 let judge_task = LLMJudgeTask::new_rs(
609 "query_relevance",
610 prompt.clone(),
611 Value::Number(1.into()),
612 Some("score".to_string()),
613 ComparisonOperator::GreaterThanOrEqual,
614 None,
615 None,
616 None,
617 None,
618 );
619
620 let assert_query_score = AssertionTask {
621 id: "assert_score".to_string(),
622 field_path: Some("query_relevance.score".to_string()),
623 operator: ComparisonOperator::IsNumeric,
624 expected_value: Value::Bool(true),
625 depends_on: vec!["query_relevance".to_string()],
626 task_type: EvaluationTaskType::Assertion,
627 description: Some("Check that score is numeric".to_string()),
628 result: None,
629 condition: false,
630 };
631
632 let tasks = EvaluationTasks::new()
633 .add_task(assertion_level_1)
634 .add_task(judge_task)
635 .add_task(assert_query_score)
636 .build();
637
638 let alert_condition = AlertCondition {
640 baseline_value: 0.8, alert_threshold: AlertThreshold::Below,
642 delta: Some(0.01), };
644
645 let alert_config = GenAIAlertConfig {
646 schedule: "* * * * * *".to_string(), dispatch_config: AlertDispatchConfig::default(),
648 alert_condition: Some(alert_condition),
649 };
650
651 let drift_config =
652 GenAIEvalConfig::new("scouter", "genai_test", "0.1.0", 1.0, alert_config, None)
653 .unwrap();
654
655 let profile = runtime
656 .block_on(async { GenAIEvalProfile::new(drift_config, tasks).await })
657 .unwrap();
658 let drift_profile = DriftProfile::GenAI(profile.clone());
659
660 let profile_args = ProfileArgs {
662 space: "scouter".to_string(),
663 name: "genai_test".to_string(),
664 version: Some("0.1.0".to_string()),
665 schedule: "* * * * * *".to_string(),
666 scouter_version: "0.1.0".to_string(),
667 drift_type: DriftType::GenAI,
668 };
669
670 let version = Version::new(0, 1, 0);
671
672 let uid = runtime.block_on(async {
673 PostgresClient::insert_drift_profile(
674 &db_pool,
675 &drift_profile,
676 &profile_args,
677 &version,
678 &true,
679 &true,
680 )
681 .await
682 .unwrap()
683 });
684
685 let entity_id = runtime.block_on(async {
686 PostgresClient::get_entity_id_from_uid(&db_pool, &uid)
687 .await
688 .unwrap()
689 });
690
691 std::thread::sleep(std::time::Duration::from_secs(1));
693
694 let mut records = vec![];
696 for i in 0..50 {
697 let context = serde_json::json!({
699 "input": {
700 "foo": if i % 4 == 0 { "bar" } else { "wrong" } }
702 });
703
704 let record = GenAIEvalRecord::new_rs(
705 context,
706 Utc::now() + chrono::Duration::seconds(i),
707 format!("UID{}", i),
708 uid.clone(),
709 None,
710 None,
711 );
712
713 records.push(BoxedGenAIEvalRecord::new(record));
714 }
715
716 let mut poller = GenAIPoller::new(&db_pool, 3);
718 for record in records {
719 runtime.block_on(async {
722 PostgresClient::insert_genai_eval_record(&db_pool, record, &entity_id)
723 .await
724 .unwrap();
725
726 poller.do_poll().await.unwrap();
727 });
728 }
729
730 let mut drift_executor = DriftExecutor::new(&db_pool);
732
733 runtime.block_on(async {
734 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
735 drift_executor.poll_for_tasks().await.unwrap();
736 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
737 });
738
739 let request = DriftAlertPaginationRequest {
741 uid: uid.clone(),
742 active: None,
743 limit: None,
744 ..Default::default()
745 };
746
747 let alerts = runtime.block_on(async {
748 PostgresClient::get_paginated_drift_alerts(&db_pool, &request, &entity_id)
749 .await
750 .unwrap()
751 });
752
753 assert!(
754 !alerts.items.is_empty(),
755 "Expected drift alerts to be generated for low pass rate"
756 );
757
758 let alert = &alerts.items[0];
760
761 assert_eq!(alert.alert.entity_name(), "genai_workflow_metric");
762
763 let observed_value: f64 = match &alert.alert {
765 AlertMap::GenAI(genai_alert) => genai_alert.observed_value,
766 _ => panic!("Expected GenAI alert map"),
767 };
768
769 assert!(
770 observed_value < 0.8, "Expected low pass rate to trigger alert"
772 );
773
774 mock.stop_server().unwrap();
776 }
777}