Skip to main content

traitclaw_test_utils/
memory.rs

1//! Mock memory backend for testing.
2//!
3//! [`MockMemory`] provides per-session message storage using
4//! [`tokio::sync::Mutex`] for async-safe access in tests.
5//!
6//! Unlike the production [`InMemoryMemory`](traitclaw_core::memory::in_memory::InMemoryMemory),
7//! this mock supports multi-session isolation via a `HashMap`.
8//!
9//! # Example
10//!
11//! ```rust
12//! use traitclaw_test_utils::memory::MockMemory;
13//!
14//! let memory = MockMemory::new();
15//! // Use with AgentRuntime or Agent for isolated test sessions
16//! ```
17
18use std::collections::HashMap;
19
20use async_trait::async_trait;
21use tokio::sync::Mutex;
22
23use traitclaw_core::traits::memory::MemoryEntry;
24use traitclaw_core::types::message::Message;
25use traitclaw_core::{Memory, Result};
26
27/// In-memory mock that stores messages per session.
28///
29/// Each session gets its own `Vec<Message>`, ensuring test
30/// isolation when multiple sessions are used concurrently.
31///
32/// # Example
33///
34/// ```rust
35/// use traitclaw_test_utils::memory::MockMemory;
36/// use traitclaw_core::types::message::Message;
37/// use traitclaw_core::Memory;
38///
39/// # tokio_test::block_on(async {
40/// let mem = MockMemory::new();
41/// mem.append("s1", Message::user("hi")).await.unwrap();
42/// let msgs = mem.messages("s1").await.unwrap();
43/// assert_eq!(msgs.len(), 1);
44/// # });
45/// ```
46pub struct MockMemory {
47    /// Per-session message storage.
48    messages: Mutex<HashMap<String, Vec<Message>>>,
49}
50
51impl MockMemory {
52    /// Create a new empty mock memory.
53    pub fn new() -> Self {
54        Self {
55            messages: Mutex::new(HashMap::new()),
56        }
57    }
58}
59
60impl Default for MockMemory {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66#[async_trait]
67impl Memory for MockMemory {
68    async fn messages(&self, session_id: &str) -> Result<Vec<Message>> {
69        let store = self.messages.lock().await;
70        Ok(store.get(session_id).cloned().unwrap_or_default())
71    }
72
73    async fn append(&self, session_id: &str, message: Message) -> Result<()> {
74        let mut store = self.messages.lock().await;
75        store
76            .entry(session_id.to_string())
77            .or_default()
78            .push(message);
79        Ok(())
80    }
81
82    async fn get_context(
83        &self,
84        _session_id: &str,
85        _key: &str,
86    ) -> Result<Option<serde_json::Value>> {
87        Ok(None)
88    }
89
90    async fn set_context(
91        &self,
92        _session_id: &str,
93        _key: &str,
94        _value: serde_json::Value,
95    ) -> Result<()> {
96        Ok(())
97    }
98
99    async fn recall(&self, _query: &str, _limit: usize) -> Result<Vec<MemoryEntry>> {
100        Ok(vec![])
101    }
102
103    async fn store(&self, _entry: MemoryEntry) -> Result<()> {
104        Ok(())
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[tokio::test]
113    async fn test_new_memory_returns_empty_messages() {
114        let mem = MockMemory::new();
115        let msgs = mem.messages("session-1").await.unwrap();
116        assert!(msgs.is_empty());
117    }
118
119    #[tokio::test]
120    async fn test_append_and_retrieve() {
121        let mem = MockMemory::new();
122        mem.append("s1", Message::user("hello")).await.unwrap();
123        mem.append("s1", Message::assistant("hi")).await.unwrap();
124
125        let msgs = mem.messages("s1").await.unwrap();
126        assert_eq!(msgs.len(), 2);
127    }
128
129    #[tokio::test]
130    async fn test_sessions_are_isolated() {
131        let mem = MockMemory::new();
132        mem.append("s1", Message::user("one")).await.unwrap();
133        mem.append("s2", Message::user("two")).await.unwrap();
134
135        assert_eq!(mem.messages("s1").await.unwrap().len(), 1);
136        assert_eq!(mem.messages("s2").await.unwrap().len(), 1);
137        assert!(mem.messages("s3").await.unwrap().is_empty());
138    }
139
140    #[tokio::test]
141    async fn test_get_context_returns_none() {
142        let mem = MockMemory::new();
143        let ctx = mem.get_context("s1", "key").await.unwrap();
144        assert!(ctx.is_none());
145    }
146
147    #[tokio::test]
148    async fn test_recall_returns_empty() {
149        let mem = MockMemory::new();
150        let entries = mem.recall("query", 10).await.unwrap();
151        assert!(entries.is_empty());
152    }
153
154    #[test]
155    fn test_mock_memory_is_send_sync() {
156        fn assert_send_sync<T: Send + Sync>() {}
157        assert_send_sync::<MockMemory>();
158    }
159}