telltale_machine/
nested.rs1use std::collections::BTreeMap;
7use std::sync::Mutex;
8
9use crate::coroutine::Value;
10use crate::effect::{EffectFailure, EffectHandler, EffectResult};
11use crate::engine::{ObsEvent, ProtocolMachine, ProtocolMachineError, StepResult};
12use crate::semantic_objects::ProtocolMachineSemanticObjects;
13
14struct SiteRunner {
15 machine: Mutex<ProtocolMachine>,
16 handler: Box<dyn EffectHandler>,
17}
18
19pub struct NestedProtocolMachineHandler {
21 sites: BTreeMap<String, SiteRunner>,
22 max_rounds_per_step: usize,
23}
24
25impl NestedProtocolMachineHandler {
26 #[must_use]
28 pub fn new() -> Self {
29 Self {
30 sites: BTreeMap::new(),
31 max_rounds_per_step: 1,
32 }
33 }
34
35 #[must_use]
37 pub fn with_rounds_per_step(mut self, rounds: usize) -> Self {
38 self.max_rounds_per_step = rounds.max(1);
39 self
40 }
41
42 #[must_use]
44 pub fn rounds_per_step(&self) -> usize {
45 self.max_rounds_per_step
46 }
47
48 pub fn add_site(
50 &mut self,
51 name: impl Into<String>,
52 machine: ProtocolMachine,
53 handler: Box<dyn EffectHandler>,
54 ) {
55 self.sites.insert(
56 name.into(),
57 SiteRunner {
58 machine: Mutex::new(machine),
59 handler,
60 },
61 );
62 }
63
64 #[must_use]
70 pub fn site_trace(&self, name: &str) -> Option<Vec<ObsEvent>> {
71 self.sites.get(name).map(|site| {
72 site.machine
73 .lock()
74 .unwrap_or_else(|poisoned| poisoned.into_inner())
75 .trace()
76 .to_vec()
77 })
78 }
79
80 #[must_use]
86 pub fn site_all_done(&self, name: &str) -> Option<bool> {
87 self.sites.get(name).map(|site| {
88 site.machine
89 .lock()
90 .unwrap_or_else(|poisoned| poisoned.into_inner())
91 .all_done()
92 })
93 }
94
95 #[must_use]
101 pub fn site_semantic_objects(&self, name: &str) -> Option<ProtocolMachineSemanticObjects> {
102 self.sites.get(name).map(|site| {
103 site.machine
104 .lock()
105 .unwrap_or_else(|poisoned| poisoned.into_inner())
106 .semantic_objects()
107 })
108 }
109
110 fn step_site(&self, name: &str) -> Result<(), String> {
111 let site = self
112 .sites
113 .get(name)
114 .ok_or_else(|| format!("unknown site: {name}"))?;
115
116 let mut machine = site
117 .machine
118 .lock()
119 .unwrap_or_else(|poisoned| poisoned.into_inner());
120 let handler = site.handler.as_ref();
121
122 for _ in 0..self.max_rounds_per_step {
123 match machine.step_round(handler, 1) {
124 Ok(StepResult::Continue) => {}
125 Ok(StepResult::AllDone | StepResult::Stuck) => break,
126 Err(ProtocolMachineError::Fault { fault, .. }) => {
127 return Err(format!("inner machine fault: {fault}"));
128 }
129 Err(e) => return Err(e.to_string()),
130 }
131 }
132
133 Ok(())
134 }
135}
136
137impl Default for NestedProtocolMachineHandler {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143impl EffectHandler for NestedProtocolMachineHandler {
144 fn handle_send(
145 &self,
146 role: &str,
147 _partner: &str,
148 _label: &str,
149 _state: &[Value],
150 ) -> EffectResult<Value> {
151 match self.step_site(role) {
152 Ok(()) => EffectResult::success(Value::Unit),
153 Err(message) => EffectResult::failure(EffectFailure::contract_violation(message)),
154 }
155 }
156
157 fn handle_recv(
158 &self,
159 role: &str,
160 _partner: &str,
161 _label: &str,
162 _state: &mut Vec<Value>,
163 _payload: &Value,
164 ) -> EffectResult<()> {
165 match self.step_site(role) {
166 Ok(()) => EffectResult::success(()),
167 Err(message) => EffectResult::failure(EffectFailure::contract_violation(message)),
168 }
169 }
170
171 fn handle_choose(
172 &self,
173 _role: &str,
174 _partner: &str,
175 labels: &[String],
176 _state: &[Value],
177 ) -> EffectResult<String> {
178 match labels.first().cloned() {
179 Some(label) => EffectResult::success(label),
180 None => EffectResult::failure(EffectFailure::invalid_input("no labels available")),
181 }
182 }
183
184 fn step(&self, role: &str, _state: &mut Vec<Value>) -> EffectResult<()> {
185 match self.step_site(role) {
186 Ok(()) => EffectResult::success(()),
187 Err(message) => EffectResult::failure(EffectFailure::contract_violation(message)),
188 }
189 }
190}