1use crate::error::DriftError;
2use chrono::{DateTime, Utc};
3use scouter_dispatch::AlertDispatcher;
4use scouter_sql::sql::traits::GenAIDriftSqlLogic;
5use scouter_sql::{sql::cache::entity_cache, PostgresClient};
6use scouter_types::{custom::ComparisonMetricAlert, genai::GenAIEvalProfile};
7use scouter_types::{AlertMap, ProfileBaseArgs};
8use sqlx::{Pool, Postgres};
9use tracing::error;
10use tracing::info;
11
12pub struct GenAIDrifter {
13 profile: GenAIEvalProfile,
14}
15
16impl GenAIDrifter {
17 pub fn new(profile: GenAIEvalProfile) -> Self {
18 Self { profile }
19 }
20
21 fn profile_id(&self) -> String {
23 format!(
24 "{}/{}/{}",
25 self.profile.space(),
26 self.profile.name(),
27 self.profile.version()
28 )
29 }
30
31 pub async fn get_workflow_value(
33 &self,
34 limit_datetime: &DateTime<Utc>,
35 db_pool: &Pool<Postgres>,
36 ) -> Result<Option<f64>, DriftError> {
37 let entity_id = entity_cache()
38 .get_entity_id_from_uid(db_pool, &self.profile.config.uid)
39 .await?;
40
41 PostgresClient::get_genai_workflow_value(db_pool, limit_datetime, &entity_id)
42 .await
43 .inspect_err(|e| {
44 error!(
45 "Unable to obtain genai metric data from DB for {}: {}",
46 self.profile_id(),
47 e
48 );
49 })
50 .map_err(Into::into)
51 }
52
53 pub async fn get_metric_value(
55 &self,
56 limit_datetime: &DateTime<Utc>,
57 db_pool: &Pool<Postgres>,
58 ) -> Result<Option<f64>, DriftError> {
59 let value = self.get_workflow_value(limit_datetime, db_pool).await?;
60
61 if value.is_none() {
62 info!(
63 "No genai metric data found for {}. Skipping alert processing.",
64 self.profile_id()
65 );
66 }
67
68 Ok(value)
69 }
70
71 pub async fn generate_alerts(
73 &self,
74 observed_value: f64,
75 ) -> Result<Option<Vec<AlertMap>>, DriftError> {
76 let Some(alert_condition) = &self.profile.config.alert_config.alert_condition else {
78 info!(
79 "No alert condition configured for {}. Skipping alert processing.",
80 self.profile_id()
81 );
82 return Ok(None);
83 };
84
85 if !alert_condition.should_alert(observed_value) {
87 info!(
88 "No alerts to process for {} (observed: {}, baseline: {})",
89 self.profile_id(),
90 observed_value,
91 alert_condition.baseline_value
92 );
93 return Ok(None);
94 }
95
96 let metric_name = "genai_workflow_metric".to_string();
98 let comparison_alert = ComparisonMetricAlert {
99 metric_name: metric_name.clone(),
100 baseline_value: alert_condition.baseline_value,
101 observed_value,
102 delta: alert_condition.delta,
103 alert_threshold: alert_condition.alert_threshold.clone(),
104 };
105
106 let alert_dispatcher = AlertDispatcher::new(&self.profile.config).inspect_err(|e| {
108 error!(
109 "Error creating alert dispatcher for {}: {}",
110 self.profile_id(),
111 e
112 );
113 })?;
114
115 alert_dispatcher
116 .process_alerts(&comparison_alert)
117 .await
118 .inspect_err(|e| {
119 error!("Error processing alerts for {}: {}", self.profile_id(), e);
120 })?;
121
122 Ok(Some(vec![AlertMap::GenAI(comparison_alert)]))
124 }
125
126 pub async fn check_for_alerts(
128 &self,
129 db_pool: &Pool<Postgres>,
130 previous_run: &DateTime<Utc>,
131 ) -> Result<Option<Vec<AlertMap>>, DriftError> {
132 let Some(metric_value) = self.get_metric_value(previous_run, db_pool).await? else {
133 return Ok(None);
134 };
135
136 self.generate_alerts(metric_value).await.inspect_err(|e| {
137 error!("Error generating alerts for {}: {}", self.profile_id(), e);
138 })
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use potato_head::mock::create_score_prompt;
146 use scouter_types::genai::{ComparisonOperator, EvaluationTasks};
147 use scouter_types::genai::{GenAIAlertConfig, GenAIEvalConfig, GenAIEvalProfile, LLMJudgeTask};
148 use scouter_types::{
149 AlertCondition, AlertDispatchConfig, AlertThreshold, ConsoleDispatchConfig,
150 };
151 use serde_json::Value;
152
153 async fn get_test_drifter() -> GenAIDrifter {
154 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
155
156 let task1 = LLMJudgeTask::new_rs(
157 "metric1",
158 prompt.clone(),
159 Value::Number(4.into()),
160 None,
161 ComparisonOperator::GreaterThanOrEqual,
162 None,
163 None,
164 None,
165 None,
166 );
167
168 let task2 = LLMJudgeTask::new_rs(
169 "metric2",
170 prompt.clone(),
171 Value::Number(2.into()),
172 None,
173 ComparisonOperator::LessThanOrEqual,
174 None,
175 None,
176 None,
177 None,
178 );
179
180 let tasks = EvaluationTasks::new()
181 .add_task(task1)
182 .add_task(task2)
183 .build();
184
185 let alert_condition = AlertCondition {
186 baseline_value: 5.0,
187 alert_threshold: AlertThreshold::Below,
188 delta: Some(1.0),
189 };
190 let alert_config = GenAIAlertConfig {
191 schedule: "0 0 * * * *".to_string(),
192 dispatch_config: AlertDispatchConfig::Console(ConsoleDispatchConfig { enabled: true }),
193 alert_condition: Some(alert_condition),
194 };
195
196 let drift_config =
197 GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
198
199 let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();
200
201 GenAIDrifter::new(profile)
202 }
203
204 #[tokio::test]
205 async fn test_generate_alerts_triggers_when_threshold_exceeded() {
206 let drifter = get_test_drifter().await;
207
208 let observed_value = 3.0;
209 let alerts = drifter
210 .generate_alerts(observed_value)
211 .await
212 .expect("Should generate alerts successfully");
213
214 assert!(
215 alerts.is_some(),
216 "Should generate alerts for out-of-bounds value"
217 );
218
219 let alert_map = &alerts.unwrap()[0];
220 match alert_map {
221 AlertMap::GenAI(alert) => {
222 assert_eq!(alert.metric_name, "genai_workflow_metric");
223 assert_eq!(alert.observed_value, observed_value);
224 }
225 _ => panic!("Expected GenAI alert map"),
226 }
227 }
228
229 #[tokio::test]
230 async fn test_generate_alerts_no_trigger_within_threshold() {
231 let drifter = get_test_drifter().await;
232
233 let observed_value = 5.0;
235 let alerts = drifter
236 .generate_alerts(observed_value)
237 .await
238 .expect("Should generate alerts successfully");
239
240 assert!(
241 alerts.is_none(),
242 "Should not generate alerts for value within threshold"
243 );
244 }
245}