Skip to main content

scouter_evaluate/evaluate/
trace.rs

1use crate::error::EvaluationError;
2use regex::Regex;
3/// Core logic of evaluation trace spans as part of TraceAssertionTask
4///
5/// use scouter_types::sql::TraceSpan;
6use scouter_types::genai::{AggregationType, SpanFilter, SpanStatus, TraceAssertion};
7use scouter_types::sql::TraceSpan;
8use serde_json::{json, Value};
9use std::collections::HashSet;
10use std::sync::Arc;
11use tracing::debug;
12
13#[derive(Debug, Clone)]
14pub struct TraceContextBuilder {
15    /// We want to share trace spans across multiple evaluations
16    pub(crate) spans: Arc<Vec<TraceSpan>>,
17}
18
19impl TraceContextBuilder {
20    pub fn new(spans: Arc<Vec<TraceSpan>>) -> Self {
21        Self { spans }
22    }
23
24    /// Converts trace data into a JSON context that AssertionEvaluator can process
25    pub fn build_context(&self, assertion: &TraceAssertion) -> Result<Value, EvaluationError> {
26        match assertion {
27            TraceAssertion::SpanSequence { span_names } => {
28                Ok(json!(self.match_span_sequence(span_names)?))
29            }
30            TraceAssertion::SpanSet { span_names } => Ok(json!(self.match_span_set(span_names)?)),
31            TraceAssertion::SpanCount { filter } => Ok(json!(self.count_spans(filter)?)),
32            TraceAssertion::SpanExists { filter } => Ok(json!(self.span_exists(filter)?)),
33            TraceAssertion::SpanAttribute {
34                filter,
35                attribute_key,
36            } => self.extract_span_attribute(filter, attribute_key),
37            TraceAssertion::SpanDuration { filter } => self.extract_span_duration(filter),
38            TraceAssertion::SpanAggregation {
39                filter,
40                attribute_key,
41                aggregation,
42            } => self.aggregate_span_attribute(filter, attribute_key, aggregation),
43            TraceAssertion::TraceDuration {} => Ok(json!(self.calculate_trace_duration())),
44            TraceAssertion::TraceSpanCount {} => Ok(json!(self.spans.len())),
45            TraceAssertion::TraceErrorCount {} => Ok(json!(self.count_error_spans())),
46            TraceAssertion::TraceServiceCount {} => Ok(json!(self.count_unique_services())),
47            TraceAssertion::TraceMaxDepth {} => Ok(json!(self.calculate_max_depth())),
48            TraceAssertion::TraceAttribute { attribute_key } => {
49                self.extract_trace_attribute(attribute_key)
50            }
51        }
52    }
53
54    // Span filtering logic
55    fn filter_spans(&self, filter: &SpanFilter) -> Result<Vec<&TraceSpan>, EvaluationError> {
56        let mut filtered = Vec::new();
57
58        for span in self.spans.iter() {
59            if self.matches_filter(span, filter)? {
60                filtered.push(span);
61            }
62        }
63
64        debug!(
65            "Filtered spans count: {} with filter {:?}",
66            filtered.len(),
67            filter
68        );
69
70        Ok(filtered)
71    }
72
73    fn matches_filter(
74        &self,
75        span: &TraceSpan,
76        filter: &SpanFilter,
77    ) -> Result<bool, EvaluationError> {
78        match filter {
79            SpanFilter::ByName { name } => Ok(span.span_name == *name),
80
81            SpanFilter::ByNamePattern { pattern } => {
82                let regex = Regex::new(pattern)?;
83                Ok(regex.is_match(&span.span_name))
84            }
85
86            SpanFilter::WithAttribute { key } => {
87                Ok(span.attributes.iter().any(|attr| attr.key == *key))
88            }
89
90            SpanFilter::WithAttributeValue { key, value } => {
91                Ok(span.attributes.iter().any(|attr| {
92                    attr.key == *key && self.attribute_value_matches(&attr.value, &value.0)
93                }))
94            }
95
96            SpanFilter::WithStatus { status } => {
97                Ok(self.map_status_code(span.status_code) == *status)
98            }
99
100            SpanFilter::WithDuration { min_ms, max_ms } => {
101                let duration_f64 = span.duration_ms as f64;
102                let min_ok = min_ms.is_none_or(|min| duration_f64 >= min);
103                let max_ok = max_ms.is_none_or(|max| duration_f64 <= max);
104                Ok(min_ok && max_ok)
105            }
106
107            SpanFilter::And { filters } => {
108                for f in filters {
109                    if !self.matches_filter(span, f)? {
110                        return Ok(false);
111                    }
112                }
113                Ok(true)
114            }
115
116            SpanFilter::Or { filters } => {
117                for f in filters {
118                    if self.matches_filter(span, f)? {
119                        return Ok(true);
120                    }
121                }
122                Ok(false)
123            }
124
125            SpanFilter::Sequence { .. } => Err(EvaluationError::InvalidFilter(
126                "Sequence filter not applicable to individual spans".to_string(),
127            )),
128        }
129    }
130
131    /// Get ordered list of span names
132    fn match_span_sequence(&self, span_names: &[String]) -> Result<bool, EvaluationError> {
133        let executed_names = self.get_ordered_span_names()?;
134        Ok(executed_names == span_names)
135    }
136
137    /// Get unique set of span names. Order does not matter.
138    fn match_span_set(&self, span_names: &[String]) -> Result<bool, EvaluationError> {
139        let unique_names: HashSet<_> = self.spans.iter().map(|s| s.span_name.clone()).collect();
140        for name in span_names {
141            if !unique_names.contains(name) {
142                return Ok(false);
143            }
144        }
145        Ok(true)
146    }
147
148    fn count_spans(&self, filter: &SpanFilter) -> Result<usize, EvaluationError> {
149        match filter {
150            SpanFilter::Sequence { names } => self.count_sequence_occurrences(names),
151            _ => Ok(self.filter_spans(filter)?.len()),
152        }
153    }
154
155    /// Count how many times a specific sequence of span names appears consecutively
156    fn count_sequence_occurrences(
157        &self,
158        target_sequence: &[String],
159    ) -> Result<usize, EvaluationError> {
160        if target_sequence.is_empty() {
161            return Ok(0);
162        }
163
164        let all_span_names = self.get_ordered_span_names()?;
165
166        if all_span_names.len() < target_sequence.len() {
167            return Ok(0);
168        }
169
170        Ok(all_span_names
171            .windows(target_sequence.len())
172            .filter(|window| *window == target_sequence)
173            .count())
174    }
175
176    fn get_ordered_span_names(&self) -> Result<Vec<String>, EvaluationError> {
177        let mut ordered_spans: Vec<_> = self.spans.iter().collect();
178        ordered_spans.sort_by_key(|s| s.span_order);
179
180        Ok(ordered_spans
181            .into_iter()
182            .map(|s| s.span_name.clone())
183            .collect())
184    }
185
186    fn span_exists(&self, filter: &SpanFilter) -> Result<bool, EvaluationError> {
187        Ok(!self.filter_spans(filter)?.is_empty())
188    }
189
190    fn extract_span_attribute(
191        &self,
192        filter: &SpanFilter,
193        attribute_key: &str,
194    ) -> Result<Value, EvaluationError> {
195        let filtered_spans = self.filter_spans(filter)?;
196
197        if filtered_spans.is_empty() {
198            return Ok(Value::Null);
199        }
200
201        let values: Vec<Value> = filtered_spans
202            .iter()
203            .filter_map(|span| {
204                span.attributes
205                    .iter()
206                    .find(|attr| attr.key == attribute_key)
207                    .map(|attr| attr.value.clone())
208            })
209            .collect();
210
211        if values.len() == 1 {
212            Ok(values[0].clone())
213        } else {
214            Ok(Value::Array(values))
215        }
216    }
217
218    fn extract_span_duration(&self, filter: &SpanFilter) -> Result<Value, EvaluationError> {
219        let filtered_spans = self.filter_spans(filter)?;
220
221        let durations: Vec<i64> = filtered_spans.iter().map(|span| span.duration_ms).collect();
222
223        if durations.len() == 1 {
224            Ok(json!(durations[0]))
225        } else {
226            Ok(json!(durations))
227        }
228    }
229
230    fn aggregate_span_attribute(
231        &self,
232        filter: &SpanFilter,
233        attribute_key: &str,
234        aggregation: &AggregationType,
235    ) -> Result<Value, EvaluationError> {
236        let filtered_spans = self.filter_spans(filter)?;
237
238        match aggregation {
239            AggregationType::Count => {
240                let count = filtered_spans
241                    .iter()
242                    .filter(|span| span.attributes.iter().any(|attr| attr.key == attribute_key))
243                    .count();
244                Ok(json!(count))
245            }
246            _ => {
247                let values: Vec<f64> = filtered_spans
248                    .iter()
249                    .filter_map(|span| {
250                        span.attributes
251                            .iter()
252                            .find(|attr| attr.key == attribute_key)
253                            .and_then(|attr| attr.value.as_f64())
254                    })
255                    .collect();
256
257                if values.is_empty() {
258                    return Ok(Value::Null);
259                }
260
261                let result = match aggregation {
262                    AggregationType::Count => unreachable!(),
263                    AggregationType::Sum => values.iter().sum(),
264                    AggregationType::Average => values.iter().sum::<f64>() / values.len() as f64,
265                    AggregationType::Min => values.iter().copied().fold(f64::INFINITY, f64::min),
266                    AggregationType::Max => {
267                        values.iter().copied().fold(f64::NEG_INFINITY, f64::max)
268                    }
269                    AggregationType::First => values[0],
270                    AggregationType::Last => values[values.len() - 1],
271                };
272
273                Ok(json!(result))
274            }
275        }
276    }
277
278    // Trace-level calculations
279    fn calculate_trace_duration(&self) -> i64 {
280        self.spans.iter().map(|s| s.duration_ms).max().unwrap_or(0)
281    }
282
283    fn count_error_spans(&self) -> usize {
284        self.spans
285            .iter()
286            .filter(|s| s.status_code == 2) // Error status
287            .count()
288    }
289
290    fn count_unique_services(&self) -> usize {
291        self.spans
292            .iter()
293            .map(|s| &s.service_name)
294            .collect::<HashSet<_>>()
295            .len()
296    }
297
298    fn calculate_max_depth(&self) -> i32 {
299        self.spans.iter().map(|s| s.depth).max().unwrap_or(0)
300    }
301
302    fn extract_trace_attribute(&self, attribute_key: &str) -> Result<Value, EvaluationError> {
303        let root_span = self
304            .spans
305            .iter()
306            .find(|s| s.depth == 0)
307            .ok_or_else(|| EvaluationError::NoRootSpan)?;
308
309        root_span
310            .attributes
311            .iter()
312            .find(|attr| attr.key == attribute_key)
313            .map(|attr| attr.value.clone())
314            .ok_or_else(|| EvaluationError::AttributeNotFound(attribute_key.to_string()))
315    }
316
317    // Helper methods
318    fn map_status_code(&self, code: i32) -> SpanStatus {
319        match code {
320            0 => SpanStatus::Unset,
321            1 => SpanStatus::Ok,
322            2 => SpanStatus::Error,
323            _ => SpanStatus::Unset,
324        }
325    }
326
327    fn attribute_value_matches(&self, attr_value: &Value, expected: &Value) -> bool {
328        attr_value == expected
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use scouter_types::genai::PyValueWrapper;
335
336    use super::*;
337
338    use scouter_mocks::{
339        create_multi_service_trace, create_nested_trace, create_sequence_pattern_trace,
340        create_simple_trace, create_trace_with_attributes, create_trace_with_errors,
341    };
342
343    #[test]
344    fn test_simple_trace_structure() {
345        let spans = create_simple_trace();
346        assert_eq!(spans.len(), 3);
347        assert_eq!(spans[0].span_name, "root");
348        assert_eq!(spans[0].depth, 0);
349        assert_eq!(
350            spans[1].parent_span_id,
351            Some("7370616e5f300000".to_string()) // span_0 as hex from span_id
352        );
353    }
354
355    #[test]
356    fn test_nested_trace_depth() {
357        let spans = create_nested_trace();
358        let builder = TraceContextBuilder::new(Arc::new(spans));
359        assert_eq!(builder.calculate_max_depth(), 2);
360    }
361
362    #[test]
363    fn test_error_counting() {
364        let spans = create_trace_with_errors();
365        let builder = TraceContextBuilder::new(Arc::new(spans));
366        assert_eq!(builder.count_error_spans(), 1);
367    }
368
369    #[test]
370    fn test_attribute_filtering() {
371        let spans = create_trace_with_attributes();
372        let builder = TraceContextBuilder::new(Arc::new(spans));
373
374        let filter = SpanFilter::WithAttribute {
375            key: "model".to_string(),
376        };
377
378        let result = builder.span_exists(&filter).unwrap();
379        assert!(result);
380    }
381
382    #[test]
383    fn test_sequence_pattern_detection() {
384        let spans = create_sequence_pattern_trace();
385        let builder = TraceContextBuilder::new(Arc::new(spans));
386
387        let filter = SpanFilter::Sequence {
388            names: vec!["call_tool".to_string(), "run_agent".to_string()],
389        };
390
391        let count = builder.count_spans(&filter).unwrap();
392        assert_eq!(count, 2);
393    }
394
395    #[test]
396    fn test_multi_service_trace() {
397        let spans = create_multi_service_trace();
398        let builder = TraceContextBuilder::new(Arc::new(spans));
399        assert_eq!(builder.count_unique_services(), 3);
400    }
401
402    #[test]
403    fn test_aggregation_with_numeric_attributes() {
404        let spans = create_trace_with_attributes();
405        let builder = TraceContextBuilder::new(Arc::new(spans));
406
407        let filter = SpanFilter::WithAttribute {
408            key: "tokens.input".to_string(),
409        };
410
411        let result = builder
412            .aggregate_span_attribute(&filter, "tokens.input", &AggregationType::Sum)
413            .unwrap();
414
415        assert_eq!(result, json!(150.0));
416    }
417
418    #[test]
419    fn test_trace_assertion_span_sequence_evaluation() {
420        let spans = create_simple_trace();
421        let builder = TraceContextBuilder::new(Arc::new(spans));
422
423        let assertion = TraceAssertion::SpanSequence {
424            span_names: vec![
425                "root".to_string(),
426                "child_1".to_string(),
427                "child_2".to_string(),
428            ],
429        };
430
431        let context = builder.build_context(&assertion).unwrap();
432        assert_eq!(context, json!(true));
433    }
434
435    #[test]
436    fn test_trace_assertion_span_set_evaluation() {
437        let spans = create_simple_trace();
438        let builder = TraceContextBuilder::new(Arc::new(spans));
439
440        let assertion = TraceAssertion::SpanSet {
441            span_names: vec![
442                "root".to_string(),
443                "child_1".to_string(),
444                "child_2".to_string(),
445            ],
446        };
447
448        let context = builder.build_context(&assertion).unwrap();
449        assert_eq!(context, json!(true));
450    }
451
452    #[test]
453    fn test_trace_assertion_span_count() {
454        let spans = create_simple_trace();
455        let builder = TraceContextBuilder::new(Arc::new(spans));
456
457        let filter = SpanFilter::ByName {
458            name: "child_1".to_string(),
459        };
460
461        let assertion = TraceAssertion::SpanCount { filter };
462
463        let context = builder.build_context(&assertion).unwrap();
464        assert_eq!(context, json!(1));
465
466        // Test with name pattern
467        let filter_pattern = SpanFilter::ByNamePattern {
468            pattern: "^child_.*".to_string(),
469        };
470
471        let assertion_pattern = TraceAssertion::SpanCount {
472            filter: filter_pattern,
473        };
474        let context_pattern = builder.build_context(&assertion_pattern).unwrap();
475        assert_eq!(context_pattern, json!(2));
476
477        // Test span count with attribute filter
478        let trace_with_attributes = create_trace_with_attributes();
479        let builder_attr = TraceContextBuilder::new(Arc::new(trace_with_attributes));
480
481        let filter_attr = SpanFilter::WithAttribute {
482            key: "model".to_string(),
483        };
484
485        let assertion_attr = TraceAssertion::SpanCount {
486            filter: filter_attr,
487        };
488        let context_attr = builder_attr.build_context(&assertion_attr).unwrap();
489        assert_eq!(context_attr, json!(1));
490
491        // Test span count with attribute value filter
492        let filter_attr_value = SpanFilter::WithAttributeValue {
493            key: "http.method".to_string(),
494            value: PyValueWrapper(json!("POST")),
495        };
496
497        let assertion_attr_value = TraceAssertion::SpanCount {
498            filter: filter_attr_value,
499        };
500        let context_attr_value = builder_attr.build_context(&assertion_attr_value).unwrap();
501        assert_eq!(context_attr_value, json!(1));
502
503        // test span count with status filter
504        let filter_status = SpanFilter::WithStatus {
505            status: SpanStatus::Ok,
506        };
507        let assertion_status = TraceAssertion::SpanCount {
508            filter: filter_status,
509        };
510        let context_status = builder_attr.build_context(&assertion_status).unwrap();
511        assert_eq!(context_status, json!(2));
512
513        // test duration filter
514        let filter_duration = SpanFilter::WithDuration {
515            min_ms: Some(80.0),
516            max_ms: Some(120.0),
517        };
518        let assertion_duration = TraceAssertion::SpanCount {
519            filter: filter_duration,
520        };
521        let context_duration = builder_attr.build_context(&assertion_duration).unwrap();
522        assert_eq!(context_duration, json!(1));
523
524        // test complex AND filter
525        let filter_and = SpanFilter::And {
526            filters: vec![
527                SpanFilter::WithAttribute {
528                    key: "http.method".to_string(),
529                },
530                SpanFilter::WithStatus {
531                    status: SpanStatus::Ok,
532                },
533            ],
534        };
535        let assertion_and = TraceAssertion::SpanCount { filter: filter_and };
536        let context_and = builder_attr.build_context(&assertion_and).unwrap();
537        assert_eq!(context_and, json!(1));
538
539        // test complex OR filter
540        let filter_or = SpanFilter::Or {
541            filters: vec![
542                SpanFilter::WithAttributeValue {
543                    key: "http.method".to_string(),
544                    value: PyValueWrapper(json!("GET")),
545                },
546                SpanFilter::WithAttributeValue {
547                    key: "model".to_string(),
548                    value: PyValueWrapper(json!("gpt-4")),
549                },
550            ],
551        };
552        let assertion_or = TraceAssertion::SpanCount { filter: filter_or };
553        let context_or = builder_attr.build_context(&assertion_or).unwrap();
554        assert_eq!(context_or, json!(1));
555    }
556
557    #[test]
558    fn test_span_exists() {
559        let spans = create_simple_trace();
560        let builder = TraceContextBuilder::new(Arc::new(spans));
561        let filter = SpanFilter::ByName {
562            name: "child_1".to_string(),
563        };
564        let assertion = TraceAssertion::SpanExists { filter };
565        let context = builder.build_context(&assertion).unwrap();
566        assert_eq!(context, json!(true));
567    }
568
569    #[test]
570    fn test_span_attribute() {
571        // test model
572        let spans = create_trace_with_attributes();
573        let builder = TraceContextBuilder::new(Arc::new(spans));
574        let filter = SpanFilter::ByName {
575            name: "api_call".to_string(),
576        };
577        let assertion = TraceAssertion::SpanAttribute {
578            filter,
579            attribute_key: "model".to_string(),
580        };
581        let context = builder.build_context(&assertion).unwrap();
582        assert_eq!(context, json!("gpt-4"));
583
584        // check response
585        let spans = create_trace_with_attributes();
586        let builder = TraceContextBuilder::new(Arc::new(spans));
587        let filter = SpanFilter::ByName {
588            name: "api_call".to_string(),
589        };
590        let assertion = TraceAssertion::SpanAttribute {
591            filter,
592            attribute_key: "response".to_string(),
593        };
594        let context = builder.build_context(&assertion).unwrap();
595        assert_eq!(context, json!({"success": true, "data": {"id": 12345}}));
596    }
597
598    #[test]
599    fn test_span_attribute_aggregation() {
600        let spans = create_trace_with_attributes();
601        let builder = TraceContextBuilder::new(Arc::new(spans));
602        let filter = SpanFilter::ByName {
603            name: "api_call".to_string(),
604        };
605        let assertion = TraceAssertion::SpanAggregation {
606            filter,
607            attribute_key: "tokens.output".to_string(),
608            aggregation: AggregationType::Sum,
609        };
610        let context = builder.build_context(&assertion).unwrap();
611        assert_eq!(context, json!(300.0));
612    }
613
614    /// check common sequence patterns
615    #[test]
616    fn test_sequence_pattern_counting() {
617        // count how often "call_tool" followed by "run_agent" occurs
618        let spans = create_sequence_pattern_trace();
619        let builder = TraceContextBuilder::new(Arc::new(spans));
620        let filter = SpanFilter::Sequence {
621            names: vec!["call_tool".to_string(), "run_agent".to_string()],
622        };
623        let assertion = TraceAssertion::SpanCount { filter };
624        let context = builder.build_context(&assertion).unwrap();
625        assert_eq!(context, json!(2));
626
627        // count how often "call_tool" occurs
628        let spans = create_sequence_pattern_trace();
629        let builder = TraceContextBuilder::new(Arc::new(spans));
630        let filter = SpanFilter::ByName {
631            name: "call_tool".to_string(),
632        };
633        let assertion = TraceAssertion::SpanCount { filter };
634        let context = builder.build_context(&assertion).unwrap();
635        assert_eq!(context, json!(2));
636    }
637}