1use std::collections::HashMap;
56use std::future::Future;
57use std::pin::Pin;
58use std::sync::{Arc, Mutex};
59
60use async_openai::types::{
61    ChatCompletionRequestMessage, CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
62};
63use tower::{BoxError, Layer, Service};
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
67pub struct SessionId(pub String);
68
69pub type History = Vec<ChatCompletionRequestMessage>;
71pub trait ConversationMessages {
73    fn to_messages(&self) -> Vec<ChatCompletionRequestMessage>;
74}
75
76impl ConversationMessages for CreateChatCompletionRequest {
77    fn to_messages(&self) -> Vec<ChatCompletionRequestMessage> {
78        self.messages.clone()
79    }
80}
81
82impl ConversationMessages for crate::core::AgentRun {
83    fn to_messages(&self) -> Vec<ChatCompletionRequestMessage> {
84        self.messages.clone()
85    }
86}
87
88impl ConversationMessages for crate::core::StepOutcome {
89    fn to_messages(&self) -> Vec<ChatCompletionRequestMessage> {
90        match self {
91            crate::core::StepOutcome::Next { messages, .. } => messages.clone(),
92            crate::core::StepOutcome::Done { messages, .. } => messages.clone(),
93        }
94    }
95}
96
97#[derive(Debug, Clone)]
99pub struct LoadSession {
100    pub id: SessionId,
101}
102
103#[derive(Debug, Clone)]
105pub struct SaveSession {
106    pub id: SessionId,
107    pub history: History,
108}
109
110#[derive(Default, Clone)]
112pub struct InMemorySessionStore {
113    inner: Arc<Mutex<HashMap<SessionId, History>>>,
114}
115
116impl Service<LoadSession> for InMemorySessionStore {
117    type Response = History;
118    type Error = BoxError;
119    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
120
121    fn poll_ready(
122        &mut self,
123        _cx: &mut std::task::Context<'_>,
124    ) -> std::task::Poll<Result<(), Self::Error>> {
125        std::task::Poll::Ready(Ok(()))
126    }
127
128    fn call(&mut self, req: LoadSession) -> Self::Future {
129        let inner = self.inner.clone();
130        Box::pin(async move {
131            let map = inner.lock().unwrap();
132            Ok(map.get(&req.id).cloned().unwrap_or_default())
133        })
134    }
135}
136
137impl Service<SaveSession> for InMemorySessionStore {
138    type Response = ();
139    type Error = BoxError;
140    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
141
142    fn poll_ready(
143        &mut self,
144        _cx: &mut std::task::Context<'_>,
145    ) -> std::task::Poll<Result<(), Self::Error>> {
146        std::task::Poll::Ready(Ok(()))
147    }
148
149    fn call(&mut self, req: SaveSession) -> Self::Future {
150        let inner = self.inner.clone();
151        Box::pin(async move {
152            let mut map = inner.lock().unwrap();
153            map.insert(req.id, req.history);
154            Ok(())
155        })
156    }
157}
158
159#[derive(Clone)]
161pub struct MemoryLayer<L, S> {
162    load: Arc<L>,
163    save: Arc<S>,
164    session_id: SessionId,
165}
166
167impl<L, S> MemoryLayer<L, S> {
168    pub fn new(load: Arc<L>, save: Arc<S>, session_id: SessionId) -> Self {
169        Self {
170            load,
171            save,
172            session_id,
173        }
174    }
175}
176
177pub struct Memory<S, L, Sv> {
179    inner: Arc<tokio::sync::Mutex<S>>,
180    load: L,
181    save: Sv,
182    session_id: SessionId,
183}
184
185impl<S, L, Sv> Layer<S> for MemoryLayer<L, Sv>
186where
187    L: Service<LoadSession, Response = History, Error = BoxError> + Send + Clone + 'static,
188    L::Future: Send + 'static,
189    Sv: Service<SaveSession, Response = (), Error = BoxError> + Send + Clone + 'static,
190    Sv::Future: Send + 'static,
191{
192    type Service = Memory<S, L, Sv>;
193
194    fn layer(&self, inner: S) -> Self::Service {
195        Memory {
196            inner: Arc::new(tokio::sync::Mutex::new(inner)),
197            load: (*self.load).clone(),
198            save: (*self.save).clone(),
199            session_id: self.session_id.clone(),
200        }
201    }
202}
203
204impl<S, Ls, Ss, R> Service<CreateChatCompletionRequest> for Memory<S, Ls, Ss>
205where
206    S: Service<CreateChatCompletionRequest, Response = R> + Send + 'static,
207    R: ConversationMessages + Send + 'static,
208    S::Error: Into<BoxError>,
209    S::Future: Send + 'static,
210    Ls: Service<LoadSession, Response = History, Error = BoxError> + Send + Sync + Clone + 'static,
211    Ls::Future: Send + 'static,
212    Ss: Service<SaveSession, Response = (), Error = BoxError> + Send + Sync + Clone + 'static,
213    Ss::Future: Send + 'static,
214{
215    type Response = R;
216    type Error = BoxError;
217    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
218
219    fn poll_ready(
220        &mut self,
221        cx: &mut std::task::Context<'_>,
222    ) -> std::task::Poll<Result<(), Self::Error>> {
223        let _ = cx;
224        std::task::Poll::Ready(Ok(()))
225    }
226
227    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
228        let session_id = self.session_id.clone();
229        let load = self.load.clone();
230        let save = self.save.clone();
231        let inner = self.inner.clone();
232
233        Box::pin(async move {
234            let mut load_svc = load;
236            let history = Service::call(
237                &mut load_svc,
238                LoadSession {
239                    id: session_id.clone(),
240                },
241            )
242            .await?;
243
244            let mut builder = CreateChatCompletionRequestArgs::default();
246            builder.model(&req.model);
247            let mut combined = history;
249            combined.extend(req.messages.clone());
250            builder.messages(combined.clone());
251            if let Some(t) = req.temperature {
252                builder.temperature(t);
253            }
254            if let Some(ts) = req.tools.clone() {
255                builder.tools(ts);
256            }
257            let combined_req = builder
258                .build()
259                .map_err(|e| -> BoxError { e.to_string().into() })?;
260
261            let resp: R = {
263                let mut guard = inner.lock().await;
264                Service::call(&mut *guard, combined_req)
265                    .await
266                    .map_err(Into::<BoxError>::into)?
267            };
268
269            let latest_messages = resp.to_messages();
271            let mut save_svc = save;
272            Service::call(
273                &mut save_svc,
274                SaveSession {
275                    id: session_id,
276                    history: latest_messages,
277                },
278            )
279            .await?;
280
281            Ok(resp)
282        })
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::validation::{validate_conversation, ValidationPolicy};
290    use async_openai::types::{
291        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
292    };
293    use tower::service_fn;
294
295    fn req_with_messages(
296        messages: Vec<ChatCompletionRequestMessage>,
297    ) -> CreateChatCompletionRequest {
298        let mut builder = CreateChatCompletionRequestArgs::default();
299        builder.model("gpt-4o");
300        builder.messages(messages);
301        builder.build().unwrap()
302    }
303
304    #[tokio::test]
305    async fn memory_layer_loads_and_saves_history() {
306        let store = InMemorySessionStore::default();
308        let session_id = SessionId("s1".into());
309
310        let prior = vec![
312            ChatCompletionRequestSystemMessageArgs::default()
313                .content("sys")
314                .build()
315                .unwrap()
316                .into(),
317            ChatCompletionRequestUserMessageArgs::default()
318                .content("prev")
319                .build()
320                .unwrap()
321                .into(),
322        ];
323        let mut save_clone = store.clone();
324        tower::Service::call(
325            &mut save_clone,
326            SaveSession {
327                id: session_id.clone(),
328                history: prior.clone(),
329            },
330        )
331        .await
332        .unwrap();
333
334        let captured: std::sync::Arc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>> =
336            std::sync::Arc::new(tokio::sync::Mutex::new(None));
337        let captured_clone = captured.clone();
338        let inner = service_fn(move |req: CreateChatCompletionRequest| {
339            let captured_inner = captured_clone.clone();
340            async move {
341                *captured_inner.lock().await = Some(req.clone());
342                Ok::<_, BoxError>(req)
343            }
344        });
345
346        let layer = MemoryLayer::new(
348            Arc::new(store.clone()),
349            Arc::new(store.clone()),
350            session_id.clone(),
351        );
352        let mut svc = layer.layer(inner);
353
354        let req = req_with_messages(vec![ChatCompletionRequestUserMessageArgs::default()
356            .content("hello")
357            .build()
358            .unwrap()
359            .into()]);
360        let _resp = tower::Service::call(&mut svc, req).await.unwrap();
361
362        let merged = captured.lock().await.clone().expect("captured");
364        let policy = ValidationPolicy {
365            allow_repeated_roles: true,
366            ..Default::default()
367        };
368        assert!(validate_conversation(&merged.messages, &policy).is_none());
369
370        let mut load = store.clone();
372        let history = tower::Service::call(&mut load, LoadSession { id: session_id })
373            .await
374            .unwrap();
375        assert_eq!(history.len(), 3);
376        let policy = ValidationPolicy {
377            allow_repeated_roles: true,
378            ..Default::default()
379        };
380        assert!(validate_conversation(&history, &policy).is_none());
381    }
382
383    #[tokio::test]
384    async fn memory_layer_persists_step_outcome_done_messages() {
385        let store = InMemorySessionStore::default();
386        let session_id = SessionId("s2".into());
387
388        let prior = vec![
390            ChatCompletionRequestSystemMessageArgs::default()
391                .content("sys")
392                .build()
393                .unwrap()
394                .into(),
395            ChatCompletionRequestUserMessageArgs::default()
396                .content("prev")
397                .build()
398                .unwrap()
399                .into(),
400        ];
401        let mut save_clone = store.clone();
402        tower::Service::call(
403            &mut save_clone,
404            SaveSession {
405                id: session_id.clone(),
406                history: prior.clone(),
407            },
408        )
409        .await
410        .unwrap();
411
412        let inner = service_fn(|req: CreateChatCompletionRequest| async move {
414            let mut msgs = req.messages.clone();
415            let asst = async_openai::types::ChatCompletionRequestAssistantMessageArgs::default()
416                .content("ok")
417                .build()
418                .unwrap();
419            msgs.push(asst.into());
420            Ok::<_, BoxError>(crate::core::StepOutcome::Done {
421                messages: msgs,
422                aux: Default::default(),
423            })
424        });
425
426        let layer = MemoryLayer::new(
427            Arc::new(store.clone()),
428            Arc::new(store.clone()),
429            session_id.clone(),
430        );
431        let mut svc = layer.layer(inner);
432
433        let req = req_with_messages(vec![ChatCompletionRequestUserMessageArgs::default()
434            .content("hello")
435            .build()
436            .unwrap()
437            .into()]);
438        let _ = tower::Service::call(&mut svc, req).await.unwrap();
439
440        let mut load = store.clone();
441        let history = tower::Service::call(&mut load, LoadSession { id: session_id })
442            .await
443            .unwrap();
444        assert_eq!(history.len(), 4); }
446
447    #[tokio::test]
448    async fn memory_layer_persists_step_outcome_next_messages() {
449        let store = InMemorySessionStore::default();
450        let session_id = SessionId("s3".into());
451
452        let inner = service_fn(|req: CreateChatCompletionRequest| async move {
454            let mut msgs = req.messages.clone();
455            let asst = async_openai::types::ChatCompletionRequestAssistantMessageArgs::default()
456                .content("next")
457                .build()
458                .unwrap();
459            msgs.push(asst.into());
460            Ok::<_, BoxError>(crate::core::StepOutcome::Next {
461                messages: msgs,
462                aux: Default::default(),
463                invoked_tools: vec![],
464            })
465        });
466
467        let layer = MemoryLayer::new(
468            Arc::new(store.clone()),
469            Arc::new(store.clone()),
470            session_id.clone(),
471        );
472        let mut svc = layer.layer(inner);
473
474        let req = req_with_messages(vec![ChatCompletionRequestUserMessageArgs::default()
475            .content("hey")
476            .build()
477            .unwrap()
478            .into()]);
479        let _ = tower::Service::call(&mut svc, req).await.unwrap();
480
481        let mut load = store.clone();
482        let history = tower::Service::call(&mut load, LoadSession { id: session_id })
483            .await
484            .unwrap();
485        assert_eq!(history.len(), 2); }
487
488    #[tokio::test]
489    async fn memory_layer_persists_agent_run_messages() {
490        let store = InMemorySessionStore::default();
491        let session_id = SessionId("s4".into());
492
493        let inner = service_fn(|req: CreateChatCompletionRequest| async move {
495            let mut msgs = req.messages.clone();
496            let asst = async_openai::types::ChatCompletionRequestAssistantMessageArgs::default()
497                .content("from agent")
498                .build()
499                .unwrap();
500            msgs.push(asst.into());
501            Ok::<_, BoxError>(crate::core::AgentRun {
502                messages: msgs,
503                steps: 1,
504                stop: crate::core::AgentStopReason::DoneNoToolCalls,
505            })
506        });
507
508        let layer = MemoryLayer::new(
509            Arc::new(store.clone()),
510            Arc::new(store.clone()),
511            session_id.clone(),
512        );
513        let mut svc = layer.layer(inner);
514
515        let req = req_with_messages(vec![ChatCompletionRequestUserMessageArgs::default()
516            .content("hey")
517            .build()
518            .unwrap()
519            .into()]);
520        let _ = tower::Service::call(&mut svc, req).await.unwrap();
521
522        let mut load = store.clone();
523        let history = tower::Service::call(&mut load, LoadSession { id: session_id })
524            .await
525            .unwrap();
526        assert_eq!(history.len(), 2); }
528}