1use crate::error::TypeError;
2use crate::{
3 dispatch::AlertDispatchType, AlertDispatchConfig, AlertThreshold, CommonCrons,
4 DispatchAlertDescription, OpsGenieDispatchConfig, PyHelperFuncs, SlackDispatchConfig,
5 ValidateAlertConfig,
6};
7use core::fmt::Debug;
8use potato_head::prompt::ResponseType;
9use potato_head::Prompt;
10use pyo3::prelude::*;
11use pyo3::types::PyString;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15#[pyclass]
16#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
17pub struct LLMDriftMetric {
18 #[pyo3(get, set)]
19 pub name: String,
20
21 #[pyo3(get, set)]
22 pub value: f64,
23
24 #[pyo3(get)]
25 pub prompt: Option<Prompt>,
26
27 #[pyo3(get, set)]
28 pub alert_condition: LLMMetricAlertCondition,
29}
30
31#[pymethods]
32impl LLMDriftMetric {
33 #[new]
34 #[pyo3(signature = (name, value, alert_threshold, alert_threshold_value=None, prompt=None))]
35 pub fn new(
36 name: &str,
37 value: f64,
38 alert_threshold: AlertThreshold,
39 alert_threshold_value: Option<f64>,
40 prompt: Option<Prompt>,
41 ) -> Result<Self, TypeError> {
42 if let Some(ref prompt) = prompt {
44 if prompt.response_type != ResponseType::Score {
45 return Err(TypeError::InvalidResponseType);
46 }
47 }
48
49 let prompt_condition = LLMMetricAlertCondition::new(alert_threshold, alert_threshold_value);
50
51 Ok(Self {
52 name: name.to_lowercase(),
53 value,
54 prompt,
55 alert_condition: prompt_condition,
56 })
57 }
58
59 pub fn __str__(&self) -> String {
60 PyHelperFuncs::__str__(self)
62 }
63
64 #[getter]
65 pub fn alert_threshold(&self) -> AlertThreshold {
66 self.alert_condition.alert_threshold.clone()
67 }
68
69 #[getter]
70 pub fn alert_threshold_value(&self) -> Option<f64> {
71 self.alert_condition.alert_threshold_value
72 }
73}
74
75#[pyclass]
76#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
77pub struct LLMMetricAlertCondition {
78 #[pyo3(get, set)]
79 pub alert_threshold: AlertThreshold,
80
81 #[pyo3(get, set)]
82 pub alert_threshold_value: Option<f64>,
83}
84
85#[pymethods]
86#[allow(clippy::too_many_arguments)]
87impl LLMMetricAlertCondition {
88 #[new]
89 #[pyo3(signature = (alert_threshold, alert_threshold_value=None))]
90 pub fn new(alert_threshold: AlertThreshold, alert_threshold_value: Option<f64>) -> Self {
91 Self {
92 alert_threshold,
93 alert_threshold_value,
94 }
95 }
96
97 pub fn __str__(&self) -> String {
98 PyHelperFuncs::__str__(self)
100 }
101}
102
103#[pyclass]
104#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
105pub struct LLMAlertConfig {
106 pub dispatch_config: AlertDispatchConfig,
107
108 #[pyo3(get, set)]
109 pub schedule: String,
110
111 #[pyo3(get, set)]
112 pub alert_conditions: Option<HashMap<String, LLMMetricAlertCondition>>,
113}
114
115impl LLMAlertConfig {
116 pub fn set_alert_conditions(&mut self, metrics: &[LLMDriftMetric]) {
117 self.alert_conditions = Some(
118 metrics
119 .iter()
120 .map(|m| (m.name.clone(), m.alert_condition.clone()))
121 .collect(),
122 );
123 }
124}
125
126impl ValidateAlertConfig for LLMAlertConfig {}
127
128#[pymethods]
129impl LLMAlertConfig {
130 #[new]
131 #[pyo3(signature = (schedule=None, dispatch_config=None))]
132 pub fn new(
133 schedule: Option<&Bound<'_, PyAny>>,
134 dispatch_config: Option<&Bound<'_, PyAny>>,
135 ) -> Result<Self, TypeError> {
136 let alert_dispatch_config = match dispatch_config {
137 None => AlertDispatchConfig::default(),
138 Some(config) => {
139 if config.is_instance_of::<SlackDispatchConfig>() {
140 AlertDispatchConfig::Slack(config.extract::<SlackDispatchConfig>()?)
141 } else if config.is_instance_of::<OpsGenieDispatchConfig>() {
142 AlertDispatchConfig::OpsGenie(config.extract::<OpsGenieDispatchConfig>()?)
143 } else {
144 AlertDispatchConfig::default()
145 }
146 }
147 };
148
149 let schedule = match schedule {
150 Some(schedule) => {
151 if schedule.is_instance_of::<PyString>() {
152 schedule.to_string()
153 } else if schedule.is_instance_of::<CommonCrons>() {
154 schedule.extract::<CommonCrons>().unwrap().cron()
155 } else {
156 return Err(TypeError::InvalidScheduleError)?;
157 }
158 }
159 None => CommonCrons::EveryDay.cron(),
160 };
161
162 let schedule = Self::resolve_schedule(&schedule);
163
164 Ok(Self {
165 schedule,
166 dispatch_config: alert_dispatch_config,
167 alert_conditions: None,
168 })
169 }
170
171 #[getter]
172 pub fn dispatch_type(&self) -> AlertDispatchType {
173 self.dispatch_config.dispatch_type()
174 }
175
176 #[getter]
177 pub fn dispatch_config<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
178 self.dispatch_config.config(py)
179 }
180}
181
182impl Default for LLMAlertConfig {
183 fn default() -> LLMAlertConfig {
184 Self {
185 dispatch_config: AlertDispatchConfig::default(),
186 schedule: CommonCrons::EveryDay.cron(),
187 alert_conditions: None,
188 }
189 }
190}
191
192pub struct PromptComparisonMetricAlert {
193 pub metric_name: String,
194 pub training_metric_value: f64,
195 pub observed_metric_value: f64,
196 pub alert_threshold_value: Option<f64>,
197 pub alert_threshold: AlertThreshold,
198}
199
200impl PromptComparisonMetricAlert {
201 fn alert_description_header(&self) -> String {
202 let below_threshold = |boundary: Option<f64>| match boundary {
203 Some(b) => format!(
204 "The observed {} metric value has dropped below the threshold (initial value - {})",
205 self.metric_name, b
206 ),
207 None => format!(
208 "The {} metric value has dropped below the initial value",
209 self.metric_name
210 ),
211 };
212
213 let above_threshold = |boundary: Option<f64>| match boundary {
214 Some(b) => format!(
215 "The {} metric value has increased beyond the threshold (initial value + {})",
216 self.metric_name, b
217 ),
218 None => format!(
219 "The {} metric value has increased beyond the initial value",
220 self.metric_name
221 ),
222 };
223
224 let outside_threshold = |boundary: Option<f64>| match boundary {
225 Some(b) => format!(
226 "The {} metric value has fallen outside the threshold (initial value ± {})",
227 self.metric_name, b,
228 ),
229 None => format!(
230 "The metric value has fallen outside the initial value for {}",
231 self.metric_name
232 ),
233 };
234
235 match self.alert_threshold {
236 AlertThreshold::Below => below_threshold(self.alert_threshold_value),
237 AlertThreshold::Above => above_threshold(self.alert_threshold_value),
238 AlertThreshold::Outside => outside_threshold(self.alert_threshold_value),
239 }
240 }
241}
242
243impl DispatchAlertDescription for PromptComparisonMetricAlert {
244 fn create_alert_description(&self, _dispatch_type: AlertDispatchType) -> String {
246 let mut alert_description = String::new();
247 let header = format!("{}\n", self.alert_description_header());
248 alert_description.push_str(&header);
249
250 let current_metric = format!("Current Metric Value: {}\n", self.observed_metric_value);
251 let historical_metric = format!("Initial Metric Value: {}\n", self.training_metric_value);
252
253 alert_description.push_str(&historical_metric);
254 alert_description.push_str(¤t_metric);
255
256 alert_description
257 }
258}
259
260#[cfg(test)]
261#[cfg(feature = "mock")]
262mod tests {
263 use super::*;
264 use potato_head::create_score_prompt;
265
266 #[test]
267 fn test_alert_config() {
268 let dispatch_config = AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
270 team: "test-team".to_string(),
271 priority: "P5".to_string(),
272 });
273 let schedule = "0 0 * * * *".to_string();
274 let mut alert_config = LLMAlertConfig {
275 dispatch_config,
276 schedule,
277 ..Default::default()
278 };
279 assert_eq!(alert_config.dispatch_type(), AlertDispatchType::OpsGenie);
280
281 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
282
283 let llm_metrics = vec![
284 LLMDriftMetric::new(
285 "mae",
286 12.4,
287 AlertThreshold::Above,
288 Some(2.3),
289 Some(prompt.clone()),
290 )
291 .unwrap(),
292 LLMDriftMetric::new(
293 "accuracy",
294 0.85,
295 AlertThreshold::Below,
296 None,
297 Some(prompt.clone()),
298 )
299 .unwrap(),
300 ];
301
302 alert_config.set_alert_conditions(&llm_metrics);
303
304 if let Some(alert_conditions) = alert_config.alert_conditions.as_ref() {
305 assert_eq!(
306 alert_conditions["mae"].alert_threshold,
307 AlertThreshold::Above
308 );
309 assert_eq!(alert_conditions["mae"].alert_threshold_value, Some(2.3));
310 assert_eq!(
311 alert_conditions["accuracy"].alert_threshold,
312 AlertThreshold::Below
313 );
314 assert_eq!(alert_conditions["accuracy"].alert_threshold_value, None);
315 } else {
316 panic!("alert_conditions should not be None");
317 }
318 }
319}