1use crate::error::{ProfileError, TypeError};
2use crate::genai::alert::GenAIAlertConfig;
3use crate::genai::eval::{AssertionTask, EvaluationTask, LLMJudgeTask};
4use crate::genai::traits::{separate_tasks, ProfileExt, TaskAccessor};
5use crate::genai::utils::{extract_assertion_tasks_from_pylist, AssertionTasks};
6use crate::genai::{AgentAssertionTask, TraceAssertionTask};
7use crate::traits::ConfigExt;
8use crate::util::{json_to_pyobject, pyobject_to_json};
9use crate::{scouter_version, EvalTaskResult, GenAIEvalWorkflowResult, WorkflowResultTableEntry};
10use crate::{
11 DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
12 PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
13};
14use crate::{ProfileRequest, TaskResultTableEntry};
15use chrono::{DateTime, Utc};
16use core::fmt::Debug;
17use potato_head::prompt_types::Prompt;
18use potato_head::Agent;
19use potato_head::Workflow;
20use potato_head::{create_uuid7, Task};
21use pyo3::prelude::*;
22use pyo3::types::{PyDict, PyList};
23use scouter_semver::VersionType;
24use scouter_state::app_state;
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::hash_map::Entry;
28use std::collections::BTreeSet;
29use std::collections::HashMap;
30use std::path::PathBuf;
31use std::sync::Arc;
32
33use tracing::instrument;
34
35fn default_sample_ratio() -> f64 {
36 1.0
37}
38
39fn default_space() -> String {
40 MISSING.to_string()
41}
42
43fn default_name() -> String {
44 MISSING.to_string()
45}
46
47fn default_version() -> String {
48 DEFAULT_VERSION.to_string()
49}
50
51fn default_uid() -> String {
52 create_uuid7()
53}
54
55fn default_drift_type() -> DriftType {
56 DriftType::GenAI
57}
58
59fn default_alert_config() -> GenAIAlertConfig {
60 GenAIAlertConfig::default()
61}
62
63#[pyclass]
64#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
65pub struct GenAIEvalConfig {
66 #[pyo3(get, set)]
67 #[serde(default = "default_sample_ratio")]
68 pub sample_ratio: f64,
69
70 #[pyo3(get, set)]
71 #[serde(default = "default_space")]
72 pub space: String,
73
74 #[pyo3(get, set)]
75 #[serde(default = "default_name")]
76 pub name: String,
77
78 #[pyo3(get, set)]
79 #[serde(default = "default_version")]
80 pub version: String,
81
82 #[pyo3(get, set)]
83 #[serde(default = "default_uid")]
84 pub uid: String,
85
86 #[pyo3(get, set)]
87 #[serde(default = "default_alert_config")]
88 pub alert_config: GenAIAlertConfig,
89
90 #[pyo3(get, set)]
91 #[serde(default = "default_drift_type")]
92 pub drift_type: DriftType,
93}
94
95impl ConfigExt for GenAIEvalConfig {
96 fn space(&self) -> &str {
97 &self.space
98 }
99
100 fn name(&self) -> &str {
101 &self.name
102 }
103
104 fn version(&self) -> &str {
105 &self.version
106 }
107 fn uid(&self) -> &str {
108 &self.uid
109 }
110}
111
112impl DispatchDriftConfig for GenAIEvalConfig {
113 fn get_drift_args(&self) -> DriftArgs {
114 DriftArgs {
115 name: self.name.clone(),
116 space: self.space.clone(),
117 version: self.version.clone(),
118 dispatch_config: self.alert_config.dispatch_config.clone(),
119 }
120 }
121}
122
123#[pymethods]
124#[allow(clippy::too_many_arguments)]
125impl GenAIEvalConfig {
126 #[new]
127 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_ratio=1.0, alert_config=GenAIAlertConfig::default(), config_path=None))]
128 pub fn new(
129 space: &str,
130 name: &str,
131 version: &str,
132 sample_ratio: f64,
133 alert_config: GenAIAlertConfig,
134 config_path: Option<PathBuf>,
135 ) -> Result<Self, ProfileError> {
136 if let Some(config_path) = config_path {
137 let config = GenAIEvalConfig::load_from_json_file(config_path)?;
138 return Ok(config);
139 }
140
141 Ok(Self {
142 sample_ratio: sample_ratio.clamp(0.0, 1.0),
143 space: space.to_string(),
144 name: name.to_string(),
145 uid: create_uuid7(),
146 version: version.to_string(),
147 alert_config,
148 drift_type: DriftType::GenAI,
149 })
150 }
151
152 #[staticmethod]
153 pub fn load_from_json_file(path: PathBuf) -> Result<GenAIEvalConfig, ProfileError> {
154 let file = std::fs::read_to_string(&path)?;
157
158 Ok(serde_json::from_str(&file)?)
159 }
160
161 pub fn __str__(&self) -> String {
162 PyHelperFuncs::__str__(self)
164 }
165
166 pub fn model_dump_json(&self) -> String {
167 PyHelperFuncs::__json__(self)
169 }
170
171 #[allow(clippy::too_many_arguments)]
172 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
173 pub fn update_config_args(
174 &mut self,
175 space: Option<String>,
176 name: Option<String>,
177 version: Option<String>,
178 uid: Option<String>,
179 alert_config: Option<GenAIAlertConfig>,
180 ) -> Result<(), TypeError> {
181 if name.is_some() {
182 self.name = name.ok_or(TypeError::MissingNameError)?;
183 }
184
185 if space.is_some() {
186 self.space = space.ok_or(TypeError::MissingSpaceError)?;
187 }
188
189 if version.is_some() {
190 self.version = version.ok_or(TypeError::MissingVersionError)?;
191 }
192
193 if alert_config.is_some() {
194 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
195 }
196
197 if uid.is_some() {
198 self.uid = uid.ok_or(TypeError::MissingUidError)?;
199 }
200
201 Ok(())
202 }
203}
204
205impl Default for GenAIEvalConfig {
206 fn default() -> Self {
207 Self {
208 sample_ratio: 1.0,
209 space: "default".to_string(),
210 name: "default_genai_profile".to_string(),
211 version: DEFAULT_VERSION.to_string(),
212 uid: create_uuid7(),
213 alert_config: GenAIAlertConfig::default(),
214 drift_type: DriftType::GenAI,
215 }
216 }
217}
218
219fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
235 let has_at_least_one_param = !prompt.parameters.is_empty();
236
237 if !has_at_least_one_param {
238 return Err(ProfileError::NeedAtLeastOneBoundParameterError(
239 id.to_string(),
240 ));
241 }
242
243 Ok(())
244}
245
246fn get_workflow_task<'a>(
247 workflow: &'a Workflow,
248 task_id: &'a str,
249) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
250 workflow
251 .task_list
252 .tasks
253 .get(task_id)
254 .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
255}
256
257fn validate_first_tasks(
259 workflow: &Workflow,
260 execution_order: &HashMap<i32, std::collections::HashSet<String>>,
261) -> Result<(), ProfileError> {
262 let first_tasks = execution_order
263 .get(&1)
264 .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
265
266 for task_id in first_tasks {
267 let task = get_workflow_task(workflow, task_id)?;
268 let task_guard = task
269 .read()
270 .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
271
272 validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
273 }
274
275 Ok(())
276}
277
278fn validate_workflow(workflow: &Workflow) -> Result<(), ProfileError> {
294 let execution_order = workflow.execution_plan()?;
295
296 validate_first_tasks(workflow, &execution_order)?;
298
299 Ok(())
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
303#[pyclass]
304pub struct ExecutionNode {
305 pub id: String,
306 pub stage: usize,
307 pub parents: Vec<String>,
308 pub children: Vec<String>,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize, Default)]
312#[pyclass]
313pub struct ExecutionPlan {
314 #[pyo3(get)]
315 pub stages: Vec<Vec<String>>,
316 #[pyo3(get)]
317 pub nodes: HashMap<String, ExecutionNode>,
318}
319
320fn initialize_node_graphs(
321 tasks: &[impl TaskAccessor],
322 graph: &mut HashMap<String, Vec<String>>,
323 reverse_graph: &mut HashMap<String, Vec<String>>,
324 in_degree: &mut HashMap<String, usize>,
325) {
326 for task in tasks {
327 let task_id = task.id().to_string();
328 graph.entry(task_id.clone()).or_default();
329 reverse_graph.entry(task_id.clone()).or_default();
330 in_degree.entry(task_id).or_insert(0);
331 }
332}
333
334fn build_dependency_edges(
335 tasks: &[impl TaskAccessor],
336 graph: &mut HashMap<String, Vec<String>>,
337 reverse_graph: &mut HashMap<String, Vec<String>>,
338 in_degree: &mut HashMap<String, usize>,
339) {
340 for task in tasks {
341 let task_id = task.id().to_string();
342 for dep in task.depends_on() {
343 graph.entry(dep.clone()).or_default().push(task_id.clone());
344 reverse_graph
345 .entry(task_id.clone())
346 .or_default()
347 .push(dep.clone());
348 *in_degree.entry(task_id.clone()).or_insert(0) += 1;
349 }
350 }
351}
352
353#[pyclass]
354#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
355pub struct GenAIEvalProfile {
356 #[pyo3(get)]
357 pub config: GenAIEvalConfig,
358
359 pub tasks: AssertionTasks,
360
361 #[pyo3(get)]
362 pub scouter_version: String,
363
364 pub workflow: Option<Workflow>,
365
366 pub task_ids: BTreeSet<String>,
367
368 pub alias: Option<String>,
369}
370
371#[pymethods]
372impl GenAIEvalProfile {
373 #[new]
374 #[pyo3(signature = (tasks, config=None, alias=None))]
375 #[instrument(skip_all)]
386 pub fn new_py(
387 tasks: &Bound<'_, PyList>,
388 config: Option<GenAIEvalConfig>,
389 alias: Option<String>,
390 ) -> Result<Self, ProfileError> {
391 let tasks = extract_assertion_tasks_from_pylist(tasks)?;
392
393 let (workflow, task_ids) =
394 app_state().block_on(async { Self::build_profile(&tasks).await })?;
395
396 Ok(Self {
397 config: config.unwrap_or_default(),
398 tasks,
399 scouter_version: scouter_version(),
400 workflow,
401 task_ids,
402 alias,
403 })
404 }
405
406 pub fn __str__(&self) -> String {
407 PyHelperFuncs::__str__(self)
409 }
410
411 pub fn model_dump_json(&self) -> String {
412 PyHelperFuncs::__json__(self)
414 }
415
416 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
417 let json_value = serde_json::to_value(self)?;
418
419 let dict = PyDict::new(py);
421
422 json_to_pyobject(py, &json_value, &dict)?;
424
425 Ok(dict.into())
427 }
428
429 #[getter]
430 pub fn drift_type(&self) -> DriftType {
431 self.config.drift_type.clone()
432 }
433
434 #[getter]
435 pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
436 self.tasks.assertion.clone()
437 }
438
439 #[getter]
440 pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
441 self.tasks.judge.clone()
442 }
443
444 #[getter]
445 pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
446 self.tasks.trace.clone()
447 }
448
449 #[getter]
450 pub fn agent_assertion_tasks(&self) -> Vec<AgentAssertionTask> {
451 self.tasks.agent.clone()
452 }
453
454 #[getter]
455 pub fn alias(&self) -> Option<String> {
456 self.alias.clone()
457 }
458
459 #[getter]
460 pub fn uid(&self) -> String {
461 self.config.uid.clone()
462 }
463
464 #[setter]
465 pub fn set_uid(&mut self, uid: String) {
466 self.config.uid = uid;
467 }
468
469 #[pyo3(signature = (path=None))]
470 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
471 Ok(PyHelperFuncs::save_to_json(
472 self,
473 path,
474 FileName::GenAIEvalProfile.to_str(),
475 )?)
476 }
477
478 #[staticmethod]
479 pub fn model_validate(data: &Bound<'_, PyDict>) -> GenAIEvalProfile {
480 let json_value = pyobject_to_json(data).unwrap();
481
482 let string = serde_json::to_string(&json_value).unwrap();
483 serde_json::from_str(&string).expect("Failed to load drift profile")
484 }
485
486 #[staticmethod]
487 pub fn model_validate_json(json_string: String) -> GenAIEvalProfile {
488 serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
490 }
491
492 #[staticmethod]
493 pub fn from_file(path: PathBuf) -> Result<GenAIEvalProfile, ProfileError> {
494 let file = std::fs::read_to_string(&path)?;
495
496 Ok(serde_json::from_str(&file)?)
497 }
498
499 #[allow(clippy::too_many_arguments)]
500 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
501 pub fn update_config_args(
502 &mut self,
503 space: Option<String>,
504 name: Option<String>,
505 version: Option<String>,
506 uid: Option<String>,
507 alert_config: Option<GenAIAlertConfig>,
508 ) -> Result<(), TypeError> {
509 self.config
510 .update_config_args(space, name, version, uid, alert_config)
511 }
512
513 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
515 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
516 None
517 } else {
518 Some(self.config.version.clone())
519 };
520
521 Ok(ProfileRequest {
522 space: self.config.space.clone(),
523 profile: self.model_dump_json(),
524 drift_type: self.config.drift_type.clone(),
525 version_request: Some(VersionRequest {
526 version,
527 version_type: VersionType::Minor,
528 pre_tag: None,
529 build_tag: None,
530 }),
531 active: false,
532 deactivate_others: false,
533 })
534 }
535
536 pub fn has_llm_tasks(&self) -> bool {
537 !self.tasks.judge.is_empty()
538 }
539
540 pub fn has_assertions(&self) -> bool {
542 !self.tasks.assertion.is_empty()
543 }
544
545 pub fn has_trace_assertions(&self) -> bool {
546 !self.tasks.trace.is_empty()
547 }
548
549 pub fn has_agent_assertions(&self) -> bool {
550 !self.tasks.agent.is_empty()
551 }
552
553 pub fn get_execution_plan(&self) -> Result<ExecutionPlan, ProfileError> {
555 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
556 let mut reverse_graph: HashMap<String, Vec<String>> = HashMap::new();
557 let mut in_degree: HashMap<String, usize> = HashMap::new();
558
559 initialize_node_graphs(
560 &self.tasks.assertion,
561 &mut graph,
562 &mut reverse_graph,
563 &mut in_degree,
564 );
565 initialize_node_graphs(
566 &self.tasks.judge,
567 &mut graph,
568 &mut reverse_graph,
569 &mut in_degree,
570 );
571
572 initialize_node_graphs(
573 &self.tasks.trace,
574 &mut graph,
575 &mut reverse_graph,
576 &mut in_degree,
577 );
578
579 initialize_node_graphs(
580 &self.tasks.agent,
581 &mut graph,
582 &mut reverse_graph,
583 &mut in_degree,
584 );
585
586 build_dependency_edges(
587 &self.tasks.assertion,
588 &mut graph,
589 &mut reverse_graph,
590 &mut in_degree,
591 );
592 build_dependency_edges(
593 &self.tasks.judge,
594 &mut graph,
595 &mut reverse_graph,
596 &mut in_degree,
597 );
598
599 build_dependency_edges(
600 &self.tasks.trace,
601 &mut graph,
602 &mut reverse_graph,
603 &mut in_degree,
604 );
605
606 build_dependency_edges(
607 &self.tasks.agent,
608 &mut graph,
609 &mut reverse_graph,
610 &mut in_degree,
611 );
612 let mut stages = Vec::new();
613 let mut nodes: HashMap<String, ExecutionNode> = HashMap::new();
614 let mut current_level: Vec<String> = in_degree
615 .iter()
616 .filter(|(_, °ree)| degree == 0)
617 .map(|(id, _)| id.clone())
618 .collect();
619
620 let mut stage_idx = 0;
621
622 while !current_level.is_empty() {
623 stages.push(current_level.clone());
624
625 for task_id in ¤t_level {
626 nodes.insert(
627 task_id.clone(),
628 ExecutionNode {
629 id: task_id.clone(),
630 stage: stage_idx,
631 parents: reverse_graph.get(task_id).cloned().unwrap_or_default(),
632 children: graph.get(task_id).cloned().unwrap_or_default(),
633 },
634 );
635 }
636
637 let mut next_level = Vec::new();
638 for task_id in ¤t_level {
639 if let Some(dependents) = graph.get(task_id) {
640 for dependent in dependents {
641 if let Some(degree) = in_degree.get_mut(dependent) {
642 *degree -= 1;
643 if *degree == 0 {
644 next_level.push(dependent.clone());
645 }
646 }
647 }
648 }
649 }
650
651 current_level = next_level;
652 stage_idx += 1;
653 }
654
655 let total_tasks = self.tasks.assertion.len()
656 + self.tasks.judge.len()
657 + self.tasks.trace.len()
658 + self.tasks.agent.len();
659 let processed_tasks: usize = stages.iter().map(|level| level.len()).sum();
660
661 if processed_tasks != total_tasks {
662 return Err(ProfileError::CircularDependency);
663 }
664
665 Ok(ExecutionPlan { stages, nodes })
666 }
667
668 pub fn print_execution_plan(&self) -> Result<(), ProfileError> {
669 use owo_colors::OwoColorize;
670
671 let plan = self.get_execution_plan()?;
672
673 println!("\n{}", "Evaluation Execution Plan".bold().green());
674 println!("{}", "═".repeat(70).green());
675
676 let mut conditional_count = 0;
677
678 for (level_idx, level) in plan.stages.iter().enumerate() {
679 let stage_label = format!("Stage {}", level_idx + 1);
680 println!("\n{}", stage_label.bold().cyan());
681
682 for (task_idx, task_id) in level.iter().enumerate() {
683 let is_last = task_idx == level.len() - 1;
684 let prefix = if is_last { "└─" } else { "├─" };
685
686 let task = self.get_task_by_id(task_id).ok_or_else(|| {
687 ProfileError::NoTasksFoundError(format!("Task '{}' not found", task_id))
688 })?;
689
690 let is_conditional = if let Some(assertion) = self.get_assertion_by_id(task_id) {
691 assertion.condition
692 } else if let Some(judge) = self.get_llm_judge_by_id(task_id) {
693 judge.condition
694 } else if let Some(trace) = self.get_trace_assertion_by_id(task_id) {
695 trace.condition
696 } else if let Some(request) = self.get_agent_assertion_by_id(task_id) {
697 request.condition
698 } else {
699 false
700 };
701
702 if is_conditional {
703 conditional_count += 1;
704 }
705
706 let (task_type, color_fn): (&str, fn(&str) -> String) =
707 if self.tasks.assertion.iter().any(|t| &t.id == task_id) {
708 ("Assertion", |s: &str| s.yellow().to_string())
709 } else if self.tasks.trace.iter().any(|t| &t.id == task_id) {
710 ("Trace Assertion", |s: &str| s.bright_blue().to_string())
711 } else if self.tasks.agent.iter().any(|t| &t.id == task_id) {
712 ("Request Assertion", |s: &str| s.bright_green().to_string())
713 } else {
714 ("LLM Judge", |s: &str| s.purple().to_string())
715 };
716
717 let conditional_marker = if is_conditional {
718 " [CONDITIONAL]".bright_red().to_string()
719 } else {
720 String::new()
721 };
722
723 println!(
724 "{} {} ({}){}",
725 prefix,
726 task_id.bold(),
727 color_fn(task_type),
728 conditional_marker
729 );
730
731 let deps = task.depends_on();
732 if !deps.is_empty() {
733 let dep_prefix = if is_last { " " } else { "│ " };
734
735 let (conditional_deps, normal_deps): (Vec<_>, Vec<_>) =
736 deps.iter().partition(|dep_id| {
737 self.get_assertion_by_id(dep_id)
738 .map(|t| t.condition)
739 .or_else(|| self.get_llm_judge_by_id(dep_id).map(|t| t.condition))
740 .or_else(|| {
741 self.get_trace_assertion_by_id(dep_id).map(|t| t.condition)
742 })
743 .or_else(|| {
744 self.get_agent_assertion_by_id(dep_id).map(|t| t.condition)
745 })
746 .unwrap_or(false)
747 });
748
749 if !normal_deps.is_empty() {
750 println!(
751 "{} {} {}",
752 dep_prefix,
753 "depends on:".dimmed(),
754 normal_deps
755 .iter()
756 .map(|s| s.as_str())
757 .collect::<Vec<_>>()
758 .join(", ")
759 .dimmed()
760 );
761 }
762
763 if !conditional_deps.is_empty() {
764 println!(
765 "{} {} {}",
766 dep_prefix,
767 "▶ conditional gate:".bright_red().dimmed(),
768 conditional_deps
769 .iter()
770 .map(|d| format!("{} must pass", d))
771 .collect::<Vec<_>>()
772 .join(", ")
773 .red()
774 .dimmed()
775 );
776 }
777 }
778
779 if is_conditional {
780 let continuation = if is_last { " " } else { "│ " };
781 println!(
782 "{} {} {}",
783 continuation,
784 "▶".bright_red(),
785 "creates conditional branch".bright_red().dimmed()
786 );
787 }
788 }
789 }
790
791 println!("\n{}", "═".repeat(70).green());
792 println!(
793 "{}: {} tasks across {} stages",
794 "Summary".bold(),
795 self.tasks.assertion.len()
796 + self.tasks.judge.len()
797 + self.tasks.trace.len()
798 + self.tasks.agent.len(),
799 plan.stages.len()
800 );
801
802 if conditional_count > 0 {
803 println!(
804 "{}: {} conditional tasks that create execution branches",
805 "Branches".bold().bright_red(),
806 conditional_count
807 );
808 }
809
810 println!();
811
812 Ok(())
813 }
814}
815
816impl Default for GenAIEvalProfile {
817 fn default() -> Self {
818 Self {
819 config: GenAIEvalConfig::default(),
820 tasks: AssertionTasks {
821 assertion: Vec::new(),
822 judge: Vec::new(),
823 trace: Vec::new(),
824 agent: Vec::new(),
825 },
826 scouter_version: scouter_version(),
827 workflow: None,
828 task_ids: BTreeSet::new(),
829 alias: None,
830 }
831 }
832}
833
834impl GenAIEvalProfile {
835 pub fn build_from_parts(
837 config: GenAIEvalConfig,
838 tasks: AssertionTasks,
839 alias: Option<String>,
840 ) -> Result<GenAIEvalProfile, ProfileError> {
841 let (workflow, task_ids) =
842 app_state().block_on(async { GenAIEvalProfile::build_profile(&tasks).await })?;
843
844 Ok(GenAIEvalProfile {
845 config,
846 tasks,
847 scouter_version: scouter_version(),
848 workflow,
849 task_ids,
850 alias,
851 })
852 }
853
854 pub async fn build_from_parts_async(
856 config: GenAIEvalConfig,
857 tasks: AssertionTasks,
858 alias: Option<String>,
859 ) -> Result<GenAIEvalProfile, ProfileError> {
860 let (workflow, task_ids) = GenAIEvalProfile::build_profile(&tasks).await?;
861
862 Ok(GenAIEvalProfile {
863 config,
864 tasks,
865 scouter_version: scouter_version(),
866 workflow,
867 task_ids,
868 alias,
869 })
870 }
871
872 #[instrument(skip_all)]
873 pub async fn new(
874 config: GenAIEvalConfig,
875 tasks: Vec<EvaluationTask>,
876 ) -> Result<Self, ProfileError> {
877 let tasks = separate_tasks(tasks);
878 let (workflow, task_ids) = Self::build_profile(&tasks).await?;
879
880 Ok(Self {
881 config,
882 tasks,
883 scouter_version: scouter_version(),
884 workflow,
885 task_ids,
886 alias: None,
887 })
888 }
889
890 async fn build_profile(
891 tasks: &AssertionTasks,
892 ) -> Result<(Option<Workflow>, BTreeSet<String>), ProfileError> {
893 if tasks.assertion.is_empty()
894 && tasks.judge.is_empty()
895 && tasks.trace.is_empty()
896 && tasks.agent.is_empty()
897 {
898 return Err(ProfileError::EmptyTaskList);
899 }
900
901 let workflow = if !tasks.judge.is_empty() {
902 let workflow = Self::build_workflow_from_judges(tasks).await?;
903 validate_workflow(&workflow)?;
904 Some(workflow)
905 } else {
906 None
907 };
908
909 for judge in &tasks.judge {
911 validate_prompt_parameters(&judge.prompt, &judge.id)?;
912 }
913
914 let task_ids = tasks.collect_all_task_ids()?;
916
917 Ok((workflow, task_ids))
918 }
919
920 async fn get_or_create_agent(
921 agents: &mut HashMap<potato_head::Provider, Agent>,
922 workflow: &mut Workflow,
923 provider: &potato_head::Provider,
924 ) -> Result<Agent, ProfileError> {
925 match agents.entry(provider.clone()) {
926 Entry::Occupied(entry) => Ok(entry.get().clone()),
927 Entry::Vacant(entry) => {
928 let agent = Agent::new(provider.clone(), None).await?;
929 workflow.add_agent(&agent);
930 Ok(entry.insert(agent).clone())
931 }
932 }
933 }
934
935 fn filter_judge_dependencies(
936 depends_on: &[String],
937 non_judge_task_ids: &BTreeSet<String>,
938 ) -> Option<Vec<String>> {
939 let filtered: Vec<String> = depends_on
940 .iter()
941 .filter(|dep_id| !non_judge_task_ids.contains(*dep_id))
942 .cloned()
943 .collect();
944
945 if filtered.is_empty() {
946 None
947 } else {
948 Some(filtered)
949 }
950 }
951
952 pub async fn build_workflow_from_judges(
958 tasks: &AssertionTasks,
959 ) -> Result<Workflow, ProfileError> {
960 let mut workflow = Workflow::new(&format!("eval_workflow_{}", create_uuid7()));
961 let mut agents = HashMap::new();
962 let non_judge_task_ids = tasks.collect_non_judge_task_ids();
963
964 for judge in &tasks.judge {
965 let agent =
966 Self::get_or_create_agent(&mut agents, &mut workflow, &judge.prompt.provider)
967 .await?;
968
969 let task_deps = Self::filter_judge_dependencies(&judge.depends_on, &non_judge_task_ids);
970
971 let task = Task::new(
972 &agent.id,
973 judge.prompt.clone(),
974 &judge.id,
975 task_deps,
976 judge.max_retries,
977 )?;
978
979 workflow.add_task(task)?;
980 }
981
982 Ok(workflow)
983 }
984}
985
986impl ProfileExt for GenAIEvalProfile {
987 #[inline]
988 fn id(&self) -> &str {
989 &self.config.uid
990 }
991
992 fn get_task_by_id(&self, id: &str) -> Option<&dyn TaskAccessor> {
993 if let Some(assertion) = self.tasks.assertion.iter().find(|t| t.id() == id) {
994 return Some(assertion);
995 }
996
997 if let Some(judge) = self.tasks.judge.iter().find(|t| t.id() == id) {
998 return Some(judge);
999 }
1000
1001 if let Some(trace) = self.tasks.trace.iter().find(|t| t.id() == id) {
1002 return Some(trace);
1003 }
1004
1005 if let Some(request) = self.tasks.agent.iter().find(|t| t.id() == id) {
1006 return Some(request);
1007 }
1008
1009 None
1010 }
1011
1012 #[inline]
1013 fn get_assertion_by_id(&self, id: &str) -> Option<&AssertionTask> {
1015 self.tasks.assertion.iter().find(|t| t.id() == id)
1016 }
1017
1018 #[inline]
1019 fn get_llm_judge_by_id(&self, id: &str) -> Option<&LLMJudgeTask> {
1020 self.tasks.judge.iter().find(|t| t.id() == id)
1021 }
1022
1023 #[inline]
1024 fn get_trace_assertion_by_id(&self, id: &str) -> Option<&TraceAssertionTask> {
1025 self.tasks.trace.iter().find(|t| t.id() == id)
1026 }
1027
1028 #[inline]
1029 fn get_agent_assertion_by_id(&self, id: &str) -> Option<&AgentAssertionTask> {
1030 self.tasks.agent.iter().find(|t| t.id() == id)
1031 }
1032
1033 #[inline]
1034 fn has_llm_tasks(&self) -> bool {
1035 !self.tasks.judge.is_empty()
1036 }
1037
1038 #[inline]
1039 fn has_trace_assertions(&self) -> bool {
1040 !self.tasks.trace.is_empty()
1041 }
1042
1043 #[inline]
1044 fn has_agent_assertions(&self) -> bool {
1045 !self.tasks.agent.is_empty()
1046 }
1047}
1048
1049impl ProfileBaseArgs for GenAIEvalProfile {
1050 type Config = GenAIEvalConfig;
1051
1052 fn config(&self) -> &Self::Config {
1053 &self.config
1054 }
1055 fn get_base_args(&self) -> ProfileArgs {
1056 ProfileArgs {
1057 name: self.config.name.clone(),
1058 space: self.config.space.clone(),
1059 version: Some(self.config.version.clone()),
1060 schedule: self.config.alert_config.schedule.clone(),
1061 scouter_version: self.scouter_version.clone(),
1062 drift_type: self.config.drift_type.clone(),
1063 }
1064 }
1065
1066 fn to_value(&self) -> Value {
1067 serde_json::to_value(self).unwrap()
1068 }
1069}
1070
1071#[pyclass]
1072#[derive(Debug, Serialize, Deserialize, Clone)]
1073pub struct EvalSet {
1074 #[pyo3(get)]
1075 pub records: Vec<EvalTaskResult>,
1076 pub inner: GenAIEvalWorkflowResult,
1077}
1078
1079impl EvalSet {
1080 pub fn build_task_entries(&mut self, record_id: &str) -> Vec<TaskResultTableEntry> {
1081 self.records
1084 .sort_by(|a, b| a.stage.cmp(&b.stage).then(a.task_id.cmp(&b.task_id)));
1085
1086 self.records
1087 .iter()
1088 .map(|record| record.to_table_entry(record_id))
1089 .collect()
1090 }
1091
1092 pub fn build_workflow_entries(&self) -> Vec<WorkflowResultTableEntry> {
1093 vec![self.inner.to_table_entry()]
1094 }
1095
1096 pub fn new(records: Vec<EvalTaskResult>, inner: GenAIEvalWorkflowResult) -> Self {
1097 Self { records, inner }
1098 }
1099
1100 pub fn empty() -> Self {
1101 Self {
1102 records: Vec::new(),
1103 inner: GenAIEvalWorkflowResult {
1104 created_at: Utc::now(),
1105 record_uid: String::new(),
1106 entity_id: 0,
1107 total_tasks: 0,
1108 passed_tasks: 0,
1109 failed_tasks: 0,
1110 pass_rate: 0.0,
1111 duration_ms: 0,
1112 entity_uid: String::new(),
1113 execution_plan: ExecutionPlan::default(),
1114 id: 0,
1115 },
1116 }
1117 }
1118}
1119
1120#[pymethods]
1121impl EvalSet {
1122 #[getter]
1123 pub fn created_at(&self) -> DateTime<Utc> {
1124 self.inner.created_at
1125 }
1126
1127 #[getter]
1128 pub fn record_uid(&self) -> String {
1129 self.inner.record_uid.clone()
1130 }
1131
1132 #[getter]
1133 pub fn total_tasks(&self) -> i32 {
1134 self.inner.total_tasks
1135 }
1136
1137 #[getter]
1138 pub fn passed_tasks(&self) -> i32 {
1139 self.inner.passed_tasks
1140 }
1141
1142 #[getter]
1143 pub fn failed_tasks(&self) -> i32 {
1144 self.inner.failed_tasks
1145 }
1146
1147 #[getter]
1148 pub fn pass_rate(&self) -> f64 {
1149 self.inner.pass_rate
1150 }
1151
1152 #[getter]
1153 pub fn duration_ms(&self) -> i64 {
1154 self.inner.duration_ms
1155 }
1156
1157 pub fn __str__(&self) -> String {
1158 PyHelperFuncs::__str__(self)
1160 }
1161}
1162
1163#[pyclass]
1164#[derive(Debug, Serialize, Deserialize, Clone)]
1165pub struct EvalResultSet {
1166 #[pyo3(get)]
1167 pub records: Vec<EvalSet>,
1168}
1169
1170#[pymethods]
1171impl EvalResultSet {
1172 pub fn record(&self, id: &str) -> Option<EvalSet> {
1173 self.records.iter().find(|r| r.record_uid() == id).cloned()
1174 }
1175 pub fn __str__(&self) -> String {
1176 PyHelperFuncs::__str__(self)
1178 }
1179}
1180
1181#[cfg(test)]
1183#[cfg(feature = "mock")]
1184mod tests {
1185
1186 use super::*;
1187 use crate::genai::{ComparisonOperator, EvaluationTasks};
1188 use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
1189
1190 use potato_head::mock::create_score_prompt;
1191
1192 #[test]
1193 fn test_genai_drift_config() {
1194 let mut drift_config = GenAIEvalConfig::new(
1195 MISSING,
1196 MISSING,
1197 "0.1.0",
1198 1.0,
1199 GenAIAlertConfig::default(),
1200 None,
1201 )
1202 .unwrap();
1203 assert_eq!(drift_config.name, "__missing__");
1204 assert_eq!(drift_config.space, "__missing__");
1205 assert_eq!(drift_config.version, "0.1.0");
1206 assert_eq!(
1207 drift_config.alert_config.dispatch_config,
1208 AlertDispatchConfig::default()
1209 );
1210
1211 let test_slack_dispatch_config = SlackDispatchConfig {
1212 channel: "test-channel".to_string(),
1213 };
1214 let new_alert_config = GenAIAlertConfig {
1215 schedule: "0 0 * * * *".to_string(),
1216 dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
1217 ..Default::default()
1218 };
1219
1220 drift_config
1221 .update_config_args(
1222 None,
1223 Some("test".to_string()),
1224 None,
1225 None,
1226 Some(new_alert_config),
1227 )
1228 .unwrap();
1229
1230 assert_eq!(drift_config.name, "test");
1231 assert_eq!(
1232 drift_config.alert_config.dispatch_config,
1233 AlertDispatchConfig::Slack(test_slack_dispatch_config)
1234 );
1235 assert_eq!(
1236 drift_config.alert_config.schedule,
1237 "0 0 * * * *".to_string()
1238 );
1239 }
1240
1241 #[tokio::test]
1242 async fn test_genai_drift_profile_metric() {
1243 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
1244
1245 let task1 = LLMJudgeTask::new_rs(
1246 "metric1",
1247 prompt.clone(),
1248 Value::Number(4.into()),
1249 None,
1250 ComparisonOperator::GreaterThanOrEqual,
1251 None,
1252 None,
1253 None,
1254 None,
1255 );
1256
1257 let task2 = LLMJudgeTask::new_rs(
1258 "metric2",
1259 prompt.clone(),
1260 Value::Number(2.into()),
1261 None,
1262 ComparisonOperator::LessThanOrEqual,
1263 None,
1264 None,
1265 None,
1266 None,
1267 );
1268
1269 let tasks = EvaluationTasks::new()
1270 .add_task(task1)
1271 .add_task(task2)
1272 .build();
1273
1274 let alert_config = GenAIAlertConfig {
1275 schedule: "0 0 * * * *".to_string(),
1276 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
1277 team: "test-team".to_string(),
1278 priority: "P5".to_string(),
1279 }),
1280 ..Default::default()
1281 };
1282
1283 let drift_config =
1284 GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
1285
1286 let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();
1287
1288 let _: Value =
1289 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
1290
1291 assert_eq!(profile.llm_judge_tasks().len(), 2);
1292 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
1293 }
1294}