sh_layer2/session_manager/
manager.rs1use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8
9use crate::session_manager::{
10 ReadWriteLock, Session, SessionConfig, SessionManagerTrait, SessionStats,
11};
12use crate::types::{AgentState, Layer2Error, Layer2Result, Message, SessionId, SessionMeta};
13
14struct SessionLock {
16 session: Session,
17 lock: ReadWriteLock,
18}
19
20pub struct ConcurrentSessionManager {
24 sessions: RwLock<HashMap<SessionId, SessionLock>>,
25 max_sessions: usize,
26}
27
28impl ConcurrentSessionManager {
29 pub fn new(max_sessions: usize) -> Self {
31 Self {
32 sessions: RwLock::new(HashMap::new()),
33 max_sessions,
34 }
35 }
36
37 pub fn default_config() -> Self {
39 Self::new(100)
40 }
41
42 #[allow(dead_code)]
44 fn get_session_lock(&self, id: &SessionId) -> Option<SessionLock> {
45 let guard = self.sessions.read();
46 guard.get(id).map(|s| SessionLock {
47 session: s.session.clone(),
48 lock: ReadWriteLock::new(), })
50 }
51
52 pub fn get_state_sync(&self, id: &SessionId) -> Option<AgentState> {
54 let guard = self.sessions.read();
55 guard.get(id).map(|s| s.session.state)
56 }
57}
58
59impl Default for ConcurrentSessionManager {
60 fn default() -> Self {
61 Self::default_config()
62 }
63}
64
65#[async_trait]
66impl SessionManagerTrait for ConcurrentSessionManager {
67 async fn create(&self, config: SessionConfig) -> Layer2Result<SessionId> {
68 let mut sessions = self.sessions.write();
69
70 if sessions.len() >= self.max_sessions {
71 return Err(Layer2Error::MaxSessionsReached(self.max_sessions).into());
72 }
73
74 let session = Session::new(&config);
75 let session_id = session.session_id.clone();
76
77 sessions.insert(
78 session_id.clone(),
79 SessionLock {
80 session,
81 lock: ReadWriteLock::new(),
82 },
83 );
84
85 Ok(session_id)
86 }
87
88 async fn get(&self, id: &SessionId) -> Layer2Result<Option<Session>> {
89 let sessions = self.sessions.read();
90 Ok(sessions.get(id).map(|s| s.session.clone()))
91 }
92
93 async fn get_or_create(
94 &self,
95 id: Option<&SessionId>,
96 config: SessionConfig,
97 ) -> Layer2Result<SessionId> {
98 let mut sessions = self.sessions.write();
99
100 if let Some(session_id) = id {
102 if sessions.contains_key(session_id) {
103 return Ok(session_id.clone());
104 }
105 }
106
107 if sessions.len() >= self.max_sessions {
109 return Err(Layer2Error::MaxSessionsReached(self.max_sessions).into());
110 }
111
112 let session = Session::new(&config);
114 let session_id = session.session_id.clone();
115
116 let final_id = id.cloned().unwrap_or_else(|| session_id.clone());
118
119 sessions.insert(
120 final_id.clone(),
121 SessionLock {
122 session,
123 lock: ReadWriteLock::new(),
124 },
125 );
126
127 Ok(final_id)
128 }
129
130 async fn save(&self, session: &Session) -> Layer2Result<()> {
131 let mut sessions = self.sessions.write();
132
133 if let Some(session_lock) = sessions.get_mut(&session.session_id) {
134 session_lock.session = session.clone();
135 session_lock.session.touch();
136 }
137
138 Ok(())
139 }
140
141 async fn delete(&self, id: &SessionId) -> Layer2Result<bool> {
142 let mut sessions = self.sessions.write();
143 Ok(sessions.remove(id).is_some())
144 }
145
146 async fn list(&self) -> Layer2Result<Vec<SessionMeta>> {
147 let sessions = self.sessions.read();
148 Ok(sessions
149 .values()
150 .map(|s| SessionMeta {
151 session_id: s.session.session_id.clone(),
152 agent_id: s.session.agent_id.clone(),
153 state: s.session.state,
154 created_at: s.session.created_at,
155 last_updated: s.session.last_updated,
156 message_count: s.session.messages.len(),
157 checkpoint_count: s.session.checkpoint_count,
158 })
159 .collect())
160 }
161
162 async fn update<F>(&self, id: &SessionId, update_fn: F) -> Layer2Result<bool>
163 where
164 F: FnOnce(&mut Session) + Send,
165 {
166 let mut sessions = self.sessions.write();
167
168 if let Some(session_lock) = sessions.get_mut(id) {
169 session_lock.lock.write(|| {
170 update_fn(&mut session_lock.session);
171 session_lock.session.touch();
172 });
173 Ok(true)
174 } else {
175 Ok(false)
176 }
177 }
178
179 async fn read<F, T>(&self, id: &SessionId, read_fn: F) -> Layer2Result<Option<T>>
180 where
181 F: FnOnce(&Session) -> T + Send,
182 T: Send,
183 {
184 let sessions = self.sessions.read();
185
186 if let Some(session_lock) = sessions.get(id) {
187 let result = session_lock.lock.read(|| read_fn(&session_lock.session));
189 Ok(Some(result))
190 } else {
191 Ok(None)
192 }
193 }
194
195 async fn get_state(&self, id: &SessionId) -> Layer2Result<Option<AgentState>> {
196 self.read(id, |s| s.state).await
197 }
198
199 async fn set_state(&self, id: &SessionId, state: AgentState) -> Layer2Result<bool> {
200 self.update(id, |s| s.state = state).await
201 }
202
203 async fn add_message(&self, id: &SessionId, message: Message) -> Layer2Result<bool> {
204 self.update(id, |s| {
205 s.messages.push(message);
206 s.iteration += 1;
207 })
208 .await
209 }
210
211 async fn get_messages(&self, id: &SessionId) -> Layer2Result<Option<Vec<Message>>> {
212 self.read(id, |s| s.messages.clone()).await
213 }
214
215 fn stats(&self) -> SessionStats {
216 let sessions = self.sessions.read();
217 SessionStats {
218 total_sessions: sessions.len(),
219 max_sessions: self.max_sessions,
220 active_sessions: sessions
221 .values()
222 .filter(|s| matches!(s.session.state, AgentState::Running))
223 .count(),
224 }
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[tokio::test]
233 async fn test_create_session() {
234 let manager = ConcurrentSessionManager::default_config();
235 let config = SessionConfig::default();
236
237 let session_id = manager.create(config).await.unwrap();
238 assert!(!session_id.0.is_empty());
239 }
240
241 #[tokio::test]
242 async fn test_get_session() {
243 let manager = ConcurrentSessionManager::default_config();
244 let config = SessionConfig::default();
245
246 let session_id = manager.create(config).await.unwrap();
247 let session = manager.get(&session_id).await.unwrap();
248
249 assert!(session.is_some());
250 assert_eq!(session.unwrap().session_id, session_id);
251 }
252
253 #[tokio::test]
254 async fn test_update_session() {
255 let manager = ConcurrentSessionManager::default_config();
256 let config = SessionConfig::default();
257
258 let session_id = manager.create(config).await.unwrap();
259
260 manager
261 .update(&session_id, |s| {
262 s.add_user_message("Hello");
263 })
264 .await
265 .unwrap();
266
267 let messages = manager.get_messages(&session_id).await.unwrap().unwrap();
268 assert_eq!(messages.len(), 1);
269 }
270
271 #[tokio::test]
272 async fn test_delete_session() {
273 let manager = ConcurrentSessionManager::default_config();
274 let config = SessionConfig::default();
275
276 let session_id = manager.create(config).await.unwrap();
277 let deleted = manager.delete(&session_id).await.unwrap();
278
279 assert!(deleted);
280
281 let session = manager.get(&session_id).await.unwrap();
282 assert!(session.is_none());
283 }
284
285 #[tokio::test]
286 async fn test_session_stats() {
287 let manager = ConcurrentSessionManager::new(10);
288
289 let config = SessionConfig::default();
290 manager.create(config).await.unwrap();
291
292 let stats = manager.stats();
293 assert_eq!(stats.total_sessions, 1);
294 assert_eq!(stats.max_sessions, 10);
295 }
296}