1mod annotation;
11mod dynamic;
12mod topology;
13
14pub(crate) use annotation::{
16 generate_annotation_docs, generate_annotation_metadata, generate_runtime_annotation_access,
17};
18
19pub use dynamic::{generate_choreography_code_with_dynamic_roles, generate_dynamic_role_support};
20pub use topology::{
21 generate_choreography_code_with_topology, generate_topology_integration, InlineTopology,
22};
23
24use crate::ast::{Choreography, LocalType, MessageType, Role};
25use crate::extensions::ProtocolExtension;
26use proc_macro2::{Ident, TokenStream};
27use quote::{format_ident, quote};
28use std::collections::HashMap;
29
30#[must_use]
32pub fn generate_session_type(
33 role: &Role,
34 local_type: &LocalType,
35 protocol_name: &str,
36) -> TokenStream {
37 let type_name = format_ident!("{}_{}", role.name(), protocol_name);
38 let inner_type = generate_type_expr(local_type);
39
40 quote! {
41 #[session]
42 type #type_name = #inner_type;
43 }
44}
45
46fn generate_type_expr(local_type: &LocalType) -> TokenStream {
49 match local_type {
50 LocalType::Send {
51 to,
52 message,
53 continuation,
54 } => generate_send_type_expr(to, message, continuation),
55 LocalType::Receive {
56 from,
57 message,
58 continuation,
59 } => generate_receive_type_expr(from, message, continuation),
60 LocalType::Select { to, branches } => generate_select_type_expr(to, branches),
61 LocalType::Branch { from, branches } => generate_branch_type_expr(from, branches),
62 LocalType::LocalChoice { branches } => generate_local_choice_type_expr(branches),
63 LocalType::Loop { condition, body } => generate_loop_type_expr(condition, body),
64 LocalType::Rec {
65 label: _label,
66 body,
67 } => generate_rec_type_expr(body),
68
69 LocalType::Var(label) => {
70 quote! { #label }
74 }
75
76 LocalType::End => {
77 quote! { End }
78 }
79
80 LocalType::Timeout {
81 duration: _,
82 body,
83 on_timeout: _,
84 on_cancel: _,
85 } => {
86 generate_type_expr(body)
88 }
89 }
90}
91
92fn generate_send_type_expr(
93 to: &Role,
94 message: &MessageType,
95 continuation: &LocalType,
96) -> TokenStream {
97 let to_name = to.name();
98 let msg_name = &message.name;
99 let cont = generate_type_expr(continuation);
100 quote! { Send<#to_name, #msg_name, #cont> }
101}
102
103fn generate_receive_type_expr(
104 from: &Role,
105 message: &MessageType,
106 continuation: &LocalType,
107) -> TokenStream {
108 let from_name = from.name();
109 let msg_name = &message.name;
110 let cont = generate_type_expr(continuation);
111 quote! { Receive<#from_name, #msg_name, #cont> }
112}
113
114fn generate_select_type_expr(to: &Role, branches: &[(Ident, LocalType)]) -> TokenStream {
115 let to_name = to.name();
116 let choice_type = generate_choice_enum(branches, true);
117 quote! { Select<#to_name, #choice_type> }
118}
119
120fn generate_branch_type_expr(from: &Role, branches: &[(Ident, LocalType)]) -> TokenStream {
121 let from_name = from.name();
122 let choice_type = generate_choice_enum(branches, false);
123 quote! { Branch<#from_name, #choice_type> }
124}
125
126fn generate_local_choice_type_expr(branches: &[(Ident, LocalType)]) -> TokenStream {
127 let choice_type = generate_choice_enum(branches, true);
128 quote! { LocalChoice<#choice_type> }
129}
130
131fn generate_loop_type_expr(
132 condition: &Option<crate::ast::Condition>,
133 body: &LocalType,
134) -> TokenStream {
135 let body_expr = generate_type_expr(body);
136 match condition {
137 Some(crate::ast::Condition::Count(_)) => quote! { Loop<#body_expr> },
138 Some(crate::ast::Condition::RoleDecides(_)) => quote! { Loop<#body_expr> },
139 Some(crate::ast::Condition::Custom(_)) => quote! { Loop<#body_expr> },
140 Some(crate::ast::Condition::Fuel(_)) => quote! { Loop<#body_expr> },
141 Some(crate::ast::Condition::YieldAfter(_)) => quote! { Loop<#body_expr> },
142 Some(crate::ast::Condition::YieldWhen(_)) => quote! { Loop<#body_expr> },
143 None => quote! { Loop<#body_expr> },
144 }
145}
146
147fn generate_rec_type_expr(body: &LocalType) -> TokenStream {
148 let body_expr = generate_type_expr(body);
149 quote! { #body_expr }
150}
151
152fn generate_choice_enum(branches: &[(Ident, LocalType)], _is_select: bool) -> TokenStream {
154 let enum_name = format_ident!(
155 "Choice{}",
156 branches
157 .iter()
158 .map(|(l, _)| l.to_string())
159 .collect::<String>()
160 );
161
162 let variants: Vec<TokenStream> = branches
163 .iter()
164 .map(|(label, local_type)| {
165 let continuation = generate_type_expr(local_type);
166 quote! {
167 #label(#label, #continuation)
168 }
169 })
170 .collect();
171
172 quote! {
173 {
174 #[session]
175 enum #enum_name {
176 #(#variants),*
177 }
178 #enum_name
179 }
180 }
181}
182
183#[must_use]
185pub fn generate_choreography_code(
186 name: &str,
187 roles: &[Role],
188 local_types: &[(Role, LocalType)],
189) -> TokenStream {
190 let role_struct_defs = generate_role_structs(roles);
191 let session_type_defs = local_types
192 .iter()
193 .map(|(role, local_type)| generate_session_type(role, local_type, name));
194
195 quote! {
196 #role_struct_defs
197 #(#session_type_defs)*
198 }
199}
200
201pub fn generate_choreography_code_with_extensions(
203 choreography: &Choreography,
204 local_types: &[(Role, LocalType)],
205 extensions: &[Box<dyn ProtocolExtension>],
206) -> TokenStream {
207 let base_code = generate_choreography_code(
209 &choreography.name.to_string(),
210 &choreography.roles,
211 local_types,
212 );
213
214 let extension_code = generate_extension_code(extensions, choreography);
216
217 quote! {
219 #base_code
220 #extension_code
221 }
222}
223
224fn generate_extension_code(
226 extensions: &[Box<dyn ProtocolExtension>],
227 choreography: &Choreography,
228) -> TokenStream {
229 if extensions.is_empty() {
230 return quote! {};
231 }
232
233 let mut extension_impls = Vec::new();
234
235 for extension in extensions {
236 let context = crate::extensions::CodegenContext {
237 choreography_name: &choreography.name.to_string(),
238 roles: &choreography.roles,
239 namespace: choreography.namespace.as_deref(),
240 };
241 let ext_code = extension.generate_code(&context);
242 extension_impls.push(ext_code);
243 }
244
245 quote! {
246 #(#extension_impls)*
248
249 pub fn create_extension_registry() -> ::telltale_runtime::extensions::ExtensionRegistry {
251 let mut registry = ::telltale_runtime::extensions::ExtensionRegistry::new();
252
253 registry
255 }
256 }
257}
258
259fn generate_role_structs(roles: &[Role]) -> TokenStream {
261 let _n = roles.len();
262 let role_names: Vec<&Ident> = roles.iter().map(|r| r.name()).collect();
263
264 let roles_struct = quote! {
266 #[derive(Roles)]
267 struct Roles(#(#role_names),*);
268 };
269
270 let role_structs = roles.iter().enumerate().map(|(i, role)| {
272 let role_name = role.name();
273 let other_roles: Vec<_> = roles
274 .iter()
275 .enumerate()
276 .filter(|(j, _)| i != *j)
277 .map(|(_, r)| r.name())
278 .collect();
279
280 if other_roles.is_empty() {
281 quote! {
283 #[derive(Role)]
284 #[message(Label)]
285 struct #role_name;
286 }
287 } else {
288 let routes = other_roles.iter().map(|other| {
289 quote! {
290 #[route(#other)] Channel
291 }
292 });
293
294 quote! {
295 #[derive(Role)]
296 #[message(Label)]
297 struct #role_name(#(#routes),*);
298 }
299 }
300 });
301
302 quote! {
303 #roles_struct
304 #(#role_structs)*
305 }
306}
307
308#[must_use]
310pub fn generate_role_implementations(
311 role: &Role,
312 local_type: &LocalType,
313 protocol_name: &str,
314) -> TokenStream {
315 let role_name = role.name();
316 let fn_name = format_ident!("{}_protocol", role_name.to_string().to_lowercase());
317 let session_type = format_ident!("{}_{}", role_name, protocol_name);
318
319 let impl_body = generate_implementation_body(local_type);
320
321 quote! {
322 async fn #fn_name(role: &mut #role_name) -> Result<()> {
323 try_session(role, |s: #session_type<'_, _>| async move {
324 #impl_body
325 Ok(((), s))
326 }).await
327 }
328 }
329}
330
331fn generate_implementation_body(local_type: &LocalType) -> TokenStream {
334 match local_type {
335 LocalType::Send {
336 message,
337 continuation,
338 ..
339 } => generate_send_impl(&message.name, continuation),
340
341 LocalType::Receive {
342 message,
343 continuation,
344 ..
345 } => generate_recv_impl(&message.name, continuation),
346
347 LocalType::Select { branches, .. } => generate_select_impl(branches),
348
349 LocalType::Branch { branches, .. } => {
350 let match_arms = branches.iter().map(generate_branch_match_arm);
351
352 quote! {
353 let s = match s.branch().await? {
354 #(#match_arms)*
355 };
356 }
357 }
358
359 LocalType::End => quote! {},
360
361 _ => quote! { },
362 }
363}
364
365fn generate_send_impl(msg_name: &Ident, continuation: &LocalType) -> TokenStream {
366 let cont_impl = generate_implementation_body(continuation);
367 quote! {
368 let s = s.send(#msg_name()).await?;
369 #cont_impl
370 }
371}
372
373fn generate_recv_impl(msg_name: &Ident, continuation: &LocalType) -> TokenStream {
374 let cont_impl = generate_implementation_body(continuation);
375 quote! {
376 let (#msg_name(value), s) = s.receive().await?;
377 #cont_impl
378 }
379}
380
381fn generate_select_impl(branches: &[(Ident, LocalType)]) -> TokenStream {
382 let first_branch = &branches[0];
383 let label = &first_branch.0;
384 let cont_impl = generate_implementation_body(&first_branch.1);
385 quote! {
386 let s = s.select(#label()).await?;
387 #cont_impl
388 }
389}
390
391fn generate_branch_match_arm(branch: &(Ident, LocalType)) -> TokenStream {
392 let (label, local_type) = branch;
393 let impl_body = generate_implementation_body(local_type);
394 quote! {
395 Choice::#label(value, s) => {
396 #impl_body
397 }
398 }
399}
400
401#[must_use]
403pub fn generate_helpers(_name: &str, messages: &[MessageType]) -> TokenStream {
404 let message_enum = if messages.is_empty() {
405 quote! {}
406 } else {
407 let variants = messages.iter().map(|msg| {
408 let name = &msg.name;
409 quote! { #name(#name) }
410 });
411
412 quote! {
413 #[derive(Message)]
414 enum Label {
415 #(#variants),*
416 }
417 }
418 };
419
420 let message_structs = messages.iter().map(|msg| {
421 let name = &msg.name;
422 if let Some(payload) = &msg.payload {
423 quote! { struct #name #payload; }
424 } else {
425 quote! { struct #name; }
426 }
427 });
428
429 quote! {
430 #message_enum
431 #(#message_structs)*
432
433 type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
434 type Channel = Bidirectional<UnboundedSender<Label>, UnboundedReceiver<Label>>;
435 }
436}
437
438#[must_use]
440pub fn generate_choreography_code_with_namespacing(
441 choreo: &Choreography,
442 local_types: &[(Role, LocalType)],
443) -> TokenStream {
444 let inner_code = generate_choreography_code_with_annotations(
445 &choreo.name.to_string(),
446 &choreo.roles,
447 local_types,
448 choreo,
449 );
450
451 let choreo_docs = generate_annotation_docs(choreo.get_attributes());
453 let choreo_metadata =
454 generate_annotation_metadata(&choreo.name.to_string(), choreo.get_attributes());
455
456 match &choreo.namespace {
457 Some(ns) => {
458 let ns_ident = format_ident!("{}", ns);
459 quote! {
460 #choreo_docs
461 #[allow(dead_code, unused_imports, unused_variables)]
462 pub mod #ns_ident {
463 use super::*;
464
465 #choreo_metadata
466 #inner_code
467 }
468 }
469 }
470 None => {
471 quote! {
472 #choreo_docs
473 #[allow(dead_code, unused_imports, unused_variables)]
474 mod __generated_choreography {
475 use super::*;
476 #choreo_metadata
477 #inner_code
478 }
479 pub use __generated_choreography::*;
480 }
481 }
482 }
483}
484
485#[must_use]
487pub fn generate_choreography_code_with_annotations(
488 name: &str,
489 roles: &[Role],
490 local_types: &[(Role, LocalType)],
491 choreo: &Choreography,
492) -> TokenStream {
493 let role_struct_defs = generate_role_structs(roles);
494 let session_type_defs = local_types
495 .iter()
496 .map(|(role, local_type)| generate_session_type(role, local_type, name));
497
498 let protocol_annotation_access = generate_runtime_annotation_access(name, &choreo.protocol);
500
501 let role_metadata: Vec<TokenStream> = roles
503 .iter()
504 .filter(|role| role.index().is_some() || role.param().is_some())
505 .map(|role| {
506 let mut role_annotations = HashMap::new();
507 if role.index().is_some() {
508 role_annotations.insert("indexed".to_string(), "true".to_string());
509 }
510 if role.param().is_some() {
511 role_annotations.insert("parameterized".to_string(), "true".to_string());
512 }
513 generate_annotation_metadata(&role.name().to_string(), &role_annotations)
514 })
515 .collect();
516
517 quote! {
518 #role_struct_defs
519 #(#session_type_defs)*
520 #protocol_annotation_access
521 #(#role_metadata)*
522
523 pub mod annotations {
525 use super::*;
526 use std::collections::HashMap;
527
528 pub fn get_all_protocol_annotations() -> HashMap<String, HashMap<String, String>> {
530 let mut all_annotations = HashMap::new();
531 all_annotations
533 }
534 }
535 }
536}