Skip to main content

scouter_evaluate/evaluate/
agent.rs

1use crate::error::EvaluationError;
2use crate::tasks::evaluator::{PATH_REGEX, REGEX_FIELD_PARSE_PATTERN};
3use potato_head::ChatResponse;
4use scouter_types::genai::AgentAssertion;
5use serde_json::{json, Value};
6use tracing::error;
7
8const MAX_PATH_LEN: usize = 512;
9const MAX_PATH_SEGMENTS: usize = 32;
10
11/// Builds evaluation context from vendor LLM request/response data.
12/// Normalizes vendor-specific formats into a standard structure for assertion evaluation.
13#[derive(Debug, Clone)]
14pub struct AgentContextBuilder {
15    response: ChatResponse,
16    raw: Value,
17}
18
19impl AgentContextBuilder {
20    /// Build an AgentContextBuilder from raw context JSON.
21    ///
22    /// Extraction strategy (priority order):
23    /// 1. Pre-normalized — if top-level keys match our standard shape
24    /// 2. OpenAI format — choices[].message.tool_calls, usage, model
25    /// 3. Anthropic format — content[] with ToolUseBlock, usage, model
26    /// 4. Google/Gemini format — candidates[].content.parts[] with function_call
27    /// 5. Fallback — walk JSON tree for known patterns
28    pub fn from_context(context: &Value) -> Result<Self, EvaluationError> {
29        let response_val = context.get("response").unwrap_or(context);
30        let response = ChatResponse::from_response_value(response_val.clone()).map_err(|e| {
31            error!("Failed to parse response: {}", e);
32            EvaluationError::InvalidProviderResponse
33        })?;
34        Ok(Self {
35            response,
36            raw: response_val.clone(),
37        })
38    }
39
40    /// Resolve an AgentAssertion variant to a JSON value for the shared AssertionEvaluator.
41    pub fn build_context(&self, assertion: &AgentAssertion) -> Result<Value, EvaluationError> {
42        match assertion {
43            AgentAssertion::ToolCalled { name } => {
44                let found = self
45                    .response
46                    .get_tool_calls()
47                    .iter()
48                    .any(|tc| tc.name == *name);
49                Ok(json!(found))
50            }
51            AgentAssertion::ToolNotCalled { name } => {
52                let not_found = !self
53                    .response
54                    .get_tool_calls()
55                    .iter()
56                    .any(|tc| tc.name == *name);
57                Ok(json!(not_found))
58            }
59            AgentAssertion::ToolCalledWithArgs { name, arguments } => {
60                let matched =
61                    self.response.get_tool_calls().iter().any(|tc| {
62                        tc.name == *name && Self::partial_match(&tc.arguments, &arguments.0)
63                    });
64                Ok(json!(matched))
65            }
66            AgentAssertion::ToolCallSequence { names } => {
67                let actual: Vec<String> = self
68                    .response
69                    .get_tool_calls()
70                    .iter()
71                    .map(|tc| tc.name.clone())
72                    .collect();
73                let mut expected_iter = names.iter();
74                let mut current = expected_iter.next();
75                for actual_name in &actual {
76                    if let Some(exp) = current {
77                        if actual_name == exp {
78                            current = expected_iter.next();
79                        }
80                    }
81                }
82                Ok(json!(current.is_none()))
83            }
84            AgentAssertion::ToolCallCount { name } => {
85                let tools = &self.response.get_tool_calls();
86                let count = if let Some(name) = name {
87                    tools.iter().filter(|tc| tc.name == *name).count()
88                } else {
89                    tools.len()
90                };
91                Ok(json!(count))
92            }
93            AgentAssertion::ToolArgument { name, argument_key } => {
94                let value = self
95                    .response
96                    .get_tool_calls()
97                    .iter()
98                    .find(|tc| tc.name == *name)
99                    .and_then(|tc| tc.arguments.get(argument_key))
100                    .cloned()
101                    .unwrap_or(Value::Null);
102
103                Ok(value)
104            }
105            AgentAssertion::ToolResult { name } => {
106                let value = self
107                    .response
108                    .get_tool_calls()
109                    .iter()
110                    .find(|tc| tc.name == *name)
111                    .and_then(|tc| tc.result.clone())
112                    .unwrap_or(Value::Null);
113
114                Ok(value)
115            }
116            AgentAssertion::ResponseContent {} => {
117                let text = self.response.response_text();
118                if text.is_empty() {
119                    Ok(Value::Null)
120                } else {
121                    Ok(json!(text))
122                }
123            }
124            AgentAssertion::ResponseModel {} => Ok(self
125                .response
126                .model_name()
127                .map(|m| json!(m))
128                .unwrap_or(Value::Null)),
129            AgentAssertion::ResponseFinishReason {} => Ok(self
130                .response
131                .finish_reason_str()
132                .map(|f| json!(f))
133                .unwrap_or(Value::Null)),
134            AgentAssertion::ResponseInputTokens {} => Ok(self
135                .response
136                .input_tokens()
137                .map(|t| json!(t))
138                .unwrap_or(Value::Null)),
139            AgentAssertion::ResponseOutputTokens {} => Ok(self
140                .response
141                .output_tokens()
142                .map(|t| json!(t))
143                .unwrap_or(Value::Null)),
144            AgentAssertion::ResponseTotalTokens {} => Ok(self
145                .response
146                .total_tokens()
147                .map(|t| json!(t))
148                .unwrap_or(Value::Null)),
149            AgentAssertion::ResponseField { path } => Self::extract_by_path(&self.raw, path),
150        }
151    }
152
153    // ─── Helpers ───────────────────────────────────────────────────────
154
155    /// Check if all specified args are present and equal in actual args (partial match).
156    fn partial_match(actual: &Value, expected: &Value) -> bool {
157        match (actual, expected) {
158            (Value::Object(actual_map), Value::Object(expected_map)) => {
159                for (key, expected_val) in expected_map {
160                    match actual_map.get(key) {
161                        Some(actual_val) => {
162                            if !Self::partial_match(actual_val, expected_val) {
163                                return false;
164                            }
165                        }
166                        None => return false,
167                    }
168                }
169                true
170            }
171            _ => actual == expected,
172        }
173    }
174
175    /// Extract a value from JSON using dot-notation path with array indexing.
176    /// Supports: "foo.bar", "foo[0].bar", "candidates[0].content.parts[0].text"
177    fn extract_by_path(val: &Value, path: &str) -> Result<Value, EvaluationError> {
178        let mut current = val.clone();
179
180        for segment in Self::parse_path_segments(path)? {
181            match segment {
182                PathSegment::Key(key) => {
183                    current = current.get(&key).cloned().unwrap_or(Value::Null);
184                }
185                PathSegment::Index(idx) => {
186                    current = current
187                        .as_array()
188                        .and_then(|arr| arr.get(idx))
189                        .cloned()
190                        .unwrap_or(Value::Null);
191                }
192            }
193        }
194
195        Ok(current)
196    }
197
198    fn parse_path_segments(path: &str) -> Result<Vec<PathSegment>, EvaluationError> {
199        if path.len() > MAX_PATH_LEN {
200            return Err(EvaluationError::PathTooLong(path.len()));
201        }
202
203        let regex = PATH_REGEX.get_or_init(|| {
204            regex::Regex::new(REGEX_FIELD_PARSE_PATTERN)
205                .expect("Invalid regex pattern in REGEX_FIELD_PARSE_PATTERN")
206        });
207
208        let mut segments = Vec::new();
209
210        for capture in regex.find_iter(path) {
211            let s = capture.as_str();
212            if s.starts_with('[') && s.ends_with(']') {
213                let idx_str = &s[1..s.len() - 1];
214                let idx = idx_str
215                    .parse::<usize>()
216                    .map_err(|_| EvaluationError::InvalidArrayIndex(idx_str.to_string()))?;
217                segments.push(PathSegment::Index(idx));
218            } else {
219                segments.push(PathSegment::Key(s.to_string()));
220            }
221        }
222
223        if segments.is_empty() {
224            return Err(EvaluationError::EmptyFieldPath);
225        }
226
227        if segments.len() > MAX_PATH_SEGMENTS {
228            return Err(EvaluationError::TooManyPathSegments(segments.len()));
229        }
230
231        Ok(segments)
232    }
233}
234
235enum PathSegment {
236    Key(String),
237    Index(usize),
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use scouter_types::genai::PyValueWrapper;
244
245    #[test]
246    fn test_tool_called_assertion() {
247        let context = json!({
248            "model": "gpt-4o",
249            "choices": [{
250                "message": {
251                    "role": "assistant",
252                    "content": null,
253                    "tool_calls": [{
254                        "id": "call_1",
255                        "type": "function",
256                        "function": {"name": "web_search", "arguments": "{\"query\": \"test\"}"}
257                    }]
258                },
259                "finish_reason": "tool_calls"
260            }]
261        });
262
263        let builder = AgentContextBuilder::from_context(&context).unwrap();
264
265        let result = builder
266            .build_context(&AgentAssertion::ToolCalled {
267                name: "web_search".to_string(),
268            })
269            .unwrap();
270        assert_eq!(result, json!(true));
271
272        let result = builder
273            .build_context(&AgentAssertion::ToolNotCalled {
274                name: "delete_user".to_string(),
275            })
276            .unwrap();
277        assert_eq!(result, json!(true));
278
279        let result = builder
280            .build_context(&AgentAssertion::ToolCallCount { name: None })
281            .unwrap();
282        assert_eq!(result, json!(1));
283    }
284
285    #[test]
286    fn test_tool_called_with_args_partial_match() {
287        let context = json!({
288            "model": "gpt-4o",
289            "choices": [{
290                "message": {
291                    "role": "assistant",
292                    "content": null,
293                    "tool_calls": [{
294                        "id": "call_1",
295                        "type": "function",
296                        "function": {"name": "web_search", "arguments": "{\"query\": \"weather NYC\", \"lang\": \"en\", \"limit\": 5}"}
297                    }]
298                },
299                "finish_reason": "tool_calls"
300            }]
301        });
302
303        let builder = AgentContextBuilder::from_context(&context).unwrap();
304
305        // Partial match - only checking "query"
306        let result = builder
307            .build_context(&AgentAssertion::ToolCalledWithArgs {
308                name: "web_search".to_string(),
309                arguments: PyValueWrapper(json!({"query": "weather NYC"})),
310            })
311            .unwrap();
312        assert_eq!(result, json!(true));
313
314        // Non-matching arg
315        let result = builder
316            .build_context(&AgentAssertion::ToolCalledWithArgs {
317                name: "web_search".to_string(),
318                arguments: PyValueWrapper(json!({"query": "weather LA"})),
319            })
320            .unwrap();
321        assert_eq!(result, json!(false));
322    }
323
324    #[test]
325    fn test_tool_call_sequence() {
326        let context = json!({
327            "model": "gpt-4o",
328            "choices": [{
329                "message": {
330                    "role": "assistant",
331                    "content": null,
332                    "tool_calls": [
333                        {"id": "call_1", "type": "function", "function": {"name": "web_search", "arguments": "{}"}},
334                        {"id": "call_2", "type": "function", "function": {"name": "summarize", "arguments": "{}"}},
335                        {"id": "call_3", "type": "function", "function": {"name": "respond", "arguments": "{}"}}
336                    ]
337                },
338                "finish_reason": "tool_calls"
339            }]
340        });
341
342        let builder = AgentContextBuilder::from_context(&context).unwrap();
343
344        let result = builder
345            .build_context(&AgentAssertion::ToolCallSequence {
346                names: vec![
347                    "web_search".to_string(),
348                    "summarize".to_string(),
349                    "respond".to_string(),
350                ],
351            })
352            .unwrap();
353        assert_eq!(result, json!(true));
354
355        // Wrong order
356        let result = builder
357            .build_context(&AgentAssertion::ToolCallSequence {
358                names: vec!["respond".to_string(), "web_search".to_string()],
359            })
360            .unwrap();
361        assert_eq!(result, json!(false));
362    }
363
364    #[test]
365    fn test_response_field_escape_hatch() {
366        let context = json!({
367            "response": {
368                "candidates": [{
369                    "content": {"role": "model", "parts": [{"text": "hello"}]},
370                    "finishReason": "STOP",
371                    "safety_ratings": [{"category": "HARM_CATEGORY_SAFE"}]
372                }],
373                "usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 2}
374            }
375        });
376
377        let builder = AgentContextBuilder::from_context(&context).unwrap();
378
379        // Path is relative to response_val (the candidates object), not the full context
380        let result = builder
381            .build_context(&AgentAssertion::ResponseField {
382                path: "candidates[0].safety_ratings[0].category".to_string(),
383            })
384            .unwrap();
385        assert_eq!(result, json!("HARM_CATEGORY_SAFE"));
386    }
387
388    #[test]
389    fn test_no_tool_calls() {
390        let context = json!({
391            "model": "gpt-4o",
392            "choices": [{
393                "message": {
394                    "role": "assistant",
395                    "content": "Just a text response."
396                },
397                "finish_reason": "stop"
398            }]
399        });
400
401        let builder = AgentContextBuilder::from_context(&context).unwrap();
402
403        let result = builder
404            .build_context(&AgentAssertion::ToolNotCalled {
405                name: "web_search".to_string(),
406            })
407            .unwrap();
408        assert_eq!(result, json!(true));
409    }
410
411    #[test]
412    fn test_from_context_invalid_json() {
413        // Empty object has no recognizable vendor keys
414        let context = json!({});
415        let result = AgentContextBuilder::from_context(&context);
416        assert!(result.is_err());
417        assert!(matches!(
418            result,
419            Err(EvaluationError::InvalidProviderResponse)
420        ));
421    }
422
423    #[test]
424    fn test_tool_call_sequence_subsequence() {
425        let context = json!({
426            "model": "gpt-4o",
427            "choices": [{
428                "message": {
429                    "role": "assistant",
430                    "content": null,
431                    "tool_calls": [
432                        {"id": "c1", "type": "function", "function": {"name": "search", "arguments": "{}"}},
433                        {"id": "c2", "type": "function", "function": {"name": "filter", "arguments": "{}"}},
434                        {"id": "c3", "type": "function", "function": {"name": "rank", "arguments": "{}"}},
435                        {"id": "c4", "type": "function", "function": {"name": "respond", "arguments": "{}"}}
436                    ]
437                },
438                "finish_reason": "tool_calls"
439            }]
440        });
441
442        let builder = AgentContextBuilder::from_context(&context).unwrap();
443
444        // Non-contiguous in-order subsequence should pass
445        let result = builder
446            .build_context(&AgentAssertion::ToolCallSequence {
447                names: vec![
448                    "search".to_string(),
449                    "rank".to_string(),
450                    "respond".to_string(),
451                ],
452            })
453            .unwrap();
454        assert_eq!(result, json!(true));
455
456        // Out-of-order should fail
457        let result = builder
458            .build_context(&AgentAssertion::ToolCallSequence {
459                names: vec!["respond".to_string(), "search".to_string()],
460            })
461            .unwrap();
462        assert_eq!(result, json!(false));
463    }
464
465    #[test]
466    fn test_parse_path_segments_errors() {
467        // Empty string -> EmptyFieldPath
468        let result = AgentContextBuilder::parse_path_segments("");
469        assert!(matches!(result, Err(EvaluationError::EmptyFieldPath)));
470
471        // Path exceeding max length
472        let long_path = "a".repeat(MAX_PATH_LEN + 1);
473        let result = AgentContextBuilder::parse_path_segments(&long_path);
474        assert!(matches!(result, Err(EvaluationError::PathTooLong(_))));
475
476        // Too many segments
477        let many_segments = (0..MAX_PATH_SEGMENTS + 1)
478            .map(|i| format!("seg{}", i))
479            .collect::<Vec<_>>()
480            .join(".");
481        let result = AgentContextBuilder::parse_path_segments(&many_segments);
482        assert!(matches!(
483            result,
484            Err(EvaluationError::TooManyPathSegments(_))
485        ));
486    }
487
488    #[test]
489    fn test_response_content_empty() {
490        let context = json!({
491            "model": "gpt-4o",
492            "choices": [{
493                "message": {
494                    "role": "assistant",
495                    "content": null
496                },
497                "finish_reason": "stop"
498            }]
499        });
500
501        let builder = AgentContextBuilder::from_context(&context).unwrap();
502        let result = builder
503            .build_context(&AgentAssertion::ResponseContent {})
504            .unwrap();
505        assert_eq!(result, Value::Null);
506    }
507
508    #[test]
509    fn test_partial_match_nested() {
510        let context = json!({
511            "model": "gpt-4o",
512            "choices": [{
513                "message": {
514                    "role": "assistant",
515                    "content": null,
516                    "tool_calls": [{
517                        "id": "c1",
518                        "type": "function",
519                        "function": {"name": "create_item", "arguments": "{\"item\": {\"name\": \"widget\", \"price\": 9.99, \"tags\": [\"sale\"]}}"}
520                    }]
521                },
522                "finish_reason": "tool_calls"
523            }]
524        });
525
526        let builder = AgentContextBuilder::from_context(&context).unwrap();
527
528        // Nested partial match - only check inner "name"
529        let result = builder
530            .build_context(&AgentAssertion::ToolCalledWithArgs {
531                name: "create_item".to_string(),
532                arguments: PyValueWrapper(json!({"item": {"name": "widget"}})),
533            })
534            .unwrap();
535        assert_eq!(result, json!(true));
536
537        // Nested mismatch
538        let result = builder
539            .build_context(&AgentAssertion::ToolCalledWithArgs {
540                name: "create_item".to_string(),
541                arguments: PyValueWrapper(json!({"item": {"name": "gadget"}})),
542            })
543            .unwrap();
544        assert_eq!(result, json!(false));
545    }
546
547    #[test]
548    fn test_tool_result_extraction() {
549        // tool result values are not standard in OpenAI format; test ToolResult returns Null
550        // when no result is present (tool_calls don't carry result in request JSON)
551        let context = json!({
552            "model": "gpt-4o",
553            "choices": [{
554                "message": {
555                    "role": "assistant",
556                    "content": null,
557                    "tool_calls": [{
558                        "id": "c1",
559                        "type": "function",
560                        "function": {"name": "web_search", "arguments": "{\"query\": \"test\"}"}
561                    }]
562                },
563                "finish_reason": "tool_calls"
564            }]
565        });
566
567        let builder = AgentContextBuilder::from_context(&context).unwrap();
568
569        // Named tool with no result -> Null
570        let result = builder
571            .build_context(&AgentAssertion::ToolResult {
572                name: "web_search".to_string(),
573            })
574            .unwrap();
575        assert_eq!(result, Value::Null);
576
577        // Missing tool name -> Null
578        let result = builder
579            .build_context(&AgentAssertion::ToolResult {
580                name: "nonexistent".to_string(),
581            })
582            .unwrap();
583        assert_eq!(result, Value::Null);
584    }
585
586    #[test]
587    fn test_tool_argument_extraction() {
588        let context = json!({
589            "model": "gpt-4o",
590            "choices": [{
591                "message": {
592                    "role": "assistant",
593                    "content": null,
594                    "tool_calls": [{
595                        "id": "call_1",
596                        "type": "function",
597                        "function": {"name": "web_search", "arguments": "{\"query\": \"test query\", \"limit\": 10}"}
598                    }]
599                },
600                "finish_reason": "tool_calls"
601            }]
602        });
603
604        let builder = AgentContextBuilder::from_context(&context).unwrap();
605
606        let result = builder
607            .build_context(&AgentAssertion::ToolArgument {
608                name: "web_search".to_string(),
609                argument_key: "query".to_string(),
610            })
611            .unwrap();
612        assert_eq!(result, json!("test query"));
613
614        let result = builder
615            .build_context(&AgentAssertion::ToolArgument {
616                name: "web_search".to_string(),
617                argument_key: "missing".to_string(),
618            })
619            .unwrap();
620        assert_eq!(result, Value::Null);
621    }
622}