scouter_drift/llm/
evaluator.rs

1// Module for polling LLM drift records that are "pending" and need to be processed
2use crate::error::DriftError;
3use potato_head::ResponseLogProbs;
4use potato_head::{calculate_weighted_score, Score, StructuredOutput, TaskStatus, Workflow};
5use scouter_types::llm::LLMDriftProfile;
6use scouter_types::{LLMMetricRecord, LLMRecord};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::sync::RwLock;
10use tracing::{debug, error, instrument, warn};
11
12pub type LLMEvalResult = (Vec<LLMMetricRecord>, HashMap<String, Score>, Option<i32>); // Vec<LLMMetricRecord>, ScoreMap, WorkflowDuration
13
14pub struct LLMEvaluator {}
15
16impl LLMEvaluator {
17    pub fn new() -> Self {
18        LLMEvaluator {}
19    }
20
21    /// Gets the final task results of the workflow.
22    /// # Returns a HashMap where the keys are task IDs and the values are AgentResponse objects.
23    pub fn get_final_task_results(
24        workflow: Arc<RwLock<Workflow>>,
25        profile: &LLMDriftProfile,
26        record_uid: &str,
27    ) -> Result<LLMEvalResult, DriftError> {
28        let workflow = workflow.read().unwrap();
29        let task_list = &workflow.task_list;
30        let execution_plan = workflow.execution_plan()?;
31
32        let max_step = execution_plan.keys().max().copied().unwrap_or(0);
33
34        if max_step == 0 {
35            return Ok((Vec::new(), HashMap::new(), None));
36        }
37
38        let mut final_results = Vec::new();
39        let mut score_map: HashMap<String, Score> = HashMap::new();
40        let workflow_duration = workflow.total_duration();
41
42        if let Some(final_task_ids) = execution_plan.get(&max_step) {
43            for task_id in final_task_ids {
44                // Get the task from the task list
45                let Some(task) = task_list.get_task(task_id) else {
46                    continue;
47                };
48
49                // Lock the task for reading
50                let task_guard = task.read().unwrap();
51
52                // Only process completed tasks with a result
53                let (TaskStatus::Completed, Some(result)) =
54                    (&task_guard.status, &task_guard.result)
55                else {
56                    continue;
57                };
58
59                let task_id = task_guard.id.clone();
60
61                // Content should be returned as a json string
62                let content = match result.content() {
63                    Some(c) => c,
64                    None => {
65                        warn!("Task result content is empty for task ID: {}", task_id);
66                        continue;
67                    }
68                };
69
70                // Validate the content as a Score object
71                let score = Score::model_validate_json_str(&content).inspect_err(|e| {
72                    error!("Failed to validate score: {:?}", e);
73                })?;
74
75                // Check for log_probs in the result
76                let log_probs: Vec<ResponseLogProbs> = result.log_probs();
77
78                // Calculate weighted score if log_probs is not empty
79                // Default to score if no log_probs are present or if calculation returns None
80                let value = if !log_probs.is_empty() {
81                    match calculate_weighted_score(&log_probs)? {
82                        Some(weighted) => weighted,
83                        None => score.score as f64,
84                    }
85                } else {
86                    score.score as f64
87                };
88
89                // Create the LLMMetricRecord
90                let record = LLMMetricRecord {
91                    record_uid: record_uid.to_string(),
92                    created_at: chrono::Utc::now(),
93                    space: profile.config.space.clone(),
94                    name: profile.config.name.clone(),
95                    version: profile.config.version.clone(),
96                    metric: task_id.clone(),
97                    value,
98                };
99
100                // Add the score to the score map
101                score_map.insert(task_id, score);
102                final_results.push(record);
103            }
104        }
105
106        Ok((final_results, score_map, Some(workflow_duration)))
107    }
108
109    #[instrument(skip_all)]
110    pub async fn process_drift_record(
111        record: &LLMRecord,
112        profile: &LLMDriftProfile,
113    ) -> Result<LLMEvalResult, DriftError> {
114        debug!("Processing workflow");
115
116        let workflow_result = profile.workflow.run(Some(record.context.clone())).await?;
117        Self::get_final_task_results(workflow_result, profile, &record.uid)
118    }
119}
120
121impl Default for LLMEvaluator {
122    fn default() -> Self {
123        Self::new()
124    }
125}