telltale_runtime/effects/handlers/
in_memory.rs1use async_trait::async_trait;
8use cfg_if::cfg_if;
9use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
10use futures::StreamExt;
11use serde::{de::DeserializeOwned, Serialize};
12use std::collections::BTreeMap;
13use std::time::Duration;
14
15use crate::effects::contract::{
16 DeliveryModel, DocumentedHandlerContract, ExtensionDispatchContract, ExtensionDispatchMode,
17 HandlerContractProfile, HandlerContractTier, ProtocolSemanticContract, RetryPolicy,
18 TimeoutPolicy, TransportPolicyContract,
19};
20use crate::effects::{ChoreoHandler, ChoreoResult, ChoreographyError, RoleId};
21use crate::RoleName;
22
23type MessageChannelPair = (UnboundedSender<Vec<u8>>, UnboundedReceiver<Vec<u8>>);
24type ChoiceChannelPair<L> = (UnboundedSender<L>, UnboundedReceiver<L>);
25type MessageChannelMap =
26 std::sync::Arc<std::sync::Mutex<BTreeMap<(RoleKey, RoleKey), MessageChannelPair>>>;
27type ChoiceChannelMap<L> =
28 std::sync::Arc<std::sync::Mutex<BTreeMap<(RoleKey, RoleKey), ChoiceChannelPair<L>>>>;
29
30#[doc(hidden)]
31#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
32pub struct RoleKey {
33 name: RoleName,
34 index: Option<u32>,
35}
36
37impl RoleKey {
38 fn from_role<R: RoleId>(role: R) -> Self {
39 Self {
40 name: role.role_name(),
41 index: role.role_index(),
42 }
43 }
44}
45
46pub struct InMemoryHandler<R: RoleId> {
48 role: R,
49 channels: MessageChannelMap,
51 choice_channels: ChoiceChannelMap<R::Label>,
53}
54
55impl<R: RoleId> InMemoryHandler<R> {
56 pub fn new(role: R) -> Self {
57 Self {
58 role,
59 channels: std::sync::Arc::new(std::sync::Mutex::new(BTreeMap::new())),
60 choice_channels: std::sync::Arc::new(std::sync::Mutex::new(BTreeMap::new())),
61 }
62 }
63
64 pub fn with_channels(
66 role: R,
67 channels: MessageChannelMap,
68 choice_channels: ChoiceChannelMap<R::Label>,
69 ) -> Self {
70 Self {
71 role,
72 channels,
73 choice_channels,
74 }
75 }
76
77 fn get_or_create_channel(&self, from: R, to: R) -> UnboundedSender<Vec<u8>> {
79 let mut channels = self
80 .channels
81 .lock()
82 .unwrap_or_else(std::sync::PoisonError::into_inner);
83 let key = (RoleKey::from_role(from), RoleKey::from_role(to));
84 channels.entry(key).or_insert_with(unbounded).0.clone()
85 }
86
87 fn get_receiver(&self, from: R, to: R) -> Option<UnboundedReceiver<Vec<u8>>> {
89 let mut channels = self
90 .channels
91 .lock()
92 .unwrap_or_else(std::sync::PoisonError::into_inner);
93 let key = (RoleKey::from_role(from), RoleKey::from_role(to));
94 channels.remove(&key).map(|(_, rx)| rx)
95 }
96
97 fn get_or_create_choice_channel(&self, from: R, to: R) -> UnboundedSender<R::Label> {
99 let mut channels = self
100 .choice_channels
101 .lock()
102 .unwrap_or_else(std::sync::PoisonError::into_inner);
103 let key = (RoleKey::from_role(from), RoleKey::from_role(to));
104 channels.entry(key).or_insert_with(unbounded).0.clone()
105 }
106
107 fn get_choice_receiver(&self, from: R, to: R) -> Option<UnboundedReceiver<R::Label>> {
109 let mut channels = self
110 .choice_channels
111 .lock()
112 .unwrap_or_else(std::sync::PoisonError::into_inner);
113 let key = (RoleKey::from_role(from), RoleKey::from_role(to));
114 channels.remove(&key).map(|(_, rx)| rx)
115 }
116}
117
118impl<R: RoleId> DocumentedHandlerContract for InMemoryHandler<R> {
119 fn contract_profile() -> HandlerContractProfile {
120 HandlerContractProfile {
121 handler_name: std::any::type_name::<Self>(),
122 tier: HandlerContractTier::FullProtocol,
123 semantics: ProtocolSemanticContract {
124 typed_send_recv_roundtrip: true,
125 exact_choice_label_preservation: true,
126 fail_closed_transport_errors: true,
127 timeouts_scoped_to_enforcing_role: true,
128 deterministic_for_regression: true,
129 can_materialize_values: true,
130 },
131 transport: TransportPolicyContract {
132 delivery_model: DeliveryModel::InMemoryChannels,
133 retry_policy: RetryPolicy::None,
134 timeout_policy: TimeoutPolicy::EnforcingRoleOnly,
135 },
136 extension_dispatch: ExtensionDispatchContract {
137 mode: ExtensionDispatchMode::Unsupported,
138 fail_closed_when_unregistered: false,
139 type_exact_before_side_effects: false,
140 },
141 notes: vec![
142 "intended for deterministic local testing rather than remote transport",
143 "role-pair channels are reinserted after each recv/offer operation",
144 ],
145 }
146 }
147}
148
149#[async_trait]
150impl<R: RoleId + 'static> ChoreoHandler for InMemoryHandler<R> {
151 type Role = R;
152 type Endpoint = ();
153
154 async fn send<M: Serialize + Send + Sync>(
155 &mut self,
156 _ep: &mut Self::Endpoint,
157 to: Self::Role,
158 msg: &M,
159 ) -> ChoreoResult<()> {
160 let bytes =
162 bincode::serialize(msg).map_err(|e| ChoreographyError::Serialization(e.to_string()))?;
163
164 let sender = self.get_or_create_channel(self.role, to);
166 sender.unbounded_send(bytes).map_err(|_| {
167 ChoreographyError::Transport(format!(
168 "Failed to send message from {:?} to {:?}",
169 self.role, to
170 ))
171 })?;
172
173 tracing::trace!(?to, "InMemoryHandler: send success");
174 Ok(())
175 }
176
177 async fn recv<M: DeserializeOwned + Send>(
178 &mut self,
179 _ep: &mut Self::Endpoint,
180 from: Self::Role,
181 ) -> ChoreoResult<M> {
182 tracing::trace!(?from, "InMemoryHandler: recv start");
183
184 let mut receiver = self.get_receiver(from, self.role).ok_or_else(|| {
186 ChoreographyError::Transport(format!("No channel from {:?} to {:?}", from, self.role))
187 })?;
188
189 let bytes = receiver.next().await.ok_or_else(|| {
191 ChoreographyError::Transport("Channel closed while waiting for message".into())
192 })?;
193
194 {
196 let mut channels = self
197 .channels
198 .lock()
199 .unwrap_or_else(std::sync::PoisonError::into_inner);
200 let key = (RoleKey::from_role(from), RoleKey::from_role(self.role));
201 if let Some((tx, _)) = channels.remove(&key) {
202 channels.insert(key, (tx, receiver));
203 }
204 }
205
206 let msg = bincode::deserialize(&bytes)
208 .map_err(|e| ChoreographyError::Serialization(e.to_string()))?;
209
210 tracing::trace!(?from, "InMemoryHandler: recv success");
211 Ok(msg)
212 }
213
214 async fn choose(
215 &mut self,
216 _ep: &mut Self::Endpoint,
217 who: Self::Role,
218 label: <Self::Role as RoleId>::Label,
219 ) -> ChoreoResult<()> {
220 let sender = self.get_or_create_choice_channel(self.role, who);
222 sender.unbounded_send(label).map_err(|_| {
223 ChoreographyError::Transport(format!(
224 "Failed to send choice from {:?} to {:?}",
225 self.role, who
226 ))
227 })?;
228
229 tracing::trace!(?who, ?label, "InMemoryHandler: sent choice");
230 Ok(())
231 }
232
233 async fn offer(
234 &mut self,
235 _ep: &mut Self::Endpoint,
236 from: Self::Role,
237 ) -> ChoreoResult<<Self::Role as RoleId>::Label> {
238 tracing::trace!(?from, "InMemoryHandler: waiting for choice");
239
240 let mut receiver = self.get_choice_receiver(from, self.role).ok_or_else(|| {
242 ChoreographyError::Transport(format!(
243 "No choice channel from {:?} to {:?}",
244 from, self.role
245 ))
246 })?;
247
248 let label = receiver.next().await.ok_or_else(|| {
250 ChoreographyError::Transport("Choice channel closed while waiting for label".into())
251 })?;
252
253 {
255 let mut channels = self
256 .choice_channels
257 .lock()
258 .unwrap_or_else(std::sync::PoisonError::into_inner);
259 let key = (RoleKey::from_role(from), RoleKey::from_role(self.role));
260 if let Some((tx, _)) = channels.remove(&key) {
261 channels.insert(key, (tx, receiver));
262 }
263 }
264
265 tracing::trace!(?from, ?label, "InMemoryHandler: received choice");
266 Ok(label)
267 }
268
269 async fn with_timeout<F, T>(
270 &mut self,
271 _ep: &mut Self::Endpoint,
272 at: Self::Role,
273 dur: Duration,
274 body: F,
275 ) -> ChoreoResult<T>
276 where
277 F: std::future::Future<Output = ChoreoResult<T>> + Send,
278 {
279 if at == self.role {
280 cfg_if! {
282 if #[cfg(target_arch = "wasm32")] {
283 use futures::future::{select, Either};
285 use futures::pin_mut;
286 use wasm_timer::Delay;
287
288 let timeout = Delay::new(dur);
289 pin_mut!(body);
290 pin_mut!(timeout);
291
292 match select(body, timeout).await {
293 Either::Left((result, _)) => result,
294 Either::Right(_) => Err(ChoreographyError::Timeout(dur)),
295 }
296 } else {
297 match tokio::time::timeout(dur, body).await {
298 Ok(result) => result,
299 Err(_) => Err(ChoreographyError::Timeout(dur)),
300 }
301 }
302 }
303 } else {
304 body.await
305 }
306 }
307}