Skip to main content

telltale_runtime/testing/
state_machine.rs

1//! Protocol state machine for step-by-step simulation
2//!
3//! This module provides the core abstraction for simulating protocol execution
4//! in a controlled, step-by-step manner.
5
6use serde::{Deserialize, Serialize};
7
8use super::envelope::ProtocolEnvelope;
9use crate::effects::{ChoreographyError, LabelId};
10use crate::identifiers::RoleName;
11
12/// What a protocol state machine is blocked on.
13///
14/// This enum describes what input is needed for the state machine
15/// to make progress.
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum BlockedOn<L: LabelId> {
18    /// Waiting to send a message to a role.
19    Send {
20        /// The destination role.
21        to: RoleName,
22        /// The message type being sent.
23        message_type: String,
24    },
25    /// Waiting to receive a message from a role.
26    Recv {
27        /// The source role.
28        from: RoleName,
29        /// Expected message types (any of these is acceptable).
30        expected_types: Vec<String>,
31    },
32    /// Waiting for an internal choice decision.
33    Choice {
34        /// Available branch labels.
35        branches: Vec<L>,
36    },
37    /// Waiting for an external choice (offer).
38    Offer {
39        /// The role making the choice.
40        from: RoleName,
41        /// Expected branch labels.
42        branches: Vec<L>,
43    },
44    /// Protocol has completed successfully.
45    Complete,
46    /// Protocol has failed with an error.
47    Failed(String),
48}
49
50impl<L: LabelId> BlockedOn<L> {
51    /// Check if the state machine is complete (successfully or with error).
52    #[must_use]
53    pub fn is_terminal(&self) -> bool {
54        matches!(self, BlockedOn::Complete | BlockedOn::Failed(_))
55    }
56
57    /// Check if waiting to send.
58    #[must_use]
59    pub fn is_send(&self) -> bool {
60        matches!(self, BlockedOn::Send { .. })
61    }
62
63    /// Check if waiting to receive.
64    #[must_use]
65    pub fn is_recv(&self) -> bool {
66        matches!(self, BlockedOn::Recv { .. })
67    }
68
69    /// Check if waiting for a choice.
70    #[must_use]
71    pub fn is_choice(&self) -> bool {
72        matches!(self, BlockedOn::Choice { .. } | BlockedOn::Offer { .. })
73    }
74}
75
76/// Input to advance the state machine.
77#[derive(Debug, Clone)]
78pub enum StepInput<L: LabelId> {
79    /// Provide a message to send.
80    SendMessage(ProtocolEnvelope),
81    /// Provide a received message.
82    RecvMessage(ProtocolEnvelope),
83    /// Make an internal choice.
84    MakeChoice(L),
85    /// Receive an external choice (offer).
86    ReceiveOffer(L),
87    /// Signal timeout.
88    Timeout,
89    /// Signal error.
90    Error(String),
91}
92
93impl<L: LabelId> StepInput<L> {
94    /// Create a send message input.
95    pub fn send(envelope: ProtocolEnvelope) -> Self {
96        Self::SendMessage(envelope)
97    }
98
99    /// Create a receive message input.
100    pub fn recv(envelope: ProtocolEnvelope) -> Self {
101        Self::RecvMessage(envelope)
102    }
103
104    /// Create a choice input.
105    pub fn choice(branch: L) -> Self {
106        Self::MakeChoice(branch)
107    }
108
109    /// Create an offer input.
110    pub fn offer(branch: L) -> Self {
111        Self::ReceiveOffer(branch)
112    }
113}
114
115/// Output from a state machine step.
116#[derive(Debug, Clone)]
117pub enum StepOutput<L: LabelId> {
118    /// A message was sent.
119    Sent(ProtocolEnvelope),
120    /// A message was received and processed.
121    Received {
122        /// The received envelope.
123        envelope: ProtocolEnvelope,
124        /// Any response to send (for request-response patterns).
125        response: Option<ProtocolEnvelope>,
126    },
127    /// A choice was made.
128    ChoiceMade(L),
129    /// An offer was received.
130    OfferReceived(L),
131    /// The protocol completed.
132    Completed,
133    /// No progress was made (input didn't match what was needed).
134    NoProgress,
135}
136
137impl<L: LabelId> StepOutput<L> {
138    /// Check if this output indicates completion.
139    #[must_use]
140    pub fn is_completed(&self) -> bool {
141        matches!(self, StepOutput::Completed)
142    }
143
144    /// Check if progress was made.
145    #[must_use]
146    pub fn made_progress(&self) -> bool {
147        !matches!(self, StepOutput::NoProgress)
148    }
149}
150
151/// A checkpoint of protocol state for save/restore.
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct Checkpoint {
154    /// Protocol name.
155    pub protocol: String,
156    /// Current role.
157    pub role: RoleName,
158    /// State identifier (implementation-specific).
159    pub state_id: String,
160    /// Serialized state data.
161    pub state_data: Vec<u8>,
162    /// Sequence number at checkpoint time.
163    pub sequence: u64,
164    /// Additional metadata (BTreeMap for deterministic iteration order).
165    pub metadata: std::collections::BTreeMap<String, String>,
166}
167
168impl Checkpoint {
169    /// Create a new checkpoint.
170    pub fn new(protocol: impl Into<String>, role: RoleName, state_id: impl Into<String>) -> Self {
171        Self {
172            protocol: protocol.into(),
173            role,
174            state_id: state_id.into(),
175            state_data: Vec::new(),
176            sequence: 0,
177            metadata: std::collections::BTreeMap::new(),
178        }
179    }
180
181    /// Set the state data.
182    #[must_use]
183    pub fn with_data(mut self, data: Vec<u8>) -> Self {
184        self.state_data = data;
185        self
186    }
187
188    /// Set the sequence number.
189    #[must_use]
190    pub fn with_sequence(mut self, seq: u64) -> Self {
191        self.sequence = seq;
192        self
193    }
194
195    /// Add metadata.
196    #[must_use]
197    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
198        self.metadata.insert(key.into(), value.into());
199        self
200    }
201
202    /// Serialize the checkpoint to bytes.
203    pub fn to_bytes(&self) -> Result<Vec<u8>, CheckpointError> {
204        bincode::serialize(self).map_err(|e| CheckpointError::Serialization(e.to_string()))
205    }
206
207    /// Deserialize a checkpoint from bytes.
208    pub fn from_bytes(bytes: &[u8]) -> Result<Self, CheckpointError> {
209        bincode::deserialize(bytes).map_err(|e| CheckpointError::Deserialization(e.to_string()))
210    }
211}
212
213/// Errors that can occur with checkpoints.
214#[derive(Debug, thiserror::Error)]
215pub enum CheckpointError {
216    /// Serialization failed.
217    #[error("Checkpoint serialization error: {0}")]
218    Serialization(String),
219
220    /// Deserialization failed.
221    #[error("Checkpoint deserialization error: {0}")]
222    Deserialization(String),
223
224    /// Checkpoint is incompatible.
225    #[error("Incompatible checkpoint: {0}")]
226    Incompatible(String),
227}
228
229/// Trait for protocol state machines that can be stepped through.
230///
231/// This trait is the core abstraction for simulation. It allows
232/// external simulators to control protocol execution step-by-step.
233pub trait ProtocolStateMachine: Send {
234    type Label: LabelId;
235    /// Get the protocol name.
236    fn protocol_name(&self) -> &str;
237
238    /// Get the current role.
239    fn role(&self) -> &RoleName;
240
241    /// Get what the state machine is currently blocked on.
242    fn blocked_on(&self) -> BlockedOn<Self::Label>;
243
244    /// Attempt to advance the state machine with the given input.
245    ///
246    /// Returns `Ok(StepOutput)` if the step succeeded, or an error
247    /// if the input was invalid for the current state.
248    fn step(
249        &mut self,
250        input: StepInput<Self::Label>,
251    ) -> Result<StepOutput<Self::Label>, ChoreographyError>;
252
253    /// Create a checkpoint of the current state.
254    fn checkpoint(&self) -> Result<Checkpoint, CheckpointError>;
255
256    /// Restore state from a checkpoint.
257    fn restore(&mut self, checkpoint: &Checkpoint) -> Result<(), CheckpointError>;
258
259    /// Get the current sequence number.
260    fn sequence(&self) -> u64;
261
262    /// Check if the protocol has completed.
263    fn is_complete(&self) -> bool {
264        self.blocked_on().is_terminal()
265    }
266}
267
268/// A simple state machine implementation for testing.
269///
270/// This implementation uses a linear sequence of expected operations.
271#[derive(Debug)]
272pub struct LinearStateMachine<L: LabelId> {
273    protocol: String,
274    role: RoleName,
275    states: Vec<BlockedOn<L>>,
276    current_state: usize,
277    sequence: u64,
278}
279
280impl<L: LabelId> LinearStateMachine<L> {
281    /// Create a new linear state machine.
282    pub fn new(protocol: impl Into<String>, role: RoleName, states: Vec<BlockedOn<L>>) -> Self {
283        Self {
284            protocol: protocol.into(),
285            role,
286            states,
287            current_state: 0,
288            sequence: 0,
289        }
290    }
291
292    /// Advance to the next state.
293    fn advance(&mut self) {
294        if self.current_state < self.states.len() {
295            self.current_state += 1;
296            self.sequence += 1;
297        }
298    }
299}
300
301impl<L: LabelId> ProtocolStateMachine for LinearStateMachine<L> {
302    type Label = L;
303
304    fn protocol_name(&self) -> &str {
305        &self.protocol
306    }
307
308    fn role(&self) -> &RoleName {
309        &self.role
310    }
311
312    fn blocked_on(&self) -> BlockedOn<Self::Label> {
313        self.states
314            .get(self.current_state)
315            .cloned()
316            .unwrap_or(BlockedOn::Complete)
317    }
318
319    fn step(
320        &mut self,
321        input: StepInput<Self::Label>,
322    ) -> Result<StepOutput<Self::Label>, ChoreographyError> {
323        let current = self.blocked_on();
324
325        match (&current, &input) {
326            (BlockedOn::Send { .. }, StepInput::SendMessage(env)) => {
327                self.advance();
328                Ok(StepOutput::Sent(env.clone()))
329            }
330            (BlockedOn::Recv { .. }, StepInput::RecvMessage(env)) => {
331                self.advance();
332                Ok(StepOutput::Received {
333                    envelope: env.clone(),
334                    response: None,
335                })
336            }
337            (BlockedOn::Choice { branches }, StepInput::MakeChoice(branch)) => {
338                if branches.contains(branch) {
339                    self.advance();
340                    Ok(StepOutput::ChoiceMade(*branch))
341                } else {
342                    Err(ChoreographyError::InvalidChoice {
343                        expected: branches
344                            .iter()
345                            .map(|label| label.as_str().to_string())
346                            .collect(),
347                        actual: branch.as_str().to_string(),
348                    })
349                }
350            }
351            (BlockedOn::Offer { branches, .. }, StepInput::ReceiveOffer(branch)) => {
352                if branches.contains(branch) {
353                    self.advance();
354                    Ok(StepOutput::OfferReceived(*branch))
355                } else {
356                    Err(ChoreographyError::InvalidChoice {
357                        expected: branches
358                            .iter()
359                            .map(|label| label.as_str().to_string())
360                            .collect(),
361                        actual: branch.as_str().to_string(),
362                    })
363                }
364            }
365            (BlockedOn::Complete, _) => Ok(StepOutput::Completed),
366            (BlockedOn::Failed(msg), _) => Err(ChoreographyError::ExecutionError(msg.clone())),
367            _ => Ok(StepOutput::NoProgress),
368        }
369    }
370
371    fn checkpoint(&self) -> Result<Checkpoint, CheckpointError> {
372        let state_data = bincode::serialize(&self.current_state)
373            .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
374
375        Ok(Checkpoint::new(
376            &self.protocol,
377            self.role.clone(),
378            format!("state_{}", self.current_state),
379        )
380        .with_data(state_data)
381        .with_sequence(self.sequence))
382    }
383
384    fn restore(&mut self, checkpoint: &Checkpoint) -> Result<(), CheckpointError> {
385        if checkpoint.protocol != self.protocol {
386            return Err(CheckpointError::Incompatible(format!(
387                "Protocol mismatch: expected {}, got {}",
388                self.protocol, checkpoint.protocol
389            )));
390        }
391
392        self.current_state = bincode::deserialize(&checkpoint.state_data)
393            .map_err(|e| CheckpointError::Deserialization(e.to_string()))?;
394        self.sequence = checkpoint.sequence;
395
396        Ok(())
397    }
398
399    fn sequence(&self) -> u64 {
400        self.sequence
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
409    enum TestLabel {
410        Accept,
411        Reject,
412        Other,
413    }
414
415    impl LabelId for TestLabel {
416        fn as_str(&self) -> &'static str {
417            match self {
418                TestLabel::Accept => "Accept",
419                TestLabel::Reject => "Reject",
420                TestLabel::Other => "Other",
421            }
422        }
423
424        fn from_str(label: &str) -> Option<Self> {
425            match label {
426                "Accept" => Some(TestLabel::Accept),
427                "Reject" => Some(TestLabel::Reject),
428                "Other" => Some(TestLabel::Other),
429                _ => None,
430            }
431        }
432    }
433
434    #[test]
435    fn test_blocked_on_terminal() {
436        assert!(BlockedOn::<TestLabel>::Complete.is_terminal());
437        assert!(BlockedOn::<TestLabel>::Failed("error".to_string()).is_terminal());
438        assert!(!BlockedOn::<TestLabel>::Send {
439            to: RoleName::from_static("Server"),
440            message_type: "Request".to_string(),
441        }
442        .is_terminal());
443    }
444
445    #[test]
446    fn test_linear_state_machine() {
447        let states = vec![
448            BlockedOn::Send {
449                to: RoleName::from_static("Server"),
450                message_type: "Request".to_string(),
451            },
452            BlockedOn::Recv {
453                from: RoleName::from_static("Server"),
454                expected_types: vec!["Response".to_string()],
455            },
456        ];
457
458        let mut sm = LinearStateMachine::<TestLabel>::new(
459            "TestProto",
460            RoleName::from_static("Client"),
461            states,
462        );
463
464        assert!(sm.blocked_on().is_send());
465
466        // Create a send envelope
467        let send_env = super::super::envelope::ProtocolEnvelope::builder()
468            .protocol("TestProto")
469            .sender(RoleName::from_static("Client"))
470            .recipient(RoleName::from_static("Server"))
471            .message_type("Request")
472            .payload(vec![])
473            .build()
474            .unwrap();
475
476        let result = sm.step(StepInput::send(send_env.clone()));
477        assert!(result.is_ok());
478        assert!(matches!(result.unwrap(), StepOutput::Sent(_)));
479
480        assert!(sm.blocked_on().is_recv());
481
482        // Create a receive envelope
483        let recv_env = super::super::envelope::ProtocolEnvelope::builder()
484            .protocol("TestProto")
485            .sender(RoleName::from_static("Server"))
486            .recipient(RoleName::from_static("Client"))
487            .message_type("Response")
488            .payload(vec![])
489            .build()
490            .unwrap();
491
492        let result = sm.step(StepInput::recv(recv_env));
493        assert!(result.is_ok());
494
495        assert!(sm.blocked_on().is_terminal());
496    }
497
498    #[test]
499    fn test_checkpoint_roundtrip() {
500        let states = vec![BlockedOn::Send {
501            to: RoleName::from_static("Server"),
502            message_type: "Msg".to_string(),
503        }];
504
505        let sm =
506            LinearStateMachine::<TestLabel>::new("Proto", RoleName::from_static("Client"), states);
507        let checkpoint = sm.checkpoint().unwrap();
508
509        let bytes = checkpoint.to_bytes().unwrap();
510        let restored = Checkpoint::from_bytes(&bytes).unwrap();
511
512        assert_eq!(checkpoint.protocol, restored.protocol);
513        assert_eq!(checkpoint.sequence, restored.sequence);
514    }
515
516    #[test]
517    fn test_choice_validation() {
518        let states = vec![BlockedOn::Choice {
519            branches: vec![TestLabel::Accept, TestLabel::Reject],
520        }];
521
522        let mut sm =
523            LinearStateMachine::<TestLabel>::new("Proto", RoleName::from_static("Client"), states);
524
525        // Invalid choice should fail
526        let result = sm.step(StepInput::choice(TestLabel::Other));
527        assert!(result.is_err());
528
529        // Valid choice should succeed
530        let result = sm.step(StepInput::choice(TestLabel::Accept));
531        assert!(result.is_ok());
532    }
533}