smos_application/testkit/
sessions.rs1use std::collections::{HashMap, HashSet};
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14
15use smos_domain::{FactId, MemoryKey, SessionId, SessionState};
16
17use crate::errors::RepoError;
18use crate::ports::SessionRepository;
19
20#[derive(Default, Clone)]
21pub struct InMemorySessions {
22 sessions: Arc<Mutex<HashMap<String, SessionState>>>,
23 injected: Arc<Mutex<HashMap<String, HashSet<String>>>>,
24}
25
26impl InMemorySessions {
27 pub fn seed(&self, state: SessionState) {
28 self.sessions
29 .lock()
30 .unwrap()
31 .insert(state.id().as_str().to_string(), state);
32 }
33
34 pub fn pending_of(&self, id: &SessionId) -> Vec<FactId> {
35 self.sessions
36 .lock()
37 .unwrap()
38 .get(id.as_str())
39 .map(|s| s.pending_facts().to_vec())
40 .unwrap_or_default()
41 }
42}
43
44impl SessionRepository for InMemorySessions {
45 async fn get_or_create(
46 &self,
47 id: &SessionId,
48 memory_key: &MemoryKey,
49 ) -> Result<SessionState, RepoError> {
50 Ok(self
51 .sessions
52 .lock()
53 .unwrap()
54 .entry(id.as_str().to_string())
55 .or_insert_with(|| {
56 SessionState::new(
57 id.clone(),
58 memory_key.clone(),
59 smos_domain::Timestamp::from_unix_secs(0).unwrap(),
60 )
61 })
62 .clone())
63 }
64
65 async fn collect_expired(
66 &self,
67 _timeout: Duration,
68 ) -> Result<Vec<(SessionId, SessionState)>, RepoError> {
69 Ok(Vec::new())
70 }
71
72 async fn snapshot_all(&self) -> Result<Vec<(SessionId, SessionState)>, RepoError> {
73 Ok(self
74 .sessions
75 .lock()
76 .unwrap()
77 .iter()
78 .map(|(k, v)| (SessionId::from_raw(k).unwrap(), v.clone()))
79 .collect())
80 }
81
82 async fn add_pending(&self, id: &SessionId, fact_ids: &[FactId]) -> Result<(), RepoError> {
83 if let Some(state) = self.sessions.lock().unwrap().get_mut(id.as_str()) {
84 state.add_pending(fact_ids);
85 }
86 Ok(())
87 }
88
89 async fn remove_pending_owned(
90 &self,
91 id: &SessionId,
92 owned: &[FactId],
93 ) -> Result<(), RepoError> {
94 if let Some(state) = self.sessions.lock().unwrap().get_mut(id.as_str()) {
95 state.remove_owned(owned);
96 }
97 Ok(())
98 }
99
100 async fn clear_session(&self, id: &SessionId) -> Result<(), RepoError> {
101 self.sessions.lock().unwrap().remove(id.as_str());
102 Ok(())
103 }
104
105 async fn dedup_and_mark(
106 &self,
107 id: &SessionId,
108 _memory_key: &MemoryKey,
109 candidate_ids: &[FactId],
110 ) -> Result<Vec<FactId>, RepoError> {
111 let mut injected = self.injected.lock().unwrap();
112 let seen = injected.entry(id.as_str().to_string()).or_default();
113 let mut new_ids = Vec::new();
114 for cid in candidate_ids {
115 if seen.insert(cid.as_str().to_string()) {
116 new_ids.push(cid.clone());
117 }
118 }
119 Ok(new_ids)
120 }
121
122 async fn save(&self, id: &SessionId, state: &SessionState) -> Result<(), RepoError> {
123 self.sessions
124 .lock()
125 .unwrap()
126 .insert(id.as_str().to_string(), state.clone());
127 Ok(())
128 }
129}