1use crate::error::{ProfileError, TypeError};
2use crate::llm::alert::LLMAlertConfig;
3use crate::llm::alert::LLMMetric;
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::ProfileRequest;
6use crate::{scouter_version, LLMMetricRecord};
7use crate::{
8 DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
9 ProfileFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
10};
11use core::fmt::Debug;
12use potato_head::prompt::ResponseType;
13use potato_head::Task;
14use potato_head::{Agent, Prompt};
15use potato_head::{Provider, Workflow};
16use pyo3::prelude::*;
17use pyo3::types::PyDict;
18use scouter_semver::VersionType;
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use std::collections::hash_map::Entry;
22use std::collections::HashMap;
23use std::path::PathBuf;
24use std::sync::Arc;
25use tracing::{debug, error, instrument};
26
27#[pyclass]
28#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
29pub struct LLMDriftConfig {
30 #[pyo3(get, set)]
31 pub sample_rate: usize,
32
33 #[pyo3(get, set)]
34 pub space: String,
35
36 #[pyo3(get, set)]
37 pub name: String,
38
39 #[pyo3(get, set)]
40 pub version: String,
41
42 #[pyo3(get, set)]
43 pub alert_config: LLMAlertConfig,
44
45 #[pyo3(get, set)]
46 #[serde(default = "default_drift_type")]
47 pub drift_type: DriftType,
48}
49
50fn default_drift_type() -> DriftType {
51 DriftType::LLM
52}
53
54impl DispatchDriftConfig for LLMDriftConfig {
55 fn get_drift_args(&self) -> DriftArgs {
56 DriftArgs {
57 name: self.name.clone(),
58 space: self.space.clone(),
59 version: self.version.clone(),
60 dispatch_config: self.alert_config.dispatch_config.clone(),
61 }
62 }
63}
64
65#[pymethods]
66#[allow(clippy::too_many_arguments)]
67impl LLMDriftConfig {
68 #[new]
69 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_rate=5, alert_config=LLMAlertConfig::default(), config_path=None))]
70 pub fn new(
71 space: &str,
72 name: &str,
73 version: &str,
74 sample_rate: usize,
75 alert_config: LLMAlertConfig,
76 config_path: Option<PathBuf>,
77 ) -> Result<Self, ProfileError> {
78 if let Some(config_path) = config_path {
79 let config = LLMDriftConfig::load_from_json_file(config_path)?;
80 return Ok(config);
81 }
82
83 Ok(Self {
84 sample_rate,
85 space: space.to_string(),
86 name: name.to_string(),
87 version: version.to_string(),
88 alert_config,
89 drift_type: DriftType::LLM,
90 })
91 }
92
93 #[staticmethod]
94 pub fn load_from_json_file(path: PathBuf) -> Result<LLMDriftConfig, ProfileError> {
95 let file = std::fs::read_to_string(&path)?;
98
99 Ok(serde_json::from_str(&file)?)
100 }
101
102 pub fn __str__(&self) -> String {
103 ProfileFuncs::__str__(self)
105 }
106
107 pub fn model_dump_json(&self) -> String {
108 ProfileFuncs::__json__(self)
110 }
111
112 #[allow(clippy::too_many_arguments)]
113 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
114 pub fn update_config_args(
115 &mut self,
116 space: Option<String>,
117 name: Option<String>,
118 version: Option<String>,
119 alert_config: Option<LLMAlertConfig>,
120 ) -> Result<(), TypeError> {
121 if name.is_some() {
122 self.name = name.ok_or(TypeError::MissingNameError)?;
123 }
124
125 if space.is_some() {
126 self.space = space.ok_or(TypeError::MissingSpaceError)?;
127 }
128
129 if version.is_some() {
130 self.version = version.ok_or(TypeError::MissingVersionError)?;
131 }
132
133 if alert_config.is_some() {
134 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
135 }
136
137 Ok(())
138 }
139}
140
141fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
157 let has_at_least_one_param = !prompt.parameters.is_empty();
158
159 if !has_at_least_one_param {
160 return Err(ProfileError::NeedAtLeastOneBoundParameterError(
161 id.to_string(),
162 ));
163 }
164
165 Ok(())
166}
167
168fn validate_prompt_response_type(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
180 if prompt.response_type != ResponseType::Score {
181 return Err(ProfileError::InvalidResponseType(id.to_string()));
182 }
183
184 Ok(())
185}
186
187fn get_workflow_task<'a>(
189 workflow: &'a Workflow,
190 task_id: &'a str,
191) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
192 workflow
193 .task_list
194 .tasks
195 .get(task_id)
196 .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
197}
198
199fn validate_first_tasks(
201 workflow: &Workflow,
202 execution_order: &HashMap<i32, std::collections::HashSet<String>>,
203) -> Result<(), ProfileError> {
204 let first_tasks = execution_order
205 .get(&1)
206 .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
207
208 for task_id in first_tasks {
209 let task = get_workflow_task(workflow, task_id)?;
210 let task_guard = task
211 .read()
212 .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
213
214 validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
215 }
216
217 Ok(())
218}
219
220fn validate_last_tasks(
222 workflow: &Workflow,
223 execution_order: &HashMap<i32, std::collections::HashSet<String>>,
224 metrics: &[LLMMetric],
225) -> Result<(), ProfileError> {
226 let last_step = execution_order.len() as i32;
227 let last_tasks = execution_order
228 .get(&last_step)
229 .ok_or_else(|| ProfileError::NoTasksFoundError("No final tasks found".to_string()))?;
230
231 for task_id in last_tasks {
232 if !metrics.iter().any(|m| m.name == *task_id) {
234 return Err(ProfileError::MetricNotFoundForOutputTask(task_id.clone()));
235 }
236
237 let task = get_workflow_task(workflow, task_id)?;
238 let task_guard = task
239 .read()
240 .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
241
242 validate_prompt_response_type(&task_guard.prompt, &task_guard.id)?;
243 }
244
245 Ok(())
246}
247
248fn validate_workflow(workflow: &Workflow, metrics: &[LLMMetric]) -> Result<(), ProfileError> {
264 let execution_order = workflow.execution_plan()?;
265
266 validate_first_tasks(workflow, &execution_order)?;
268
269 validate_last_tasks(workflow, &execution_order, metrics)?;
271
272 Ok(())
273}
274
275#[pyclass]
276#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
277pub struct LLMDriftProfile {
278 #[pyo3(get)]
279 pub config: LLMDriftConfig,
280
281 #[pyo3(get)]
282 pub metrics: Vec<LLMMetric>,
283
284 #[pyo3(get)]
285 pub scouter_version: String,
286
287 pub workflow: Workflow,
288
289 #[pyo3(get)]
290 pub metric_names: Vec<String>,
291}
292#[pymethods]
293impl LLMDriftProfile {
294 #[new]
295 #[pyo3(signature = (config, metrics, workflow=None))]
296 #[instrument(skip_all)]
316 pub fn new(
317 config: LLMDriftConfig,
318 metrics: Vec<LLMMetric>,
319 workflow: Option<Bound<'_, PyAny>>,
320 ) -> Result<Self, ProfileError> {
321 match workflow {
322 Some(py_workflow) => {
323 let workflow = Self::extract_workflow(&py_workflow).map_err(|e| {
325 error!("Failed to extract workflow: {}", e);
326 e
327 })?;
328 validate_workflow(&workflow, &metrics)?;
329 Self::from_workflow(config, workflow, metrics)
330 }
331 None => {
332 if metrics.is_empty() {
334 return Err(ProfileError::EmptyMetricsList);
335 }
336 Self::from_metrics(config, metrics)
337 }
338 }
339 }
340
341 pub fn __str__(&self) -> String {
342 ProfileFuncs::__str__(self)
344 }
345
346 pub fn model_dump_json(&self) -> String {
347 ProfileFuncs::__json__(self)
349 }
350
351 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
352 let json_str = serde_json::to_string(&self)?;
353
354 let json_value: Value = serde_json::from_str(&json_str)?;
355
356 let dict = PyDict::new(py);
358
359 json_to_pyobject(py, &json_value, &dict)?;
361
362 Ok(dict.into())
364 }
365
366 #[pyo3(signature = (path=None))]
367 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
368 Ok(ProfileFuncs::save_to_json(
369 self,
370 path,
371 FileName::LLMDriftProfile.to_str(),
372 )?)
373 }
374
375 #[staticmethod]
376 pub fn model_validate(data: &Bound<'_, PyDict>) -> LLMDriftProfile {
377 let json_value = pyobject_to_json(data).unwrap();
378
379 let string = serde_json::to_string(&json_value).unwrap();
380 serde_json::from_str(&string).expect("Failed to load drift profile")
381 }
382
383 #[staticmethod]
384 pub fn model_validate_json(json_string: String) -> LLMDriftProfile {
385 serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
387 }
388
389 #[staticmethod]
390 pub fn from_file(path: PathBuf) -> Result<LLMDriftProfile, ProfileError> {
391 let file = std::fs::read_to_string(&path)?;
392
393 Ok(serde_json::from_str(&file)?)
394 }
395
396 #[allow(clippy::too_many_arguments)]
397 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
398 pub fn update_config_args(
399 &mut self,
400 space: Option<String>,
401 name: Option<String>,
402 version: Option<String>,
403 alert_config: Option<LLMAlertConfig>,
404 ) -> Result<(), TypeError> {
405 self.config
406 .update_config_args(space, name, version, alert_config)
407 }
408
409 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
411 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
412 None
413 } else {
414 Some(self.config.version.clone())
415 };
416
417 Ok(ProfileRequest {
418 space: self.config.space.clone(),
419 profile: self.model_dump_json(),
420 drift_type: self.config.drift_type.clone(),
421 version_request: VersionRequest {
422 version,
423 version_type: VersionType::Minor,
424 pre_tag: None,
425 build_tag: None,
426 },
427 })
428 }
429}
430
431impl LLMDriftProfile {
432 pub fn from_metrics(
441 mut config: LLMDriftConfig,
442 metrics: Vec<LLMMetric>,
443 ) -> Result<Self, ProfileError> {
444 let mut workflow = Workflow::new("llm_drift_workflow");
446 let mut agents = HashMap::new();
447 let mut metric_names = Vec::new();
448
449 for metric in &metrics {
452 let prompt = metric
454 .prompt
455 .as_ref()
456 .ok_or_else(|| ProfileError::MissingPromptError(metric.name.clone()))?;
457
458 let provider = Provider::from_string(&prompt.model_settings.provider)?;
459
460 let agent = match agents.entry(provider) {
461 Entry::Occupied(entry) => entry.into_mut(),
462 Entry::Vacant(entry) => {
463 let agent = Agent::from_model_settings(&prompt.model_settings)?;
464 workflow.add_agent(&agent);
465 entry.insert(agent)
466 }
467 };
468
469 let task = Task::new(&agent.id, prompt.clone(), &metric.name, None, None);
470 validate_prompt_parameters(prompt, &metric.name)?;
471 workflow.add_task(task)?;
472 metric_names.push(metric.name.clone());
473 }
474
475 config.alert_config.set_alert_conditions(&metrics);
476
477 Ok(Self {
478 config,
479 metrics,
480 scouter_version: scouter_version(),
481 workflow,
482 metric_names,
483 })
484 }
485
486 pub fn from_workflow(
504 mut config: LLMDriftConfig,
505 workflow: Workflow,
506 metrics: Vec<LLMMetric>,
507 ) -> Result<Self, ProfileError> {
508 validate_workflow(&workflow, &metrics)?;
509
510 config.alert_config.set_alert_conditions(&metrics);
511
512 let metric_names = metrics.iter().map(|m| m.name.clone()).collect();
513
514 Ok(Self {
515 config,
516 metrics,
517 scouter_version: scouter_version(),
518 workflow,
519 metric_names,
520 })
521 }
522
523 #[instrument(skip_all)]
534 fn extract_workflow(py_workflow: &Bound<'_, PyAny>) -> Result<Workflow, ProfileError> {
535 debug!("Extracting workflow from Python object");
536
537 if !py_workflow.hasattr("__workflow__")? {
538 error!("Invalid workflow type provided. Expected object with __workflow__ method.");
539 return Err(ProfileError::InvalidWorkflowType);
540 }
541
542 let workflow_string = py_workflow
543 .getattr("__workflow__")
544 .map_err(|e| {
545 error!("Failed to call __workflow__ property: {}", e);
546 ProfileError::InvalidWorkflowType
547 })?
548 .extract::<String>()
549 .inspect_err(|e| {
550 error!(
551 "Failed to extract workflow string from Python object: {}",
552 e
553 );
554 })?;
555
556 serde_json::from_str(&workflow_string).map_err(|e| {
557 error!("Failed to deserialize workflow: {}", e);
558 ProfileError::InvalidWorkflowType
559 })
560 }
561
562 pub fn get_metric_value(&self, metric_name: &str) -> Result<f64, ProfileError> {
563 self.metrics
564 .iter()
565 .find(|m| m.name == metric_name)
566 .map(|m| m.value)
567 .ok_or_else(|| ProfileError::MetricNotFound(metric_name.to_string()))
568 }
569}
570
571impl ProfileBaseArgs for LLMDriftProfile {
572 fn get_base_args(&self) -> ProfileArgs {
573 ProfileArgs {
574 name: self.config.name.clone(),
575 space: self.config.space.clone(),
576 version: Some(self.config.version.clone()),
577 schedule: self.config.alert_config.schedule.clone(),
578 scouter_version: self.scouter_version.clone(),
579 drift_type: self.config.drift_type.clone(),
580 }
581 }
582
583 fn to_value(&self) -> Value {
584 serde_json::to_value(self).unwrap()
585 }
586}
587
588#[pyclass]
589#[derive(Debug, Serialize, Deserialize, Clone)]
590pub struct LLMDriftMap {
591 #[pyo3(get)]
592 pub records: Vec<LLMMetricRecord>,
593}
594
595#[pymethods]
596impl LLMDriftMap {
597 #[new]
598 pub fn new(records: Vec<LLMMetricRecord>) -> Self {
599 Self { records }
600 }
601
602 pub fn __str__(&self) -> String {
603 ProfileFuncs::__str__(self)
605 }
606}
607
608#[cfg(test)]
610#[cfg(feature = "mock")]
611mod tests {
612
613 use super::*;
614 use crate::AlertThreshold;
615 use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
616 use potato_head::create_score_prompt;
617 use potato_head::prompt::ResponseType;
618
619 use potato_head::{LLMTestServer, Message, PromptContent};
620
621 pub fn create_parameterized_prompt() -> Prompt {
622 let user_content =
623 PromptContent::Str("What is ${input} + ${response} + ${context}?".to_string());
624 let system_content = PromptContent::Str("You are a helpful assistant.".to_string());
625 Prompt::new_rs(
626 vec![Message::new_rs(user_content)],
627 Some("gpt-4o"),
628 Some("openai"),
629 vec![Message::new_rs(system_content)],
630 None,
631 None,
632 ResponseType::Null,
633 )
634 .unwrap()
635 }
636
637 #[test]
638 fn test_llm_drift_config() {
639 let mut drift_config = LLMDriftConfig::new(
640 MISSING,
641 MISSING,
642 "0.1.0",
643 25,
644 LLMAlertConfig::default(),
645 None,
646 )
647 .unwrap();
648 assert_eq!(drift_config.name, "__missing__");
649 assert_eq!(drift_config.space, "__missing__");
650 assert_eq!(drift_config.version, "0.1.0");
651 assert_eq!(
652 drift_config.alert_config.dispatch_config,
653 AlertDispatchConfig::default()
654 );
655
656 let test_slack_dispatch_config = SlackDispatchConfig {
657 channel: "test-channel".to_string(),
658 };
659 let new_alert_config = LLMAlertConfig {
660 schedule: "0 0 * * * *".to_string(),
661 dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
662 ..Default::default()
663 };
664
665 drift_config
666 .update_config_args(None, Some("test".to_string()), None, Some(new_alert_config))
667 .unwrap();
668
669 assert_eq!(drift_config.name, "test");
670 assert_eq!(
671 drift_config.alert_config.dispatch_config,
672 AlertDispatchConfig::Slack(test_slack_dispatch_config)
673 );
674 assert_eq!(
675 drift_config.alert_config.schedule,
676 "0 0 * * * *".to_string()
677 );
678 }
679
680 #[test]
681 fn test_llm_drift_profile_metric() {
682 let _runtime = tokio::runtime::Runtime::new().unwrap();
683 let mut mock = LLMTestServer::new();
684 mock.start_server().unwrap();
685 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
686
687 let metric1 = LLMMetric::new(
688 "metric1",
689 5.0,
690 AlertThreshold::Above,
691 None,
692 Some(prompt.clone()),
693 )
694 .unwrap();
695 let metric2 = LLMMetric::new(
696 "metric2",
697 3.0,
698 AlertThreshold::Below,
699 Some(1.0),
700 Some(prompt.clone()),
701 )
702 .unwrap();
703
704 let llm_metrics = vec![metric1, metric2];
705
706 let alert_config = LLMAlertConfig {
707 schedule: "0 0 * * * *".to_string(),
708 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
709 team: "test-team".to_string(),
710 priority: "P5".to_string(),
711 }),
712 ..Default::default()
713 };
714
715 let drift_config =
716 LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
717
718 let profile = LLMDriftProfile::from_metrics(drift_config, llm_metrics).unwrap();
719 let _: Value =
720 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
721
722 assert_eq!(profile.metrics.len(), 2);
723 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
724
725 mock.stop_server().unwrap();
726 }
727
728 #[test]
729 fn test_llm_drift_profile_workflow() {
730 let mut mock = LLMTestServer::new();
731 mock.start_server().unwrap();
732
733 let mut workflow = Workflow::new("My eval Workflow");
734
735 let initial_prompt = create_parameterized_prompt();
736 let final_prompt1 = create_score_prompt(None);
737 let final_prompt2 = create_score_prompt(None);
738
739 let agent1 = Agent::new(Provider::OpenAI, None).unwrap();
740 workflow.add_agent(&agent1);
741
742 workflow
744 .add_task(Task::new(
745 &agent1.id,
746 initial_prompt.clone(),
747 "task1",
748 None,
749 None,
750 ))
751 .unwrap();
752
753 workflow
755 .add_task(Task::new(
756 &agent1.id,
757 final_prompt1.clone(),
758 "task2",
759 Some(vec!["task1".to_string()]),
760 None,
761 ))
762 .unwrap();
763
764 workflow
765 .add_task(Task::new(
766 &agent1.id,
767 final_prompt2.clone(),
768 "task3",
769 Some(vec!["task1".to_string()]),
770 None,
771 ))
772 .unwrap();
773
774 let metric1 = LLMMetric::new("task2", 3.0, AlertThreshold::Below, Some(1.0), None).unwrap();
775 let metric2 = LLMMetric::new("task3", 4.0, AlertThreshold::Above, Some(2.0), None).unwrap();
776
777 let llm_metrics = vec![metric1, metric2];
778
779 let alert_config = LLMAlertConfig {
780 schedule: "0 0 * * * *".to_string(),
781 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
782 team: "test-team".to_string(),
783 priority: "P5".to_string(),
784 }),
785 ..Default::default()
786 };
787
788 let drift_config =
789 LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
790
791 let profile = LLMDriftProfile::from_workflow(drift_config, workflow, llm_metrics).unwrap();
792
793 let _: Value =
794 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
795
796 assert_eq!(profile.metrics.len(), 2);
797 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
798
799 let plan = profile.workflow.execution_plan().unwrap();
800
801 assert_eq!(plan.len(), 2);
803 assert_eq!(plan.get(&1).unwrap().len(), 1);
805 assert_eq!(plan.get(&(plan.len() as i32)).unwrap().len(), 2);
807
808 mock.stop_server().unwrap();
809 }
810}