1use crate::ast::{Branch, Choreography, LocalType, Protocol, Role};
7use crate::topology::{Location, Topology, TopologyConstraint, TopologyMode};
8use proc_macro2::{Ident, TokenStream};
9use quote::{format_ident, quote};
10
11use super::generate_choreography_code;
12
13#[derive(Debug, Clone)]
15pub struct InlineTopology {
16 pub name: String,
18 pub topology: Topology,
20}
21
22#[derive(Debug, Clone)]
23struct BranchRequirementSpec {
24 sender: String,
25 receiver: String,
26 label_count: u32,
27}
28
29#[must_use]
36pub fn generate_topology_integration(
37 choreography: &Choreography,
38 inline_topologies: &[InlineTopology],
39) -> TokenStream {
40 let _protocol_name = &choreography.name;
41
42 let role_names: Vec<&Ident> = choreography.roles.iter().map(|r| r.name()).collect();
44 let role_name_strs: Vec<String> = role_names.iter().map(|r| r.to_string()).collect();
45
46 let handler_method = generate_handler_method();
48
49 let branch_requirements = collect_branch_requirements(&choreography.protocol);
51
52 let with_topology_method = generate_with_topology_method(&role_name_strs, &branch_requirements);
54
55 let topology_constants = generate_topology_constants(inline_topologies, &role_name_strs);
57
58 quote! {
59 pub mod topology {
61 use super::*;
62 use ::telltale_runtime::topology::{
63 BranchRequirement, Location, Topology, TopologyBuilder, TopologyHandler,
64 TopologyMode,
65 };
66 use ::telltale_runtime::{
67 ChannelCapacity, Datacenter, Namespace, Region, RoleName, TopologyEndpoint,
68 };
69
70 #handler_method
71 #with_topology_method
72 #topology_constants
73 }
74 }
75}
76
77fn collect_branch_requirements(protocol: &Protocol) -> Vec<BranchRequirementSpec> {
78 let mut requirements = Vec::new();
79 collect_branch_requirements_from_protocol(protocol, &mut requirements);
80 requirements
81}
82
83fn collect_branch_requirements_from_protocol(
84 protocol: &Protocol,
85 requirements: &mut Vec<BranchRequirementSpec>,
86) {
87 match protocol {
88 Protocol::Choice { branches, .. } => {
89 let label_count = u32::try_from(branches.len()).unwrap_or(u32::MAX);
90 for branch in branches {
91 collect_branch_requirement_from_branch(branch, label_count, requirements);
92 collect_branch_requirements_from_protocol(&branch.protocol, requirements);
93 }
94 }
95 Protocol::Case { branches, .. } => {
96 for branch in branches {
97 collect_branch_requirements_from_protocol(&branch.protocol, requirements);
98 }
99 }
100 Protocol::Timeout {
101 body,
102 on_timeout,
103 on_cancel,
104 ..
105 } => {
106 collect_branch_requirements_from_protocol(body, requirements);
107 collect_branch_requirements_from_protocol(on_timeout, requirements);
108 if let Some(on_cancel) = on_cancel.as_deref() {
109 collect_branch_requirements_from_protocol(on_cancel, requirements);
110 }
111 }
112 Protocol::Send { continuation, .. } => {
113 collect_branch_requirements_from_protocol(continuation, requirements);
114 }
115 Protocol::Broadcast { continuation, .. } => {
116 collect_branch_requirements_from_protocol(continuation, requirements);
117 }
118 Protocol::Loop { body, .. } => {
119 collect_branch_requirements_from_protocol(body, requirements);
120 }
121 Protocol::Parallel { protocols } => {
122 for p in protocols {
123 collect_branch_requirements_from_protocol(p, requirements);
124 }
125 }
126 Protocol::Rec { body, .. } => {
127 collect_branch_requirements_from_protocol(body, requirements);
128 }
129 Protocol::Begin { continuation, .. }
130 | Protocol::Await { continuation, .. }
131 | Protocol::Resolve { continuation, .. }
132 | Protocol::Invalidate { continuation, .. }
133 | Protocol::Extension { continuation, .. }
134 | Protocol::Let { continuation, .. }
135 | Protocol::Publish { continuation, .. }
136 | Protocol::PublishAuthority { continuation, .. }
137 | Protocol::Materialize { continuation, .. }
138 | Protocol::Handoff { continuation, .. }
139 | Protocol::DependentWork { continuation, .. } => {
140 collect_branch_requirements_from_protocol(continuation, requirements);
141 }
142 Protocol::Var(_) | Protocol::End => {}
143 }
144}
145
146fn collect_branch_requirement_from_branch(
147 branch: &Branch,
148 label_count: u32,
149 requirements: &mut Vec<BranchRequirementSpec>,
150) {
151 match &branch.protocol {
152 Protocol::Send { from, to, .. } => {
153 requirements.push(BranchRequirementSpec {
154 sender: from.name().to_string(),
155 receiver: to.name().to_string(),
156 label_count,
157 });
158 }
159 Protocol::Broadcast { from, to_all, .. } => {
160 for to in to_all {
161 requirements.push(BranchRequirementSpec {
162 sender: from.name().to_string(),
163 receiver: to.name().to_string(),
164 label_count,
165 });
166 }
167 }
168 _ => {}
169 }
170}
171
172fn generate_handler_method() -> TokenStream {
174 quote! {
175 pub fn handler(role: Role) -> TopologyHandler {
190 TopologyHandler::local(role.role_name())
191 }
192 }
193}
194
195fn generate_with_topology_method(
197 role_names: &[String],
198 branch_requirements: &[BranchRequirementSpec],
199) -> TokenStream {
200 let role_name_literals: Vec<TokenStream> = role_names
201 .iter()
202 .map(|role| quote! { RoleName::from_static(#role) })
203 .collect();
204
205 let branch_requirement_literals: Vec<TokenStream> = branch_requirements
206 .iter()
207 .map(|req| {
208 let sender = &req.sender;
209 let receiver = &req.receiver;
210 let label_count = req.label_count;
211 quote! {
212 BranchRequirement::new(
213 RoleName::from_static(#sender),
214 RoleName::from_static(#receiver),
215 #label_count
216 )
217 }
218 })
219 .collect();
220
221 quote! {
222 pub fn with_topology(
246 topology: Topology,
247 role: Role,
248 ) -> Result<TopologyHandler, String> {
249 let roles = [#(#role_name_literals),*];
250 let branch_requirements: &[BranchRequirement] = &[#(#branch_requirement_literals),*];
251
252 let validation = topology.validate_with_branches(&roles, &branch_requirements);
254 if !validation.is_valid() {
255 return Err(format!("Topology validation failed: {:?}", validation));
256 }
257
258 Ok(TopologyHandler::new(topology, role.role_name()))
259 }
260 }
261}
262
263fn generate_topology_constants(
265 inline_topologies: &[InlineTopology],
266 role_names: &[String],
267) -> TokenStream {
268 if inline_topologies.is_empty() {
269 return quote! {};
270 }
271
272 let constants: Vec<TokenStream> = inline_topologies
273 .iter()
274 .map(|topo| {
275 let _const_name = format_ident!("{}", topo.name.to_uppercase());
276 let fn_name = format_ident!("{}", topo.name.to_lowercase());
277 let handler_fn_name = format_ident!("{}_handler", topo.name.to_lowercase());
278
279 let builder_calls = generate_topology_builder(&topo.topology, role_names);
281
282 quote! {
283 pub fn #fn_name() -> Topology {
285 #builder_calls
286 }
287
288 pub fn #handler_fn_name(role: Role) -> Result<TopologyHandler, String> {
290 with_topology(#fn_name(), role)
291 }
292 }
293 })
294 .collect();
295
296 quote! {
297 pub mod topologies {
299 use super::*;
300
301 #(#constants)*
302 }
303 }
304}
305
306fn generate_topology_builder(topology: &Topology, _role_names: &[String]) -> TokenStream {
308 let mut builder_calls = Vec::new();
309
310 if let Some(ref mode) = topology.mode {
312 builder_calls.push(generate_mode_builder_call(mode));
313 }
314
315 for (role, location) in &topology.locations {
317 builder_calls.push(generate_location_builder_call(role, location));
318 }
319
320 for constraint in &topology.constraints {
322 builder_calls.push(generate_constraint_builder_call(constraint));
323 }
324
325 for ((sender, receiver), capacity) in &topology.channel_capacities {
327 builder_calls.push(generate_channel_capacity_builder_call(
328 sender, receiver, capacity,
329 ));
330 }
331
332 if builder_calls.is_empty() {
333 quote! {
334 TopologyBuilder::new().build()
335 }
336 } else {
337 quote! {
338 TopologyBuilder::new()
339 #(#builder_calls)*
340 .build()
341 }
342 }
343}
344
345fn generate_mode_builder_call(mode: &TopologyMode) -> TokenStream {
346 match mode {
347 TopologyMode::Local => quote! { .mode(TopologyMode::Local) },
348 TopologyMode::PerRole => quote! { .mode(TopologyMode::PerRole) },
349 TopologyMode::Kubernetes(ns) => {
350 let ns_literal = ns.as_str();
351 quote! { .mode(TopologyMode::Kubernetes(Namespace::new(#ns_literal).unwrap())) }
352 }
353 TopologyMode::Consul(dc) => {
354 let dc_literal = dc.as_str();
355 quote! { .mode(TopologyMode::Consul(Datacenter::new(#dc_literal).unwrap())) }
356 }
357 }
358}
359
360fn generate_location_builder_call(
361 role: &crate::identifiers::RoleName,
362 location: &Location,
363) -> TokenStream {
364 let role_literal = role.as_str();
365 match location {
366 Location::Local => quote! { .local_role(RoleName::from_static(#role_literal)) },
367 Location::Remote(endpoint) => {
368 let endpoint_literal = endpoint.as_str();
369 quote! {
370 .remote_role(
371 RoleName::from_static(#role_literal),
372 TopologyEndpoint::new(#endpoint_literal).unwrap()
373 )
374 }
375 }
376 Location::Colocated(peer) => {
377 let peer_literal = peer.as_str();
378 quote! {
379 .colocated_role(
380 RoleName::from_static(#role_literal),
381 RoleName::from_static(#peer_literal)
382 )
383 }
384 }
385 }
386}
387
388fn generate_pinned_location_expr(location: &Location) -> TokenStream {
389 match location {
390 Location::Local => quote! { Location::Local },
391 Location::Remote(endpoint) => {
392 let endpoint_literal = endpoint.as_str();
393 quote! { Location::Remote(TopologyEndpoint::new(#endpoint_literal).unwrap()) }
394 }
395 Location::Colocated(peer) => {
396 let peer_literal = peer.as_str();
397 quote! { Location::Colocated(RoleName::from_static(#peer_literal)) }
398 }
399 }
400}
401
402fn generate_constraint_builder_call(constraint: &TopologyConstraint) -> TokenStream {
403 match constraint {
404 TopologyConstraint::Colocated(r1, r2) => {
405 let r1_literal = r1.as_str();
406 let r2_literal = r2.as_str();
407 quote! {
408 .colocated(
409 RoleName::from_static(#r1_literal),
410 RoleName::from_static(#r2_literal)
411 )
412 }
413 }
414 TopologyConstraint::Separated(r1, r2) => {
415 let r1_literal = r1.as_str();
416 let r2_literal = r2.as_str();
417 quote! {
418 .separated(
419 RoleName::from_static(#r1_literal),
420 RoleName::from_static(#r2_literal)
421 )
422 }
423 }
424 TopologyConstraint::Pinned(role, location) => {
425 let role_literal = role.as_str();
426 let location_expr = generate_pinned_location_expr(location);
427 quote! { .pinned(RoleName::from_static(#role_literal), #location_expr) }
428 }
429 TopologyConstraint::Region(role, region) => {
430 let role_literal = role.as_str();
431 let region_literal = region.as_str();
432 quote! {
433 .region(
434 RoleName::from_static(#role_literal),
435 Region::new(#region_literal).unwrap()
436 )
437 }
438 }
439 }
440}
441
442fn generate_channel_capacity_builder_call(
443 sender: &crate::identifiers::RoleName,
444 receiver: &crate::identifiers::RoleName,
445 capacity: &crate::ChannelCapacity,
446) -> TokenStream {
447 let sender_literal = sender.as_str();
448 let receiver_literal = receiver.as_str();
449 let capacity_value = capacity.get();
450 quote! {
451 .channel_capacity(
452 RoleName::from_static(#sender_literal),
453 RoleName::from_static(#receiver_literal),
454 ChannelCapacity::try_new(#capacity_value)
455 .expect("generated channel capacity must be within declared bounds")
456 )
457 }
458}
459
460#[must_use]
462pub fn generate_choreography_code_with_topology(
463 choreography: &Choreography,
464 local_types: &[(Role, LocalType)],
465 inline_topologies: &[InlineTopology],
466) -> TokenStream {
467 let name = choreography.name.to_string();
468 let base_code = generate_choreography_code(&name, &choreography.roles, local_types);
469 let topology_code = generate_topology_integration(choreography, inline_topologies);
470
471 quote! {
472 #base_code
473 #topology_code
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use crate::ast::Protocol;
481 use crate::identifiers::RoleName;
482
483 fn create_test_choreography() -> Choreography {
484 use quote::format_ident;
485
486 Choreography {
487 name: format_ident!("TestProtocol"),
488 namespace: None,
489 roles: vec![
490 Role::new(format_ident!("Alice")).unwrap(),
491 Role::new(format_ident!("Bob")).unwrap(),
492 ],
493 protocol: Protocol::End,
494 attrs: std::collections::HashMap::new(),
495 }
496 }
497
498 #[test]
499 fn test_generate_topology_integration_basic() {
500 let choreography = create_test_choreography();
501 let inline_topologies = vec![];
502
503 let tokens = generate_topology_integration(&choreography, &inline_topologies);
504 let code = tokens.to_string();
505
506 assert!(code.contains("pub mod topology"));
508 assert!(code.contains("RoleName :: from_static"));
510 assert!(code.contains("pub fn handler"));
512 assert!(code.contains("pub fn with_topology"));
514 }
515
516 #[test]
517 fn test_generate_topology_integration_with_inline_topologies() {
518 let choreography = create_test_choreography();
519
520 let dev_topology = Topology::builder()
521 .mode(TopologyMode::Local)
522 .local_role(RoleName::from_static("Alice"))
523 .local_role(RoleName::from_static("Bob"))
524 .build();
525
526 let prod_topology = Topology::builder()
527 .remote_role(
528 RoleName::from_static("Alice"),
529 crate::identifiers::Endpoint::new("alice.prod:8080").unwrap(),
530 )
531 .remote_role(
532 RoleName::from_static("Bob"),
533 crate::identifiers::Endpoint::new("bob.prod:8081").unwrap(),
534 )
535 .build();
536
537 let inline_topologies = vec![
538 InlineTopology {
539 name: "Dev".to_string(),
540 topology: dev_topology,
541 },
542 InlineTopology {
543 name: "Prod".to_string(),
544 topology: prod_topology,
545 },
546 ];
547
548 let tokens = generate_topology_integration(&choreography, &inline_topologies);
549 let code = tokens.to_string();
550
551 assert!(code.contains("pub mod topologies"));
553 assert!(code.contains("pub fn dev"));
555 assert!(code.contains("pub fn prod"));
557 assert!(code.contains("dev_handler"));
559 assert!(code.contains("prod_handler"));
560 }
561
562 #[test]
563 fn test_generate_handler_method() {
564 let tokens = generate_handler_method();
565 let code = tokens.to_string();
566
567 assert!(code.contains("pub fn handler"));
568 assert!(code.contains("TopologyHandler :: local"));
569 assert!(code.contains("role_name"));
570 }
571
572 #[test]
573 fn test_generate_with_topology_method() {
574 let tokens = generate_with_topology_method(&["Alice".to_string(), "Bob".to_string()], &[]);
575 let code = tokens.to_string();
576
577 assert!(code.contains("pub fn with_topology"));
578 assert!(code.contains("TopologyHandler :: new"));
579 assert!(code.contains("topology . validate_with_branches"));
580 }
581
582 #[test]
583 fn test_generate_topology_builder_local_mode() {
584 let topology = Topology::builder().mode(TopologyMode::Local).build();
585
586 let tokens =
587 generate_topology_builder(&topology, &["Alice".to_string(), "Bob".to_string()]);
588 let code = tokens.to_string();
589
590 assert!(code.contains("TopologyMode :: Local"));
591 }
592
593 #[test]
594 fn test_generate_topology_builder_with_roles() {
595 let topology = Topology::builder()
596 .local_role(RoleName::from_static("Alice"))
597 .remote_role(
598 RoleName::from_static("Bob"),
599 crate::identifiers::Endpoint::new("localhost:8080").unwrap(),
600 )
601 .build();
602
603 let tokens =
604 generate_topology_builder(&topology, &["Alice".to_string(), "Bob".to_string()]);
605 let code = tokens.to_string();
606
607 assert!(code.contains("local_role"));
608 assert!(code.contains("remote_role"));
609 assert!(code.contains("localhost:8080"));
610 }
611
612 #[test]
613 fn test_generate_topology_builder_with_constraints() {
614 let topology = Topology::builder()
615 .local_role(RoleName::from_static("Alice"))
616 .local_role(RoleName::from_static("Bob"))
617 .colocated(RoleName::from_static("Alice"), RoleName::from_static("Bob"))
618 .separated(
619 RoleName::from_static("Alice"),
620 RoleName::from_static("Carol"),
621 )
622 .build();
623
624 let tokens = generate_topology_builder(
625 &topology,
626 &["Alice".to_string(), "Bob".to_string(), "Carol".to_string()],
627 );
628 let code = tokens.to_string();
629
630 assert!(code.contains("colocated"));
631 assert!(code.contains("separated"));
632 }
633
634 #[test]
635 fn test_generate_choreography_code_with_topology() {
636 let choreography = create_test_choreography();
637 let local_types = vec![
638 (
639 Role::new(format_ident!("Alice")).unwrap(),
640 crate::ast::LocalType::End,
641 ),
642 (
643 Role::new(format_ident!("Bob")).unwrap(),
644 crate::ast::LocalType::End,
645 ),
646 ];
647 let inline_topologies = vec![];
648
649 let tokens = generate_choreography_code_with_topology(
650 &choreography,
651 &local_types,
652 &inline_topologies,
653 );
654 let code = tokens.to_string();
655
656 assert!(code.contains("Alice") || code.contains("Roles"));
658 assert!(code.contains("pub mod topology"));
660 }
661}