tower_llm/sessions/
mod.rs

1//! Sessions and persistence (memory/replay) for the `next` stack
2//!
3//! What this module provides (spec)
4//! - A clear, Tower-native way to persist and replay conversational state
5//! - No dynamic lookups; all dependencies are constructor-injected
6//! - Interoperates with the codec (RunItem ↔ raw messages) and recording
7//!
8//! Exports (public API surface)
9//! - Models
10//!   - `SessionId` (newtype)
11//!   - `LoadSession { id: SessionId }`, `SaveSession { id: SessionId, history: History }`
12//!   - `History` = `Vec<RawChatMessage>` (or a smalltype wrapper)
13//! - Services
14//!   - `SessionLoadStore: Service<LoadSession, Response=History, Error=BoxError>`
15//!   - `SessionSaveStore: Service<SaveSession, Response=(), Error=BoxError>`
16//!     - Impl examples: `SqliteSessionStore`, `InMemorySessionStore`
17//! - Layers
18//!   - `MemoryLayer<S>` where `S: Service<RawChatRequest, Response=StepOutcome>`
19//!     - On call: loads `History`, merges into request messages, forwards, then appends new messages and saves
20//!   - `RecorderLayer<S>` (see recording module)
21//!   - `ReplayLayer<S>` (short-circuits with canned outcomes)
22//! - Utils (sugar)
23//!   - AgentBuilder: `.session(load_store, save_store, session_id)`
24//!   - Helpers: `merge_history(history, request_messages)`
25//!
26//! Implementation strategy
27//! - Session stores are plain services with typed requests (no global registries)
28//! - `MemoryLayer` holds `Arc<SessionLoadStore>`, `Arc<SessionSaveStore>`, and `SessionId`
29//! - On each call:
30//!   1) `load_store.call(LoadSession { id })` → `History`
31//!   2) Compose `RawChatRequest` by prefixing `History` to current messages
32//!   3) Forward to inner step/agent
33//!   4) Extract newly produced messages from `StepOutcome`/`AgentRun` and append to `History`
34//!   5) `save_store.call(SaveSession { id, history })`
35//! - Errors bubble up; store errors are surfaced explicitly
36//!
37//! Composition examples
38//! - `ServiceBuilder::new().layer(MemoryLayer::new(load, save, session_id)).service(step)`
39//! - Combine with `RecorderLayer` if you want both persistence and replay traces
40//!
41//! Testing strategy
42//! - Unit tests
43//!   - Fake stores using `tower::service_fn` to simulate load/save
44//!   - Assert correct merge order (history first) and that saves receive appended messages
45//!   - Error propagation when load/save fails
46//! - Integration tests
47//!   - With a fake model provider and a real `InMemorySessionStore`, verify multi-turn accumulation
48//!   - With `ReplayLayer`, verify deterministic reproduction of a captured trace
49//!
50//! Notes and constraints
51//! - Keep the session I/O isolated behind services; do not push DB/file logic into layers
52//! - Prefer separate load/save services to keep signatures simple and testable
53//! - The replay logic defers to the `recording` and `codec` modules for trace fidelity
54
55use std::collections::HashMap;
56use std::future::Future;
57use std::pin::Pin;
58use std::sync::{Arc, Mutex};
59
60use async_openai::types::{
61    ChatCompletionRequestMessage, CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
62};
63use tower::{BoxError, Layer, Service};
64
65/// Session identifier newtype.
66#[derive(Debug, Clone, PartialEq, Eq, Hash)]
67pub struct SessionId(pub String);
68
69/// History of chat messages for a session.
70pub type History = Vec<ChatCompletionRequestMessage>;
71/// Trait for extracting conversation messages from various response types.
72pub trait ConversationMessages {
73    fn to_messages(&self) -> Vec<ChatCompletionRequestMessage>;
74}
75
76impl ConversationMessages for CreateChatCompletionRequest {
77    fn to_messages(&self) -> Vec<ChatCompletionRequestMessage> {
78        self.messages.clone()
79    }
80}
81
82impl ConversationMessages for crate::core::AgentRun {
83    fn to_messages(&self) -> Vec<ChatCompletionRequestMessage> {
84        self.messages.clone()
85    }
86}
87
88impl ConversationMessages for crate::core::StepOutcome {
89    fn to_messages(&self) -> Vec<ChatCompletionRequestMessage> {
90        match self {
91            crate::core::StepOutcome::Next { messages, .. } => messages.clone(),
92            crate::core::StepOutcome::Done { messages, .. } => messages.clone(),
93        }
94    }
95}
96
97/// Load request for a session.
98#[derive(Debug, Clone)]
99pub struct LoadSession {
100    pub id: SessionId,
101}
102
103/// Save request for a session.
104#[derive(Debug, Clone)]
105pub struct SaveSession {
106    pub id: SessionId,
107    pub history: History,
108}
109
110/// A simple in-memory session store implementing both load and save services.
111#[derive(Default, Clone)]
112pub struct InMemorySessionStore {
113    inner: Arc<Mutex<HashMap<SessionId, History>>>,
114}
115
116impl Service<LoadSession> for InMemorySessionStore {
117    type Response = History;
118    type Error = BoxError;
119    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
120
121    fn poll_ready(
122        &mut self,
123        _cx: &mut std::task::Context<'_>,
124    ) -> std::task::Poll<Result<(), Self::Error>> {
125        std::task::Poll::Ready(Ok(()))
126    }
127
128    fn call(&mut self, req: LoadSession) -> Self::Future {
129        let inner = self.inner.clone();
130        Box::pin(async move {
131            let map = inner.lock().unwrap();
132            Ok(map.get(&req.id).cloned().unwrap_or_default())
133        })
134    }
135}
136
137impl Service<SaveSession> for InMemorySessionStore {
138    type Response = ();
139    type Error = BoxError;
140    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
141
142    fn poll_ready(
143        &mut self,
144        _cx: &mut std::task::Context<'_>,
145    ) -> std::task::Poll<Result<(), Self::Error>> {
146        std::task::Poll::Ready(Ok(()))
147    }
148
149    fn call(&mut self, req: SaveSession) -> Self::Future {
150        let inner = self.inner.clone();
151        Box::pin(async move {
152            let mut map = inner.lock().unwrap();
153            map.insert(req.id, req.history);
154            Ok(())
155        })
156    }
157}
158
159/// Layer configuration for memory persistence.
160#[derive(Clone)]
161pub struct MemoryLayer<L, S> {
162    load: Arc<L>,
163    save: Arc<S>,
164    session_id: SessionId,
165}
166
167impl<L, S> MemoryLayer<L, S> {
168    pub fn new(load: Arc<L>, save: Arc<S>, session_id: SessionId) -> Self {
169        Self {
170            load,
171            save,
172            session_id,
173        }
174    }
175}
176
177/// Wrapped service that loads history before the call and saves after.
178pub struct Memory<S, L, Sv> {
179    inner: Arc<tokio::sync::Mutex<S>>,
180    load: L,
181    save: Sv,
182    session_id: SessionId,
183}
184
185impl<S, L, Sv> Layer<S> for MemoryLayer<L, Sv>
186where
187    L: Service<LoadSession, Response = History, Error = BoxError> + Send + Clone + 'static,
188    L::Future: Send + 'static,
189    Sv: Service<SaveSession, Response = (), Error = BoxError> + Send + Clone + 'static,
190    Sv::Future: Send + 'static,
191{
192    type Service = Memory<S, L, Sv>;
193
194    fn layer(&self, inner: S) -> Self::Service {
195        Memory {
196            inner: Arc::new(tokio::sync::Mutex::new(inner)),
197            load: (*self.load).clone(),
198            save: (*self.save).clone(),
199            session_id: self.session_id.clone(),
200        }
201    }
202}
203
204impl<S, Ls, Ss, R> Service<CreateChatCompletionRequest> for Memory<S, Ls, Ss>
205where
206    S: Service<CreateChatCompletionRequest, Response = R> + Send + 'static,
207    R: ConversationMessages + Send + 'static,
208    S::Error: Into<BoxError>,
209    S::Future: Send + 'static,
210    Ls: Service<LoadSession, Response = History, Error = BoxError> + Send + Sync + Clone + 'static,
211    Ls::Future: Send + 'static,
212    Ss: Service<SaveSession, Response = (), Error = BoxError> + Send + Sync + Clone + 'static,
213    Ss::Future: Send + 'static,
214{
215    type Response = R;
216    type Error = BoxError;
217    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
218
219    fn poll_ready(
220        &mut self,
221        cx: &mut std::task::Context<'_>,
222    ) -> std::task::Poll<Result<(), Self::Error>> {
223        let _ = cx;
224        std::task::Poll::Ready(Ok(()))
225    }
226
227    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
228        let session_id = self.session_id.clone();
229        let load = self.load.clone();
230        let save = self.save.clone();
231        let inner = self.inner.clone();
232
233        Box::pin(async move {
234            // Load history
235            let mut load_svc = load;
236            let history = Service::call(
237                &mut load_svc,
238                LoadSession {
239                    id: session_id.clone(),
240                },
241            )
242            .await?;
243
244            // Build combined request preserving model/tools/knobs
245            let mut builder = CreateChatCompletionRequestArgs::default();
246            builder.model(&req.model);
247            // Combine history and current messages
248            let mut combined = history;
249            combined.extend(req.messages.clone());
250            builder.messages(combined.clone());
251            if let Some(t) = req.temperature {
252                builder.temperature(t);
253            }
254            if let Some(ts) = req.tools.clone() {
255                builder.tools(ts);
256            }
257            let combined_req = builder
258                .build()
259                .map_err(|e| -> BoxError { e.to_string().into() })?;
260
261            // Call inner
262            let resp: R = {
263                let mut guard = inner.lock().await;
264                Service::call(&mut *guard, combined_req)
265                    .await
266                    .map_err(Into::<BoxError>::into)?
267            };
268
269            // Persist the latest messages; for simplicity, overwrite full history
270            let latest_messages = resp.to_messages();
271            let mut save_svc = save;
272            Service::call(
273                &mut save_svc,
274                SaveSession {
275                    id: session_id,
276                    history: latest_messages,
277                },
278            )
279            .await?;
280
281            Ok(resp)
282        })
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::validation::{validate_conversation, ValidationPolicy};
290    use async_openai::types::{
291        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
292    };
293    use tower::service_fn;
294
295    fn req_with_messages(
296        messages: Vec<ChatCompletionRequestMessage>,
297    ) -> CreateChatCompletionRequest {
298        let mut builder = CreateChatCompletionRequestArgs::default();
299        builder.model("gpt-4o");
300        builder.messages(messages);
301        builder.build().unwrap()
302    }
303
304    #[tokio::test]
305    async fn memory_layer_loads_and_saves_history() {
306        // Seed store with prior history
307        let store = InMemorySessionStore::default();
308        let session_id = SessionId("s1".into());
309
310        // Save initial history
311        let prior = vec![
312            ChatCompletionRequestSystemMessageArgs::default()
313                .content("sys")
314                .build()
315                .unwrap()
316                .into(),
317            ChatCompletionRequestUserMessageArgs::default()
318                .content("prev")
319                .build()
320                .unwrap()
321                .into(),
322        ];
323        let mut save_clone = store.clone();
324        tower::Service::call(
325            &mut save_clone,
326            SaveSession {
327                id: session_id.clone(),
328                history: prior.clone(),
329            },
330        )
331        .await
332        .unwrap();
333
334        // Inner service echoes messages and returns them unchanged; capture merged request
335        let captured: std::sync::Arc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>> =
336            std::sync::Arc::new(tokio::sync::Mutex::new(None));
337        let captured_clone = captured.clone();
338        let inner = service_fn(move |req: CreateChatCompletionRequest| {
339            let captured_inner = captured_clone.clone();
340            async move {
341                *captured_inner.lock().await = Some(req.clone());
342                Ok::<_, BoxError>(req)
343            }
344        });
345
346        // Wrap with Memory layer
347        let layer = MemoryLayer::new(
348            Arc::new(store.clone()),
349            Arc::new(store.clone()),
350            session_id.clone(),
351        );
352        let mut svc = layer.layer(inner);
353
354        // Call with a new user message
355        let req = req_with_messages(vec![ChatCompletionRequestUserMessageArgs::default()
356            .content("hello")
357            .build()
358            .unwrap()
359            .into()]);
360        let _resp = tower::Service::call(&mut svc, req).await.unwrap();
361
362        // Validate preflight (merged) request messages are valid
363        let merged = captured.lock().await.clone().expect("captured");
364        let policy = ValidationPolicy {
365            allow_repeated_roles: true,
366            ..Default::default()
367        };
368        assert!(validate_conversation(&merged.messages, &policy).is_none());
369
370        // Verify store now contains prior + new
371        let mut load = store.clone();
372        let history = tower::Service::call(&mut load, LoadSession { id: session_id })
373            .await
374            .unwrap();
375        assert_eq!(history.len(), 3);
376        let policy = ValidationPolicy {
377            allow_repeated_roles: true,
378            ..Default::default()
379        };
380        assert!(validate_conversation(&history, &policy).is_none());
381    }
382
383    #[tokio::test]
384    async fn memory_layer_persists_step_outcome_done_messages() {
385        let store = InMemorySessionStore::default();
386        let session_id = SessionId("s2".into());
387
388        // Seed prior history
389        let prior = vec![
390            ChatCompletionRequestSystemMessageArgs::default()
391                .content("sys")
392                .build()
393                .unwrap()
394                .into(),
395            ChatCompletionRequestUserMessageArgs::default()
396                .content("prev")
397                .build()
398                .unwrap()
399                .into(),
400        ];
401        let mut save_clone = store.clone();
402        tower::Service::call(
403            &mut save_clone,
404            SaveSession {
405                id: session_id.clone(),
406                history: prior.clone(),
407            },
408        )
409        .await
410        .unwrap();
411
412        // Inner builds StepOutcome::Done with appended assistant message based on incoming request
413        let inner = service_fn(|req: CreateChatCompletionRequest| async move {
414            let mut msgs = req.messages.clone();
415            let asst = async_openai::types::ChatCompletionRequestAssistantMessageArgs::default()
416                .content("ok")
417                .build()
418                .unwrap();
419            msgs.push(asst.into());
420            Ok::<_, BoxError>(crate::core::StepOutcome::Done {
421                messages: msgs,
422                aux: Default::default(),
423            })
424        });
425
426        let layer = MemoryLayer::new(
427            Arc::new(store.clone()),
428            Arc::new(store.clone()),
429            session_id.clone(),
430        );
431        let mut svc = layer.layer(inner);
432
433        let req = req_with_messages(vec![ChatCompletionRequestUserMessageArgs::default()
434            .content("hello")
435            .build()
436            .unwrap()
437            .into()]);
438        let _ = tower::Service::call(&mut svc, req).await.unwrap();
439
440        let mut load = store.clone();
441        let history = tower::Service::call(&mut load, LoadSession { id: session_id })
442            .await
443            .unwrap();
444        assert_eq!(history.len(), 4); // prior(2) + new user + appended assistant
445    }
446
447    #[tokio::test]
448    async fn memory_layer_persists_step_outcome_next_messages() {
449        let store = InMemorySessionStore::default();
450        let session_id = SessionId("s3".into());
451
452        // No prior history
453        let inner = service_fn(|req: CreateChatCompletionRequest| async move {
454            let mut msgs = req.messages.clone();
455            let asst = async_openai::types::ChatCompletionRequestAssistantMessageArgs::default()
456                .content("next")
457                .build()
458                .unwrap();
459            msgs.push(asst.into());
460            Ok::<_, BoxError>(crate::core::StepOutcome::Next {
461                messages: msgs,
462                aux: Default::default(),
463                invoked_tools: vec![],
464            })
465        });
466
467        let layer = MemoryLayer::new(
468            Arc::new(store.clone()),
469            Arc::new(store.clone()),
470            session_id.clone(),
471        );
472        let mut svc = layer.layer(inner);
473
474        let req = req_with_messages(vec![ChatCompletionRequestUserMessageArgs::default()
475            .content("hey")
476            .build()
477            .unwrap()
478            .into()]);
479        let _ = tower::Service::call(&mut svc, req).await.unwrap();
480
481        let mut load = store.clone();
482        let history = tower::Service::call(&mut load, LoadSession { id: session_id })
483            .await
484            .unwrap();
485        assert_eq!(history.len(), 2); // user + assistant("next")
486    }
487
488    #[tokio::test]
489    async fn memory_layer_persists_agent_run_messages() {
490        let store = InMemorySessionStore::default();
491        let session_id = SessionId("s4".into());
492
493        // Inner returns AgentRun echoing request messages plus one assistant
494        let inner = service_fn(|req: CreateChatCompletionRequest| async move {
495            let mut msgs = req.messages.clone();
496            let asst = async_openai::types::ChatCompletionRequestAssistantMessageArgs::default()
497                .content("from agent")
498                .build()
499                .unwrap();
500            msgs.push(asst.into());
501            Ok::<_, BoxError>(crate::core::AgentRun {
502                messages: msgs,
503                steps: 1,
504                stop: crate::core::AgentStopReason::DoneNoToolCalls,
505            })
506        });
507
508        let layer = MemoryLayer::new(
509            Arc::new(store.clone()),
510            Arc::new(store.clone()),
511            session_id.clone(),
512        );
513        let mut svc = layer.layer(inner);
514
515        let req = req_with_messages(vec![ChatCompletionRequestUserMessageArgs::default()
516            .content("hey")
517            .build()
518            .unwrap()
519            .into()]);
520        let _ = tower::Service::call(&mut svc, req).await.unwrap();
521
522        let mut load = store.clone();
523        let history = tower::Service::call(&mut load, LoadSession { id: session_id })
524            .await
525            .unwrap();
526        assert_eq!(history.len(), 2); // user + assistant("from agent")
527    }
528}