1use crate::error::{SageError, SageResult};
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use tokio::sync::{mpsc, RwLock};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct SessionId(u64);
20
21impl SessionId {
22 #[must_use]
24 pub fn new(id: u64) -> Self {
25 Self(id)
26 }
27
28 #[must_use]
30 pub fn value(&self) -> u64 {
31 self.0
32 }
33}
34
35impl std::fmt::Display for SessionId {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 write!(f, "session-{}", self.0)
38 }
39}
40
41#[derive(Debug, Clone)]
46pub struct SenderHandle {
47 reply_tx: mpsc::Sender<crate::agent::Message>,
49 pub protocol: Option<String>,
51 pub session_id: Option<SessionId>,
53}
54
55impl SenderHandle {
56 #[must_use]
58 pub fn new(
59 reply_tx: mpsc::Sender<crate::agent::Message>,
60 protocol: Option<String>,
61 session_id: Option<SessionId>,
62 ) -> Self {
63 Self {
64 reply_tx,
65 protocol,
66 session_id,
67 }
68 }
69
70 pub async fn send<M: serde::Serialize>(&self, msg: M) -> SageResult<()> {
72 let message = crate::agent::Message::new(msg)?;
73 self.reply_tx
74 .send(message)
75 .await
76 .map_err(|e| SageError::Agent(format!("Failed to send reply: {e}")))
77 }
78}
79
80#[derive(Debug)]
82pub struct SessionState {
83 pub protocol: String,
85 pub state: Box<dyn ProtocolStateMachine>,
87 pub role: String,
89 pub partner: SenderHandle,
91}
92
93#[derive(Debug, Default)]
98pub struct SessionRegistry {
99 sessions: HashMap<SessionId, SessionState>,
101 next_session_id: AtomicU64,
103}
104
105impl SessionRegistry {
106 #[must_use]
108 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn next_id(&self) -> SessionId {
114 SessionId(self.next_session_id.fetch_add(1, Ordering::SeqCst))
115 }
116
117 pub fn start_session(
119 &mut self,
120 session_id: SessionId,
121 protocol: String,
122 role: String,
123 state: Box<dyn ProtocolStateMachine>,
124 partner: SenderHandle,
125 ) {
126 self.sessions.insert(
127 session_id,
128 SessionState {
129 protocol,
130 state,
131 role,
132 partner,
133 },
134 );
135 }
136
137 #[must_use]
139 pub fn get(&self, session_id: &SessionId) -> Option<&SessionState> {
140 self.sessions.get(session_id)
141 }
142
143 pub fn get_mut(&mut self, session_id: &SessionId) -> Option<&mut SessionState> {
145 self.sessions.get_mut(session_id)
146 }
147
148 pub fn remove(&mut self, session_id: &SessionId) -> Option<SessionState> {
150 self.sessions.remove(session_id)
151 }
152
153 #[must_use]
155 pub fn has(&self, session_id: &SessionId) -> bool {
156 self.sessions.contains_key(session_id)
157 }
158
159 #[must_use]
161 pub fn len(&self) -> usize {
162 self.sessions.len()
163 }
164
165 #[must_use]
167 pub fn is_empty(&self) -> bool {
168 self.sessions.is_empty()
169 }
170}
171
172#[derive(Debug, Clone)]
174pub enum ProtocolViolation {
175 UnexpectedMessage {
177 protocol: String,
179 expected: String,
181 received: String,
183 state: String,
185 },
186
187 EarlyTermination {
189 protocol: String,
191 state: String,
193 },
194
195 WrongSender {
197 protocol: String,
199 expected_role: String,
201 actual_role: String,
203 },
204
205 NoSession {
207 session_id: SessionId,
209 },
210
211 ReplyOutsideHandler,
213}
214
215impl std::fmt::Display for ProtocolViolation {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 match self {
218 ProtocolViolation::UnexpectedMessage {
219 protocol,
220 expected,
221 received,
222 state,
223 } => write!(
224 f,
225 "unexpected message in protocol '{}': expected '{}', got '{}' (state: {})",
226 protocol, expected, received, state
227 ),
228 ProtocolViolation::EarlyTermination { protocol, state } => {
229 write!(
230 f,
231 "protocol '{}' terminated early in state '{}'",
232 protocol, state
233 )
234 }
235 ProtocolViolation::WrongSender {
236 protocol,
237 expected_role,
238 actual_role,
239 } => write!(
240 f,
241 "wrong sender in protocol '{}': expected role '{}', got '{}'",
242 protocol, expected_role, actual_role
243 ),
244 ProtocolViolation::NoSession { session_id } => {
245 write!(f, "no session found with id {}", session_id)
246 }
247 ProtocolViolation::ReplyOutsideHandler => {
248 write!(f, "reply() called outside of message handler")
249 }
250 }
251 }
252}
253
254impl From<ProtocolViolation> for SageError {
255 fn from(v: ProtocolViolation) -> Self {
256 SageError::Protocol(v.to_string())
257 }
258}
259
260pub trait ProtocolStateMachine: Send + Sync + std::fmt::Debug {
265 fn state_name(&self) -> &str;
267
268 fn can_send(&self, msg_type: &str, from_role: &str) -> bool;
270
271 fn can_receive(&self, msg_type: &str, to_role: &str) -> bool;
273
274 fn transition(&mut self, msg_type: &str) -> Result<(), ProtocolViolation>;
280
281 fn is_terminal(&self) -> bool;
283
284 fn protocol_name(&self) -> &str;
286
287 fn clone_box(&self) -> Box<dyn ProtocolStateMachine>;
289}
290
291pub type SharedSessionRegistry = Arc<RwLock<SessionRegistry>>;
293
294#[must_use]
296pub fn shared_registry() -> SharedSessionRegistry {
297 Arc::new(RwLock::new(SessionRegistry::new()))
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn session_id_display() {
306 let id = SessionId::new(42);
307 assert_eq!(format!("{}", id), "session-42");
308 assert_eq!(id.value(), 42);
309 }
310
311 #[test]
312 fn session_registry_basic() {
313 let registry = SessionRegistry::new();
314 assert!(registry.is_empty());
315 assert_eq!(registry.len(), 0);
316
317 let id1 = registry.next_id();
318 let id2 = registry.next_id();
319 assert_ne!(id1, id2);
320 }
321
322 #[test]
323 fn protocol_violation_display() {
324 let violation = ProtocolViolation::UnexpectedMessage {
325 protocol: "PingPong".to_string(),
326 expected: "Pong".to_string(),
327 received: "Ping".to_string(),
328 state: "AwaitingPong".to_string(),
329 };
330 let msg = format!("{}", violation);
331 assert!(msg.contains("PingPong"));
332 assert!(msg.contains("Pong"));
333 assert!(msg.contains("Ping"));
334 }
335
336 #[test]
337 fn protocol_violation_to_error() {
338 let violation = ProtocolViolation::ReplyOutsideHandler;
339 let error: SageError = violation.into();
340 assert!(matches!(error, SageError::Protocol(_)));
341 }
342}