Skip to main content

smcp_server_core/
session.rs

1//! 会话管理模块 / Session management module
2
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6use thiserror::Error;
7
8// 类型别名
9pub type OfficeId = String;
10pub type SessionId = String;
11
12/// 会话错误类型
13#[derive(Error, Debug, serde::Serialize)]
14pub enum SessionError {
15    #[error("Session not found: {0}")]
16    NotFound(String),
17    #[error("Name already registered: {0}")]
18    NameAlreadyRegistered(String),
19    #[error("Agent already in room: {0}")]
20    AgentAlreadyInRoom(OfficeId),
21    #[error("Agent already exists in room")]
22    AgentAlreadyExists,
23    #[error("Computer with name '{0}' already exists in room '{1}'")]
24    ComputerAlreadyExists(String, OfficeId),
25    #[error("Invalid session state: {0}")]
26    InvalidState(String),
27}
28
29impl SessionError {
30    /// 获取错误码 / Get error code
31    pub fn error_code(&self) -> i32 {
32        match self {
33            SessionError::NotFound(_) => smcp::error_codes::NOT_FOUND,
34            SessionError::NameAlreadyRegistered(_) => smcp::error_codes::FORBIDDEN,
35            SessionError::AgentAlreadyInRoom(_) => smcp::error_codes::FORBIDDEN,
36            SessionError::AgentAlreadyExists => smcp::error_codes::ROOM_FULL,
37            SessionError::ComputerAlreadyExists(_, _) => smcp::error_codes::FORBIDDEN,
38            SessionError::InvalidState(_) => smcp::error_codes::BAD_REQUEST,
39        }
40    }
41}
42
43/// 客户端角色
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
45pub enum ClientRole {
46    Agent,
47    Computer,
48}
49
50impl From<smcp::Role> for ClientRole {
51    fn from(role: smcp::Role) -> Self {
52        match role {
53            smcp::Role::Agent => ClientRole::Agent,
54            smcp::Role::Computer => ClientRole::Computer,
55        }
56    }
57}
58
59impl From<ClientRole> for smcp::Role {
60    fn from(role: ClientRole) -> Self {
61        match role {
62            ClientRole::Agent => smcp::Role::Agent,
63            ClientRole::Computer => smcp::Role::Computer,
64        }
65    }
66}
67
68impl std::fmt::Display for ClientRole {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        match self {
71            ClientRole::Agent => write!(f, "agent"),
72            ClientRole::Computer => write!(f, "computer"),
73        }
74    }
75}
76
77/// 会话数据
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct SessionData {
80    /// 会话 ID
81    pub sid: SessionId,
82    /// 客户端名称
83    pub name: String,
84    /// 客户端角色
85    pub role: ClientRole,
86    /// 当前所在的办公室 ID
87    pub office_id: Option<OfficeId>,
88    /// 其他扩展数据
89    pub extra: serde_json::Value,
90}
91
92impl SessionData {
93    /// 创建新的会话数据
94    pub fn new(sid: SessionId, name: String, role: ClientRole) -> Self {
95        Self {
96            sid,
97            name,
98            role,
99            office_id: None,
100            extra: serde_json::Value::Object(Default::default()),
101        }
102    }
103
104    /// 设置办公室 ID
105    pub fn with_office_id(mut self, office_id: OfficeId) -> Self {
106        self.office_id = Some(office_id);
107        self
108    }
109
110    /// 设置扩展数据
111    pub fn with_extra(mut self, extra: serde_json::Value) -> Self {
112        self.extra = extra;
113        self
114    }
115}
116
117/// 会话管理器
118#[derive(Debug)]
119pub struct SessionManager {
120    /// sid -> session_data 映射
121    sessions: Arc<DashMap<SessionId, SessionData>>,
122    /// name -> sid 映射(用于通过 name 查找 session)
123    name_to_sid: Arc<DashMap<String, SessionId>>,
124}
125
126impl SessionManager {
127    fn name_key(role: &ClientRole, office_id: Option<&OfficeId>, name: &str) -> String {
128        match role {
129            ClientRole::Agent => format!("agent:{}", name),
130            ClientRole::Computer => match office_id {
131                Some(office_id) => format!("computer:{}:{}", office_id, name),
132                None => format!("computer::{}", name),
133            },
134        }
135    }
136
137    /// 创建新的会话管理器
138    pub fn new() -> Self {
139        Self {
140            sessions: Arc::new(DashMap::new()),
141            name_to_sid: Arc::new(DashMap::new()),
142        }
143    }
144
145    /// 注册新会话
146    pub fn register_session(&self, session: SessionData) -> Result<(), SessionError> {
147        let key = Self::name_key(&session.role, session.office_id.as_ref(), &session.name);
148        // 检查 name 是否已被其他 sid 使用
149        if let Some(existing_sid) = self.name_to_sid.get(&key) {
150            if *existing_sid != session.sid {
151                return Err(SessionError::NameAlreadyRegistered(session.name));
152            }
153            // 如果是同一个 sid 重新注册,允许(幂等操作)
154            tracing::debug!("Name '{}' re-registered by same sid", session.name);
155            return Ok(());
156        }
157
158        // 注册映射
159        self.sessions.insert(session.sid.clone(), session.clone());
160        self.name_to_sid.insert(key, session.sid.clone());
161
162        tracing::debug!("Registered session: {} -> {}", session.name, session.sid);
163        Ok(())
164    }
165
166    /// 注销会话
167    pub fn unregister_session(&self, sid: &SessionId) -> Option<SessionData> {
168        let session = self.sessions.remove(sid)?;
169
170        // 清理 name 映射
171        let key = Self::name_key(
172            &session.1.role,
173            session.1.office_id.as_ref(),
174            &session.1.name,
175        );
176        self.name_to_sid.remove(&key);
177
178        tracing::debug!("Unregistered session: {} -> {}", session.1.name, sid);
179        Some(session.1)
180    }
181
182    /// 获取会话数据
183    pub fn get_session(&self, sid: &SessionId) -> Option<SessionData> {
184        self.sessions.get(sid).map(|s| s.clone())
185    }
186
187    /// 通过名称获取会话 ID
188    pub fn get_sid_by_name(&self, name: &str) -> Option<SessionId> {
189        let key = Self::name_key(&ClientRole::Agent, None, name);
190        self.name_to_sid.get(&key).map(|s| s.clone())
191    }
192
193    /// 更新会话的办公室 ID
194    pub fn update_office_id(
195        &self,
196        sid: &SessionId,
197        office_id: Option<OfficeId>,
198    ) -> Result<(), SessionError> {
199        let mut session = self
200            .sessions
201            .get_mut(sid)
202            .ok_or_else(|| SessionError::NotFound(sid.clone()))?;
203
204        let role = session.role.clone();
205        let name = session.name.clone();
206        let old_office_id = session.office_id.clone();
207
208        let old_key = Self::name_key(&role, old_office_id.as_ref(), &name);
209        let new_key = Self::name_key(&role, office_id.as_ref(), &name);
210
211        if old_key != new_key {
212            if let Some(existing_sid) = self.name_to_sid.get(&new_key) {
213                if *existing_sid != *sid {
214                    return Err(SessionError::NameAlreadyRegistered(name));
215                }
216            }
217
218            self.name_to_sid.remove(&old_key);
219            self.name_to_sid.insert(new_key, sid.clone());
220        }
221
222        session.office_id = office_id;
223        Ok(())
224    }
225
226    /// 获取指定办公室内的所有会话
227    pub fn get_sessions_in_office(&self, office_id: &OfficeId) -> Vec<SessionData> {
228        self.sessions
229            .iter()
230            .filter(|s| s.office_id.as_ref() == Some(office_id))
231            .map(|s| s.clone())
232            .collect()
233    }
234
235    /// 检查房间内是否已有 Agent
236    pub fn has_agent_in_office(&self, office_id: &OfficeId) -> bool {
237        self.sessions
238            .iter()
239            .any(|s| s.office_id.as_ref() == Some(office_id) && s.role == ClientRole::Agent)
240    }
241
242    /// 检查房间内是否有指定名称的 Computer
243    pub fn has_computer_in_office(&self, office_id: &OfficeId, name: &str) -> bool {
244        self.sessions.iter().any(|s| {
245            s.office_id.as_ref() == Some(office_id)
246                && s.role == ClientRole::Computer
247                && s.name == name
248        })
249    }
250
251    /// 获取房间内指定 Computer 的 sid
252    pub fn get_computer_sid_in_office(
253        &self,
254        office_id: &OfficeId,
255        name: &str,
256    ) -> Option<SessionId> {
257        self.sessions.iter().find_map(|s| {
258            if s.office_id.as_ref() == Some(office_id)
259                && s.role == ClientRole::Computer
260                && s.name == name
261            {
262                Some(s.sid.clone())
263            } else {
264                None
265            }
266        })
267    }
268
269    /// 获取所有会话
270    pub fn get_all_sessions(&self) -> Vec<SessionData> {
271        self.sessions.iter().map(|s| s.clone()).collect()
272    }
273
274    /// 获取会话统计信息
275    pub fn get_stats(&self) -> SessionStats {
276        let total = self.sessions.len();
277        let agents = self
278            .sessions
279            .iter()
280            .filter(|s| s.role == ClientRole::Agent)
281            .count();
282        let computers = self
283            .sessions
284            .iter()
285            .filter(|s| s.role == ClientRole::Computer)
286            .count();
287
288        SessionStats {
289            total,
290            agents,
291            computers,
292        }
293    }
294}
295
296/// 会话统计信息
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct SessionStats {
299    /// 总会话数
300    pub total: usize,
301    /// Agent 数量
302    pub agents: usize,
303    /// Computer 数量
304    pub computers: usize,
305}
306
307impl Default for SessionManager {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use serde_json::json;
317    use uuid::Uuid;
318
319    #[test]
320    fn test_session_registration() {
321        let manager = SessionManager::new();
322        let sid = Uuid::new_v4().to_string();
323        let session = SessionData::new(sid.clone(), "test_agent".to_string(), ClientRole::Agent);
324
325        // 注册会话
326        assert!(manager.register_session(session).is_ok());
327
328        // 获取会话
329        let retrieved = manager.get_session(&sid);
330        assert!(retrieved.is_some());
331        assert_eq!(retrieved.unwrap().name, "test_agent");
332
333        // 通过名称获取 sid
334        let found_sid = manager.get_sid_by_name("test_agent");
335        assert_eq!(found_sid, Some(sid));
336    }
337
338    #[test]
339    fn test_duplicate_name_registration() {
340        let manager = SessionManager::new();
341        let sid1 = Uuid::new_v4().to_string();
342        let sid2 = Uuid::new_v4().to_string();
343        let sid3 = Uuid::new_v4().to_string();
344        let sid4 = Uuid::new_v4().to_string();
345
346        let session1 = SessionData::new(
347            sid1.clone(),
348            "duplicate_name".to_string(),
349            ClientRole::Agent,
350        );
351        let session2 = SessionData::new(
352            sid2.clone(),
353            "duplicate_name".to_string(),
354            ClientRole::Agent,
355        );
356
357        let session3 = SessionData::new(
358            sid3.clone(),
359            "duplicate_name".to_string(),
360            ClientRole::Computer,
361        )
362        .with_office_id("office1".to_string());
363
364        let session4 = SessionData::new(
365            sid4.clone(),
366            "duplicate_name".to_string(),
367            ClientRole::Computer,
368        )
369        .with_office_id("office2".to_string());
370
371        // 第一个注册成功
372        assert!(manager.register_session(session1).is_ok());
373
374        // 第二个注册失败(Agent 名称全局唯一)
375        assert!(manager.register_session(session2).is_err());
376
377        // Computer 名称按 office 唯一:同 office 冲突
378        assert!(manager.register_session(session3.clone()).is_ok());
379        assert!(manager.register_session(session3).is_ok());
380
381        let dup_same_office = SessionData::new(
382            Uuid::new_v4().to_string(),
383            "duplicate_name".to_string(),
384            ClientRole::Computer,
385        )
386        .with_office_id("office1".to_string());
387        assert!(manager.register_session(dup_same_office).is_err());
388
389        // 不同 office 允许同名
390        assert!(manager.register_session(session4).is_ok());
391    }
392
393    #[test]
394    fn test_office_management() {
395        let manager = SessionManager::new();
396        let office_id = "office_123".to_string();
397        let sid = Uuid::new_v4().to_string();
398        let session = SessionData::new(
399            sid.clone(),
400            "test_computer".to_string(),
401            ClientRole::Computer,
402        )
403        .with_office_id(office_id.clone());
404
405        manager.register_session(session).unwrap();
406
407        // 检查房间内的会话
408        let sessions = manager.get_sessions_in_office(&office_id);
409        assert_eq!(sessions.len(), 1);
410        assert_eq!(sessions[0].sid, sid);
411
412        // 检查是否有 Agent
413        assert!(!manager.has_agent_in_office(&office_id));
414
415        // 检查是否有指定 Computer
416        assert!(manager.has_computer_in_office(&office_id, "test_computer"));
417    }
418
419    #[test]
420    fn test_session_unregistration() {
421        let manager = SessionManager::new();
422        let sid = Uuid::new_v4().to_string();
423        let session = SessionData::new(sid.clone(), "test_agent".to_string(), ClientRole::Agent);
424
425        manager.register_session(session).unwrap();
426
427        // 注销会话
428        let removed = manager.unregister_session(&sid);
429        assert!(removed.is_some());
430
431        // 验证会话已删除
432        assert!(manager.get_session(&sid).is_none());
433        assert!(manager.get_sid_by_name("test_agent").is_none());
434    }
435
436    #[test]
437    fn test_stats() {
438        let manager = SessionManager::new();
439
440        // 添加一些会话
441        let agent_session = SessionData::new(
442            Uuid::new_v4().to_string(),
443            "agent1".to_string(),
444            ClientRole::Agent,
445        );
446        let computer_session1 = SessionData::new(
447            Uuid::new_v4().to_string(),
448            "computer1".to_string(),
449            ClientRole::Computer,
450        );
451        let computer_session2 = SessionData::new(
452            Uuid::new_v4().to_string(),
453            "computer2".to_string(),
454            ClientRole::Computer,
455        );
456
457        manager.register_session(agent_session).unwrap();
458        manager.register_session(computer_session1).unwrap();
459        manager.register_session(computer_session2).unwrap();
460
461        let stats = manager.get_stats();
462        assert_eq!(stats.total, 3);
463        assert_eq!(stats.agents, 1);
464        assert_eq!(stats.computers, 2);
465    }
466
467    #[test]
468    fn test_register_session_idempotent_same_sid() {
469        let manager = SessionManager::new();
470        let sid = Uuid::new_v4().to_string();
471        let session1 = SessionData::new(sid.clone(), "same_name".to_string(), ClientRole::Agent);
472        let session2 = SessionData::new(sid.clone(), "same_name".to_string(), ClientRole::Agent);
473
474        assert!(manager.register_session(session1).is_ok());
475        assert!(manager.register_session(session2).is_ok());
476
477        let retrieved = manager.get_session(&sid).unwrap();
478        assert_eq!(retrieved.name, "same_name");
479    }
480
481    #[test]
482    fn test_update_office_id_ok_and_not_found() {
483        let manager = SessionManager::new();
484        let sid = Uuid::new_v4().to_string();
485        let session = SessionData::new(sid.clone(), "test".to_string(), ClientRole::Agent);
486        manager.register_session(session).unwrap();
487
488        assert!(manager
489            .update_office_id(&sid, Some("office_x".to_string()))
490            .is_ok());
491        assert_eq!(
492            manager.get_session(&sid).unwrap().office_id,
493            Some("office_x".to_string())
494        );
495
496        let missing_sid = Uuid::new_v4().to_string();
497        let err = manager
498            .update_office_id(&missing_sid, Some("office_y".to_string()))
499            .unwrap_err();
500        assert!(matches!(err, SessionError::NotFound(s) if s == missing_sid));
501    }
502
503    #[test]
504    fn test_get_all_sessions() {
505        let manager = SessionManager::new();
506        let s1 = SessionData::new(
507            Uuid::new_v4().to_string(),
508            "a1".to_string(),
509            ClientRole::Agent,
510        );
511        let s2 = SessionData::new(
512            Uuid::new_v4().to_string(),
513            "c1".to_string(),
514            ClientRole::Computer,
515        );
516        manager.register_session(s1).unwrap();
517        manager.register_session(s2).unwrap();
518
519        let all = manager.get_all_sessions();
520        assert_eq!(all.len(), 2);
521    }
522
523    #[test]
524    fn test_client_role_convert_and_display() {
525        let agent: ClientRole = smcp::Role::Agent.into();
526        let computer: ClientRole = smcp::Role::Computer.into();
527        assert_eq!(agent.to_string(), "agent");
528        assert_eq!(computer.to_string(), "computer");
529
530        let back_agent: smcp::Role = agent.into();
531        let back_computer: smcp::Role = computer.into();
532        assert!(matches!(back_agent, smcp::Role::Agent));
533        assert!(matches!(back_computer, smcp::Role::Computer));
534    }
535
536    #[test]
537    fn test_session_data_with_extra() {
538        let sid = Uuid::new_v4().to_string();
539        let extra = json!({"k": "v", "n": 1});
540        let session =
541            SessionData::new(sid, "n".to_string(), ClientRole::Computer).with_extra(extra.clone());
542        assert_eq!(session.extra, extra);
543    }
544}