Skip to main content

scouter_types/genai/
profile.rs

1use crate::error::{ProfileError, TypeError};
2use crate::genai::alert::GenAIAlertConfig;
3use crate::genai::eval::{AssertionTask, EvaluationTask, LLMJudgeTask};
4use crate::genai::traits::{separate_tasks, ProfileExt, TaskAccessor};
5use crate::genai::utils::{extract_assertion_tasks_from_pylist, AssertionTasks};
6use crate::genai::TraceAssertionTask;
7use crate::traits::ConfigExt;
8use crate::util::{json_to_pyobject, pyobject_to_json};
9use crate::{scouter_version, EvalTaskResult, GenAIEvalWorkflowResult, WorkflowResultTableEntry};
10use crate::{
11    DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
12    PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
13};
14use crate::{ProfileRequest, TaskResultTableEntry};
15use chrono::{DateTime, Utc};
16use core::fmt::Debug;
17use potato_head::prompt_types::Prompt;
18use potato_head::Agent;
19use potato_head::Workflow;
20use potato_head::{create_uuid7, Task};
21use pyo3::prelude::*;
22use pyo3::types::{PyDict, PyList};
23use scouter_semver::VersionType;
24use scouter_state::app_state;
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::hash_map::Entry;
28use std::collections::BTreeSet;
29use std::collections::HashMap;
30use std::path::PathBuf;
31use std::sync::Arc;
32
33use tracing::instrument;
34
35fn default_sample_ratio() -> f64 {
36    1.0
37}
38
39fn default_space() -> String {
40    MISSING.to_string()
41}
42
43fn default_name() -> String {
44    MISSING.to_string()
45}
46
47fn default_version() -> String {
48    DEFAULT_VERSION.to_string()
49}
50
51fn default_uid() -> String {
52    create_uuid7()
53}
54
55fn default_drift_type() -> DriftType {
56    DriftType::GenAI
57}
58
59fn default_alert_config() -> GenAIAlertConfig {
60    GenAIAlertConfig::default()
61}
62
63#[pyclass]
64#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
65pub struct GenAIEvalConfig {
66    #[pyo3(get, set)]
67    #[serde(default = "default_sample_ratio")]
68    pub sample_ratio: f64,
69
70    #[pyo3(get, set)]
71    #[serde(default = "default_space")]
72    pub space: String,
73
74    #[pyo3(get, set)]
75    #[serde(default = "default_name")]
76    pub name: String,
77
78    #[pyo3(get, set)]
79    #[serde(default = "default_version")]
80    pub version: String,
81
82    #[pyo3(get, set)]
83    #[serde(default = "default_uid")]
84    pub uid: String,
85
86    #[pyo3(get, set)]
87    #[serde(default = "default_alert_config")]
88    pub alert_config: GenAIAlertConfig,
89
90    #[pyo3(get, set)]
91    #[serde(default = "default_drift_type")]
92    pub drift_type: DriftType,
93}
94
95impl ConfigExt for GenAIEvalConfig {
96    fn space(&self) -> &str {
97        &self.space
98    }
99
100    fn name(&self) -> &str {
101        &self.name
102    }
103
104    fn version(&self) -> &str {
105        &self.version
106    }
107    fn uid(&self) -> &str {
108        &self.uid
109    }
110}
111
112impl DispatchDriftConfig for GenAIEvalConfig {
113    fn get_drift_args(&self) -> DriftArgs {
114        DriftArgs {
115            name: self.name.clone(),
116            space: self.space.clone(),
117            version: self.version.clone(),
118            dispatch_config: self.alert_config.dispatch_config.clone(),
119        }
120    }
121}
122
123#[pymethods]
124#[allow(clippy::too_many_arguments)]
125impl GenAIEvalConfig {
126    #[new]
127    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_ratio=1.0, alert_config=GenAIAlertConfig::default(), config_path=None))]
128    pub fn new(
129        space: &str,
130        name: &str,
131        version: &str,
132        sample_ratio: f64,
133        alert_config: GenAIAlertConfig,
134        config_path: Option<PathBuf>,
135    ) -> Result<Self, ProfileError> {
136        if let Some(config_path) = config_path {
137            let config = GenAIEvalConfig::load_from_json_file(config_path)?;
138            return Ok(config);
139        }
140
141        Ok(Self {
142            sample_ratio: sample_ratio.clamp(0.0, 1.0),
143            space: space.to_string(),
144            name: name.to_string(),
145            uid: create_uuid7(),
146            version: version.to_string(),
147            alert_config,
148            drift_type: DriftType::GenAI,
149        })
150    }
151
152    #[staticmethod]
153    pub fn load_from_json_file(path: PathBuf) -> Result<GenAIEvalConfig, ProfileError> {
154        // deserialize the string to a struct
155
156        let file = std::fs::read_to_string(&path)?;
157
158        Ok(serde_json::from_str(&file)?)
159    }
160
161    pub fn __str__(&self) -> String {
162        // serialize the struct to a string
163        PyHelperFuncs::__str__(self)
164    }
165
166    pub fn model_dump_json(&self) -> String {
167        // serialize the struct to a string
168        PyHelperFuncs::__json__(self)
169    }
170
171    #[allow(clippy::too_many_arguments)]
172    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
173    pub fn update_config_args(
174        &mut self,
175        space: Option<String>,
176        name: Option<String>,
177        version: Option<String>,
178        uid: Option<String>,
179        alert_config: Option<GenAIAlertConfig>,
180    ) -> Result<(), TypeError> {
181        if name.is_some() {
182            self.name = name.ok_or(TypeError::MissingNameError)?;
183        }
184
185        if space.is_some() {
186            self.space = space.ok_or(TypeError::MissingSpaceError)?;
187        }
188
189        if version.is_some() {
190            self.version = version.ok_or(TypeError::MissingVersionError)?;
191        }
192
193        if alert_config.is_some() {
194            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
195        }
196
197        if uid.is_some() {
198            self.uid = uid.ok_or(TypeError::MissingUidError)?;
199        }
200
201        Ok(())
202    }
203}
204
205impl Default for GenAIEvalConfig {
206    fn default() -> Self {
207        Self {
208            sample_ratio: 1.0,
209            space: "default".to_string(),
210            name: "default_genai_profile".to_string(),
211            version: DEFAULT_VERSION.to_string(),
212            uid: create_uuid7(),
213            alert_config: GenAIAlertConfig::default(),
214            drift_type: DriftType::GenAI,
215        }
216    }
217}
218
219/// Validates that a prompt contains at least one required parameter.
220///
221/// LLM evaluation prompts must have either "input" or "output" parameters
222/// to access the data being evaluated.
223///
224/// # Arguments
225/// * `prompt` - The prompt to validate
226/// * `id` - Identifier for error reporting
227///
228/// # Returns
229/// * `Ok(())` if validation passes
230/// * `Err(ProfileError::MissingPromptParametersError)` if no required parameters found
231///
232/// # Errors
233/// Returns an error if the prompt lacks both "input" and "output" parameters.
234fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
235    let has_at_least_one_param = !prompt.parameters.is_empty();
236
237    if !has_at_least_one_param {
238        return Err(ProfileError::NeedAtLeastOneBoundParameterError(
239            id.to_string(),
240        ));
241    }
242
243    Ok(())
244}
245
246fn get_workflow_task<'a>(
247    workflow: &'a Workflow,
248    task_id: &'a str,
249) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
250    workflow
251        .task_list
252        .tasks
253        .get(task_id)
254        .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
255}
256
257/// Helper function to validate first tasks in workflow execution.
258fn validate_first_tasks(
259    workflow: &Workflow,
260    execution_order: &HashMap<i32, std::collections::HashSet<String>>,
261) -> Result<(), ProfileError> {
262    let first_tasks = execution_order
263        .get(&1)
264        .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
265
266    for task_id in first_tasks {
267        let task = get_workflow_task(workflow, task_id)?;
268        let task_guard = task
269            .read()
270            .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
271
272        validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
273    }
274
275    Ok(())
276}
277
278/// Validates workflow execution parameters and response types.
279///
280/// Ensures that:
281/// - First tasks have required prompt parameters
282/// - Last tasks have a response type e
283///
284/// # Arguments
285/// * `workflow` - The workflow to validate
286///
287/// # Returns
288/// * `Ok(())` if validation passes
289/// * `Err(ProfileError)` if validation fails
290///
291/// # Errors
292/// Returns various ProfileError types based on validation failures.
293fn validate_workflow(workflow: &Workflow) -> Result<(), ProfileError> {
294    let execution_order = workflow.execution_plan()?;
295
296    // Validate first tasks have required parameters
297    validate_first_tasks(workflow, &execution_order)?;
298
299    Ok(())
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
303#[pyclass]
304pub struct ExecutionNode {
305    pub id: String,
306    pub stage: usize,
307    pub parents: Vec<String>,
308    pub children: Vec<String>,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize, Default)]
312#[pyclass]
313pub struct ExecutionPlan {
314    #[pyo3(get)]
315    pub stages: Vec<Vec<String>>,
316    #[pyo3(get)]
317    pub nodes: HashMap<String, ExecutionNode>,
318}
319
320fn initialize_node_graphs(
321    tasks: &[impl TaskAccessor],
322    graph: &mut HashMap<String, Vec<String>>,
323    reverse_graph: &mut HashMap<String, Vec<String>>,
324    in_degree: &mut HashMap<String, usize>,
325) {
326    for task in tasks {
327        let task_id = task.id().to_string();
328        graph.entry(task_id.clone()).or_default();
329        reverse_graph.entry(task_id.clone()).or_default();
330        in_degree.entry(task_id).or_insert(0);
331    }
332}
333
334fn build_dependency_edges(
335    tasks: &[impl TaskAccessor],
336    graph: &mut HashMap<String, Vec<String>>,
337    reverse_graph: &mut HashMap<String, Vec<String>>,
338    in_degree: &mut HashMap<String, usize>,
339) {
340    for task in tasks {
341        let task_id = task.id().to_string();
342        for dep in task.depends_on() {
343            graph.entry(dep.clone()).or_default().push(task_id.clone());
344            reverse_graph
345                .entry(task_id.clone())
346                .or_default()
347                .push(dep.clone());
348            *in_degree.entry(task_id.clone()).or_insert(0) += 1;
349        }
350    }
351}
352
353#[pyclass]
354#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
355pub struct GenAIEvalProfile {
356    #[pyo3(get)]
357    pub config: GenAIEvalConfig,
358
359    pub tasks: AssertionTasks,
360
361    #[pyo3(get)]
362    pub scouter_version: String,
363
364    pub workflow: Option<Workflow>,
365
366    pub task_ids: BTreeSet<String>,
367
368    pub alias: Option<String>,
369}
370
371#[pymethods]
372impl GenAIEvalProfile {
373    #[new]
374    #[pyo3(signature = (tasks, config=None, alias=None))]
375    /// Create a new GenAIEvalProfile
376    /// GenAI evaluations are run asynchronously on the scouter server.
377    /// # Arguments
378    /// * `config` - GenAIEvalConfig - The configuration for the GenAI drift profile
379    /// * `tasks` - PyList - List of AssertionTask, LLMJudgeTask or ConditionalTask
380    /// * `alias` - Option<String> - Optional alias for the profile
381    /// # Returns
382    /// * `Result<Self, ProfileError>` - The GenAIEvalProfile
383    /// # Errors
384    /// * `ProfileError::MissingWorkflowError` - If the workflow is
385    #[instrument(skip_all)]
386    pub fn new_py(
387        tasks: &Bound<'_, PyList>,
388        config: Option<GenAIEvalConfig>,
389        alias: Option<String>,
390    ) -> Result<Self, ProfileError> {
391        let tasks = extract_assertion_tasks_from_pylist(tasks)?;
392
393        let (workflow, task_ids) =
394            app_state().block_on(async { Self::build_profile(&tasks).await })?;
395
396        Ok(Self {
397            config: config.unwrap_or_default(),
398            tasks,
399            scouter_version: scouter_version(),
400            workflow,
401            task_ids,
402            alias,
403        })
404    }
405
406    pub fn __str__(&self) -> String {
407        // serialize the struct to a string
408        PyHelperFuncs::__str__(self)
409    }
410
411    pub fn model_dump_json(&self) -> String {
412        // serialize the struct to a string
413        PyHelperFuncs::__json__(self)
414    }
415
416    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
417        let json_value = serde_json::to_value(self)?;
418
419        // Create a new Python dictionary
420        let dict = PyDict::new(py);
421
422        // Convert JSON to Python dict
423        json_to_pyobject(py, &json_value, &dict)?;
424
425        // Return the Python dictionary
426        Ok(dict.into())
427    }
428
429    #[getter]
430    pub fn drift_type(&self) -> DriftType {
431        self.config.drift_type.clone()
432    }
433
434    #[getter]
435    pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
436        self.tasks.assertion.clone()
437    }
438
439    #[getter]
440    pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
441        self.tasks.judge.clone()
442    }
443
444    #[getter]
445    pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
446        self.tasks.trace.clone()
447    }
448
449    #[getter]
450    pub fn alias(&self) -> Option<String> {
451        self.alias.clone()
452    }
453
454    #[getter]
455    pub fn uid(&self) -> String {
456        self.config.uid.clone()
457    }
458
459    #[setter]
460    pub fn set_uid(&mut self, uid: String) {
461        self.config.uid = uid;
462    }
463
464    #[pyo3(signature = (path=None))]
465    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
466        Ok(PyHelperFuncs::save_to_json(
467            self,
468            path,
469            FileName::GenAIEvalProfile.to_str(),
470        )?)
471    }
472
473    #[staticmethod]
474    pub fn model_validate(data: &Bound<'_, PyDict>) -> GenAIEvalProfile {
475        let json_value = pyobject_to_json(data).unwrap();
476
477        let string = serde_json::to_string(&json_value).unwrap();
478        serde_json::from_str(&string).expect("Failed to load drift profile")
479    }
480
481    #[staticmethod]
482    pub fn model_validate_json(json_string: String) -> GenAIEvalProfile {
483        // deserialize the string to a struct
484        serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
485    }
486
487    #[staticmethod]
488    pub fn from_file(path: PathBuf) -> Result<GenAIEvalProfile, ProfileError> {
489        let file = std::fs::read_to_string(&path)?;
490
491        Ok(serde_json::from_str(&file)?)
492    }
493
494    #[allow(clippy::too_many_arguments)]
495    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
496    pub fn update_config_args(
497        &mut self,
498        space: Option<String>,
499        name: Option<String>,
500        version: Option<String>,
501        uid: Option<String>,
502        alert_config: Option<GenAIAlertConfig>,
503    ) -> Result<(), TypeError> {
504        self.config
505            .update_config_args(space, name, version, uid, alert_config)
506    }
507
508    /// Create a profile request from the profile
509    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
510        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
511            None
512        } else {
513            Some(self.config.version.clone())
514        };
515
516        Ok(ProfileRequest {
517            space: self.config.space.clone(),
518            profile: self.model_dump_json(),
519            drift_type: self.config.drift_type.clone(),
520            version_request: Some(VersionRequest {
521                version,
522                version_type: VersionType::Minor,
523                pre_tag: None,
524                build_tag: None,
525            }),
526            active: false,
527            deactivate_others: false,
528        })
529    }
530
531    pub fn has_llm_tasks(&self) -> bool {
532        !self.tasks.judge.is_empty()
533    }
534
535    /// Check if this profile has assertions
536    pub fn has_assertions(&self) -> bool {
537        !self.tasks.assertion.is_empty()
538    }
539
540    pub fn has_trace_assertions(&self) -> bool {
541        !self.tasks.trace.is_empty()
542    }
543
544    /// Get execution order for all tasks (assertions + LLM judges + trace assertions)
545    pub fn get_execution_plan(&self) -> Result<ExecutionPlan, ProfileError> {
546        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
547        let mut reverse_graph: HashMap<String, Vec<String>> = HashMap::new();
548        let mut in_degree: HashMap<String, usize> = HashMap::new();
549
550        initialize_node_graphs(
551            &self.tasks.assertion,
552            &mut graph,
553            &mut reverse_graph,
554            &mut in_degree,
555        );
556        initialize_node_graphs(
557            &self.tasks.judge,
558            &mut graph,
559            &mut reverse_graph,
560            &mut in_degree,
561        );
562
563        initialize_node_graphs(
564            &self.tasks.trace,
565            &mut graph,
566            &mut reverse_graph,
567            &mut in_degree,
568        );
569
570        build_dependency_edges(
571            &self.tasks.assertion,
572            &mut graph,
573            &mut reverse_graph,
574            &mut in_degree,
575        );
576        build_dependency_edges(
577            &self.tasks.judge,
578            &mut graph,
579            &mut reverse_graph,
580            &mut in_degree,
581        );
582
583        build_dependency_edges(
584            &self.tasks.trace,
585            &mut graph,
586            &mut reverse_graph,
587            &mut in_degree,
588        );
589        let mut stages = Vec::new();
590        let mut nodes: HashMap<String, ExecutionNode> = HashMap::new();
591        let mut current_level: Vec<String> = in_degree
592            .iter()
593            .filter(|(_, &degree)| degree == 0)
594            .map(|(id, _)| id.clone())
595            .collect();
596
597        let mut stage_idx = 0;
598
599        while !current_level.is_empty() {
600            stages.push(current_level.clone());
601
602            for task_id in &current_level {
603                nodes.insert(
604                    task_id.clone(),
605                    ExecutionNode {
606                        id: task_id.clone(),
607                        stage: stage_idx,
608                        parents: reverse_graph.get(task_id).cloned().unwrap_or_default(),
609                        children: graph.get(task_id).cloned().unwrap_or_default(),
610                    },
611                );
612            }
613
614            let mut next_level = Vec::new();
615            for task_id in &current_level {
616                if let Some(dependents) = graph.get(task_id) {
617                    for dependent in dependents {
618                        if let Some(degree) = in_degree.get_mut(dependent) {
619                            *degree -= 1;
620                            if *degree == 0 {
621                                next_level.push(dependent.clone());
622                            }
623                        }
624                    }
625                }
626            }
627
628            current_level = next_level;
629            stage_idx += 1;
630        }
631
632        let total_tasks =
633            self.tasks.assertion.len() + self.tasks.judge.len() + self.tasks.trace.len();
634        let processed_tasks: usize = stages.iter().map(|level| level.len()).sum();
635
636        if processed_tasks != total_tasks {
637            return Err(ProfileError::CircularDependency);
638        }
639
640        Ok(ExecutionPlan { stages, nodes })
641    }
642
643    pub fn print_execution_plan(&self) -> Result<(), ProfileError> {
644        use owo_colors::OwoColorize;
645
646        let plan = self.get_execution_plan()?;
647
648        println!("\n{}", "Evaluation Execution Plan".bold().green());
649        println!("{}", "═".repeat(70).green());
650
651        let mut conditional_count = 0;
652
653        for (level_idx, level) in plan.stages.iter().enumerate() {
654            let stage_label = format!("Stage {}", level_idx + 1);
655            println!("\n{}", stage_label.bold().cyan());
656
657            for (task_idx, task_id) in level.iter().enumerate() {
658                let is_last = task_idx == level.len() - 1;
659                let prefix = if is_last { "└─" } else { "├─" };
660
661                let task = self.get_task_by_id(task_id).ok_or_else(|| {
662                    ProfileError::NoTasksFoundError(format!("Task '{}' not found", task_id))
663                })?;
664
665                let is_conditional = if let Some(assertion) = self.get_assertion_by_id(task_id) {
666                    assertion.condition
667                } else if let Some(judge) = self.get_llm_judge_by_id(task_id) {
668                    judge.condition
669                } else if let Some(trace) = self.get_trace_assertion_by_id(task_id) {
670                    trace.condition
671                } else {
672                    false
673                };
674
675                if is_conditional {
676                    conditional_count += 1;
677                }
678
679                let (task_type, color_fn): (&str, fn(&str) -> String) =
680                    if self.tasks.assertion.iter().any(|t| &t.id == task_id) {
681                        ("Assertion", |s: &str| s.yellow().to_string())
682                    } else if self.tasks.trace.iter().any(|t| &t.id == task_id) {
683                        ("Trace Assertion", |s: &str| s.bright_blue().to_string())
684                    } else {
685                        ("LLM Judge", |s: &str| s.purple().to_string())
686                    };
687
688                let conditional_marker = if is_conditional {
689                    " [CONDITIONAL]".bright_red().to_string()
690                } else {
691                    String::new()
692                };
693
694                println!(
695                    "{} {} ({}){}",
696                    prefix,
697                    task_id.bold(),
698                    color_fn(task_type),
699                    conditional_marker
700                );
701
702                let deps = task.depends_on();
703                if !deps.is_empty() {
704                    let dep_prefix = if is_last { "  " } else { "│ " };
705
706                    let (conditional_deps, normal_deps): (Vec<_>, Vec<_>) =
707                        deps.iter().partition(|dep_id| {
708                            self.get_assertion_by_id(dep_id)
709                                .map(|t| t.condition)
710                                .or_else(|| self.get_llm_judge_by_id(dep_id).map(|t| t.condition))
711                                .or_else(|| {
712                                    self.get_trace_assertion_by_id(dep_id).map(|t| t.condition)
713                                })
714                                .unwrap_or(false)
715                        });
716
717                    if !normal_deps.is_empty() {
718                        println!(
719                            "{}   {} {}",
720                            dep_prefix,
721                            "depends on:".dimmed(),
722                            normal_deps
723                                .iter()
724                                .map(|s| s.as_str())
725                                .collect::<Vec<_>>()
726                                .join(", ")
727                                .dimmed()
728                        );
729                    }
730
731                    if !conditional_deps.is_empty() {
732                        println!(
733                            "{}   {} {}",
734                            dep_prefix,
735                            "▶ conditional gate:".bright_red().dimmed(),
736                            conditional_deps
737                                .iter()
738                                .map(|d| format!("{} must pass", d))
739                                .collect::<Vec<_>>()
740                                .join(", ")
741                                .red()
742                                .dimmed()
743                        );
744                    }
745                }
746
747                if is_conditional {
748                    let continuation = if is_last { "  " } else { "│ " };
749                    println!(
750                        "{}   {} {}",
751                        continuation,
752                        "▶".bright_red(),
753                        "creates conditional branch".bright_red().dimmed()
754                    );
755                }
756            }
757        }
758
759        println!("\n{}", "═".repeat(70).green());
760        println!(
761            "{}: {} tasks across {} stages",
762            "Summary".bold(),
763            self.tasks.assertion.len() + self.tasks.judge.len() + self.tasks.trace.len(),
764            plan.stages.len()
765        );
766
767        if conditional_count > 0 {
768            println!(
769                "{}: {} conditional tasks that create execution branches",
770                "Branches".bold().bright_red(),
771                conditional_count
772            );
773        }
774
775        println!();
776
777        Ok(())
778    }
779}
780
781impl Default for GenAIEvalProfile {
782    fn default() -> Self {
783        Self {
784            config: GenAIEvalConfig::default(),
785            tasks: AssertionTasks {
786                assertion: Vec::new(),
787                judge: Vec::new(),
788                trace: Vec::new(),
789            },
790            scouter_version: scouter_version(),
791            workflow: None,
792            task_ids: BTreeSet::new(),
793            alias: None,
794        }
795    }
796}
797
798impl GenAIEvalProfile {
799    /// Helper method to build profile from given tasks
800    pub fn build_from_parts(
801        config: GenAIEvalConfig,
802        tasks: AssertionTasks,
803        alias: Option<String>,
804    ) -> Result<GenAIEvalProfile, ProfileError> {
805        let (workflow, task_ids) =
806            app_state().block_on(async { GenAIEvalProfile::build_profile(&tasks).await })?;
807
808        Ok(GenAIEvalProfile {
809            config,
810            tasks,
811            scouter_version: scouter_version(),
812            workflow,
813            task_ids,
814            alias,
815        })
816    }
817
818    #[instrument(skip_all)]
819    pub async fn new(
820        config: GenAIEvalConfig,
821        tasks: Vec<EvaluationTask>,
822    ) -> Result<Self, ProfileError> {
823        let tasks = separate_tasks(tasks);
824        let (workflow, task_ids) = Self::build_profile(&tasks).await?;
825
826        Ok(Self {
827            config,
828            tasks,
829            scouter_version: scouter_version(),
830            workflow,
831            task_ids,
832            alias: None,
833        })
834    }
835
836    async fn build_profile(
837        tasks: &AssertionTasks,
838    ) -> Result<(Option<Workflow>, BTreeSet<String>), ProfileError> {
839        if tasks.assertion.is_empty() && tasks.judge.is_empty() && tasks.trace.is_empty() {
840            return Err(ProfileError::EmptyTaskList);
841        }
842
843        let workflow = if !tasks.judge.is_empty() {
844            let workflow = Self::build_workflow_from_judges(tasks).await?;
845            validate_workflow(&workflow)?;
846            Some(workflow)
847        } else {
848            None
849        };
850
851        // Validate LLM judge prompts individually
852        for judge in &tasks.judge {
853            validate_prompt_parameters(&judge.prompt, &judge.id)?;
854        }
855
856        // Collect all task IDs
857        let task_ids = tasks.collect_all_task_ids()?;
858
859        Ok((workflow, task_ids))
860    }
861
862    async fn get_or_create_agent(
863        agents: &mut HashMap<potato_head::Provider, Agent>,
864        workflow: &mut Workflow,
865        provider: &potato_head::Provider,
866    ) -> Result<Agent, ProfileError> {
867        match agents.entry(provider.clone()) {
868            Entry::Occupied(entry) => Ok(entry.get().clone()),
869            Entry::Vacant(entry) => {
870                let agent = Agent::new(provider.clone(), None).await?;
871                workflow.add_agent(&agent);
872                Ok(entry.insert(agent).clone())
873            }
874        }
875    }
876
877    fn filter_judge_dependencies(
878        depends_on: &[String],
879        non_judge_task_ids: &BTreeSet<String>,
880    ) -> Option<Vec<String>> {
881        let filtered: Vec<String> = depends_on
882            .iter()
883            .filter(|dep_id| !non_judge_task_ids.contains(*dep_id))
884            .cloned()
885            .collect();
886
887        if filtered.is_empty() {
888            None
889        } else {
890            Some(filtered)
891        }
892    }
893
894    /// Build workflow from LLM judge tasks
895    /// # Arguments
896    /// * `tasks` - Reference to AssertionTasks
897    /// # Returns
898    /// * `Result<Workflow, ProfileError>` - The constructed workflow
899    pub async fn build_workflow_from_judges(
900        tasks: &AssertionTasks,
901    ) -> Result<Workflow, ProfileError> {
902        let mut workflow = Workflow::new(&format!("eval_workflow_{}", create_uuid7()));
903        let mut agents = HashMap::new();
904        let non_judge_task_ids = tasks.collect_non_judge_task_ids();
905
906        for judge in &tasks.judge {
907            let agent =
908                Self::get_or_create_agent(&mut agents, &mut workflow, &judge.prompt.provider)
909                    .await?;
910
911            let task_deps = Self::filter_judge_dependencies(&judge.depends_on, &non_judge_task_ids);
912
913            let task = Task::new(
914                &agent.id,
915                judge.prompt.clone(),
916                &judge.id,
917                task_deps,
918                judge.max_retries,
919            )?;
920
921            workflow.add_task(task)?;
922        }
923
924        Ok(workflow)
925    }
926}
927
928impl ProfileExt for GenAIEvalProfile {
929    #[inline]
930    fn id(&self) -> &str {
931        &self.config.uid
932    }
933
934    fn get_task_by_id(&self, id: &str) -> Option<&dyn TaskAccessor> {
935        if let Some(assertion) = self.tasks.assertion.iter().find(|t| t.id() == id) {
936            return Some(assertion);
937        }
938
939        if let Some(judge) = self.tasks.judge.iter().find(|t| t.id() == id) {
940            return Some(judge);
941        }
942
943        if let Some(trace) = self.tasks.trace.iter().find(|t| t.id() == id) {
944            return Some(trace);
945        }
946
947        None
948    }
949
950    #[inline]
951    /// Get assertion task by ID, first checking AssertionTasks, then ConditionalTasks
952    fn get_assertion_by_id(&self, id: &str) -> Option<&AssertionTask> {
953        self.tasks.assertion.iter().find(|t| t.id() == id)
954    }
955
956    #[inline]
957    fn get_llm_judge_by_id(&self, id: &str) -> Option<&LLMJudgeTask> {
958        self.tasks.judge.iter().find(|t| t.id() == id)
959    }
960
961    #[inline]
962    fn get_trace_assertion_by_id(&self, id: &str) -> Option<&TraceAssertionTask> {
963        self.tasks.trace.iter().find(|t| t.id() == id)
964    }
965
966    #[inline]
967    fn has_llm_tasks(&self) -> bool {
968        !self.tasks.judge.is_empty()
969    }
970
971    #[inline]
972    fn has_trace_assertions(&self) -> bool {
973        !self.tasks.trace.is_empty()
974    }
975}
976
977impl ProfileBaseArgs for GenAIEvalProfile {
978    type Config = GenAIEvalConfig;
979
980    fn config(&self) -> &Self::Config {
981        &self.config
982    }
983    fn get_base_args(&self) -> ProfileArgs {
984        ProfileArgs {
985            name: self.config.name.clone(),
986            space: self.config.space.clone(),
987            version: Some(self.config.version.clone()),
988            schedule: self.config.alert_config.schedule.clone(),
989            scouter_version: self.scouter_version.clone(),
990            drift_type: self.config.drift_type.clone(),
991        }
992    }
993
994    fn to_value(&self) -> Value {
995        serde_json::to_value(self).unwrap()
996    }
997}
998
999#[pyclass]
1000#[derive(Debug, Serialize, Deserialize, Clone)]
1001pub struct EvalSet {
1002    #[pyo3(get)]
1003    pub records: Vec<EvalTaskResult>,
1004    pub inner: GenAIEvalWorkflowResult,
1005}
1006
1007impl EvalSet {
1008    pub fn build_task_entries(&mut self, record_id: &str) -> Vec<TaskResultTableEntry> {
1009        // sort records by stage, then by task_id
1010
1011        self.records
1012            .sort_by(|a, b| a.stage.cmp(&b.stage).then(a.task_id.cmp(&b.task_id)));
1013
1014        self.records
1015            .iter()
1016            .map(|record| record.to_table_entry(record_id))
1017            .collect()
1018    }
1019
1020    pub fn build_workflow_entries(&self) -> Vec<WorkflowResultTableEntry> {
1021        vec![self.inner.to_table_entry()]
1022    }
1023
1024    pub fn new(records: Vec<EvalTaskResult>, inner: GenAIEvalWorkflowResult) -> Self {
1025        Self { records, inner }
1026    }
1027
1028    pub fn empty() -> Self {
1029        Self {
1030            records: Vec::new(),
1031            inner: GenAIEvalWorkflowResult {
1032                created_at: Utc::now(),
1033                record_uid: String::new(),
1034                entity_id: 0,
1035                total_tasks: 0,
1036                passed_tasks: 0,
1037                failed_tasks: 0,
1038                pass_rate: 0.0,
1039                duration_ms: 0,
1040                entity_uid: String::new(),
1041                execution_plan: ExecutionPlan::default(),
1042                id: 0,
1043            },
1044        }
1045    }
1046}
1047
1048#[pymethods]
1049impl EvalSet {
1050    #[getter]
1051    pub fn created_at(&self) -> DateTime<Utc> {
1052        self.inner.created_at
1053    }
1054
1055    #[getter]
1056    pub fn record_uid(&self) -> String {
1057        self.inner.record_uid.clone()
1058    }
1059
1060    #[getter]
1061    pub fn total_tasks(&self) -> i32 {
1062        self.inner.total_tasks
1063    }
1064
1065    #[getter]
1066    pub fn passed_tasks(&self) -> i32 {
1067        self.inner.passed_tasks
1068    }
1069
1070    #[getter]
1071    pub fn failed_tasks(&self) -> i32 {
1072        self.inner.failed_tasks
1073    }
1074
1075    #[getter]
1076    pub fn pass_rate(&self) -> f64 {
1077        self.inner.pass_rate
1078    }
1079
1080    #[getter]
1081    pub fn duration_ms(&self) -> i64 {
1082        self.inner.duration_ms
1083    }
1084
1085    pub fn __str__(&self) -> String {
1086        // serialize the struct to a string
1087        PyHelperFuncs::__str__(self)
1088    }
1089}
1090
1091#[pyclass]
1092#[derive(Debug, Serialize, Deserialize, Clone)]
1093pub struct EvalResultSet {
1094    #[pyo3(get)]
1095    pub records: Vec<EvalSet>,
1096}
1097
1098#[pymethods]
1099impl EvalResultSet {
1100    pub fn record(&self, id: &str) -> Option<EvalSet> {
1101        self.records.iter().find(|r| r.record_uid() == id).cloned()
1102    }
1103    pub fn __str__(&self) -> String {
1104        // serialize the struct to a string
1105        PyHelperFuncs::__str__(self)
1106    }
1107}
1108
1109// write test using mock feature
1110#[cfg(test)]
1111#[cfg(feature = "mock")]
1112mod tests {
1113
1114    use super::*;
1115    use crate::genai::{ComparisonOperator, EvaluationTasks};
1116    use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
1117
1118    use potato_head::mock::create_score_prompt;
1119
1120    #[test]
1121    fn test_genai_drift_config() {
1122        let mut drift_config = GenAIEvalConfig::new(
1123            MISSING,
1124            MISSING,
1125            "0.1.0",
1126            1.0,
1127            GenAIAlertConfig::default(),
1128            None,
1129        )
1130        .unwrap();
1131        assert_eq!(drift_config.name, "__missing__");
1132        assert_eq!(drift_config.space, "__missing__");
1133        assert_eq!(drift_config.version, "0.1.0");
1134        assert_eq!(
1135            drift_config.alert_config.dispatch_config,
1136            AlertDispatchConfig::default()
1137        );
1138
1139        let test_slack_dispatch_config = SlackDispatchConfig {
1140            channel: "test-channel".to_string(),
1141        };
1142        let new_alert_config = GenAIAlertConfig {
1143            schedule: "0 0 * * * *".to_string(),
1144            dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
1145            ..Default::default()
1146        };
1147
1148        drift_config
1149            .update_config_args(
1150                None,
1151                Some("test".to_string()),
1152                None,
1153                None,
1154                Some(new_alert_config),
1155            )
1156            .unwrap();
1157
1158        assert_eq!(drift_config.name, "test");
1159        assert_eq!(
1160            drift_config.alert_config.dispatch_config,
1161            AlertDispatchConfig::Slack(test_slack_dispatch_config)
1162        );
1163        assert_eq!(
1164            drift_config.alert_config.schedule,
1165            "0 0 * * * *".to_string()
1166        );
1167    }
1168
1169    #[tokio::test]
1170    async fn test_genai_drift_profile_metric() {
1171        let prompt = create_score_prompt(Some(vec!["input".to_string()]));
1172
1173        let task1 = LLMJudgeTask::new_rs(
1174            "metric1",
1175            prompt.clone(),
1176            Value::Number(4.into()),
1177            None,
1178            ComparisonOperator::GreaterThanOrEqual,
1179            None,
1180            None,
1181            None,
1182            None,
1183        );
1184
1185        let task2 = LLMJudgeTask::new_rs(
1186            "metric2",
1187            prompt.clone(),
1188            Value::Number(2.into()),
1189            None,
1190            ComparisonOperator::LessThanOrEqual,
1191            None,
1192            None,
1193            None,
1194            None,
1195        );
1196
1197        let tasks = EvaluationTasks::new()
1198            .add_task(task1)
1199            .add_task(task2)
1200            .build();
1201
1202        let alert_config = GenAIAlertConfig {
1203            schedule: "0 0 * * * *".to_string(),
1204            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
1205                team: "test-team".to_string(),
1206                priority: "P5".to_string(),
1207            }),
1208            ..Default::default()
1209        };
1210
1211        let drift_config =
1212            GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
1213
1214        let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();
1215
1216        let _: Value =
1217            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
1218
1219        assert_eq!(profile.llm_judge_tasks().len(), 2);
1220        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
1221    }
1222}