Skip to main content

scouter_types/genai/
profile.rs

1use crate::error::{ProfileError, TypeError};
2use crate::genai::alert::GenAIAlertConfig;
3use crate::genai::eval::{AssertionTask, EvaluationTask, LLMJudgeTask};
4use crate::genai::traits::{separate_tasks, ProfileExt, TaskAccessor};
5use crate::genai::utils::{extract_assertion_tasks_from_pylist, AssertionTasks};
6use crate::genai::TraceAssertionTask;
7use crate::traits::ConfigExt;
8use crate::util::{json_to_pyobject, pyobject_to_json};
9use crate::{
10    scouter_version, GenAIEvalTaskResult, GenAIEvalWorkflowResult, WorkflowResultTableEntry,
11};
12use crate::{
13    DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
14    PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
15};
16use crate::{ProfileRequest, TaskResultTableEntry};
17use chrono::{DateTime, Utc};
18use core::fmt::Debug;
19use potato_head::prompt_types::Prompt;
20use potato_head::Agent;
21use potato_head::Workflow;
22use potato_head::{create_uuid7, Task};
23use pyo3::prelude::*;
24use pyo3::types::{PyDict, PyList};
25use scouter_semver::VersionType;
26use scouter_state::app_state;
27use serde::{Deserialize, Serialize};
28use serde_json::Value;
29use std::collections::hash_map::Entry;
30use std::collections::BTreeSet;
31use std::collections::HashMap;
32use std::path::PathBuf;
33use std::sync::Arc;
34
35use tracing::instrument;
36
37fn default_sample_ratio() -> f64 {
38    1.0
39}
40
41fn default_space() -> String {
42    MISSING.to_string()
43}
44
45fn default_name() -> String {
46    MISSING.to_string()
47}
48
49fn default_version() -> String {
50    DEFAULT_VERSION.to_string()
51}
52
53fn default_uid() -> String {
54    create_uuid7()
55}
56
57fn default_drift_type() -> DriftType {
58    DriftType::GenAI
59}
60
61fn default_alert_config() -> GenAIAlertConfig {
62    GenAIAlertConfig::default()
63}
64
65#[pyclass]
66#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
67pub struct GenAIEvalConfig {
68    #[pyo3(get, set)]
69    #[serde(default = "default_sample_ratio")]
70    pub sample_ratio: f64,
71
72    #[pyo3(get, set)]
73    #[serde(default = "default_space")]
74    pub space: String,
75
76    #[pyo3(get, set)]
77    #[serde(default = "default_name")]
78    pub name: String,
79
80    #[pyo3(get, set)]
81    #[serde(default = "default_version")]
82    pub version: String,
83
84    #[pyo3(get, set)]
85    #[serde(default = "default_uid")]
86    pub uid: String,
87
88    #[pyo3(get, set)]
89    #[serde(default = "default_alert_config")]
90    pub alert_config: GenAIAlertConfig,
91
92    #[pyo3(get, set)]
93    #[serde(default = "default_drift_type")]
94    pub drift_type: DriftType,
95}
96
97impl ConfigExt for GenAIEvalConfig {
98    fn space(&self) -> &str {
99        &self.space
100    }
101
102    fn name(&self) -> &str {
103        &self.name
104    }
105
106    fn version(&self) -> &str {
107        &self.version
108    }
109    fn uid(&self) -> &str {
110        &self.uid
111    }
112}
113
114impl DispatchDriftConfig for GenAIEvalConfig {
115    fn get_drift_args(&self) -> DriftArgs {
116        DriftArgs {
117            name: self.name.clone(),
118            space: self.space.clone(),
119            version: self.version.clone(),
120            dispatch_config: self.alert_config.dispatch_config.clone(),
121        }
122    }
123}
124
125#[pymethods]
126#[allow(clippy::too_many_arguments)]
127impl GenAIEvalConfig {
128    #[new]
129    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_ratio=1.0, alert_config=GenAIAlertConfig::default(), config_path=None))]
130    pub fn new(
131        space: &str,
132        name: &str,
133        version: &str,
134        sample_ratio: f64,
135        alert_config: GenAIAlertConfig,
136        config_path: Option<PathBuf>,
137    ) -> Result<Self, ProfileError> {
138        if let Some(config_path) = config_path {
139            let config = GenAIEvalConfig::load_from_json_file(config_path)?;
140            return Ok(config);
141        }
142
143        Ok(Self {
144            sample_ratio: sample_ratio.clamp(0.0, 1.0),
145            space: space.to_string(),
146            name: name.to_string(),
147            uid: create_uuid7(),
148            version: version.to_string(),
149            alert_config,
150            drift_type: DriftType::GenAI,
151        })
152    }
153
154    #[staticmethod]
155    pub fn load_from_json_file(path: PathBuf) -> Result<GenAIEvalConfig, ProfileError> {
156        // deserialize the string to a struct
157
158        let file = std::fs::read_to_string(&path)?;
159
160        Ok(serde_json::from_str(&file)?)
161    }
162
163    pub fn __str__(&self) -> String {
164        // serialize the struct to a string
165        PyHelperFuncs::__str__(self)
166    }
167
168    pub fn model_dump_json(&self) -> String {
169        // serialize the struct to a string
170        PyHelperFuncs::__json__(self)
171    }
172
173    #[allow(clippy::too_many_arguments)]
174    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
175    pub fn update_config_args(
176        &mut self,
177        space: Option<String>,
178        name: Option<String>,
179        version: Option<String>,
180        uid: Option<String>,
181        alert_config: Option<GenAIAlertConfig>,
182    ) -> Result<(), TypeError> {
183        if name.is_some() {
184            self.name = name.ok_or(TypeError::MissingNameError)?;
185        }
186
187        if space.is_some() {
188            self.space = space.ok_or(TypeError::MissingSpaceError)?;
189        }
190
191        if version.is_some() {
192            self.version = version.ok_or(TypeError::MissingVersionError)?;
193        }
194
195        if alert_config.is_some() {
196            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
197        }
198
199        if uid.is_some() {
200            self.uid = uid.ok_or(TypeError::MissingUidError)?;
201        }
202
203        Ok(())
204    }
205}
206
207impl Default for GenAIEvalConfig {
208    fn default() -> Self {
209        Self {
210            sample_ratio: 1.0,
211            space: "default".to_string(),
212            name: "default_genai_profile".to_string(),
213            version: DEFAULT_VERSION.to_string(),
214            uid: create_uuid7(),
215            alert_config: GenAIAlertConfig::default(),
216            drift_type: DriftType::GenAI,
217        }
218    }
219}
220
221/// Validates that a prompt contains at least one required parameter.
222///
223/// LLM evaluation prompts must have either "input" or "output" parameters
224/// to access the data being evaluated.
225///
226/// # Arguments
227/// * `prompt` - The prompt to validate
228/// * `id` - Identifier for error reporting
229///
230/// # Returns
231/// * `Ok(())` if validation passes
232/// * `Err(ProfileError::MissingPromptParametersError)` if no required parameters found
233///
234/// # Errors
235/// Returns an error if the prompt lacks both "input" and "output" parameters.
236fn validate_prompt_parameters(prompt: &Prompt, id: &str) -> Result<(), ProfileError> {
237    let has_at_least_one_param = !prompt.parameters.is_empty();
238
239    if !has_at_least_one_param {
240        return Err(ProfileError::NeedAtLeastOneBoundParameterError(
241            id.to_string(),
242        ));
243    }
244
245    Ok(())
246}
247
248fn get_workflow_task<'a>(
249    workflow: &'a Workflow,
250    task_id: &'a str,
251) -> Result<&'a Arc<std::sync::RwLock<potato_head::Task>>, ProfileError> {
252    workflow
253        .task_list
254        .tasks
255        .get(task_id)
256        .ok_or_else(|| ProfileError::NoTasksFoundError(format!("Task '{task_id}' not found")))
257}
258
259/// Helper function to validate first tasks in workflow execution.
260fn validate_first_tasks(
261    workflow: &Workflow,
262    execution_order: &HashMap<i32, std::collections::HashSet<String>>,
263) -> Result<(), ProfileError> {
264    let first_tasks = execution_order
265        .get(&1)
266        .ok_or_else(|| ProfileError::NoTasksFoundError("No initial tasks found".to_string()))?;
267
268    for task_id in first_tasks {
269        let task = get_workflow_task(workflow, task_id)?;
270        let task_guard = task
271            .read()
272            .map_err(|_| ProfileError::NoTasksFoundError("Failed to read task".to_string()))?;
273
274        validate_prompt_parameters(&task_guard.prompt, &task_guard.id)?;
275    }
276
277    Ok(())
278}
279
280/// Validates workflow execution parameters and response types.
281///
282/// Ensures that:
283/// - First tasks have required prompt parameters
284/// - Last tasks have a response type e
285///
286/// # Arguments
287/// * `workflow` - The workflow to validate
288///
289/// # Returns
290/// * `Ok(())` if validation passes
291/// * `Err(ProfileError)` if validation fails
292///
293/// # Errors
294/// Returns various ProfileError types based on validation failures.
295fn validate_workflow(workflow: &Workflow) -> Result<(), ProfileError> {
296    let execution_order = workflow.execution_plan()?;
297
298    // Validate first tasks have required parameters
299    validate_first_tasks(workflow, &execution_order)?;
300
301    Ok(())
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305#[pyclass]
306pub struct ExecutionNode {
307    pub id: String,
308    pub stage: usize,
309    pub parents: Vec<String>,
310    pub children: Vec<String>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, Default)]
314#[pyclass]
315pub struct ExecutionPlan {
316    #[pyo3(get)]
317    pub stages: Vec<Vec<String>>,
318    #[pyo3(get)]
319    pub nodes: HashMap<String, ExecutionNode>,
320}
321
322fn initialize_node_graphs(
323    tasks: &[impl TaskAccessor],
324    graph: &mut HashMap<String, Vec<String>>,
325    reverse_graph: &mut HashMap<String, Vec<String>>,
326    in_degree: &mut HashMap<String, usize>,
327) {
328    for task in tasks {
329        let task_id = task.id().to_string();
330        graph.entry(task_id.clone()).or_default();
331        reverse_graph.entry(task_id.clone()).or_default();
332        in_degree.entry(task_id).or_insert(0);
333    }
334}
335
336fn build_dependency_edges(
337    tasks: &[impl TaskAccessor],
338    graph: &mut HashMap<String, Vec<String>>,
339    reverse_graph: &mut HashMap<String, Vec<String>>,
340    in_degree: &mut HashMap<String, usize>,
341) {
342    for task in tasks {
343        let task_id = task.id().to_string();
344        for dep in task.depends_on() {
345            graph.entry(dep.clone()).or_default().push(task_id.clone());
346            reverse_graph
347                .entry(task_id.clone())
348                .or_default()
349                .push(dep.clone());
350            *in_degree.entry(task_id.clone()).or_insert(0) += 1;
351        }
352    }
353}
354
355#[pyclass]
356#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
357pub struct GenAIEvalProfile {
358    #[pyo3(get)]
359    pub config: GenAIEvalConfig,
360
361    pub tasks: AssertionTasks,
362
363    #[pyo3(get)]
364    pub scouter_version: String,
365
366    pub workflow: Option<Workflow>,
367
368    pub task_ids: BTreeSet<String>,
369
370    pub alias: Option<String>,
371}
372
373#[pymethods]
374impl GenAIEvalProfile {
375    #[new]
376    #[pyo3(signature = (tasks, config=None, alias=None))]
377    /// Create a new GenAIEvalProfile
378    /// GenAI evaluations are run asynchronously on the scouter server.
379    /// # Arguments
380    /// * `config` - GenAIEvalConfig - The configuration for the GenAI drift profile
381    /// * `tasks` - PyList - List of AssertionTask, LLMJudgeTask or ConditionalTask
382    /// * `alias` - Option<String> - Optional alias for the profile
383    /// # Returns
384    /// * `Result<Self, ProfileError>` - The GenAIEvalProfile
385    /// # Errors
386    /// * `ProfileError::MissingWorkflowError` - If the workflow is
387    #[instrument(skip_all)]
388    pub fn new_py(
389        tasks: &Bound<'_, PyList>,
390        config: Option<GenAIEvalConfig>,
391        alias: Option<String>,
392    ) -> Result<Self, ProfileError> {
393        let tasks = extract_assertion_tasks_from_pylist(tasks)?;
394
395        let (workflow, task_ids) =
396            app_state().block_on(async { Self::build_profile(&tasks).await })?;
397
398        Ok(Self {
399            config: config.unwrap_or_default(),
400            tasks,
401            scouter_version: scouter_version(),
402            workflow,
403            task_ids,
404            alias,
405        })
406    }
407
408    pub fn __str__(&self) -> String {
409        // serialize the struct to a string
410        PyHelperFuncs::__str__(self)
411    }
412
413    pub fn model_dump_json(&self) -> String {
414        // serialize the struct to a string
415        PyHelperFuncs::__json__(self)
416    }
417
418    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
419        let json_value = serde_json::to_value(self)?;
420
421        // Create a new Python dictionary
422        let dict = PyDict::new(py);
423
424        // Convert JSON to Python dict
425        json_to_pyobject(py, &json_value, &dict)?;
426
427        // Return the Python dictionary
428        Ok(dict.into())
429    }
430
431    #[getter]
432    pub fn drift_type(&self) -> DriftType {
433        self.config.drift_type.clone()
434    }
435
436    #[getter]
437    pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
438        self.tasks.assertion.clone()
439    }
440
441    #[getter]
442    pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
443        self.tasks.judge.clone()
444    }
445
446    #[getter]
447    pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
448        self.tasks.trace.clone()
449    }
450
451    #[getter]
452    pub fn alias(&self) -> Option<String> {
453        self.alias.clone()
454    }
455
456    #[getter]
457    pub fn uid(&self) -> String {
458        self.config.uid.clone()
459    }
460
461    #[setter]
462    pub fn set_uid(&mut self, uid: String) {
463        self.config.uid = uid;
464    }
465
466    #[pyo3(signature = (path=None))]
467    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
468        Ok(PyHelperFuncs::save_to_json(
469            self,
470            path,
471            FileName::GenAIEvalProfile.to_str(),
472        )?)
473    }
474
475    #[staticmethod]
476    pub fn model_validate(data: &Bound<'_, PyDict>) -> GenAIEvalProfile {
477        let json_value = pyobject_to_json(data).unwrap();
478
479        let string = serde_json::to_string(&json_value).unwrap();
480        serde_json::from_str(&string).expect("Failed to load drift profile")
481    }
482
483    #[staticmethod]
484    pub fn model_validate_json(json_string: String) -> GenAIEvalProfile {
485        // deserialize the string to a struct
486        serde_json::from_str(&json_string).expect("Failed to load prompt drift profile")
487    }
488
489    #[staticmethod]
490    pub fn from_file(path: PathBuf) -> Result<GenAIEvalProfile, ProfileError> {
491        let file = std::fs::read_to_string(&path)?;
492
493        Ok(serde_json::from_str(&file)?)
494    }
495
496    #[allow(clippy::too_many_arguments)]
497    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
498    pub fn update_config_args(
499        &mut self,
500        space: Option<String>,
501        name: Option<String>,
502        version: Option<String>,
503        uid: Option<String>,
504        alert_config: Option<GenAIAlertConfig>,
505    ) -> Result<(), TypeError> {
506        self.config
507            .update_config_args(space, name, version, uid, alert_config)
508    }
509
510    /// Create a profile request from the profile
511    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
512        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
513            None
514        } else {
515            Some(self.config.version.clone())
516        };
517
518        Ok(ProfileRequest {
519            space: self.config.space.clone(),
520            profile: self.model_dump_json(),
521            drift_type: self.config.drift_type.clone(),
522            version_request: Some(VersionRequest {
523                version,
524                version_type: VersionType::Minor,
525                pre_tag: None,
526                build_tag: None,
527            }),
528            active: false,
529            deactivate_others: false,
530        })
531    }
532
533    pub fn has_llm_tasks(&self) -> bool {
534        !self.tasks.judge.is_empty()
535    }
536
537    /// Check if this profile has assertions
538    pub fn has_assertions(&self) -> bool {
539        !self.tasks.assertion.is_empty()
540    }
541
542    pub fn has_trace_assertions(&self) -> bool {
543        !self.tasks.trace.is_empty()
544    }
545
546    /// Get execution order for all tasks (assertions + LLM judges + trace assertions)
547    pub fn get_execution_plan(&self) -> Result<ExecutionPlan, ProfileError> {
548        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
549        let mut reverse_graph: HashMap<String, Vec<String>> = HashMap::new();
550        let mut in_degree: HashMap<String, usize> = HashMap::new();
551
552        initialize_node_graphs(
553            &self.tasks.assertion,
554            &mut graph,
555            &mut reverse_graph,
556            &mut in_degree,
557        );
558        initialize_node_graphs(
559            &self.tasks.judge,
560            &mut graph,
561            &mut reverse_graph,
562            &mut in_degree,
563        );
564
565        initialize_node_graphs(
566            &self.tasks.trace,
567            &mut graph,
568            &mut reverse_graph,
569            &mut in_degree,
570        );
571
572        build_dependency_edges(
573            &self.tasks.assertion,
574            &mut graph,
575            &mut reverse_graph,
576            &mut in_degree,
577        );
578        build_dependency_edges(
579            &self.tasks.judge,
580            &mut graph,
581            &mut reverse_graph,
582            &mut in_degree,
583        );
584
585        build_dependency_edges(
586            &self.tasks.trace,
587            &mut graph,
588            &mut reverse_graph,
589            &mut in_degree,
590        );
591        let mut stages = Vec::new();
592        let mut nodes: HashMap<String, ExecutionNode> = HashMap::new();
593        let mut current_level: Vec<String> = in_degree
594            .iter()
595            .filter(|(_, &degree)| degree == 0)
596            .map(|(id, _)| id.clone())
597            .collect();
598
599        let mut stage_idx = 0;
600
601        while !current_level.is_empty() {
602            stages.push(current_level.clone());
603
604            for task_id in &current_level {
605                nodes.insert(
606                    task_id.clone(),
607                    ExecutionNode {
608                        id: task_id.clone(),
609                        stage: stage_idx,
610                        parents: reverse_graph.get(task_id).cloned().unwrap_or_default(),
611                        children: graph.get(task_id).cloned().unwrap_or_default(),
612                    },
613                );
614            }
615
616            let mut next_level = Vec::new();
617            for task_id in &current_level {
618                if let Some(dependents) = graph.get(task_id) {
619                    for dependent in dependents {
620                        if let Some(degree) = in_degree.get_mut(dependent) {
621                            *degree -= 1;
622                            if *degree == 0 {
623                                next_level.push(dependent.clone());
624                            }
625                        }
626                    }
627                }
628            }
629
630            current_level = next_level;
631            stage_idx += 1;
632        }
633
634        let total_tasks =
635            self.tasks.assertion.len() + self.tasks.judge.len() + self.tasks.trace.len();
636        let processed_tasks: usize = stages.iter().map(|level| level.len()).sum();
637
638        if processed_tasks != total_tasks {
639            return Err(ProfileError::CircularDependency);
640        }
641
642        Ok(ExecutionPlan { stages, nodes })
643    }
644
645    pub fn print_execution_plan(&self) -> Result<(), ProfileError> {
646        use owo_colors::OwoColorize;
647
648        let plan = self.get_execution_plan()?;
649
650        println!("\n{}", "Evaluation Execution Plan".bold().green());
651        println!("{}", "═".repeat(70).green());
652
653        let mut conditional_count = 0;
654
655        for (level_idx, level) in plan.stages.iter().enumerate() {
656            let stage_label = format!("Stage {}", level_idx + 1);
657            println!("\n{}", stage_label.bold().cyan());
658
659            for (task_idx, task_id) in level.iter().enumerate() {
660                let is_last = task_idx == level.len() - 1;
661                let prefix = if is_last { "└─" } else { "├─" };
662
663                let task = self.get_task_by_id(task_id).ok_or_else(|| {
664                    ProfileError::NoTasksFoundError(format!("Task '{}' not found", task_id))
665                })?;
666
667                let is_conditional = if let Some(assertion) = self.get_assertion_by_id(task_id) {
668                    assertion.condition
669                } else if let Some(judge) = self.get_llm_judge_by_id(task_id) {
670                    judge.condition
671                } else if let Some(trace) = self.get_trace_assertion_by_id(task_id) {
672                    trace.condition
673                } else {
674                    false
675                };
676
677                if is_conditional {
678                    conditional_count += 1;
679                }
680
681                let (task_type, color_fn): (&str, fn(&str) -> String) =
682                    if self.tasks.assertion.iter().any(|t| &t.id == task_id) {
683                        ("Assertion", |s: &str| s.yellow().to_string())
684                    } else if self.tasks.trace.iter().any(|t| &t.id == task_id) {
685                        ("Trace Assertion", |s: &str| s.bright_blue().to_string())
686                    } else {
687                        ("LLM Judge", |s: &str| s.purple().to_string())
688                    };
689
690                let conditional_marker = if is_conditional {
691                    " [CONDITIONAL]".bright_red().to_string()
692                } else {
693                    String::new()
694                };
695
696                println!(
697                    "{} {} ({}){}",
698                    prefix,
699                    task_id.bold(),
700                    color_fn(task_type),
701                    conditional_marker
702                );
703
704                let deps = task.depends_on();
705                if !deps.is_empty() {
706                    let dep_prefix = if is_last { "  " } else { "│ " };
707
708                    let (conditional_deps, normal_deps): (Vec<_>, Vec<_>) =
709                        deps.iter().partition(|dep_id| {
710                            self.get_assertion_by_id(dep_id)
711                                .map(|t| t.condition)
712                                .or_else(|| self.get_llm_judge_by_id(dep_id).map(|t| t.condition))
713                                .or_else(|| {
714                                    self.get_trace_assertion_by_id(dep_id).map(|t| t.condition)
715                                })
716                                .unwrap_or(false)
717                        });
718
719                    if !normal_deps.is_empty() {
720                        println!(
721                            "{}   {} {}",
722                            dep_prefix,
723                            "depends on:".dimmed(),
724                            normal_deps
725                                .iter()
726                                .map(|s| s.as_str())
727                                .collect::<Vec<_>>()
728                                .join(", ")
729                                .dimmed()
730                        );
731                    }
732
733                    if !conditional_deps.is_empty() {
734                        println!(
735                            "{}   {} {}",
736                            dep_prefix,
737                            "▶ conditional gate:".bright_red().dimmed(),
738                            conditional_deps
739                                .iter()
740                                .map(|d| format!("{} must pass", d))
741                                .collect::<Vec<_>>()
742                                .join(", ")
743                                .red()
744                                .dimmed()
745                        );
746                    }
747                }
748
749                if is_conditional {
750                    let continuation = if is_last { "  " } else { "│ " };
751                    println!(
752                        "{}   {} {}",
753                        continuation,
754                        "▶".bright_red(),
755                        "creates conditional branch".bright_red().dimmed()
756                    );
757                }
758            }
759        }
760
761        println!("\n{}", "═".repeat(70).green());
762        println!(
763            "{}: {} tasks across {} stages",
764            "Summary".bold(),
765            self.tasks.assertion.len() + self.tasks.judge.len() + self.tasks.trace.len(),
766            plan.stages.len()
767        );
768
769        if conditional_count > 0 {
770            println!(
771                "{}: {} conditional tasks that create execution branches",
772                "Branches".bold().bright_red(),
773                conditional_count
774            );
775        }
776
777        println!();
778
779        Ok(())
780    }
781}
782
783impl Default for GenAIEvalProfile {
784    fn default() -> Self {
785        Self {
786            config: GenAIEvalConfig::default(),
787            tasks: AssertionTasks {
788                assertion: Vec::new(),
789                judge: Vec::new(),
790                trace: Vec::new(),
791            },
792            scouter_version: scouter_version(),
793            workflow: None,
794            task_ids: BTreeSet::new(),
795            alias: None,
796        }
797    }
798}
799
800impl GenAIEvalProfile {
801    /// Helper method to build profile from given tasks
802    pub fn build_from_parts(
803        config: GenAIEvalConfig,
804        tasks: AssertionTasks,
805        alias: Option<String>,
806    ) -> Result<GenAIEvalProfile, ProfileError> {
807        let (workflow, task_ids) =
808            app_state().block_on(async { GenAIEvalProfile::build_profile(&tasks).await })?;
809
810        Ok(GenAIEvalProfile {
811            config,
812            tasks,
813            scouter_version: scouter_version(),
814            workflow,
815            task_ids,
816            alias,
817        })
818    }
819
820    #[instrument(skip_all)]
821    pub async fn new(
822        config: GenAIEvalConfig,
823        tasks: Vec<EvaluationTask>,
824    ) -> Result<Self, ProfileError> {
825        let tasks = separate_tasks(tasks);
826        let (workflow, task_ids) = Self::build_profile(&tasks).await?;
827
828        Ok(Self {
829            config,
830            tasks,
831            scouter_version: scouter_version(),
832            workflow,
833            task_ids,
834            alias: None,
835        })
836    }
837
838    async fn build_profile(
839        tasks: &AssertionTasks,
840    ) -> Result<(Option<Workflow>, BTreeSet<String>), ProfileError> {
841        if tasks.assertion.is_empty() && tasks.judge.is_empty() && tasks.trace.is_empty() {
842            return Err(ProfileError::EmptyTaskList);
843        }
844
845        let workflow = if !tasks.judge.is_empty() {
846            let workflow = Self::build_workflow_from_judges(tasks).await?;
847            validate_workflow(&workflow)?;
848            Some(workflow)
849        } else {
850            None
851        };
852
853        // Validate LLM judge prompts individually
854        for judge in &tasks.judge {
855            validate_prompt_parameters(&judge.prompt, &judge.id)?;
856        }
857
858        // Collect all task IDs
859        let task_ids = tasks.collect_all_task_ids()?;
860
861        Ok((workflow, task_ids))
862    }
863
864    async fn get_or_create_agent(
865        agents: &mut HashMap<potato_head::Provider, Agent>,
866        workflow: &mut Workflow,
867        provider: &potato_head::Provider,
868    ) -> Result<Agent, ProfileError> {
869        match agents.entry(provider.clone()) {
870            Entry::Occupied(entry) => Ok(entry.get().clone()),
871            Entry::Vacant(entry) => {
872                let agent = Agent::new(provider.clone(), None).await?;
873                workflow.add_agent(&agent);
874                Ok(entry.insert(agent).clone())
875            }
876        }
877    }
878
879    fn filter_judge_dependencies(
880        depends_on: &[String],
881        non_judge_task_ids: &BTreeSet<String>,
882    ) -> Option<Vec<String>> {
883        let filtered: Vec<String> = depends_on
884            .iter()
885            .filter(|dep_id| !non_judge_task_ids.contains(*dep_id))
886            .cloned()
887            .collect();
888
889        if filtered.is_empty() {
890            None
891        } else {
892            Some(filtered)
893        }
894    }
895
896    /// Build workflow from LLM judge tasks
897    /// # Arguments
898    /// * `tasks` - Reference to AssertionTasks
899    /// # Returns
900    /// * `Result<Workflow, ProfileError>` - The constructed workflow
901    pub async fn build_workflow_from_judges(
902        tasks: &AssertionTasks,
903    ) -> Result<Workflow, ProfileError> {
904        let mut workflow = Workflow::new(&format!("eval_workflow_{}", create_uuid7()));
905        let mut agents = HashMap::new();
906        let non_judge_task_ids = tasks.collect_non_judge_task_ids();
907
908        for judge in &tasks.judge {
909            let agent =
910                Self::get_or_create_agent(&mut agents, &mut workflow, &judge.prompt.provider)
911                    .await?;
912
913            let task_deps = Self::filter_judge_dependencies(&judge.depends_on, &non_judge_task_ids);
914
915            let task = Task::new(
916                &agent.id,
917                judge.prompt.clone(),
918                &judge.id,
919                task_deps,
920                judge.max_retries,
921            )?;
922
923            workflow.add_task(task)?;
924        }
925
926        Ok(workflow)
927    }
928}
929
930impl ProfileExt for GenAIEvalProfile {
931    #[inline]
932    fn id(&self) -> &str {
933        &self.config.uid
934    }
935
936    fn get_task_by_id(&self, id: &str) -> Option<&dyn TaskAccessor> {
937        if let Some(assertion) = self.tasks.assertion.iter().find(|t| t.id() == id) {
938            return Some(assertion);
939        }
940
941        if let Some(judge) = self.tasks.judge.iter().find(|t| t.id() == id) {
942            return Some(judge);
943        }
944
945        if let Some(trace) = self.tasks.trace.iter().find(|t| t.id() == id) {
946            return Some(trace);
947        }
948
949        None
950    }
951
952    #[inline]
953    /// Get assertion task by ID, first checking AssertionTasks, then ConditionalTasks
954    fn get_assertion_by_id(&self, id: &str) -> Option<&AssertionTask> {
955        self.tasks.assertion.iter().find(|t| t.id() == id)
956    }
957
958    #[inline]
959    fn get_llm_judge_by_id(&self, id: &str) -> Option<&LLMJudgeTask> {
960        self.tasks.judge.iter().find(|t| t.id() == id)
961    }
962
963    #[inline]
964    fn get_trace_assertion_by_id(&self, id: &str) -> Option<&TraceAssertionTask> {
965        self.tasks.trace.iter().find(|t| t.id() == id)
966    }
967
968    #[inline]
969    fn has_llm_tasks(&self) -> bool {
970        !self.tasks.judge.is_empty()
971    }
972
973    #[inline]
974    fn has_trace_assertions(&self) -> bool {
975        !self.tasks.trace.is_empty()
976    }
977}
978
979impl ProfileBaseArgs for GenAIEvalProfile {
980    type Config = GenAIEvalConfig;
981
982    fn config(&self) -> &Self::Config {
983        &self.config
984    }
985    fn get_base_args(&self) -> ProfileArgs {
986        ProfileArgs {
987            name: self.config.name.clone(),
988            space: self.config.space.clone(),
989            version: Some(self.config.version.clone()),
990            schedule: self.config.alert_config.schedule.clone(),
991            scouter_version: self.scouter_version.clone(),
992            drift_type: self.config.drift_type.clone(),
993        }
994    }
995
996    fn to_value(&self) -> Value {
997        serde_json::to_value(self).unwrap()
998    }
999}
1000
1001#[pyclass]
1002#[derive(Debug, Serialize, Deserialize, Clone)]
1003pub struct GenAIEvalSet {
1004    #[pyo3(get)]
1005    pub records: Vec<GenAIEvalTaskResult>,
1006    pub inner: GenAIEvalWorkflowResult,
1007}
1008
1009impl GenAIEvalSet {
1010    pub fn build_task_entries(&mut self, record_id: &str) -> Vec<TaskResultTableEntry> {
1011        // sort records by stage, then by task_id
1012
1013        self.records
1014            .sort_by(|a, b| a.stage.cmp(&b.stage).then(a.task_id.cmp(&b.task_id)));
1015
1016        self.records
1017            .iter()
1018            .map(|record| record.to_table_entry(record_id))
1019            .collect()
1020    }
1021
1022    pub fn build_workflow_entries(&self) -> Vec<WorkflowResultTableEntry> {
1023        vec![self.inner.to_table_entry()]
1024    }
1025
1026    pub fn new(records: Vec<GenAIEvalTaskResult>, inner: GenAIEvalWorkflowResult) -> Self {
1027        Self { records, inner }
1028    }
1029
1030    pub fn empty() -> Self {
1031        Self {
1032            records: Vec::new(),
1033            inner: GenAIEvalWorkflowResult {
1034                created_at: Utc::now(),
1035                record_uid: String::new(),
1036                entity_id: 0,
1037                total_tasks: 0,
1038                passed_tasks: 0,
1039                failed_tasks: 0,
1040                pass_rate: 0.0,
1041                duration_ms: 0,
1042                entity_uid: String::new(),
1043                execution_plan: ExecutionPlan::default(),
1044                id: 0,
1045            },
1046        }
1047    }
1048}
1049
1050#[pymethods]
1051impl GenAIEvalSet {
1052    #[getter]
1053    pub fn created_at(&self) -> DateTime<Utc> {
1054        self.inner.created_at
1055    }
1056
1057    #[getter]
1058    pub fn record_uid(&self) -> String {
1059        self.inner.record_uid.clone()
1060    }
1061
1062    #[getter]
1063    pub fn total_tasks(&self) -> i32 {
1064        self.inner.total_tasks
1065    }
1066
1067    #[getter]
1068    pub fn passed_tasks(&self) -> i32 {
1069        self.inner.passed_tasks
1070    }
1071
1072    #[getter]
1073    pub fn failed_tasks(&self) -> i32 {
1074        self.inner.failed_tasks
1075    }
1076
1077    #[getter]
1078    pub fn pass_rate(&self) -> f64 {
1079        self.inner.pass_rate
1080    }
1081
1082    #[getter]
1083    pub fn duration_ms(&self) -> i64 {
1084        self.inner.duration_ms
1085    }
1086
1087    pub fn __str__(&self) -> String {
1088        // serialize the struct to a string
1089        PyHelperFuncs::__str__(self)
1090    }
1091}
1092
1093#[pyclass]
1094#[derive(Debug, Serialize, Deserialize, Clone)]
1095pub struct GenAIEvalResultSet {
1096    #[pyo3(get)]
1097    pub records: Vec<GenAIEvalSet>,
1098}
1099
1100#[pymethods]
1101impl GenAIEvalResultSet {
1102    pub fn record(&self, id: &str) -> Option<GenAIEvalSet> {
1103        self.records.iter().find(|r| r.record_uid() == id).cloned()
1104    }
1105    pub fn __str__(&self) -> String {
1106        // serialize the struct to a string
1107        PyHelperFuncs::__str__(self)
1108    }
1109}
1110
1111// write test using mock feature
1112#[cfg(test)]
1113#[cfg(feature = "mock")]
1114mod tests {
1115
1116    use super::*;
1117    use crate::genai::{ComparisonOperator, EvaluationTasks};
1118    use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
1119
1120    use potato_head::mock::create_score_prompt;
1121
1122    #[test]
1123    fn test_genai_drift_config() {
1124        let mut drift_config = GenAIEvalConfig::new(
1125            MISSING,
1126            MISSING,
1127            "0.1.0",
1128            1.0,
1129            GenAIAlertConfig::default(),
1130            None,
1131        )
1132        .unwrap();
1133        assert_eq!(drift_config.name, "__missing__");
1134        assert_eq!(drift_config.space, "__missing__");
1135        assert_eq!(drift_config.version, "0.1.0");
1136        assert_eq!(
1137            drift_config.alert_config.dispatch_config,
1138            AlertDispatchConfig::default()
1139        );
1140
1141        let test_slack_dispatch_config = SlackDispatchConfig {
1142            channel: "test-channel".to_string(),
1143        };
1144        let new_alert_config = GenAIAlertConfig {
1145            schedule: "0 0 * * * *".to_string(),
1146            dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
1147            ..Default::default()
1148        };
1149
1150        drift_config
1151            .update_config_args(
1152                None,
1153                Some("test".to_string()),
1154                None,
1155                None,
1156                Some(new_alert_config),
1157            )
1158            .unwrap();
1159
1160        assert_eq!(drift_config.name, "test");
1161        assert_eq!(
1162            drift_config.alert_config.dispatch_config,
1163            AlertDispatchConfig::Slack(test_slack_dispatch_config)
1164        );
1165        assert_eq!(
1166            drift_config.alert_config.schedule,
1167            "0 0 * * * *".to_string()
1168        );
1169    }
1170
1171    #[tokio::test]
1172    async fn test_genai_drift_profile_metric() {
1173        let prompt = create_score_prompt(Some(vec!["input".to_string()]));
1174
1175        let task1 = LLMJudgeTask::new_rs(
1176            "metric1",
1177            prompt.clone(),
1178            Value::Number(4.into()),
1179            None,
1180            ComparisonOperator::GreaterThanOrEqual,
1181            None,
1182            None,
1183            None,
1184            None,
1185        );
1186
1187        let task2 = LLMJudgeTask::new_rs(
1188            "metric2",
1189            prompt.clone(),
1190            Value::Number(2.into()),
1191            None,
1192            ComparisonOperator::LessThanOrEqual,
1193            None,
1194            None,
1195            None,
1196            None,
1197        );
1198
1199        let tasks = EvaluationTasks::new()
1200            .add_task(task1)
1201            .add_task(task2)
1202            .build();
1203
1204        let alert_config = GenAIAlertConfig {
1205            schedule: "0 0 * * * *".to_string(),
1206            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
1207                team: "test-team".to_string(),
1208                priority: "P5".to_string(),
1209            }),
1210            ..Default::default()
1211        };
1212
1213        let drift_config =
1214            GenAIEvalConfig::new("scouter", "ML", "0.1.0", 1.0, alert_config, None).unwrap();
1215
1216        let profile = GenAIEvalProfile::new(drift_config, tasks).await.unwrap();
1217
1218        let _: Value =
1219            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
1220
1221        assert_eq!(profile.llm_judge_tasks().len(), 2);
1222        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
1223    }
1224}