Skip to main content

telltale_runtime/effects/handlers/
telltale.rs

1// Telltale session-typed effect handler.
2//
3// This handler uses one canonical session boundary:
4// - SessionTypeDynamic: object-safe async session interface
5// - TelltaleSession: boxed dynamic session wrapper
6// - TelltaleEndpoint: per-peer session registry plus metadata
7// - TelltaleHandler: ChoreoHandler over registered TelltaleSession values
8
9use async_trait::async_trait;
10use cfg_if::cfg_if;
11use serde::{de::DeserializeOwned, Serialize};
12use std::{collections::HashMap, fmt::Debug, marker::PhantomData, time::Duration};
13
14use crate::effects::{ChoreoHandler, ChoreoResult, ChoreographyError, LabelId, RoleId};
15use telltale::{Message, Role};
16
17#[path = "telltale_session.rs"]
18mod session;
19pub use session::{SessionMetadata, SessionTypeDynamic, SessionUpdate, TelltaleSession};
20
21struct ChannelRecord {
22    session: TelltaleSession,
23    metadata: SessionMetadata,
24}
25
26/// Endpoint that manages per-peer channels/sessions plus metadata.
27pub struct TelltaleEndpoint<R>
28where
29    R: Role + Eq + std::hash::Hash + Clone + Debug,
30{
31    local_role: R,
32    channels: HashMap<R, ChannelRecord>,
33}
34
35impl<R> TelltaleEndpoint<R>
36where
37    R: Role + Eq + std::hash::Hash + Clone + Debug,
38{
39    pub fn new(local_role: R) -> Self {
40        Self {
41            local_role,
42            channels: HashMap::new(),
43        }
44    }
45
46    /// Register a dynamic session for a peer.
47    pub fn register_session(&mut self, peer: R, session: TelltaleSession) {
48        tracing::debug!(peer = ?peer, session = session.type_name(), "Registering dynamic session");
49        self.channels.insert(
50            peer,
51            ChannelRecord {
52                session,
53                metadata: SessionMetadata::default(),
54            },
55        );
56    }
57
58    fn take_record(&mut self, peer: &R) -> Option<ChannelRecord> {
59        self.channels.remove(peer)
60    }
61
62    fn put_record(&mut self, peer: R, record: ChannelRecord) {
63        self.channels.insert(peer, record);
64    }
65
66    pub fn has_channel(&self, peer: &R) -> bool {
67        self.channels.contains_key(peer)
68    }
69
70    pub fn close_channel(&mut self, peer: &R) -> bool {
71        self.channels.remove(peer).is_some()
72    }
73
74    pub fn close_all_channels(&mut self) -> usize {
75        let count = self.channels.len();
76        self.channels.clear();
77        count
78    }
79
80    pub fn is_all_closed(&self) -> bool {
81        self.channels.is_empty()
82    }
83
84    pub fn active_channel_count(&self) -> usize {
85        self.channels.len()
86    }
87
88    pub fn local_role(&self) -> &R {
89        &self.local_role
90    }
91
92    pub fn get_metadata(&self, peer: &R) -> Option<&SessionMetadata> {
93        self.channels.get(peer).map(|record| &record.metadata)
94    }
95
96    pub fn all_metadata(&self) -> Vec<(R, &SessionMetadata)> {
97        self.channels
98            .iter()
99            .map(|(peer, record)| (peer.clone(), &record.metadata))
100            .collect()
101    }
102}
103
104impl<R> Drop for TelltaleEndpoint<R>
105where
106    R: Role + Eq + std::hash::Hash + Clone + Debug,
107{
108    fn drop(&mut self) {
109        let active = self.active_channel_count();
110        if active > 0 {
111            tracing::warn!(active, "Endpoint dropped with active channels; closing");
112            self.close_all_channels();
113        }
114    }
115}
116
117/// Effect handler backed by Telltale sessions.
118pub struct TelltaleHandler<R, M> {
119    _phantom: PhantomData<(R, M)>,
120}
121
122impl<R, M> TelltaleHandler<R, M>
123where
124    R: Role + Eq + std::hash::Hash + Clone + Debug,
125{
126    #[must_use]
127    pub fn new() -> Self {
128        Self {
129            _phantom: PhantomData,
130        }
131    }
132
133    async fn with_channel_operation<T, F, Fut>(
134        ep: &mut TelltaleEndpoint<R>,
135        peer: &R,
136        default_description: &str,
137        f: F,
138    ) -> ChoreoResult<T>
139    where
140        F: FnOnce(TelltaleSession) -> Fut,
141        Fut: std::future::Future<Output = ChoreoResult<(T, TelltaleSession, Option<String>, bool)>>,
142    {
143        let mut record = ep
144            .take_record(peer)
145            .ok_or_else(|| ChoreographyError::NoPeerChannel {
146                peer: format!("{peer:?}"),
147            })?;
148
149        let (result, next_session, description, completed) = f(record.session).await?;
150        record.session = next_session;
151        record.metadata.operation_count += 1;
152        record.metadata.state_description =
153            description.unwrap_or_else(|| default_description.to_string());
154        if completed {
155            record.metadata.is_complete = true;
156        }
157
158        ep.put_record(peer.clone(), record);
159        Ok(result)
160    }
161}
162
163impl<R, M> Default for TelltaleHandler<R, M>
164where
165    R: Role + Eq + std::hash::Hash + Clone + Debug,
166{
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172#[async_trait]
173impl<R, M> ChoreoHandler for TelltaleHandler<R, M>
174where
175    R: Role<Message = M> + Send + Sync + RoleId + Eq + std::hash::Hash + Clone + Debug + 'static,
176    M: Message<Box<dyn std::any::Any + Send>> + Send + Sync + 'static,
177{
178    type Role = R;
179    type Endpoint = TelltaleEndpoint<R>;
180
181    async fn send<Msg: Serialize + Send + Sync>(
182        &mut self,
183        ep: &mut Self::Endpoint,
184        to: Self::Role,
185        msg: &Msg,
186    ) -> ChoreoResult<()> {
187        let serialized =
188            bincode::serialize(msg).map_err(|e| ChoreographyError::MessageSerializationFailed {
189                operation: "Serialization",
190                type_name: std::any::type_name::<Msg>(),
191                reason: e.to_string(),
192            })?;
193
194        Self::with_channel_operation(ep, &to, "Send", |state| async move {
195            let mut session = state;
196            let update = session.send(serialized).await?;
197            Ok(((), session, update.description, update.is_complete))
198        })
199        .await
200    }
201
202    async fn recv<Msg: DeserializeOwned + Send>(
203        &mut self,
204        ep: &mut Self::Endpoint,
205        from: Self::Role,
206    ) -> ChoreoResult<Msg> {
207        Self::with_channel_operation(ep, &from, "Recv", |state| async move {
208            let mut session = state;
209            let update = session.recv().await?;
210            let msg = bincode::deserialize(&update.output).map_err(|e| {
211                ChoreographyError::MessageSerializationFailed {
212                    operation: "Deserialization",
213                    type_name: std::any::type_name::<Msg>(),
214                    reason: e.to_string(),
215                }
216            })?;
217            Ok((msg, session, update.description, update.is_complete))
218        })
219        .await
220    }
221
222    async fn choose(
223        &mut self,
224        ep: &mut Self::Endpoint,
225        who: Self::Role,
226        label: <Self::Role as RoleId>::Label,
227    ) -> ChoreoResult<()> {
228        let label_str = label.as_str().to_string();
229        Self::with_channel_operation(ep, &who, "Choose", |state| async move {
230            let mut session = state;
231            let update = session.choose(&label_str).await?;
232            Ok(((), session, update.description, update.is_complete))
233        })
234        .await
235    }
236
237    async fn offer(
238        &mut self,
239        ep: &mut Self::Endpoint,
240        from: Self::Role,
241    ) -> ChoreoResult<<Self::Role as RoleId>::Label> {
242        Self::with_channel_operation(ep, &from, "Offer", |state| async move {
243            let mut session = state;
244            let update = session.offer().await?;
245            let label =
246                <Self::Role as RoleId>::Label::from_str(&update.output).ok_or_else(|| {
247                    ChoreographyError::ProtocolViolation(format!(
248                        "Unknown label '{}'",
249                        update.output
250                    ))
251                })?;
252            Ok((label, session, update.description, update.is_complete))
253        })
254        .await
255    }
256
257    async fn with_timeout<F, T>(
258        &mut self,
259        _ep: &mut Self::Endpoint,
260        _at: Self::Role,
261        dur: Duration,
262        body: F,
263    ) -> ChoreoResult<T>
264    where
265        F: std::future::Future<Output = ChoreoResult<T>> + Send,
266    {
267        cfg_if! {
268            if #[cfg(target_arch = "wasm32")] {
269                use futures::future::{select, Either};
270                use futures::pin_mut;
271                use wasm_timer::Delay;
272
273                let timeout = Delay::new(dur);
274                pin_mut!(body);
275                pin_mut!(timeout);
276
277                match select(body, timeout).await {
278                    Either::Left((result, _)) => result,
279                    Either::Right(_) => Err(ChoreographyError::Timeout(dur)),
280                }
281            } else {
282                match tokio::time::timeout(dur, body).await {
283                    Ok(result) => result,
284                    Err(_) => Err(ChoreographyError::Timeout(dur)),
285                }
286            }
287        }
288    }
289}