Skip to main content

sh_layer2/session_manager/
manager.rs

1//! # Concurrent Session Manager
2//!
3//! 并发安全的会话管理器实现。
4
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8
9use crate::session_manager::{
10    ReadWriteLock, Session, SessionConfig, SessionManagerTrait, SessionStats,
11};
12use crate::types::{AgentState, Layer2Error, Layer2Result, Message, SessionId, SessionMeta};
13
14/// 会话锁包装
15struct SessionLock {
16    session: Session,
17    lock: ReadWriteLock,
18}
19
20/// 并发安全会话管理器
21///
22/// 使用读写分离锁,读操作可并发,写操作互斥。
23pub struct ConcurrentSessionManager {
24    sessions: RwLock<HashMap<SessionId, SessionLock>>,
25    max_sessions: usize,
26}
27
28impl ConcurrentSessionManager {
29    /// 创建新的会话管理器
30    pub fn new(max_sessions: usize) -> Self {
31        Self {
32            sessions: RwLock::new(HashMap::new()),
33            max_sessions,
34        }
35    }
36
37    /// 使用默认配置创建
38    pub fn default_config() -> Self {
39        Self::new(100)
40    }
41
42    /// 获取会话锁
43    #[allow(dead_code)]
44    fn get_session_lock(&self, id: &SessionId) -> Option<SessionLock> {
45        let guard = self.sessions.read();
46        guard.get(id).map(|s| SessionLock {
47            session: s.session.clone(),
48            lock: ReadWriteLock::new(), // 每次返回新的锁实例
49        })
50    }
51
52    /// 同步获取会话状态(用于 AgentRuntime::status 同步方法)
53    pub fn get_state_sync(&self, id: &SessionId) -> Option<AgentState> {
54        let guard = self.sessions.read();
55        guard.get(id).map(|s| s.session.state)
56    }
57}
58
59impl Default for ConcurrentSessionManager {
60    fn default() -> Self {
61        Self::default_config()
62    }
63}
64
65#[async_trait]
66impl SessionManagerTrait for ConcurrentSessionManager {
67    async fn create(&self, config: SessionConfig) -> Layer2Result<SessionId> {
68        let mut sessions = self.sessions.write();
69
70        if sessions.len() >= self.max_sessions {
71            return Err(Layer2Error::MaxSessionsReached(self.max_sessions).into());
72        }
73
74        let session = Session::new(&config);
75        let session_id = session.session_id.clone();
76
77        sessions.insert(
78            session_id.clone(),
79            SessionLock {
80                session,
81                lock: ReadWriteLock::new(),
82            },
83        );
84
85        Ok(session_id)
86    }
87
88    async fn get(&self, id: &SessionId) -> Layer2Result<Option<Session>> {
89        let sessions = self.sessions.read();
90        Ok(sessions.get(id).map(|s| s.session.clone()))
91    }
92
93    async fn get_or_create(
94        &self,
95        id: Option<&SessionId>,
96        config: SessionConfig,
97    ) -> Layer2Result<SessionId> {
98        let mut sessions = self.sessions.write();
99
100        // 如果指定了 ID 且存在,直接返回
101        if let Some(session_id) = id {
102            if sessions.contains_key(session_id) {
103                return Ok(session_id.clone());
104            }
105        }
106
107        // 检查限制
108        if sessions.len() >= self.max_sessions {
109            return Err(Layer2Error::MaxSessionsReached(self.max_sessions).into());
110        }
111
112        // 创建新会话
113        let session = Session::new(&config);
114        let session_id = session.session_id.clone();
115
116        // 如果指定了 ID,使用指定的 ID
117        let final_id = id.cloned().unwrap_or_else(|| session_id.clone());
118
119        sessions.insert(
120            final_id.clone(),
121            SessionLock {
122                session,
123                lock: ReadWriteLock::new(),
124            },
125        );
126
127        Ok(final_id)
128    }
129
130    async fn save(&self, session: &Session) -> Layer2Result<()> {
131        let mut sessions = self.sessions.write();
132
133        if let Some(session_lock) = sessions.get_mut(&session.session_id) {
134            session_lock.session = session.clone();
135            session_lock.session.touch();
136        }
137
138        Ok(())
139    }
140
141    async fn delete(&self, id: &SessionId) -> Layer2Result<bool> {
142        let mut sessions = self.sessions.write();
143        Ok(sessions.remove(id).is_some())
144    }
145
146    async fn list(&self) -> Layer2Result<Vec<SessionMeta>> {
147        let sessions = self.sessions.read();
148        Ok(sessions
149            .values()
150            .map(|s| SessionMeta {
151                session_id: s.session.session_id.clone(),
152                agent_id: s.session.agent_id.clone(),
153                state: s.session.state,
154                created_at: s.session.created_at,
155                last_updated: s.session.last_updated,
156                message_count: s.session.messages.len(),
157                checkpoint_count: s.session.checkpoint_count,
158            })
159            .collect())
160    }
161
162    async fn update<F>(&self, id: &SessionId, update_fn: F) -> Layer2Result<bool>
163    where
164        F: FnOnce(&mut Session) + Send,
165    {
166        let mut sessions = self.sessions.write();
167
168        if let Some(session_lock) = sessions.get_mut(id) {
169            session_lock.lock.write(|| {
170                update_fn(&mut session_lock.session);
171                session_lock.session.touch();
172            });
173            Ok(true)
174        } else {
175            Ok(false)
176        }
177    }
178
179    async fn read<F, T>(&self, id: &SessionId, read_fn: F) -> Layer2Result<Option<T>>
180    where
181        F: FnOnce(&Session) -> T + Send,
182        T: Send,
183    {
184        let sessions = self.sessions.read();
185
186        if let Some(session_lock) = sessions.get(id) {
187            // 使用读锁
188            let result = session_lock.lock.read(|| read_fn(&session_lock.session));
189            Ok(Some(result))
190        } else {
191            Ok(None)
192        }
193    }
194
195    async fn get_state(&self, id: &SessionId) -> Layer2Result<Option<AgentState>> {
196        self.read(id, |s| s.state).await
197    }
198
199    async fn set_state(&self, id: &SessionId, state: AgentState) -> Layer2Result<bool> {
200        self.update(id, |s| s.state = state).await
201    }
202
203    async fn add_message(&self, id: &SessionId, message: Message) -> Layer2Result<bool> {
204        self.update(id, |s| {
205            s.messages.push(message);
206            s.iteration += 1;
207        })
208        .await
209    }
210
211    async fn get_messages(&self, id: &SessionId) -> Layer2Result<Option<Vec<Message>>> {
212        self.read(id, |s| s.messages.clone()).await
213    }
214
215    fn stats(&self) -> SessionStats {
216        let sessions = self.sessions.read();
217        SessionStats {
218            total_sessions: sessions.len(),
219            max_sessions: self.max_sessions,
220            active_sessions: sessions
221                .values()
222                .filter(|s| matches!(s.session.state, AgentState::Running))
223                .count(),
224        }
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[tokio::test]
233    async fn test_create_session() {
234        let manager = ConcurrentSessionManager::default_config();
235        let config = SessionConfig::default();
236
237        let session_id = manager.create(config).await.unwrap();
238        assert!(!session_id.0.is_empty());
239    }
240
241    #[tokio::test]
242    async fn test_get_session() {
243        let manager = ConcurrentSessionManager::default_config();
244        let config = SessionConfig::default();
245
246        let session_id = manager.create(config).await.unwrap();
247        let session = manager.get(&session_id).await.unwrap();
248
249        assert!(session.is_some());
250        assert_eq!(session.unwrap().session_id, session_id);
251    }
252
253    #[tokio::test]
254    async fn test_update_session() {
255        let manager = ConcurrentSessionManager::default_config();
256        let config = SessionConfig::default();
257
258        let session_id = manager.create(config).await.unwrap();
259
260        manager
261            .update(&session_id, |s| {
262                s.add_user_message("Hello");
263            })
264            .await
265            .unwrap();
266
267        let messages = manager.get_messages(&session_id).await.unwrap().unwrap();
268        assert_eq!(messages.len(), 1);
269    }
270
271    #[tokio::test]
272    async fn test_delete_session() {
273        let manager = ConcurrentSessionManager::default_config();
274        let config = SessionConfig::default();
275
276        let session_id = manager.create(config).await.unwrap();
277        let deleted = manager.delete(&session_id).await.unwrap();
278
279        assert!(deleted);
280
281        let session = manager.get(&session_id).await.unwrap();
282        assert!(session.is_none());
283    }
284
285    #[tokio::test]
286    async fn test_session_stats() {
287        let manager = ConcurrentSessionManager::new(10);
288
289        let config = SessionConfig::default();
290        manager.create(config).await.unwrap();
291
292        let stats = manager.stats();
293        assert_eq!(stats.total_sessions, 1);
294        assert_eq!(stats.max_sessions, 10);
295    }
296}