Skip to main content

telltale_runtime/effects/handlers/
in_memory.rs

1// In-memory effect handler for testing
2//
3// Uses futures channels to simulate message passing between roles without
4// requiring actual network communication. Ideal for protocol testing.
5// WASM-compatible.
6
7use async_trait::async_trait;
8use cfg_if::cfg_if;
9use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
10use futures::StreamExt;
11use serde::{de::DeserializeOwned, Serialize};
12use std::collections::BTreeMap;
13use std::time::Duration;
14
15use crate::effects::{ChoreoHandler, ChoreoResult, ChoreographyError, RoleId};
16use crate::RoleName;
17
18type MessageChannelPair = (UnboundedSender<Vec<u8>>, UnboundedReceiver<Vec<u8>>);
19type ChoiceChannelPair<L> = (UnboundedSender<L>, UnboundedReceiver<L>);
20type MessageChannelMap =
21    std::sync::Arc<std::sync::Mutex<BTreeMap<(RoleKey, RoleKey), MessageChannelPair>>>;
22type ChoiceChannelMap<L> =
23    std::sync::Arc<std::sync::Mutex<BTreeMap<(RoleKey, RoleKey), ChoiceChannelPair<L>>>>;
24
25#[doc(hidden)]
26#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
27pub struct RoleKey {
28    name: RoleName,
29    index: Option<u32>,
30}
31
32impl RoleKey {
33    fn from_role<R: RoleId>(role: R) -> Self {
34        Self {
35            name: role.role_name(),
36            index: role.role_index(),
37        }
38    }
39}
40
41/// In-memory handler for testing - uses tokio channels
42pub struct InMemoryHandler<R: RoleId> {
43    role: R,
44    // Channel map for sending/receiving messages between roles
45    channels: MessageChannelMap,
46    // Choice channel for broadcasting/receiving choice labels
47    choice_channels: ChoiceChannelMap<R::Label>,
48}
49
50impl<R: RoleId> InMemoryHandler<R> {
51    pub fn new(role: R) -> Self {
52        Self {
53            role,
54            channels: std::sync::Arc::new(std::sync::Mutex::new(BTreeMap::new())),
55            choice_channels: std::sync::Arc::new(std::sync::Mutex::new(BTreeMap::new())),
56        }
57    }
58
59    /// Create a new handler with shared channels for coordinated testing
60    pub fn with_channels(
61        role: R,
62        channels: MessageChannelMap,
63        choice_channels: ChoiceChannelMap<R::Label>,
64    ) -> Self {
65        Self {
66            role,
67            channels,
68            choice_channels,
69        }
70    }
71
72    /// Get or create a channel pair for communication between two roles
73    fn get_or_create_channel(&self, from: R, to: R) -> UnboundedSender<Vec<u8>> {
74        let mut channels = self
75            .channels
76            .lock()
77            .unwrap_or_else(std::sync::PoisonError::into_inner);
78        let key = (RoleKey::from_role(from), RoleKey::from_role(to));
79        channels.entry(key).or_insert_with(unbounded).0.clone()
80    }
81
82    /// Get receiver for a channel pair
83    fn get_receiver(&self, from: R, to: R) -> Option<UnboundedReceiver<Vec<u8>>> {
84        let mut channels = self
85            .channels
86            .lock()
87            .unwrap_or_else(std::sync::PoisonError::into_inner);
88        let key = (RoleKey::from_role(from), RoleKey::from_role(to));
89        channels.remove(&key).map(|(_, rx)| rx)
90    }
91
92    /// Get or create a choice channel pair for broadcasting choices
93    fn get_or_create_choice_channel(&self, from: R, to: R) -> UnboundedSender<R::Label> {
94        let mut channels = self
95            .choice_channels
96            .lock()
97            .unwrap_or_else(std::sync::PoisonError::into_inner);
98        let key = (RoleKey::from_role(from), RoleKey::from_role(to));
99        channels.entry(key).or_insert_with(unbounded).0.clone()
100    }
101
102    /// Get choice receiver for a channel pair
103    fn get_choice_receiver(&self, from: R, to: R) -> Option<UnboundedReceiver<R::Label>> {
104        let mut channels = self
105            .choice_channels
106            .lock()
107            .unwrap_or_else(std::sync::PoisonError::into_inner);
108        let key = (RoleKey::from_role(from), RoleKey::from_role(to));
109        channels.remove(&key).map(|(_, rx)| rx)
110    }
111}
112
113#[async_trait]
114impl<R: RoleId + 'static> ChoreoHandler for InMemoryHandler<R> {
115    type Role = R;
116    type Endpoint = ();
117
118    async fn send<M: Serialize + Send + Sync>(
119        &mut self,
120        _ep: &mut Self::Endpoint,
121        to: Self::Role,
122        msg: &M,
123    ) -> ChoreoResult<()> {
124        // Serialize message
125        let bytes =
126            bincode::serialize(msg).map_err(|e| ChoreographyError::Serialization(e.to_string()))?;
127
128        // Get or create channel for (self.role, to) and send bytes
129        let sender = self.get_or_create_channel(self.role, to);
130        sender.unbounded_send(bytes).map_err(|_| {
131            ChoreographyError::Transport(format!(
132                "Failed to send message from {:?} to {:?}",
133                self.role, to
134            ))
135        })?;
136
137        tracing::trace!(?to, "InMemoryHandler: send success");
138        Ok(())
139    }
140
141    async fn recv<M: DeserializeOwned + Send>(
142        &mut self,
143        _ep: &mut Self::Endpoint,
144        from: Self::Role,
145    ) -> ChoreoResult<M> {
146        tracing::trace!(?from, "InMemoryHandler: recv start");
147
148        // Get the receiver for messages from 'from' to 'self.role'
149        let mut receiver = self.get_receiver(from, self.role).ok_or_else(|| {
150            ChoreographyError::Transport(format!("No channel from {:?} to {:?}", from, self.role))
151        })?;
152
153        // Wait for message
154        let bytes = receiver.next().await.ok_or_else(|| {
155            ChoreographyError::Transport("Channel closed while waiting for message".into())
156        })?;
157
158        // Put the receiver back
159        {
160            let mut channels = self
161                .channels
162                .lock()
163                .unwrap_or_else(std::sync::PoisonError::into_inner);
164            let key = (RoleKey::from_role(from), RoleKey::from_role(self.role));
165            if let Some((tx, _)) = channels.remove(&key) {
166                channels.insert(key, (tx, receiver));
167            }
168        }
169
170        // Deserialize message
171        let msg = bincode::deserialize(&bytes)
172            .map_err(|e| ChoreographyError::Serialization(e.to_string()))?;
173
174        tracing::trace!(?from, "InMemoryHandler: recv success");
175        Ok(msg)
176    }
177
178    async fn choose(
179        &mut self,
180        _ep: &mut Self::Endpoint,
181        who: Self::Role,
182        label: <Self::Role as RoleId>::Label,
183    ) -> ChoreoResult<()> {
184        // Send choice label from self.role to who via the choice channel
185        let sender = self.get_or_create_choice_channel(self.role, who);
186        sender.unbounded_send(label).map_err(|_| {
187            ChoreographyError::Transport(format!(
188                "Failed to send choice from {:?} to {:?}",
189                self.role, who
190            ))
191        })?;
192
193        tracing::trace!(?who, ?label, "InMemoryHandler: sent choice");
194        Ok(())
195    }
196
197    async fn offer(
198        &mut self,
199        _ep: &mut Self::Endpoint,
200        from: Self::Role,
201    ) -> ChoreoResult<<Self::Role as RoleId>::Label> {
202        tracing::trace!(?from, "InMemoryHandler: waiting for choice");
203
204        // Get the choice receiver for choices from 'from' to 'self.role'
205        let mut receiver = self.get_choice_receiver(from, self.role).ok_or_else(|| {
206            ChoreographyError::Transport(format!(
207                "No choice channel from {:?} to {:?}",
208                from, self.role
209            ))
210        })?;
211
212        // Wait for choice label
213        let label = receiver.next().await.ok_or_else(|| {
214            ChoreographyError::Transport("Choice channel closed while waiting for label".into())
215        })?;
216
217        // Put the receiver back
218        {
219            let mut channels = self
220                .choice_channels
221                .lock()
222                .unwrap_or_else(std::sync::PoisonError::into_inner);
223            let key = (RoleKey::from_role(from), RoleKey::from_role(self.role));
224            if let Some((tx, _)) = channels.remove(&key) {
225                channels.insert(key, (tx, receiver));
226            }
227        }
228
229        tracing::trace!(?from, ?label, "InMemoryHandler: received choice");
230        Ok(label)
231    }
232
233    async fn with_timeout<F, T>(
234        &mut self,
235        _ep: &mut Self::Endpoint,
236        at: Self::Role,
237        dur: Duration,
238        body: F,
239    ) -> ChoreoResult<T>
240    where
241        F: std::future::Future<Output = ChoreoResult<T>> + Send,
242    {
243        if at == self.role {
244            // Platform-specific timeout implementation
245            cfg_if! {
246                if #[cfg(target_arch = "wasm32")] {
247                    // Use wasm_timer for WASM compatibility.
248                    use futures::future::{select, Either};
249                    use futures::pin_mut;
250                    use wasm_timer::Delay;
251
252                    let timeout = Delay::new(dur);
253                    pin_mut!(body);
254                    pin_mut!(timeout);
255
256                    match select(body, timeout).await {
257                        Either::Left((result, _)) => result,
258                        Either::Right(_) => Err(ChoreographyError::Timeout(dur)),
259                    }
260                } else {
261                    match tokio::time::timeout(dur, body).await {
262                        Ok(result) => result,
263                        Err(_) => Err(ChoreographyError::Timeout(dur)),
264                    }
265                }
266            }
267        } else {
268            body.await
269        }
270    }
271}