1use async_trait::async_trait;
10use cfg_if::cfg_if;
11use serde::{de::DeserializeOwned, Serialize};
12use std::{collections::HashMap, fmt::Debug, marker::PhantomData, time::Duration};
13
14use crate::effects::contract::{
15 DeliveryModel, DocumentedHandlerContract, ExtensionDispatchContract, ExtensionDispatchMode,
16 HandlerContractProfile, HandlerContractTier, ProtocolSemanticContract, RetryPolicy,
17 TimeoutPolicy, TransportPolicyContract,
18};
19use crate::effects::{ChoreoHandler, ChoreoResult, ChoreographyError, LabelId, RoleId};
20use telltale::{Message, Role};
21
22#[path = "telltale_session.rs"]
23mod session;
24pub use session::{SessionMetadata, SessionTypeDynamic, SessionUpdate, TelltaleSession};
25
26struct ChannelRecord {
27 session: TelltaleSession,
28 metadata: SessionMetadata,
29}
30
31pub struct TelltaleEndpoint<R>
33where
34 R: Role + Eq + std::hash::Hash + Clone + Debug,
35{
36 local_role: R,
37 channels: HashMap<R, ChannelRecord>,
38}
39
40impl<R> TelltaleEndpoint<R>
41where
42 R: Role + Eq + std::hash::Hash + Clone + Debug,
43{
44 pub fn new(local_role: R) -> Self {
45 Self {
46 local_role,
47 channels: HashMap::new(),
48 }
49 }
50
51 pub fn register_session(&mut self, peer: R, session: TelltaleSession) {
53 tracing::debug!(peer = ?peer, session = session.type_name(), "Registering dynamic session");
54 self.channels.insert(
55 peer,
56 ChannelRecord {
57 session,
58 metadata: SessionMetadata::default(),
59 },
60 );
61 }
62
63 fn take_record(&mut self, peer: &R) -> Option<ChannelRecord> {
64 self.channels.remove(peer)
65 }
66
67 fn put_record(&mut self, peer: R, record: ChannelRecord) {
68 self.channels.insert(peer, record);
69 }
70
71 pub fn has_channel(&self, peer: &R) -> bool {
72 self.channels.contains_key(peer)
73 }
74
75 pub fn close_channel(&mut self, peer: &R) -> bool {
76 self.channels.remove(peer).is_some()
77 }
78
79 pub fn close_all_channels(&mut self) -> usize {
80 let count = self.channels.len();
81 self.channels.clear();
82 count
83 }
84
85 pub fn is_all_closed(&self) -> bool {
86 self.channels.is_empty()
87 }
88
89 pub fn active_channel_count(&self) -> usize {
90 self.channels.len()
91 }
92
93 pub fn local_role(&self) -> &R {
94 &self.local_role
95 }
96
97 pub fn get_metadata(&self, peer: &R) -> Option<&SessionMetadata> {
98 self.channels.get(peer).map(|record| &record.metadata)
99 }
100
101 pub fn all_metadata(&self) -> Vec<(R, &SessionMetadata)> {
102 self.channels
103 .iter()
104 .map(|(peer, record)| (peer.clone(), &record.metadata))
105 .collect()
106 }
107}
108
109impl<R> Drop for TelltaleEndpoint<R>
110where
111 R: Role + Eq + std::hash::Hash + Clone + Debug,
112{
113 fn drop(&mut self) {
114 let active = self.active_channel_count();
115 if active > 0 {
116 tracing::warn!(active, "Endpoint dropped with active channels; closing");
117 self.close_all_channels();
118 }
119 }
120}
121
122pub struct TelltaleHandler<R, M> {
124 _phantom: PhantomData<(R, M)>,
125}
126
127impl<R, M> TelltaleHandler<R, M>
128where
129 R: Role + Eq + std::hash::Hash + Clone + Debug,
130{
131 #[must_use]
132 pub fn new() -> Self {
133 Self {
134 _phantom: PhantomData,
135 }
136 }
137
138 async fn with_channel_operation<T, F, Fut>(
139 ep: &mut TelltaleEndpoint<R>,
140 peer: &R,
141 default_description: &str,
142 f: F,
143 ) -> ChoreoResult<T>
144 where
145 F: FnOnce(TelltaleSession) -> Fut,
146 Fut: std::future::Future<Output = ChoreoResult<(T, TelltaleSession, Option<String>, bool)>>,
147 {
148 let mut record = ep
149 .take_record(peer)
150 .ok_or_else(|| ChoreographyError::NoPeerChannel {
151 peer: format!("{peer:?}"),
152 })?;
153
154 let (result, next_session, description, completed) = f(record.session).await?;
155 record.session = next_session;
156 record.metadata.operation_count += 1;
157 record.metadata.state_description =
158 description.unwrap_or_else(|| default_description.to_string());
159 if completed {
160 record.metadata.is_complete = true;
161 }
162
163 ep.put_record(peer.clone(), record);
164 Ok(result)
165 }
166}
167
168impl<R, M> Default for TelltaleHandler<R, M>
169where
170 R: Role + Eq + std::hash::Hash + Clone + Debug,
171{
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177impl<R, M> DocumentedHandlerContract for TelltaleHandler<R, M>
178where
179 R: Role + Eq + std::hash::Hash + Clone + Debug,
180{
181 fn contract_profile() -> HandlerContractProfile {
182 HandlerContractProfile {
183 handler_name: std::any::type_name::<Self>(),
184 tier: HandlerContractTier::FullProtocol,
185 semantics: ProtocolSemanticContract {
186 typed_send_recv_roundtrip: true,
187 exact_choice_label_preservation: true,
188 fail_closed_transport_errors: true,
189 timeouts_scoped_to_enforcing_role: true,
190 deterministic_for_regression: true,
191 can_materialize_values: true,
192 },
193 transport: TransportPolicyContract {
194 delivery_model: DeliveryModel::SessionBoundary,
195 retry_policy: RetryPolicy::None,
196 timeout_policy: TimeoutPolicy::EnforcingRoleOnly,
197 },
198 extension_dispatch: ExtensionDispatchContract {
199 mode: ExtensionDispatchMode::Unsupported,
200 fail_closed_when_unregistered: false,
201 type_exact_before_side_effects: false,
202 },
203 notes: vec![
204 "session metadata is updated only after the underlying dynamic session operation succeeds",
205 "serialization/deserialization failures are surfaced as protocol-visible errors",
206 ],
207 }
208 }
209}
210
211#[async_trait]
212impl<R, M> ChoreoHandler for TelltaleHandler<R, M>
213where
214 R: Role<Message = M> + Send + Sync + RoleId + Eq + std::hash::Hash + Clone + Debug + 'static,
215 M: Message<Box<dyn std::any::Any + Send>> + Send + Sync + 'static,
216{
217 type Role = R;
218 type Endpoint = TelltaleEndpoint<R>;
219
220 async fn send<Msg: Serialize + Send + Sync>(
221 &mut self,
222 ep: &mut Self::Endpoint,
223 to: Self::Role,
224 msg: &Msg,
225 ) -> ChoreoResult<()> {
226 let serialized =
227 bincode::serialize(msg).map_err(|e| ChoreographyError::MessageSerializationFailed {
228 operation: "Serialization",
229 type_name: std::any::type_name::<Msg>(),
230 reason: e.to_string(),
231 })?;
232
233 Self::with_channel_operation(ep, &to, "Send", |state| async move {
234 let mut session = state;
235 let update = session.send(serialized).await?;
236 Ok(((), session, update.description, update.is_complete))
237 })
238 .await
239 }
240
241 async fn recv<Msg: DeserializeOwned + Send>(
242 &mut self,
243 ep: &mut Self::Endpoint,
244 from: Self::Role,
245 ) -> ChoreoResult<Msg> {
246 Self::with_channel_operation(ep, &from, "Recv", |state| async move {
247 let mut session = state;
248 let update = session.recv().await?;
249 let msg = bincode::deserialize(&update.output).map_err(|e| {
250 ChoreographyError::MessageSerializationFailed {
251 operation: "Deserialization",
252 type_name: std::any::type_name::<Msg>(),
253 reason: e.to_string(),
254 }
255 })?;
256 Ok((msg, session, update.description, update.is_complete))
257 })
258 .await
259 }
260
261 async fn choose(
262 &mut self,
263 ep: &mut Self::Endpoint,
264 who: Self::Role,
265 label: <Self::Role as RoleId>::Label,
266 ) -> ChoreoResult<()> {
267 let label_str = label.as_str().to_string();
268 Self::with_channel_operation(ep, &who, "Choose", |state| async move {
269 let mut session = state;
270 let update = session.choose(&label_str).await?;
271 Ok(((), session, update.description, update.is_complete))
272 })
273 .await
274 }
275
276 async fn offer(
277 &mut self,
278 ep: &mut Self::Endpoint,
279 from: Self::Role,
280 ) -> ChoreoResult<<Self::Role as RoleId>::Label> {
281 Self::with_channel_operation(ep, &from, "Offer", |state| async move {
282 let mut session = state;
283 let update = session.offer().await?;
284 let label =
285 <Self::Role as RoleId>::Label::from_str(&update.output).ok_or_else(|| {
286 ChoreographyError::ProtocolViolation(format!(
287 "Unknown label '{}'",
288 update.output
289 ))
290 })?;
291 Ok((label, session, update.description, update.is_complete))
292 })
293 .await
294 }
295
296 async fn with_timeout<F, T>(
297 &mut self,
298 _ep: &mut Self::Endpoint,
299 _at: Self::Role,
300 dur: Duration,
301 body: F,
302 ) -> ChoreoResult<T>
303 where
304 F: std::future::Future<Output = ChoreoResult<T>> + Send,
305 {
306 cfg_if! {
307 if #[cfg(target_arch = "wasm32")] {
308 use futures::future::{select, Either};
309 use futures::pin_mut;
310 use wasm_timer::Delay;
311
312 let timeout = Delay::new(dur);
313 pin_mut!(body);
314 pin_mut!(timeout);
315
316 match select(body, timeout).await {
317 Either::Left((result, _)) => result,
318 Either::Right(_) => Err(ChoreographyError::Timeout(dur)),
319 }
320 } else {
321 match tokio::time::timeout(dur, body).await {
322 Ok(result) => result,
323 Err(_) => Err(ChoreographyError::Timeout(dur)),
324 }
325 }
326 }
327 }
328}