scouter_drift/llm/
drift.rs

1use crate::error::DriftError;
2use chrono::{DateTime, Utc};
3use scouter_dispatch::AlertDispatcher;
4use scouter_sql::sql::traits::LLMDriftSqlLogic;
5use scouter_sql::PostgresClient;
6use scouter_types::contracts::ServiceInfo;
7use scouter_types::{custom::ComparisonMetricAlert, llm::LLMDriftProfile, AlertThreshold};
8use sqlx::{Pool, Postgres};
9use std::collections::{BTreeMap, HashMap};
10use tracing::error;
11use tracing::info;
12
13pub struct LLMDrifter {
14    service_info: ServiceInfo,
15    profile: LLMDriftProfile,
16}
17
18impl LLMDrifter {
19    pub fn new(profile: LLMDriftProfile) -> Self {
20        Self {
21            service_info: ServiceInfo {
22                name: profile.config.name.clone(),
23                space: profile.config.space.clone(),
24                version: profile.config.version.clone(),
25            },
26            profile,
27        }
28    }
29
30    pub async fn get_observed_llm_metric_values(
31        &self,
32        limit_datetime: &DateTime<Utc>,
33        db_pool: &Pool<Postgres>,
34    ) -> Result<HashMap<String, f64>, DriftError> {
35        let metrics: Vec<String> = self
36            .profile
37            .metrics
38            .iter()
39            .map(|metric| metric.name.clone())
40            .collect();
41
42        Ok(PostgresClient::get_llm_metric_values(
43            db_pool,
44            &self.service_info,
45            limit_datetime,
46            &metrics,
47        )
48        .await
49        .inspect_err(|e| {
50            let msg = format!(
51                "Error: Unable to obtain llm metric data from DB for {}/{}/{}: {}",
52                self.service_info.space, self.service_info.name, self.service_info.version, e
53            );
54            error!(msg);
55        })?)
56    }
57
58    pub async fn get_metric_map(
59        &self,
60        limit_datetime: &DateTime<Utc>,
61        db_pool: &Pool<Postgres>,
62    ) -> Result<Option<HashMap<String, f64>>, DriftError> {
63        let metric_map = self
64            .get_observed_llm_metric_values(limit_datetime, db_pool)
65            .await?;
66
67        if metric_map.is_empty() {
68            info!(
69                "No llm metric data was found for {}/{}/{}. Skipping alert processing.",
70                self.service_info.space, self.service_info.name, self.service_info.version,
71            );
72            return Ok(None);
73        }
74
75        Ok(Some(metric_map))
76    }
77
78    fn is_out_of_bounds(
79        training_value: f64,
80        observed_value: f64,
81        alert_condition: &AlertThreshold,
82        alert_boundary: Option<f64>,
83    ) -> bool {
84        if observed_value == training_value {
85            return false;
86        }
87
88        let below_threshold = |boundary: Option<f64>| match boundary {
89            Some(b) => observed_value < training_value - b,
90            None => observed_value < training_value,
91        };
92
93        let above_threshold = |boundary: Option<f64>| match boundary {
94            Some(b) => observed_value > training_value + b,
95            None => observed_value > training_value,
96        };
97
98        match alert_condition {
99            AlertThreshold::Below => below_threshold(alert_boundary),
100            AlertThreshold::Above => above_threshold(alert_boundary),
101            AlertThreshold::Outside => {
102                below_threshold(alert_boundary) || above_threshold(alert_boundary)
103            } // Handled by early equality check
104        }
105    }
106
107    pub async fn generate_alerts(
108        &self,
109        metric_map: &HashMap<String, f64>,
110    ) -> Result<Option<Vec<ComparisonMetricAlert>>, DriftError> {
111        let metric_alerts: Vec<ComparisonMetricAlert> = metric_map
112            .iter()
113            .filter_map(|(name, observed_value)| {
114                let training_value = self
115                    .profile
116                    .get_metric_value(name)
117                    .inspect_err(|e| {
118                        let msg = format!("Error getting training value for metric {name}: {e}");
119                        error!(msg);
120                    })
121                    .ok()?;
122                let alert_condition = &self
123                    .profile
124                    .config
125                    .alert_config
126                    .alert_conditions
127                    .as_ref()
128                    .unwrap()[name];
129                if Self::is_out_of_bounds(
130                    training_value,
131                    *observed_value,
132                    &alert_condition.alert_threshold,
133                    alert_condition.alert_threshold_value,
134                ) {
135                    Some(ComparisonMetricAlert {
136                        metric_name: name.clone(),
137                        training_metric_value: training_value,
138                        observed_metric_value: *observed_value,
139                        alert_threshold_value: alert_condition.alert_threshold_value,
140                        alert_threshold: alert_condition.alert_threshold.clone(),
141                    })
142                } else {
143                    None
144                }
145            })
146            .collect();
147
148        if metric_alerts.is_empty() {
149            info!(
150                "No alerts to process for {}/{}/{}",
151                self.service_info.space, self.service_info.name, self.service_info.version
152            );
153            return Ok(None);
154        }
155
156        let alert_dispatcher = AlertDispatcher::new(&self.profile.config).inspect_err(|e| {
157            let msg = format!(
158                "Error creating alert dispatcher for {}/{}/{}: {}",
159                self.service_info.space, self.service_info.name, self.service_info.version, e
160            );
161            error!(msg);
162        })?;
163
164        for alert in &metric_alerts {
165            alert_dispatcher
166                .process_alerts(alert)
167                .await
168                .inspect_err(|e| {
169                    let msg = format!(
170                        "Error processing alerts for {}/{}/{}: {}",
171                        self.service_info.space,
172                        self.service_info.name,
173                        self.service_info.version,
174                        e
175                    );
176                    error!(msg);
177                })?;
178        }
179
180        Ok(Some(metric_alerts))
181    }
182
183    fn organize_alerts(mut alerts: Vec<ComparisonMetricAlert>) -> Vec<BTreeMap<String, String>> {
184        let mut alert_vec = Vec::new();
185        alerts.iter_mut().for_each(|alert| {
186            let mut alert_map = BTreeMap::new();
187            alert_map.insert("entity_name".to_string(), alert.metric_name.clone());
188            alert_map.insert(
189                "training_metric_value".to_string(),
190                alert.training_metric_value.to_string(),
191            );
192            alert_map.insert(
193                "observed_metric_value".to_string(),
194                alert.observed_metric_value.to_string(),
195            );
196            let alert_threshold_value_str = match alert.alert_threshold_value {
197                Some(value) => value.to_string(),
198                None => "None".to_string(),
199            };
200            alert_map.insert(
201                "alert_threshold_value".to_string(),
202                alert_threshold_value_str,
203            );
204            alert_map.insert(
205                "alert_threshold".to_string(),
206                alert.alert_threshold.to_string(),
207            );
208            alert_vec.push(alert_map);
209        });
210
211        alert_vec
212    }
213
214    pub async fn check_for_alerts(
215        &self,
216        db_pool: &Pool<Postgres>,
217        previous_run: DateTime<Utc>,
218    ) -> Result<Option<Vec<BTreeMap<String, String>>>, DriftError> {
219        let metric_map = self.get_metric_map(&previous_run, db_pool).await?;
220
221        match metric_map {
222            Some(metric_map) => {
223                let alerts = self.generate_alerts(&metric_map).await.inspect_err(|e| {
224                    let msg = format!(
225                        "Error generating alerts for {}/{}/{}: {}",
226                        self.service_info.space,
227                        self.service_info.name,
228                        self.service_info.version,
229                        e
230                    );
231                    error!(msg);
232                })?;
233                match alerts {
234                    Some(alerts) => Ok(Some(Self::organize_alerts(alerts))),
235                    None => Ok(None),
236                }
237            }
238            None => Ok(None),
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use potato_head::{create_score_prompt, LLMTestServer};
247    use scouter_types::llm::{LLMAlertConfig, LLMDriftConfig, LLMDriftMetric, LLMDriftProfile};
248
249    async fn get_test_drifter() -> LLMDrifter {
250        let prompt = create_score_prompt(Some(vec!["input".to_string()]));
251        let metric1 = LLMDriftMetric::new(
252            "coherence",
253            5.0,
254            AlertThreshold::Below,
255            Some(0.5),
256            Some(prompt.clone()),
257        )
258        .unwrap();
259
260        let metric2 = LLMDriftMetric::new(
261            "relevancy",
262            5.0,
263            AlertThreshold::Below,
264            None,
265            Some(prompt.clone()),
266        )
267        .unwrap();
268
269        let alert_config = LLMAlertConfig::default();
270        let drift_config =
271            LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
272
273        let profile = LLMDriftProfile::from_metrics(drift_config, vec![metric1, metric2])
274            .await
275            .unwrap();
276
277        LLMDrifter::new(profile)
278    }
279
280    #[test]
281    fn test_is_out_of_bounds() {
282        // relevancy training value obtained during initial model training.
283        let relevancy_training_value = 5.0;
284
285        // observed relevancy metric value captured somewhere after the initial training run.
286        let relevancy_observed_value = 4.0;
287
288        // we want relevancy to be as small as possible, so we want to see if the metric has increased.
289        let relevancy_alert_condition = AlertThreshold::Below;
290
291        // we do not want to alert if the relevancy values have decreased by more than 0.5
292        // if the metric observed has increased beyond (relevancy_training_value - 0.5)
293        let relevancy_alert_boundary = Some(0.5);
294
295        let relevancy_is_out_of_bounds = LLMDrifter::is_out_of_bounds(
296            relevancy_training_value,
297            relevancy_observed_value,
298            &relevancy_alert_condition,
299            relevancy_alert_boundary,
300        );
301        assert!(relevancy_is_out_of_bounds);
302
303        // test observed metric has decreased beyond threshold.
304
305        // coherence training value obtained during initial model training.
306        let coherence_training_value = 0.76;
307
308        // observed coherence metric value captured somewhere after the initial training run.
309        let coherence_observed_value = 0.67;
310
311        // we want to alert if coherence has decreased.
312        let coherence_alert_condition = AlertThreshold::Below;
313
314        // we will not be specifying a boundary here as we want to alert if coherence has decreased by any amount
315        let coherence_alert_boundary = None;
316
317        let coherence_is_out_of_bounds = LLMDrifter::is_out_of_bounds(
318            coherence_training_value,
319            coherence_observed_value,
320            &coherence_alert_condition,
321            coherence_alert_boundary,
322        );
323        assert!(coherence_is_out_of_bounds);
324    }
325
326    #[test]
327    fn test_generate_llm_alerts() {
328        let mut mock = LLMTestServer::new();
329        mock.start_server().unwrap();
330        let runtime = tokio::runtime::Runtime::new().unwrap();
331
332        let mut metric_map = HashMap::new();
333        // mse had an initial value of 12.02 when the profile was generated
334        metric_map.insert("coherence".to_string(), 4.0);
335        // accuracy had an initial 0.75 when the profile was generated
336        metric_map.insert("relevancy".to_string(), 4.5);
337
338        let alerts = runtime.block_on(async {
339            let drifter = get_test_drifter().await;
340            drifter.generate_alerts(&metric_map).await.unwrap().unwrap()
341        });
342
343        assert_eq!(alerts.len(), 2);
344        mock.stop_server().unwrap();
345    }
346}