Skip to main content

scouter_evaluate/evaluate/
trace.rs

1use crate::error::EvaluationError;
2use crate::evaluate::agent::AgentContextBuilder;
3use crate::tasks::evaluator::AssertionEvaluator;
4use regex::Regex;
5use scouter_types::genai::{
6    AggregationType, AttributeFilterTask, MultiResponseMode, SpanFilter, SpanStatus, TraceAssertion,
7};
8use scouter_types::sql::TraceSpan;
9use serde_json::{json, Value};
10use std::collections::{HashMap, HashSet};
11use std::sync::Arc;
12use tracing::{debug, warn};
13
14#[derive(Debug, Clone)]
15pub struct TraceContextBuilder {
16    /// We want to share trace spans across multiple evaluations
17    pub(crate) spans: Arc<Vec<TraceSpan>>,
18}
19
20fn collect_filter_regexes<'a>(
21    filter: &'a SpanFilter,
22    out: &mut HashMap<&'a str, Regex>,
23) -> Result<(), EvaluationError> {
24    match filter {
25        SpanFilter::ByNamePattern { pattern } => {
26            if !out.contains_key(pattern.as_str()) {
27                out.insert(pattern.as_str(), Regex::new(pattern)?);
28            }
29        }
30        SpanFilter::And { filters } | SpanFilter::Or { filters } => {
31            for f in filters {
32                collect_filter_regexes(f, out)?;
33            }
34        }
35        _ => {}
36    }
37    Ok(())
38}
39
40impl TraceContextBuilder {
41    pub fn new(spans: Arc<Vec<TraceSpan>>) -> Self {
42        Self { spans }
43    }
44
45    /// Converts trace data into a JSON context that AssertionEvaluator can process
46    pub fn build_context(&self, assertion: &TraceAssertion) -> Result<Value, EvaluationError> {
47        match assertion {
48            TraceAssertion::SpanSequence { span_names } => {
49                Ok(json!(self.match_span_sequence(span_names)?))
50            }
51            TraceAssertion::SpanSet { span_names } => Ok(json!(self.match_span_set(span_names)?)),
52            TraceAssertion::SpanCount { filter } => Ok(json!(self.count_spans(filter)?)),
53            TraceAssertion::SpanExists { filter } => Ok(json!(self.span_exists(filter)?)),
54            TraceAssertion::SpanAttribute {
55                filter,
56                attribute_key,
57            } => self.extract_span_attribute(filter, attribute_key),
58            TraceAssertion::SpanDuration { filter } => self.extract_span_duration(filter),
59            TraceAssertion::SpanAggregation {
60                filter,
61                attribute_key,
62                aggregation,
63            } => self.aggregate_span_attribute(filter, attribute_key, aggregation),
64            TraceAssertion::TraceDuration {} => Ok(json!(self.calculate_trace_duration())),
65            TraceAssertion::TraceSpanCount {} => Ok(json!(self.spans.len())),
66            TraceAssertion::TraceErrorCount {} => Ok(json!(self.count_error_spans())),
67            TraceAssertion::TraceServiceCount {} => Ok(json!(self.count_unique_services())),
68            TraceAssertion::TraceMaxDepth {} => Ok(json!(self.calculate_max_depth())),
69            TraceAssertion::TraceAttribute { attribute_key } => {
70                self.extract_trace_attribute(attribute_key)
71            }
72            TraceAssertion::AttributeFilter { key, task, mode } => {
73                self.evaluate_attribute_filter(key, task, mode)
74            }
75        }
76    }
77
78    // Span filtering logic
79    fn filter_spans(&self, filter: &SpanFilter) -> Result<Vec<&TraceSpan>, EvaluationError> {
80        let mut regexes: HashMap<&str, Regex> = HashMap::new();
81        collect_filter_regexes(filter, &mut regexes)?;
82
83        let mut filtered = Vec::new();
84        for span in self.spans.iter() {
85            if self.matches_filter(span, filter, &regexes)? {
86                filtered.push(span);
87            }
88        }
89
90        debug!(
91            "Filtered spans count: {} with filter {:?}",
92            filtered.len(),
93            filter
94        );
95
96        Ok(filtered)
97    }
98
99    fn matches_filter(
100        &self,
101        span: &TraceSpan,
102        filter: &SpanFilter,
103        regexes: &HashMap<&str, Regex>,
104    ) -> Result<bool, EvaluationError> {
105        match filter {
106            SpanFilter::ByName { name } => Ok(span.span_name == *name),
107
108            SpanFilter::ByNamePattern { pattern } => {
109                let regex = regexes
110                    .get(pattern.as_str())
111                    .expect("regex pre-compiled by collect_filter_regexes");
112                Ok(regex.is_match(&span.span_name))
113            }
114
115            SpanFilter::WithAttribute { key } => {
116                Ok(span.attributes.iter().any(|attr| attr.key == *key))
117            }
118
119            SpanFilter::WithAttributeValue { key, value } => {
120                Ok(span.attributes.iter().any(|attr| {
121                    attr.key == *key && self.attribute_value_matches(&attr.value, &value.0)
122                }))
123            }
124
125            SpanFilter::WithStatus { status } => {
126                Ok(self.map_status_code(span.status_code) == *status)
127            }
128
129            SpanFilter::WithDuration { min_ms, max_ms } => {
130                let duration_f64 = span.duration_ms as f64;
131                let min_ok = min_ms.is_none_or(|min| duration_f64 >= min);
132                let max_ok = max_ms.is_none_or(|max| duration_f64 <= max);
133                Ok(min_ok && max_ok)
134            }
135
136            SpanFilter::And { filters } => {
137                for f in filters {
138                    if !self.matches_filter(span, f, regexes)? {
139                        return Ok(false);
140                    }
141                }
142                Ok(true)
143            }
144
145            SpanFilter::Or { filters } => {
146                for f in filters {
147                    if self.matches_filter(span, f, regexes)? {
148                        return Ok(true);
149                    }
150                }
151                Ok(false)
152            }
153
154            SpanFilter::Sequence { .. } => Err(EvaluationError::InvalidFilter(
155                "Sequence filter not applicable to individual spans".to_string(),
156            )),
157        }
158    }
159
160    /// Get ordered list of span names
161    fn match_span_sequence(&self, span_names: &[String]) -> Result<bool, EvaluationError> {
162        let executed_names = self.get_ordered_span_names()?;
163        Ok(executed_names == span_names)
164    }
165
166    /// Get unique set of span names. Order does not matter.
167    fn match_span_set(&self, span_names: &[String]) -> Result<bool, EvaluationError> {
168        let unique_names: HashSet<_> = self.spans.iter().map(|s| s.span_name.clone()).collect();
169        for name in span_names {
170            if !unique_names.contains(name) {
171                return Ok(false);
172            }
173        }
174        Ok(true)
175    }
176
177    fn count_spans(&self, filter: &SpanFilter) -> Result<usize, EvaluationError> {
178        match filter {
179            SpanFilter::Sequence { names } => self.count_sequence_occurrences(names),
180            _ => Ok(self.filter_spans(filter)?.len()),
181        }
182    }
183
184    /// Count how many times a specific sequence of span names appears consecutively
185    fn count_sequence_occurrences(
186        &self,
187        target_sequence: &[String],
188    ) -> Result<usize, EvaluationError> {
189        if target_sequence.is_empty() {
190            return Ok(0);
191        }
192
193        let all_span_names = self.get_ordered_span_names()?;
194
195        if all_span_names.len() < target_sequence.len() {
196            return Ok(0);
197        }
198
199        Ok(all_span_names
200            .windows(target_sequence.len())
201            .filter(|window| *window == target_sequence)
202            .count())
203    }
204
205    fn get_ordered_span_names(&self) -> Result<Vec<String>, EvaluationError> {
206        let mut ordered_spans: Vec<_> = self.spans.iter().collect();
207        ordered_spans.sort_by_key(|s| s.span_order);
208
209        Ok(ordered_spans
210            .into_iter()
211            .map(|s| s.span_name.clone())
212            .collect())
213    }
214
215    fn span_exists(&self, filter: &SpanFilter) -> Result<bool, EvaluationError> {
216        Ok(!self.filter_spans(filter)?.is_empty())
217    }
218
219    fn extract_span_attribute(
220        &self,
221        filter: &SpanFilter,
222        attribute_key: &str,
223    ) -> Result<Value, EvaluationError> {
224        let filtered_spans = self.filter_spans(filter)?;
225
226        if filtered_spans.is_empty() {
227            return Ok(Value::Null);
228        }
229
230        let mut values: Vec<Value> = filtered_spans
231            .iter()
232            .filter_map(|span| {
233                span.attributes
234                    .iter()
235                    .find(|attr| attr.key == attribute_key)
236                    .map(|attr| attr.value.clone())
237            })
238            .collect();
239
240        if values.len() == 1 {
241            Ok(values.remove(0))
242        } else {
243            Ok(Value::Array(values))
244        }
245    }
246
247    fn extract_span_duration(&self, filter: &SpanFilter) -> Result<Value, EvaluationError> {
248        let filtered_spans = self.filter_spans(filter)?;
249
250        let durations: Vec<i64> = filtered_spans.iter().map(|span| span.duration_ms).collect();
251
252        if durations.len() == 1 {
253            Ok(json!(durations[0]))
254        } else {
255            Ok(json!(durations))
256        }
257    }
258
259    fn aggregate_span_attribute(
260        &self,
261        filter: &SpanFilter,
262        attribute_key: &str,
263        aggregation: &AggregationType,
264    ) -> Result<Value, EvaluationError> {
265        let filtered_spans = self.filter_spans(filter)?;
266
267        match aggregation {
268            AggregationType::Count => {
269                let count = filtered_spans
270                    .iter()
271                    .filter(|span| span.attributes.iter().any(|attr| attr.key == attribute_key))
272                    .count();
273                Ok(json!(count))
274            }
275            _ => {
276                let values: Vec<f64> = filtered_spans
277                    .iter()
278                    .filter_map(|span| {
279                        span.attributes
280                            .iter()
281                            .find(|attr| attr.key == attribute_key)
282                            .and_then(|attr| attr.value.as_f64())
283                    })
284                    .collect();
285
286                if values.is_empty() {
287                    return Ok(Value::Null);
288                }
289
290                let result = match aggregation {
291                    AggregationType::Count => unreachable!(),
292                    AggregationType::Sum => values.iter().sum(),
293                    AggregationType::Average => values.iter().sum::<f64>() / values.len() as f64,
294                    AggregationType::Min => values.iter().copied().fold(f64::INFINITY, f64::min),
295                    AggregationType::Max => {
296                        values.iter().copied().fold(f64::NEG_INFINITY, f64::max)
297                    }
298                    AggregationType::First => values[0],
299                    AggregationType::Last => values[values.len() - 1],
300                };
301
302                Ok(json!(result))
303            }
304        }
305    }
306
307    // Trace-level calculations
308    fn calculate_trace_duration(&self) -> i64 {
309        let min_start = self.spans.iter().map(|s| s.start_time).min();
310        let max_end = self.spans.iter().map(|s| s.end_time).max();
311        match (min_start, max_end) {
312            (Some(start), Some(end)) => (end - start).num_milliseconds().max(0),
313            _ => 0,
314        }
315    }
316
317    fn count_error_spans(&self) -> usize {
318        self.spans
319            .iter()
320            .filter(|s| self.map_status_code(s.status_code) == SpanStatus::Error)
321            .count()
322    }
323
324    fn count_unique_services(&self) -> usize {
325        self.spans
326            .iter()
327            .map(|s| &s.service_name)
328            .collect::<HashSet<_>>()
329            .len()
330    }
331
332    fn calculate_max_depth(&self) -> i32 {
333        self.spans.iter().map(|s| s.depth).max().unwrap_or(0)
334    }
335
336    fn extract_trace_attribute(&self, attribute_key: &str) -> Result<Value, EvaluationError> {
337        let root_span = self
338            .spans
339            .iter()
340            .find(|s| s.depth == 0)
341            .ok_or_else(|| EvaluationError::NoRootSpan)?;
342
343        root_span
344            .attributes
345            .iter()
346            .find(|attr| attr.key == attribute_key)
347            .map(|attr| attr.value.clone())
348            .ok_or_else(|| EvaluationError::AttributeNotFound(attribute_key.to_string()))
349    }
350
351    fn evaluate_attribute_filter(
352        &self,
353        key: &str,
354        task: &AttributeFilterTask,
355        mode: &MultiResponseMode,
356    ) -> Result<Value, EvaluationError> {
357        // Find all spans with the key attribute and extract values
358        let values: Vec<Value> = self
359            .spans
360            .iter()
361            .filter_map(|span| {
362                span.attributes
363                    .iter()
364                    .find(|attr| attr.key == key)
365                    .map(|attr| attr.value.clone())
366            })
367            .collect();
368
369        if values.is_empty() {
370            return Ok(json!(false));
371        }
372
373        let results: Vec<bool> = values
374            .iter()
375            .map(|value| self.evaluate_inner_task(value, task))
376            .collect::<Result<Vec<bool>, EvaluationError>>()?;
377
378        let passed = match mode {
379            MultiResponseMode::Any => results.iter().any(|r| *r),
380            MultiResponseMode::All => results.iter().all(|r| *r),
381        };
382
383        Ok(json!(passed))
384    }
385
386    fn evaluate_inner_task(
387        &self,
388        value: &Value,
389        task: &AttributeFilterTask,
390    ) -> Result<bool, EvaluationError> {
391        // Attribute values stored as JSON strings need parsing
392        let parsed = match value {
393            Value::String(s) => serde_json::from_str(s).unwrap_or_else(|_| {
394                warn!("Attribute value is a non-JSON string; evaluating as raw string value");
395                value.clone()
396            }),
397            _ => value.clone(),
398        };
399
400        match task {
401            AttributeFilterTask::Assertion(assertion_task) => {
402                let result = AssertionEvaluator::evaluate_assertion(&parsed, assertion_task)?;
403                Ok(result.passed)
404            }
405            AttributeFilterTask::AgentAssertion(agent_task) => {
406                let context_builder =
407                    AgentContextBuilder::from_context(&parsed, agent_task.provider.as_ref())?;
408                let resolved = context_builder.build_context(&agent_task.assertion)?;
409                let result = AssertionEvaluator::evaluate_assertion(&resolved, agent_task)?;
410                Ok(result.passed)
411            }
412        }
413    }
414
415    // Helper methods
416    fn map_status_code(&self, code: i32) -> SpanStatus {
417        match code {
418            0 => SpanStatus::Unset,
419            1 => SpanStatus::Ok,
420            2 => SpanStatus::Error,
421            _ => SpanStatus::Unset,
422        }
423    }
424
425    fn attribute_value_matches(&self, attr_value: &Value, expected: &Value) -> bool {
426        attr_value == expected
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use scouter_types::genai::PyValueWrapper;
433
434    use super::*;
435
436    use potato_head::Provider;
437    use scouter_mocks::{
438        create_adk_agent_trace, create_gemini_agent_trace, create_multi_service_trace,
439        create_nested_trace, create_sequence_pattern_trace, create_simple_trace,
440        create_trace_with_attributes, create_trace_with_errors,
441    };
442    use scouter_types::genai::{
443        AgentAssertion, AgentAssertionTask, AssertionTask, ComparisonOperator, EvaluationTaskType,
444    };
445
446    #[test]
447    fn test_simple_trace_structure() {
448        let spans = create_simple_trace();
449        assert_eq!(spans.len(), 3);
450        assert_eq!(spans[0].span_name, "root");
451        assert_eq!(spans[0].depth, 0);
452        assert_eq!(
453            spans[1].parent_span_id,
454            Some("7370616e5f300000".to_string()) // span_0 as hex from span_id
455        );
456    }
457
458    #[test]
459    fn test_nested_trace_depth() {
460        let spans = create_nested_trace();
461        let builder = TraceContextBuilder::new(Arc::new(spans));
462        assert_eq!(builder.calculate_max_depth(), 2);
463    }
464
465    #[test]
466    fn test_error_counting() {
467        let spans = create_trace_with_errors();
468        let builder = TraceContextBuilder::new(Arc::new(spans));
469        assert_eq!(builder.count_error_spans(), 1);
470    }
471
472    #[test]
473    fn test_attribute_filtering() {
474        let spans = create_trace_with_attributes();
475        let builder = TraceContextBuilder::new(Arc::new(spans));
476
477        let filter = SpanFilter::WithAttribute {
478            key: "model".to_string(),
479        };
480
481        let result = builder.span_exists(&filter).unwrap();
482        assert!(result);
483    }
484
485    #[test]
486    fn test_sequence_pattern_detection() {
487        let spans = create_sequence_pattern_trace();
488        let builder = TraceContextBuilder::new(Arc::new(spans));
489
490        let filter = SpanFilter::Sequence {
491            names: vec!["call_tool".to_string(), "run_agent".to_string()],
492        };
493
494        let count = builder.count_spans(&filter).unwrap();
495        assert_eq!(count, 2);
496    }
497
498    #[test]
499    fn test_multi_service_trace() {
500        let spans = create_multi_service_trace();
501        let builder = TraceContextBuilder::new(Arc::new(spans));
502        assert_eq!(builder.count_unique_services(), 3);
503    }
504
505    #[test]
506    fn test_aggregation_with_numeric_attributes() {
507        let spans = create_trace_with_attributes();
508        let builder = TraceContextBuilder::new(Arc::new(spans));
509
510        let filter = SpanFilter::WithAttribute {
511            key: "tokens.input".to_string(),
512        };
513
514        let result = builder
515            .aggregate_span_attribute(&filter, "tokens.input", &AggregationType::Sum)
516            .unwrap();
517
518        assert_eq!(result, json!(150.0));
519    }
520
521    #[test]
522    fn test_trace_assertion_span_sequence_evaluation() {
523        let spans = create_simple_trace();
524        let builder = TraceContextBuilder::new(Arc::new(spans));
525
526        let assertion = TraceAssertion::SpanSequence {
527            span_names: vec![
528                "root".to_string(),
529                "child_1".to_string(),
530                "child_2".to_string(),
531            ],
532        };
533
534        let context = builder.build_context(&assertion).unwrap();
535        assert_eq!(context, json!(true));
536    }
537
538    #[test]
539    fn test_trace_assertion_span_set_evaluation() {
540        let spans = create_simple_trace();
541        let builder = TraceContextBuilder::new(Arc::new(spans));
542
543        let assertion = TraceAssertion::SpanSet {
544            span_names: vec![
545                "root".to_string(),
546                "child_1".to_string(),
547                "child_2".to_string(),
548            ],
549        };
550
551        let context = builder.build_context(&assertion).unwrap();
552        assert_eq!(context, json!(true));
553    }
554
555    #[test]
556    fn test_trace_assertion_span_count() {
557        let spans = create_simple_trace();
558        let builder = TraceContextBuilder::new(Arc::new(spans));
559
560        let filter = SpanFilter::ByName {
561            name: "child_1".to_string(),
562        };
563
564        let assertion = TraceAssertion::SpanCount { filter };
565
566        let context = builder.build_context(&assertion).unwrap();
567        assert_eq!(context, json!(1));
568
569        // Test with name pattern
570        let filter_pattern = SpanFilter::ByNamePattern {
571            pattern: "^child_.*".to_string(),
572        };
573
574        let assertion_pattern = TraceAssertion::SpanCount {
575            filter: filter_pattern,
576        };
577        let context_pattern = builder.build_context(&assertion_pattern).unwrap();
578        assert_eq!(context_pattern, json!(2));
579
580        // Test span count with attribute filter
581        let trace_with_attributes = create_trace_with_attributes();
582        let builder_attr = TraceContextBuilder::new(Arc::new(trace_with_attributes));
583
584        let filter_attr = SpanFilter::WithAttribute {
585            key: "model".to_string(),
586        };
587
588        let assertion_attr = TraceAssertion::SpanCount {
589            filter: filter_attr,
590        };
591        let context_attr = builder_attr.build_context(&assertion_attr).unwrap();
592        assert_eq!(context_attr, json!(1));
593
594        // Test span count with attribute value filter
595        let filter_attr_value = SpanFilter::WithAttributeValue {
596            key: "http.method".to_string(),
597            value: PyValueWrapper(json!("POST")),
598        };
599
600        let assertion_attr_value = TraceAssertion::SpanCount {
601            filter: filter_attr_value,
602        };
603        let context_attr_value = builder_attr.build_context(&assertion_attr_value).unwrap();
604        assert_eq!(context_attr_value, json!(1));
605
606        // test span count with status filter
607        let filter_status = SpanFilter::WithStatus {
608            status: SpanStatus::Ok,
609        };
610        let assertion_status = TraceAssertion::SpanCount {
611            filter: filter_status,
612        };
613        let context_status = builder_attr.build_context(&assertion_status).unwrap();
614        assert_eq!(context_status, json!(2));
615
616        // test duration filter
617        let filter_duration = SpanFilter::WithDuration {
618            min_ms: Some(80.0),
619            max_ms: Some(120.0),
620        };
621        let assertion_duration = TraceAssertion::SpanCount {
622            filter: filter_duration,
623        };
624        let context_duration = builder_attr.build_context(&assertion_duration).unwrap();
625        assert_eq!(context_duration, json!(1));
626
627        // test complex AND filter
628        let filter_and = SpanFilter::And {
629            filters: vec![
630                SpanFilter::WithAttribute {
631                    key: "http.method".to_string(),
632                },
633                SpanFilter::WithStatus {
634                    status: SpanStatus::Ok,
635                },
636            ],
637        };
638        let assertion_and = TraceAssertion::SpanCount { filter: filter_and };
639        let context_and = builder_attr.build_context(&assertion_and).unwrap();
640        assert_eq!(context_and, json!(1));
641
642        // test complex OR filter
643        let filter_or = SpanFilter::Or {
644            filters: vec![
645                SpanFilter::WithAttributeValue {
646                    key: "http.method".to_string(),
647                    value: PyValueWrapper(json!("GET")),
648                },
649                SpanFilter::WithAttributeValue {
650                    key: "model".to_string(),
651                    value: PyValueWrapper(json!("gpt-4")),
652                },
653            ],
654        };
655        let assertion_or = TraceAssertion::SpanCount { filter: filter_or };
656        let context_or = builder_attr.build_context(&assertion_or).unwrap();
657        assert_eq!(context_or, json!(1));
658    }
659
660    #[test]
661    fn test_span_exists() {
662        let spans = create_simple_trace();
663        let builder = TraceContextBuilder::new(Arc::new(spans));
664        let filter = SpanFilter::ByName {
665            name: "child_1".to_string(),
666        };
667        let assertion = TraceAssertion::SpanExists { filter };
668        let context = builder.build_context(&assertion).unwrap();
669        assert_eq!(context, json!(true));
670    }
671
672    #[test]
673    fn test_span_attribute() {
674        // test model
675        let spans = create_trace_with_attributes();
676        let builder = TraceContextBuilder::new(Arc::new(spans));
677        let filter = SpanFilter::ByName {
678            name: "api_call".to_string(),
679        };
680        let assertion = TraceAssertion::SpanAttribute {
681            filter,
682            attribute_key: "model".to_string(),
683        };
684        let context = builder.build_context(&assertion).unwrap();
685        assert_eq!(context, json!("gpt-4"));
686
687        // check response
688        let spans = create_trace_with_attributes();
689        let builder = TraceContextBuilder::new(Arc::new(spans));
690        let filter = SpanFilter::ByName {
691            name: "api_call".to_string(),
692        };
693        let assertion = TraceAssertion::SpanAttribute {
694            filter,
695            attribute_key: "response".to_string(),
696        };
697        let context = builder.build_context(&assertion).unwrap();
698        assert_eq!(context, json!({"success": true, "data": {"id": 12345}}));
699    }
700
701    #[test]
702    fn test_span_attribute_aggregation() {
703        let spans = create_trace_with_attributes();
704        let builder = TraceContextBuilder::new(Arc::new(spans));
705        let filter = SpanFilter::ByName {
706            name: "api_call".to_string(),
707        };
708        let assertion = TraceAssertion::SpanAggregation {
709            filter,
710            attribute_key: "tokens.output".to_string(),
711            aggregation: AggregationType::Sum,
712        };
713        let context = builder.build_context(&assertion).unwrap();
714        assert_eq!(context, json!(300.0));
715    }
716
717    /// check common sequence patterns
718    #[test]
719    fn test_sequence_pattern_counting() {
720        // count how often "call_tool" followed by "run_agent" occurs
721        let spans = create_sequence_pattern_trace();
722        let builder = TraceContextBuilder::new(Arc::new(spans));
723        let filter = SpanFilter::Sequence {
724            names: vec!["call_tool".to_string(), "run_agent".to_string()],
725        };
726        let assertion = TraceAssertion::SpanCount { filter };
727        let context = builder.build_context(&assertion).unwrap();
728        assert_eq!(context, json!(2));
729
730        // count how often "call_tool" occurs
731        let spans = create_sequence_pattern_trace();
732        let builder = TraceContextBuilder::new(Arc::new(spans));
733        let filter = SpanFilter::ByName {
734            name: "call_tool".to_string(),
735        };
736        let assertion = TraceAssertion::SpanCount { filter };
737        let context = builder.build_context(&assertion).unwrap();
738        assert_eq!(context, json!(2));
739    }
740
741    #[test]
742    fn test_attribute_filter_agent_assertion_any() {
743        let spans = create_gemini_agent_trace();
744        let builder = TraceContextBuilder::new(Arc::new(spans));
745
746        // Check that at least one span has a tool call to "transfer_to_agent"
747        let assertion = TraceAssertion::AttributeFilter {
748            key: "gen_ai.response".to_string(),
749            task: AttributeFilterTask::AgentAssertion(AgentAssertionTask {
750                id: "inner_tool".to_string(),
751                assertion: AgentAssertion::ToolCalled {
752                    name: "transfer_to_agent".to_string(),
753                },
754                operator: ComparisonOperator::Equals,
755                expected_value: json!(true),
756                description: None,
757                depends_on: vec![],
758                task_type: EvaluationTaskType::AgentAssertion,
759                result: None,
760                condition: false,
761                provider: None,
762            }),
763            mode: MultiResponseMode::Any,
764        };
765
766        let context = builder.build_context(&assertion).unwrap();
767        assert_eq!(context, json!(true));
768    }
769
770    #[test]
771    fn test_attribute_filter_agent_assertion_all() {
772        let spans = create_gemini_agent_trace();
773        let builder = TraceContextBuilder::new(Arc::new(spans));
774
775        // Check that ALL spans have total tokens < 5000
776        let assertion = TraceAssertion::AttributeFilter {
777            key: "gen_ai.response".to_string(),
778            task: AttributeFilterTask::AgentAssertion(AgentAssertionTask {
779                id: "inner_tokens".to_string(),
780                assertion: AgentAssertion::ResponseTotalTokens {},
781                operator: ComparisonOperator::LessThan,
782                expected_value: json!(5000),
783                description: None,
784                depends_on: vec![],
785                task_type: EvaluationTaskType::AgentAssertion,
786                result: None,
787                condition: false,
788                provider: None,
789            }),
790            mode: MultiResponseMode::All,
791        };
792
793        let context = builder.build_context(&assertion).unwrap();
794        assert_eq!(context, json!(true));
795    }
796
797    #[test]
798    fn test_attribute_filter_assertion_raw_value() {
799        let spans = create_gemini_agent_trace();
800        let builder = TraceContextBuilder::new(Arc::new(spans));
801
802        // Check that all gen_ai.response values have finishReason == "STOP"
803        let assertion = TraceAssertion::AttributeFilter {
804            key: "gen_ai.response".to_string(),
805            task: AttributeFilterTask::Assertion(AssertionTask {
806                id: "inner_finish".to_string(),
807                context_path: Some("candidates[0].finishReason".to_string()),
808                item_context_path: None,
809                operator: ComparisonOperator::Equals,
810                expected_value: json!("STOP"),
811                description: None,
812                depends_on: vec![],
813                task_type: EvaluationTaskType::Assertion,
814                result: None,
815                condition: false,
816            }),
817            mode: MultiResponseMode::All,
818        };
819
820        let context = builder.build_context(&assertion).unwrap();
821        assert_eq!(context, json!(true));
822    }
823
824    #[test]
825    fn test_attribute_filter_no_matching_spans() {
826        let spans = create_gemini_agent_trace();
827        let builder = TraceContextBuilder::new(Arc::new(spans));
828
829        // No span has "nonexistent" attribute
830        let assertion = TraceAssertion::AttributeFilter {
831            key: "nonexistent".to_string(),
832            task: AttributeFilterTask::Assertion(AssertionTask {
833                id: "inner".to_string(),
834                context_path: None,
835                item_context_path: None,
836                operator: ComparisonOperator::Equals,
837                expected_value: json!(true),
838                description: None,
839                depends_on: vec![],
840                task_type: EvaluationTaskType::Assertion,
841                result: None,
842                condition: false,
843            }),
844            mode: MultiResponseMode::Any,
845        };
846
847        let context = builder.build_context(&assertion).unwrap();
848        assert_eq!(context, json!(false));
849    }
850
851    #[test]
852    fn test_attribute_filter_all_mode_fails_when_one_fails() {
853        let spans = create_gemini_agent_trace();
854        let builder = TraceContextBuilder::new(Arc::new(spans));
855
856        // Only one span has tool calls, so "All" mode should fail for ToolCalled
857        let assertion = TraceAssertion::AttributeFilter {
858            key: "gen_ai.response".to_string(),
859            task: AttributeFilterTask::AgentAssertion(AgentAssertionTask {
860                id: "inner_tool".to_string(),
861                assertion: AgentAssertion::ToolCalled {
862                    name: "transfer_to_agent".to_string(),
863                },
864                operator: ComparisonOperator::Equals,
865                expected_value: json!(true),
866                description: None,
867                depends_on: vec![],
868                task_type: EvaluationTaskType::AgentAssertion,
869                result: None,
870                condition: false,
871                provider: None,
872            }),
873            mode: MultiResponseMode::All,
874        };
875
876        let context = builder.build_context(&assertion).unwrap();
877        // Only one of two spans has this tool call, so All fails
878        assert_eq!(context, json!(false));
879    }
880
881    #[test]
882    fn test_adk_attribute_filter_agent_assertion_any() {
883        let spans = create_adk_agent_trace();
884        let builder = TraceContextBuilder::new(Arc::new(spans));
885
886        // Check that at least one ADK span has a tool call to "transfer_to_agent"
887        let assertion = TraceAssertion::AttributeFilter {
888            key: "gen_ai.response".to_string(),
889            task: AttributeFilterTask::AgentAssertion(AgentAssertionTask {
890                id: "inner_tool_adk".to_string(),
891                assertion: AgentAssertion::ToolCalled {
892                    name: "transfer_to_agent".to_string(),
893                },
894                operator: ComparisonOperator::Equals,
895                expected_value: json!(true),
896                description: None,
897                depends_on: vec![],
898                task_type: EvaluationTaskType::AgentAssertion,
899                result: None,
900                condition: false,
901                provider: Some(Provider::GoogleAdk),
902            }),
903            mode: MultiResponseMode::Any,
904        };
905
906        let context = builder.build_context(&assertion).unwrap();
907        assert_eq!(context, json!(true));
908    }
909
910    #[test]
911    fn test_adk_attribute_filter_assertion_raw_value() {
912        let spans = create_adk_agent_trace();
913        let builder = TraceContextBuilder::new(Arc::new(spans));
914
915        // Check that all ADK gen_ai.response values have finish_reason == "STOP"
916        let assertion = TraceAssertion::AttributeFilter {
917            key: "gen_ai.response".to_string(),
918            task: AttributeFilterTask::Assertion(AssertionTask {
919                id: "inner_finish_adk".to_string(),
920                context_path: Some("finish_reason".to_string()),
921                item_context_path: None,
922                operator: ComparisonOperator::Equals,
923                expected_value: json!("STOP"),
924                description: None,
925                depends_on: vec![],
926                task_type: EvaluationTaskType::Assertion,
927                result: None,
928                condition: false,
929            }),
930            mode: MultiResponseMode::All,
931        };
932
933        let context = builder.build_context(&assertion).unwrap();
934        assert_eq!(context, json!(true));
935    }
936
937    #[test]
938    fn test_adk_attribute_filter_agent_assertion_all() {
939        let spans = create_adk_agent_trace();
940        let builder = TraceContextBuilder::new(Arc::new(spans));
941
942        // Check that ALL ADK spans have model_version containing "gemini"
943        let assertion = TraceAssertion::AttributeFilter {
944            key: "gen_ai.response".to_string(),
945            task: AttributeFilterTask::AgentAssertion(AgentAssertionTask {
946                id: "inner_model_adk".to_string(),
947                assertion: AgentAssertion::ResponseModel {},
948                operator: ComparisonOperator::Contains,
949                expected_value: json!("gemini"),
950                description: None,
951                depends_on: vec![],
952                task_type: EvaluationTaskType::AgentAssertion,
953                result: None,
954                condition: false,
955                provider: Some(Provider::GoogleAdk),
956            }),
957            mode: MultiResponseMode::All,
958        };
959
960        let context = builder.build_context(&assertion).unwrap();
961        assert_eq!(context, json!(true));
962    }
963
964    #[test]
965    fn test_invalid_regex_returns_error() {
966        let spans = create_simple_trace();
967        let builder = TraceContextBuilder::new(Arc::new(spans));
968        let result = builder.build_context(&TraceAssertion::SpanCount {
969            filter: SpanFilter::ByNamePattern {
970                pattern: "[invalid".to_string(),
971            },
972        });
973        assert!(result.is_err());
974    }
975
976    #[test]
977    fn test_sequence_filter_in_span_exists_returns_invalid_filter_error() {
978        let spans = create_simple_trace();
979        let builder = TraceContextBuilder::new(Arc::new(spans));
980        let result = builder.build_context(&TraceAssertion::SpanExists {
981            filter: SpanFilter::Sequence {
982                names: vec!["root".to_string()],
983            },
984        });
985        assert!(matches!(result, Err(EvaluationError::InvalidFilter(_))));
986    }
987
988    #[test]
989    fn test_extract_trace_attribute_no_root_span() {
990        let mut spans = create_simple_trace();
991        for span in &mut spans {
992            span.depth = 1;
993        }
994        let builder = TraceContextBuilder::new(Arc::new(spans));
995        let result = builder.build_context(&TraceAssertion::TraceAttribute {
996            attribute_key: "model".to_string(),
997        });
998        assert!(matches!(result, Err(EvaluationError::NoRootSpan)));
999    }
1000
1001    #[test]
1002    fn test_extract_trace_attribute_key_not_found() {
1003        let spans = create_simple_trace();
1004        let builder = TraceContextBuilder::new(Arc::new(spans));
1005        let result = builder.build_context(&TraceAssertion::TraceAttribute {
1006            attribute_key: "nonexistent_key_xyz".to_string(),
1007        });
1008        assert!(matches!(result, Err(EvaluationError::AttributeNotFound(_))));
1009    }
1010
1011    #[test]
1012    fn test_aggregate_null_when_no_numeric_values() {
1013        let spans = create_simple_trace();
1014        let builder = TraceContextBuilder::new(Arc::new(spans));
1015        let filter = SpanFilter::ByName {
1016            name: "root".to_string(),
1017        };
1018        let result = builder
1019            .aggregate_span_attribute(&filter, "nonexistent_numeric", &AggregationType::Sum)
1020            .unwrap();
1021        assert_eq!(result, Value::Null);
1022    }
1023
1024    #[test]
1025    fn test_extract_span_attribute_multi_span_returns_array() {
1026        // create_gemini_agent_trace has 2 spans with gen_ai.response
1027        let spans = create_gemini_agent_trace();
1028        let builder = TraceContextBuilder::new(Arc::new(spans));
1029        let filter = SpanFilter::WithAttribute {
1030            key: "gen_ai.response".to_string(),
1031        };
1032        let result = builder
1033            .extract_span_attribute(&filter, "gen_ai.response")
1034            .unwrap();
1035        assert!(result.is_array());
1036    }
1037}