telltale_runtime/testing/
state_machine.rs1use serde::{Deserialize, Serialize};
7
8use super::envelope::ProtocolEnvelope;
9use crate::effects::{ChoreographyError, LabelId};
10use crate::identifiers::RoleName;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum BlockedOn<L: LabelId> {
18 Send {
20 to: RoleName,
22 message_type: String,
24 },
25 Recv {
27 from: RoleName,
29 expected_types: Vec<String>,
31 },
32 Choice {
34 branches: Vec<L>,
36 },
37 Offer {
39 from: RoleName,
41 branches: Vec<L>,
43 },
44 Complete,
46 Failed(String),
48}
49
50impl<L: LabelId> BlockedOn<L> {
51 #[must_use]
53 pub fn is_terminal(&self) -> bool {
54 matches!(self, BlockedOn::Complete | BlockedOn::Failed(_))
55 }
56
57 #[must_use]
59 pub fn is_send(&self) -> bool {
60 matches!(self, BlockedOn::Send { .. })
61 }
62
63 #[must_use]
65 pub fn is_recv(&self) -> bool {
66 matches!(self, BlockedOn::Recv { .. })
67 }
68
69 #[must_use]
71 pub fn is_choice(&self) -> bool {
72 matches!(self, BlockedOn::Choice { .. } | BlockedOn::Offer { .. })
73 }
74}
75
76#[derive(Debug, Clone)]
78pub enum StepInput<L: LabelId> {
79 SendMessage(ProtocolEnvelope),
81 RecvMessage(ProtocolEnvelope),
83 MakeChoice(L),
85 ReceiveOffer(L),
87 Timeout,
89 Error(String),
91}
92
93impl<L: LabelId> StepInput<L> {
94 pub fn send(envelope: ProtocolEnvelope) -> Self {
96 Self::SendMessage(envelope)
97 }
98
99 pub fn recv(envelope: ProtocolEnvelope) -> Self {
101 Self::RecvMessage(envelope)
102 }
103
104 pub fn choice(branch: L) -> Self {
106 Self::MakeChoice(branch)
107 }
108
109 pub fn offer(branch: L) -> Self {
111 Self::ReceiveOffer(branch)
112 }
113}
114
115#[derive(Debug, Clone)]
117pub enum StepOutput<L: LabelId> {
118 Sent(ProtocolEnvelope),
120 Received {
122 envelope: ProtocolEnvelope,
124 response: Option<ProtocolEnvelope>,
126 },
127 ChoiceMade(L),
129 OfferReceived(L),
131 Completed,
133 NoProgress,
135}
136
137impl<L: LabelId> StepOutput<L> {
138 #[must_use]
140 pub fn is_completed(&self) -> bool {
141 matches!(self, StepOutput::Completed)
142 }
143
144 #[must_use]
146 pub fn made_progress(&self) -> bool {
147 !matches!(self, StepOutput::NoProgress)
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct Checkpoint {
154 pub protocol: String,
156 pub role: RoleName,
158 pub state_id: String,
160 pub state_data: Vec<u8>,
162 pub sequence: u64,
164 pub metadata: std::collections::BTreeMap<String, String>,
166}
167
168impl Checkpoint {
169 pub fn new(protocol: impl Into<String>, role: RoleName, state_id: impl Into<String>) -> Self {
171 Self {
172 protocol: protocol.into(),
173 role,
174 state_id: state_id.into(),
175 state_data: Vec::new(),
176 sequence: 0,
177 metadata: std::collections::BTreeMap::new(),
178 }
179 }
180
181 #[must_use]
183 pub fn with_data(mut self, data: Vec<u8>) -> Self {
184 self.state_data = data;
185 self
186 }
187
188 #[must_use]
190 pub fn with_sequence(mut self, seq: u64) -> Self {
191 self.sequence = seq;
192 self
193 }
194
195 #[must_use]
197 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
198 self.metadata.insert(key.into(), value.into());
199 self
200 }
201
202 pub fn to_bytes(&self) -> Result<Vec<u8>, CheckpointError> {
204 bincode::serialize(self).map_err(|e| CheckpointError::Serialization(e.to_string()))
205 }
206
207 pub fn from_bytes(bytes: &[u8]) -> Result<Self, CheckpointError> {
209 bincode::deserialize(bytes).map_err(|e| CheckpointError::Deserialization(e.to_string()))
210 }
211}
212
213#[derive(Debug, thiserror::Error)]
215pub enum CheckpointError {
216 #[error("Checkpoint serialization error: {0}")]
218 Serialization(String),
219
220 #[error("Checkpoint deserialization error: {0}")]
222 Deserialization(String),
223
224 #[error("Incompatible checkpoint: {0}")]
226 Incompatible(String),
227}
228
229pub trait ProtocolStateMachine: Send {
234 type Label: LabelId;
235 fn protocol_name(&self) -> &str;
237
238 fn role(&self) -> &RoleName;
240
241 fn blocked_on(&self) -> BlockedOn<Self::Label>;
243
244 fn step(
249 &mut self,
250 input: StepInput<Self::Label>,
251 ) -> Result<StepOutput<Self::Label>, ChoreographyError>;
252
253 fn checkpoint(&self) -> Result<Checkpoint, CheckpointError>;
255
256 fn restore(&mut self, checkpoint: &Checkpoint) -> Result<(), CheckpointError>;
258
259 fn sequence(&self) -> u64;
261
262 fn is_complete(&self) -> bool {
264 self.blocked_on().is_terminal()
265 }
266}
267
268#[derive(Debug)]
272pub struct LinearStateMachine<L: LabelId> {
273 protocol: String,
274 role: RoleName,
275 states: Vec<BlockedOn<L>>,
276 current_state: usize,
277 sequence: u64,
278}
279
280impl<L: LabelId> LinearStateMachine<L> {
281 pub fn new(protocol: impl Into<String>, role: RoleName, states: Vec<BlockedOn<L>>) -> Self {
283 Self {
284 protocol: protocol.into(),
285 role,
286 states,
287 current_state: 0,
288 sequence: 0,
289 }
290 }
291
292 fn advance(&mut self) {
294 if self.current_state < self.states.len() {
295 self.current_state += 1;
296 self.sequence += 1;
297 }
298 }
299}
300
301impl<L: LabelId> ProtocolStateMachine for LinearStateMachine<L> {
302 type Label = L;
303
304 fn protocol_name(&self) -> &str {
305 &self.protocol
306 }
307
308 fn role(&self) -> &RoleName {
309 &self.role
310 }
311
312 fn blocked_on(&self) -> BlockedOn<Self::Label> {
313 self.states
314 .get(self.current_state)
315 .cloned()
316 .unwrap_or(BlockedOn::Complete)
317 }
318
319 fn step(
320 &mut self,
321 input: StepInput<Self::Label>,
322 ) -> Result<StepOutput<Self::Label>, ChoreographyError> {
323 let current = self.blocked_on();
324
325 match (¤t, &input) {
326 (BlockedOn::Send { .. }, StepInput::SendMessage(env)) => {
327 self.advance();
328 Ok(StepOutput::Sent(env.clone()))
329 }
330 (BlockedOn::Recv { .. }, StepInput::RecvMessage(env)) => {
331 self.advance();
332 Ok(StepOutput::Received {
333 envelope: env.clone(),
334 response: None,
335 })
336 }
337 (BlockedOn::Choice { branches }, StepInput::MakeChoice(branch)) => {
338 if branches.contains(branch) {
339 self.advance();
340 Ok(StepOutput::ChoiceMade(*branch))
341 } else {
342 Err(ChoreographyError::InvalidChoice {
343 expected: branches
344 .iter()
345 .map(|label| label.as_str().to_string())
346 .collect(),
347 actual: branch.as_str().to_string(),
348 })
349 }
350 }
351 (BlockedOn::Offer { branches, .. }, StepInput::ReceiveOffer(branch)) => {
352 if branches.contains(branch) {
353 self.advance();
354 Ok(StepOutput::OfferReceived(*branch))
355 } else {
356 Err(ChoreographyError::InvalidChoice {
357 expected: branches
358 .iter()
359 .map(|label| label.as_str().to_string())
360 .collect(),
361 actual: branch.as_str().to_string(),
362 })
363 }
364 }
365 (BlockedOn::Complete, _) => Ok(StepOutput::Completed),
366 (BlockedOn::Failed(msg), _) => Err(ChoreographyError::ExecutionError(msg.clone())),
367 _ => Ok(StepOutput::NoProgress),
368 }
369 }
370
371 fn checkpoint(&self) -> Result<Checkpoint, CheckpointError> {
372 let state_data = bincode::serialize(&self.current_state)
373 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
374
375 Ok(Checkpoint::new(
376 &self.protocol,
377 self.role.clone(),
378 format!("state_{}", self.current_state),
379 )
380 .with_data(state_data)
381 .with_sequence(self.sequence))
382 }
383
384 fn restore(&mut self, checkpoint: &Checkpoint) -> Result<(), CheckpointError> {
385 if checkpoint.protocol != self.protocol {
386 return Err(CheckpointError::Incompatible(format!(
387 "Protocol mismatch: expected {}, got {}",
388 self.protocol, checkpoint.protocol
389 )));
390 }
391
392 self.current_state = bincode::deserialize(&checkpoint.state_data)
393 .map_err(|e| CheckpointError::Deserialization(e.to_string()))?;
394 self.sequence = checkpoint.sequence;
395
396 Ok(())
397 }
398
399 fn sequence(&self) -> u64 {
400 self.sequence
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
409 enum TestLabel {
410 Accept,
411 Reject,
412 Other,
413 }
414
415 impl LabelId for TestLabel {
416 fn as_str(&self) -> &'static str {
417 match self {
418 TestLabel::Accept => "Accept",
419 TestLabel::Reject => "Reject",
420 TestLabel::Other => "Other",
421 }
422 }
423
424 fn from_str(label: &str) -> Option<Self> {
425 match label {
426 "Accept" => Some(TestLabel::Accept),
427 "Reject" => Some(TestLabel::Reject),
428 "Other" => Some(TestLabel::Other),
429 _ => None,
430 }
431 }
432 }
433
434 #[test]
435 fn test_blocked_on_terminal() {
436 assert!(BlockedOn::<TestLabel>::Complete.is_terminal());
437 assert!(BlockedOn::<TestLabel>::Failed("error".to_string()).is_terminal());
438 assert!(!BlockedOn::<TestLabel>::Send {
439 to: RoleName::from_static("Server"),
440 message_type: "Request".to_string(),
441 }
442 .is_terminal());
443 }
444
445 #[test]
446 fn test_linear_state_machine() {
447 let states = vec![
448 BlockedOn::Send {
449 to: RoleName::from_static("Server"),
450 message_type: "Request".to_string(),
451 },
452 BlockedOn::Recv {
453 from: RoleName::from_static("Server"),
454 expected_types: vec!["Response".to_string()],
455 },
456 ];
457
458 let mut sm = LinearStateMachine::<TestLabel>::new(
459 "TestProto",
460 RoleName::from_static("Client"),
461 states,
462 );
463
464 assert!(sm.blocked_on().is_send());
465
466 let send_env = super::super::envelope::ProtocolEnvelope::builder()
468 .protocol("TestProto")
469 .sender(RoleName::from_static("Client"))
470 .recipient(RoleName::from_static("Server"))
471 .message_type("Request")
472 .payload(vec![])
473 .build()
474 .unwrap();
475
476 let result = sm.step(StepInput::send(send_env.clone()));
477 assert!(result.is_ok());
478 assert!(matches!(result.unwrap(), StepOutput::Sent(_)));
479
480 assert!(sm.blocked_on().is_recv());
481
482 let recv_env = super::super::envelope::ProtocolEnvelope::builder()
484 .protocol("TestProto")
485 .sender(RoleName::from_static("Server"))
486 .recipient(RoleName::from_static("Client"))
487 .message_type("Response")
488 .payload(vec![])
489 .build()
490 .unwrap();
491
492 let result = sm.step(StepInput::recv(recv_env));
493 assert!(result.is_ok());
494
495 assert!(sm.blocked_on().is_terminal());
496 }
497
498 #[test]
499 fn test_checkpoint_roundtrip() {
500 let states = vec![BlockedOn::Send {
501 to: RoleName::from_static("Server"),
502 message_type: "Msg".to_string(),
503 }];
504
505 let sm =
506 LinearStateMachine::<TestLabel>::new("Proto", RoleName::from_static("Client"), states);
507 let checkpoint = sm.checkpoint().unwrap();
508
509 let bytes = checkpoint.to_bytes().unwrap();
510 let restored = Checkpoint::from_bytes(&bytes).unwrap();
511
512 assert_eq!(checkpoint.protocol, restored.protocol);
513 assert_eq!(checkpoint.sequence, restored.sequence);
514 }
515
516 #[test]
517 fn test_choice_validation() {
518 let states = vec![BlockedOn::Choice {
519 branches: vec![TestLabel::Accept, TestLabel::Reject],
520 }];
521
522 let mut sm =
523 LinearStateMachine::<TestLabel>::new("Proto", RoleName::from_static("Client"), states);
524
525 let result = sm.step(StepInput::choice(TestLabel::Other));
527 assert!(result.is_err());
528
529 let result = sm.step(StepInput::choice(TestLabel::Accept));
531 assert!(result.is_ok());
532 }
533}