Skip to main content

scouter_types/genai/
profile.rs

1use crate::error::{ProfileError, TypeError};
2use crate::genai::alert::GenAIAlertConfig;
3use crate::genai::eval::{AssertionTask, EvaluationTask, LLMJudgeTask};
4use crate::genai::traits::{separate_tasks, ProfileExt, TaskAccessor};
5use crate::genai::utils::extract_assertion_tasks_from_pylist;
6use crate::util::{json_to_pyobject, pyobject_to_json, ConfigExt};
7use crate::{
8    scouter_version, GenAIEvalTaskResult, GenAIEvalWorkflowResult, WorkflowResultTableEntry,
9};
10use crate::{
11    DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
12    PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
13};
14use crate::{ProfileRequest, TaskResultTableEntry};
15use chrono::{DateTime, Utc};
16use core::fmt::Debug;
17use potato_head::prompt_types::Prompt;
18use potato_head::Agent;
19use potato_head::Workflow;
20use potato_head::{create_uuid7, Task};
21use pyo3::prelude::*;
22use pyo3::types::{PyDict, PyList};
23use scouter_semver::VersionType;
24use scouter_state::app_state;
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::hash_map::Entry;
28use std::collections::BTreeSet;
29use std::collections::HashMap;
30use std::path::PathBuf;
31use std::sync::Arc;
32
33use tracing::instrument;
34
35#[pyclass]
36#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
37pub struct GenAIEvalConfig {
38    #[pyo3(get, set)]
39    pub sample_ratio: f64,
40
41    #[pyo3(get, set)]
42    pub space: String,
43
44    #[pyo3(get, set)]
45    pub name: String,
46
47    #[pyo3(get, set)]
48    pub version: String,
49
50    #[pyo3(get, set)]
51    pub uid: String,
52
53    #[pyo3(get, set)]
54    pub alert_config: GenAIAlertConfig,
55
56    #[pyo3(get, set)]
57    #[serde(default = "default_drift_type")]
58    pub drift_type: DriftType,
59}
60
61impl ConfigExt for GenAIEvalConfig {
62    fn space(&self) -> &str {
63        &self.space
64    }
65
66    fn name(&self) -> &str {
67        &self.name
68    }
69
70    fn version(&self) -> &str {
71        &self.version
72    }
73}
74
75fn default_drift_type() -> DriftType {
76    DriftType::GenAI
77}
78
79impl DispatchDriftConfig for GenAIEvalConfig {
80    fn get_drift_args(&self) -> DriftArgs {
81        DriftArgs {
82            name: self.name.clone(),
83            space: self.space.clone(),
84            version: self.version.clone(),
85            dispatch_config: self.alert_config.dispatch_config.clone(),
86        }
87    }
88}
89
90#[pymethods]
91#[allow(clippy::too_many_arguments)]
92impl GenAIEvalConfig {
93    #[new]
94    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_ratio=1.0, alert_config=GenAIAlertConfig::default(), config_path=None))]
95    pub fn new(
96        space: &str,
97        name: &str,
98        version: &str,
99        sample_ratio: f64,
100        alert_config: GenAIAlertConfig,
101        config_path: Option<PathBuf>,
102    ) -> Result<Self, ProfileError> {
103        if let Some(config_path) = config_path {
104            let config = GenAIEvalConfig::load_from_json_file(config_path)?;
105            return Ok(config);
106        }
107
108        Ok(Self {
109            sample_ratio: sample_ratio.clamp(0.0, 1.0),
110            space: space.to_string(),
111            name: name.to_string(),
112            uid: create_uuid7(),
113            version: version.to_string(),
114            alert_config,
115            drift_type: DriftType::GenAI,
116        })
117    }
118
119    #[staticmethod]
120    pub fn load_from_json_file(path: PathBuf) -> Result<GenAIEvalConfig, ProfileError> {
121        // deserialize the string to a struct
122
123        let file = std::fs::read_to_string(&path)?;
124
125        Ok(serde_json::from_str(&file)?)
126    }
127
128    pub fn __str__(&self) -> String {
129        // serialize the struct to a string
130        PyHelperFuncs::__str__(self)
131    }
132
133    pub fn model_dump_json(&self) -> String {
134        // serialize the struct to a string
135        PyHelperFuncs::__json__(self)
136    }
137
138    #[allow(clippy::too_many_arguments)]
139    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
140    pub fn update_config_args(
141        &mut self,
142        space: Option<String>,
143        name: Option<String>,
144        version: Option<String>,
145        uid: Option<String>,
146        alert_config: Option<GenAIAlertConfig>,
147    ) -> Result<(), TypeError> {
148        if name.is_some() {
149            self.name = name.ok_or(TypeError::MissingNameError)?;
150        }
151
152        if space.is_some() {
153            self.space = space.ok_or(TypeError::MissingSpaceError)?;
154        }
155
156        if version.is_some() {
157            self.version = version.ok_or(TypeError::MissingVersionError)?;
158        }
159
160        if alert_config.is_some() {
161            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
162        }
163
164        if uid.is_some() {
165            self.uid = uid.ok_or(TypeError::MissingUidError)?;
166        }
167
168        Ok(())
169    }
170}
171
172impl Default for GenAIEvalConfig {
173    fn default() -> Self {
174        Self {
175            sample_ratio: 1.0,
176            space: "default".to_string(),
177            name: "default_genai_profile".to_string(),
178            version: DEFAULT_VERSION.to_string(),
179            uid: create_uuid7(),
180            alert_config: GenAIAlertConfig::default(),
181            drift_type: DriftType::GenAI,
182        }
183    }
184}
185
186/// Validates that a prompt contains at least one required parameter.
187///
188/// LLM evaluation prompts must have either "input" or "output" parameters
189/// to access the data being evaluated.
190///
191/// # Arguments
192/// * `prompt` - The prompt to validate
193/// * `id` - Identifier for error reporting
194///
195/// # Returns
196/// * `Ok(())` if validation passes
197/// * `Err(ProfileError::MissingPromptParametersError)` if no required parameters found
198///
199/// # Errors
200/// Returns an error if the prompt lacks both "input" and "output" parameters.
201fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
202    let has_at_least_one_param = !prompt.parameters.is_empty();
203
204    if !has_at_least_one_param {
205        return Err(ProfileError::NeedAtLeastOneBoundParameterError(
206            id.to_string(),
207        ));
208    }
209
210    Ok(())
211}
212
213fn get_workflow_task<'a>(
214    workflow: &'a Workflow,
215    task_id: &'a str,
216) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
217    workflow
218        .task_list
219        .tasks
220        .get(task_id)
221        .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
222}
223
224/// Helper function to validate first tasks in workflow execution.
225fn validate_first_tasks(
226    workflow: &Workflow,
227    execution_order: &HashMap<i32, std::collections::HashSet<String>>,
228) -> Result<(), ProfileError> {
229    let first_tasks = execution_order
230        .get(&1)
231        .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
232
233    for task_id in first_tasks {
234        let task = get_workflow_task(workflow, task_id)?;
235        let task_guard = task
236            .read()
237            .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
238
239        validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
240    }
241
242    Ok(())
243}
244
245/// Validates workflow execution parameters and response types.
246///
247/// Ensures that:
248/// - First tasks have required prompt parameters
249/// - Last tasks have a response type e
250///
251/// # Arguments
252/// * `workflow` - The workflow to validate
253///
254/// # Returns
255/// * `Ok(())` if validation passes
256/// * `Err(ProfileError)` if validation fails
257///
258/// # Errors
259/// Returns various ProfileError types based on validation failures.
260fn validate_workflow(workflow: &Workflow) -> Result<(), ProfileError> {
261    let execution_order = workflow.execution_plan()?;
262
263    // Validate first tasks have required parameters
264    validate_first_tasks(workflow, &execution_order)?;
265
266    Ok(())
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
270#[pyclass]
271pub struct ExecutionNode {
272    pub id: String,
273    pub stage: usize,
274    pub parents: Vec<String>,
275    pub children: Vec<String>,
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize, Default)]
279#[pyclass]
280pub struct ExecutionPlan {
281    #[pyo3(get)]
282    pub stages: Vec<Vec<String>>,
283    #[pyo3(get)]
284    pub nodes: HashMap<String, ExecutionNode>,
285}
286
287fn initialize_node_graphs(
288    tasks: &[impl TaskAccessor],
289    graph: &mut HashMap<String, Vec<String>>,
290    reverse_graph: &mut HashMap<String, Vec<String>>,
291    in_degree: &mut HashMap<String, usize>,
292) {
293    for task in tasks {
294        let task_id = task.id().to_string();
295        graph.entry(task_id.clone()).or_default();
296        reverse_graph.entry(task_id.clone()).or_default();
297        in_degree.entry(task_id).or_insert(0);
298    }
299}
300
301fn build_dependency_edges(
302    tasks: &[impl TaskAccessor],
303    graph: &mut HashMap<String, Vec<String>>,
304    reverse_graph: &mut HashMap<String, Vec<String>>,
305    in_degree: &mut HashMap<String, usize>,
306) {
307    for task in tasks {
308        let task_id = task.id().to_string();
309        for dep in task.depends_on() {
310            graph.entry(dep.clone()).or_default().push(task_id.clone());
311            reverse_graph
312                .entry(task_id.clone())
313                .or_default()
314                .push(dep.clone());
315            *in_degree.entry(task_id.clone()).or_insert(0) += 1;
316        }
317    }
318}
319
320#[pyclass]
321#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
322pub struct GenAIEvalProfile {
323    #[pyo3(get)]
324    pub config: GenAIEvalConfig,
325
326    #[pyo3(get)]
327    pub assertion_tasks: Vec<AssertionTask>,
328
329    #[pyo3(get)]
330    pub llm_judge_tasks: Vec<LLMJudgeTask>,
331
332    #[pyo3(get)]
333    pub scouter_version: String,
334
335    pub workflow: Option<Workflow>,
336
337    pub task_ids: BTreeSet<String>,
338}
339
340#[pymethods]
341impl GenAIEvalProfile {
342    #[new]
343    #[pyo3(signature = (config, tasks))]
344    /// Create a new GenAIEvalProfile
345    /// GenAI evaluations are run asynchronously on the scouter server.
346    /// # Arguments
347    /// * `config` - GenAIEvalConfig - The configuration for the GenAI drift profile
348    /// * `tasks` - PyList - List of AssertionTask, LLMJudgeTask or ConditionalTask
349    /// # Returns
350    /// * `Result<Self, ProfileError>` - The GenAIEvalProfile
351    /// # Errors
352    /// * `ProfileError::MissingWorkflowError` - If the workflow is
353    #[instrument(skip_all)]
354    pub fn new_py(
355        config: GenAIEvalConfig,
356        tasks: &Bound<'_, PyList>,
357    ) -> Result<Self, ProfileError> {
358        let (assertion_tasks, llm_judge_tasks) = extract_assertion_tasks_from_pylist(tasks)?;
359
360        let (workflow, task_ids) = app_state()
361            .block_on(async { Self::build_profile(&llm_judge_tasks, &assertion_tasks).await })?;
362
363        Ok(Self {
364            config,
365            assertion_tasks,
366            llm_judge_tasks,
367
368            scouter_version: scouter_version(),
369            workflow,
370            task_ids,
371        })
372    }
373
374    pub fn __str__(&self) -> String {
375        // serialize the struct to a string
376        PyHelperFuncs::__str__(self)
377    }
378
379    pub fn model_dump_json(&self) -> String {
380        // serialize the struct to a string
381        PyHelperFuncs::__json__(self)
382    }
383
384    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
385        let json_value = serde_json::to_value(self)?;
386
387        // Create a new Python dictionary
388        let dict = PyDict::new(py);
389
390        // Convert JSON to Python dict
391        json_to_pyobject(py, &json_value, &dict)?;
392
393        // Return the Python dictionary
394        Ok(dict.into())
395    }
396
397    #[getter]
398    pub fn uid(&self) -> String {
399        self.config.uid.clone()
400    }
401
402    #[setter]
403    pub fn set_uid(&mut self, uid: String) {
404        self.config.uid = uid;
405    }
406
407    #[pyo3(signature = (path=None))]
408    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
409        Ok(PyHelperFuncs::save_to_json(
410            self,
411            path,
412            FileName::GenAIEvalProfile.to_str(),
413        )?)
414    }
415
416    #[staticmethod]
417    pub fn model_validate(data: &Bound<'_, PyDict>) -> GenAIEvalProfile {
418        let json_value = pyobject_to_json(data).unwrap();
419
420        let string = serde_json::to_string(&json_value).unwrap();
421        serde_json::from_str(&string).expect("Failed to load drift profile")
422    }
423
424    #[staticmethod]
425    pub fn model_validate_json(json_string: String) -> GenAIEvalProfile {
426        // deserialize the string to a struct
427        serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
428    }
429
430    #[staticmethod]
431    pub fn from_file(path: PathBuf) -> Result<GenAIEvalProfile, ProfileError> {
432        let file = std::fs::read_to_string(&path)?;
433
434        Ok(serde_json::from_str(&file)?)
435    }
436
437    #[allow(clippy::too_many_arguments)]
438    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
439    pub fn update_config_args(
440        &mut self,
441        space: Option<String>,
442        name: Option<String>,
443        version: Option<String>,
444        uid: Option<String>,
445        alert_config: Option<GenAIAlertConfig>,
446    ) -> Result<(), TypeError> {
447        self.config
448            .update_config_args(space, name, version, uid, alert_config)
449    }
450
451    /// Create a profile request from the profile
452    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
453        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
454            None
455        } else {
456            Some(self.config.version.clone())
457        };
458
459        Ok(ProfileRequest {
460            space: self.config.space.clone(),
461            profile: self.model_dump_json(),
462            drift_type: self.config.drift_type.clone(),
463            version_request: Some(VersionRequest {
464                version,
465                version_type: VersionType::Minor,
466                pre_tag: None,
467                build_tag: None,
468            }),
469            active: false,
470            deactivate_others: false,
471        })
472    }
473
474    pub fn has_llm_tasks(&self) -> bool {
475        !self.llm_judge_tasks.is_empty()
476    }
477
478    /// Check if this profile has assertions
479    pub fn has_assertions(&self) -> bool {
480        !self.assertion_tasks.is_empty()
481    }
482
483    /// Get execution order for all tasks (assertions + LLM judges)
484    pub fn get_execution_plan(&self) -> Result<ExecutionPlan, ProfileError> {
485        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
486        let mut reverse_graph: HashMap<String, Vec<String>> = HashMap::new();
487        let mut in_degree: HashMap<String, usize> = HashMap::new();
488
489        initialize_node_graphs(
490            &self.assertion_tasks,
491            &mut graph,
492            &mut reverse_graph,
493            &mut in_degree,
494        );
495        initialize_node_graphs(
496            &self.llm_judge_tasks,
497            &mut graph,
498            &mut reverse_graph,
499            &mut in_degree,
500        );
501
502        build_dependency_edges(
503            &self.assertion_tasks,
504            &mut graph,
505            &mut reverse_graph,
506            &mut in_degree,
507        );
508        build_dependency_edges(
509            &self.llm_judge_tasks,
510            &mut graph,
511            &mut reverse_graph,
512            &mut in_degree,
513        );
514
515        let mut stages = Vec::new();
516        let mut nodes: HashMap<String, ExecutionNode> = HashMap::new();
517        let mut current_level: Vec<String> = in_degree
518            .iter()
519            .filter(|(_, &degree)| degree == 0)
520            .map(|(id, _)| id.clone())
521            .collect();
522
523        let mut stage_idx = 0;
524
525        while !current_level.is_empty() {
526            stages.push(current_level.clone());
527
528            for task_id in &current_level {
529                nodes.insert(
530                    task_id.clone(),
531                    ExecutionNode {
532                        id: task_id.clone(),
533                        stage: stage_idx,
534                        parents: reverse_graph.get(task_id).cloned().unwrap_or_default(),
535                        children: graph.get(task_id).cloned().unwrap_or_default(),
536                    },
537                );
538            }
539
540            let mut next_level = Vec::new();
541            for task_id in &current_level {
542                if let Some(dependents) = graph.get(task_id) {
543                    for dependent in dependents {
544                        if let Some(degree) = in_degree.get_mut(dependent) {
545                            *degree -= 1;
546                            if *degree == 0 {
547                                next_level.push(dependent.clone());
548                            }
549                        }
550                    }
551                }
552            }
553
554            current_level = next_level;
555            stage_idx += 1;
556        }
557
558        let total_tasks = self.assertion_tasks.len() + self.llm_judge_tasks.len();
559        let processed_tasks: usize = stages.iter().map(|level| level.len()).sum();
560
561        if processed_tasks != total_tasks {
562            return Err(ProfileError::CircularDependency);
563        }
564
565        Ok(ExecutionPlan { stages, nodes })
566    }
567
568    pub fn print_execution_plan(&self) -> Result<(), ProfileError> {
569        use owo_colors::OwoColorize;
570
571        let plan = self.get_execution_plan()?;
572
573        println!("\n{}", "Evaluation Execution Plan".bold().green());
574        println!("{}", "═".repeat(70).green());
575
576        let mut conditional_count = 0;
577
578        for (level_idx, level) in plan.stages.iter().enumerate() {
579            let stage_label = format!("Stage {}", level_idx + 1);
580            println!("\n{}", stage_label.bold().cyan());
581
582            for (task_idx, task_id) in level.iter().enumerate() {
583                let is_last = task_idx == level.len() - 1;
584                let prefix = if is_last { "└─" } else { "├─" };
585
586                let task = self.get_task_by_id(task_id).ok_or_else(|| {
587                    ProfileError::NoTasksFoundError(format!("Task '{}' not found", task_id))
588                })?;
589
590                let is_conditional = if let Some(assertion) = self.get_assertion_by_id(task_id) {
591                    assertion.condition
592                } else if let Some(judge) = self.get_llm_judge_by_id(task_id) {
593                    judge.condition
594                } else {
595                    false
596                };
597
598                if is_conditional {
599                    conditional_count += 1;
600                }
601
602                let (task_type, color_fn): (&str, fn(&str) -> String) =
603                    if self.assertion_tasks.iter().any(|t| &t.id == task_id) {
604                        ("Assertion", |s: &str| s.yellow().to_string())
605                    } else {
606                        ("LLM Judge", |s: &str| s.purple().to_string())
607                    };
608
609                let conditional_marker = if is_conditional {
610                    " [CONDITIONAL]".bright_red().to_string()
611                } else {
612                    String::new()
613                };
614
615                println!(
616                    "{} {} ({}){}",
617                    prefix,
618                    task_id.bold(),
619                    color_fn(task_type),
620                    conditional_marker
621                );
622
623                let deps = task.depends_on();
624                if !deps.is_empty() {
625                    let dep_prefix = if is_last { "  " } else { "│ " };
626
627                    let (conditional_deps, normal_deps): (Vec<_>, Vec<_>) =
628                        deps.iter().partition(|dep_id| {
629                            self.get_assertion_by_id(dep_id)
630                                .map(|t| t.condition)
631                                .or_else(|| self.get_llm_judge_by_id(dep_id).map(|t| t.condition))
632                                .unwrap_or(false)
633                        });
634
635                    if !normal_deps.is_empty() {
636                        println!(
637                            "{}   {} {}",
638                            dep_prefix,
639                            "depends on:".dimmed(),
640                            normal_deps
641                                .iter()
642                                .map(|s| s.as_str())
643                                .collect::<Vec<_>>()
644                                .join(", ")
645                                .dimmed()
646                        );
647                    }
648
649                    if !conditional_deps.is_empty() {
650                        println!(
651                            "{}   {} {}",
652                            dep_prefix,
653                            "▶ conditional gate:".bright_red().dimmed(),
654                            conditional_deps
655                                .iter()
656                                .map(|d| format!("{} must pass", d))
657                                .collect::<Vec<_>>()
658                                .join(", ")
659                                .red()
660                                .dimmed()
661                        );
662                    }
663                }
664
665                if is_conditional {
666                    let continuation = if is_last { "  " } else { "│ " };
667                    println!(
668                        "{}   {} {}",
669                        continuation,
670                        "▶".bright_red(),
671                        "creates conditional branch".bright_red().dimmed()
672                    );
673                }
674            }
675        }
676
677        println!("\n{}", "═".repeat(70).green());
678        println!(
679            "{}: {} tasks across {} stages",
680            "Summary".bold(),
681            self.assertion_tasks.len() + self.llm_judge_tasks.len(),
682            plan.stages.len()
683        );
684
685        if conditional_count > 0 {
686            println!(
687                "{}: {} conditional tasks that create execution branches",
688                "Branches".bold().bright_red(),
689                conditional_count
690            );
691        }
692
693        println!();
694
695        Ok(())
696    }
697}
698
699impl Default for GenAIEvalProfile {
700    fn default() -> Self {
701        Self {
702            config: GenAIEvalConfig::default(),
703            assertion_tasks: Vec::new(),
704            llm_judge_tasks: Vec::new(),
705            scouter_version: scouter_version(),
706            workflow: None,
707            task_ids: BTreeSet::new(),
708        }
709    }
710}
711
712impl GenAIEvalProfile {
713    #[instrument(skip_all)]
714    pub async fn new(
715        config: GenAIEvalConfig,
716        tasks: Vec<EvaluationTask>,
717    ) -> Result<Self, ProfileError> {
718        let (assertion_tasks, llm_judge_tasks) = separate_tasks(tasks);
719        let (workflow, task_ids) = Self::build_profile(&llm_judge_tasks, &assertion_tasks).await?;
720
721        Ok(Self {
722            config,
723            assertion_tasks,
724            llm_judge_tasks,
725
726            scouter_version: scouter_version(),
727            workflow,
728            task_ids,
729        })
730    }
731
732    async fn build_profile(
733        judge_tasks: &[LLMJudgeTask],
734        assertion_tasks: &[AssertionTask],
735    ) -> Result<(Option<Workflow>, BTreeSet<String>), ProfileError> {
736        if assertion_tasks.is_empty() && judge_tasks.is_empty() {
737            return Err(ProfileError::EmptyTaskList);
738        }
739
740        let workflow = if !judge_tasks.is_empty() {
741            let assertion_ids: BTreeSet<String> =
742                assertion_tasks.iter().map(|t| t.id.clone()).collect();
743            let workflow = Self::build_workflow_from_judges(judge_tasks, &assertion_ids).await?;
744
745            // Validate the workflow
746            validate_workflow(&workflow)?;
747            Some(workflow)
748        } else {
749            None
750        };
751
752        // Validate LLM judge prompts individually
753        for judge in judge_tasks {
754            validate_prompt_parameters(&judge.prompt, &judge.id)?;
755        }
756
757        // Collect all task IDs from assertion and LLM judge tasks
758        let mut task_ids = BTreeSet::new();
759        for task in assertion_tasks {
760            task_ids.insert(task.id.clone());
761        }
762        for task in judge_tasks {
763            task_ids.insert(task.id.clone());
764        }
765
766        // check for duplicate task IDs across all task types
767        let total_tasks = assertion_tasks.len() + judge_tasks.len();
768
769        if task_ids.len() != total_tasks {
770            return Err(ProfileError::DuplicateTaskIds);
771        }
772        Ok((workflow, task_ids))
773    }
774
775    /// Build workflow from LLM judge tasks
776    /// # Arguments
777    /// * `judges` - Slice of LLMJudgeTask
778    /// # Returns
779    /// * `Result<Workflow, ProfileError>` - The constructed workflow
780    pub async fn build_workflow_from_judges(
781        judges: &[LLMJudgeTask],
782        assertion_ids: &BTreeSet<String>,
783    ) -> Result<Workflow, ProfileError> {
784        let mut workflow = Workflow::new(&format!("eval_workflow_{}", create_uuid7()));
785        let mut agents = HashMap::new();
786
787        for judge in judges {
788            let agent = match agents.entry(&judge.prompt.provider) {
789                Entry::Occupied(entry) => entry.into_mut(),
790                Entry::Vacant(entry) => {
791                    let agent = Agent::new(judge.prompt.provider.clone(), None).await?;
792                    workflow.add_agent(&agent);
793                    entry.insert(agent)
794                }
795            };
796
797            // filter out assertion IDs from dependencies
798            // All dependencies are checked when building a workflow, and non-judge tasks are not part of the workflow
799            let workflow_dependencies: Vec<String> = judge
800                .depends_on
801                .iter()
802                .filter(|dep_id| !assertion_ids.contains(*dep_id))
803                .cloned()
804                .collect();
805
806            let task_deps = if workflow_dependencies.is_empty() {
807                None
808            } else {
809                Some(workflow_dependencies)
810            };
811
812            let task = Task::new(
813                &agent.id,
814                judge.prompt.clone(),
815                &judge.id,
816                task_deps,
817                judge.max_retries,
818            )
819            .unwrap();
820
821            workflow.add_task(task)?;
822        }
823
824        Ok(workflow)
825    }
826}
827
828impl ProfileExt for GenAIEvalProfile {
829    #[inline]
830    fn id(&self) -> &str {
831        &self.config.uid
832    }
833
834    fn get_task_by_id(&self, id: &str) -> Option<&dyn TaskAccessor> {
835        if let Some(assertion) = self.assertion_tasks.iter().find(|t| t.id() == id) {
836            return Some(assertion);
837        }
838
839        if let Some(judge) = self.llm_judge_tasks.iter().find(|t| t.id() == id) {
840            return Some(judge);
841        }
842
843        None
844    }
845
846    #[inline]
847    /// Get assertion task by ID, first checking AssertionTasks, then ConditionalTasks
848    fn get_assertion_by_id(&self, id: &str) -> Option<&AssertionTask> {
849        self.assertion_tasks.iter().find(|t| t.id() == id)
850    }
851
852    #[inline]
853    fn get_llm_judge_by_id(&self, id: &str) -> Option<&LLMJudgeTask> {
854        self.llm_judge_tasks.iter().find(|t| t.id() == id)
855    }
856
857    #[inline]
858    fn has_llm_tasks(&self) -> bool {
859        !self.llm_judge_tasks.is_empty()
860    }
861}
862
863impl ProfileBaseArgs for GenAIEvalProfile {
864    type Config = GenAIEvalConfig;
865
866    fn config(&self) -> &Self::Config {
867        &self.config
868    }
869    fn get_base_args(&self) -> ProfileArgs {
870        ProfileArgs {
871            name: self.config.name.clone(),
872            space: self.config.space.clone(),
873            version: Some(self.config.version.clone()),
874            schedule: self.config.alert_config.schedule.clone(),
875            scouter_version: self.scouter_version.clone(),
876            drift_type: self.config.drift_type.clone(),
877        }
878    }
879
880    fn to_value(&self) -> Value {
881        serde_json::to_value(self).unwrap()
882    }
883}
884
885#[pyclass]
886#[derive(Debug, Serialize, Deserialize, Clone)]
887pub struct GenAIEvalSet {
888    #[pyo3(get)]
889    pub records: Vec<GenAIEvalTaskResult>,
890    pub inner: GenAIEvalWorkflowResult,
891}
892
893impl GenAIEvalSet {
894    pub fn build_task_entries(&mut self, record_id: &str) -> Vec<TaskResultTableEntry> {
895        // sort records by stage, then by task_id
896
897        self.records
898            .sort_by(|a, b| a.stage.cmp(&b.stage).then(a.task_id.cmp(&b.task_id)));
899
900        self.records
901            .iter()
902            .map(|record| record.to_table_entry(record_id))
903            .collect()
904    }
905
906    pub fn build_workflow_entries(&self) -> Vec<WorkflowResultTableEntry> {
907        vec![self.inner.to_table_entry()]
908    }
909
910    pub fn new(records: Vec<GenAIEvalTaskResult>, inner: GenAIEvalWorkflowResult) -> Self {
911        Self { records, inner }
912    }
913
914    pub fn empty() -> Self {
915        Self {
916            records: Vec::new(),
917            inner: GenAIEvalWorkflowResult {
918                created_at: Utc::now(),
919                record_uid: String::new(),
920                entity_id: 0,
921                total_tasks: 0,
922                passed_tasks: 0,
923                failed_tasks: 0,
924                pass_rate: 0.0,
925                duration_ms: 0,
926                entity_uid: String::new(),
927                execution_plan: ExecutionPlan::default(),
928                id: 0,
929            },
930        }
931    }
932}
933
934#[pymethods]
935impl GenAIEvalSet {
936    #[getter]
937    pub fn created_at(&self) -> DateTime<Utc> {
938        self.inner.created_at
939    }
940
941    #[getter]
942    pub fn record_uid(&self) -> String {
943        self.inner.record_uid.clone()
944    }
945
946    #[getter]
947    pub fn total_tasks(&self) -> i32 {
948        self.inner.total_tasks
949    }
950
951    #[getter]
952    pub fn passed_tasks(&self) -> i32 {
953        self.inner.passed_tasks
954    }
955
956    #[getter]
957    pub fn failed_tasks(&self) -> i32 {
958        self.inner.failed_tasks
959    }
960
961    #[getter]
962    pub fn pass_rate(&self) -> f64 {
963        self.inner.pass_rate
964    }
965
966    #[getter]
967    pub fn duration_ms(&self) -> i64 {
968        self.inner.duration_ms
969    }
970
971    pub fn __str__(&self) -> String {
972        // serialize the struct to a string
973        PyHelperFuncs::__str__(self)
974    }
975}
976
977#[pyclass]
978#[derive(Debug, Serialize, Deserialize, Clone)]
979pub struct GenAIEvalResultSet {
980    #[pyo3(get)]
981    pub records: Vec<GenAIEvalSet>,
982}
983
984#[pymethods]
985impl GenAIEvalResultSet {
986    pub fn record(&self, id: &str) -> Option<GenAIEvalSet> {
987        self.records.iter().find(|r| r.record_uid() == id).cloned()
988    }
989    pub fn __str__(&self) -> String {
990        // serialize the struct to a string
991        PyHelperFuncs::__str__(self)
992    }
993}
994
995// write test using mock feature
996#[cfg(test)]
997#[cfg(feature = "mock")]
998mod tests {
999
1000    use super::*;
1001    use crate::genai::{ComparisonOperator, EvaluationTasks};
1002    use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
1003
1004    use potato_head::mock::create_score_prompt;
1005
1006    #[test]
1007    fn test_genai_drift_config() {
1008        let mut drift_config = GenAIEvalConfig::new(
1009            MISSING,
1010            MISSING,
1011            "0.1.0",
1012            1.0,
1013            GenAIAlertConfig::default(),
1014            None,
1015        )
1016        .unwrap();
1017        assert_eq!(drift_config.name, "__missing__");
1018        assert_eq!(drift_config.space, "__missing__");
1019        assert_eq!(drift_config.version, "0.1.0");
1020        assert_eq!(
1021            drift_config.alert_config.dispatch_config,
1022            AlertDispatchConfig::default()
1023        );
1024
1025        let test_slack_dispatch_config = SlackDispatchConfig {
1026            channel: "test-channel".to_string(),
1027        };
1028        let new_alert_config = GenAIAlertConfig {
1029            schedule: "0 0 * * * *".to_string(),
1030            dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
1031            ..Default::default()
1032        };
1033
1034        drift_config
1035            .update_config_args(
1036                None,
1037                Some("test".to_string()),
1038                None,
1039                None,
1040                Some(new_alert_config),
1041            )
1042            .unwrap();
1043
1044        assert_eq!(drift_config.name, "test");
1045        assert_eq!(
1046            drift_config.alert_config.dispatch_config,
1047            AlertDispatchConfig::Slack(test_slack_dispatch_config)
1048        );
1049        assert_eq!(
1050            drift_config.alert_config.schedule,
1051            "0 0 * * * *".to_string()
1052        );
1053    }
1054
1055    #[tokio::test]
1056    async fn test_genai_drift_profile_metric() {
1057        let prompt = create_score_prompt(Some(vec!["input".to_string()]));
1058
1059        let task1 = LLMJudgeTask::new_rs(
1060            "metric1",
1061            prompt.clone(),
1062            Value::Number(4.into()),
1063            None,
1064            ComparisonOperator::GreaterThanOrEqual,
1065            None,
1066            None,
1067            None,
1068            None,
1069        );
1070
1071        let task2 = LLMJudgeTask::new_rs(
1072            "metric2",
1073            prompt.clone(),
1074            Value::Number(2.into()),
1075            None,
1076            ComparisonOperator::LessThanOrEqual,
1077            None,
1078            None,
1079            None,
1080            None,
1081        );
1082
1083        let tasks = EvaluationTasks::new()
1084            .add_task(task1)
1085            .add_task(task2)
1086            .build();
1087
1088        let alert_config = GenAIAlertConfig {
1089            schedule: "0 0 * * * *".to_string(),
1090            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
1091                team: "test-team".to_string(),
1092                priority: "P5".to_string(),
1093            }),
1094            ..Default::default()
1095        };
1096
1097        let drift_config =
1098            GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
1099
1100        let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();
1101
1102        let _: Value =
1103            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
1104
1105        assert_eq!(profile.llm_judge_tasks.len(), 2);
1106        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
1107    }
1108}