tower_llm/recording/
mod.rs

1//! Recording and replay I/O surfaces
2//!
3//! What this module provides (spec)
4//! - Pluggable I/O surfaces for capturing and replaying runs
5//! - Integrates with `codec` for lossless message/event fidelity
6//!
7//! Exports
8//! - Services
9//!   - `TraceWriter: Service<WriteTrace { id, items }, Response=()>`
10//!   - `TraceReader: Service<ReadTrace { id }, Response=Trace>`
11//!   - `ReplayService: Service<RawChatRequest, Response=StepOutcome>` (reads from a `Trace`)
12//! - Layers
13//!   - `RecorderLayer<S>` taps `StepOutcome`/`AgentRun` and writes via `TraceWriter`
14//! - Utils
15//!   - Trace format (ndjson), `TraceVersion`, integrity checks (hashes)
16//!
17//! Implementation strategy
18//! - `RecorderLayer` calls `codec::messages_to_items`/`items_to_messages` as needed
19//! - `ReplayService` reads precomputed outcomes and serves them in sequence (for step) or as a final run (for agent)
20//! - Writers/readers are constructor-injected services supporting file/db backends
21//!
22//! Composition
23//! - `ServiceBuilder::new().layer(RecorderLayer::new(writer)).service(step)`
24//! - Or: `let agent = ReplayService::new(reader, trace_id);`
25//!
26//! Testing strategy
27//! - Roundtrip with fake provider: live run → record → replay; assert same final messages/events
28//! - Corruption tests: invalid trace produces explicit error
29
30use std::collections::HashMap;
31use std::future::Future;
32use std::pin::Pin;
33use std::sync::Arc;
34
35use async_openai::types::{CreateChatCompletionRequest, CreateChatCompletionRequestArgs};
36use tokio::sync::Mutex;
37use tower::{BoxError, Layer, Service, ServiceExt};
38
39use crate::codec::{items_to_messages, messages_to_items};
40use crate::core::StepOutcome;
41use crate::items::RunItem;
42
43#[derive(Debug, Clone)]
44pub struct WriteTrace {
45    pub id: String,
46    pub items: Vec<RunItem>,
47}
48#[derive(Debug, Clone)]
49pub struct ReadTrace {
50    pub id: String,
51}
52#[derive(Debug, Clone, Default)]
53pub struct Trace {
54    pub items: Vec<RunItem>,
55}
56
57/// Simple in-memory trace store for tests.
58#[derive(Default, Clone)]
59pub struct InMemoryTraceStore(Arc<Mutex<HashMap<String, Trace>>>);
60
61impl Service<WriteTrace> for InMemoryTraceStore {
62    type Response = ();
63    type Error = BoxError;
64    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
65    fn poll_ready(
66        &mut self,
67        _cx: &mut std::task::Context<'_>,
68    ) -> std::task::Poll<Result<(), Self::Error>> {
69        std::task::Poll::Ready(Ok(()))
70    }
71    fn call(&mut self, req: WriteTrace) -> Self::Future {
72        let store = self.0.clone();
73        Box::pin(async move {
74            store
75                .lock()
76                .await
77                .insert(req.id, Trace { items: req.items });
78            Ok(())
79        })
80    }
81}
82
83impl Service<ReadTrace> for InMemoryTraceStore {
84    type Response = Trace;
85    type Error = BoxError;
86    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
87    fn poll_ready(
88        &mut self,
89        _cx: &mut std::task::Context<'_>,
90    ) -> std::task::Poll<Result<(), Self::Error>> {
91        std::task::Poll::Ready(Ok(()))
92    }
93    fn call(&mut self, req: ReadTrace) -> Self::Future {
94        let store = self.0.clone();
95        Box::pin(async move {
96            let trace = store.lock().await.get(&req.id).cloned().unwrap_or_default();
97            Ok(trace)
98        })
99    }
100}
101
102/// Recorder layer that captures step outcomes into a trace writer.
103pub struct RecorderLayer<W> {
104    writer: W,
105    trace_id: String,
106}
107impl<W> RecorderLayer<W> {
108    pub fn new(writer: W, trace_id: impl Into<String>) -> Self {
109        Self {
110            writer,
111            trace_id: trace_id.into(),
112        }
113    }
114}
115
116pub struct Recorder<S, W> {
117    inner: S,
118    writer: W,
119    trace_id: String,
120}
121
122impl<S, W> Layer<S> for RecorderLayer<W>
123where
124    W: Clone,
125{
126    type Service = Recorder<S, W>;
127    fn layer(&self, inner: S) -> Self::Service {
128        Recorder {
129            inner,
130            writer: self.writer.clone(),
131            trace_id: self.trace_id.clone(),
132        }
133    }
134}
135
136impl<S, W> Service<CreateChatCompletionRequest> for Recorder<S, W>
137where
138    S: Service<CreateChatCompletionRequest, Response = StepOutcome, Error = BoxError>
139        + Send
140        + 'static,
141    S::Future: Send + 'static,
142    W: Service<WriteTrace, Response = (), Error = BoxError> + Clone + Send + 'static,
143    W::Future: Send + 'static,
144{
145    type Response = StepOutcome;
146    type Error = BoxError;
147    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
148
149    fn poll_ready(
150        &mut self,
151        cx: &mut std::task::Context<'_>,
152    ) -> std::task::Poll<Result<(), Self::Error>> {
153        self.inner.poll_ready(cx)
154    }
155
156    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
157        let mut writer = self.writer.clone();
158        let trace_id = self.trace_id.clone();
159        let fut = self.inner.call(req);
160        Box::pin(async move {
161            let out = fut.await?;
162            let messages = match &out {
163                StepOutcome::Next { messages, .. } | StepOutcome::Done { messages, .. } => {
164                    messages.clone()
165                }
166            };
167            let items = messages_to_items(&messages).map_err(|e| format!("codec: {}", e))?;
168            ServiceExt::ready(&mut writer)
169                .await?
170                .call(WriteTrace {
171                    id: trace_id,
172                    items,
173                })
174                .await?;
175            Ok(out)
176        })
177    }
178}
179
180/// Service that replays a stored trace as a `StepOutcome::Done` using codec reconstruction.
181pub struct ReplayService<R> {
182    reader: R,
183    trace_id: String,
184    model: String,
185}
186impl<R> ReplayService<R> {
187    pub fn new(reader: R, trace_id: impl Into<String>, model: impl Into<String>) -> Self {
188        Self {
189            reader,
190            trace_id: trace_id.into(),
191            model: model.into(),
192        }
193    }
194}
195
196impl<R> Service<CreateChatCompletionRequest> for ReplayService<R>
197where
198    R: Service<ReadTrace, Response = Trace, Error = BoxError> + Send + Clone + 'static,
199    R::Future: Send + 'static,
200{
201    type Response = StepOutcome;
202    type Error = BoxError;
203    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
204
205    fn poll_ready(
206        &mut self,
207        _cx: &mut std::task::Context<'_>,
208    ) -> std::task::Poll<Result<(), Self::Error>> {
209        std::task::Poll::Ready(Ok(()))
210    }
211
212    fn call(&mut self, _req: CreateChatCompletionRequest) -> Self::Future {
213        let mut reader = self.reader.clone();
214        let trace_id = self.trace_id.clone();
215        let model = self.model.clone();
216        Box::pin(async move {
217            let trace = Service::call(&mut reader, ReadTrace { id: trace_id }).await?;
218            let messages = items_to_messages(&trace.items);
219            let _req = CreateChatCompletionRequestArgs::default()
220                .model(model)
221                .messages(messages.clone())
222                .build()?;
223            // Return Done with reconstructed messages
224            Ok(StepOutcome::Done {
225                messages,
226                aux: Default::default(),
227            })
228        })
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::validation::{gen, validate_conversation, ValidationPolicy};
236    use async_openai::types::ChatCompletionRequestUserMessageArgs;
237    use proptest::prop_assert;
238    use tower::service_fn;
239
240    fn req_with_user(s: &str) -> CreateChatCompletionRequest {
241        let msg = ChatCompletionRequestUserMessageArgs::default()
242            .content(s)
243            .build()
244            .unwrap();
245        CreateChatCompletionRequestArgs::default()
246            .model("gpt-4o")
247            .messages(vec![msg.into()])
248            .build()
249            .unwrap()
250    }
251
252    #[tokio::test]
253    async fn records_trace_on_step_done() {
254        let writer = InMemoryTraceStore::default();
255        let inner = service_fn(|req: CreateChatCompletionRequest| async move {
256            Ok::<_, BoxError>(StepOutcome::Done {
257                messages: req.messages,
258                aux: Default::default(),
259            })
260        });
261        let mut svc = RecorderLayer::new(writer.clone(), "t1").layer(inner);
262        let _ = ServiceExt::ready(&mut svc)
263            .await
264            .unwrap()
265            .call(req_with_user("hi"))
266            .await
267            .unwrap();
268        let trace = tower::Service::call(&mut writer.clone(), ReadTrace { id: "t1".into() })
269            .await
270            .unwrap();
271        assert!(!trace.items.is_empty());
272    }
273
274    #[tokio::test]
275    async fn replay_restores_messages() {
276        let store = InMemoryTraceStore::default();
277        // Write a trace
278        let msgs = req_with_user("hi").messages;
279        let items = messages_to_items(&msgs).unwrap();
280        tower::Service::call(
281            &mut store.clone(),
282            WriteTrace {
283                id: "t2".into(),
284                items,
285            },
286        )
287        .await
288        .unwrap();
289        // Replay
290        let mut replay = ReplayService::new(store.clone(), "t2", "gpt-4o");
291        let out = ServiceExt::ready(&mut replay)
292            .await
293            .unwrap()
294            .call(req_with_user("ignored"))
295            .await
296            .unwrap();
297        match out {
298            StepOutcome::Done { messages, .. } => assert!(!messages.is_empty()),
299            _ => panic!("expected done"),
300        }
301    }
302
303    #[tokio::test]
304    async fn recording_preserves_tool_output_and_calls() {
305        // Build a messages vector with assistant tool call and tool output JSON
306        use async_openai::types::{
307            ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
308            ChatCompletionRequestMessage, ChatCompletionRequestToolMessageArgs,
309            ChatCompletionToolType, FunctionCall,
310        };
311
312        // Assistant with one tool call
313        let tc = ChatCompletionMessageToolCall {
314            id: "call_1".to_string(),
315            r#type: ChatCompletionToolType::Function,
316            function: FunctionCall {
317                name: "calc".to_string(),
318                arguments: "{\"a\":1}".to_string(),
319            },
320        };
321        let asst = ChatCompletionRequestAssistantMessageArgs::default()
322            .content("")
323            .tool_calls(vec![tc])
324            .build()
325            .unwrap();
326        let tool = ChatCompletionRequestToolMessageArgs::default()
327            .content("{\"sum\":2}")
328            .tool_call_id("call_1")
329            .build()
330            .unwrap();
331
332        let out_messages = vec![
333            ChatCompletionRequestMessage::Assistant(asst),
334            ChatCompletionRequestMessage::Tool(tool),
335        ];
336
337        // Inner returns Done with those messages
338        let inner = service_fn(move |_req: CreateChatCompletionRequest| {
339            let msgs = out_messages.clone();
340            async move {
341                Ok::<_, BoxError>(StepOutcome::Done {
342                    messages: msgs,
343                    aux: Default::default(),
344                })
345            }
346        });
347
348        let writer = InMemoryTraceStore::default();
349        let mut svc = RecorderLayer::new(writer.clone(), "t3").layer(inner);
350        let req = req_with_user("start");
351        let _ = ServiceExt::ready(&mut svc)
352            .await
353            .unwrap()
354            .call(req)
355            .await
356            .unwrap();
357
358        // Read back trace and assert ToolCall + ToolOutput present with JSON preserved
359        let trace = tower::Service::call(&mut writer.clone(), ReadTrace { id: "t3".into() })
360            .await
361            .unwrap();
362        assert!(trace
363            .items
364            .iter()
365            .any(|it| matches!(it, RunItem::ToolCall(_))));
366        let out = trace
367            .items
368            .iter()
369            .find_map(|it| {
370                if let RunItem::ToolOutput(o) = it {
371                    Some(o)
372                } else {
373                    None
374                }
375            })
376            .unwrap();
377        assert_eq!(out.tool_call_id, "call_1");
378        assert_eq!(out.output, serde_json::json!({"sum":2}));
379    }
380
381    #[tokio::test]
382    async fn replay_reconstructs_tool_messages_fidelity() {
383        use async_openai::types::{
384            ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
385            ChatCompletionRequestMessage, ChatCompletionRequestToolMessageArgs,
386            ChatCompletionToolType, FunctionCall,
387        };
388
389        // Prepare items via codec from messages including tool output
390        let tc = ChatCompletionMessageToolCall {
391            id: "id1".to_string(),
392            r#type: ChatCompletionToolType::Function,
393            function: FunctionCall {
394                name: "echo".to_string(),
395                arguments: "{}".to_string(),
396            },
397        };
398        let asst = ChatCompletionRequestAssistantMessageArgs::default()
399            .content("")
400            .tool_calls(vec![tc])
401            .build()
402            .unwrap();
403        let tool = ChatCompletionRequestToolMessageArgs::default()
404            .content("{\"ok\":true}")
405            .tool_call_id("id1")
406            .build()
407            .unwrap();
408        let msgs = vec![
409            ChatCompletionRequestMessage::Assistant(asst),
410            ChatCompletionRequestMessage::Tool(tool),
411        ];
412        let items = messages_to_items(&msgs).unwrap();
413
414        // Store and replay
415        let store = InMemoryTraceStore::default();
416        tower::Service::call(
417            &mut store.clone(),
418            WriteTrace {
419                id: "t4".into(),
420                items,
421            },
422        )
423        .await
424        .unwrap();
425        let mut replay = ReplayService::new(store, "t4", "gpt-4o");
426        let out = ServiceExt::ready(&mut replay)
427            .await
428            .unwrap()
429            .call(req_with_user("ignored"))
430            .await
431            .unwrap();
432        match out {
433            StepOutcome::Done { messages, .. } => {
434                // Ensure there is a tool message and its content parses back to JSON
435                let tool_msg = messages
436                    .iter()
437                    .find(|m| matches!(m, ChatCompletionRequestMessage::Tool(_)))
438                    .unwrap();
439                if let ChatCompletionRequestMessage::Tool(t) = tool_msg {
440                    if let async_openai::types::ChatCompletionRequestToolMessageContent::Text(txt) =
441                        &t.content
442                    {
443                        let val: serde_json::Value = serde_json::from_str(txt).unwrap();
444                        assert_eq!(val, serde_json::json!({"ok": true}));
445                    } else {
446                        panic!("expected text content");
447                    }
448                }
449            }
450            _ => panic!("expected done"),
451        }
452    }
453
454    proptest::proptest! {
455        #[test]
456        fn replay_service_returns_valid_messages_for_valid_trace(msgs in gen::valid_conversation(gen::GeneratorConfig::default())) {
457            let items = crate::codec::messages_to_items(&msgs).unwrap();
458            let store = InMemoryTraceStore::default();
459            let mut writer = store.clone();
460            let rt = tokio::runtime::Runtime::new().unwrap();
461            rt.block_on(async {
462                tower::Service::call(&mut writer, WriteTrace { id: "t-valid".into(), items }).await.unwrap();
463            });
464            let mut replay = ReplayService::new(store, "t-valid", "gpt-4o");
465            let out = rt.block_on(async move {
466                ServiceExt::ready(&mut replay)
467                    .await
468                    .unwrap()
469                    .call(async_openai::types::CreateChatCompletionRequestArgs::default().model("gpt-4o").messages(vec![]).build().unwrap())
470                    .await
471                    .unwrap()
472            });
473            if let StepOutcome::Done { messages, .. } = out {
474                prop_assert!(validate_conversation(&messages, &ValidationPolicy::default()).is_none());
475            } else {
476                prop_assert!(false, "expected Done");
477            }
478        }
479    }
480}