1use crate::error::{ProfileError, TypeError};
2use crate::genai::alert::GenAIAlertConfig;
3use crate::genai::eval::{AssertionTask, EvaluationTask, LLMJudgeTask};
4use crate::genai::traits::{separate_tasks, ProfileExt, TaskAccessor};
5use crate::genai::utils::extract_assertion_tasks_from_pylist;
6use crate::util::{json_to_pyobject, pyobject_to_json, ConfigExt};
7use crate::{
8 scouter_version, GenAIEvalTaskResult, GenAIEvalWorkflowResult, WorkflowResultTableEntry,
9};
10use crate::{
11 DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
12 PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
13};
14use crate::{ProfileRequest, TaskResultTableEntry};
15use chrono::{DateTime, Utc};
16use core::fmt::Debug;
17use potato_head::prompt_types::Prompt;
18use potato_head::Agent;
19use potato_head::Workflow;
20use potato_head::{create_uuid7, Task};
21use pyo3::prelude::*;
22use pyo3::types::{PyDict, PyList};
23use scouter_semver::VersionType;
24use scouter_state::app_state;
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::hash_map::Entry;
28use std::collections::BTreeSet;
29use std::collections::HashMap;
30use std::path::PathBuf;
31use std::sync::Arc;
32
33use tracing::instrument;
34
35#[pyclass]
36#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
37pub struct GenAIEvalConfig {
38 #[pyo3(get, set)]
39 pub sample_ratio: f64,
40
41 #[pyo3(get, set)]
42 pub space: String,
43
44 #[pyo3(get, set)]
45 pub name: String,
46
47 #[pyo3(get, set)]
48 pub version: String,
49
50 #[pyo3(get, set)]
51 pub uid: String,
52
53 #[pyo3(get, set)]
54 pub alert_config: GenAIAlertConfig,
55
56 #[pyo3(get, set)]
57 #[serde(default = "default_drift_type")]
58 pub drift_type: DriftType,
59}
60
61impl ConfigExt for GenAIEvalConfig {
62 fn space(&self) -> &str {
63 &self.space
64 }
65
66 fn name(&self) -> &str {
67 &self.name
68 }
69
70 fn version(&self) -> &str {
71 &self.version
72 }
73}
74
75fn default_drift_type() -> DriftType {
76 DriftType::GenAI
77}
78
79impl DispatchDriftConfig for GenAIEvalConfig {
80 fn get_drift_args(&self) -> DriftArgs {
81 DriftArgs {
82 name: self.name.clone(),
83 space: self.space.clone(),
84 version: self.version.clone(),
85 dispatch_config: self.alert_config.dispatch_config.clone(),
86 }
87 }
88}
89
90#[pymethods]
91#[allow(clippy::too_many_arguments)]
92impl GenAIEvalConfig {
93 #[new]
94 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_ratio=1.0, alert_config=GenAIAlertConfig::default(), config_path=None))]
95 pub fn new(
96 space: &str,
97 name: &str,
98 version: &str,
99 sample_ratio: f64,
100 alert_config: GenAIAlertConfig,
101 config_path: Option<PathBuf>,
102 ) -> Result<Self, ProfileError> {
103 if let Some(config_path) = config_path {
104 let config = GenAIEvalConfig::load_from_json_file(config_path)?;
105 return Ok(config);
106 }
107
108 Ok(Self {
109 sample_ratio: sample_ratio.clamp(0.0, 1.0),
110 space: space.to_string(),
111 name: name.to_string(),
112 uid: create_uuid7(),
113 version: version.to_string(),
114 alert_config,
115 drift_type: DriftType::GenAI,
116 })
117 }
118
119 #[staticmethod]
120 pub fn load_from_json_file(path: PathBuf) -> Result<GenAIEvalConfig, ProfileError> {
121 let file = std::fs::read_to_string(&path)?;
124
125 Ok(serde_json::from_str(&file)?)
126 }
127
128 pub fn __str__(&self) -> String {
129 PyHelperFuncs::__str__(self)
131 }
132
133 pub fn model_dump_json(&self) -> String {
134 PyHelperFuncs::__json__(self)
136 }
137
138 #[allow(clippy::too_many_arguments)]
139 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
140 pub fn update_config_args(
141 &mut self,
142 space: Option<String>,
143 name: Option<String>,
144 version: Option<String>,
145 uid: Option<String>,
146 alert_config: Option<GenAIAlertConfig>,
147 ) -> Result<(), TypeError> {
148 if name.is_some() {
149 self.name = name.ok_or(TypeError::MissingNameError)?;
150 }
151
152 if space.is_some() {
153 self.space = space.ok_or(TypeError::MissingSpaceError)?;
154 }
155
156 if version.is_some() {
157 self.version = version.ok_or(TypeError::MissingVersionError)?;
158 }
159
160 if alert_config.is_some() {
161 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
162 }
163
164 if uid.is_some() {
165 self.uid = uid.ok_or(TypeError::MissingUidError)?;
166 }
167
168 Ok(())
169 }
170}
171
172impl Default for GenAIEvalConfig {
173 fn default() -> Self {
174 Self {
175 sample_ratio: 1.0,
176 space: "default".to_string(),
177 name: "default_genai_profile".to_string(),
178 version: DEFAULT_VERSION.to_string(),
179 uid: create_uuid7(),
180 alert_config: GenAIAlertConfig::default(),
181 drift_type: DriftType::GenAI,
182 }
183 }
184}
185
186fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
202 let has_at_least_one_param = !prompt.parameters.is_empty();
203
204 if !has_at_least_one_param {
205 return Err(ProfileError::NeedAtLeastOneBoundParameterError(
206 id.to_string(),
207 ));
208 }
209
210 Ok(())
211}
212
213fn get_workflow_task<'a>(
214 workflow: &'a Workflow,
215 task_id: &'a str,
216) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
217 workflow
218 .task_list
219 .tasks
220 .get(task_id)
221 .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
222}
223
224fn validate_first_tasks(
226 workflow: &Workflow,
227 execution_order: &HashMap<i32, std::collections::HashSet<String>>,
228) -> Result<(), ProfileError> {
229 let first_tasks = execution_order
230 .get(&1)
231 .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
232
233 for task_id in first_tasks {
234 let task = get_workflow_task(workflow, task_id)?;
235 let task_guard = task
236 .read()
237 .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
238
239 validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
240 }
241
242 Ok(())
243}
244
245fn validate_workflow(workflow: &Workflow) -> Result<(), ProfileError> {
261 let execution_order = workflow.execution_plan()?;
262
263 validate_first_tasks(workflow, &execution_order)?;
265
266 Ok(())
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
270#[pyclass]
271pub struct ExecutionNode {
272 pub id: String,
273 pub stage: usize,
274 pub parents: Vec<String>,
275 pub children: Vec<String>,
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize, Default)]
279#[pyclass]
280pub struct ExecutionPlan {
281 #[pyo3(get)]
282 pub stages: Vec<Vec<String>>,
283 #[pyo3(get)]
284 pub nodes: HashMap<String, ExecutionNode>,
285}
286
287fn initialize_node_graphs(
288 tasks: &[impl TaskAccessor],
289 graph: &mut HashMap<String, Vec<String>>,
290 reverse_graph: &mut HashMap<String, Vec<String>>,
291 in_degree: &mut HashMap<String, usize>,
292) {
293 for task in tasks {
294 let task_id = task.id().to_string();
295 graph.entry(task_id.clone()).or_default();
296 reverse_graph.entry(task_id.clone()).or_default();
297 in_degree.entry(task_id).or_insert(0);
298 }
299}
300
301fn build_dependency_edges(
302 tasks: &[impl TaskAccessor],
303 graph: &mut HashMap<String, Vec<String>>,
304 reverse_graph: &mut HashMap<String, Vec<String>>,
305 in_degree: &mut HashMap<String, usize>,
306) {
307 for task in tasks {
308 let task_id = task.id().to_string();
309 for dep in task.depends_on() {
310 graph.entry(dep.clone()).or_default().push(task_id.clone());
311 reverse_graph
312 .entry(task_id.clone())
313 .or_default()
314 .push(dep.clone());
315 *in_degree.entry(task_id.clone()).or_insert(0) += 1;
316 }
317 }
318}
319
320#[pyclass]
321#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
322pub struct GenAIEvalProfile {
323 #[pyo3(get)]
324 pub config: GenAIEvalConfig,
325
326 #[pyo3(get)]
327 pub assertion_tasks: Vec<AssertionTask>,
328
329 #[pyo3(get)]
330 pub llm_judge_tasks: Vec<LLMJudgeTask>,
331
332 #[pyo3(get)]
333 pub scouter_version: String,
334
335 pub workflow: Option<Workflow>,
336
337 pub task_ids: BTreeSet<String>,
338}
339
340#[pymethods]
341impl GenAIEvalProfile {
342 #[new]
343 #[pyo3(signature = (config, tasks))]
344 #[instrument(skip_all)]
354 pub fn new_py(
355 config: GenAIEvalConfig,
356 tasks: &Bound<'_, PyList>,
357 ) -> Result<Self, ProfileError> {
358 let (assertion_tasks, llm_judge_tasks) = extract_assertion_tasks_from_pylist(tasks)?;
359
360 let (workflow, task_ids) = app_state()
361 .block_on(async { Self::build_profile(&llm_judge_tasks, &assertion_tasks).await })?;
362
363 Ok(Self {
364 config,
365 assertion_tasks,
366 llm_judge_tasks,
367
368 scouter_version: scouter_version(),
369 workflow,
370 task_ids,
371 })
372 }
373
374 pub fn __str__(&self) -> String {
375 PyHelperFuncs::__str__(self)
377 }
378
379 pub fn model_dump_json(&self) -> String {
380 PyHelperFuncs::__json__(self)
382 }
383
384 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
385 let json_value = serde_json::to_value(self)?;
386
387 let dict = PyDict::new(py);
389
390 json_to_pyobject(py, &json_value, &dict)?;
392
393 Ok(dict.into())
395 }
396
397 #[getter]
398 pub fn uid(&self) -> String {
399 self.config.uid.clone()
400 }
401
402 #[setter]
403 pub fn set_uid(&mut self, uid: String) {
404 self.config.uid = uid;
405 }
406
407 #[pyo3(signature = (path=None))]
408 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
409 Ok(PyHelperFuncs::save_to_json(
410 self,
411 path,
412 FileName::GenAIEvalProfile.to_str(),
413 )?)
414 }
415
416 #[staticmethod]
417 pub fn model_validate(data: &Bound<'_, PyDict>) -> GenAIEvalProfile {
418 let json_value = pyobject_to_json(data).unwrap();
419
420 let string = serde_json::to_string(&json_value).unwrap();
421 serde_json::from_str(&string).expect("Failed to load drift profile")
422 }
423
424 #[staticmethod]
425 pub fn model_validate_json(json_string: String) -> GenAIEvalProfile {
426 serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
428 }
429
430 #[staticmethod]
431 pub fn from_file(path: PathBuf) -> Result<GenAIEvalProfile, ProfileError> {
432 let file = std::fs::read_to_string(&path)?;
433
434 Ok(serde_json::from_str(&file)?)
435 }
436
437 #[allow(clippy::too_many_arguments)]
438 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
439 pub fn update_config_args(
440 &mut self,
441 space: Option<String>,
442 name: Option<String>,
443 version: Option<String>,
444 uid: Option<String>,
445 alert_config: Option<GenAIAlertConfig>,
446 ) -> Result<(), TypeError> {
447 self.config
448 .update_config_args(space, name, version, uid, alert_config)
449 }
450
451 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
453 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
454 None
455 } else {
456 Some(self.config.version.clone())
457 };
458
459 Ok(ProfileRequest {
460 space: self.config.space.clone(),
461 profile: self.model_dump_json(),
462 drift_type: self.config.drift_type.clone(),
463 version_request: Some(VersionRequest {
464 version,
465 version_type: VersionType::Minor,
466 pre_tag: None,
467 build_tag: None,
468 }),
469 active: false,
470 deactivate_others: false,
471 })
472 }
473
474 pub fn has_llm_tasks(&self) -> bool {
475 !self.llm_judge_tasks.is_empty()
476 }
477
478 pub fn has_assertions(&self) -> bool {
480 !self.assertion_tasks.is_empty()
481 }
482
483 pub fn get_execution_plan(&self) -> Result<ExecutionPlan, ProfileError> {
485 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
486 let mut reverse_graph: HashMap<String, Vec<String>> = HashMap::new();
487 let mut in_degree: HashMap<String, usize> = HashMap::new();
488
489 initialize_node_graphs(
490 &self.assertion_tasks,
491 &mut graph,
492 &mut reverse_graph,
493 &mut in_degree,
494 );
495 initialize_node_graphs(
496 &self.llm_judge_tasks,
497 &mut graph,
498 &mut reverse_graph,
499 &mut in_degree,
500 );
501
502 build_dependency_edges(
503 &self.assertion_tasks,
504 &mut graph,
505 &mut reverse_graph,
506 &mut in_degree,
507 );
508 build_dependency_edges(
509 &self.llm_judge_tasks,
510 &mut graph,
511 &mut reverse_graph,
512 &mut in_degree,
513 );
514
515 let mut stages = Vec::new();
516 let mut nodes: HashMap<String, ExecutionNode> = HashMap::new();
517 let mut current_level: Vec<String> = in_degree
518 .iter()
519 .filter(|(_, °ree)| degree == 0)
520 .map(|(id, _)| id.clone())
521 .collect();
522
523 let mut stage_idx = 0;
524
525 while !current_level.is_empty() {
526 stages.push(current_level.clone());
527
528 for task_id in ¤t_level {
529 nodes.insert(
530 task_id.clone(),
531 ExecutionNode {
532 id: task_id.clone(),
533 stage: stage_idx,
534 parents: reverse_graph.get(task_id).cloned().unwrap_or_default(),
535 children: graph.get(task_id).cloned().unwrap_or_default(),
536 },
537 );
538 }
539
540 let mut next_level = Vec::new();
541 for task_id in ¤t_level {
542 if let Some(dependents) = graph.get(task_id) {
543 for dependent in dependents {
544 if let Some(degree) = in_degree.get_mut(dependent) {
545 *degree -= 1;
546 if *degree == 0 {
547 next_level.push(dependent.clone());
548 }
549 }
550 }
551 }
552 }
553
554 current_level = next_level;
555 stage_idx += 1;
556 }
557
558 let total_tasks = self.assertion_tasks.len() + self.llm_judge_tasks.len();
559 let processed_tasks: usize = stages.iter().map(|level| level.len()).sum();
560
561 if processed_tasks != total_tasks {
562 return Err(ProfileError::CircularDependency);
563 }
564
565 Ok(ExecutionPlan { stages, nodes })
566 }
567
568 pub fn print_execution_plan(&self) -> Result<(), ProfileError> {
569 use owo_colors::OwoColorize;
570
571 let plan = self.get_execution_plan()?;
572
573 println!("\n{}", "Evaluation Execution Plan".bold().green());
574 println!("{}", "═".repeat(70).green());
575
576 let mut conditional_count = 0;
577
578 for (level_idx, level) in plan.stages.iter().enumerate() {
579 let stage_label = format!("Stage {}", level_idx + 1);
580 println!("\n{}", stage_label.bold().cyan());
581
582 for (task_idx, task_id) in level.iter().enumerate() {
583 let is_last = task_idx == level.len() - 1;
584 let prefix = if is_last { "└─" } else { "├─" };
585
586 let task = self.get_task_by_id(task_id).ok_or_else(|| {
587 ProfileError::NoTasksFoundError(format!("Task '{}' not found", task_id))
588 })?;
589
590 let is_conditional = if let Some(assertion) = self.get_assertion_by_id(task_id) {
591 assertion.condition
592 } else if let Some(judge) = self.get_llm_judge_by_id(task_id) {
593 judge.condition
594 } else {
595 false
596 };
597
598 if is_conditional {
599 conditional_count += 1;
600 }
601
602 let (task_type, color_fn): (&str, fn(&str) -> String) =
603 if self.assertion_tasks.iter().any(|t| &t.id == task_id) {
604 ("Assertion", |s: &str| s.yellow().to_string())
605 } else {
606 ("LLM Judge", |s: &str| s.purple().to_string())
607 };
608
609 let conditional_marker = if is_conditional {
610 " [CONDITIONAL]".bright_red().to_string()
611 } else {
612 String::new()
613 };
614
615 println!(
616 "{} {} ({}){}",
617 prefix,
618 task_id.bold(),
619 color_fn(task_type),
620 conditional_marker
621 );
622
623 let deps = task.depends_on();
624 if !deps.is_empty() {
625 let dep_prefix = if is_last { " " } else { "│ " };
626
627 let (conditional_deps, normal_deps): (Vec<_>, Vec<_>) =
628 deps.iter().partition(|dep_id| {
629 self.get_assertion_by_id(dep_id)
630 .map(|t| t.condition)
631 .or_else(|| self.get_llm_judge_by_id(dep_id).map(|t| t.condition))
632 .unwrap_or(false)
633 });
634
635 if !normal_deps.is_empty() {
636 println!(
637 "{} {} {}",
638 dep_prefix,
639 "depends on:".dimmed(),
640 normal_deps
641 .iter()
642 .map(|s| s.as_str())
643 .collect::<Vec<_>>()
644 .join(", ")
645 .dimmed()
646 );
647 }
648
649 if !conditional_deps.is_empty() {
650 println!(
651 "{} {} {}",
652 dep_prefix,
653 "▶ conditional gate:".bright_red().dimmed(),
654 conditional_deps
655 .iter()
656 .map(|d| format!("{} must pass", d))
657 .collect::<Vec<_>>()
658 .join(", ")
659 .red()
660 .dimmed()
661 );
662 }
663 }
664
665 if is_conditional {
666 let continuation = if is_last { " " } else { "│ " };
667 println!(
668 "{} {} {}",
669 continuation,
670 "▶".bright_red(),
671 "creates conditional branch".bright_red().dimmed()
672 );
673 }
674 }
675 }
676
677 println!("\n{}", "═".repeat(70).green());
678 println!(
679 "{}: {} tasks across {} stages",
680 "Summary".bold(),
681 self.assertion_tasks.len() + self.llm_judge_tasks.len(),
682 plan.stages.len()
683 );
684
685 if conditional_count > 0 {
686 println!(
687 "{}: {} conditional tasks that create execution branches",
688 "Branches".bold().bright_red(),
689 conditional_count
690 );
691 }
692
693 println!();
694
695 Ok(())
696 }
697}
698
699impl Default for GenAIEvalProfile {
700 fn default() -> Self {
701 Self {
702 config: GenAIEvalConfig::default(),
703 assertion_tasks: Vec::new(),
704 llm_judge_tasks: Vec::new(),
705 scouter_version: scouter_version(),
706 workflow: None,
707 task_ids: BTreeSet::new(),
708 }
709 }
710}
711
712impl GenAIEvalProfile {
713 #[instrument(skip_all)]
714 pub async fn new(
715 config: GenAIEvalConfig,
716 tasks: Vec<EvaluationTask>,
717 ) -> Result<Self, ProfileError> {
718 let (assertion_tasks, llm_judge_tasks) = separate_tasks(tasks);
719 let (workflow, task_ids) = Self::build_profile(&llm_judge_tasks, &assertion_tasks).await?;
720
721 Ok(Self {
722 config,
723 assertion_tasks,
724 llm_judge_tasks,
725
726 scouter_version: scouter_version(),
727 workflow,
728 task_ids,
729 })
730 }
731
732 async fn build_profile(
733 judge_tasks: &[LLMJudgeTask],
734 assertion_tasks: &[AssertionTask],
735 ) -> Result<(Option<Workflow>, BTreeSet<String>), ProfileError> {
736 if assertion_tasks.is_empty() && judge_tasks.is_empty() {
737 return Err(ProfileError::EmptyTaskList);
738 }
739
740 let workflow = if !judge_tasks.is_empty() {
741 let assertion_ids: BTreeSet<String> =
742 assertion_tasks.iter().map(|t| t.id.clone()).collect();
743 let workflow = Self::build_workflow_from_judges(judge_tasks, &assertion_ids).await?;
744
745 validate_workflow(&workflow)?;
747 Some(workflow)
748 } else {
749 None
750 };
751
752 for judge in judge_tasks {
754 validate_prompt_parameters(&judge.prompt, &judge.id)?;
755 }
756
757 let mut task_ids = BTreeSet::new();
759 for task in assertion_tasks {
760 task_ids.insert(task.id.clone());
761 }
762 for task in judge_tasks {
763 task_ids.insert(task.id.clone());
764 }
765
766 let total_tasks = assertion_tasks.len() + judge_tasks.len();
768
769 if task_ids.len() != total_tasks {
770 return Err(ProfileError::DuplicateTaskIds);
771 }
772 Ok((workflow, task_ids))
773 }
774
775 pub async fn build_workflow_from_judges(
781 judges: &[LLMJudgeTask],
782 assertion_ids: &BTreeSet<String>,
783 ) -> Result<Workflow, ProfileError> {
784 let mut workflow = Workflow::new(&format!("eval_workflow_{}", create_uuid7()));
785 let mut agents = HashMap::new();
786
787 for judge in judges {
788 let agent = match agents.entry(&judge.prompt.provider) {
789 Entry::Occupied(entry) => entry.into_mut(),
790 Entry::Vacant(entry) => {
791 let agent = Agent::new(judge.prompt.provider.clone(), None).await?;
792 workflow.add_agent(&agent);
793 entry.insert(agent)
794 }
795 };
796
797 let workflow_dependencies: Vec<String> = judge
800 .depends_on
801 .iter()
802 .filter(|dep_id| !assertion_ids.contains(*dep_id))
803 .cloned()
804 .collect();
805
806 let task_deps = if workflow_dependencies.is_empty() {
807 None
808 } else {
809 Some(workflow_dependencies)
810 };
811
812 let task = Task::new(
813 &agent.id,
814 judge.prompt.clone(),
815 &judge.id,
816 task_deps,
817 judge.max_retries,
818 )
819 .unwrap();
820
821 workflow.add_task(task)?;
822 }
823
824 Ok(workflow)
825 }
826}
827
828impl ProfileExt for GenAIEvalProfile {
829 #[inline]
830 fn id(&self) -> &str {
831 &self.config.uid
832 }
833
834 fn get_task_by_id(&self, id: &str) -> Option<&dyn TaskAccessor> {
835 if let Some(assertion) = self.assertion_tasks.iter().find(|t| t.id() == id) {
836 return Some(assertion);
837 }
838
839 if let Some(judge) = self.llm_judge_tasks.iter().find(|t| t.id() == id) {
840 return Some(judge);
841 }
842
843 None
844 }
845
846 #[inline]
847 fn get_assertion_by_id(&self, id: &str) -> Option<&AssertionTask> {
849 self.assertion_tasks.iter().find(|t| t.id() == id)
850 }
851
852 #[inline]
853 fn get_llm_judge_by_id(&self, id: &str) -> Option<&LLMJudgeTask> {
854 self.llm_judge_tasks.iter().find(|t| t.id() == id)
855 }
856
857 #[inline]
858 fn has_llm_tasks(&self) -> bool {
859 !self.llm_judge_tasks.is_empty()
860 }
861}
862
863impl ProfileBaseArgs for GenAIEvalProfile {
864 type Config = GenAIEvalConfig;
865
866 fn config(&self) -> &Self::Config {
867 &self.config
868 }
869 fn get_base_args(&self) -> ProfileArgs {
870 ProfileArgs {
871 name: self.config.name.clone(),
872 space: self.config.space.clone(),
873 version: Some(self.config.version.clone()),
874 schedule: self.config.alert_config.schedule.clone(),
875 scouter_version: self.scouter_version.clone(),
876 drift_type: self.config.drift_type.clone(),
877 }
878 }
879
880 fn to_value(&self) -> Value {
881 serde_json::to_value(self).unwrap()
882 }
883}
884
885#[pyclass]
886#[derive(Debug, Serialize, Deserialize, Clone)]
887pub struct GenAIEvalSet {
888 #[pyo3(get)]
889 pub records: Vec<GenAIEvalTaskResult>,
890 pub inner: GenAIEvalWorkflowResult,
891}
892
893impl GenAIEvalSet {
894 pub fn build_task_entries(&mut self, record_id: &str) -> Vec<TaskResultTableEntry> {
895 self.records
898 .sort_by(|a, b| a.stage.cmp(&b.stage).then(a.task_id.cmp(&b.task_id)));
899
900 self.records
901 .iter()
902 .map(|record| record.to_table_entry(record_id))
903 .collect()
904 }
905
906 pub fn build_workflow_entries(&self) -> Vec<WorkflowResultTableEntry> {
907 vec![self.inner.to_table_entry()]
908 }
909
910 pub fn new(records: Vec<GenAIEvalTaskResult>, inner: GenAIEvalWorkflowResult) -> Self {
911 Self { records, inner }
912 }
913
914 pub fn empty() -> Self {
915 Self {
916 records: Vec::new(),
917 inner: GenAIEvalWorkflowResult {
918 created_at: Utc::now(),
919 record_uid: String::new(),
920 entity_id: 0,
921 total_tasks: 0,
922 passed_tasks: 0,
923 failed_tasks: 0,
924 pass_rate: 0.0,
925 duration_ms: 0,
926 entity_uid: String::new(),
927 execution_plan: ExecutionPlan::default(),
928 id: 0,
929 },
930 }
931 }
932}
933
934#[pymethods]
935impl GenAIEvalSet {
936 #[getter]
937 pub fn created_at(&self) -> DateTime<Utc> {
938 self.inner.created_at
939 }
940
941 #[getter]
942 pub fn record_uid(&self) -> String {
943 self.inner.record_uid.clone()
944 }
945
946 #[getter]
947 pub fn total_tasks(&self) -> i32 {
948 self.inner.total_tasks
949 }
950
951 #[getter]
952 pub fn passed_tasks(&self) -> i32 {
953 self.inner.passed_tasks
954 }
955
956 #[getter]
957 pub fn failed_tasks(&self) -> i32 {
958 self.inner.failed_tasks
959 }
960
961 #[getter]
962 pub fn pass_rate(&self) -> f64 {
963 self.inner.pass_rate
964 }
965
966 #[getter]
967 pub fn duration_ms(&self) -> i64 {
968 self.inner.duration_ms
969 }
970
971 pub fn __str__(&self) -> String {
972 PyHelperFuncs::__str__(self)
974 }
975}
976
977#[pyclass]
978#[derive(Debug, Serialize, Deserialize, Clone)]
979pub struct GenAIEvalResultSet {
980 #[pyo3(get)]
981 pub records: Vec<GenAIEvalSet>,
982}
983
984#[pymethods]
985impl GenAIEvalResultSet {
986 pub fn record(&self, id: &str) -> Option<GenAIEvalSet> {
987 self.records.iter().find(|r| r.record_uid() == id).cloned()
988 }
989 pub fn __str__(&self) -> String {
990 PyHelperFuncs::__str__(self)
992 }
993}
994
995#[cfg(test)]
997#[cfg(feature = "mock")]
998mod tests {
999
1000 use super::*;
1001 use crate::genai::{ComparisonOperator, EvaluationTasks};
1002 use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
1003
1004 use potato_head::mock::create_score_prompt;
1005
1006 #[test]
1007 fn test_genai_drift_config() {
1008 let mut drift_config = GenAIEvalConfig::new(
1009 MISSING,
1010 MISSING,
1011 "0.1.0",
1012 1.0,
1013 GenAIAlertConfig::default(),
1014 None,
1015 )
1016 .unwrap();
1017 assert_eq!(drift_config.name, "__missing__");
1018 assert_eq!(drift_config.space, "__missing__");
1019 assert_eq!(drift_config.version, "0.1.0");
1020 assert_eq!(
1021 drift_config.alert_config.dispatch_config,
1022 AlertDispatchConfig::default()
1023 );
1024
1025 let test_slack_dispatch_config = SlackDispatchConfig {
1026 channel: "test-channel".to_string(),
1027 };
1028 let new_alert_config = GenAIAlertConfig {
1029 schedule: "0 0 * * * *".to_string(),
1030 dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
1031 ..Default::default()
1032 };
1033
1034 drift_config
1035 .update_config_args(
1036 None,
1037 Some("test".to_string()),
1038 None,
1039 None,
1040 Some(new_alert_config),
1041 )
1042 .unwrap();
1043
1044 assert_eq!(drift_config.name, "test");
1045 assert_eq!(
1046 drift_config.alert_config.dispatch_config,
1047 AlertDispatchConfig::Slack(test_slack_dispatch_config)
1048 );
1049 assert_eq!(
1050 drift_config.alert_config.schedule,
1051 "0 0 * * * *".to_string()
1052 );
1053 }
1054
1055 #[tokio::test]
1056 async fn test_genai_drift_profile_metric() {
1057 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
1058
1059 let task1 = LLMJudgeTask::new_rs(
1060 "metric1",
1061 prompt.clone(),
1062 Value::Number(4.into()),
1063 None,
1064 ComparisonOperator::GreaterThanOrEqual,
1065 None,
1066 None,
1067 None,
1068 None,
1069 );
1070
1071 let task2 = LLMJudgeTask::new_rs(
1072 "metric2",
1073 prompt.clone(),
1074 Value::Number(2.into()),
1075 None,
1076 ComparisonOperator::LessThanOrEqual,
1077 None,
1078 None,
1079 None,
1080 None,
1081 );
1082
1083 let tasks = EvaluationTasks::new()
1084 .add_task(task1)
1085 .add_task(task2)
1086 .build();
1087
1088 let alert_config = GenAIAlertConfig {
1089 schedule: "0 0 * * * *".to_string(),
1090 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
1091 team: "test-team".to_string(),
1092 priority: "P5".to_string(),
1093 }),
1094 ..Default::default()
1095 };
1096
1097 let drift_config =
1098 GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
1099
1100 let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();
1101
1102 let _: Value =
1103 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
1104
1105 assert_eq!(profile.llm_judge_tasks.len(), 2);
1106 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
1107 }
1108}