1use crate::error::{ProfileError, TypeError};
2use crate::genai::alert::GenAIAlertConfig;
3use crate::genai::eval::{AssertionTask, EvaluationTask, LLMJudgeTask};
4use crate::genai::traits::{separate_tasks, ProfileExt, TaskAccessor};
5use crate::genai::utils::{extract_assertion_tasks_from_pylist, AssertionTasks};
6use crate::genai::TraceAssertionTask;
7use crate::traits::ConfigExt;
8use crate::util::{json_to_pyobject, pyobject_to_json};
9use crate::{
10 scouter_version, GenAIEvalTaskResult, GenAIEvalWorkflowResult, WorkflowResultTableEntry,
11};
12use crate::{
13 DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
14 PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
15};
16use crate::{ProfileRequest, TaskResultTableEntry};
17use chrono::{DateTime, Utc};
18use core::fmt::Debug;
19use potato_head::prompt_types::Prompt;
20use potato_head::Agent;
21use potato_head::Workflow;
22use potato_head::{create_uuid7, Task};
23use pyo3::prelude::*;
24use pyo3::types::{PyDict, PyList};
25use scouter_semver::VersionType;
26use scouter_state::app_state;
27use serde::{Deserialize, Serialize};
28use serde_json::Value;
29use std::collections::hash_map::Entry;
30use std::collections::BTreeSet;
31use std::collections::HashMap;
32use std::path::PathBuf;
33use std::sync::Arc;
34
35use tracing::instrument;
36
37fn default_sample_ratio() -> f64 {
38 1.0
39}
40
41fn default_space() -> String {
42 MISSING.to_string()
43}
44
45fn default_name() -> String {
46 MISSING.to_string()
47}
48
49fn default_version() -> String {
50 DEFAULT_VERSION.to_string()
51}
52
53fn default_uid() -> String {
54 create_uuid7()
55}
56
57fn default_drift_type() -> DriftType {
58 DriftType::GenAI
59}
60
61fn default_alert_config() -> GenAIAlertConfig {
62 GenAIAlertConfig::default()
63}
64
65#[pyclass]
66#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
67pub struct GenAIEvalConfig {
68 #[pyo3(get, set)]
69 #[serde(default = "default_sample_ratio")]
70 pub sample_ratio: f64,
71
72 #[pyo3(get, set)]
73 #[serde(default = "default_space")]
74 pub space: String,
75
76 #[pyo3(get, set)]
77 #[serde(default = "default_name")]
78 pub name: String,
79
80 #[pyo3(get, set)]
81 #[serde(default = "default_version")]
82 pub version: String,
83
84 #[pyo3(get, set)]
85 #[serde(default = "default_uid")]
86 pub uid: String,
87
88 #[pyo3(get, set)]
89 #[serde(default = "default_alert_config")]
90 pub alert_config: GenAIAlertConfig,
91
92 #[pyo3(get, set)]
93 #[serde(default = "default_drift_type")]
94 pub drift_type: DriftType,
95}
96
97impl ConfigExt for GenAIEvalConfig {
98 fn space(&self) -> &str {
99 &self.space
100 }
101
102 fn name(&self) -> &str {
103 &self.name
104 }
105
106 fn version(&self) -> &str {
107 &self.version
108 }
109 fn uid(&self) -> &str {
110 &self.uid
111 }
112}
113
114impl DispatchDriftConfig for GenAIEvalConfig {
115 fn get_drift_args(&self) -> DriftArgs {
116 DriftArgs {
117 name: self.name.clone(),
118 space: self.space.clone(),
119 version: self.version.clone(),
120 dispatch_config: self.alert_config.dispatch_config.clone(),
121 }
122 }
123}
124
125#[pymethods]
126#[allow(clippy::too_many_arguments)]
127impl GenAIEvalConfig {
128 #[new]
129 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_ratio=1.0, alert_config=GenAIAlertConfig::default(), config_path=None))]
130 pub fn new(
131 space: &str,
132 name: &str,
133 version: &str,
134 sample_ratio: f64,
135 alert_config: GenAIAlertConfig,
136 config_path: Option<PathBuf>,
137 ) -> Result<Self, ProfileError> {
138 if let Some(config_path) = config_path {
139 let config = GenAIEvalConfig::load_from_json_file(config_path)?;
140 return Ok(config);
141 }
142
143 Ok(Self {
144 sample_ratio: sample_ratio.clamp(0.0, 1.0),
145 space: space.to_string(),
146 name: name.to_string(),
147 uid: create_uuid7(),
148 version: version.to_string(),
149 alert_config,
150 drift_type: DriftType::GenAI,
151 })
152 }
153
154 #[staticmethod]
155 pub fn load_from_json_file(path: PathBuf) -> Result<GenAIEvalConfig, ProfileError> {
156 let file = std::fs::read_to_string(&path)?;
159
160 Ok(serde_json::from_str(&file)?)
161 }
162
163 pub fn __str__(&self) -> String {
164 PyHelperFuncs::__str__(self)
166 }
167
168 pub fn model_dump_json(&self) -> String {
169 PyHelperFuncs::__json__(self)
171 }
172
173 #[allow(clippy::too_many_arguments)]
174 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
175 pub fn update_config_args(
176 &mut self,
177 space: Option<String>,
178 name: Option<String>,
179 version: Option<String>,
180 uid: Option<String>,
181 alert_config: Option<GenAIAlertConfig>,
182 ) -> Result<(), TypeError> {
183 if name.is_some() {
184 self.name = name.ok_or(TypeError::MissingNameError)?;
185 }
186
187 if space.is_some() {
188 self.space = space.ok_or(TypeError::MissingSpaceError)?;
189 }
190
191 if version.is_some() {
192 self.version = version.ok_or(TypeError::MissingVersionError)?;
193 }
194
195 if alert_config.is_some() {
196 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
197 }
198
199 if uid.is_some() {
200 self.uid = uid.ok_or(TypeError::MissingUidError)?;
201 }
202
203 Ok(())
204 }
205}
206
207impl Default for GenAIEvalConfig {
208 fn default() -> Self {
209 Self {
210 sample_ratio: 1.0,
211 space: "default".to_string(),
212 name: "default_genai_profile".to_string(),
213 version: DEFAULT_VERSION.to_string(),
214 uid: create_uuid7(),
215 alert_config: GenAIAlertConfig::default(),
216 drift_type: DriftType::GenAI,
217 }
218 }
219}
220
221fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
237 let has_at_least_one_param = !prompt.parameters.is_empty();
238
239 if !has_at_least_one_param {
240 return Err(ProfileError::NeedAtLeastOneBoundParameterError(
241 id.to_string(),
242 ));
243 }
244
245 Ok(())
246}
247
248fn get_workflow_task<'a>(
249 workflow: &'a Workflow,
250 task_id: &'a str,
251) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
252 workflow
253 .task_list
254 .tasks
255 .get(task_id)
256 .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
257}
258
259fn validate_first_tasks(
261 workflow: &Workflow,
262 execution_order: &HashMap<i32, std::collections::HashSet<String>>,
263) -> Result<(), ProfileError> {
264 let first_tasks = execution_order
265 .get(&1)
266 .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
267
268 for task_id in first_tasks {
269 let task = get_workflow_task(workflow, task_id)?;
270 let task_guard = task
271 .read()
272 .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
273
274 validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
275 }
276
277 Ok(())
278}
279
280fn validate_workflow(workflow: &Workflow) -> Result<(), ProfileError> {
296 let execution_order = workflow.execution_plan()?;
297
298 validate_first_tasks(workflow, &execution_order)?;
300
301 Ok(())
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305#[pyclass]
306pub struct ExecutionNode {
307 pub id: String,
308 pub stage: usize,
309 pub parents: Vec<String>,
310 pub children: Vec<String>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, Default)]
314#[pyclass]
315pub struct ExecutionPlan {
316 #[pyo3(get)]
317 pub stages: Vec<Vec<String>>,
318 #[pyo3(get)]
319 pub nodes: HashMap<String, ExecutionNode>,
320}
321
322fn initialize_node_graphs(
323 tasks: &[impl TaskAccessor],
324 graph: &mut HashMap<String, Vec<String>>,
325 reverse_graph: &mut HashMap<String, Vec<String>>,
326 in_degree: &mut HashMap<String, usize>,
327) {
328 for task in tasks {
329 let task_id = task.id().to_string();
330 graph.entry(task_id.clone()).or_default();
331 reverse_graph.entry(task_id.clone()).or_default();
332 in_degree.entry(task_id).or_insert(0);
333 }
334}
335
336fn build_dependency_edges(
337 tasks: &[impl TaskAccessor],
338 graph: &mut HashMap<String, Vec<String>>,
339 reverse_graph: &mut HashMap<String, Vec<String>>,
340 in_degree: &mut HashMap<String, usize>,
341) {
342 for task in tasks {
343 let task_id = task.id().to_string();
344 for dep in task.depends_on() {
345 graph.entry(dep.clone()).or_default().push(task_id.clone());
346 reverse_graph
347 .entry(task_id.clone())
348 .or_default()
349 .push(dep.clone());
350 *in_degree.entry(task_id.clone()).or_insert(0) += 1;
351 }
352 }
353}
354
355#[pyclass]
356#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
357pub struct GenAIEvalProfile {
358 #[pyo3(get)]
359 pub config: GenAIEvalConfig,
360
361 pub tasks: AssertionTasks,
362
363 #[pyo3(get)]
364 pub scouter_version: String,
365
366 pub workflow: Option<Workflow>,
367
368 pub task_ids: BTreeSet<String>,
369
370 pub alias: Option<String>,
371}
372
373#[pymethods]
374impl GenAIEvalProfile {
375 #[new]
376 #[pyo3(signature = (tasks, config=None, alias=None))]
377 #[instrument(skip_all)]
388 pub fn new_py(
389 tasks: &Bound<'_, PyList>,
390 config: Option<GenAIEvalConfig>,
391 alias: Option<String>,
392 ) -> Result<Self, ProfileError> {
393 let tasks = extract_assertion_tasks_from_pylist(tasks)?;
394
395 let (workflow, task_ids) =
396 app_state().block_on(async { Self::build_profile(&tasks).await })?;
397
398 Ok(Self {
399 config: config.unwrap_or_default(),
400 tasks,
401 scouter_version: scouter_version(),
402 workflow,
403 task_ids,
404 alias,
405 })
406 }
407
408 pub fn __str__(&self) -> String {
409 PyHelperFuncs::__str__(self)
411 }
412
413 pub fn model_dump_json(&self) -> String {
414 PyHelperFuncs::__json__(self)
416 }
417
418 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
419 let json_value = serde_json::to_value(self)?;
420
421 let dict = PyDict::new(py);
423
424 json_to_pyobject(py, &json_value, &dict)?;
426
427 Ok(dict.into())
429 }
430
431 #[getter]
432 pub fn drift_type(&self) -> DriftType {
433 self.config.drift_type.clone()
434 }
435
436 #[getter]
437 pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
438 self.tasks.assertion.clone()
439 }
440
441 #[getter]
442 pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
443 self.tasks.judge.clone()
444 }
445
446 #[getter]
447 pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
448 self.tasks.trace.clone()
449 }
450
451 #[getter]
452 pub fn alias(&self) -> Option<String> {
453 self.alias.clone()
454 }
455
456 #[getter]
457 pub fn uid(&self) -> String {
458 self.config.uid.clone()
459 }
460
461 #[setter]
462 pub fn set_uid(&mut self, uid: String) {
463 self.config.uid = uid;
464 }
465
466 #[pyo3(signature = (path=None))]
467 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
468 Ok(PyHelperFuncs::save_to_json(
469 self,
470 path,
471 FileName::GenAIEvalProfile.to_str(),
472 )?)
473 }
474
475 #[staticmethod]
476 pub fn model_validate(data: &Bound<'_, PyDict>) -> GenAIEvalProfile {
477 let json_value = pyobject_to_json(data).unwrap();
478
479 let string = serde_json::to_string(&json_value).unwrap();
480 serde_json::from_str(&string).expect("Failed to load drift profile")
481 }
482
483 #[staticmethod]
484 pub fn model_validate_json(json_string: String) -> GenAIEvalProfile {
485 serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
487 }
488
489 #[staticmethod]
490 pub fn from_file(path: PathBuf) -> Result<GenAIEvalProfile, ProfileError> {
491 let file = std::fs::read_to_string(&path)?;
492
493 Ok(serde_json::from_str(&file)?)
494 }
495
496 #[allow(clippy::too_many_arguments)]
497 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
498 pub fn update_config_args(
499 &mut self,
500 space: Option<String>,
501 name: Option<String>,
502 version: Option<String>,
503 uid: Option<String>,
504 alert_config: Option<GenAIAlertConfig>,
505 ) -> Result<(), TypeError> {
506 self.config
507 .update_config_args(space, name, version, uid, alert_config)
508 }
509
510 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
512 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
513 None
514 } else {
515 Some(self.config.version.clone())
516 };
517
518 Ok(ProfileRequest {
519 space: self.config.space.clone(),
520 profile: self.model_dump_json(),
521 drift_type: self.config.drift_type.clone(),
522 version_request: Some(VersionRequest {
523 version,
524 version_type: VersionType::Minor,
525 pre_tag: None,
526 build_tag: None,
527 }),
528 active: false,
529 deactivate_others: false,
530 })
531 }
532
533 pub fn has_llm_tasks(&self) -> bool {
534 !self.tasks.judge.is_empty()
535 }
536
537 pub fn has_assertions(&self) -> bool {
539 !self.tasks.assertion.is_empty()
540 }
541
542 pub fn has_trace_assertions(&self) -> bool {
543 !self.tasks.trace.is_empty()
544 }
545
546 pub fn get_execution_plan(&self) -> Result<ExecutionPlan, ProfileError> {
548 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
549 let mut reverse_graph: HashMap<String, Vec<String>> = HashMap::new();
550 let mut in_degree: HashMap<String, usize> = HashMap::new();
551
552 initialize_node_graphs(
553 &self.tasks.assertion,
554 &mut graph,
555 &mut reverse_graph,
556 &mut in_degree,
557 );
558 initialize_node_graphs(
559 &self.tasks.judge,
560 &mut graph,
561 &mut reverse_graph,
562 &mut in_degree,
563 );
564
565 initialize_node_graphs(
566 &self.tasks.trace,
567 &mut graph,
568 &mut reverse_graph,
569 &mut in_degree,
570 );
571
572 build_dependency_edges(
573 &self.tasks.assertion,
574 &mut graph,
575 &mut reverse_graph,
576 &mut in_degree,
577 );
578 build_dependency_edges(
579 &self.tasks.judge,
580 &mut graph,
581 &mut reverse_graph,
582 &mut in_degree,
583 );
584
585 build_dependency_edges(
586 &self.tasks.trace,
587 &mut graph,
588 &mut reverse_graph,
589 &mut in_degree,
590 );
591 let mut stages = Vec::new();
592 let mut nodes: HashMap<String, ExecutionNode> = HashMap::new();
593 let mut current_level: Vec<String> = in_degree
594 .iter()
595 .filter(|(_, °ree)| degree == 0)
596 .map(|(id, _)| id.clone())
597 .collect();
598
599 let mut stage_idx = 0;
600
601 while !current_level.is_empty() {
602 stages.push(current_level.clone());
603
604 for task_id in ¤t_level {
605 nodes.insert(
606 task_id.clone(),
607 ExecutionNode {
608 id: task_id.clone(),
609 stage: stage_idx,
610 parents: reverse_graph.get(task_id).cloned().unwrap_or_default(),
611 children: graph.get(task_id).cloned().unwrap_or_default(),
612 },
613 );
614 }
615
616 let mut next_level = Vec::new();
617 for task_id in ¤t_level {
618 if let Some(dependents) = graph.get(task_id) {
619 for dependent in dependents {
620 if let Some(degree) = in_degree.get_mut(dependent) {
621 *degree -= 1;
622 if *degree == 0 {
623 next_level.push(dependent.clone());
624 }
625 }
626 }
627 }
628 }
629
630 current_level = next_level;
631 stage_idx += 1;
632 }
633
634 let total_tasks =
635 self.tasks.assertion.len() + self.tasks.judge.len() + self.tasks.trace.len();
636 let processed_tasks: usize = stages.iter().map(|level| level.len()).sum();
637
638 if processed_tasks != total_tasks {
639 return Err(ProfileError::CircularDependency);
640 }
641
642 Ok(ExecutionPlan { stages, nodes })
643 }
644
645 pub fn print_execution_plan(&self) -> Result<(), ProfileError> {
646 use owo_colors::OwoColorize;
647
648 let plan = self.get_execution_plan()?;
649
650 println!("\n{}", "Evaluation Execution Plan".bold().green());
651 println!("{}", "═".repeat(70).green());
652
653 let mut conditional_count = 0;
654
655 for (level_idx, level) in plan.stages.iter().enumerate() {
656 let stage_label = format!("Stage {}", level_idx + 1);
657 println!("\n{}", stage_label.bold().cyan());
658
659 for (task_idx, task_id) in level.iter().enumerate() {
660 let is_last = task_idx == level.len() - 1;
661 let prefix = if is_last { "└─" } else { "├─" };
662
663 let task = self.get_task_by_id(task_id).ok_or_else(|| {
664 ProfileError::NoTasksFoundError(format!("Task '{}' not found", task_id))
665 })?;
666
667 let is_conditional = if let Some(assertion) = self.get_assertion_by_id(task_id) {
668 assertion.condition
669 } else if let Some(judge) = self.get_llm_judge_by_id(task_id) {
670 judge.condition
671 } else if let Some(trace) = self.get_trace_assertion_by_id(task_id) {
672 trace.condition
673 } else {
674 false
675 };
676
677 if is_conditional {
678 conditional_count += 1;
679 }
680
681 let (task_type, color_fn): (&str, fn(&str) -> String) =
682 if self.tasks.assertion.iter().any(|t| &t.id == task_id) {
683 ("Assertion", |s: &str| s.yellow().to_string())
684 } else if self.tasks.trace.iter().any(|t| &t.id == task_id) {
685 ("Trace Assertion", |s: &str| s.bright_blue().to_string())
686 } else {
687 ("LLM Judge", |s: &str| s.purple().to_string())
688 };
689
690 let conditional_marker = if is_conditional {
691 " [CONDITIONAL]".bright_red().to_string()
692 } else {
693 String::new()
694 };
695
696 println!(
697 "{} {} ({}){}",
698 prefix,
699 task_id.bold(),
700 color_fn(task_type),
701 conditional_marker
702 );
703
704 let deps = task.depends_on();
705 if !deps.is_empty() {
706 let dep_prefix = if is_last { " " } else { "│ " };
707
708 let (conditional_deps, normal_deps): (Vec<_>, Vec<_>) =
709 deps.iter().partition(|dep_id| {
710 self.get_assertion_by_id(dep_id)
711 .map(|t| t.condition)
712 .or_else(|| self.get_llm_judge_by_id(dep_id).map(|t| t.condition))
713 .or_else(|| {
714 self.get_trace_assertion_by_id(dep_id).map(|t| t.condition)
715 })
716 .unwrap_or(false)
717 });
718
719 if !normal_deps.is_empty() {
720 println!(
721 "{} {} {}",
722 dep_prefix,
723 "depends on:".dimmed(),
724 normal_deps
725 .iter()
726 .map(|s| s.as_str())
727 .collect::<Vec<_>>()
728 .join(", ")
729 .dimmed()
730 );
731 }
732
733 if !conditional_deps.is_empty() {
734 println!(
735 "{} {} {}",
736 dep_prefix,
737 "▶ conditional gate:".bright_red().dimmed(),
738 conditional_deps
739 .iter()
740 .map(|d| format!("{} must pass", d))
741 .collect::<Vec<_>>()
742 .join(", ")
743 .red()
744 .dimmed()
745 );
746 }
747 }
748
749 if is_conditional {
750 let continuation = if is_last { " " } else { "│ " };
751 println!(
752 "{} {} {}",
753 continuation,
754 "▶".bright_red(),
755 "creates conditional branch".bright_red().dimmed()
756 );
757 }
758 }
759 }
760
761 println!("\n{}", "═".repeat(70).green());
762 println!(
763 "{}: {} tasks across {} stages",
764 "Summary".bold(),
765 self.tasks.assertion.len() + self.tasks.judge.len() + self.tasks.trace.len(),
766 plan.stages.len()
767 );
768
769 if conditional_count > 0 {
770 println!(
771 "{}: {} conditional tasks that create execution branches",
772 "Branches".bold().bright_red(),
773 conditional_count
774 );
775 }
776
777 println!();
778
779 Ok(())
780 }
781}
782
783impl Default for GenAIEvalProfile {
784 fn default() -> Self {
785 Self {
786 config: GenAIEvalConfig::default(),
787 tasks: AssertionTasks {
788 assertion: Vec::new(),
789 judge: Vec::new(),
790 trace: Vec::new(),
791 },
792 scouter_version: scouter_version(),
793 workflow: None,
794 task_ids: BTreeSet::new(),
795 alias: None,
796 }
797 }
798}
799
800impl GenAIEvalProfile {
801 pub fn build_from_parts(
803 config: GenAIEvalConfig,
804 tasks: AssertionTasks,
805 alias: Option<String>,
806 ) -> Result<GenAIEvalProfile, ProfileError> {
807 let (workflow, task_ids) =
808 app_state().block_on(async { GenAIEvalProfile::build_profile(&tasks).await })?;
809
810 Ok(GenAIEvalProfile {
811 config,
812 tasks,
813 scouter_version: scouter_version(),
814 workflow,
815 task_ids,
816 alias,
817 })
818 }
819
820 #[instrument(skip_all)]
821 pub async fn new(
822 config: GenAIEvalConfig,
823 tasks: Vec<EvaluationTask>,
824 ) -> Result<Self, ProfileError> {
825 let tasks = separate_tasks(tasks);
826 let (workflow, task_ids) = Self::build_profile(&tasks).await?;
827
828 Ok(Self {
829 config,
830 tasks,
831 scouter_version: scouter_version(),
832 workflow,
833 task_ids,
834 alias: None,
835 })
836 }
837
838 async fn build_profile(
839 tasks: &AssertionTasks,
840 ) -> Result<(Option<Workflow>, BTreeSet<String>), ProfileError> {
841 if tasks.assertion.is_empty() && tasks.judge.is_empty() && tasks.trace.is_empty() {
842 return Err(ProfileError::EmptyTaskList);
843 }
844
845 let workflow = if !tasks.judge.is_empty() {
846 let workflow = Self::build_workflow_from_judges(tasks).await?;
847 validate_workflow(&workflow)?;
848 Some(workflow)
849 } else {
850 None
851 };
852
853 for judge in &tasks.judge {
855 validate_prompt_parameters(&judge.prompt, &judge.id)?;
856 }
857
858 let task_ids = tasks.collect_all_task_ids()?;
860
861 Ok((workflow, task_ids))
862 }
863
864 async fn get_or_create_agent(
865 agents: &mut HashMap<potato_head::Provider, Agent>,
866 workflow: &mut Workflow,
867 provider: &potato_head::Provider,
868 ) -> Result<Agent, ProfileError> {
869 match agents.entry(provider.clone()) {
870 Entry::Occupied(entry) => Ok(entry.get().clone()),
871 Entry::Vacant(entry) => {
872 let agent = Agent::new(provider.clone(), None).await?;
873 workflow.add_agent(&agent);
874 Ok(entry.insert(agent).clone())
875 }
876 }
877 }
878
879 fn filter_judge_dependencies(
880 depends_on: &[String],
881 non_judge_task_ids: &BTreeSet<String>,
882 ) -> Option<Vec<String>> {
883 let filtered: Vec<String> = depends_on
884 .iter()
885 .filter(|dep_id| !non_judge_task_ids.contains(*dep_id))
886 .cloned()
887 .collect();
888
889 if filtered.is_empty() {
890 None
891 } else {
892 Some(filtered)
893 }
894 }
895
896 pub async fn build_workflow_from_judges(
902 tasks: &AssertionTasks,
903 ) -> Result<Workflow, ProfileError> {
904 let mut workflow = Workflow::new(&format!("eval_workflow_{}", create_uuid7()));
905 let mut agents = HashMap::new();
906 let non_judge_task_ids = tasks.collect_non_judge_task_ids();
907
908 for judge in &tasks.judge {
909 let agent =
910 Self::get_or_create_agent(&mut agents, &mut workflow, &judge.prompt.provider)
911 .await?;
912
913 let task_deps = Self::filter_judge_dependencies(&judge.depends_on, &non_judge_task_ids);
914
915 let task = Task::new(
916 &agent.id,
917 judge.prompt.clone(),
918 &judge.id,
919 task_deps,
920 judge.max_retries,
921 )?;
922
923 workflow.add_task(task)?;
924 }
925
926 Ok(workflow)
927 }
928}
929
930impl ProfileExt for GenAIEvalProfile {
931 #[inline]
932 fn id(&self) -> &str {
933 &self.config.uid
934 }
935
936 fn get_task_by_id(&self, id: &str) -> Option<&dyn TaskAccessor> {
937 if let Some(assertion) = self.tasks.assertion.iter().find(|t| t.id() == id) {
938 return Some(assertion);
939 }
940
941 if let Some(judge) = self.tasks.judge.iter().find(|t| t.id() == id) {
942 return Some(judge);
943 }
944
945 if let Some(trace) = self.tasks.trace.iter().find(|t| t.id() == id) {
946 return Some(trace);
947 }
948
949 None
950 }
951
952 #[inline]
953 fn get_assertion_by_id(&self, id: &str) -> Option<&AssertionTask> {
955 self.tasks.assertion.iter().find(|t| t.id() == id)
956 }
957
958 #[inline]
959 fn get_llm_judge_by_id(&self, id: &str) -> Option<&LLMJudgeTask> {
960 self.tasks.judge.iter().find(|t| t.id() == id)
961 }
962
963 #[inline]
964 fn get_trace_assertion_by_id(&self, id: &str) -> Option<&TraceAssertionTask> {
965 self.tasks.trace.iter().find(|t| t.id() == id)
966 }
967
968 #[inline]
969 fn has_llm_tasks(&self) -> bool {
970 !self.tasks.judge.is_empty()
971 }
972
973 #[inline]
974 fn has_trace_assertions(&self) -> bool {
975 !self.tasks.trace.is_empty()
976 }
977}
978
979impl ProfileBaseArgs for GenAIEvalProfile {
980 type Config = GenAIEvalConfig;
981
982 fn config(&self) -> &Self::Config {
983 &self.config
984 }
985 fn get_base_args(&self) -> ProfileArgs {
986 ProfileArgs {
987 name: self.config.name.clone(),
988 space: self.config.space.clone(),
989 version: Some(self.config.version.clone()),
990 schedule: self.config.alert_config.schedule.clone(),
991 scouter_version: self.scouter_version.clone(),
992 drift_type: self.config.drift_type.clone(),
993 }
994 }
995
996 fn to_value(&self) -> Value {
997 serde_json::to_value(self).unwrap()
998 }
999}
1000
1001#[pyclass]
1002#[derive(Debug, Serialize, Deserialize, Clone)]
1003pub struct GenAIEvalSet {
1004 #[pyo3(get)]
1005 pub records: Vec<GenAIEvalTaskResult>,
1006 pub inner: GenAIEvalWorkflowResult,
1007}
1008
1009impl GenAIEvalSet {
1010 pub fn build_task_entries(&mut self, record_id: &str) -> Vec<TaskResultTableEntry> {
1011 self.records
1014 .sort_by(|a, b| a.stage.cmp(&b.stage).then(a.task_id.cmp(&b.task_id)));
1015
1016 self.records
1017 .iter()
1018 .map(|record| record.to_table_entry(record_id))
1019 .collect()
1020 }
1021
1022 pub fn build_workflow_entries(&self) -> Vec<WorkflowResultTableEntry> {
1023 vec![self.inner.to_table_entry()]
1024 }
1025
1026 pub fn new(records: Vec<GenAIEvalTaskResult>, inner: GenAIEvalWorkflowResult) -> Self {
1027 Self { records, inner }
1028 }
1029
1030 pub fn empty() -> Self {
1031 Self {
1032 records: Vec::new(),
1033 inner: GenAIEvalWorkflowResult {
1034 created_at: Utc::now(),
1035 record_uid: String::new(),
1036 entity_id: 0,
1037 total_tasks: 0,
1038 passed_tasks: 0,
1039 failed_tasks: 0,
1040 pass_rate: 0.0,
1041 duration_ms: 0,
1042 entity_uid: String::new(),
1043 execution_plan: ExecutionPlan::default(),
1044 id: 0,
1045 },
1046 }
1047 }
1048}
1049
1050#[pymethods]
1051impl GenAIEvalSet {
1052 #[getter]
1053 pub fn created_at(&self) -> DateTime<Utc> {
1054 self.inner.created_at
1055 }
1056
1057 #[getter]
1058 pub fn record_uid(&self) -> String {
1059 self.inner.record_uid.clone()
1060 }
1061
1062 #[getter]
1063 pub fn total_tasks(&self) -> i32 {
1064 self.inner.total_tasks
1065 }
1066
1067 #[getter]
1068 pub fn passed_tasks(&self) -> i32 {
1069 self.inner.passed_tasks
1070 }
1071
1072 #[getter]
1073 pub fn failed_tasks(&self) -> i32 {
1074 self.inner.failed_tasks
1075 }
1076
1077 #[getter]
1078 pub fn pass_rate(&self) -> f64 {
1079 self.inner.pass_rate
1080 }
1081
1082 #[getter]
1083 pub fn duration_ms(&self) -> i64 {
1084 self.inner.duration_ms
1085 }
1086
1087 pub fn __str__(&self) -> String {
1088 PyHelperFuncs::__str__(self)
1090 }
1091}
1092
1093#[pyclass]
1094#[derive(Debug, Serialize, Deserialize, Clone)]
1095pub struct GenAIEvalResultSet {
1096 #[pyo3(get)]
1097 pub records: Vec<GenAIEvalSet>,
1098}
1099
1100#[pymethods]
1101impl GenAIEvalResultSet {
1102 pub fn record(&self, id: &str) -> Option<GenAIEvalSet> {
1103 self.records.iter().find(|r| r.record_uid() == id).cloned()
1104 }
1105 pub fn __str__(&self) -> String {
1106 PyHelperFuncs::__str__(self)
1108 }
1109}
1110
1111#[cfg(test)]
1113#[cfg(feature = "mock")]
1114mod tests {
1115
1116 use super::*;
1117 use crate::genai::{ComparisonOperator, EvaluationTasks};
1118 use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
1119
1120 use potato_head::mock::create_score_prompt;
1121
1122 #[test]
1123 fn test_genai_drift_config() {
1124 let mut drift_config = GenAIEvalConfig::new(
1125 MISSING,
1126 MISSING,
1127 "0.1.0",
1128 1.0,
1129 GenAIAlertConfig::default(),
1130 None,
1131 )
1132 .unwrap();
1133 assert_eq!(drift_config.name, "__missing__");
1134 assert_eq!(drift_config.space, "__missing__");
1135 assert_eq!(drift_config.version, "0.1.0");
1136 assert_eq!(
1137 drift_config.alert_config.dispatch_config,
1138 AlertDispatchConfig::default()
1139 );
1140
1141 let test_slack_dispatch_config = SlackDispatchConfig {
1142 channel: "test-channel".to_string(),
1143 };
1144 let new_alert_config = GenAIAlertConfig {
1145 schedule: "0 0 * * * *".to_string(),
1146 dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
1147 ..Default::default()
1148 };
1149
1150 drift_config
1151 .update_config_args(
1152 None,
1153 Some("test".to_string()),
1154 None,
1155 None,
1156 Some(new_alert_config),
1157 )
1158 .unwrap();
1159
1160 assert_eq!(drift_config.name, "test");
1161 assert_eq!(
1162 drift_config.alert_config.dispatch_config,
1163 AlertDispatchConfig::Slack(test_slack_dispatch_config)
1164 );
1165 assert_eq!(
1166 drift_config.alert_config.schedule,
1167 "0 0 * * * *".to_string()
1168 );
1169 }
1170
1171 #[tokio::test]
1172 async fn test_genai_drift_profile_metric() {
1173 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
1174
1175 let task1 = LLMJudgeTask::new_rs(
1176 "metric1",
1177 prompt.clone(),
1178 Value::Number(4.into()),
1179 None,
1180 ComparisonOperator::GreaterThanOrEqual,
1181 None,
1182 None,
1183 None,
1184 None,
1185 );
1186
1187 let task2 = LLMJudgeTask::new_rs(
1188 "metric2",
1189 prompt.clone(),
1190 Value::Number(2.into()),
1191 None,
1192 ComparisonOperator::LessThanOrEqual,
1193 None,
1194 None,
1195 None,
1196 None,
1197 );
1198
1199 let tasks = EvaluationTasks::new()
1200 .add_task(task1)
1201 .add_task(task2)
1202 .build();
1203
1204 let alert_config = GenAIAlertConfig {
1205 schedule: "0 0 * * * *".to_string(),
1206 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
1207 team: "test-team".to_string(),
1208 priority: "P5".to_string(),
1209 }),
1210 ..Default::default()
1211 };
1212
1213 let drift_config =
1214 GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
1215
1216 let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();
1217
1218 let _: Value =
1219 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
1220
1221 assert_eq!(profile.llm_judge_tasks().len(), 2);
1222 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
1223 }
1224}