Skip to main content

telltale_runtime/compiler/codegen/
topology.rs

1//! Topology integration code generation.
2//!
3//! Generates topology-aware protocol handlers that support local testing
4//! and distributed deployment configurations.
5
6use crate::ast::{Branch, Choreography, LocalType, Protocol, Role};
7use crate::topology::{Location, Topology, TopologyConstraint, TopologyMode};
8use proc_macro2::{Ident, TokenStream};
9use quote::{format_ident, quote};
10
11use super::generate_choreography_code;
12
13/// Parsed inline topology definition for code generation
14#[derive(Debug, Clone)]
15pub struct InlineTopology {
16    /// Name of the topology (e.g., "Dev", "Prod")
17    pub name: String,
18    /// The topology configuration
19    pub topology: Topology,
20}
21
22#[derive(Debug, Clone)]
23struct BranchRequirementSpec {
24    sender: String,
25    receiver: String,
26    label_count: u32,
27}
28
29/// Generate topology-aware protocol handlers
30///
31/// This generates:
32/// - `Protocol::handler(role)` - Creates a TopologyHandler with local mode
33/// - `Protocol::with_topology(topo, role)` - Creates a TopologyHandler with custom topology
34/// - Named topology constants for inline definitions
35#[must_use]
36pub fn generate_topology_integration(
37    choreography: &Choreography,
38    inline_topologies: &[InlineTopology],
39) -> TokenStream {
40    let _protocol_name = &choreography.name;
41
42    // Collect role names for validation
43    let role_names: Vec<&Ident> = choreography.roles.iter().map(|r| r.name()).collect();
44    let role_name_strs: Vec<String> = role_names.iter().map(|r| r.to_string()).collect();
45
46    // Generate handler method
47    let handler_method = generate_handler_method();
48
49    // Collect branch requirements for capacity checking
50    let branch_requirements = collect_branch_requirements(&choreography.protocol);
51
52    // Generate with_topology method
53    let with_topology_method = generate_with_topology_method(&role_name_strs, &branch_requirements);
54
55    // Generate topology constants
56    let topology_constants = generate_topology_constants(inline_topologies, &role_name_strs);
57
58    quote! {
59        /// Topology integration for the #protocol_name_str protocol
60        pub mod topology {
61            use super::*;
62            use ::telltale_runtime::topology::{
63                BranchRequirement, Location, Topology, TopologyBuilder, TopologyHandler,
64                TopologyMode,
65            };
66            use ::telltale_runtime::{
67                ChannelCapacity, Datacenter, Namespace, Region, RoleName, TopologyEndpoint,
68            };
69
70            #handler_method
71            #with_topology_method
72            #topology_constants
73        }
74    }
75}
76
77fn collect_branch_requirements(protocol: &Protocol) -> Vec<BranchRequirementSpec> {
78    let mut requirements = Vec::new();
79    collect_branch_requirements_from_protocol(protocol, &mut requirements);
80    requirements
81}
82
83fn collect_branch_requirements_from_protocol(
84    protocol: &Protocol,
85    requirements: &mut Vec<BranchRequirementSpec>,
86) {
87    match protocol {
88        Protocol::Choice { branches, .. } => {
89            let label_count = u32::try_from(branches.len()).unwrap_or(u32::MAX);
90            for branch in branches {
91                collect_branch_requirement_from_branch(branch, label_count, requirements);
92                collect_branch_requirements_from_protocol(&branch.protocol, requirements);
93            }
94        }
95        Protocol::Case { branches, .. } => {
96            for branch in branches {
97                collect_branch_requirements_from_protocol(&branch.protocol, requirements);
98            }
99        }
100        Protocol::Timeout {
101            body,
102            on_timeout,
103            on_cancel,
104            ..
105        } => {
106            collect_branch_requirements_from_protocol(body, requirements);
107            collect_branch_requirements_from_protocol(on_timeout, requirements);
108            if let Some(on_cancel) = on_cancel.as_deref() {
109                collect_branch_requirements_from_protocol(on_cancel, requirements);
110            }
111        }
112        Protocol::Send { continuation, .. } => {
113            collect_branch_requirements_from_protocol(continuation, requirements);
114        }
115        Protocol::Broadcast { continuation, .. } => {
116            collect_branch_requirements_from_protocol(continuation, requirements);
117        }
118        Protocol::Loop { body, .. } => {
119            collect_branch_requirements_from_protocol(body, requirements);
120        }
121        Protocol::Parallel { protocols } => {
122            for p in protocols {
123                collect_branch_requirements_from_protocol(p, requirements);
124            }
125        }
126        Protocol::Rec { body, .. } => {
127            collect_branch_requirements_from_protocol(body, requirements);
128        }
129        Protocol::Begin { continuation, .. }
130        | Protocol::Await { continuation, .. }
131        | Protocol::Resolve { continuation, .. }
132        | Protocol::Invalidate { continuation, .. }
133        | Protocol::Extension { continuation, .. }
134        | Protocol::Let { continuation, .. }
135        | Protocol::Publish { continuation, .. }
136        | Protocol::PublishAuthority { continuation, .. }
137        | Protocol::Materialize { continuation, .. }
138        | Protocol::Handoff { continuation, .. }
139        | Protocol::DependentWork { continuation, .. } => {
140            collect_branch_requirements_from_protocol(continuation, requirements);
141        }
142        Protocol::Var(_) | Protocol::End => {}
143    }
144}
145
146fn collect_branch_requirement_from_branch(
147    branch: &Branch,
148    label_count: u32,
149    requirements: &mut Vec<BranchRequirementSpec>,
150) {
151    match &branch.protocol {
152        Protocol::Send { from, to, .. } => {
153            requirements.push(BranchRequirementSpec {
154                sender: from.name().to_string(),
155                receiver: to.name().to_string(),
156                label_count,
157            });
158        }
159        Protocol::Broadcast { from, to_all, .. } => {
160            for to in to_all {
161                requirements.push(BranchRequirementSpec {
162                    sender: from.name().to_string(),
163                    receiver: to.name().to_string(),
164                    label_count,
165                });
166            }
167        }
168        _ => {}
169    }
170}
171
172/// Generate the `handler(role)` method that returns a local TopologyHandler
173fn generate_handler_method() -> TokenStream {
174    quote! {
175        /// Create a handler for this protocol with local-mode topology.
176        ///
177        /// This is suitable for testing and single-process execution where
178        /// all roles run in the same process using in-memory channels.
179        ///
180        /// # Arguments
181        ///
182        /// * `role` - The role this handler will act as
183        ///
184        /// # Example
185        ///
186        /// ```ignore
187        /// let handler = MyProtocol::handler(Role::Alice);
188        /// ```
189        pub fn handler(role: Role) -> TopologyHandler {
190            TopologyHandler::local(role.role_name())
191        }
192    }
193}
194
195/// Generate the `with_topology(topo, role)` method
196fn generate_with_topology_method(
197    role_names: &[String],
198    branch_requirements: &[BranchRequirementSpec],
199) -> TokenStream {
200    let role_name_literals: Vec<TokenStream> = role_names
201        .iter()
202        .map(|role| quote! { RoleName::from_static(#role) })
203        .collect();
204
205    let branch_requirement_literals: Vec<TokenStream> = branch_requirements
206        .iter()
207        .map(|req| {
208            let sender = &req.sender;
209            let receiver = &req.receiver;
210            let label_count = req.label_count;
211            quote! {
212                BranchRequirement::new(
213                    RoleName::from_static(#sender),
214                    RoleName::from_static(#receiver),
215                    #label_count
216                )
217            }
218        })
219        .collect();
220
221    quote! {
222        /// Create a handler for this protocol with a custom topology.
223        ///
224        /// This allows specifying where each role is deployed, enabling
225        /// distributed execution across multiple processes or machines.
226        ///
227        /// # Arguments
228        ///
229        /// * `topology` - The topology configuration
230        /// * `role` - The role this handler will act as
231        ///
232        /// # Example
233        ///
234        /// ```ignore
235        /// let topology = Topology::builder()
236        ///     .local_role(RoleName::from_static("Alice"))
237        ///     .remote_role(
238        ///         RoleName::from_static("Bob"),
239        ///         TopologyEndpoint::new("192.168.1.10:8080").unwrap(),
240        ///     )
241        ///     .build();
242        ///
243        /// let handler = MyProtocol::with_topology(topology, Role::Alice)?;
244        /// ```
245        pub fn with_topology(
246            topology: Topology,
247            role: Role,
248        ) -> Result<TopologyHandler, String> {
249            let roles = [#(#role_name_literals),*];
250            let branch_requirements: &[BranchRequirement] = &[#(#branch_requirement_literals),*];
251
252            // Validate topology against protocol roles
253            let validation = topology.validate_with_branches(&roles, &branch_requirements);
254            if !validation.is_valid() {
255                return Err(format!("Topology validation failed: {:?}", validation));
256            }
257
258            Ok(TopologyHandler::new(topology, role.role_name()))
259        }
260    }
261}
262
263/// Generate named topology constants from inline definitions
264fn generate_topology_constants(
265    inline_topologies: &[InlineTopology],
266    role_names: &[String],
267) -> TokenStream {
268    if inline_topologies.is_empty() {
269        return quote! {};
270    }
271
272    let constants: Vec<TokenStream> = inline_topologies
273        .iter()
274        .map(|topo| {
275            let _const_name = format_ident!("{}", topo.name.to_uppercase());
276            let fn_name = format_ident!("{}", topo.name.to_lowercase());
277            let handler_fn_name = format_ident!("{}_handler", topo.name.to_lowercase());
278
279            // Generate the topology builder calls
280            let builder_calls = generate_topology_builder(&topo.topology, role_names);
281
282            quote! {
283                /// Pre-configured topology: #const_name
284                pub fn #fn_name() -> Topology {
285                    #builder_calls
286                }
287
288                /// Get handler for the #const_name topology
289                pub fn #handler_fn_name(role: Role) -> Result<TopologyHandler, String> {
290                    with_topology(#fn_name(), role)
291                }
292            }
293        })
294        .collect();
295
296    quote! {
297        /// Pre-configured topologies for this protocol
298        pub mod topologies {
299            use super::*;
300
301            #(#constants)*
302        }
303    }
304}
305
306/// Generate topology builder code from a Topology
307fn generate_topology_builder(topology: &Topology, _role_names: &[String]) -> TokenStream {
308    let mut builder_calls = Vec::new();
309
310    // Add mode if specified
311    if let Some(ref mode) = topology.mode {
312        builder_calls.push(generate_mode_builder_call(mode));
313    }
314
315    // Add role locations
316    for (role, location) in &topology.locations {
317        builder_calls.push(generate_location_builder_call(role, location));
318    }
319
320    // Add constraints
321    for constraint in &topology.constraints {
322        builder_calls.push(generate_constraint_builder_call(constraint));
323    }
324
325    // Add channel capacities
326    for ((sender, receiver), capacity) in &topology.channel_capacities {
327        builder_calls.push(generate_channel_capacity_builder_call(
328            sender, receiver, capacity,
329        ));
330    }
331
332    if builder_calls.is_empty() {
333        quote! {
334            TopologyBuilder::new().build()
335        }
336    } else {
337        quote! {
338            TopologyBuilder::new()
339                #(#builder_calls)*
340                .build()
341        }
342    }
343}
344
345fn generate_mode_builder_call(mode: &TopologyMode) -> TokenStream {
346    match mode {
347        TopologyMode::Local => quote! { .mode(TopologyMode::Local) },
348        TopologyMode::PerRole => quote! { .mode(TopologyMode::PerRole) },
349        TopologyMode::Kubernetes(ns) => {
350            let ns_literal = ns.as_str();
351            quote! { .mode(TopologyMode::Kubernetes(Namespace::new(#ns_literal).unwrap())) }
352        }
353        TopologyMode::Consul(dc) => {
354            let dc_literal = dc.as_str();
355            quote! { .mode(TopologyMode::Consul(Datacenter::new(#dc_literal).unwrap())) }
356        }
357    }
358}
359
360fn generate_location_builder_call(
361    role: &crate::identifiers::RoleName,
362    location: &Location,
363) -> TokenStream {
364    let role_literal = role.as_str();
365    match location {
366        Location::Local => quote! { .local_role(RoleName::from_static(#role_literal)) },
367        Location::Remote(endpoint) => {
368            let endpoint_literal = endpoint.as_str();
369            quote! {
370                .remote_role(
371                    RoleName::from_static(#role_literal),
372                    TopologyEndpoint::new(#endpoint_literal).unwrap()
373                )
374            }
375        }
376        Location::Colocated(peer) => {
377            let peer_literal = peer.as_str();
378            quote! {
379                .colocated_role(
380                    RoleName::from_static(#role_literal),
381                    RoleName::from_static(#peer_literal)
382                )
383            }
384        }
385    }
386}
387
388fn generate_pinned_location_expr(location: &Location) -> TokenStream {
389    match location {
390        Location::Local => quote! { Location::Local },
391        Location::Remote(endpoint) => {
392            let endpoint_literal = endpoint.as_str();
393            quote! { Location::Remote(TopologyEndpoint::new(#endpoint_literal).unwrap()) }
394        }
395        Location::Colocated(peer) => {
396            let peer_literal = peer.as_str();
397            quote! { Location::Colocated(RoleName::from_static(#peer_literal)) }
398        }
399    }
400}
401
402fn generate_constraint_builder_call(constraint: &TopologyConstraint) -> TokenStream {
403    match constraint {
404        TopologyConstraint::Colocated(r1, r2) => {
405            let r1_literal = r1.as_str();
406            let r2_literal = r2.as_str();
407            quote! {
408                .colocated(
409                    RoleName::from_static(#r1_literal),
410                    RoleName::from_static(#r2_literal)
411                )
412            }
413        }
414        TopologyConstraint::Separated(r1, r2) => {
415            let r1_literal = r1.as_str();
416            let r2_literal = r2.as_str();
417            quote! {
418                .separated(
419                    RoleName::from_static(#r1_literal),
420                    RoleName::from_static(#r2_literal)
421                )
422            }
423        }
424        TopologyConstraint::Pinned(role, location) => {
425            let role_literal = role.as_str();
426            let location_expr = generate_pinned_location_expr(location);
427            quote! { .pinned(RoleName::from_static(#role_literal), #location_expr) }
428        }
429        TopologyConstraint::Region(role, region) => {
430            let role_literal = role.as_str();
431            let region_literal = region.as_str();
432            quote! {
433                .region(
434                    RoleName::from_static(#role_literal),
435                    Region::new(#region_literal).unwrap()
436                )
437            }
438        }
439    }
440}
441
442fn generate_channel_capacity_builder_call(
443    sender: &crate::identifiers::RoleName,
444    receiver: &crate::identifiers::RoleName,
445    capacity: &crate::ChannelCapacity,
446) -> TokenStream {
447    let sender_literal = sender.as_str();
448    let receiver_literal = receiver.as_str();
449    let capacity_value = capacity.get();
450    quote! {
451        .channel_capacity(
452            RoleName::from_static(#sender_literal),
453            RoleName::from_static(#receiver_literal),
454            ChannelCapacity::try_new(#capacity_value)
455                .expect("generated channel capacity must be within declared bounds")
456        )
457    }
458}
459
460/// Generate complete choreography code with topology integration
461#[must_use]
462pub fn generate_choreography_code_with_topology(
463    choreography: &Choreography,
464    local_types: &[(Role, LocalType)],
465    inline_topologies: &[InlineTopology],
466) -> TokenStream {
467    let name = choreography.name.to_string();
468    let base_code = generate_choreography_code(&name, &choreography.roles, local_types);
469    let topology_code = generate_topology_integration(choreography, inline_topologies);
470
471    quote! {
472        #base_code
473        #topology_code
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use crate::ast::Protocol;
481    use crate::identifiers::RoleName;
482
483    fn create_test_choreography() -> Choreography {
484        use quote::format_ident;
485
486        Choreography {
487            name: format_ident!("TestProtocol"),
488            namespace: None,
489            roles: vec![
490                Role::new(format_ident!("Alice")).unwrap(),
491                Role::new(format_ident!("Bob")).unwrap(),
492            ],
493            protocol: Protocol::End,
494            attrs: std::collections::HashMap::new(),
495        }
496    }
497
498    #[test]
499    fn test_generate_topology_integration_basic() {
500        let choreography = create_test_choreography();
501        let inline_topologies = vec![];
502
503        let tokens = generate_topology_integration(&choreography, &inline_topologies);
504        let code = tokens.to_string();
505
506        // Should generate the topology module
507        assert!(code.contains("pub mod topology"));
508        // Should construct role names for validation
509        assert!(code.contains("RoleName :: from_static"));
510        // Should contain handler function
511        assert!(code.contains("pub fn handler"));
512        // Should contain with_topology function
513        assert!(code.contains("pub fn with_topology"));
514    }
515
516    #[test]
517    fn test_generate_topology_integration_with_inline_topologies() {
518        let choreography = create_test_choreography();
519
520        let dev_topology = Topology::builder()
521            .mode(TopologyMode::Local)
522            .local_role(RoleName::from_static("Alice"))
523            .local_role(RoleName::from_static("Bob"))
524            .build();
525
526        let prod_topology = Topology::builder()
527            .remote_role(
528                RoleName::from_static("Alice"),
529                crate::identifiers::Endpoint::new("alice.prod:8080").unwrap(),
530            )
531            .remote_role(
532                RoleName::from_static("Bob"),
533                crate::identifiers::Endpoint::new("bob.prod:8081").unwrap(),
534            )
535            .build();
536
537        let inline_topologies = vec![
538            InlineTopology {
539                name: "Dev".to_string(),
540                topology: dev_topology,
541            },
542            InlineTopology {
543                name: "Prod".to_string(),
544                topology: prod_topology,
545            },
546        ];
547
548        let tokens = generate_topology_integration(&choreography, &inline_topologies);
549        let code = tokens.to_string();
550
551        // Should generate topology constants module
552        assert!(code.contains("pub mod topologies"));
553        // Should generate dev topology function
554        assert!(code.contains("pub fn dev"));
555        // Should generate prod topology function
556        assert!(code.contains("pub fn prod"));
557        // Should generate handler functions for each
558        assert!(code.contains("dev_handler"));
559        assert!(code.contains("prod_handler"));
560    }
561
562    #[test]
563    fn test_generate_handler_method() {
564        let tokens = generate_handler_method();
565        let code = tokens.to_string();
566
567        assert!(code.contains("pub fn handler"));
568        assert!(code.contains("TopologyHandler :: local"));
569        assert!(code.contains("role_name"));
570    }
571
572    #[test]
573    fn test_generate_with_topology_method() {
574        let tokens = generate_with_topology_method(&["Alice".to_string(), "Bob".to_string()], &[]);
575        let code = tokens.to_string();
576
577        assert!(code.contains("pub fn with_topology"));
578        assert!(code.contains("TopologyHandler :: new"));
579        assert!(code.contains("topology . validate_with_branches"));
580    }
581
582    #[test]
583    fn test_generate_topology_builder_local_mode() {
584        let topology = Topology::builder().mode(TopologyMode::Local).build();
585
586        let tokens =
587            generate_topology_builder(&topology, &["Alice".to_string(), "Bob".to_string()]);
588        let code = tokens.to_string();
589
590        assert!(code.contains("TopologyMode :: Local"));
591    }
592
593    #[test]
594    fn test_generate_topology_builder_with_roles() {
595        let topology = Topology::builder()
596            .local_role(RoleName::from_static("Alice"))
597            .remote_role(
598                RoleName::from_static("Bob"),
599                crate::identifiers::Endpoint::new("localhost:8080").unwrap(),
600            )
601            .build();
602
603        let tokens =
604            generate_topology_builder(&topology, &["Alice".to_string(), "Bob".to_string()]);
605        let code = tokens.to_string();
606
607        assert!(code.contains("local_role"));
608        assert!(code.contains("remote_role"));
609        assert!(code.contains("localhost:8080"));
610    }
611
612    #[test]
613    fn test_generate_topology_builder_with_constraints() {
614        let topology = Topology::builder()
615            .local_role(RoleName::from_static("Alice"))
616            .local_role(RoleName::from_static("Bob"))
617            .colocated(RoleName::from_static("Alice"), RoleName::from_static("Bob"))
618            .separated(
619                RoleName::from_static("Alice"),
620                RoleName::from_static("Carol"),
621            )
622            .build();
623
624        let tokens = generate_topology_builder(
625            &topology,
626            &["Alice".to_string(), "Bob".to_string(), "Carol".to_string()],
627        );
628        let code = tokens.to_string();
629
630        assert!(code.contains("colocated"));
631        assert!(code.contains("separated"));
632    }
633
634    #[test]
635    fn test_generate_choreography_code_with_topology() {
636        let choreography = create_test_choreography();
637        let local_types = vec![
638            (
639                Role::new(format_ident!("Alice")).unwrap(),
640                crate::ast::LocalType::End,
641            ),
642            (
643                Role::new(format_ident!("Bob")).unwrap(),
644                crate::ast::LocalType::End,
645            ),
646        ];
647        let inline_topologies = vec![];
648
649        let tokens = generate_choreography_code_with_topology(
650            &choreography,
651            &local_types,
652            &inline_topologies,
653        );
654        let code = tokens.to_string();
655
656        // Should contain role definitions
657        assert!(code.contains("Alice") || code.contains("Roles"));
658        // Should contain topology integration
659        assert!(code.contains("pub mod topology"));
660    }
661}