1use crate::error::{ProfileError, TypeError};
2use crate::llm::alert::LLMAlertConfig;
3use crate::llm::alert::LLMDriftMetric;
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::ProfileRequest;
6use crate::{scouter_version, LLMMetricRecord};
7use crate::{
8 DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
9 PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
10};
11use core::fmt::Debug;
12use potato_head::prompt::ResponseType;
13use potato_head::Task;
14use potato_head::Workflow;
15use potato_head::{Agent, Prompt};
16use pyo3::prelude::*;
17use pyo3::types::PyDict;
18use scouter_semver::VersionType;
19use scouter_state::app_state;
20use serde::{Deserialize, Serialize};
21use serde_json::Value;
22use std::collections::hash_map::Entry;
23use std::collections::HashMap;
24use std::path::PathBuf;
25use std::sync::Arc;
26use tracing::{debug, error, instrument};
27
28#[pyclass]
29#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
30pub struct LLMDriftConfig {
31 #[pyo3(get, set)]
32 pub sample_rate: usize,
33
34 #[pyo3(get, set)]
35 pub space: String,
36
37 #[pyo3(get, set)]
38 pub name: String,
39
40 #[pyo3(get, set)]
41 pub version: String,
42
43 #[pyo3(get, set)]
44 pub alert_config: LLMAlertConfig,
45
46 #[pyo3(get, set)]
47 #[serde(default = "default_drift_type")]
48 pub drift_type: DriftType,
49}
50
51fn default_drift_type() -> DriftType {
52 DriftType::LLM
53}
54
55impl DispatchDriftConfig for LLMDriftConfig {
56 fn get_drift_args(&self) -> DriftArgs {
57 DriftArgs {
58 name: self.name.clone(),
59 space: self.space.clone(),
60 version: self.version.clone(),
61 dispatch_config: self.alert_config.dispatch_config.clone(),
62 }
63 }
64}
65
66#[pymethods]
67#[allow(clippy::too_many_arguments)]
68impl LLMDriftConfig {
69 #[new]
70 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_rate=5, alert_config=LLMAlertConfig::default(), config_path=None))]
71 pub fn new(
72 space: &str,
73 name: &str,
74 version: &str,
75 sample_rate: usize,
76 alert_config: LLMAlertConfig,
77 config_path: Option<PathBuf>,
78 ) -> Result<Self, ProfileError> {
79 if let Some(config_path) = config_path {
80 let config = LLMDriftConfig::load_from_json_file(config_path)?;
81 return Ok(config);
82 }
83
84 Ok(Self {
85 sample_rate,
86 space: space.to_string(),
87 name: name.to_string(),
88 version: version.to_string(),
89 alert_config,
90 drift_type: DriftType::LLM,
91 })
92 }
93
94 #[staticmethod]
95 pub fn load_from_json_file(path: PathBuf) -> Result<LLMDriftConfig, ProfileError> {
96 let file = std::fs::read_to_string(&path)?;
99
100 Ok(serde_json::from_str(&file)?)
101 }
102
103 pub fn __str__(&self) -> String {
104 PyHelperFuncs::__str__(self)
106 }
107
108 pub fn model_dump_json(&self) -> String {
109 PyHelperFuncs::__json__(self)
111 }
112
113 #[allow(clippy::too_many_arguments)]
114 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
115 pub fn update_config_args(
116 &mut self,
117 space: Option<String>,
118 name: Option<String>,
119 version: Option<String>,
120 alert_config: Option<LLMAlertConfig>,
121 ) -> Result<(), TypeError> {
122 if name.is_some() {
123 self.name = name.ok_or(TypeError::MissingNameError)?;
124 }
125
126 if space.is_some() {
127 self.space = space.ok_or(TypeError::MissingSpaceError)?;
128 }
129
130 if version.is_some() {
131 self.version = version.ok_or(TypeError::MissingVersionError)?;
132 }
133
134 if alert_config.is_some() {
135 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
136 }
137
138 Ok(())
139 }
140}
141
142fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
158 let has_at_least_one_param = !prompt.parameters.is_empty();
159
160 if !has_at_least_one_param {
161 return Err(ProfileError::NeedAtLeastOneBoundParameterError(
162 id.to_string(),
163 ));
164 }
165
166 Ok(())
167}
168
169fn validate_prompt_response_type(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
181 if prompt.response_type != ResponseType::Score {
182 return Err(ProfileError::InvalidResponseType(id.to_string()));
183 }
184
185 Ok(())
186}
187
188fn get_workflow_task<'a>(
190 workflow: &'a Workflow,
191 task_id: &'a str,
192) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
193 workflow
194 .task_list
195 .tasks
196 .get(task_id)
197 .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
198}
199
200fn validate_first_tasks(
202 workflow: &Workflow,
203 execution_order: &HashMap<i32, std::collections::HashSet<String>>,
204) -> Result<(), ProfileError> {
205 let first_tasks = execution_order
206 .get(&1)
207 .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
208
209 for task_id in first_tasks {
210 let task = get_workflow_task(workflow, task_id)?;
211 let task_guard = task
212 .read()
213 .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
214
215 validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
216 }
217
218 Ok(())
219}
220
221fn validate_last_tasks(
223 workflow: &Workflow,
224 execution_order: &HashMap<i32, std::collections::HashSet<String>>,
225 metrics: &[LLMDriftMetric],
226) -> Result<(), ProfileError> {
227 let last_step = execution_order.len() as i32;
228 let last_tasks = execution_order
229 .get(&last_step)
230 .ok_or_else(|| ProfileError::NoTasksFoundError("No final tasks found".to_string()))?;
231
232 for task_id in last_tasks {
233 if !metrics.iter().any(|m| m.name == *task_id) {
235 return Err(ProfileError::MetricNotFoundForOutputTask(task_id.clone()));
236 }
237
238 let task = get_workflow_task(workflow, task_id)?;
239 let task_guard = task
240 .read()
241 .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
242
243 validate_prompt_response_type(&task_guard.prompt, &task_guard.id)?;
244 }
245
246 Ok(())
247}
248
249fn validate_workflow(workflow: &Workflow, metrics: &[LLMDriftMetric]) -> Result<(), ProfileError> {
265 let execution_order = workflow.execution_plan()?;
266
267 validate_first_tasks(workflow, &execution_order)?;
269
270 validate_last_tasks(workflow, &execution_order, metrics)?;
272
273 Ok(())
274}
275
276#[pyclass]
277#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
278pub struct LLMDriftProfile {
279 #[pyo3(get)]
280 pub config: LLMDriftConfig,
281
282 #[pyo3(get)]
283 pub metrics: Vec<LLMDriftMetric>,
284
285 #[pyo3(get)]
286 pub scouter_version: String,
287
288 pub workflow: Workflow,
289
290 #[pyo3(get)]
291 pub metric_names: Vec<String>,
292}
293#[pymethods]
294impl LLMDriftProfile {
295 #[new]
296 #[pyo3(signature = (config, metrics, workflow=None))]
297 #[instrument(skip_all)]
317 pub fn new(
318 config: LLMDriftConfig,
319 metrics: Vec<LLMDriftMetric>,
320 workflow: Option<Bound<'_, PyAny>>,
321 ) -> Result<Self, ProfileError> {
322 match workflow {
323 Some(py_workflow) => {
324 let workflow = Self::extract_workflow(&py_workflow).map_err(|e| {
326 error!("Failed to extract workflow: {}", e);
327 e
328 })?;
329 validate_workflow(&workflow, &metrics)?;
330 Self::from_workflow(config, workflow, metrics)
331 }
332 None => {
333 if metrics.is_empty() {
335 return Err(ProfileError::EmptyMetricsList);
336 }
337 app_state().block_on(async { Self::from_metrics(config, metrics).await })
338 }
339 }
340 }
341
342 pub fn __str__(&self) -> String {
343 PyHelperFuncs::__str__(self)
345 }
346
347 pub fn model_dump_json(&self) -> String {
348 PyHelperFuncs::__json__(self)
350 }
351
352 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
353 let json_str = serde_json::to_string(&self)?;
354
355 let json_value: Value = serde_json::from_str(&json_str)?;
356
357 let dict = PyDict::new(py);
359
360 json_to_pyobject(py, &json_value, &dict)?;
362
363 Ok(dict.into())
365 }
366
367 #[pyo3(signature = (path=None))]
368 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
369 Ok(PyHelperFuncs::save_to_json(
370 self,
371 path,
372 FileName::LLMDriftProfile.to_str(),
373 )?)
374 }
375
376 #[staticmethod]
377 pub fn model_validate(data: &Bound<'_, PyDict>) -> LLMDriftProfile {
378 let json_value = pyobject_to_json(data).unwrap();
379
380 let string = serde_json::to_string(&json_value).unwrap();
381 serde_json::from_str(&string).expect("Failed to load drift profile")
382 }
383
384 #[staticmethod]
385 pub fn model_validate_json(json_string: String) -> LLMDriftProfile {
386 serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
388 }
389
390 #[staticmethod]
391 pub fn from_file(path: PathBuf) -> Result<LLMDriftProfile, ProfileError> {
392 let file = std::fs::read_to_string(&path)?;
393
394 Ok(serde_json::from_str(&file)?)
395 }
396
397 #[allow(clippy::too_many_arguments)]
398 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
399 pub fn update_config_args(
400 &mut self,
401 space: Option<String>,
402 name: Option<String>,
403 version: Option<String>,
404 alert_config: Option<LLMAlertConfig>,
405 ) -> Result<(), TypeError> {
406 self.config
407 .update_config_args(space, name, version, alert_config)
408 }
409
410 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
412 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
413 None
414 } else {
415 Some(self.config.version.clone())
416 };
417
418 Ok(ProfileRequest {
419 space: self.config.space.clone(),
420 profile: self.model_dump_json(),
421 drift_type: self.config.drift_type.clone(),
422 version_request: VersionRequest {
423 version,
424 version_type: VersionType::Minor,
425 pre_tag: None,
426 build_tag: None,
427 },
428 active: false,
429 deactivate_others: false,
430 })
431 }
432}
433
434impl LLMDriftProfile {
435 pub async fn from_metrics(
444 mut config: LLMDriftConfig,
445 metrics: Vec<LLMDriftMetric>,
446 ) -> Result<Self, ProfileError> {
447 let mut workflow = Workflow::new("llm_drift_workflow");
449 let mut agents = HashMap::new();
450 let mut metric_names = Vec::new();
451
452 for metric in &metrics {
455 let prompt = metric
457 .prompt
458 .as_ref()
459 .ok_or_else(|| ProfileError::MissingPromptError(metric.name.clone()))?;
460
461 let provider = prompt.model_settings.provider();
462
463 let agent = match agents.entry(provider) {
464 Entry::Occupied(entry) => entry.into_mut(),
465 Entry::Vacant(entry) => {
466 let agent = Agent::from_model_settings(&prompt.model_settings).await?;
467 workflow.add_agent(&agent);
468 entry.insert(agent)
469 }
470 };
471
472 let task = Task::new(&agent.id, prompt.clone(), &metric.name, None, None);
473 validate_prompt_parameters(prompt, &metric.name)?;
474 workflow.add_task(task)?;
475 metric_names.push(metric.name.clone());
476 }
477
478 config.alert_config.set_alert_conditions(&metrics);
479
480 Ok(Self {
481 config,
482 metrics,
483 scouter_version: scouter_version(),
484 workflow,
485 metric_names,
486 })
487 }
488
489 pub fn from_workflow(
507 mut config: LLMDriftConfig,
508 workflow: Workflow,
509 metrics: Vec<LLMDriftMetric>,
510 ) -> Result<Self, ProfileError> {
511 validate_workflow(&workflow, &metrics)?;
512
513 config.alert_config.set_alert_conditions(&metrics);
514
515 let metric_names = metrics.iter().map(|m| m.name.clone()).collect();
516
517 Ok(Self {
518 config,
519 metrics,
520 scouter_version: scouter_version(),
521 workflow,
522 metric_names,
523 })
524 }
525
526 #[instrument(skip_all)]
537 fn extract_workflow(py_workflow: &Bound<'_, PyAny>) -> Result<Workflow, ProfileError> {
538 debug!("Extracting workflow from Python object");
539
540 if !py_workflow.hasattr("__workflow__")? {
541 error!("Invalid workflow type provided. Expected object with __workflow__ method.");
542 return Err(ProfileError::InvalidWorkflowType);
543 }
544
545 let workflow_string = py_workflow
546 .getattr("__workflow__")
547 .map_err(|e| {
548 error!("Failed to call __workflow__ property: {}", e);
549 ProfileError::InvalidWorkflowType
550 })?
551 .extract::<String>()
552 .inspect_err(|e| {
553 error!(
554 "Failed to extract workflow string from Python object: {}",
555 e
556 );
557 })?;
558
559 serde_json::from_str(&workflow_string).map_err(|e| {
560 error!("Failed to deserialize workflow: {}", e);
561 ProfileError::InvalidWorkflowType
562 })
563 }
564
565 pub fn get_metric_value(&self, metric_name: &str) -> Result<f64, ProfileError> {
566 self.metrics
567 .iter()
568 .find(|m| m.name == metric_name)
569 .map(|m| m.value)
570 .ok_or_else(|| ProfileError::MetricNotFound(metric_name.to_string()))
571 }
572
573 pub fn new_with_runtime(
576 config: LLMDriftConfig,
577 metrics: Vec<LLMDriftMetric>,
578 workflow: Option<Bound<'_, PyAny>>,
579 runtime: Arc<tokio::runtime::Runtime>,
580 ) -> Result<Self, ProfileError> {
581 match workflow {
582 Some(py_workflow) => {
583 let workflow = Self::extract_workflow(&py_workflow).map_err(|e| {
585 error!("Failed to extract workflow: {}", e);
586 e
587 })?;
588 validate_workflow(&workflow, &metrics)?;
589 Self::from_workflow(config, workflow, metrics)
590 }
591 None => {
592 if metrics.is_empty() {
594 return Err(ProfileError::EmptyMetricsList);
595 }
596 runtime.block_on(async { Self::from_metrics(config, metrics).await })
597 }
598 }
599 }
600}
601
602impl ProfileBaseArgs for LLMDriftProfile {
603 fn get_base_args(&self) -> ProfileArgs {
604 ProfileArgs {
605 name: self.config.name.clone(),
606 space: self.config.space.clone(),
607 version: Some(self.config.version.clone()),
608 schedule: self.config.alert_config.schedule.clone(),
609 scouter_version: self.scouter_version.clone(),
610 drift_type: self.config.drift_type.clone(),
611 }
612 }
613
614 fn to_value(&self) -> Value {
615 serde_json::to_value(self).unwrap()
616 }
617}
618
619#[pyclass]
620#[derive(Debug, Serialize, Deserialize, Clone)]
621pub struct LLMDriftMap {
622 #[pyo3(get)]
623 pub records: Vec<LLMMetricRecord>,
624}
625
626#[pymethods]
627impl LLMDriftMap {
628 #[new]
629 pub fn new(records: Vec<LLMMetricRecord>) -> Self {
630 Self { records }
631 }
632
633 pub fn __str__(&self) -> String {
634 PyHelperFuncs::__str__(self)
636 }
637}
638
639#[cfg(test)]
641#[cfg(feature = "mock")]
642mod tests {
643
644 use super::*;
645 use crate::AlertThreshold;
646 use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
647 use potato_head::create_score_prompt;
648 use potato_head::prompt::ResponseType;
649 use potato_head::Provider;
650
651 use potato_head::{LLMTestServer, Message, PromptContent};
652
653 pub fn create_parameterized_prompt() -> Prompt {
654 let user_content =
655 PromptContent::Str("What is ${input} + ${response} + ${context}?".to_string());
656 let system_content = PromptContent::Str("You are a helpful assistant.".to_string());
657 Prompt::new_rs(
658 vec![Message::new_rs(user_content)],
659 "gpt-4o",
660 Provider::OpenAI,
661 vec![Message::new_rs(system_content)],
662 None,
663 None,
664 ResponseType::Null,
665 )
666 .unwrap()
667 }
668
669 #[test]
670 fn test_llm_drift_config() {
671 let mut drift_config = LLMDriftConfig::new(
672 MISSING,
673 MISSING,
674 "0.1.0",
675 25,
676 LLMAlertConfig::default(),
677 None,
678 )
679 .unwrap();
680 assert_eq!(drift_config.name, "__missing__");
681 assert_eq!(drift_config.space, "__missing__");
682 assert_eq!(drift_config.version, "0.1.0");
683 assert_eq!(
684 drift_config.alert_config.dispatch_config,
685 AlertDispatchConfig::default()
686 );
687
688 let test_slack_dispatch_config = SlackDispatchConfig {
689 channel: "test-channel".to_string(),
690 };
691 let new_alert_config = LLMAlertConfig {
692 schedule: "0 0 * * * *".to_string(),
693 dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
694 ..Default::default()
695 };
696
697 drift_config
698 .update_config_args(None, Some("test".to_string()), None, Some(new_alert_config))
699 .unwrap();
700
701 assert_eq!(drift_config.name, "test");
702 assert_eq!(
703 drift_config.alert_config.dispatch_config,
704 AlertDispatchConfig::Slack(test_slack_dispatch_config)
705 );
706 assert_eq!(
707 drift_config.alert_config.schedule,
708 "0 0 * * * *".to_string()
709 );
710 }
711
712 #[test]
713 fn test_llm_drift_profile_metric() {
714 let runtime = tokio::runtime::Runtime::new().unwrap();
715 let mut mock = LLMTestServer::new();
716 mock.start_server().unwrap();
717 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
718
719 let metric1 = LLMDriftMetric::new(
720 "metric1",
721 5.0,
722 AlertThreshold::Above,
723 None,
724 Some(prompt.clone()),
725 )
726 .unwrap();
727 let metric2 = LLMDriftMetric::new(
728 "metric2",
729 3.0,
730 AlertThreshold::Below,
731 Some(1.0),
732 Some(prompt.clone()),
733 )
734 .unwrap();
735
736 let llm_metrics = vec![metric1, metric2];
737
738 let alert_config = LLMAlertConfig {
739 schedule: "0 0 * * * *".to_string(),
740 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
741 team: "test-team".to_string(),
742 priority: "P5".to_string(),
743 }),
744 ..Default::default()
745 };
746
747 let drift_config =
748 LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
749
750 let profile = runtime
751 .block_on(async { LLMDriftProfile::from_metrics(drift_config, llm_metrics).await })
752 .unwrap();
753 let _: Value =
754 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
755
756 assert_eq!(profile.metrics.len(), 2);
757 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
758
759 mock.stop_server().unwrap();
760 }
761
762 #[test]
763 fn test_llm_drift_profile_workflow() {
764 let mut mock = LLMTestServer::new();
765 mock.start_server().unwrap();
766 let runtime = tokio::runtime::Runtime::new().unwrap();
767
768 let mut workflow = Workflow::new("My eval Workflow");
769
770 let initial_prompt = create_parameterized_prompt();
771 let final_prompt1 = create_score_prompt(None);
772 let final_prompt2 = create_score_prompt(None);
773
774 let agent1 = runtime
775 .block_on(async { Agent::new(Provider::OpenAI, None).await })
776 .unwrap();
777
778 workflow.add_agent(&agent1);
779
780 workflow
782 .add_task(Task::new(
783 &agent1.id,
784 initial_prompt.clone(),
785 "task1",
786 None,
787 None,
788 ))
789 .unwrap();
790
791 workflow
793 .add_task(Task::new(
794 &agent1.id,
795 final_prompt1.clone(),
796 "task2",
797 Some(vec!["task1".to_string()]),
798 None,
799 ))
800 .unwrap();
801
802 workflow
803 .add_task(Task::new(
804 &agent1.id,
805 final_prompt2.clone(),
806 "task3",
807 Some(vec!["task1".to_string()]),
808 None,
809 ))
810 .unwrap();
811
812 let metric1 =
813 LLMDriftMetric::new("task2", 3.0, AlertThreshold::Below, Some(1.0), None).unwrap();
814 let metric2 =
815 LLMDriftMetric::new("task3", 4.0, AlertThreshold::Above, Some(2.0), None).unwrap();
816
817 let llm_metrics = vec![metric1, metric2];
818
819 let alert_config = LLMAlertConfig {
820 schedule: "0 0 * * * *".to_string(),
821 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
822 team: "test-team".to_string(),
823 priority: "P5".to_string(),
824 }),
825 ..Default::default()
826 };
827
828 let drift_config =
829 LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
830
831 let profile = LLMDriftProfile::from_workflow(drift_config, workflow, llm_metrics).unwrap();
832
833 let _: Value =
834 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
835
836 assert_eq!(profile.metrics.len(), 2);
837 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
838
839 let plan = profile.workflow.execution_plan().unwrap();
840
841 assert_eq!(plan.len(), 2);
843 assert_eq!(plan.get(&1).unwrap().len(), 1);
845 assert_eq!(plan.get(&(plan.len() as i32)).unwrap().len(), 2);
847
848 mock.stop_server().unwrap();
849 }
850}