Skip to main content

scouter_types/genai/
profile.rs

1use crate::error::{ProfileError, TypeError};
2use crate::genai::alert::GenAIAlertConfig;
3use crate::genai::eval::{AssertionTask, EvaluationTask, LLMJudgeTask};
4use crate::genai::traits::{separate_tasks, ProfileExt, TaskAccessor};
5use crate::genai::utils::{extract_assertion_tasks_from_pylist, AssertionTasks};
6use crate::genai::{AgentAssertionTask, TraceAssertionTask};
7use crate::traits::ConfigExt;
8use crate::util::{json_to_pyobject, pyobject_to_json};
9use crate::{scouter_version, EvalTaskResult, GenAIEvalWorkflowResult, WorkflowResultTableEntry};
10use crate::{
11    DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
12    PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
13};
14use crate::{ProfileRequest, TaskResultTableEntry};
15use chrono::{DateTime, Utc};
16use core::fmt::Debug;
17use potato_head::prompt_types::Prompt;
18use potato_head::Agent;
19use potato_head::Workflow;
20use potato_head::{create_uuid7, Task};
21use pyo3::prelude::*;
22use pyo3::types::{PyDict, PyList};
23use scouter_semver::VersionType;
24use scouter_state::app_state;
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::hash_map::Entry;
28use std::collections::BTreeSet;
29use std::collections::HashMap;
30use std::path::PathBuf;
31use std::sync::Arc;
32
33use tracing::instrument;
34
35fn default_sample_ratio() -> f64 {
36    1.0
37}
38
39fn default_space() -> String {
40    MISSING.to_string()
41}
42
43fn default_name() -> String {
44    MISSING.to_string()
45}
46
47fn default_version() -> String {
48    DEFAULT_VERSION.to_string()
49}
50
51fn default_uid() -> String {
52    create_uuid7()
53}
54
55fn default_drift_type() -> DriftType {
56    DriftType::GenAI
57}
58
59fn default_alert_config() -> GenAIAlertConfig {
60    GenAIAlertConfig::default()
61}
62
63#[pyclass]
64#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
65pub struct GenAIEvalConfig {
66    #[pyo3(get, set)]
67    #[serde(default = "default_sample_ratio")]
68    pub sample_ratio: f64,
69
70    #[pyo3(get, set)]
71    #[serde(default = "default_space")]
72    pub space: String,
73
74    #[pyo3(get, set)]
75    #[serde(default = "default_name")]
76    pub name: String,
77
78    #[pyo3(get, set)]
79    #[serde(default = "default_version")]
80    pub version: String,
81
82    #[pyo3(get, set)]
83    #[serde(default = "default_uid")]
84    pub uid: String,
85
86    #[pyo3(get, set)]
87    #[serde(default = "default_alert_config")]
88    pub alert_config: GenAIAlertConfig,
89
90    #[pyo3(get, set)]
91    #[serde(default = "default_drift_type")]
92    pub drift_type: DriftType,
93}
94
95impl ConfigExt for GenAIEvalConfig {
96    fn space(&self) -> &str {
97        &self.space
98    }
99
100    fn name(&self) -> &str {
101        &self.name
102    }
103
104    fn version(&self) -> &str {
105        &self.version
106    }
107    fn uid(&self) -> &str {
108        &self.uid
109    }
110}
111
112impl DispatchDriftConfig for GenAIEvalConfig {
113    fn get_drift_args(&self) -> DriftArgs {
114        DriftArgs {
115            name: self.name.clone(),
116            space: self.space.clone(),
117            version: self.version.clone(),
118            dispatch_config: self.alert_config.dispatch_config.clone(),
119        }
120    }
121}
122
123#[pymethods]
124#[allow(clippy::too_many_arguments)]
125impl GenAIEvalConfig {
126    #[new]
127    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_ratio=1.0, alert_config=GenAIAlertConfig::default(), config_path=None))]
128    pub fn new(
129        space: &str,
130        name: &str,
131        version: &str,
132        sample_ratio: f64,
133        alert_config: GenAIAlertConfig,
134        config_path: Option<PathBuf>,
135    ) -> Result<Self, ProfileError> {
136        if let Some(config_path) = config_path {
137            let config = GenAIEvalConfig::load_from_json_file(config_path)?;
138            return Ok(config);
139        }
140
141        Ok(Self {
142            sample_ratio: sample_ratio.clamp(0.0, 1.0),
143            space: space.to_string(),
144            name: name.to_string(),
145            uid: create_uuid7(),
146            version: version.to_string(),
147            alert_config,
148            drift_type: DriftType::GenAI,
149        })
150    }
151
152    #[staticmethod]
153    pub fn load_from_json_file(path: PathBuf) -> Result<GenAIEvalConfig, ProfileError> {
154        // deserialize the string to a struct
155
156        let file = std::fs::read_to_string(&path)?;
157
158        Ok(serde_json::from_str(&file)?)
159    }
160
161    pub fn __str__(&self) -> String {
162        // serialize the struct to a string
163        PyHelperFuncs::__str__(self)
164    }
165
166    pub fn model_dump_json(&self) -> String {
167        // serialize the struct to a string
168        PyHelperFuncs::__json__(self)
169    }
170
171    #[allow(clippy::too_many_arguments)]
172    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
173    pub fn update_config_args(
174        &mut self,
175        space: Option<String>,
176        name: Option<String>,
177        version: Option<String>,
178        uid: Option<String>,
179        alert_config: Option<GenAIAlertConfig>,
180    ) -> Result<(), TypeError> {
181        if name.is_some() {
182            self.name = name.ok_or(TypeError::MissingNameError)?;
183        }
184
185        if space.is_some() {
186            self.space = space.ok_or(TypeError::MissingSpaceError)?;
187        }
188
189        if version.is_some() {
190            self.version = version.ok_or(TypeError::MissingVersionError)?;
191        }
192
193        if alert_config.is_some() {
194            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
195        }
196
197        if uid.is_some() {
198            self.uid = uid.ok_or(TypeError::MissingUidError)?;
199        }
200
201        Ok(())
202    }
203}
204
205impl Default for GenAIEvalConfig {
206    fn default() -> Self {
207        Self {
208            sample_ratio: 1.0,
209            space: "default".to_string(),
210            name: "default_genai_profile".to_string(),
211            version: DEFAULT_VERSION.to_string(),
212            uid: create_uuid7(),
213            alert_config: GenAIAlertConfig::default(),
214            drift_type: DriftType::GenAI,
215        }
216    }
217}
218
219/// Validates that a prompt contains at least one required parameter.
220///
221/// LLM evaluation prompts must have either "input" or "output" parameters
222/// to access the data being evaluated.
223///
224/// # Arguments
225/// * `prompt` - The prompt to validate
226/// * `id` - Identifier for error reporting
227///
228/// # Returns
229/// * `Ok(())` if validation passes
230/// * `Err(ProfileError::MissingPromptParametersError)` if no required parameters found
231///
232/// # Errors
233/// Returns an error if the prompt lacks both "input" and "output" parameters.
234fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
235    let has_at_least_one_param = !prompt.parameters.is_empty();
236
237    if !has_at_least_one_param {
238        return Err(ProfileError::NeedAtLeastOneBoundParameterError(
239            id.to_string(),
240        ));
241    }
242
243    Ok(())
244}
245
246fn get_workflow_task<'a>(
247    workflow: &'a Workflow,
248    task_id: &'a str,
249) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
250    workflow
251        .task_list
252        .tasks
253        .get(task_id)
254        .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
255}
256
257/// Helper function to validate first tasks in workflow execution.
258fn validate_first_tasks(
259    workflow: &Workflow,
260    execution_order: &HashMap<i32, std::collections::HashSet<String>>,
261) -> Result<(), ProfileError> {
262    let first_tasks = execution_order
263        .get(&1)
264        .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
265
266    for task_id in first_tasks {
267        let task = get_workflow_task(workflow, task_id)?;
268        let task_guard = task
269            .read()
270            .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
271
272        validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
273    }
274
275    Ok(())
276}
277
278/// Validates workflow execution parameters and response types.
279///
280/// Ensures that:
281/// - First tasks have required prompt parameters
282/// - Last tasks have a response type e
283///
284/// # Arguments
285/// * `workflow` - The workflow to validate
286///
287/// # Returns
288/// * `Ok(())` if validation passes
289/// * `Err(ProfileError)` if validation fails
290///
291/// # Errors
292/// Returns various ProfileError types based on validation failures.
293fn validate_workflow(workflow: &Workflow) -> Result<(), ProfileError> {
294    let execution_order = workflow.execution_plan()?;
295
296    // Validate first tasks have required parameters
297    validate_first_tasks(workflow, &execution_order)?;
298
299    Ok(())
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
303#[pyclass]
304pub struct ExecutionNode {
305    pub id: String,
306    pub stage: usize,
307    pub parents: Vec<String>,
308    pub children: Vec<String>,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize, Default)]
312#[pyclass]
313pub struct ExecutionPlan {
314    #[pyo3(get)]
315    pub stages: Vec<Vec<String>>,
316    #[pyo3(get)]
317    pub nodes: HashMap<String, ExecutionNode>,
318}
319
320fn initialize_node_graphs(
321    tasks: &[impl TaskAccessor],
322    graph: &mut HashMap<String, Vec<String>>,
323    reverse_graph: &mut HashMap<String, Vec<String>>,
324    in_degree: &mut HashMap<String, usize>,
325) {
326    for task in tasks {
327        let task_id = task.id().to_string();
328        graph.entry(task_id.clone()).or_default();
329        reverse_graph.entry(task_id.clone()).or_default();
330        in_degree.entry(task_id).or_insert(0);
331    }
332}
333
334fn build_dependency_edges(
335    tasks: &[impl TaskAccessor],
336    graph: &mut HashMap<String, Vec<String>>,
337    reverse_graph: &mut HashMap<String, Vec<String>>,
338    in_degree: &mut HashMap<String, usize>,
339) {
340    for task in tasks {
341        let task_id = task.id().to_string();
342        for dep in task.depends_on() {
343            graph.entry(dep.clone()).or_default().push(task_id.clone());
344            reverse_graph
345                .entry(task_id.clone())
346                .or_default()
347                .push(dep.clone());
348            *in_degree.entry(task_id.clone()).or_insert(0) += 1;
349        }
350    }
351}
352
353#[pyclass]
354#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
355pub struct GenAIEvalProfile {
356    #[pyo3(get)]
357    pub config: GenAIEvalConfig,
358
359    pub tasks: AssertionTasks,
360
361    #[pyo3(get)]
362    pub scouter_version: String,
363
364    pub workflow: Option<Workflow>,
365
366    pub task_ids: BTreeSet<String>,
367
368    pub alias: Option<String>,
369}
370
371#[pymethods]
372impl GenAIEvalProfile {
373    #[new]
374    #[pyo3(signature = (tasks, config=None, alias=None))]
375    /// Create a new GenAIEvalProfile
376    /// GenAI evaluations are run asynchronously on the scouter server.
377    /// # Arguments
378    /// * `config` - GenAIEvalConfig - The configuration for the GenAI drift profile
379    /// * `tasks` - PyList - List of AssertionTask, LLMJudgeTask or ConditionalTask
380    /// * `alias` - Option<String> - Optional alias for the profile
381    /// # Returns
382    /// * `Result<Self, ProfileError>` - The GenAIEvalProfile
383    /// # Errors
384    /// * `ProfileError::MissingWorkflowError` - If the workflow is
385    #[instrument(skip_all)]
386    pub fn new_py(
387        tasks: &Bound<'_, PyList>,
388        config: Option<GenAIEvalConfig>,
389        alias: Option<String>,
390    ) -> Result<Self, ProfileError> {
391        let tasks = extract_assertion_tasks_from_pylist(tasks)?;
392
393        let (workflow, task_ids) =
394            app_state().block_on(async { Self::build_profile(&tasks).await })?;
395
396        Ok(Self {
397            config: config.unwrap_or_default(),
398            tasks,
399            scouter_version: scouter_version(),
400            workflow,
401            task_ids,
402            alias,
403        })
404    }
405
406    pub fn __str__(&self) -> String {
407        // serialize the struct to a string
408        PyHelperFuncs::__str__(self)
409    }
410
411    pub fn model_dump_json(&self) -> String {
412        // serialize the struct to a string
413        PyHelperFuncs::__json__(self)
414    }
415
416    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
417        let json_value = serde_json::to_value(self)?;
418
419        // Create a new Python dictionary
420        let dict = PyDict::new(py);
421
422        // Convert JSON to Python dict
423        json_to_pyobject(py, &json_value, &dict)?;
424
425        // Return the Python dictionary
426        Ok(dict.into())
427    }
428
429    #[getter]
430    pub fn drift_type(&self) -> DriftType {
431        self.config.drift_type.clone()
432    }
433
434    #[getter]
435    pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
436        self.tasks.assertion.clone()
437    }
438
439    #[getter]
440    pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
441        self.tasks.judge.clone()
442    }
443
444    #[getter]
445    pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
446        self.tasks.trace.clone()
447    }
448
449    #[getter]
450    pub fn agent_assertion_tasks(&self) -> Vec<AgentAssertionTask> {
451        self.tasks.agent.clone()
452    }
453
454    #[getter]
455    pub fn alias(&self) -> Option<String> {
456        self.alias.clone()
457    }
458
459    #[getter]
460    pub fn uid(&self) -> String {
461        self.config.uid.clone()
462    }
463
464    #[setter]
465    pub fn set_uid(&mut self, uid: String) {
466        self.config.uid = uid;
467    }
468
469    #[pyo3(signature = (path=None))]
470    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
471        Ok(PyHelperFuncs::save_to_json(
472            self,
473            path,
474            FileName::GenAIEvalProfile.to_str(),
475        )?)
476    }
477
478    #[staticmethod]
479    pub fn model_validate(data: &Bound<'_, PyDict>) -> GenAIEvalProfile {
480        let json_value = pyobject_to_json(data).unwrap();
481
482        let string = serde_json::to_string(&json_value).unwrap();
483        serde_json::from_str(&string).expect("Failed to load drift profile")
484    }
485
486    #[staticmethod]
487    pub fn model_validate_json(json_string: String) -> GenAIEvalProfile {
488        // deserialize the string to a struct
489        serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
490    }
491
492    #[staticmethod]
493    pub fn from_file(path: PathBuf) -> Result<GenAIEvalProfile, ProfileError> {
494        let file = std::fs::read_to_string(&path)?;
495
496        Ok(serde_json::from_str(&file)?)
497    }
498
499    #[allow(clippy::too_many_arguments)]
500    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
501    pub fn update_config_args(
502        &mut self,
503        space: Option<String>,
504        name: Option<String>,
505        version: Option<String>,
506        uid: Option<String>,
507        alert_config: Option<GenAIAlertConfig>,
508    ) -> Result<(), TypeError> {
509        self.config
510            .update_config_args(space, name, version, uid, alert_config)
511    }
512
513    /// Create a profile request from the profile
514    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
515        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
516            None
517        } else {
518            Some(self.config.version.clone())
519        };
520
521        Ok(ProfileRequest {
522            space: self.config.space.clone(),
523            profile: self.model_dump_json(),
524            drift_type: self.config.drift_type.clone(),
525            version_request: Some(VersionRequest {
526                version,
527                version_type: VersionType::Minor,
528                pre_tag: None,
529                build_tag: None,
530            }),
531            active: false,
532            deactivate_others: false,
533        })
534    }
535
536    pub fn has_llm_tasks(&self) -> bool {
537        !self.tasks.judge.is_empty()
538    }
539
540    /// Check if this profile has assertions
541    pub fn has_assertions(&self) -> bool {
542        !self.tasks.assertion.is_empty()
543    }
544
545    pub fn has_trace_assertions(&self) -> bool {
546        !self.tasks.trace.is_empty()
547    }
548
549    pub fn has_agent_assertions(&self) -> bool {
550        !self.tasks.agent.is_empty()
551    }
552
553    /// Get execution order for all tasks (assertions + LLM judges + trace assertions + request assertions)
554    pub fn get_execution_plan(&self) -> Result<ExecutionPlan, ProfileError> {
555        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
556        let mut reverse_graph: HashMap<String, Vec<String>> = HashMap::new();
557        let mut in_degree: HashMap<String, usize> = HashMap::new();
558
559        initialize_node_graphs(
560            &self.tasks.assertion,
561            &mut graph,
562            &mut reverse_graph,
563            &mut in_degree,
564        );
565        initialize_node_graphs(
566            &self.tasks.judge,
567            &mut graph,
568            &mut reverse_graph,
569            &mut in_degree,
570        );
571
572        initialize_node_graphs(
573            &self.tasks.trace,
574            &mut graph,
575            &mut reverse_graph,
576            &mut in_degree,
577        );
578
579        initialize_node_graphs(
580            &self.tasks.agent,
581            &mut graph,
582            &mut reverse_graph,
583            &mut in_degree,
584        );
585
586        build_dependency_edges(
587            &self.tasks.assertion,
588            &mut graph,
589            &mut reverse_graph,
590            &mut in_degree,
591        );
592        build_dependency_edges(
593            &self.tasks.judge,
594            &mut graph,
595            &mut reverse_graph,
596            &mut in_degree,
597        );
598
599        build_dependency_edges(
600            &self.tasks.trace,
601            &mut graph,
602            &mut reverse_graph,
603            &mut in_degree,
604        );
605
606        build_dependency_edges(
607            &self.tasks.agent,
608            &mut graph,
609            &mut reverse_graph,
610            &mut in_degree,
611        );
612        let mut stages = Vec::new();
613        let mut nodes: HashMap<String, ExecutionNode> = HashMap::new();
614        let mut current_level: Vec<String> = in_degree
615            .iter()
616            .filter(|(_, &degree)| degree == 0)
617            .map(|(id, _)| id.clone())
618            .collect();
619
620        let mut stage_idx = 0;
621
622        while !current_level.is_empty() {
623            stages.push(current_level.clone());
624
625            for task_id in &current_level {
626                nodes.insert(
627                    task_id.clone(),
628                    ExecutionNode {
629                        id: task_id.clone(),
630                        stage: stage_idx,
631                        parents: reverse_graph.get(task_id).cloned().unwrap_or_default(),
632                        children: graph.get(task_id).cloned().unwrap_or_default(),
633                    },
634                );
635            }
636
637            let mut next_level = Vec::new();
638            for task_id in &current_level {
639                if let Some(dependents) = graph.get(task_id) {
640                    for dependent in dependents {
641                        if let Some(degree) = in_degree.get_mut(dependent) {
642                            *degree -= 1;
643                            if *degree == 0 {
644                                next_level.push(dependent.clone());
645                            }
646                        }
647                    }
648                }
649            }
650
651            current_level = next_level;
652            stage_idx += 1;
653        }
654
655        let total_tasks = self.tasks.assertion.len()
656            + self.tasks.judge.len()
657            + self.tasks.trace.len()
658            + self.tasks.agent.len();
659        let processed_tasks: usize = stages.iter().map(|level| level.len()).sum();
660
661        if processed_tasks != total_tasks {
662            return Err(ProfileError::CircularDependency);
663        }
664
665        Ok(ExecutionPlan { stages, nodes })
666    }
667
668    pub fn print_execution_plan(&self) -> Result<(), ProfileError> {
669        use owo_colors::OwoColorize;
670
671        let plan = self.get_execution_plan()?;
672
673        println!("\n{}", "Evaluation Execution Plan".bold().green());
674        println!("{}", "═".repeat(70).green());
675
676        let mut conditional_count = 0;
677
678        for (level_idx, level) in plan.stages.iter().enumerate() {
679            let stage_label = format!("Stage {}", level_idx + 1);
680            println!("\n{}", stage_label.bold().cyan());
681
682            for (task_idx, task_id) in level.iter().enumerate() {
683                let is_last = task_idx == level.len() - 1;
684                let prefix = if is_last { "└─" } else { "├─" };
685
686                let task = self.get_task_by_id(task_id).ok_or_else(|| {
687                    ProfileError::NoTasksFoundError(format!("Task '{}' not found", task_id))
688                })?;
689
690                let is_conditional = if let Some(assertion) = self.get_assertion_by_id(task_id) {
691                    assertion.condition
692                } else if let Some(judge) = self.get_llm_judge_by_id(task_id) {
693                    judge.condition
694                } else if let Some(trace) = self.get_trace_assertion_by_id(task_id) {
695                    trace.condition
696                } else if let Some(request) = self.get_agent_assertion_by_id(task_id) {
697                    request.condition
698                } else {
699                    false
700                };
701
702                if is_conditional {
703                    conditional_count += 1;
704                }
705
706                let (task_type, color_fn): (&str, fn(&str) -> String) =
707                    if self.tasks.assertion.iter().any(|t| &t.id == task_id) {
708                        ("Assertion", |s: &str| s.yellow().to_string())
709                    } else if self.tasks.trace.iter().any(|t| &t.id == task_id) {
710                        ("Trace Assertion", |s: &str| s.bright_blue().to_string())
711                    } else if self.tasks.agent.iter().any(|t| &t.id == task_id) {
712                        ("Request Assertion", |s: &str| s.bright_green().to_string())
713                    } else {
714                        ("LLM Judge", |s: &str| s.purple().to_string())
715                    };
716
717                let conditional_marker = if is_conditional {
718                    " [CONDITIONAL]".bright_red().to_string()
719                } else {
720                    String::new()
721                };
722
723                println!(
724                    "{} {} ({}){}",
725                    prefix,
726                    task_id.bold(),
727                    color_fn(task_type),
728                    conditional_marker
729                );
730
731                let deps = task.depends_on();
732                if !deps.is_empty() {
733                    let dep_prefix = if is_last { "  " } else { "│ " };
734
735                    let (conditional_deps, normal_deps): (Vec<_>, Vec<_>) =
736                        deps.iter().partition(|dep_id| {
737                            self.get_assertion_by_id(dep_id)
738                                .map(|t| t.condition)
739                                .or_else(|| self.get_llm_judge_by_id(dep_id).map(|t| t.condition))
740                                .or_else(|| {
741                                    self.get_trace_assertion_by_id(dep_id).map(|t| t.condition)
742                                })
743                                .or_else(|| {
744                                    self.get_agent_assertion_by_id(dep_id).map(|t| t.condition)
745                                })
746                                .unwrap_or(false)
747                        });
748
749                    if !normal_deps.is_empty() {
750                        println!(
751                            "{}   {} {}",
752                            dep_prefix,
753                            "depends on:".dimmed(),
754                            normal_deps
755                                .iter()
756                                .map(|s| s.as_str())
757                                .collect::<Vec<_>>()
758                                .join(", ")
759                                .dimmed()
760                        );
761                    }
762
763                    if !conditional_deps.is_empty() {
764                        println!(
765                            "{}   {} {}",
766                            dep_prefix,
767                            "▶ conditional gate:".bright_red().dimmed(),
768                            conditional_deps
769                                .iter()
770                                .map(|d| format!("{} must pass", d))
771                                .collect::<Vec<_>>()
772                                .join(", ")
773                                .red()
774                                .dimmed()
775                        );
776                    }
777                }
778
779                if is_conditional {
780                    let continuation = if is_last { "  " } else { "│ " };
781                    println!(
782                        "{}   {} {}",
783                        continuation,
784                        "▶".bright_red(),
785                        "creates conditional branch".bright_red().dimmed()
786                    );
787                }
788            }
789        }
790
791        println!("\n{}", "═".repeat(70).green());
792        println!(
793            "{}: {} tasks across {} stages",
794            "Summary".bold(),
795            self.tasks.assertion.len()
796                + self.tasks.judge.len()
797                + self.tasks.trace.len()
798                + self.tasks.agent.len(),
799            plan.stages.len()
800        );
801
802        if conditional_count > 0 {
803            println!(
804                "{}: {} conditional tasks that create execution branches",
805                "Branches".bold().bright_red(),
806                conditional_count
807            );
808        }
809
810        println!();
811
812        Ok(())
813    }
814}
815
816impl Default for GenAIEvalProfile {
817    fn default() -> Self {
818        Self {
819            config: GenAIEvalConfig::default(),
820            tasks: AssertionTasks {
821                assertion: Vec::new(),
822                judge: Vec::new(),
823                trace: Vec::new(),
824                agent: Vec::new(),
825            },
826            scouter_version: scouter_version(),
827            workflow: None,
828            task_ids: BTreeSet::new(),
829            alias: None,
830        }
831    }
832}
833
834impl GenAIEvalProfile {
835    /// Helper method to build profile from given tasks
836    pub fn build_from_parts(
837        config: GenAIEvalConfig,
838        tasks: AssertionTasks,
839        alias: Option<String>,
840    ) -> Result<GenAIEvalProfile, ProfileError> {
841        let (workflow, task_ids) =
842            app_state().block_on(async { GenAIEvalProfile::build_profile(&tasks).await })?;
843
844        Ok(GenAIEvalProfile {
845            config,
846            tasks,
847            scouter_version: scouter_version(),
848            workflow,
849            task_ids,
850            alias,
851        })
852    }
853
854    /// Async version of `build_from_parts` — safe to call from within an async context.
855    pub async fn build_from_parts_async(
856        config: GenAIEvalConfig,
857        tasks: AssertionTasks,
858        alias: Option<String>,
859    ) -> Result<GenAIEvalProfile, ProfileError> {
860        let (workflow, task_ids) = GenAIEvalProfile::build_profile(&tasks).await?;
861
862        Ok(GenAIEvalProfile {
863            config,
864            tasks,
865            scouter_version: scouter_version(),
866            workflow,
867            task_ids,
868            alias,
869        })
870    }
871
872    #[instrument(skip_all)]
873    pub async fn new(
874        config: GenAIEvalConfig,
875        tasks: Vec<EvaluationTask>,
876    ) -> Result<Self, ProfileError> {
877        let tasks = separate_tasks(tasks);
878        let (workflow, task_ids) = Self::build_profile(&tasks).await?;
879
880        Ok(Self {
881            config,
882            tasks,
883            scouter_version: scouter_version(),
884            workflow,
885            task_ids,
886            alias: None,
887        })
888    }
889
890    async fn build_profile(
891        tasks: &AssertionTasks,
892    ) -> Result<(Option<Workflow>, BTreeSet<String>), ProfileError> {
893        if tasks.assertion.is_empty()
894            && tasks.judge.is_empty()
895            && tasks.trace.is_empty()
896            && tasks.agent.is_empty()
897        {
898            return Err(ProfileError::EmptyTaskList);
899        }
900
901        let workflow = if !tasks.judge.is_empty() {
902            let workflow = Self::build_workflow_from_judges(tasks).await?;
903            validate_workflow(&workflow)?;
904            Some(workflow)
905        } else {
906            None
907        };
908
909        // Validate LLM judge prompts individually
910        for judge in &tasks.judge {
911            validate_prompt_parameters(&judge.prompt, &judge.id)?;
912        }
913
914        // Collect all task IDs
915        let task_ids = tasks.collect_all_task_ids()?;
916
917        Ok((workflow, task_ids))
918    }
919
920    async fn get_or_create_agent(
921        agents: &mut HashMap<potato_head::Provider, Agent>,
922        workflow: &mut Workflow,
923        provider: &potato_head::Provider,
924    ) -> Result<Agent, ProfileError> {
925        match agents.entry(provider.clone()) {
926            Entry::Occupied(entry) => Ok(entry.get().clone()),
927            Entry::Vacant(entry) => {
928                let agent = Agent::new(provider.clone(), None).await?;
929                workflow.add_agent(&agent);
930                Ok(entry.insert(agent).clone())
931            }
932        }
933    }
934
935    fn filter_judge_dependencies(
936        depends_on: &[String],
937        non_judge_task_ids: &BTreeSet<String>,
938    ) -> Option<Vec<String>> {
939        let filtered: Vec<String> = depends_on
940            .iter()
941            .filter(|dep_id| !non_judge_task_ids.contains(*dep_id))
942            .cloned()
943            .collect();
944
945        if filtered.is_empty() {
946            None
947        } else {
948            Some(filtered)
949        }
950    }
951
952    /// Build workflow from LLM judge tasks
953    /// # Arguments
954    /// * `tasks` - Reference to AssertionTasks
955    /// # Returns
956    /// * `Result<Workflow, ProfileError>` - The constructed workflow
957    pub async fn build_workflow_from_judges(
958        tasks: &AssertionTasks,
959    ) -> Result<Workflow, ProfileError> {
960        let mut workflow = Workflow::new(&format!("eval_workflow_{}", create_uuid7()));
961        let mut agents = HashMap::new();
962        let non_judge_task_ids = tasks.collect_non_judge_task_ids();
963
964        for judge in &tasks.judge {
965            let agent =
966                Self::get_or_create_agent(&mut agents, &mut workflow, &judge.prompt.provider)
967                    .await?;
968
969            let task_deps = Self::filter_judge_dependencies(&judge.depends_on, &non_judge_task_ids);
970
971            let task = Task::new(
972                &agent.id,
973                judge.prompt.clone(),
974                &judge.id,
975                task_deps,
976                judge.max_retries,
977            )?;
978
979            workflow.add_task(task)?;
980        }
981
982        Ok(workflow)
983    }
984}
985
986impl ProfileExt for GenAIEvalProfile {
987    #[inline]
988    fn id(&self) -> &str {
989        &self.config.uid
990    }
991
992    fn get_task_by_id(&self, id: &str) -> Option<&dyn TaskAccessor> {
993        if let Some(assertion) = self.tasks.assertion.iter().find(|t| t.id() == id) {
994            return Some(assertion);
995        }
996
997        if let Some(judge) = self.tasks.judge.iter().find(|t| t.id() == id) {
998            return Some(judge);
999        }
1000
1001        if let Some(trace) = self.tasks.trace.iter().find(|t| t.id() == id) {
1002            return Some(trace);
1003        }
1004
1005        if let Some(request) = self.tasks.agent.iter().find(|t| t.id() == id) {
1006            return Some(request);
1007        }
1008
1009        None
1010    }
1011
1012    #[inline]
1013    /// Get assertion task by ID, first checking AssertionTasks, then ConditionalTasks
1014    fn get_assertion_by_id(&self, id: &str) -> Option<&AssertionTask> {
1015        self.tasks.assertion.iter().find(|t| t.id() == id)
1016    }
1017
1018    #[inline]
1019    fn get_llm_judge_by_id(&self, id: &str) -> Option<&LLMJudgeTask> {
1020        self.tasks.judge.iter().find(|t| t.id() == id)
1021    }
1022
1023    #[inline]
1024    fn get_trace_assertion_by_id(&self, id: &str) -> Option<&TraceAssertionTask> {
1025        self.tasks.trace.iter().find(|t| t.id() == id)
1026    }
1027
1028    #[inline]
1029    fn get_agent_assertion_by_id(&self, id: &str) -> Option<&AgentAssertionTask> {
1030        self.tasks.agent.iter().find(|t| t.id() == id)
1031    }
1032
1033    #[inline]
1034    fn has_llm_tasks(&self) -> bool {
1035        !self.tasks.judge.is_empty()
1036    }
1037
1038    #[inline]
1039    fn has_trace_assertions(&self) -> bool {
1040        !self.tasks.trace.is_empty()
1041    }
1042
1043    #[inline]
1044    fn has_agent_assertions(&self) -> bool {
1045        !self.tasks.agent.is_empty()
1046    }
1047}
1048
1049impl ProfileBaseArgs for GenAIEvalProfile {
1050    type Config = GenAIEvalConfig;
1051
1052    fn config(&self) -> &Self::Config {
1053        &self.config
1054    }
1055    fn get_base_args(&self) -> ProfileArgs {
1056        ProfileArgs {
1057            name: self.config.name.clone(),
1058            space: self.config.space.clone(),
1059            version: Some(self.config.version.clone()),
1060            schedule: self.config.alert_config.schedule.clone(),
1061            scouter_version: self.scouter_version.clone(),
1062            drift_type: self.config.drift_type.clone(),
1063        }
1064    }
1065
1066    fn to_value(&self) -> Value {
1067        serde_json::to_value(self).unwrap()
1068    }
1069}
1070
1071#[pyclass]
1072#[derive(Debug, Serialize, Deserialize, Clone)]
1073pub struct EvalSet {
1074    #[pyo3(get)]
1075    pub records: Vec<EvalTaskResult>,
1076    pub inner: GenAIEvalWorkflowResult,
1077}
1078
1079impl EvalSet {
1080    pub fn build_task_entries(&mut self, record_id: &str) -> Vec<TaskResultTableEntry> {
1081        // sort records by stage, then by task_id
1082
1083        self.records
1084            .sort_by(|a, b| a.stage.cmp(&b.stage).then(a.task_id.cmp(&b.task_id)));
1085
1086        self.records
1087            .iter()
1088            .map(|record| record.to_table_entry(record_id))
1089            .collect()
1090    }
1091
1092    pub fn build_workflow_entries(&self) -> Vec<WorkflowResultTableEntry> {
1093        vec![self.inner.to_table_entry()]
1094    }
1095
1096    pub fn new(records: Vec<EvalTaskResult>, inner: GenAIEvalWorkflowResult) -> Self {
1097        Self { records, inner }
1098    }
1099
1100    pub fn empty() -> Self {
1101        Self {
1102            records: Vec::new(),
1103            inner: GenAIEvalWorkflowResult {
1104                created_at: Utc::now(),
1105                record_uid: String::new(),
1106                entity_id: 0,
1107                total_tasks: 0,
1108                passed_tasks: 0,
1109                failed_tasks: 0,
1110                pass_rate: 0.0,
1111                duration_ms: 0,
1112                entity_uid: String::new(),
1113                execution_plan: ExecutionPlan::default(),
1114                id: 0,
1115            },
1116        }
1117    }
1118}
1119
1120#[pymethods]
1121impl EvalSet {
1122    #[getter]
1123    pub fn created_at(&self) -> DateTime<Utc> {
1124        self.inner.created_at
1125    }
1126
1127    #[getter]
1128    pub fn record_uid(&self) -> String {
1129        self.inner.record_uid.clone()
1130    }
1131
1132    #[getter]
1133    pub fn total_tasks(&self) -> i32 {
1134        self.inner.total_tasks
1135    }
1136
1137    #[getter]
1138    pub fn passed_tasks(&self) -> i32 {
1139        self.inner.passed_tasks
1140    }
1141
1142    #[getter]
1143    pub fn failed_tasks(&self) -> i32 {
1144        self.inner.failed_tasks
1145    }
1146
1147    #[getter]
1148    pub fn pass_rate(&self) -> f64 {
1149        self.inner.pass_rate
1150    }
1151
1152    #[getter]
1153    pub fn duration_ms(&self) -> i64 {
1154        self.inner.duration_ms
1155    }
1156
1157    pub fn __str__(&self) -> String {
1158        // serialize the struct to a string
1159        PyHelperFuncs::__str__(self)
1160    }
1161}
1162
1163#[pyclass]
1164#[derive(Debug, Serialize, Deserialize, Clone)]
1165pub struct EvalResultSet {
1166    #[pyo3(get)]
1167    pub records: Vec<EvalSet>,
1168}
1169
1170#[pymethods]
1171impl EvalResultSet {
1172    pub fn record(&self, id: &str) -> Option<EvalSet> {
1173        self.records.iter().find(|r| r.record_uid() == id).cloned()
1174    }
1175    pub fn __str__(&self) -> String {
1176        // serialize the struct to a string
1177        PyHelperFuncs::__str__(self)
1178    }
1179}
1180
1181// write test using mock feature
1182#[cfg(test)]
1183#[cfg(feature = "mock")]
1184mod tests {
1185
1186    use super::*;
1187    use crate::genai::{ComparisonOperator, EvaluationTasks};
1188    use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
1189
1190    use potato_head::mock::create_score_prompt;
1191
1192    #[test]
1193    fn test_genai_drift_config() {
1194        let mut drift_config = GenAIEvalConfig::new(
1195            MISSING,
1196            MISSING,
1197            "0.1.0",
1198            1.0,
1199            GenAIAlertConfig::default(),
1200            None,
1201        )
1202        .unwrap();
1203        assert_eq!(drift_config.name, "__missing__");
1204        assert_eq!(drift_config.space, "__missing__");
1205        assert_eq!(drift_config.version, "0.1.0");
1206        assert_eq!(
1207            drift_config.alert_config.dispatch_config,
1208            AlertDispatchConfig::default()
1209        );
1210
1211        let test_slack_dispatch_config = SlackDispatchConfig {
1212            channel: "test-channel".to_string(),
1213        };
1214        let new_alert_config = GenAIAlertConfig {
1215            schedule: "0 0 * * * *".to_string(),
1216            dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
1217            ..Default::default()
1218        };
1219
1220        drift_config
1221            .update_config_args(
1222                None,
1223                Some("test".to_string()),
1224                None,
1225                None,
1226                Some(new_alert_config),
1227            )
1228            .unwrap();
1229
1230        assert_eq!(drift_config.name, "test");
1231        assert_eq!(
1232            drift_config.alert_config.dispatch_config,
1233            AlertDispatchConfig::Slack(test_slack_dispatch_config)
1234        );
1235        assert_eq!(
1236            drift_config.alert_config.schedule,
1237            "0 0 * * * *".to_string()
1238        );
1239    }
1240
1241    #[tokio::test]
1242    async fn test_genai_drift_profile_metric() {
1243        let prompt = create_score_prompt(Some(vec!["input".to_string()]));
1244
1245        let task1 = LLMJudgeTask::new_rs(
1246            "metric1",
1247            prompt.clone(),
1248            Value::Number(4.into()),
1249            None,
1250            ComparisonOperator::GreaterThanOrEqual,
1251            None,
1252            None,
1253            None,
1254            None,
1255        );
1256
1257        let task2 = LLMJudgeTask::new_rs(
1258            "metric2",
1259            prompt.clone(),
1260            Value::Number(2.into()),
1261            None,
1262            ComparisonOperator::LessThanOrEqual,
1263            None,
1264            None,
1265            None,
1266            None,
1267        );
1268
1269        let tasks = EvaluationTasks::new()
1270            .add_task(task1)
1271            .add_task(task2)
1272            .build();
1273
1274        let alert_config = GenAIAlertConfig {
1275            schedule: "0 0 * * * *".to_string(),
1276            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
1277                team: "test-team".to_string(),
1278                priority: "P5".to_string(),
1279            }),
1280            ..Default::default()
1281        };
1282
1283        let drift_config =
1284            GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
1285
1286        let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();
1287
1288        let _: Value =
1289            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
1290
1291        assert_eq!(profile.llm_judge_tasks().len(), 2);
1292        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
1293    }
1294}