scouter_types/llm/
alert.rs

1use crate::error::TypeError;
2use crate::{
3    dispatch::AlertDispatchType, AlertDispatchConfig, AlertThreshold, CommonCrons,
4    DispatchAlertDescription, OpsGenieDispatchConfig, PyHelperFuncs, SlackDispatchConfig,
5    ValidateAlertConfig,
6};
7use core::fmt::Debug;
8use potato_head::prompt::ResponseType;
9use potato_head::Prompt;
10use pyo3::prelude::*;
11use pyo3::types::PyString;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15#[pyclass]
16#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
17pub struct LLMDriftMetric {
18    #[pyo3(get, set)]
19    pub name: String,
20
21    #[pyo3(get, set)]
22    pub value: f64,
23
24    #[pyo3(get)]
25    pub prompt: Option<Prompt>,
26
27    #[pyo3(get, set)]
28    pub alert_condition: LLMMetricAlertCondition,
29}
30
31#[pymethods]
32impl LLMDriftMetric {
33    #[new]
34    #[pyo3(signature = (name, value, alert_threshold, alert_threshold_value=None, prompt=None))]
35    pub fn new(
36        name: &str,
37        value: f64,
38        alert_threshold: AlertThreshold,
39        alert_threshold_value: Option<f64>,
40        prompt: Option<Prompt>,
41    ) -> Result<Self, TypeError> {
42        // assert that the prompt is a scoring prompt
43        if let Some(ref prompt) = prompt {
44            if prompt.response_type != ResponseType::Score {
45                return Err(TypeError::InvalidResponseType);
46            }
47        }
48
49        let prompt_condition = LLMMetricAlertCondition::new(alert_threshold, alert_threshold_value);
50
51        Ok(Self {
52            name: name.to_lowercase(),
53            value,
54            prompt,
55            alert_condition: prompt_condition,
56        })
57    }
58
59    pub fn __str__(&self) -> String {
60        // serialize the struct to a string
61        PyHelperFuncs::__str__(self)
62    }
63
64    #[getter]
65    pub fn alert_threshold(&self) -> AlertThreshold {
66        self.alert_condition.alert_threshold.clone()
67    }
68
69    #[getter]
70    pub fn alert_threshold_value(&self) -> Option<f64> {
71        self.alert_condition.alert_threshold_value
72    }
73}
74
75#[pyclass]
76#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
77pub struct LLMMetricAlertCondition {
78    #[pyo3(get, set)]
79    pub alert_threshold: AlertThreshold,
80
81    #[pyo3(get, set)]
82    pub alert_threshold_value: Option<f64>,
83}
84
85#[pymethods]
86#[allow(clippy::too_many_arguments)]
87impl LLMMetricAlertCondition {
88    #[new]
89    #[pyo3(signature = (alert_threshold, alert_threshold_value=None))]
90    pub fn new(alert_threshold: AlertThreshold, alert_threshold_value: Option<f64>) -> Self {
91        Self {
92            alert_threshold,
93            alert_threshold_value,
94        }
95    }
96
97    pub fn __str__(&self) -> String {
98        // serialize the struct to a string
99        PyHelperFuncs::__str__(self)
100    }
101}
102
103#[pyclass]
104#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
105pub struct LLMAlertConfig {
106    pub dispatch_config: AlertDispatchConfig,
107
108    #[pyo3(get, set)]
109    pub schedule: String,
110
111    #[pyo3(get, set)]
112    pub alert_conditions: Option<HashMap<String, LLMMetricAlertCondition>>,
113}
114
115impl LLMAlertConfig {
116    pub fn set_alert_conditions(&mut self, metrics: &[LLMDriftMetric]) {
117        self.alert_conditions = Some(
118            metrics
119                .iter()
120                .map(|m| (m.name.clone(), m.alert_condition.clone()))
121                .collect(),
122        );
123    }
124}
125
126impl ValidateAlertConfig for LLMAlertConfig {}
127
128#[pymethods]
129impl LLMAlertConfig {
130    #[new]
131    #[pyo3(signature = (schedule=None, dispatch_config=None))]
132    pub fn new(
133        schedule: Option<&Bound<'_, PyAny>>,
134        dispatch_config: Option<&Bound<'_, PyAny>>,
135    ) -> Result<Self, TypeError> {
136        let alert_dispatch_config = match dispatch_config {
137            None => AlertDispatchConfig::default(),
138            Some(config) => {
139                if config.is_instance_of::<SlackDispatchConfig>() {
140                    AlertDispatchConfig::Slack(config.extract::<SlackDispatchConfig>()?)
141                } else if config.is_instance_of::<OpsGenieDispatchConfig>() {
142                    AlertDispatchConfig::OpsGenie(config.extract::<OpsGenieDispatchConfig>()?)
143                } else {
144                    AlertDispatchConfig::default()
145                }
146            }
147        };
148
149        let schedule = match schedule {
150            Some(schedule) => {
151                if schedule.is_instance_of::<PyString>() {
152                    schedule.to_string()
153                } else if schedule.is_instance_of::<CommonCrons>() {
154                    schedule.extract::<CommonCrons>().unwrap().cron()
155                } else {
156                    return Err(TypeError::InvalidScheduleError)?;
157                }
158            }
159            None => CommonCrons::EveryDay.cron(),
160        };
161
162        let schedule = Self::resolve_schedule(&schedule);
163
164        Ok(Self {
165            schedule,
166            dispatch_config: alert_dispatch_config,
167            alert_conditions: None,
168        })
169    }
170
171    #[getter]
172    pub fn dispatch_type(&self) -> AlertDispatchType {
173        self.dispatch_config.dispatch_type()
174    }
175
176    #[getter]
177    pub fn dispatch_config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
178        self.dispatch_config.config(py)
179    }
180}
181
182impl Default for LLMAlertConfig {
183    fn default() -> LLMAlertConfig {
184        Self {
185            dispatch_config: AlertDispatchConfig::default(),
186            schedule: CommonCrons::EveryDay.cron(),
187            alert_conditions: None,
188        }
189    }
190}
191
192pub struct PromptComparisonMetricAlert {
193    pub metric_name: String,
194    pub training_metric_value: f64,
195    pub observed_metric_value: f64,
196    pub alert_threshold_value: Option<f64>,
197    pub alert_threshold: AlertThreshold,
198}
199
200impl PromptComparisonMetricAlert {
201    fn alert_description_header(&self) -> String {
202        let below_threshold = |boundary: Option<f64>| match boundary {
203            Some(b) => format!(
204                "The observed {} metric value has dropped below the threshold (initial value - {})",
205                self.metric_name, b
206            ),
207            None => format!(
208                "The {} metric value has dropped below the initial value",
209                self.metric_name
210            ),
211        };
212
213        let above_threshold = |boundary: Option<f64>| match boundary {
214            Some(b) => format!(
215                "The {} metric value has increased beyond the threshold (initial value + {})",
216                self.metric_name, b
217            ),
218            None => format!(
219                "The {} metric value has increased beyond the initial value",
220                self.metric_name
221            ),
222        };
223
224        let outside_threshold = |boundary: Option<f64>| match boundary {
225            Some(b) => format!(
226                "The {} metric value has fallen outside the threshold (initial value ± {})",
227                self.metric_name, b,
228            ),
229            None => format!(
230                "The metric value has fallen outside the initial value for {}",
231                self.metric_name
232            ),
233        };
234
235        match self.alert_threshold {
236            AlertThreshold::Below => below_threshold(self.alert_threshold_value),
237            AlertThreshold::Above => above_threshold(self.alert_threshold_value),
238            AlertThreshold::Outside => outside_threshold(self.alert_threshold_value),
239        }
240    }
241}
242
243impl DispatchAlertDescription for PromptComparisonMetricAlert {
244    // TODO make pretty per dispatch type
245    fn create_alert_description(&self, _dispatch_type: AlertDispatchType) -> String {
246        let mut alert_description = String::new();
247        let header = format!("{}\n", self.alert_description_header());
248        alert_description.push_str(&header);
249
250        let current_metric = format!("Current Metric Value: {}\n", self.observed_metric_value);
251        let historical_metric = format!("Initial Metric Value: {}\n", self.training_metric_value);
252
253        alert_description.push_str(&historical_metric);
254        alert_description.push_str(&current_metric);
255
256        alert_description
257    }
258}
259
260#[cfg(test)]
261#[cfg(feature = "mock")]
262mod tests {
263    use super::*;
264    use potato_head::create_score_prompt;
265
266    #[test]
267    fn test_alert_config() {
268        //test console alert config
269        let dispatch_config = AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
270            team: "test-team".to_string(),
271            priority: "P5".to_string(),
272        });
273        let schedule = "0 0 * * * *".to_string();
274        let mut alert_config = LLMAlertConfig {
275            dispatch_config,
276            schedule,
277            ..Default::default()
278        };
279        assert_eq!(alert_config.dispatch_type(), AlertDispatchType::OpsGenie);
280
281        let prompt = create_score_prompt(Some(vec!["input".to_string()]));
282
283        let llm_metrics = vec![
284            LLMDriftMetric::new(
285                "mae",
286                12.4,
287                AlertThreshold::Above,
288                Some(2.3),
289                Some(prompt.clone()),
290            )
291            .unwrap(),
292            LLMDriftMetric::new(
293                "accuracy",
294                0.85,
295                AlertThreshold::Below,
296                None,
297                Some(prompt.clone()),
298            )
299            .unwrap(),
300        ];
301
302        alert_config.set_alert_conditions(&llm_metrics);
303
304        if let Some(alert_conditions) = alert_config.alert_conditions.as_ref() {
305            assert_eq!(
306                alert_conditions["mae"].alert_threshold,
307                AlertThreshold::Above
308            );
309            assert_eq!(alert_conditions["mae"].alert_threshold_value, Some(2.3));
310            assert_eq!(
311                alert_conditions["accuracy"].alert_threshold,
312                AlertThreshold::Below
313            );
314            assert_eq!(alert_conditions["accuracy"].alert_threshold_value, None);
315        } else {
316            panic!("alert_conditions should not be None");
317        }
318    }
319}