Skip to main content

telltale_runtime/compiler/codegen/
topology.rs

1//! Topology integration code generation.
2//!
3//! Generates topology-aware protocol handlers that support local testing
4//! and distributed deployment configurations.
5
6use crate::ast::{Branch, Choreography, LocalType, Protocol, Role};
7use crate::topology::{Location, Topology, TopologyConstraint, TopologyMode};
8use proc_macro2::{Ident, TokenStream};
9use quote::{format_ident, quote};
10
11use super::generate_choreography_code;
12
13/// Parsed inline topology definition for code generation
14#[derive(Debug, Clone)]
15pub struct InlineTopology {
16    /// Name of the topology (e.g., "Dev", "Prod")
17    pub name: String,
18    /// The topology configuration
19    pub topology: Topology,
20}
21
22#[derive(Debug, Clone)]
23struct BranchRequirementSpec {
24    sender: String,
25    receiver: String,
26    label_count: u32,
27}
28
29/// Generate topology-aware protocol handlers
30///
31/// This generates:
32/// - `Protocol::handler(role)` - Creates a TopologyHandler with local mode
33/// - `Protocol::with_topology(topo, role)` - Creates a TopologyHandler with custom topology
34/// - Named topology constants for inline definitions
35#[must_use]
36pub fn generate_topology_integration(
37    choreography: &Choreography,
38    inline_topologies: &[InlineTopology],
39) -> TokenStream {
40    let _protocol_name = &choreography.name;
41
42    // Collect role names for validation
43    let role_names: Vec<&Ident> = choreography.roles.iter().map(|r| r.name()).collect();
44    let role_name_strs: Vec<String> = role_names.iter().map(|r| r.to_string()).collect();
45
46    // Generate handler method
47    let handler_method = generate_handler_method();
48
49    // Collect branch requirements for capacity checking
50    let branch_requirements = collect_branch_requirements(&choreography.protocol);
51
52    // Generate with_topology method
53    let with_topology_method = generate_with_topology_method(&role_name_strs, &branch_requirements);
54
55    // Generate topology constants
56    let topology_constants = generate_topology_constants(inline_topologies, &role_name_strs);
57
58    quote! {
59        /// Topology integration for the #protocol_name_str protocol
60        pub mod topology {
61            use super::*;
62            use ::telltale_runtime::topology::{
63                BranchRequirement, Location, Topology, TopologyBuilder, TopologyHandler,
64                TopologyMode,
65            };
66            use ::telltale_runtime::{
67                ChannelCapacity, Region, RoleFamilyConstraint, RoleName, TopologyEndpoint,
68            };
69
70            #handler_method
71            #with_topology_method
72            #topology_constants
73        }
74    }
75}
76
77fn collect_branch_requirements(protocol: &Protocol) -> Vec<BranchRequirementSpec> {
78    let mut requirements = Vec::new();
79    collect_branch_requirements_from_protocol(protocol, &mut requirements);
80    requirements
81}
82
83fn collect_branch_requirements_from_protocol(
84    protocol: &Protocol,
85    requirements: &mut Vec<BranchRequirementSpec>,
86) {
87    match protocol {
88        Protocol::Choice { branches, .. } => {
89            let label_count = u32::try_from(branches.len()).unwrap_or(u32::MAX);
90            for branch in branches {
91                collect_branch_requirement_from_branch(branch, label_count, requirements);
92                collect_branch_requirements_from_protocol(&branch.protocol, requirements);
93            }
94        }
95        Protocol::Case { branches, .. } => {
96            for branch in branches {
97                collect_branch_requirements_from_protocol(&branch.protocol, requirements);
98            }
99        }
100        Protocol::Timeout {
101            body,
102            on_timeout,
103            on_cancel,
104            ..
105        } => {
106            collect_branch_requirements_from_protocol(body, requirements);
107            collect_branch_requirements_from_protocol(on_timeout, requirements);
108            if let Some(on_cancel) = on_cancel.as_deref() {
109                collect_branch_requirements_from_protocol(on_cancel, requirements);
110            }
111        }
112        Protocol::Send { continuation, .. } => {
113            collect_branch_requirements_from_protocol(continuation, requirements);
114        }
115        Protocol::Broadcast { continuation, .. } => {
116            collect_branch_requirements_from_protocol(continuation, requirements);
117        }
118        Protocol::Loop { body, .. } => {
119            collect_branch_requirements_from_protocol(body, requirements);
120        }
121        Protocol::Parallel { protocols } => {
122            for p in protocols {
123                collect_branch_requirements_from_protocol(p, requirements);
124            }
125        }
126        Protocol::Rec { body, .. } => {
127            collect_branch_requirements_from_protocol(body, requirements);
128        }
129        Protocol::Begin { continuation, .. }
130        | Protocol::Await { continuation, .. }
131        | Protocol::Resolve { continuation, .. }
132        | Protocol::Invalidate { continuation, .. }
133        | Protocol::Extension { continuation, .. }
134        | Protocol::Let { continuation, .. }
135        | Protocol::Publish { continuation, .. }
136        | Protocol::PublishAuthority { continuation, .. }
137        | Protocol::Materialize { continuation, .. }
138        | Protocol::Handoff { continuation, .. }
139        | Protocol::DependentWork { continuation, .. } => {
140            collect_branch_requirements_from_protocol(continuation, requirements);
141        }
142        Protocol::Var(_) | Protocol::End => {}
143    }
144}
145
146fn collect_branch_requirement_from_branch(
147    branch: &Branch,
148    label_count: u32,
149    requirements: &mut Vec<BranchRequirementSpec>,
150) {
151    match &branch.protocol {
152        Protocol::Send { from, to, .. } => {
153            requirements.push(BranchRequirementSpec {
154                sender: from.name().to_string(),
155                receiver: to.name().to_string(),
156                label_count,
157            });
158        }
159        Protocol::Broadcast { from, to_all, .. } => {
160            for to in to_all {
161                requirements.push(BranchRequirementSpec {
162                    sender: from.name().to_string(),
163                    receiver: to.name().to_string(),
164                    label_count,
165                });
166            }
167        }
168        _ => {}
169    }
170}
171
172/// Generate the `handler(role)` method that returns a local TopologyHandler
173fn generate_handler_method() -> TokenStream {
174    quote! {
175        /// Create a handler for this protocol with local-mode topology.
176        ///
177        /// This is suitable for testing and single-process execution where
178        /// all roles run in the same process using in-memory channels.
179        ///
180        /// # Arguments
181        ///
182        /// * `role` - The role this handler will act as
183        ///
184        /// # Example
185        ///
186        /// ```ignore
187        /// let handler = MyProtocol::handler(Role::Alice);
188        /// ```
189        pub fn handler(role: Role) -> TopologyHandler {
190            TopologyHandler::local(role.role_name())
191        }
192    }
193}
194
195/// Generate the `with_topology(topo, role)` method
196fn generate_with_topology_method(
197    role_names: &[String],
198    branch_requirements: &[BranchRequirementSpec],
199) -> TokenStream {
200    let role_name_literals: Vec<TokenStream> = role_names
201        .iter()
202        .map(|role| quote! { RoleName::from_static(#role) })
203        .collect();
204
205    let branch_requirement_literals: Vec<TokenStream> = branch_requirements
206        .iter()
207        .map(|req| {
208            let sender = &req.sender;
209            let receiver = &req.receiver;
210            let label_count = req.label_count;
211            quote! {
212                BranchRequirement::new(
213                    RoleName::from_static(#sender),
214                    RoleName::from_static(#receiver),
215                    #label_count
216                )
217            }
218        })
219        .collect();
220
221    quote! {
222        /// Create a handler for this protocol with a custom topology.
223        ///
224        /// This allows specifying where each role is deployed, enabling
225        /// distributed execution across multiple processes or machines.
226        ///
227        /// # Arguments
228        ///
229        /// * `topology` - The topology configuration
230        /// * `role` - The role this handler will act as
231        ///
232        /// # Example
233        ///
234        /// ```ignore
235        /// let topology = Topology::builder()
236        ///     .local_role(RoleName::from_static("Alice"))
237        ///     .remote_role(
238        ///         RoleName::from_static("Bob"),
239        ///         TopologyEndpoint::new("192.168.1.10:8080").unwrap(),
240        ///     )
241        ///     .build();
242        ///
243        /// let handler = MyProtocol::with_topology(topology, Role::Alice)?;
244        /// ```
245        pub fn with_topology(
246            topology: Topology,
247            role: Role,
248        ) -> Result<TopologyHandler, String> {
249            let roles = [#(#role_name_literals),*];
250            let branch_requirements: &[BranchRequirement] = &[#(#branch_requirement_literals),*];
251
252            // Validate topology against protocol roles
253            let validation = topology.validate_with_branches(&roles, &branch_requirements);
254            if !validation.is_valid() {
255                return Err(format!("Topology validation failed: {:?}", validation));
256            }
257
258            Ok(TopologyHandler::new(topology, role.role_name()))
259        }
260    }
261}
262
263/// Generate named topology constants from inline definitions
264fn generate_topology_constants(
265    inline_topologies: &[InlineTopology],
266    role_names: &[String],
267) -> TokenStream {
268    if inline_topologies.is_empty() {
269        return quote! {};
270    }
271
272    let constants: Vec<TokenStream> = inline_topologies
273        .iter()
274        .map(|topo| {
275            let _const_name = format_ident!("{}", topo.name.to_uppercase());
276            let fn_name = format_ident!("{}", topo.name.to_lowercase());
277            let handler_fn_name = format_ident!("{}_handler", topo.name.to_lowercase());
278
279            // Generate the topology builder calls
280            let builder_calls = generate_topology_builder(&topo.topology, role_names);
281
282            quote! {
283                /// Pre-configured topology: #const_name
284                pub fn #fn_name() -> Topology {
285                    #builder_calls
286                }
287
288                /// Get handler for the #const_name topology
289                pub fn #handler_fn_name(role: Role) -> Result<TopologyHandler, String> {
290                    with_topology(#fn_name(), role)
291                }
292            }
293        })
294        .collect();
295
296    quote! {
297        /// Pre-configured topologies for this protocol
298        pub mod topologies {
299            use super::*;
300
301            #(#constants)*
302        }
303    }
304}
305
306/// Generate topology builder code from a Topology
307fn generate_topology_builder(topology: &Topology, _role_names: &[String]) -> TokenStream {
308    let mut builder_calls = Vec::new();
309
310    // Add mode if specified
311    if let Some(ref mode) = topology.mode {
312        builder_calls.push(generate_mode_builder_call(mode));
313    }
314
315    // Add role locations
316    for (role, location) in &topology.locations {
317        builder_calls.push(generate_location_builder_call(role, location));
318    }
319
320    // Add constraints
321    for constraint in &topology.constraints {
322        builder_calls.push(generate_constraint_builder_call(constraint));
323    }
324
325    // Add channel capacities
326    for ((sender, receiver), capacity) in &topology.channel_capacities {
327        builder_calls.push(generate_channel_capacity_builder_call(
328            sender, receiver, capacity,
329        ));
330    }
331
332    // Add role-family constraints
333    for (family, constraint) in &topology.role_constraints {
334        builder_calls.push(generate_role_family_constraint_builder_call(
335            family, constraint,
336        ));
337    }
338
339    if builder_calls.is_empty() {
340        quote! {
341            TopologyBuilder::new().build()
342        }
343    } else {
344        quote! {
345            TopologyBuilder::new()
346                #(#builder_calls)*
347                .build()
348        }
349    }
350}
351
352fn generate_mode_builder_call(mode: &TopologyMode) -> TokenStream {
353    match mode {
354        TopologyMode::Local => quote! { .mode(TopologyMode::Local) },
355    }
356}
357
358fn generate_location_builder_call(
359    role: &crate::identifiers::RoleName,
360    location: &Location,
361) -> TokenStream {
362    let role_literal = role.as_str();
363    match location {
364        Location::Local => quote! { .local_role(RoleName::from_static(#role_literal)) },
365        Location::Remote(endpoint) => {
366            let endpoint_literal = endpoint.as_str();
367            quote! {
368                .remote_role(
369                    RoleName::from_static(#role_literal),
370                    TopologyEndpoint::new(#endpoint_literal).unwrap()
371                )
372            }
373        }
374        Location::Colocated(peer) => {
375            let peer_literal = peer.as_str();
376            quote! {
377                .colocated_role(
378                    RoleName::from_static(#role_literal),
379                    RoleName::from_static(#peer_literal)
380                )
381            }
382        }
383    }
384}
385
386fn generate_pinned_location_expr(location: &Location) -> TokenStream {
387    match location {
388        Location::Local => quote! { Location::Local },
389        Location::Remote(endpoint) => {
390            let endpoint_literal = endpoint.as_str();
391            quote! { Location::Remote(TopologyEndpoint::new(#endpoint_literal).unwrap()) }
392        }
393        Location::Colocated(peer) => {
394            let peer_literal = peer.as_str();
395            quote! { Location::Colocated(RoleName::from_static(#peer_literal)) }
396        }
397    }
398}
399
400fn generate_constraint_builder_call(constraint: &TopologyConstraint) -> TokenStream {
401    match constraint {
402        TopologyConstraint::Colocated(r1, r2) => {
403            let r1_literal = r1.as_str();
404            let r2_literal = r2.as_str();
405            quote! {
406                .colocated(
407                    RoleName::from_static(#r1_literal),
408                    RoleName::from_static(#r2_literal)
409                )
410            }
411        }
412        TopologyConstraint::Separated(r1, r2) => {
413            let r1_literal = r1.as_str();
414            let r2_literal = r2.as_str();
415            quote! {
416                .separated(
417                    RoleName::from_static(#r1_literal),
418                    RoleName::from_static(#r2_literal)
419                )
420            }
421        }
422        TopologyConstraint::Pinned(role, location) => {
423            let role_literal = role.as_str();
424            let location_expr = generate_pinned_location_expr(location);
425            quote! { .pinned(RoleName::from_static(#role_literal), #location_expr) }
426        }
427        TopologyConstraint::Region(role, region) => {
428            let role_literal = role.as_str();
429            let region_literal = region.as_str();
430            quote! {
431                .region(
432                    RoleName::from_static(#role_literal),
433                    Region::new(#region_literal).unwrap()
434                )
435            }
436        }
437    }
438}
439
440fn generate_channel_capacity_builder_call(
441    sender: &crate::identifiers::RoleName,
442    receiver: &crate::identifiers::RoleName,
443    capacity: &crate::ChannelCapacity,
444) -> TokenStream {
445    let sender_literal = sender.as_str();
446    let receiver_literal = receiver.as_str();
447    let capacity_value = capacity.get();
448    quote! {
449        .channel_capacity(
450            RoleName::from_static(#sender_literal),
451            RoleName::from_static(#receiver_literal),
452            ChannelCapacity::try_new(#capacity_value)
453                .expect("generated channel capacity must be within declared bounds")
454        )
455    }
456}
457
458fn generate_role_family_constraint_builder_call(
459    family: &str,
460    constraint: &crate::topology::RoleFamilyConstraint,
461) -> TokenStream {
462    let min = constraint.min;
463    match constraint.max {
464        Some(max) => quote! {
465            .role_family_constraint(#family, RoleFamilyConstraint::bounded(#min, #max))
466        },
467        None => quote! {
468            .role_family_constraint(#family, RoleFamilyConstraint::min_only(#min))
469        },
470    }
471}
472
473/// Generate complete choreography code with topology integration
474#[must_use]
475pub fn generate_choreography_code_with_topology(
476    choreography: &Choreography,
477    local_types: &[(Role, LocalType)],
478    inline_topologies: &[InlineTopology],
479) -> TokenStream {
480    let name = choreography.name.to_string();
481    let base_code = generate_choreography_code(&name, &choreography.roles, local_types);
482    let topology_code = generate_topology_integration(choreography, inline_topologies);
483
484    quote! {
485        #base_code
486        #topology_code
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use crate::ast::Protocol;
494    use crate::identifiers::RoleName;
495
496    fn create_test_choreography() -> Choreography {
497        use quote::format_ident;
498
499        Choreography {
500            name: format_ident!("TestProtocol"),
501            namespace: None,
502            roles: vec![
503                Role::new(format_ident!("Alice")).unwrap(),
504                Role::new(format_ident!("Bob")).unwrap(),
505            ],
506            protocol: Protocol::End,
507            attrs: std::collections::HashMap::new(),
508        }
509    }
510
511    #[test]
512    fn test_generate_topology_integration_basic() {
513        let choreography = create_test_choreography();
514        let inline_topologies = vec![];
515
516        let tokens = generate_topology_integration(&choreography, &inline_topologies);
517        let code = tokens.to_string();
518
519        // Should generate the topology module
520        assert!(code.contains("pub mod topology"));
521        // Should construct role names for validation
522        assert!(code.contains("RoleName :: from_static"));
523        // Should contain handler function
524        assert!(code.contains("pub fn handler"));
525        // Should contain with_topology function
526        assert!(code.contains("pub fn with_topology"));
527    }
528
529    #[test]
530    fn test_generate_topology_integration_with_inline_topologies() {
531        let choreography = create_test_choreography();
532
533        let dev_topology = Topology::builder()
534            .mode(TopologyMode::Local)
535            .local_role(RoleName::from_static("Alice"))
536            .local_role(RoleName::from_static("Bob"))
537            .build();
538
539        let prod_topology = Topology::builder()
540            .remote_role(
541                RoleName::from_static("Alice"),
542                crate::identifiers::Endpoint::new("alice.prod:8080").unwrap(),
543            )
544            .remote_role(
545                RoleName::from_static("Bob"),
546                crate::identifiers::Endpoint::new("bob.prod:8081").unwrap(),
547            )
548            .build();
549
550        let inline_topologies = vec![
551            InlineTopology {
552                name: "Dev".to_string(),
553                topology: dev_topology,
554            },
555            InlineTopology {
556                name: "Prod".to_string(),
557                topology: prod_topology,
558            },
559        ];
560
561        let tokens = generate_topology_integration(&choreography, &inline_topologies);
562        let code = tokens.to_string();
563
564        // Should generate topology constants module
565        assert!(code.contains("pub mod topologies"));
566        // Should generate dev topology function
567        assert!(code.contains("pub fn dev"));
568        // Should generate prod topology function
569        assert!(code.contains("pub fn prod"));
570        // Should generate handler functions for each
571        assert!(code.contains("dev_handler"));
572        assert!(code.contains("prod_handler"));
573    }
574
575    #[test]
576    fn test_generate_handler_method() {
577        let tokens = generate_handler_method();
578        let code = tokens.to_string();
579
580        assert!(code.contains("pub fn handler"));
581        assert!(code.contains("TopologyHandler :: local"));
582        assert!(code.contains("role_name"));
583    }
584
585    #[test]
586    fn test_generate_with_topology_method() {
587        let tokens = generate_with_topology_method(&["Alice".to_string(), "Bob".to_string()], &[]);
588        let code = tokens.to_string();
589
590        assert!(code.contains("pub fn with_topology"));
591        assert!(code.contains("TopologyHandler :: new"));
592        assert!(code.contains("topology . validate_with_branches"));
593    }
594
595    #[test]
596    fn test_generate_topology_builder_local_mode() {
597        let topology = Topology::builder().mode(TopologyMode::Local).build();
598
599        let tokens =
600            generate_topology_builder(&topology, &["Alice".to_string(), "Bob".to_string()]);
601        let code = tokens.to_string();
602
603        assert!(code.contains("TopologyMode :: Local"));
604    }
605
606    #[test]
607    fn test_generated_topology_helpers_do_not_import_deployment_backends() {
608        let choreography = create_test_choreography();
609        let tokens = generate_topology_integration(&choreography, &[]);
610        let code = tokens.to_string();
611
612        assert!(!code.contains("Datacenter"));
613        assert!(!code.contains("Kubernetes"));
614        assert!(!code.contains("Consul"));
615    }
616
617    #[test]
618    fn test_generate_topology_builder_with_roles() {
619        let topology = Topology::builder()
620            .local_role(RoleName::from_static("Alice"))
621            .remote_role(
622                RoleName::from_static("Bob"),
623                crate::identifiers::Endpoint::new("localhost:8080").unwrap(),
624            )
625            .build();
626
627        let tokens =
628            generate_topology_builder(&topology, &["Alice".to_string(), "Bob".to_string()]);
629        let code = tokens.to_string();
630
631        assert!(code.contains("local_role"));
632        assert!(code.contains("remote_role"));
633        assert!(code.contains("localhost:8080"));
634    }
635
636    #[test]
637    fn test_generate_topology_builder_with_constraints() {
638        let topology = Topology::builder()
639            .local_role(RoleName::from_static("Alice"))
640            .local_role(RoleName::from_static("Bob"))
641            .colocated(RoleName::from_static("Alice"), RoleName::from_static("Bob"))
642            .separated(
643                RoleName::from_static("Alice"),
644                RoleName::from_static("Carol"),
645            )
646            .role_family_constraint(
647                "Witness",
648                crate::topology::RoleFamilyConstraint::bounded(2, 5),
649            )
650            .build();
651
652        let tokens = generate_topology_builder(
653            &topology,
654            &["Alice".to_string(), "Bob".to_string(), "Carol".to_string()],
655        );
656        let code = tokens.to_string();
657
658        assert!(code.contains("colocated"));
659        assert!(code.contains("separated"));
660        assert!(code.contains("role_family_constraint"));
661        assert!(code.contains("RoleFamilyConstraint :: bounded"));
662    }
663
664    #[test]
665    fn test_generate_choreography_code_with_topology() {
666        let choreography = create_test_choreography();
667        let local_types = vec![
668            (
669                Role::new(format_ident!("Alice")).unwrap(),
670                crate::ast::LocalType::End,
671            ),
672            (
673                Role::new(format_ident!("Bob")).unwrap(),
674                crate::ast::LocalType::End,
675            ),
676        ];
677        let inline_topologies = vec![];
678
679        let tokens = generate_choreography_code_with_topology(
680            &choreography,
681            &local_types,
682            &inline_topologies,
683        );
684        let code = tokens.to_string();
685
686        // Should contain role definitions
687        assert!(code.contains("Alice") || code.contains("Roles"));
688        // Should contain topology integration
689        assert!(code.contains("pub mod topology"));
690    }
691}