Skip to main content

sage_runtime/
session.rs

1//! Phase 3: Session types infrastructure for protocol verification.
2//!
3//! This module provides runtime support for session types, enabling
4//! protocol verification at runtime. It includes:
5//!
6//! - `SessionId`: Unique identifier for protocol sessions
7//! - `SenderHandle`: Handle for replying to messages within a session
8//! - `SessionRegistry`: Per-agent registry of active sessions
9//! - `ProtocolStateMachine`: Trait for protocol state machines
10
11use crate::error::{SageError, SageResult};
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use tokio::sync::{mpsc, RwLock};
16
17/// Unique identifier for a protocol session.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct SessionId(u64);
20
21impl SessionId {
22    /// Create a new session ID with the given value.
23    #[must_use]
24    pub fn new(id: u64) -> Self {
25        Self(id)
26    }
27
28    /// Get the raw session ID value.
29    #[must_use]
30    pub fn value(&self) -> u64 {
31        self.0
32    }
33}
34
35impl std::fmt::Display for SessionId {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        write!(f, "session-{}", self.0)
38    }
39}
40
41/// Handle for sending replies within a protocol session.
42///
43/// This is used by `reply()` to send messages back to the sender
44/// within the context of a session.
45#[derive(Debug, Clone)]
46pub struct SenderHandle {
47    /// Channel for sending reply messages.
48    reply_tx: mpsc::Sender<crate::agent::Message>,
49    /// The protocol this session belongs to (if any).
50    pub protocol: Option<String>,
51    /// The session ID for this message exchange.
52    pub session_id: Option<SessionId>,
53}
54
55impl SenderHandle {
56    /// Create a new sender handle.
57    #[must_use]
58    pub fn new(
59        reply_tx: mpsc::Sender<crate::agent::Message>,
60        protocol: Option<String>,
61        session_id: Option<SessionId>,
62    ) -> Self {
63        Self {
64            reply_tx,
65            protocol,
66            session_id,
67        }
68    }
69
70    /// Send a reply message.
71    pub async fn send<M: serde::Serialize>(&self, msg: M) -> SageResult<()> {
72        let message = crate::agent::Message::new(msg)?;
73        self.reply_tx
74            .send(message)
75            .await
76            .map_err(|e| SageError::Agent(format!("Failed to send reply: {e}")))
77    }
78}
79
80/// State of an active protocol session.
81#[derive(Debug)]
82pub struct SessionState {
83    /// The protocol this session is following.
84    pub protocol: String,
85    /// The current state of the protocol state machine.
86    pub state: Box<dyn ProtocolStateMachine>,
87    /// The role this agent plays in the protocol.
88    pub role: String,
89    /// Handle to send messages to the session partner.
90    pub partner: SenderHandle,
91}
92
93/// Registry of active protocol sessions for an agent.
94///
95/// Each agent maintains its own session registry to track
96/// ongoing protocol sessions.
97#[derive(Debug, Default)]
98pub struct SessionRegistry {
99    /// Active sessions indexed by session ID.
100    sessions: HashMap<SessionId, SessionState>,
101    /// Counter for generating unique session IDs.
102    next_session_id: AtomicU64,
103}
104
105impl SessionRegistry {
106    /// Create a new empty session registry.
107    #[must_use]
108    pub fn new() -> Self {
109        Self::default()
110    }
111
112    /// Generate a new unique session ID.
113    pub fn next_id(&self) -> SessionId {
114        SessionId(self.next_session_id.fetch_add(1, Ordering::SeqCst))
115    }
116
117    /// Start a new protocol session.
118    pub fn start_session(
119        &mut self,
120        session_id: SessionId,
121        protocol: String,
122        role: String,
123        state: Box<dyn ProtocolStateMachine>,
124        partner: SenderHandle,
125    ) {
126        self.sessions.insert(
127            session_id,
128            SessionState {
129                protocol,
130                state,
131                role,
132                partner,
133            },
134        );
135    }
136
137    /// Get a session by ID.
138    #[must_use]
139    pub fn get(&self, session_id: &SessionId) -> Option<&SessionState> {
140        self.sessions.get(session_id)
141    }
142
143    /// Get a mutable reference to a session by ID.
144    pub fn get_mut(&mut self, session_id: &SessionId) -> Option<&mut SessionState> {
145        self.sessions.get_mut(session_id)
146    }
147
148    /// Remove and return a session (e.g., when protocol completes).
149    pub fn remove(&mut self, session_id: &SessionId) -> Option<SessionState> {
150        self.sessions.remove(session_id)
151    }
152
153    /// Check if a session exists.
154    #[must_use]
155    pub fn has(&self, session_id: &SessionId) -> bool {
156        self.sessions.contains_key(session_id)
157    }
158
159    /// Get the number of active sessions.
160    #[must_use]
161    pub fn len(&self) -> usize {
162        self.sessions.len()
163    }
164
165    /// Check if there are no active sessions.
166    #[must_use]
167    pub fn is_empty(&self) -> bool {
168        self.sessions.is_empty()
169    }
170}
171
172/// Protocol violation error details.
173#[derive(Debug, Clone)]
174pub enum ProtocolViolation {
175    /// Received an unexpected message for the current protocol state.
176    UnexpectedMessage {
177        /// The protocol that was violated.
178        protocol: String,
179        /// The expected message type(s).
180        expected: String,
181        /// The received message type.
182        received: String,
183        /// The current state when the violation occurred.
184        state: String,
185    },
186
187    /// Protocol terminated early (session ended before completion).
188    EarlyTermination {
189        /// The protocol that was violated.
190        protocol: String,
191        /// The state when termination occurred.
192        state: String,
193    },
194
195    /// Message received from wrong sender role.
196    WrongSender {
197        /// The protocol that was violated.
198        protocol: String,
199        /// The expected sender role.
200        expected_role: String,
201        /// The actual sender role.
202        actual_role: String,
203    },
204
205    /// No session found for the given session ID.
206    NoSession {
207        /// The missing session ID.
208        session_id: SessionId,
209    },
210
211    /// Attempt to reply outside of a message handler.
212    ReplyOutsideHandler,
213}
214
215impl std::fmt::Display for ProtocolViolation {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        match self {
218            ProtocolViolation::UnexpectedMessage {
219                protocol,
220                expected,
221                received,
222                state,
223            } => write!(
224                f,
225                "unexpected message in protocol '{}': expected '{}', got '{}' (state: {})",
226                protocol, expected, received, state
227            ),
228            ProtocolViolation::EarlyTermination { protocol, state } => {
229                write!(
230                    f,
231                    "protocol '{}' terminated early in state '{}'",
232                    protocol, state
233                )
234            }
235            ProtocolViolation::WrongSender {
236                protocol,
237                expected_role,
238                actual_role,
239            } => write!(
240                f,
241                "wrong sender in protocol '{}': expected role '{}', got '{}'",
242                protocol, expected_role, actual_role
243            ),
244            ProtocolViolation::NoSession { session_id } => {
245                write!(f, "no session found with id {}", session_id)
246            }
247            ProtocolViolation::ReplyOutsideHandler => {
248                write!(f, "reply() called outside of message handler")
249            }
250        }
251    }
252}
253
254impl From<ProtocolViolation> for SageError {
255    fn from(v: ProtocolViolation) -> Self {
256        SageError::Protocol(v.to_string())
257    }
258}
259
260/// Trait for protocol state machines.
261///
262/// This trait is implemented by generated code for each protocol declaration.
263/// It tracks the current state and validates message transitions.
264pub trait ProtocolStateMachine: Send + Sync + std::fmt::Debug {
265    /// Get the name of the current state.
266    fn state_name(&self) -> &str;
267
268    /// Check if a message type can be sent from the given role in the current state.
269    fn can_send(&self, msg_type: &str, from_role: &str) -> bool;
270
271    /// Check if a message type can be received by the given role in the current state.
272    fn can_receive(&self, msg_type: &str, to_role: &str) -> bool;
273
274    /// Transition the state machine based on a message.
275    ///
276    /// # Errors
277    ///
278    /// Returns a `ProtocolViolation` if the transition is invalid.
279    fn transition(&mut self, msg_type: &str) -> Result<(), ProtocolViolation>;
280
281    /// Check if the protocol has reached a terminal (accepting) state.
282    fn is_terminal(&self) -> bool;
283
284    /// Get the protocol name.
285    fn protocol_name(&self) -> &str;
286
287    /// Clone the state machine into a boxed trait object.
288    fn clone_box(&self) -> Box<dyn ProtocolStateMachine>;
289}
290
291/// Thread-safe shared session registry.
292pub type SharedSessionRegistry = Arc<RwLock<SessionRegistry>>;
293
294/// Create a new shared session registry.
295#[must_use]
296pub fn shared_registry() -> SharedSessionRegistry {
297    Arc::new(RwLock::new(SessionRegistry::new()))
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn session_id_display() {
306        let id = SessionId::new(42);
307        assert_eq!(format!("{}", id), "session-42");
308        assert_eq!(id.value(), 42);
309    }
310
311    #[test]
312    fn session_registry_basic() {
313        let registry = SessionRegistry::new();
314        assert!(registry.is_empty());
315        assert_eq!(registry.len(), 0);
316
317        let id1 = registry.next_id();
318        let id2 = registry.next_id();
319        assert_ne!(id1, id2);
320    }
321
322    #[test]
323    fn protocol_violation_display() {
324        let violation = ProtocolViolation::UnexpectedMessage {
325            protocol: "PingPong".to_string(),
326            expected: "Pong".to_string(),
327            received: "Ping".to_string(),
328            state: "AwaitingPong".to_string(),
329        };
330        let msg = format!("{}", violation);
331        assert!(msg.contains("PingPong"));
332        assert!(msg.contains("Pong"));
333        assert!(msg.contains("Ping"));
334    }
335
336    #[test]
337    fn protocol_violation_to_error() {
338        let violation = ProtocolViolation::ReplyOutsideHandler;
339        let error: SageError = violation.into();
340        assert!(matches!(error, SageError::Protocol(_)));
341    }
342}