1use crate::error::EvaluationError;
2use crate::tasks::evaluator::{PATH_REGEX, REGEX_FIELD_PARSE_PATTERN};
3use potato_head::{ChatResponse, Provider};
4use scouter_types::genai::AgentAssertion;
5use serde_json::{json, Value};
6use tracing::error;
7
8const MAX_PATH_LEN: usize = 512;
9const MAX_PATH_SEGMENTS: usize = 32;
10
11#[derive(Debug, Clone)]
14pub struct AgentContextBuilder {
15 response: ChatResponse,
16 raw: Value,
17}
18
19impl AgentContextBuilder {
20 pub fn from_context(
29 context: &Value,
30 provider: Option<&Provider>,
31 ) -> Result<Self, EvaluationError> {
32 let response_val = context.get("response").unwrap_or(context);
33 let response =
34 ChatResponse::from_response_value(response_val.clone(), provider).map_err(|e| {
35 error!("Failed to parse response: {}", e);
36 EvaluationError::InvalidProviderResponse
37 })?;
38 Ok(Self {
39 response,
40 raw: response_val.clone(),
41 })
42 }
43
44 pub fn build_context(&self, assertion: &AgentAssertion) -> Result<Value, EvaluationError> {
46 match assertion {
47 AgentAssertion::ToolCalled { name } => {
48 let found = self
49 .response
50 .get_tool_calls()
51 .iter()
52 .any(|tc| tc.name == *name);
53 Ok(json!(found))
54 }
55 AgentAssertion::ToolNotCalled { name } => {
56 let not_found = !self
57 .response
58 .get_tool_calls()
59 .iter()
60 .any(|tc| tc.name == *name);
61 Ok(json!(not_found))
62 }
63 AgentAssertion::ToolCalledWithArgs { name, arguments } => {
64 let matched =
65 self.response.get_tool_calls().iter().any(|tc| {
66 tc.name == *name && Self::partial_match(&tc.arguments, &arguments.0)
67 });
68 Ok(json!(matched))
69 }
70 AgentAssertion::ToolCallSequence { names } => {
71 let actual: Vec<String> = self
72 .response
73 .get_tool_calls()
74 .iter()
75 .map(|tc| tc.name.clone())
76 .collect();
77 let mut expected_iter = names.iter();
78 let mut current = expected_iter.next();
79 for actual_name in &actual {
80 if let Some(exp) = current {
81 if actual_name == exp {
82 current = expected_iter.next();
83 }
84 }
85 }
86 Ok(json!(current.is_none()))
87 }
88 AgentAssertion::ToolCallCount { name } => {
89 let tools = &self.response.get_tool_calls();
90 let count = if let Some(name) = name {
91 tools.iter().filter(|tc| tc.name == *name).count()
92 } else {
93 tools.len()
94 };
95 Ok(json!(count))
96 }
97 AgentAssertion::ToolArgument { name, argument_key } => {
98 let value = self
99 .response
100 .get_tool_calls()
101 .iter()
102 .find(|tc| tc.name == *name)
103 .and_then(|tc| tc.arguments.get(argument_key))
104 .cloned()
105 .unwrap_or(Value::Null);
106
107 Ok(value)
108 }
109 AgentAssertion::ToolResult { name } => {
110 let value = self
111 .response
112 .get_tool_calls()
113 .iter()
114 .find(|tc| tc.name == *name)
115 .and_then(|tc| tc.result.clone())
116 .unwrap_or(Value::Null);
117
118 Ok(value)
119 }
120 AgentAssertion::ResponseContent {} => {
121 let text = self.response.response_text();
122 if text.is_empty() {
123 Ok(Value::Null)
124 } else {
125 Ok(json!(text))
126 }
127 }
128 AgentAssertion::ResponseModel {} => Ok(self
129 .response
130 .model_name()
131 .map(|m| json!(m))
132 .unwrap_or(Value::Null)),
133 AgentAssertion::ResponseFinishReason {} => Ok(self
134 .response
135 .finish_reason_str()
136 .map(|f| json!(f))
137 .unwrap_or(Value::Null)),
138 AgentAssertion::ResponseInputTokens {} => Ok(self
139 .response
140 .input_tokens()
141 .map(|t| json!(t))
142 .unwrap_or(Value::Null)),
143 AgentAssertion::ResponseOutputTokens {} => Ok(self
144 .response
145 .output_tokens()
146 .map(|t| json!(t))
147 .unwrap_or(Value::Null)),
148 AgentAssertion::ResponseTotalTokens {} => Ok(self
149 .response
150 .total_tokens()
151 .map(|t| json!(t))
152 .unwrap_or(Value::Null)),
153 AgentAssertion::ResponseField { path } => Self::extract_by_path(&self.raw, path),
154 }
155 }
156
157 fn partial_match(actual: &Value, expected: &Value) -> bool {
161 match (actual, expected) {
162 (Value::Object(actual_map), Value::Object(expected_map)) => {
163 for (key, expected_val) in expected_map {
164 match actual_map.get(key) {
165 Some(actual_val) => {
166 if !Self::partial_match(actual_val, expected_val) {
167 return false;
168 }
169 }
170 None => return false,
171 }
172 }
173 true
174 }
175 _ => actual == expected,
176 }
177 }
178
179 fn extract_by_path(val: &Value, path: &str) -> Result<Value, EvaluationError> {
182 let mut current = val.clone();
183
184 for segment in Self::parse_path_segments(path)? {
185 match segment {
186 PathSegment::Key(key) => {
187 current = current.get(&key).cloned().unwrap_or(Value::Null);
188 }
189 PathSegment::Index(idx) => {
190 current = current
191 .as_array()
192 .and_then(|arr| arr.get(idx))
193 .cloned()
194 .unwrap_or(Value::Null);
195 }
196 }
197 }
198
199 Ok(current)
200 }
201
202 fn parse_path_segments(path: &str) -> Result<Vec<PathSegment>, EvaluationError> {
203 if path.len() > MAX_PATH_LEN {
204 return Err(EvaluationError::PathTooLong(path.len()));
205 }
206
207 let regex = PATH_REGEX.get_or_init(|| {
208 regex::Regex::new(REGEX_FIELD_PARSE_PATTERN)
209 .expect("Invalid regex pattern in REGEX_FIELD_PARSE_PATTERN")
210 });
211
212 let mut segments = Vec::new();
213
214 for capture in regex.find_iter(path) {
215 let s = capture.as_str();
216 if s.starts_with('[') && s.ends_with(']') {
217 let idx_str = &s[1..s.len() - 1];
218 let idx = idx_str
219 .parse::<usize>()
220 .map_err(|_| EvaluationError::InvalidArrayIndex(idx_str.to_string()))?;
221 segments.push(PathSegment::Index(idx));
222 } else {
223 segments.push(PathSegment::Key(s.to_string()));
224 }
225 }
226
227 if segments.is_empty() {
228 return Err(EvaluationError::EmptyFieldPath);
229 }
230
231 if segments.len() > MAX_PATH_SEGMENTS {
232 return Err(EvaluationError::TooManyPathSegments(segments.len()));
233 }
234
235 Ok(segments)
236 }
237}
238
239enum PathSegment {
240 Key(String),
241 Index(usize),
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use scouter_types::genai::PyValueWrapper;
248
249 #[test]
250 fn test_tool_called_assertion() {
251 let context = json!({
252 "model": "gpt-4o",
253 "choices": [{
254 "message": {
255 "role": "assistant",
256 "content": null,
257 "tool_calls": [{
258 "id": "call_1",
259 "type": "function",
260 "function": {"name": "web_search", "arguments": "{\"query\": \"test\"}"}
261 }]
262 },
263 "finish_reason": "tool_calls"
264 }]
265 });
266
267 let builder = AgentContextBuilder::from_context(&context, None).unwrap();
268
269 let result = builder
270 .build_context(&AgentAssertion::ToolCalled {
271 name: "web_search".to_string(),
272 })
273 .unwrap();
274 assert_eq!(result, json!(true));
275
276 let result = builder
277 .build_context(&AgentAssertion::ToolNotCalled {
278 name: "delete_user".to_string(),
279 })
280 .unwrap();
281 assert_eq!(result, json!(true));
282
283 let result = builder
284 .build_context(&AgentAssertion::ToolCallCount { name: None })
285 .unwrap();
286 assert_eq!(result, json!(1));
287 }
288
289 #[test]
290 fn test_tool_called_with_args_partial_match() {
291 let context = json!({
292 "model": "gpt-4o",
293 "choices": [{
294 "message": {
295 "role": "assistant",
296 "content": null,
297 "tool_calls": [{
298 "id": "call_1",
299 "type": "function",
300 "function": {"name": "web_search", "arguments": "{\"query\": \"weather NYC\", \"lang\": \"en\", \"limit\": 5}"}
301 }]
302 },
303 "finish_reason": "tool_calls"
304 }]
305 });
306
307 let builder = AgentContextBuilder::from_context(&context, None).unwrap();
308
309 let result = builder
311 .build_context(&AgentAssertion::ToolCalledWithArgs {
312 name: "web_search".to_string(),
313 arguments: PyValueWrapper(json!({"query": "weather NYC"})),
314 })
315 .unwrap();
316 assert_eq!(result, json!(true));
317
318 let result = builder
320 .build_context(&AgentAssertion::ToolCalledWithArgs {
321 name: "web_search".to_string(),
322 arguments: PyValueWrapper(json!({"query": "weather LA"})),
323 })
324 .unwrap();
325 assert_eq!(result, json!(false));
326 }
327
328 #[test]
329 fn test_tool_call_sequence() {
330 let context = json!({
331 "model": "gpt-4o",
332 "choices": [{
333 "message": {
334 "role": "assistant",
335 "content": null,
336 "tool_calls": [
337 {"id": "call_1", "type": "function", "function": {"name": "web_search", "arguments": "{}"}},
338 {"id": "call_2", "type": "function", "function": {"name": "summarize", "arguments": "{}"}},
339 {"id": "call_3", "type": "function", "function": {"name": "respond", "arguments": "{}"}}
340 ]
341 },
342 "finish_reason": "tool_calls"
343 }]
344 });
345
346 let builder = AgentContextBuilder::from_context(&context, None).unwrap();
347
348 let result = builder
349 .build_context(&AgentAssertion::ToolCallSequence {
350 names: vec![
351 "web_search".to_string(),
352 "summarize".to_string(),
353 "respond".to_string(),
354 ],
355 })
356 .unwrap();
357 assert_eq!(result, json!(true));
358
359 let result = builder
361 .build_context(&AgentAssertion::ToolCallSequence {
362 names: vec!["respond".to_string(), "web_search".to_string()],
363 })
364 .unwrap();
365 assert_eq!(result, json!(false));
366 }
367
368 #[test]
369 fn test_response_field_escape_hatch() {
370 let context = json!({
371 "response": {
372 "candidates": [{
373 "content": {"role": "model", "parts": [{"text": "hello"}]},
374 "finishReason": "STOP",
375 "safety_ratings": [{"category": "HARM_CATEGORY_SAFE"}]
376 }],
377 "usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 2}
378 }
379 });
380
381 let builder = AgentContextBuilder::from_context(&context, None).unwrap();
382
383 let result = builder
385 .build_context(&AgentAssertion::ResponseField {
386 path: "candidates[0].safety_ratings[0].category".to_string(),
387 })
388 .unwrap();
389 assert_eq!(result, json!("HARM_CATEGORY_SAFE"));
390 }
391
392 #[test]
393 fn test_no_tool_calls() {
394 let context = json!({
395 "model": "gpt-4o",
396 "choices": [{
397 "message": {
398 "role": "assistant",
399 "content": "Just a text response."
400 },
401 "finish_reason": "stop"
402 }]
403 });
404
405 let builder = AgentContextBuilder::from_context(&context, None).unwrap();
406
407 let result = builder
408 .build_context(&AgentAssertion::ToolNotCalled {
409 name: "web_search".to_string(),
410 })
411 .unwrap();
412 assert_eq!(result, json!(true));
413 }
414
415 #[test]
416 fn test_from_context_invalid_json() {
417 let context = json!({});
419 let result = AgentContextBuilder::from_context(&context, None);
420 assert!(result.is_err());
421 assert!(matches!(
422 result,
423 Err(EvaluationError::InvalidProviderResponse)
424 ));
425 }
426
427 #[test]
428 fn test_tool_call_sequence_subsequence() {
429 let context = json!({
430 "model": "gpt-4o",
431 "choices": [{
432 "message": {
433 "role": "assistant",
434 "content": null,
435 "tool_calls": [
436 {"id": "c1", "type": "function", "function": {"name": "search", "arguments": "{}"}},
437 {"id": "c2", "type": "function", "function": {"name": "filter", "arguments": "{}"}},
438 {"id": "c3", "type": "function", "function": {"name": "rank", "arguments": "{}"}},
439 {"id": "c4", "type": "function", "function": {"name": "respond", "arguments": "{}"}}
440 ]
441 },
442 "finish_reason": "tool_calls"
443 }]
444 });
445
446 let builder = AgentContextBuilder::from_context(&context, None).unwrap();
447
448 let result = builder
450 .build_context(&AgentAssertion::ToolCallSequence {
451 names: vec![
452 "search".to_string(),
453 "rank".to_string(),
454 "respond".to_string(),
455 ],
456 })
457 .unwrap();
458 assert_eq!(result, json!(true));
459
460 let result = builder
462 .build_context(&AgentAssertion::ToolCallSequence {
463 names: vec!["respond".to_string(), "search".to_string()],
464 })
465 .unwrap();
466 assert_eq!(result, json!(false));
467 }
468
469 #[test]
470 fn test_parse_path_segments_errors() {
471 let result = AgentContextBuilder::parse_path_segments("");
473 assert!(matches!(result, Err(EvaluationError::EmptyFieldPath)));
474
475 let long_path = "a".repeat(MAX_PATH_LEN + 1);
477 let result = AgentContextBuilder::parse_path_segments(&long_path);
478 assert!(matches!(result, Err(EvaluationError::PathTooLong(_))));
479
480 let many_segments = (0..MAX_PATH_SEGMENTS + 1)
482 .map(|i| format!("seg{}", i))
483 .collect::<Vec<_>>()
484 .join(".");
485 let result = AgentContextBuilder::parse_path_segments(&many_segments);
486 assert!(matches!(
487 result,
488 Err(EvaluationError::TooManyPathSegments(_))
489 ));
490 }
491
492 #[test]
493 fn test_response_content_empty() {
494 let context = json!({
495 "model": "gpt-4o",
496 "choices": [{
497 "message": {
498 "role": "assistant",
499 "content": null
500 },
501 "finish_reason": "stop"
502 }]
503 });
504
505 let builder = AgentContextBuilder::from_context(&context, None).unwrap();
506 let result = builder
507 .build_context(&AgentAssertion::ResponseContent {})
508 .unwrap();
509 assert_eq!(result, Value::Null);
510 }
511
512 #[test]
513 fn test_partial_match_nested() {
514 let context = json!({
515 "model": "gpt-4o",
516 "choices": [{
517 "message": {
518 "role": "assistant",
519 "content": null,
520 "tool_calls": [{
521 "id": "c1",
522 "type": "function",
523 "function": {"name": "create_item", "arguments": "{\"item\": {\"name\": \"widget\", \"price\": 9.99, \"tags\": [\"sale\"]}}"}
524 }]
525 },
526 "finish_reason": "tool_calls"
527 }]
528 });
529
530 let builder = AgentContextBuilder::from_context(&context, None).unwrap();
531
532 let result = builder
534 .build_context(&AgentAssertion::ToolCalledWithArgs {
535 name: "create_item".to_string(),
536 arguments: PyValueWrapper(json!({"item": {"name": "widget"}})),
537 })
538 .unwrap();
539 assert_eq!(result, json!(true));
540
541 let result = builder
543 .build_context(&AgentAssertion::ToolCalledWithArgs {
544 name: "create_item".to_string(),
545 arguments: PyValueWrapper(json!({"item": {"name": "gadget"}})),
546 })
547 .unwrap();
548 assert_eq!(result, json!(false));
549 }
550
551 #[test]
552 fn test_tool_result_extraction() {
553 let context = json!({
556 "model": "gpt-4o",
557 "choices": [{
558 "message": {
559 "role": "assistant",
560 "content": null,
561 "tool_calls": [{
562 "id": "c1",
563 "type": "function",
564 "function": {"name": "web_search", "arguments": "{\"query\": \"test\"}"}
565 }]
566 },
567 "finish_reason": "tool_calls"
568 }]
569 });
570
571 let builder = AgentContextBuilder::from_context(&context, None).unwrap();
572
573 let result = builder
575 .build_context(&AgentAssertion::ToolResult {
576 name: "web_search".to_string(),
577 })
578 .unwrap();
579 assert_eq!(result, Value::Null);
580
581 let result = builder
583 .build_context(&AgentAssertion::ToolResult {
584 name: "nonexistent".to_string(),
585 })
586 .unwrap();
587 assert_eq!(result, Value::Null);
588 }
589
590 #[test]
591 fn test_tool_argument_extraction() {
592 let context = json!({
593 "model": "gpt-4o",
594 "choices": [{
595 "message": {
596 "role": "assistant",
597 "content": null,
598 "tool_calls": [{
599 "id": "call_1",
600 "type": "function",
601 "function": {"name": "web_search", "arguments": "{\"query\": \"test query\", \"limit\": 10}"}
602 }]
603 },
604 "finish_reason": "tool_calls"
605 }]
606 });
607
608 let builder = AgentContextBuilder::from_context(&context, None).unwrap();
609
610 let result = builder
611 .build_context(&AgentAssertion::ToolArgument {
612 name: "web_search".to_string(),
613 argument_key: "query".to_string(),
614 })
615 .unwrap();
616 assert_eq!(result, json!("test query"));
617
618 let result = builder
619 .build_context(&AgentAssertion::ToolArgument {
620 name: "web_search".to_string(),
621 argument_key: "missing".to_string(),
622 })
623 .unwrap();
624 assert_eq!(result, Value::Null);
625 }
626}