Skip to main content

trace_weft/
eval.rs

1use std::sync::{Arc, Mutex};
2use trace_weft_core::{EventRecord, SpanRecord, TraceWeftSpanKind};
3use trace_weft_recorder::TraceStore;
4
5/// In-memory trace store designed specifically for capturing spans during local unit tests
6/// and evaluation pipelines.
7#[derive(Clone, Default)]
8pub struct MemoryStore {
9    pub spans: Arc<Mutex<Vec<SpanRecord>>>,
10    pub events: Arc<Mutex<Vec<EventRecord>>>,
11}
12
13impl MemoryStore {
14    pub fn new() -> Self {
15        Self::default()
16    }
17
18    pub fn get_trajectory(&self) -> TraceTrajectory {
19        let spans = self.spans.lock().unwrap().clone();
20        TraceTrajectory { spans }
21    }
22
23    pub fn clear(&self) {
24        self.spans.lock().unwrap().clear();
25        self.events.lock().unwrap().clear();
26    }
27}
28
29#[async_trait::async_trait]
30impl TraceStore for MemoryStore {
31    async fn record_span(&self, span: SpanRecord) -> anyhow::Result<()> {
32        self.spans.lock().unwrap().push(span);
33        Ok(())
34    }
35
36    async fn record_event(&self, event: EventRecord) -> anyhow::Result<()> {
37        self.events.lock().unwrap().push(event);
38        Ok(())
39    }
40}
41
42/// A wrapper around a collection of spans to facilitate easy trajectory assertions.
43pub struct TraceTrajectory {
44    pub spans: Vec<SpanRecord>,
45}
46
47impl TraceTrajectory {
48    /// Checks if a specific tool was called during the trace.
49    pub fn contains_tool_call(&self, tool_name: &str) -> bool {
50        self.spans
51            .iter()
52            .any(|s| s.span_kind == TraceWeftSpanKind::Tool && s.name == tool_name)
53    }
54
55    /// Checks if an error span was recorded.
56    pub fn has_errors(&self) -> bool {
57        self.spans.iter().any(|s| {
58            s.status == trace_weft_core::SpanStatus::Error
59                || s.span_kind == TraceWeftSpanKind::Error
60        })
61    }
62
63    /// Calculates the total cost estimate of all spans in the trajectory.
64    pub fn total_cost(&self) -> f64 {
65        self.spans
66            .iter()
67            .filter_map(|s| s.cost_estimate.as_ref())
68            .map(|c| c.amount)
69            .sum()
70    }
71
72    /// Returns the latency of the root workflow/agent span.
73    pub fn total_latency_ms(&self) -> u64 {
74        self.spans
75            .iter()
76            .filter(|s| s.parent_span_id.is_none())
77            .map(|s| s.latency_ms.unwrap_or(0))
78            .sum()
79    }
80
81    /// Returns the total number of input tokens consumed across all LLM calls.
82    pub fn total_input_tokens(&self) -> u64 {
83        self.spans
84            .iter()
85            .filter_map(|s| s.token_usage.as_ref())
86            .map(|u| u.input)
87            .sum()
88    }
89}