1use std::collections::HashMap;
31use std::future::Future;
32use std::pin::Pin;
33use std::sync::Arc;
34
35use async_openai::types::{CreateChatCompletionRequest, CreateChatCompletionRequestArgs};
36use tokio::sync::Mutex;
37use tower::{BoxError, Layer, Service, ServiceExt};
38
39use crate::codec::{items_to_messages, messages_to_items};
40use crate::core::StepOutcome;
41use crate::items::RunItem;
42
43#[derive(Debug, Clone)]
44pub struct WriteTrace {
45 pub id: String,
46 pub items: Vec<RunItem>,
47}
48#[derive(Debug, Clone)]
49pub struct ReadTrace {
50 pub id: String,
51}
52#[derive(Debug, Clone, Default)]
53pub struct Trace {
54 pub items: Vec<RunItem>,
55}
56
57#[derive(Default, Clone)]
59pub struct InMemoryTraceStore(Arc<Mutex<HashMap<String, Trace>>>);
60
61impl Service<WriteTrace> for InMemoryTraceStore {
62 type Response = ();
63 type Error = BoxError;
64 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
65 fn poll_ready(
66 &mut self,
67 _cx: &mut std::task::Context<'_>,
68 ) -> std::task::Poll<Result<(), Self::Error>> {
69 std::task::Poll::Ready(Ok(()))
70 }
71 fn call(&mut self, req: WriteTrace) -> Self::Future {
72 let store = self.0.clone();
73 Box::pin(async move {
74 store
75 .lock()
76 .await
77 .insert(req.id, Trace { items: req.items });
78 Ok(())
79 })
80 }
81}
82
83impl Service<ReadTrace> for InMemoryTraceStore {
84 type Response = Trace;
85 type Error = BoxError;
86 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
87 fn poll_ready(
88 &mut self,
89 _cx: &mut std::task::Context<'_>,
90 ) -> std::task::Poll<Result<(), Self::Error>> {
91 std::task::Poll::Ready(Ok(()))
92 }
93 fn call(&mut self, req: ReadTrace) -> Self::Future {
94 let store = self.0.clone();
95 Box::pin(async move {
96 let trace = store.lock().await.get(&req.id).cloned().unwrap_or_default();
97 Ok(trace)
98 })
99 }
100}
101
102pub struct RecorderLayer<W> {
104 writer: W,
105 trace_id: String,
106}
107impl<W> RecorderLayer<W> {
108 pub fn new(writer: W, trace_id: impl Into<String>) -> Self {
109 Self {
110 writer,
111 trace_id: trace_id.into(),
112 }
113 }
114}
115
116pub struct Recorder<S, W> {
117 inner: S,
118 writer: W,
119 trace_id: String,
120}
121
122impl<S, W> Layer<S> for RecorderLayer<W>
123where
124 W: Clone,
125{
126 type Service = Recorder<S, W>;
127 fn layer(&self, inner: S) -> Self::Service {
128 Recorder {
129 inner,
130 writer: self.writer.clone(),
131 trace_id: self.trace_id.clone(),
132 }
133 }
134}
135
136impl<S, W> Service<CreateChatCompletionRequest> for Recorder<S, W>
137where
138 S: Service<CreateChatCompletionRequest, Response = StepOutcome, Error = BoxError>
139 + Send
140 + 'static,
141 S::Future: Send + 'static,
142 W: Service<WriteTrace, Response = (), Error = BoxError> + Clone + Send + 'static,
143 W::Future: Send + 'static,
144{
145 type Response = StepOutcome;
146 type Error = BoxError;
147 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
148
149 fn poll_ready(
150 &mut self,
151 cx: &mut std::task::Context<'_>,
152 ) -> std::task::Poll<Result<(), Self::Error>> {
153 self.inner.poll_ready(cx)
154 }
155
156 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
157 let mut writer = self.writer.clone();
158 let trace_id = self.trace_id.clone();
159 let fut = self.inner.call(req);
160 Box::pin(async move {
161 let out = fut.await?;
162 let messages = match &out {
163 StepOutcome::Next { messages, .. } | StepOutcome::Done { messages, .. } => {
164 messages.clone()
165 }
166 };
167 let items = messages_to_items(&messages).map_err(|e| format!("codec: {}", e))?;
168 ServiceExt::ready(&mut writer)
169 .await?
170 .call(WriteTrace {
171 id: trace_id,
172 items,
173 })
174 .await?;
175 Ok(out)
176 })
177 }
178}
179
180pub struct ReplayService<R> {
182 reader: R,
183 trace_id: String,
184 model: String,
185}
186impl<R> ReplayService<R> {
187 pub fn new(reader: R, trace_id: impl Into<String>, model: impl Into<String>) -> Self {
188 Self {
189 reader,
190 trace_id: trace_id.into(),
191 model: model.into(),
192 }
193 }
194}
195
196impl<R> Service<CreateChatCompletionRequest> for ReplayService<R>
197where
198 R: Service<ReadTrace, Response = Trace, Error = BoxError> + Send + Clone + 'static,
199 R::Future: Send + 'static,
200{
201 type Response = StepOutcome;
202 type Error = BoxError;
203 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
204
205 fn poll_ready(
206 &mut self,
207 _cx: &mut std::task::Context<'_>,
208 ) -> std::task::Poll<Result<(), Self::Error>> {
209 std::task::Poll::Ready(Ok(()))
210 }
211
212 fn call(&mut self, _req: CreateChatCompletionRequest) -> Self::Future {
213 let mut reader = self.reader.clone();
214 let trace_id = self.trace_id.clone();
215 let model = self.model.clone();
216 Box::pin(async move {
217 let trace = Service::call(&mut reader, ReadTrace { id: trace_id }).await?;
218 let messages = items_to_messages(&trace.items);
219 let _req = CreateChatCompletionRequestArgs::default()
220 .model(model)
221 .messages(messages.clone())
222 .build()?;
223 Ok(StepOutcome::Done {
225 messages,
226 aux: Default::default(),
227 })
228 })
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::validation::{gen, validate_conversation, ValidationPolicy};
236 use async_openai::types::ChatCompletionRequestUserMessageArgs;
237 use proptest::prop_assert;
238 use tower::service_fn;
239
240 fn req_with_user(s: &str) -> CreateChatCompletionRequest {
241 let msg = ChatCompletionRequestUserMessageArgs::default()
242 .content(s)
243 .build()
244 .unwrap();
245 CreateChatCompletionRequestArgs::default()
246 .model("gpt-4o")
247 .messages(vec![msg.into()])
248 .build()
249 .unwrap()
250 }
251
252 #[tokio::test]
253 async fn records_trace_on_step_done() {
254 let writer = InMemoryTraceStore::default();
255 let inner = service_fn(|req: CreateChatCompletionRequest| async move {
256 Ok::<_, BoxError>(StepOutcome::Done {
257 messages: req.messages,
258 aux: Default::default(),
259 })
260 });
261 let mut svc = RecorderLayer::new(writer.clone(), "t1").layer(inner);
262 let _ = ServiceExt::ready(&mut svc)
263 .await
264 .unwrap()
265 .call(req_with_user("hi"))
266 .await
267 .unwrap();
268 let trace = tower::Service::call(&mut writer.clone(), ReadTrace { id: "t1".into() })
269 .await
270 .unwrap();
271 assert!(!trace.items.is_empty());
272 }
273
274 #[tokio::test]
275 async fn replay_restores_messages() {
276 let store = InMemoryTraceStore::default();
277 let msgs = req_with_user("hi").messages;
279 let items = messages_to_items(&msgs).unwrap();
280 tower::Service::call(
281 &mut store.clone(),
282 WriteTrace {
283 id: "t2".into(),
284 items,
285 },
286 )
287 .await
288 .unwrap();
289 let mut replay = ReplayService::new(store.clone(), "t2", "gpt-4o");
291 let out = ServiceExt::ready(&mut replay)
292 .await
293 .unwrap()
294 .call(req_with_user("ignored"))
295 .await
296 .unwrap();
297 match out {
298 StepOutcome::Done { messages, .. } => assert!(!messages.is_empty()),
299 _ => panic!("expected done"),
300 }
301 }
302
303 #[tokio::test]
304 async fn recording_preserves_tool_output_and_calls() {
305 use async_openai::types::{
307 ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
308 ChatCompletionRequestMessage, ChatCompletionRequestToolMessageArgs,
309 ChatCompletionToolType, FunctionCall,
310 };
311
312 let tc = ChatCompletionMessageToolCall {
314 id: "call_1".to_string(),
315 r#type: ChatCompletionToolType::Function,
316 function: FunctionCall {
317 name: "calc".to_string(),
318 arguments: "{\"a\":1}".to_string(),
319 },
320 };
321 let asst = ChatCompletionRequestAssistantMessageArgs::default()
322 .content("")
323 .tool_calls(vec![tc])
324 .build()
325 .unwrap();
326 let tool = ChatCompletionRequestToolMessageArgs::default()
327 .content("{\"sum\":2}")
328 .tool_call_id("call_1")
329 .build()
330 .unwrap();
331
332 let out_messages = vec![
333 ChatCompletionRequestMessage::Assistant(asst),
334 ChatCompletionRequestMessage::Tool(tool),
335 ];
336
337 let inner = service_fn(move |_req: CreateChatCompletionRequest| {
339 let msgs = out_messages.clone();
340 async move {
341 Ok::<_, BoxError>(StepOutcome::Done {
342 messages: msgs,
343 aux: Default::default(),
344 })
345 }
346 });
347
348 let writer = InMemoryTraceStore::default();
349 let mut svc = RecorderLayer::new(writer.clone(), "t3").layer(inner);
350 let req = req_with_user("start");
351 let _ = ServiceExt::ready(&mut svc)
352 .await
353 .unwrap()
354 .call(req)
355 .await
356 .unwrap();
357
358 let trace = tower::Service::call(&mut writer.clone(), ReadTrace { id: "t3".into() })
360 .await
361 .unwrap();
362 assert!(trace
363 .items
364 .iter()
365 .any(|it| matches!(it, RunItem::ToolCall(_))));
366 let out = trace
367 .items
368 .iter()
369 .find_map(|it| {
370 if let RunItem::ToolOutput(o) = it {
371 Some(o)
372 } else {
373 None
374 }
375 })
376 .unwrap();
377 assert_eq!(out.tool_call_id, "call_1");
378 assert_eq!(out.output, serde_json::json!({"sum":2}));
379 }
380
381 #[tokio::test]
382 async fn replay_reconstructs_tool_messages_fidelity() {
383 use async_openai::types::{
384 ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
385 ChatCompletionRequestMessage, ChatCompletionRequestToolMessageArgs,
386 ChatCompletionToolType, FunctionCall,
387 };
388
389 let tc = ChatCompletionMessageToolCall {
391 id: "id1".to_string(),
392 r#type: ChatCompletionToolType::Function,
393 function: FunctionCall {
394 name: "echo".to_string(),
395 arguments: "{}".to_string(),
396 },
397 };
398 let asst = ChatCompletionRequestAssistantMessageArgs::default()
399 .content("")
400 .tool_calls(vec![tc])
401 .build()
402 .unwrap();
403 let tool = ChatCompletionRequestToolMessageArgs::default()
404 .content("{\"ok\":true}")
405 .tool_call_id("id1")
406 .build()
407 .unwrap();
408 let msgs = vec![
409 ChatCompletionRequestMessage::Assistant(asst),
410 ChatCompletionRequestMessage::Tool(tool),
411 ];
412 let items = messages_to_items(&msgs).unwrap();
413
414 let store = InMemoryTraceStore::default();
416 tower::Service::call(
417 &mut store.clone(),
418 WriteTrace {
419 id: "t4".into(),
420 items,
421 },
422 )
423 .await
424 .unwrap();
425 let mut replay = ReplayService::new(store, "t4", "gpt-4o");
426 let out = ServiceExt::ready(&mut replay)
427 .await
428 .unwrap()
429 .call(req_with_user("ignored"))
430 .await
431 .unwrap();
432 match out {
433 StepOutcome::Done { messages, .. } => {
434 let tool_msg = messages
436 .iter()
437 .find(|m| matches!(m, ChatCompletionRequestMessage::Tool(_)))
438 .unwrap();
439 if let ChatCompletionRequestMessage::Tool(t) = tool_msg {
440 if let async_openai::types::ChatCompletionRequestToolMessageContent::Text(txt) =
441 &t.content
442 {
443 let val: serde_json::Value = serde_json::from_str(txt).unwrap();
444 assert_eq!(val, serde_json::json!({"ok": true}));
445 } else {
446 panic!("expected text content");
447 }
448 }
449 }
450 _ => panic!("expected done"),
451 }
452 }
453
454 proptest::proptest! {
455 #[test]
456 fn replay_service_returns_valid_messages_for_valid_trace(msgs in gen::valid_conversation(gen::GeneratorConfig::default())) {
457 let items = crate::codec::messages_to_items(&msgs).unwrap();
458 let store = InMemoryTraceStore::default();
459 let mut writer = store.clone();
460 let rt = tokio::runtime::Runtime::new().unwrap();
461 rt.block_on(async {
462 tower::Service::call(&mut writer, WriteTrace { id: "t-valid".into(), items }).await.unwrap();
463 });
464 let mut replay = ReplayService::new(store, "t-valid", "gpt-4o");
465 let out = rt.block_on(async move {
466 ServiceExt::ready(&mut replay)
467 .await
468 .unwrap()
469 .call(async_openai::types::CreateChatCompletionRequestArgs::default().model("gpt-4o").messages(vec![]).build().unwrap())
470 .await
471 .unwrap()
472 });
473 if let StepOutcome::Done { messages, .. } = out {
474 prop_assert!(validate_conversation(&messages, &ValidationPolicy::default()).is_none());
475 } else {
476 prop_assert!(false, "expected Done");
477 }
478 }
479 }
480}