telltale_runtime/effects/handlers/
telltale.rs1use async_trait::async_trait;
10use cfg_if::cfg_if;
11use serde::{de::DeserializeOwned, Serialize};
12use std::{collections::HashMap, fmt::Debug, marker::PhantomData, time::Duration};
13
14use crate::effects::{ChoreoHandler, ChoreoResult, ChoreographyError, LabelId, RoleId};
15use telltale::{Message, Role};
16
17#[path = "telltale_session.rs"]
18mod session;
19pub use session::{SessionMetadata, SessionTypeDynamic, SessionUpdate, TelltaleSession};
20
21struct ChannelRecord {
22 session: TelltaleSession,
23 metadata: SessionMetadata,
24}
25
26pub struct TelltaleEndpoint<R>
28where
29 R: Role + Eq + std::hash::Hash + Clone + Debug,
30{
31 local_role: R,
32 channels: HashMap<R, ChannelRecord>,
33}
34
35impl<R> TelltaleEndpoint<R>
36where
37 R: Role + Eq + std::hash::Hash + Clone + Debug,
38{
39 pub fn new(local_role: R) -> Self {
40 Self {
41 local_role,
42 channels: HashMap::new(),
43 }
44 }
45
46 pub fn register_session(&mut self, peer: R, session: TelltaleSession) {
48 tracing::debug!(peer = ?peer, session = session.type_name(), "Registering dynamic session");
49 self.channels.insert(
50 peer,
51 ChannelRecord {
52 session,
53 metadata: SessionMetadata::default(),
54 },
55 );
56 }
57
58 fn take_record(&mut self, peer: &R) -> Option<ChannelRecord> {
59 self.channels.remove(peer)
60 }
61
62 fn put_record(&mut self, peer: R, record: ChannelRecord) {
63 self.channels.insert(peer, record);
64 }
65
66 pub fn has_channel(&self, peer: &R) -> bool {
67 self.channels.contains_key(peer)
68 }
69
70 pub fn close_channel(&mut self, peer: &R) -> bool {
71 self.channels.remove(peer).is_some()
72 }
73
74 pub fn close_all_channels(&mut self) -> usize {
75 let count = self.channels.len();
76 self.channels.clear();
77 count
78 }
79
80 pub fn is_all_closed(&self) -> bool {
81 self.channels.is_empty()
82 }
83
84 pub fn active_channel_count(&self) -> usize {
85 self.channels.len()
86 }
87
88 pub fn local_role(&self) -> &R {
89 &self.local_role
90 }
91
92 pub fn get_metadata(&self, peer: &R) -> Option<&SessionMetadata> {
93 self.channels.get(peer).map(|record| &record.metadata)
94 }
95
96 pub fn all_metadata(&self) -> Vec<(R, &SessionMetadata)> {
97 self.channels
98 .iter()
99 .map(|(peer, record)| (peer.clone(), &record.metadata))
100 .collect()
101 }
102}
103
104impl<R> Drop for TelltaleEndpoint<R>
105where
106 R: Role + Eq + std::hash::Hash + Clone + Debug,
107{
108 fn drop(&mut self) {
109 let active = self.active_channel_count();
110 if active > 0 {
111 tracing::warn!(active, "Endpoint dropped with active channels; closing");
112 self.close_all_channels();
113 }
114 }
115}
116
117pub struct TelltaleHandler<R, M> {
119 _phantom: PhantomData<(R, M)>,
120}
121
122impl<R, M> TelltaleHandler<R, M>
123where
124 R: Role + Eq + std::hash::Hash + Clone + Debug,
125{
126 #[must_use]
127 pub fn new() -> Self {
128 Self {
129 _phantom: PhantomData,
130 }
131 }
132
133 async fn with_channel_operation<T, F, Fut>(
134 ep: &mut TelltaleEndpoint<R>,
135 peer: &R,
136 default_description: &str,
137 f: F,
138 ) -> ChoreoResult<T>
139 where
140 F: FnOnce(TelltaleSession) -> Fut,
141 Fut: std::future::Future<Output = ChoreoResult<(T, TelltaleSession, Option<String>, bool)>>,
142 {
143 let mut record = ep
144 .take_record(peer)
145 .ok_or_else(|| ChoreographyError::NoPeerChannel {
146 peer: format!("{peer:?}"),
147 })?;
148
149 let (result, next_session, description, completed) = f(record.session).await?;
150 record.session = next_session;
151 record.metadata.operation_count += 1;
152 record.metadata.state_description =
153 description.unwrap_or_else(|| default_description.to_string());
154 if completed {
155 record.metadata.is_complete = true;
156 }
157
158 ep.put_record(peer.clone(), record);
159 Ok(result)
160 }
161}
162
163impl<R, M> Default for TelltaleHandler<R, M>
164where
165 R: Role + Eq + std::hash::Hash + Clone + Debug,
166{
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[async_trait]
173impl<R, M> ChoreoHandler for TelltaleHandler<R, M>
174where
175 R: Role<Message = M> + Send + Sync + RoleId + Eq + std::hash::Hash + Clone + Debug + 'static,
176 M: Message<Box<dyn std::any::Any + Send>> + Send + Sync + 'static,
177{
178 type Role = R;
179 type Endpoint = TelltaleEndpoint<R>;
180
181 async fn send<Msg: Serialize + Send + Sync>(
182 &mut self,
183 ep: &mut Self::Endpoint,
184 to: Self::Role,
185 msg: &Msg,
186 ) -> ChoreoResult<()> {
187 let serialized =
188 bincode::serialize(msg).map_err(|e| ChoreographyError::MessageSerializationFailed {
189 operation: "Serialization",
190 type_name: std::any::type_name::<Msg>(),
191 reason: e.to_string(),
192 })?;
193
194 Self::with_channel_operation(ep, &to, "Send", |state| async move {
195 let mut session = state;
196 let update = session.send(serialized).await?;
197 Ok(((), session, update.description, update.is_complete))
198 })
199 .await
200 }
201
202 async fn recv<Msg: DeserializeOwned + Send>(
203 &mut self,
204 ep: &mut Self::Endpoint,
205 from: Self::Role,
206 ) -> ChoreoResult<Msg> {
207 Self::with_channel_operation(ep, &from, "Recv", |state| async move {
208 let mut session = state;
209 let update = session.recv().await?;
210 let msg = bincode::deserialize(&update.output).map_err(|e| {
211 ChoreographyError::MessageSerializationFailed {
212 operation: "Deserialization",
213 type_name: std::any::type_name::<Msg>(),
214 reason: e.to_string(),
215 }
216 })?;
217 Ok((msg, session, update.description, update.is_complete))
218 })
219 .await
220 }
221
222 async fn choose(
223 &mut self,
224 ep: &mut Self::Endpoint,
225 who: Self::Role,
226 label: <Self::Role as RoleId>::Label,
227 ) -> ChoreoResult<()> {
228 let label_str = label.as_str().to_string();
229 Self::with_channel_operation(ep, &who, "Choose", |state| async move {
230 let mut session = state;
231 let update = session.choose(&label_str).await?;
232 Ok(((), session, update.description, update.is_complete))
233 })
234 .await
235 }
236
237 async fn offer(
238 &mut self,
239 ep: &mut Self::Endpoint,
240 from: Self::Role,
241 ) -> ChoreoResult<<Self::Role as RoleId>::Label> {
242 Self::with_channel_operation(ep, &from, "Offer", |state| async move {
243 let mut session = state;
244 let update = session.offer().await?;
245 let label =
246 <Self::Role as RoleId>::Label::from_str(&update.output).ok_or_else(|| {
247 ChoreographyError::ProtocolViolation(format!(
248 "Unknown label '{}'",
249 update.output
250 ))
251 })?;
252 Ok((label, session, update.description, update.is_complete))
253 })
254 .await
255 }
256
257 async fn with_timeout<F, T>(
258 &mut self,
259 _ep: &mut Self::Endpoint,
260 _at: Self::Role,
261 dur: Duration,
262 body: F,
263 ) -> ChoreoResult<T>
264 where
265 F: std::future::Future<Output = ChoreoResult<T>> + Send,
266 {
267 cfg_if! {
268 if #[cfg(target_arch = "wasm32")] {
269 use futures::future::{select, Either};
270 use futures::pin_mut;
271 use wasm_timer::Delay;
272
273 let timeout = Delay::new(dur);
274 pin_mut!(body);
275 pin_mut!(timeout);
276
277 match select(body, timeout).await {
278 Either::Left((result, _)) => result,
279 Either::Right(_) => Err(ChoreographyError::Timeout(dur)),
280 }
281 } else {
282 match tokio::time::timeout(dur, body).await {
283 Ok(result) => result,
284 Err(_) => Err(ChoreographyError::Timeout(dur)),
285 }
286 }
287 }
288 }
289}