telltale_runtime/effects/handlers/
in_memory.rs1use async_trait::async_trait;
8use cfg_if::cfg_if;
9use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
10use futures::StreamExt;
11use serde::{de::DeserializeOwned, Serialize};
12use std::collections::BTreeMap;
13use std::time::Duration;
14
15use crate::effects::{ChoreoHandler, ChoreoResult, ChoreographyError, RoleId};
16use crate::RoleName;
17
18type MessageChannelPair = (UnboundedSender<Vec<u8>>, UnboundedReceiver<Vec<u8>>);
19type ChoiceChannelPair<L> = (UnboundedSender<L>, UnboundedReceiver<L>);
20type MessageChannelMap =
21 std::sync::Arc<std::sync::Mutex<BTreeMap<(RoleKey, RoleKey), MessageChannelPair>>>;
22type ChoiceChannelMap<L> =
23 std::sync::Arc<std::sync::Mutex<BTreeMap<(RoleKey, RoleKey), ChoiceChannelPair<L>>>>;
24
25#[doc(hidden)]
26#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
27pub struct RoleKey {
28 name: RoleName,
29 index: Option<u32>,
30}
31
32impl RoleKey {
33 fn from_role<R: RoleId>(role: R) -> Self {
34 Self {
35 name: role.role_name(),
36 index: role.role_index(),
37 }
38 }
39}
40
41pub struct InMemoryHandler<R: RoleId> {
43 role: R,
44 channels: MessageChannelMap,
46 choice_channels: ChoiceChannelMap<R::Label>,
48}
49
50impl<R: RoleId> InMemoryHandler<R> {
51 pub fn new(role: R) -> Self {
52 Self {
53 role,
54 channels: std::sync::Arc::new(std::sync::Mutex::new(BTreeMap::new())),
55 choice_channels: std::sync::Arc::new(std::sync::Mutex::new(BTreeMap::new())),
56 }
57 }
58
59 pub fn with_channels(
61 role: R,
62 channels: MessageChannelMap,
63 choice_channels: ChoiceChannelMap<R::Label>,
64 ) -> Self {
65 Self {
66 role,
67 channels,
68 choice_channels,
69 }
70 }
71
72 fn get_or_create_channel(&self, from: R, to: R) -> UnboundedSender<Vec<u8>> {
74 let mut channels = self
75 .channels
76 .lock()
77 .unwrap_or_else(std::sync::PoisonError::into_inner);
78 let key = (RoleKey::from_role(from), RoleKey::from_role(to));
79 channels.entry(key).or_insert_with(unbounded).0.clone()
80 }
81
82 fn get_receiver(&self, from: R, to: R) -> Option<UnboundedReceiver<Vec<u8>>> {
84 let mut channels = self
85 .channels
86 .lock()
87 .unwrap_or_else(std::sync::PoisonError::into_inner);
88 let key = (RoleKey::from_role(from), RoleKey::from_role(to));
89 channels.remove(&key).map(|(_, rx)| rx)
90 }
91
92 fn get_or_create_choice_channel(&self, from: R, to: R) -> UnboundedSender<R::Label> {
94 let mut channels = self
95 .choice_channels
96 .lock()
97 .unwrap_or_else(std::sync::PoisonError::into_inner);
98 let key = (RoleKey::from_role(from), RoleKey::from_role(to));
99 channels.entry(key).or_insert_with(unbounded).0.clone()
100 }
101
102 fn get_choice_receiver(&self, from: R, to: R) -> Option<UnboundedReceiver<R::Label>> {
104 let mut channels = self
105 .choice_channels
106 .lock()
107 .unwrap_or_else(std::sync::PoisonError::into_inner);
108 let key = (RoleKey::from_role(from), RoleKey::from_role(to));
109 channels.remove(&key).map(|(_, rx)| rx)
110 }
111}
112
113#[async_trait]
114impl<R: RoleId + 'static> ChoreoHandler for InMemoryHandler<R> {
115 type Role = R;
116 type Endpoint = ();
117
118 async fn send<M: Serialize + Send + Sync>(
119 &mut self,
120 _ep: &mut Self::Endpoint,
121 to: Self::Role,
122 msg: &M,
123 ) -> ChoreoResult<()> {
124 let bytes =
126 bincode::serialize(msg).map_err(|e| ChoreographyError::Serialization(e.to_string()))?;
127
128 let sender = self.get_or_create_channel(self.role, to);
130 sender.unbounded_send(bytes).map_err(|_| {
131 ChoreographyError::Transport(format!(
132 "Failed to send message from {:?} to {:?}",
133 self.role, to
134 ))
135 })?;
136
137 tracing::trace!(?to, "InMemoryHandler: send success");
138 Ok(())
139 }
140
141 async fn recv<M: DeserializeOwned + Send>(
142 &mut self,
143 _ep: &mut Self::Endpoint,
144 from: Self::Role,
145 ) -> ChoreoResult<M> {
146 tracing::trace!(?from, "InMemoryHandler: recv start");
147
148 let mut receiver = self.get_receiver(from, self.role).ok_or_else(|| {
150 ChoreographyError::Transport(format!("No channel from {:?} to {:?}", from, self.role))
151 })?;
152
153 let bytes = receiver.next().await.ok_or_else(|| {
155 ChoreographyError::Transport("Channel closed while waiting for message".into())
156 })?;
157
158 {
160 let mut channels = self
161 .channels
162 .lock()
163 .unwrap_or_else(std::sync::PoisonError::into_inner);
164 let key = (RoleKey::from_role(from), RoleKey::from_role(self.role));
165 if let Some((tx, _)) = channels.remove(&key) {
166 channels.insert(key, (tx, receiver));
167 }
168 }
169
170 let msg = bincode::deserialize(&bytes)
172 .map_err(|e| ChoreographyError::Serialization(e.to_string()))?;
173
174 tracing::trace!(?from, "InMemoryHandler: recv success");
175 Ok(msg)
176 }
177
178 async fn choose(
179 &mut self,
180 _ep: &mut Self::Endpoint,
181 who: Self::Role,
182 label: <Self::Role as RoleId>::Label,
183 ) -> ChoreoResult<()> {
184 let sender = self.get_or_create_choice_channel(self.role, who);
186 sender.unbounded_send(label).map_err(|_| {
187 ChoreographyError::Transport(format!(
188 "Failed to send choice from {:?} to {:?}",
189 self.role, who
190 ))
191 })?;
192
193 tracing::trace!(?who, ?label, "InMemoryHandler: sent choice");
194 Ok(())
195 }
196
197 async fn offer(
198 &mut self,
199 _ep: &mut Self::Endpoint,
200 from: Self::Role,
201 ) -> ChoreoResult<<Self::Role as RoleId>::Label> {
202 tracing::trace!(?from, "InMemoryHandler: waiting for choice");
203
204 let mut receiver = self.get_choice_receiver(from, self.role).ok_or_else(|| {
206 ChoreographyError::Transport(format!(
207 "No choice channel from {:?} to {:?}",
208 from, self.role
209 ))
210 })?;
211
212 let label = receiver.next().await.ok_or_else(|| {
214 ChoreographyError::Transport("Choice channel closed while waiting for label".into())
215 })?;
216
217 {
219 let mut channels = self
220 .choice_channels
221 .lock()
222 .unwrap_or_else(std::sync::PoisonError::into_inner);
223 let key = (RoleKey::from_role(from), RoleKey::from_role(self.role));
224 if let Some((tx, _)) = channels.remove(&key) {
225 channels.insert(key, (tx, receiver));
226 }
227 }
228
229 tracing::trace!(?from, ?label, "InMemoryHandler: received choice");
230 Ok(label)
231 }
232
233 async fn with_timeout<F, T>(
234 &mut self,
235 _ep: &mut Self::Endpoint,
236 at: Self::Role,
237 dur: Duration,
238 body: F,
239 ) -> ChoreoResult<T>
240 where
241 F: std::future::Future<Output = ChoreoResult<T>> + Send,
242 {
243 if at == self.role {
244 cfg_if! {
246 if #[cfg(target_arch = "wasm32")] {
247 use futures::future::{select, Either};
249 use futures::pin_mut;
250 use wasm_timer::Delay;
251
252 let timeout = Delay::new(dur);
253 pin_mut!(body);
254 pin_mut!(timeout);
255
256 match select(body, timeout).await {
257 Either::Left((result, _)) => result,
258 Either::Right(_) => Err(ChoreographyError::Timeout(dur)),
259 }
260 } else {
261 match tokio::time::timeout(dur, body).await {
262 Ok(result) => result,
263 Err(_) => Err(ChoreographyError::Timeout(dur)),
264 }
265 }
266 }
267 } else {
268 body.await
269 }
270 }
271}