scouter_drift/llm/
evaluator.rs1use crate::error::DriftError;
3use potato_head::ResponseLogProbs;
4use potato_head::{calculate_weighted_score, Score, StructuredOutput, TaskStatus, Workflow};
5use scouter_types::llm::LLMDriftProfile;
6use scouter_types::{LLMMetricRecord, LLMRecord};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::sync::RwLock;
10use tracing::{debug, error, instrument, warn};
11
12pub type LLMEvalResult = (Vec<LLMMetricRecord>, HashMap<String, Score>, Option<i32>); pub struct LLMEvaluator {}
15
16impl LLMEvaluator {
17 pub fn new() -> Self {
18 LLMEvaluator {}
19 }
20
21 pub fn get_final_task_results(
24 workflow: Arc<RwLock<Workflow>>,
25 profile: &LLMDriftProfile,
26 record_uid: &str,
27 ) -> Result<LLMEvalResult, DriftError> {
28 let workflow = workflow.read().unwrap();
29 let task_list = &workflow.task_list;
30 let execution_plan = workflow.execution_plan()?;
31
32 let max_step = execution_plan.keys().max().copied().unwrap_or(0);
33
34 if max_step == 0 {
35 return Ok((Vec::new(), HashMap::new(), None));
36 }
37
38 let mut final_results = Vec::new();
39 let mut score_map: HashMap<String, Score> = HashMap::new();
40 let workflow_duration = workflow.total_duration();
41
42 if let Some(final_task_ids) = execution_plan.get(&max_step) {
43 for task_id in final_task_ids {
44 let Some(task) = task_list.get_task(task_id) else {
46 continue;
47 };
48
49 let task_guard = task.read().unwrap();
51
52 let (TaskStatus::Completed, Some(result)) =
54 (&task_guard.status, &task_guard.result)
55 else {
56 continue;
57 };
58
59 let task_id = task_guard.id.clone();
60
61 let content = match result.content() {
63 Some(c) => c,
64 None => {
65 warn!("Task result content is empty for task ID: {}", task_id);
66 continue;
67 }
68 };
69
70 let score = Score::model_validate_json_str(&content).inspect_err(|e| {
72 error!("Failed to validate score: {:?}", e);
73 })?;
74
75 let log_probs: Vec<ResponseLogProbs> = result.log_probs();
77
78 let value = if !log_probs.is_empty() {
81 match calculate_weighted_score(&log_probs)? {
82 Some(weighted) => weighted,
83 None => score.score as f64,
84 }
85 } else {
86 score.score as f64
87 };
88
89 let record = LLMMetricRecord {
91 record_uid: record_uid.to_string(),
92 created_at: chrono::Utc::now(),
93 space: profile.config.space.clone(),
94 name: profile.config.name.clone(),
95 version: profile.config.version.clone(),
96 metric: task_id.clone(),
97 value,
98 };
99
100 score_map.insert(task_id, score);
102 final_results.push(record);
103 }
104 }
105
106 Ok((final_results, score_map, Some(workflow_duration)))
107 }
108
109 #[instrument(skip_all)]
110 pub async fn process_drift_record(
111 record: &LLMRecord,
112 profile: &LLMDriftProfile,
113 ) -> Result<LLMEvalResult, DriftError> {
114 debug!("Processing workflow");
115
116 let workflow_result = profile.workflow.run(Some(record.context.clone())).await?;
117 Self::get_final_task_results(workflow_result, profile, &record.uid)
118 }
119}
120
121impl Default for LLMEvaluator {
122 fn default() -> Self {
123 Self::new()
124 }
125}