1use crate::error::{ProfileError, TypeError};
2use crate::genai::alert::GenAIAlertConfig;
3use crate::genai::eval::{AssertionTask, EvaluationTask, LLMJudgeTask};
4use crate::genai::traits::{separate_tasks, ProfileExt, TaskAccessor};
5use crate::genai::utils::{extract_assertion_tasks_from_pylist, AssertionTasks};
6use crate::genai::TraceAssertionTask;
7use crate::traits::ConfigExt;
8use crate::util::{json_to_pyobject, pyobject_to_json};
9use crate::{scouter_version, EvalTaskResult, GenAIEvalWorkflowResult, WorkflowResultTableEntry};
10use crate::{
11 DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
12 PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
13};
14use crate::{ProfileRequest, TaskResultTableEntry};
15use chrono::{DateTime, Utc};
16use core::fmt::Debug;
17use potato_head::prompt_types::Prompt;
18use potato_head::Agent;
19use potato_head::Workflow;
20use potato_head::{create_uuid7, Task};
21use pyo3::prelude::*;
22use pyo3::types::{PyDict, PyList};
23use scouter_semver::VersionType;
24use scouter_state::app_state;
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::hash_map::Entry;
28use std::collections::BTreeSet;
29use std::collections::HashMap;
30use std::path::PathBuf;
31use std::sync::Arc;
32
33use tracing::instrument;
34
35fn default_sample_ratio() -> f64 {
36 1.0
37}
38
39fn default_space() -> String {
40 MISSING.to_string()
41}
42
43fn default_name() -> String {
44 MISSING.to_string()
45}
46
47fn default_version() -> String {
48 DEFAULT_VERSION.to_string()
49}
50
51fn default_uid() -> String {
52 create_uuid7()
53}
54
55fn default_drift_type() -> DriftType {
56 DriftType::GenAI
57}
58
59fn default_alert_config() -> GenAIAlertConfig {
60 GenAIAlertConfig::default()
61}
62
63#[pyclass]
64#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
65pub struct GenAIEvalConfig {
66 #[pyo3(get, set)]
67 #[serde(default = "default_sample_ratio")]
68 pub sample_ratio: f64,
69
70 #[pyo3(get, set)]
71 #[serde(default = "default_space")]
72 pub space: String,
73
74 #[pyo3(get, set)]
75 #[serde(default = "default_name")]
76 pub name: String,
77
78 #[pyo3(get, set)]
79 #[serde(default = "default_version")]
80 pub version: String,
81
82 #[pyo3(get, set)]
83 #[serde(default = "default_uid")]
84 pub uid: String,
85
86 #[pyo3(get, set)]
87 #[serde(default = "default_alert_config")]
88 pub alert_config: GenAIAlertConfig,
89
90 #[pyo3(get, set)]
91 #[serde(default = "default_drift_type")]
92 pub drift_type: DriftType,
93}
94
95impl ConfigExt for GenAIEvalConfig {
96 fn space(&self) -> &str {
97 &self.space
98 }
99
100 fn name(&self) -> &str {
101 &self.name
102 }
103
104 fn version(&self) -> &str {
105 &self.version
106 }
107 fn uid(&self) -> &str {
108 &self.uid
109 }
110}
111
112impl DispatchDriftConfig for GenAIEvalConfig {
113 fn get_drift_args(&self) -> DriftArgs {
114 DriftArgs {
115 name: self.name.clone(),
116 space: self.space.clone(),
117 version: self.version.clone(),
118 dispatch_config: self.alert_config.dispatch_config.clone(),
119 }
120 }
121}
122
123#[pymethods]
124#[allow(clippy::too_many_arguments)]
125impl GenAIEvalConfig {
126 #[new]
127 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_ratio=1.0, alert_config=GenAIAlertConfig::default(), config_path=None))]
128 pub fn new(
129 space: &str,
130 name: &str,
131 version: &str,
132 sample_ratio: f64,
133 alert_config: GenAIAlertConfig,
134 config_path: Option<PathBuf>,
135 ) -> Result<Self, ProfileError> {
136 if let Some(config_path) = config_path {
137 let config = GenAIEvalConfig::load_from_json_file(config_path)?;
138 return Ok(config);
139 }
140
141 Ok(Self {
142 sample_ratio: sample_ratio.clamp(0.0, 1.0),
143 space: space.to_string(),
144 name: name.to_string(),
145 uid: create_uuid7(),
146 version: version.to_string(),
147 alert_config,
148 drift_type: DriftType::GenAI,
149 })
150 }
151
152 #[staticmethod]
153 pub fn load_from_json_file(path: PathBuf) -> Result<GenAIEvalConfig, ProfileError> {
154 let file = std::fs::read_to_string(&path)?;
157
158 Ok(serde_json::from_str(&file)?)
159 }
160
161 pub fn __str__(&self) -> String {
162 PyHelperFuncs::__str__(self)
164 }
165
166 pub fn model_dump_json(&self) -> String {
167 PyHelperFuncs::__json__(self)
169 }
170
171 #[allow(clippy::too_many_arguments)]
172 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
173 pub fn update_config_args(
174 &mut self,
175 space: Option<String>,
176 name: Option<String>,
177 version: Option<String>,
178 uid: Option<String>,
179 alert_config: Option<GenAIAlertConfig>,
180 ) -> Result<(), TypeError> {
181 if name.is_some() {
182 self.name = name.ok_or(TypeError::MissingNameError)?;
183 }
184
185 if space.is_some() {
186 self.space = space.ok_or(TypeError::MissingSpaceError)?;
187 }
188
189 if version.is_some() {
190 self.version = version.ok_or(TypeError::MissingVersionError)?;
191 }
192
193 if alert_config.is_some() {
194 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
195 }
196
197 if uid.is_some() {
198 self.uid = uid.ok_or(TypeError::MissingUidError)?;
199 }
200
201 Ok(())
202 }
203}
204
205impl Default for GenAIEvalConfig {
206 fn default() -> Self {
207 Self {
208 sample_ratio: 1.0,
209 space: "default".to_string(),
210 name: "default_genai_profile".to_string(),
211 version: DEFAULT_VERSION.to_string(),
212 uid: create_uuid7(),
213 alert_config: GenAIAlertConfig::default(),
214 drift_type: DriftType::GenAI,
215 }
216 }
217}
218
219fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
235 let has_at_least_one_param = !prompt.parameters.is_empty();
236
237 if !has_at_least_one_param {
238 return Err(ProfileError::NeedAtLeastOneBoundParameterError(
239 id.to_string(),
240 ));
241 }
242
243 Ok(())
244}
245
246fn get_workflow_task<'a>(
247 workflow: &'a Workflow,
248 task_id: &'a str,
249) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
250 workflow
251 .task_list
252 .tasks
253 .get(task_id)
254 .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
255}
256
257fn validate_first_tasks(
259 workflow: &Workflow,
260 execution_order: &HashMap<i32, std::collections::HashSet<String>>,
261) -> Result<(), ProfileError> {
262 let first_tasks = execution_order
263 .get(&1)
264 .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
265
266 for task_id in first_tasks {
267 let task = get_workflow_task(workflow, task_id)?;
268 let task_guard = task
269 .read()
270 .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
271
272 validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
273 }
274
275 Ok(())
276}
277
278fn validate_workflow(workflow: &Workflow) -> Result<(), ProfileError> {
294 let execution_order = workflow.execution_plan()?;
295
296 validate_first_tasks(workflow, &execution_order)?;
298
299 Ok(())
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
303#[pyclass]
304pub struct ExecutionNode {
305 pub id: String,
306 pub stage: usize,
307 pub parents: Vec<String>,
308 pub children: Vec<String>,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize, Default)]
312#[pyclass]
313pub struct ExecutionPlan {
314 #[pyo3(get)]
315 pub stages: Vec<Vec<String>>,
316 #[pyo3(get)]
317 pub nodes: HashMap<String, ExecutionNode>,
318}
319
320fn initialize_node_graphs(
321 tasks: &[impl TaskAccessor],
322 graph: &mut HashMap<String, Vec<String>>,
323 reverse_graph: &mut HashMap<String, Vec<String>>,
324 in_degree: &mut HashMap<String, usize>,
325) {
326 for task in tasks {
327 let task_id = task.id().to_string();
328 graph.entry(task_id.clone()).or_default();
329 reverse_graph.entry(task_id.clone()).or_default();
330 in_degree.entry(task_id).or_insert(0);
331 }
332}
333
334fn build_dependency_edges(
335 tasks: &[impl TaskAccessor],
336 graph: &mut HashMap<String, Vec<String>>,
337 reverse_graph: &mut HashMap<String, Vec<String>>,
338 in_degree: &mut HashMap<String, usize>,
339) {
340 for task in tasks {
341 let task_id = task.id().to_string();
342 for dep in task.depends_on() {
343 graph.entry(dep.clone()).or_default().push(task_id.clone());
344 reverse_graph
345 .entry(task_id.clone())
346 .or_default()
347 .push(dep.clone());
348 *in_degree.entry(task_id.clone()).or_insert(0) += 1;
349 }
350 }
351}
352
353#[pyclass]
354#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
355pub struct GenAIEvalProfile {
356 #[pyo3(get)]
357 pub config: GenAIEvalConfig,
358
359 pub tasks: AssertionTasks,
360
361 #[pyo3(get)]
362 pub scouter_version: String,
363
364 pub workflow: Option<Workflow>,
365
366 pub task_ids: BTreeSet<String>,
367
368 pub alias: Option<String>,
369}
370
371#[pymethods]
372impl GenAIEvalProfile {
373 #[new]
374 #[pyo3(signature = (tasks, config=None, alias=None))]
375 #[instrument(skip_all)]
386 pub fn new_py(
387 tasks: &Bound<'_, PyList>,
388 config: Option<GenAIEvalConfig>,
389 alias: Option<String>,
390 ) -> Result<Self, ProfileError> {
391 let tasks = extract_assertion_tasks_from_pylist(tasks)?;
392
393 let (workflow, task_ids) =
394 app_state().block_on(async { Self::build_profile(&tasks).await })?;
395
396 Ok(Self {
397 config: config.unwrap_or_default(),
398 tasks,
399 scouter_version: scouter_version(),
400 workflow,
401 task_ids,
402 alias,
403 })
404 }
405
406 pub fn __str__(&self) -> String {
407 PyHelperFuncs::__str__(self)
409 }
410
411 pub fn model_dump_json(&self) -> String {
412 PyHelperFuncs::__json__(self)
414 }
415
416 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
417 let json_value = serde_json::to_value(self)?;
418
419 let dict = PyDict::new(py);
421
422 json_to_pyobject(py, &json_value, &dict)?;
424
425 Ok(dict.into())
427 }
428
429 #[getter]
430 pub fn drift_type(&self) -> DriftType {
431 self.config.drift_type.clone()
432 }
433
434 #[getter]
435 pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
436 self.tasks.assertion.clone()
437 }
438
439 #[getter]
440 pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
441 self.tasks.judge.clone()
442 }
443
444 #[getter]
445 pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
446 self.tasks.trace.clone()
447 }
448
449 #[getter]
450 pub fn alias(&self) -> Option<String> {
451 self.alias.clone()
452 }
453
454 #[getter]
455 pub fn uid(&self) -> String {
456 self.config.uid.clone()
457 }
458
459 #[setter]
460 pub fn set_uid(&mut self, uid: String) {
461 self.config.uid = uid;
462 }
463
464 #[pyo3(signature = (path=None))]
465 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
466 Ok(PyHelperFuncs::save_to_json(
467 self,
468 path,
469 FileName::GenAIEvalProfile.to_str(),
470 )?)
471 }
472
473 #[staticmethod]
474 pub fn model_validate(data: &Bound<'_, PyDict>) -> GenAIEvalProfile {
475 let json_value = pyobject_to_json(data).unwrap();
476
477 let string = serde_json::to_string(&json_value).unwrap();
478 serde_json::from_str(&string).expect("Failed to load drift profile")
479 }
480
481 #[staticmethod]
482 pub fn model_validate_json(json_string: String) -> GenAIEvalProfile {
483 serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
485 }
486
487 #[staticmethod]
488 pub fn from_file(path: PathBuf) -> Result<GenAIEvalProfile, ProfileError> {
489 let file = std::fs::read_to_string(&path)?;
490
491 Ok(serde_json::from_str(&file)?)
492 }
493
494 #[allow(clippy::too_many_arguments)]
495 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
496 pub fn update_config_args(
497 &mut self,
498 space: Option<String>,
499 name: Option<String>,
500 version: Option<String>,
501 uid: Option<String>,
502 alert_config: Option<GenAIAlertConfig>,
503 ) -> Result<(), TypeError> {
504 self.config
505 .update_config_args(space, name, version, uid, alert_config)
506 }
507
508 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
510 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
511 None
512 } else {
513 Some(self.config.version.clone())
514 };
515
516 Ok(ProfileRequest {
517 space: self.config.space.clone(),
518 profile: self.model_dump_json(),
519 drift_type: self.config.drift_type.clone(),
520 version_request: Some(VersionRequest {
521 version,
522 version_type: VersionType::Minor,
523 pre_tag: None,
524 build_tag: None,
525 }),
526 active: false,
527 deactivate_others: false,
528 })
529 }
530
531 pub fn has_llm_tasks(&self) -> bool {
532 !self.tasks.judge.is_empty()
533 }
534
535 pub fn has_assertions(&self) -> bool {
537 !self.tasks.assertion.is_empty()
538 }
539
540 pub fn has_trace_assertions(&self) -> bool {
541 !self.tasks.trace.is_empty()
542 }
543
544 pub fn get_execution_plan(&self) -> Result<ExecutionPlan, ProfileError> {
546 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
547 let mut reverse_graph: HashMap<String, Vec<String>> = HashMap::new();
548 let mut in_degree: HashMap<String, usize> = HashMap::new();
549
550 initialize_node_graphs(
551 &self.tasks.assertion,
552 &mut graph,
553 &mut reverse_graph,
554 &mut in_degree,
555 );
556 initialize_node_graphs(
557 &self.tasks.judge,
558 &mut graph,
559 &mut reverse_graph,
560 &mut in_degree,
561 );
562
563 initialize_node_graphs(
564 &self.tasks.trace,
565 &mut graph,
566 &mut reverse_graph,
567 &mut in_degree,
568 );
569
570 build_dependency_edges(
571 &self.tasks.assertion,
572 &mut graph,
573 &mut reverse_graph,
574 &mut in_degree,
575 );
576 build_dependency_edges(
577 &self.tasks.judge,
578 &mut graph,
579 &mut reverse_graph,
580 &mut in_degree,
581 );
582
583 build_dependency_edges(
584 &self.tasks.trace,
585 &mut graph,
586 &mut reverse_graph,
587 &mut in_degree,
588 );
589 let mut stages = Vec::new();
590 let mut nodes: HashMap<String, ExecutionNode> = HashMap::new();
591 let mut current_level: Vec<String> = in_degree
592 .iter()
593 .filter(|(_, °ree)| degree == 0)
594 .map(|(id, _)| id.clone())
595 .collect();
596
597 let mut stage_idx = 0;
598
599 while !current_level.is_empty() {
600 stages.push(current_level.clone());
601
602 for task_id in ¤t_level {
603 nodes.insert(
604 task_id.clone(),
605 ExecutionNode {
606 id: task_id.clone(),
607 stage: stage_idx,
608 parents: reverse_graph.get(task_id).cloned().unwrap_or_default(),
609 children: graph.get(task_id).cloned().unwrap_or_default(),
610 },
611 );
612 }
613
614 let mut next_level = Vec::new();
615 for task_id in ¤t_level {
616 if let Some(dependents) = graph.get(task_id) {
617 for dependent in dependents {
618 if let Some(degree) = in_degree.get_mut(dependent) {
619 *degree -= 1;
620 if *degree == 0 {
621 next_level.push(dependent.clone());
622 }
623 }
624 }
625 }
626 }
627
628 current_level = next_level;
629 stage_idx += 1;
630 }
631
632 let total_tasks =
633 self.tasks.assertion.len() + self.tasks.judge.len() + self.tasks.trace.len();
634 let processed_tasks: usize = stages.iter().map(|level| level.len()).sum();
635
636 if processed_tasks != total_tasks {
637 return Err(ProfileError::CircularDependency);
638 }
639
640 Ok(ExecutionPlan { stages, nodes })
641 }
642
643 pub fn print_execution_plan(&self) -> Result<(), ProfileError> {
644 use owo_colors::OwoColorize;
645
646 let plan = self.get_execution_plan()?;
647
648 println!("\n{}", "Evaluation Execution Plan".bold().green());
649 println!("{}", "═".repeat(70).green());
650
651 let mut conditional_count = 0;
652
653 for (level_idx, level) in plan.stages.iter().enumerate() {
654 let stage_label = format!("Stage {}", level_idx + 1);
655 println!("\n{}", stage_label.bold().cyan());
656
657 for (task_idx, task_id) in level.iter().enumerate() {
658 let is_last = task_idx == level.len() - 1;
659 let prefix = if is_last { "└─" } else { "├─" };
660
661 let task = self.get_task_by_id(task_id).ok_or_else(|| {
662 ProfileError::NoTasksFoundError(format!("Task '{}' not found", task_id))
663 })?;
664
665 let is_conditional = if let Some(assertion) = self.get_assertion_by_id(task_id) {
666 assertion.condition
667 } else if let Some(judge) = self.get_llm_judge_by_id(task_id) {
668 judge.condition
669 } else if let Some(trace) = self.get_trace_assertion_by_id(task_id) {
670 trace.condition
671 } else {
672 false
673 };
674
675 if is_conditional {
676 conditional_count += 1;
677 }
678
679 let (task_type, color_fn): (&str, fn(&str) -> String) =
680 if self.tasks.assertion.iter().any(|t| &t.id == task_id) {
681 ("Assertion", |s: &str| s.yellow().to_string())
682 } else if self.tasks.trace.iter().any(|t| &t.id == task_id) {
683 ("Trace Assertion", |s: &str| s.bright_blue().to_string())
684 } else {
685 ("LLM Judge", |s: &str| s.purple().to_string())
686 };
687
688 let conditional_marker = if is_conditional {
689 " [CONDITIONAL]".bright_red().to_string()
690 } else {
691 String::new()
692 };
693
694 println!(
695 "{} {} ({}){}",
696 prefix,
697 task_id.bold(),
698 color_fn(task_type),
699 conditional_marker
700 );
701
702 let deps = task.depends_on();
703 if !deps.is_empty() {
704 let dep_prefix = if is_last { " " } else { "│ " };
705
706 let (conditional_deps, normal_deps): (Vec<_>, Vec<_>) =
707 deps.iter().partition(|dep_id| {
708 self.get_assertion_by_id(dep_id)
709 .map(|t| t.condition)
710 .or_else(|| self.get_llm_judge_by_id(dep_id).map(|t| t.condition))
711 .or_else(|| {
712 self.get_trace_assertion_by_id(dep_id).map(|t| t.condition)
713 })
714 .unwrap_or(false)
715 });
716
717 if !normal_deps.is_empty() {
718 println!(
719 "{} {} {}",
720 dep_prefix,
721 "depends on:".dimmed(),
722 normal_deps
723 .iter()
724 .map(|s| s.as_str())
725 .collect::<Vec<_>>()
726 .join(", ")
727 .dimmed()
728 );
729 }
730
731 if !conditional_deps.is_empty() {
732 println!(
733 "{} {} {}",
734 dep_prefix,
735 "▶ conditional gate:".bright_red().dimmed(),
736 conditional_deps
737 .iter()
738 .map(|d| format!("{} must pass", d))
739 .collect::<Vec<_>>()
740 .join(", ")
741 .red()
742 .dimmed()
743 );
744 }
745 }
746
747 if is_conditional {
748 let continuation = if is_last { " " } else { "│ " };
749 println!(
750 "{} {} {}",
751 continuation,
752 "▶".bright_red(),
753 "creates conditional branch".bright_red().dimmed()
754 );
755 }
756 }
757 }
758
759 println!("\n{}", "═".repeat(70).green());
760 println!(
761 "{}: {} tasks across {} stages",
762 "Summary".bold(),
763 self.tasks.assertion.len() + self.tasks.judge.len() + self.tasks.trace.len(),
764 plan.stages.len()
765 );
766
767 if conditional_count > 0 {
768 println!(
769 "{}: {} conditional tasks that create execution branches",
770 "Branches".bold().bright_red(),
771 conditional_count
772 );
773 }
774
775 println!();
776
777 Ok(())
778 }
779}
780
781impl Default for GenAIEvalProfile {
782 fn default() -> Self {
783 Self {
784 config: GenAIEvalConfig::default(),
785 tasks: AssertionTasks {
786 assertion: Vec::new(),
787 judge: Vec::new(),
788 trace: Vec::new(),
789 },
790 scouter_version: scouter_version(),
791 workflow: None,
792 task_ids: BTreeSet::new(),
793 alias: None,
794 }
795 }
796}
797
798impl GenAIEvalProfile {
799 pub fn build_from_parts(
801 config: GenAIEvalConfig,
802 tasks: AssertionTasks,
803 alias: Option<String>,
804 ) -> Result<GenAIEvalProfile, ProfileError> {
805 let (workflow, task_ids) =
806 app_state().block_on(async { GenAIEvalProfile::build_profile(&tasks).await })?;
807
808 Ok(GenAIEvalProfile {
809 config,
810 tasks,
811 scouter_version: scouter_version(),
812 workflow,
813 task_ids,
814 alias,
815 })
816 }
817
818 #[instrument(skip_all)]
819 pub async fn new(
820 config: GenAIEvalConfig,
821 tasks: Vec<EvaluationTask>,
822 ) -> Result<Self, ProfileError> {
823 let tasks = separate_tasks(tasks);
824 let (workflow, task_ids) = Self::build_profile(&tasks).await?;
825
826 Ok(Self {
827 config,
828 tasks,
829 scouter_version: scouter_version(),
830 workflow,
831 task_ids,
832 alias: None,
833 })
834 }
835
836 async fn build_profile(
837 tasks: &AssertionTasks,
838 ) -> Result<(Option<Workflow>, BTreeSet<String>), ProfileError> {
839 if tasks.assertion.is_empty() && tasks.judge.is_empty() && tasks.trace.is_empty() {
840 return Err(ProfileError::EmptyTaskList);
841 }
842
843 let workflow = if !tasks.judge.is_empty() {
844 let workflow = Self::build_workflow_from_judges(tasks).await?;
845 validate_workflow(&workflow)?;
846 Some(workflow)
847 } else {
848 None
849 };
850
851 for judge in &tasks.judge {
853 validate_prompt_parameters(&judge.prompt, &judge.id)?;
854 }
855
856 let task_ids = tasks.collect_all_task_ids()?;
858
859 Ok((workflow, task_ids))
860 }
861
862 async fn get_or_create_agent(
863 agents: &mut HashMap<potato_head::Provider, Agent>,
864 workflow: &mut Workflow,
865 provider: &potato_head::Provider,
866 ) -> Result<Agent, ProfileError> {
867 match agents.entry(provider.clone()) {
868 Entry::Occupied(entry) => Ok(entry.get().clone()),
869 Entry::Vacant(entry) => {
870 let agent = Agent::new(provider.clone(), None).await?;
871 workflow.add_agent(&agent);
872 Ok(entry.insert(agent).clone())
873 }
874 }
875 }
876
877 fn filter_judge_dependencies(
878 depends_on: &[String],
879 non_judge_task_ids: &BTreeSet<String>,
880 ) -> Option<Vec<String>> {
881 let filtered: Vec<String> = depends_on
882 .iter()
883 .filter(|dep_id| !non_judge_task_ids.contains(*dep_id))
884 .cloned()
885 .collect();
886
887 if filtered.is_empty() {
888 None
889 } else {
890 Some(filtered)
891 }
892 }
893
894 pub async fn build_workflow_from_judges(
900 tasks: &AssertionTasks,
901 ) -> Result<Workflow, ProfileError> {
902 let mut workflow = Workflow::new(&format!("eval_workflow_{}", create_uuid7()));
903 let mut agents = HashMap::new();
904 let non_judge_task_ids = tasks.collect_non_judge_task_ids();
905
906 for judge in &tasks.judge {
907 let agent =
908 Self::get_or_create_agent(&mut agents, &mut workflow, &judge.prompt.provider)
909 .await?;
910
911 let task_deps = Self::filter_judge_dependencies(&judge.depends_on, &non_judge_task_ids);
912
913 let task = Task::new(
914 &agent.id,
915 judge.prompt.clone(),
916 &judge.id,
917 task_deps,
918 judge.max_retries,
919 )?;
920
921 workflow.add_task(task)?;
922 }
923
924 Ok(workflow)
925 }
926}
927
928impl ProfileExt for GenAIEvalProfile {
929 #[inline]
930 fn id(&self) -> &str {
931 &self.config.uid
932 }
933
934 fn get_task_by_id(&self, id: &str) -> Option<&dyn TaskAccessor> {
935 if let Some(assertion) = self.tasks.assertion.iter().find(|t| t.id() == id) {
936 return Some(assertion);
937 }
938
939 if let Some(judge) = self.tasks.judge.iter().find(|t| t.id() == id) {
940 return Some(judge);
941 }
942
943 if let Some(trace) = self.tasks.trace.iter().find(|t| t.id() == id) {
944 return Some(trace);
945 }
946
947 None
948 }
949
950 #[inline]
951 fn get_assertion_by_id(&self, id: &str) -> Option<&AssertionTask> {
953 self.tasks.assertion.iter().find(|t| t.id() == id)
954 }
955
956 #[inline]
957 fn get_llm_judge_by_id(&self, id: &str) -> Option<&LLMJudgeTask> {
958 self.tasks.judge.iter().find(|t| t.id() == id)
959 }
960
961 #[inline]
962 fn get_trace_assertion_by_id(&self, id: &str) -> Option<&TraceAssertionTask> {
963 self.tasks.trace.iter().find(|t| t.id() == id)
964 }
965
966 #[inline]
967 fn has_llm_tasks(&self) -> bool {
968 !self.tasks.judge.is_empty()
969 }
970
971 #[inline]
972 fn has_trace_assertions(&self) -> bool {
973 !self.tasks.trace.is_empty()
974 }
975}
976
977impl ProfileBaseArgs for GenAIEvalProfile {
978 type Config = GenAIEvalConfig;
979
980 fn config(&self) -> &Self::Config {
981 &self.config
982 }
983 fn get_base_args(&self) -> ProfileArgs {
984 ProfileArgs {
985 name: self.config.name.clone(),
986 space: self.config.space.clone(),
987 version: Some(self.config.version.clone()),
988 schedule: self.config.alert_config.schedule.clone(),
989 scouter_version: self.scouter_version.clone(),
990 drift_type: self.config.drift_type.clone(),
991 }
992 }
993
994 fn to_value(&self) -> Value {
995 serde_json::to_value(self).unwrap()
996 }
997}
998
999#[pyclass]
1000#[derive(Debug, Serialize, Deserialize, Clone)]
1001pub struct EvalSet {
1002 #[pyo3(get)]
1003 pub records: Vec<EvalTaskResult>,
1004 pub inner: GenAIEvalWorkflowResult,
1005}
1006
1007impl EvalSet {
1008 pub fn build_task_entries(&mut self, record_id: &str) -> Vec<TaskResultTableEntry> {
1009 self.records
1012 .sort_by(|a, b| a.stage.cmp(&b.stage).then(a.task_id.cmp(&b.task_id)));
1013
1014 self.records
1015 .iter()
1016 .map(|record| record.to_table_entry(record_id))
1017 .collect()
1018 }
1019
1020 pub fn build_workflow_entries(&self) -> Vec<WorkflowResultTableEntry> {
1021 vec![self.inner.to_table_entry()]
1022 }
1023
1024 pub fn new(records: Vec<EvalTaskResult>, inner: GenAIEvalWorkflowResult) -> Self {
1025 Self { records, inner }
1026 }
1027
1028 pub fn empty() -> Self {
1029 Self {
1030 records: Vec::new(),
1031 inner: GenAIEvalWorkflowResult {
1032 created_at: Utc::now(),
1033 record_uid: String::new(),
1034 entity_id: 0,
1035 total_tasks: 0,
1036 passed_tasks: 0,
1037 failed_tasks: 0,
1038 pass_rate: 0.0,
1039 duration_ms: 0,
1040 entity_uid: String::new(),
1041 execution_plan: ExecutionPlan::default(),
1042 id: 0,
1043 },
1044 }
1045 }
1046}
1047
1048#[pymethods]
1049impl EvalSet {
1050 #[getter]
1051 pub fn created_at(&self) -> DateTime<Utc> {
1052 self.inner.created_at
1053 }
1054
1055 #[getter]
1056 pub fn record_uid(&self) -> String {
1057 self.inner.record_uid.clone()
1058 }
1059
1060 #[getter]
1061 pub fn total_tasks(&self) -> i32 {
1062 self.inner.total_tasks
1063 }
1064
1065 #[getter]
1066 pub fn passed_tasks(&self) -> i32 {
1067 self.inner.passed_tasks
1068 }
1069
1070 #[getter]
1071 pub fn failed_tasks(&self) -> i32 {
1072 self.inner.failed_tasks
1073 }
1074
1075 #[getter]
1076 pub fn pass_rate(&self) -> f64 {
1077 self.inner.pass_rate
1078 }
1079
1080 #[getter]
1081 pub fn duration_ms(&self) -> i64 {
1082 self.inner.duration_ms
1083 }
1084
1085 pub fn __str__(&self) -> String {
1086 PyHelperFuncs::__str__(self)
1088 }
1089}
1090
1091#[pyclass]
1092#[derive(Debug, Serialize, Deserialize, Clone)]
1093pub struct EvalResultSet {
1094 #[pyo3(get)]
1095 pub records: Vec<EvalSet>,
1096}
1097
1098#[pymethods]
1099impl EvalResultSet {
1100 pub fn record(&self, id: &str) -> Option<EvalSet> {
1101 self.records.iter().find(|r| r.record_uid() == id).cloned()
1102 }
1103 pub fn __str__(&self) -> String {
1104 PyHelperFuncs::__str__(self)
1106 }
1107}
1108
1109#[cfg(test)]
1111#[cfg(feature = "mock")]
1112mod tests {
1113
1114 use super::*;
1115 use crate::genai::{ComparisonOperator, EvaluationTasks};
1116 use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
1117
1118 use potato_head::mock::create_score_prompt;
1119
1120 #[test]
1121 fn test_genai_drift_config() {
1122 let mut drift_config = GenAIEvalConfig::new(
1123 MISSING,
1124 MISSING,
1125 "0.1.0",
1126 1.0,
1127 GenAIAlertConfig::default(),
1128 None,
1129 )
1130 .unwrap();
1131 assert_eq!(drift_config.name, "__missing__");
1132 assert_eq!(drift_config.space, "__missing__");
1133 assert_eq!(drift_config.version, "0.1.0");
1134 assert_eq!(
1135 drift_config.alert_config.dispatch_config,
1136 AlertDispatchConfig::default()
1137 );
1138
1139 let test_slack_dispatch_config = SlackDispatchConfig {
1140 channel: "test-channel".to_string(),
1141 };
1142 let new_alert_config = GenAIAlertConfig {
1143 schedule: "0 0 * * * *".to_string(),
1144 dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
1145 ..Default::default()
1146 };
1147
1148 drift_config
1149 .update_config_args(
1150 None,
1151 Some("test".to_string()),
1152 None,
1153 None,
1154 Some(new_alert_config),
1155 )
1156 .unwrap();
1157
1158 assert_eq!(drift_config.name, "test");
1159 assert_eq!(
1160 drift_config.alert_config.dispatch_config,
1161 AlertDispatchConfig::Slack(test_slack_dispatch_config)
1162 );
1163 assert_eq!(
1164 drift_config.alert_config.schedule,
1165 "0 0 * * * *".to_string()
1166 );
1167 }
1168
1169 #[tokio::test]
1170 async fn test_genai_drift_profile_metric() {
1171 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
1172
1173 let task1 = LLMJudgeTask::new_rs(
1174 "metric1",
1175 prompt.clone(),
1176 Value::Number(4.into()),
1177 None,
1178 ComparisonOperator::GreaterThanOrEqual,
1179 None,
1180 None,
1181 None,
1182 None,
1183 );
1184
1185 let task2 = LLMJudgeTask::new_rs(
1186 "metric2",
1187 prompt.clone(),
1188 Value::Number(2.into()),
1189 None,
1190 ComparisonOperator::LessThanOrEqual,
1191 None,
1192 None,
1193 None,
1194 None,
1195 );
1196
1197 let tasks = EvaluationTasks::new()
1198 .add_task(task1)
1199 .add_task(task2)
1200 .build();
1201
1202 let alert_config = GenAIAlertConfig {
1203 schedule: "0 0 * * * *".to_string(),
1204 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
1205 team: "test-team".to_string(),
1206 priority: "P5".to_string(),
1207 }),
1208 ..Default::default()
1209 };
1210
1211 let drift_config =
1212 GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
1213
1214 let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();
1215
1216 let _: Value =
1217 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
1218
1219 assert_eq!(profile.llm_judge_tasks().len(), 2);
1220 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
1221 }
1222}