Skip to main content

scouter_evaluate/evaluate/
agent.rs

1use crate::error::EvaluationError;
2use crate::tasks::evaluator::{PATH_REGEX, REGEX_FIELD_PARSE_PATTERN};
3use potato_head::{ChatResponse, Provider};
4use scouter_types::genai::AgentAssertion;
5use serde_json::{json, Value};
6use tracing::error;
7
8const MAX_PATH_LEN: usize = 512;
9const MAX_PATH_SEGMENTS: usize = 32;
10
11/// Builds evaluation context from vendor LLM request/response data.
12/// Normalizes vendor-specific formats into a standard structure for assertion evaluation.
13#[derive(Debug, Clone)]
14pub struct AgentContextBuilder {
15    response: ChatResponse,
16    raw: Value,
17}
18
19impl AgentContextBuilder {
20    /// Build an AgentContextBuilder from raw context JSON.
21    ///
22    /// Extraction strategy (priority order):
23    /// 1. Pre-normalized — if top-level keys match our standard shape
24    /// 2. OpenAI format — choices[].message.tool_calls, usage, model
25    /// 3. Anthropic format — content[] with ToolUseBlock, usage, model
26    /// 4. Google/Gemini format — candidates[].content.parts[] with function_call
27    /// 5. Fallback — walk JSON tree for known patterns
28    pub fn from_context(
29        context: &Value,
30        provider: Option<&Provider>,
31    ) -> Result<Self, EvaluationError> {
32        let response_val = context.get("response").unwrap_or(context);
33        let response =
34            ChatResponse::from_response_value(response_val.clone(), provider).map_err(|e| {
35                error!("Failed to parse response: {}", e);
36                EvaluationError::InvalidProviderResponse
37            })?;
38        Ok(Self {
39            response,
40            raw: response_val.clone(),
41        })
42    }
43
44    /// Resolve an AgentAssertion variant to a JSON value for the shared AssertionEvaluator.
45    pub fn build_context(&self, assertion: &AgentAssertion) -> Result<Value, EvaluationError> {
46        match assertion {
47            AgentAssertion::ToolCalled { name } => {
48                let found = self
49                    .response
50                    .get_tool_calls()
51                    .iter()
52                    .any(|tc| tc.name == *name);
53                Ok(json!(found))
54            }
55            AgentAssertion::ToolNotCalled { name } => {
56                let not_found = !self
57                    .response
58                    .get_tool_calls()
59                    .iter()
60                    .any(|tc| tc.name == *name);
61                Ok(json!(not_found))
62            }
63            AgentAssertion::ToolCalledWithArgs { name, arguments } => {
64                let matched =
65                    self.response.get_tool_calls().iter().any(|tc| {
66                        tc.name == *name && Self::partial_match(&tc.arguments, &arguments.0)
67                    });
68                Ok(json!(matched))
69            }
70            AgentAssertion::ToolCallSequence { names } => {
71                let actual: Vec<String> = self
72                    .response
73                    .get_tool_calls()
74                    .iter()
75                    .map(|tc| tc.name.clone())
76                    .collect();
77                let mut expected_iter = names.iter();
78                let mut current = expected_iter.next();
79                for actual_name in &actual {
80                    if let Some(exp) = current {
81                        if actual_name == exp {
82                            current = expected_iter.next();
83                        }
84                    }
85                }
86                Ok(json!(current.is_none()))
87            }
88            AgentAssertion::ToolCallCount { name } => {
89                let tools = &self.response.get_tool_calls();
90                let count = if let Some(name) = name {
91                    tools.iter().filter(|tc| tc.name == *name).count()
92                } else {
93                    tools.len()
94                };
95                Ok(json!(count))
96            }
97            AgentAssertion::ToolArgument { name, argument_key } => {
98                let value = self
99                    .response
100                    .get_tool_calls()
101                    .iter()
102                    .find(|tc| tc.name == *name)
103                    .and_then(|tc| tc.arguments.get(argument_key))
104                    .cloned()
105                    .unwrap_or(Value::Null);
106
107                Ok(value)
108            }
109            AgentAssertion::ToolResult { name } => {
110                let value = self
111                    .response
112                    .get_tool_calls()
113                    .iter()
114                    .find(|tc| tc.name == *name)
115                    .and_then(|tc| tc.result.clone())
116                    .unwrap_or(Value::Null);
117
118                Ok(value)
119            }
120            AgentAssertion::ResponseContent {} => {
121                let text = self.response.response_text();
122                if text.is_empty() {
123                    Ok(Value::Null)
124                } else {
125                    Ok(json!(text))
126                }
127            }
128            AgentAssertion::ResponseModel {} => Ok(self
129                .response
130                .model_name()
131                .map(|m| json!(m))
132                .unwrap_or(Value::Null)),
133            AgentAssertion::ResponseFinishReason {} => Ok(self
134                .response
135                .finish_reason_str()
136                .map(|f| json!(f))
137                .unwrap_or(Value::Null)),
138            AgentAssertion::ResponseInputTokens {} => Ok(self
139                .response
140                .input_tokens()
141                .map(|t| json!(t))
142                .unwrap_or(Value::Null)),
143            AgentAssertion::ResponseOutputTokens {} => Ok(self
144                .response
145                .output_tokens()
146                .map(|t| json!(t))
147                .unwrap_or(Value::Null)),
148            AgentAssertion::ResponseTotalTokens {} => Ok(self
149                .response
150                .total_tokens()
151                .map(|t| json!(t))
152                .unwrap_or(Value::Null)),
153            AgentAssertion::ResponseField { path } => Self::extract_by_path(&self.raw, path),
154        }
155    }
156
157    // ─── Helpers ───────────────────────────────────────────────────────
158
159    /// Check if all specified args are present and equal in actual args (partial match).
160    fn partial_match(actual: &Value, expected: &Value) -> bool {
161        match (actual, expected) {
162            (Value::Object(actual_map), Value::Object(expected_map)) => {
163                for (key, expected_val) in expected_map {
164                    match actual_map.get(key) {
165                        Some(actual_val) => {
166                            if !Self::partial_match(actual_val, expected_val) {
167                                return false;
168                            }
169                        }
170                        None => return false,
171                    }
172                }
173                true
174            }
175            _ => actual == expected,
176        }
177    }
178
179    /// Extract a value from JSON using dot-notation path with array indexing.
180    /// Supports: "foo.bar", "foo[0].bar", "candidates[0].content.parts[0].text"
181    fn extract_by_path(val: &Value, path: &str) -> Result<Value, EvaluationError> {
182        let mut current = val.clone();
183
184        for segment in Self::parse_path_segments(path)? {
185            match segment {
186                PathSegment::Key(key) => {
187                    current = current.get(&key).cloned().unwrap_or(Value::Null);
188                }
189                PathSegment::Index(idx) => {
190                    current = current
191                        .as_array()
192                        .and_then(|arr| arr.get(idx))
193                        .cloned()
194                        .unwrap_or(Value::Null);
195                }
196            }
197        }
198
199        Ok(current)
200    }
201
202    fn parse_path_segments(path: &str) -> Result<Vec<PathSegment>, EvaluationError> {
203        if path.len() > MAX_PATH_LEN {
204            return Err(EvaluationError::PathTooLong(path.len()));
205        }
206
207        let regex = PATH_REGEX.get_or_init(|| {
208            regex::Regex::new(REGEX_FIELD_PARSE_PATTERN)
209                .expect("Invalid regex pattern in REGEX_FIELD_PARSE_PATTERN")
210        });
211
212        let mut segments = Vec::new();
213
214        for capture in regex.find_iter(path) {
215            let s = capture.as_str();
216            if s.starts_with('[') && s.ends_with(']') {
217                let idx_str = &s[1..s.len() - 1];
218                let idx = idx_str
219                    .parse::<usize>()
220                    .map_err(|_| EvaluationError::InvalidArrayIndex(idx_str.to_string()))?;
221                segments.push(PathSegment::Index(idx));
222            } else {
223                segments.push(PathSegment::Key(s.to_string()));
224            }
225        }
226
227        if segments.is_empty() {
228            return Err(EvaluationError::EmptyFieldPath);
229        }
230
231        if segments.len() > MAX_PATH_SEGMENTS {
232            return Err(EvaluationError::TooManyPathSegments(segments.len()));
233        }
234
235        Ok(segments)
236    }
237}
238
239enum PathSegment {
240    Key(String),
241    Index(usize),
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use scouter_types::genai::PyValueWrapper;
248
249    #[test]
250    fn test_tool_called_assertion() {
251        let context = json!({
252            "model": "gpt-4o",
253            "choices": [{
254                "message": {
255                    "role": "assistant",
256                    "content": null,
257                    "tool_calls": [{
258                        "id": "call_1",
259                        "type": "function",
260                        "function": {"name": "web_search", "arguments": "{\"query\": \"test\"}"}
261                    }]
262                },
263                "finish_reason": "tool_calls"
264            }]
265        });
266
267        let builder = AgentContextBuilder::from_context(&context, None).unwrap();
268
269        let result = builder
270            .build_context(&AgentAssertion::ToolCalled {
271                name: "web_search".to_string(),
272            })
273            .unwrap();
274        assert_eq!(result, json!(true));
275
276        let result = builder
277            .build_context(&AgentAssertion::ToolNotCalled {
278                name: "delete_user".to_string(),
279            })
280            .unwrap();
281        assert_eq!(result, json!(true));
282
283        let result = builder
284            .build_context(&AgentAssertion::ToolCallCount { name: None })
285            .unwrap();
286        assert_eq!(result, json!(1));
287    }
288
289    #[test]
290    fn test_tool_called_with_args_partial_match() {
291        let context = json!({
292            "model": "gpt-4o",
293            "choices": [{
294                "message": {
295                    "role": "assistant",
296                    "content": null,
297                    "tool_calls": [{
298                        "id": "call_1",
299                        "type": "function",
300                        "function": {"name": "web_search", "arguments": "{\"query\": \"weather NYC\", \"lang\": \"en\", \"limit\": 5}"}
301                    }]
302                },
303                "finish_reason": "tool_calls"
304            }]
305        });
306
307        let builder = AgentContextBuilder::from_context(&context, None).unwrap();
308
309        // Partial match - only checking "query"
310        let result = builder
311            .build_context(&AgentAssertion::ToolCalledWithArgs {
312                name: "web_search".to_string(),
313                arguments: PyValueWrapper(json!({"query": "weather NYC"})),
314            })
315            .unwrap();
316        assert_eq!(result, json!(true));
317
318        // Non-matching arg
319        let result = builder
320            .build_context(&AgentAssertion::ToolCalledWithArgs {
321                name: "web_search".to_string(),
322                arguments: PyValueWrapper(json!({"query": "weather LA"})),
323            })
324            .unwrap();
325        assert_eq!(result, json!(false));
326    }
327
328    #[test]
329    fn test_tool_call_sequence() {
330        let context = json!({
331            "model": "gpt-4o",
332            "choices": [{
333                "message": {
334                    "role": "assistant",
335                    "content": null,
336                    "tool_calls": [
337                        {"id": "call_1", "type": "function", "function": {"name": "web_search", "arguments": "{}"}},
338                        {"id": "call_2", "type": "function", "function": {"name": "summarize", "arguments": "{}"}},
339                        {"id": "call_3", "type": "function", "function": {"name": "respond", "arguments": "{}"}}
340                    ]
341                },
342                "finish_reason": "tool_calls"
343            }]
344        });
345
346        let builder = AgentContextBuilder::from_context(&context, None).unwrap();
347
348        let result = builder
349            .build_context(&AgentAssertion::ToolCallSequence {
350                names: vec![
351                    "web_search".to_string(),
352                    "summarize".to_string(),
353                    "respond".to_string(),
354                ],
355            })
356            .unwrap();
357        assert_eq!(result, json!(true));
358
359        // Wrong order
360        let result = builder
361            .build_context(&AgentAssertion::ToolCallSequence {
362                names: vec!["respond".to_string(), "web_search".to_string()],
363            })
364            .unwrap();
365        assert_eq!(result, json!(false));
366    }
367
368    #[test]
369    fn test_response_field_escape_hatch() {
370        let context = json!({
371            "response": {
372                "candidates": [{
373                    "content": {"role": "model", "parts": [{"text": "hello"}]},
374                    "finishReason": "STOP",
375                    "safety_ratings": [{"category": "HARM_CATEGORY_SAFE"}]
376                }],
377                "usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 2}
378            }
379        });
380
381        let builder = AgentContextBuilder::from_context(&context, None).unwrap();
382
383        // Path is relative to response_val (the candidates object), not the full context
384        let result = builder
385            .build_context(&AgentAssertion::ResponseField {
386                path: "candidates[0].safety_ratings[0].category".to_string(),
387            })
388            .unwrap();
389        assert_eq!(result, json!("HARM_CATEGORY_SAFE"));
390    }
391
392    #[test]
393    fn test_no_tool_calls() {
394        let context = json!({
395            "model": "gpt-4o",
396            "choices": [{
397                "message": {
398                    "role": "assistant",
399                    "content": "Just a text response."
400                },
401                "finish_reason": "stop"
402            }]
403        });
404
405        let builder = AgentContextBuilder::from_context(&context, None).unwrap();
406
407        let result = builder
408            .build_context(&AgentAssertion::ToolNotCalled {
409                name: "web_search".to_string(),
410            })
411            .unwrap();
412        assert_eq!(result, json!(true));
413    }
414
415    #[test]
416    fn test_from_context_invalid_json() {
417        // Empty object has no recognizable vendor keys
418        let context = json!({});
419        let result = AgentContextBuilder::from_context(&context, None);
420        assert!(result.is_err());
421        assert!(matches!(
422            result,
423            Err(EvaluationError::InvalidProviderResponse)
424        ));
425    }
426
427    #[test]
428    fn test_tool_call_sequence_subsequence() {
429        let context = json!({
430            "model": "gpt-4o",
431            "choices": [{
432                "message": {
433                    "role": "assistant",
434                    "content": null,
435                    "tool_calls": [
436                        {"id": "c1", "type": "function", "function": {"name": "search", "arguments": "{}"}},
437                        {"id": "c2", "type": "function", "function": {"name": "filter", "arguments": "{}"}},
438                        {"id": "c3", "type": "function", "function": {"name": "rank", "arguments": "{}"}},
439                        {"id": "c4", "type": "function", "function": {"name": "respond", "arguments": "{}"}}
440                    ]
441                },
442                "finish_reason": "tool_calls"
443            }]
444        });
445
446        let builder = AgentContextBuilder::from_context(&context, None).unwrap();
447
448        // Non-contiguous in-order subsequence should pass
449        let result = builder
450            .build_context(&AgentAssertion::ToolCallSequence {
451                names: vec![
452                    "search".to_string(),
453                    "rank".to_string(),
454                    "respond".to_string(),
455                ],
456            })
457            .unwrap();
458        assert_eq!(result, json!(true));
459
460        // Out-of-order should fail
461        let result = builder
462            .build_context(&AgentAssertion::ToolCallSequence {
463                names: vec!["respond".to_string(), "search".to_string()],
464            })
465            .unwrap();
466        assert_eq!(result, json!(false));
467    }
468
469    #[test]
470    fn test_parse_path_segments_errors() {
471        // Empty string -> EmptyFieldPath
472        let result = AgentContextBuilder::parse_path_segments("");
473        assert!(matches!(result, Err(EvaluationError::EmptyFieldPath)));
474
475        // Path exceeding max length
476        let long_path = "a".repeat(MAX_PATH_LEN + 1);
477        let result = AgentContextBuilder::parse_path_segments(&long_path);
478        assert!(matches!(result, Err(EvaluationError::PathTooLong(_))));
479
480        // Too many segments
481        let many_segments = (0..MAX_PATH_SEGMENTS + 1)
482            .map(|i| format!("seg{}", i))
483            .collect::<Vec<_>>()
484            .join(".");
485        let result = AgentContextBuilder::parse_path_segments(&many_segments);
486        assert!(matches!(
487            result,
488            Err(EvaluationError::TooManyPathSegments(_))
489        ));
490    }
491
492    #[test]
493    fn test_response_content_empty() {
494        let context = json!({
495            "model": "gpt-4o",
496            "choices": [{
497                "message": {
498                    "role": "assistant",
499                    "content": null
500                },
501                "finish_reason": "stop"
502            }]
503        });
504
505        let builder = AgentContextBuilder::from_context(&context, None).unwrap();
506        let result = builder
507            .build_context(&AgentAssertion::ResponseContent {})
508            .unwrap();
509        assert_eq!(result, Value::Null);
510    }
511
512    #[test]
513    fn test_partial_match_nested() {
514        let context = json!({
515            "model": "gpt-4o",
516            "choices": [{
517                "message": {
518                    "role": "assistant",
519                    "content": null,
520                    "tool_calls": [{
521                        "id": "c1",
522                        "type": "function",
523                        "function": {"name": "create_item", "arguments": "{\"item\": {\"name\": \"widget\", \"price\": 9.99, \"tags\": [\"sale\"]}}"}
524                    }]
525                },
526                "finish_reason": "tool_calls"
527            }]
528        });
529
530        let builder = AgentContextBuilder::from_context(&context, None).unwrap();
531
532        // Nested partial match - only check inner "name"
533        let result = builder
534            .build_context(&AgentAssertion::ToolCalledWithArgs {
535                name: "create_item".to_string(),
536                arguments: PyValueWrapper(json!({"item": {"name": "widget"}})),
537            })
538            .unwrap();
539        assert_eq!(result, json!(true));
540
541        // Nested mismatch
542        let result = builder
543            .build_context(&AgentAssertion::ToolCalledWithArgs {
544                name: "create_item".to_string(),
545                arguments: PyValueWrapper(json!({"item": {"name": "gadget"}})),
546            })
547            .unwrap();
548        assert_eq!(result, json!(false));
549    }
550
551    #[test]
552    fn test_tool_result_extraction() {
553        // tool result values are not standard in OpenAI format; test ToolResult returns Null
554        // when no result is present (tool_calls don't carry result in request JSON)
555        let context = json!({
556            "model": "gpt-4o",
557            "choices": [{
558                "message": {
559                    "role": "assistant",
560                    "content": null,
561                    "tool_calls": [{
562                        "id": "c1",
563                        "type": "function",
564                        "function": {"name": "web_search", "arguments": "{\"query\": \"test\"}"}
565                    }]
566                },
567                "finish_reason": "tool_calls"
568            }]
569        });
570
571        let builder = AgentContextBuilder::from_context(&context, None).unwrap();
572
573        // Named tool with no result -> Null
574        let result = builder
575            .build_context(&AgentAssertion::ToolResult {
576                name: "web_search".to_string(),
577            })
578            .unwrap();
579        assert_eq!(result, Value::Null);
580
581        // Missing tool name -> Null
582        let result = builder
583            .build_context(&AgentAssertion::ToolResult {
584                name: "nonexistent".to_string(),
585            })
586            .unwrap();
587        assert_eq!(result, Value::Null);
588    }
589
590    #[test]
591    fn test_tool_argument_extraction() {
592        let context = json!({
593            "model": "gpt-4o",
594            "choices": [{
595                "message": {
596                    "role": "assistant",
597                    "content": null,
598                    "tool_calls": [{
599                        "id": "call_1",
600                        "type": "function",
601                        "function": {"name": "web_search", "arguments": "{\"query\": \"test query\", \"limit\": 10}"}
602                    }]
603                },
604                "finish_reason": "tool_calls"
605            }]
606        });
607
608        let builder = AgentContextBuilder::from_context(&context, None).unwrap();
609
610        let result = builder
611            .build_context(&AgentAssertion::ToolArgument {
612                name: "web_search".to_string(),
613                argument_key: "query".to_string(),
614            })
615            .unwrap();
616        assert_eq!(result, json!("test query"));
617
618        let result = builder
619            .build_context(&AgentAssertion::ToolArgument {
620                name: "web_search".to_string(),
621                argument_key: "missing".to_string(),
622            })
623            .unwrap();
624        assert_eq!(result, Value::Null);
625    }
626}