Skip to main content

telltale_language/
extensions.rs

1//! DSL Extension System for Telltale
2//!
3//! This module provides a clean, composable system for extending choreographic DSL syntax.
4//! Extensions can add new grammar rules, custom statement parsers, and protocol behaviors
5//! while maintaining compatibility with the core choreographic infrastructure.
6
7use crate::ast::{LocalType, Role};
8use crate::compiler::projection::ProjectionError;
9use std::any::{Any, TypeId};
10use std::collections::BTreeMap;
11use std::fmt::Debug;
12
13/// Documentation for an extension
14#[derive(Debug, Clone)]
15pub struct ExtensionDocumentation {
16    pub overview: String,
17    pub syntax_guide: String,
18    pub use_cases: Vec<String>,
19    pub limitations: Vec<String>,
20    pub see_also: Vec<String>,
21}
22
23impl Default for ExtensionDocumentation {
24    fn default() -> Self {
25        Self {
26            overview: "No documentation provided".to_string(),
27            syntax_guide: "No syntax guide provided".to_string(),
28            use_cases: vec![],
29            limitations: vec![],
30            see_also: vec![],
31        }
32    }
33}
34
35/// Example usage for an extension
36#[derive(Debug, Clone)]
37pub struct ExtensionExample {
38    pub title: String,
39    pub description: String,
40    pub code: String,
41    pub expected_output: Option<String>,
42}
43
44/// Trait for adding new grammar rules to the choreographic DSL
45pub trait GrammarExtension: Send + Sync + Debug {
46    /// Return the Pest grammar rules this extension provides
47    fn grammar_rules(&self) -> &'static str;
48
49    /// List of statement rule names this extension handles
50    fn statement_rules(&self) -> Vec<&'static str>;
51
52    /// Priority for conflict resolution (higher = more precedence)
53    fn priority(&self) -> u32 {
54        100
55    }
56
57    /// Extension identifier for debugging and registration
58    fn extension_id(&self) -> &'static str;
59}
60
61/// Trait for self-documenting extensions
62pub trait DocumentedGrammarExtension: GrammarExtension {
63    /// Documentation for this extension
64    fn documentation(&self) -> ExtensionDocumentation {
65        ExtensionDocumentation::default()
66    }
67
68    /// Examples showing how to use this extension
69    fn examples(&self) -> Vec<ExtensionExample> {
70        vec![]
71    }
72
73    /// Grammar rules with human-readable descriptions
74    fn rule_descriptions(&self) -> std::collections::HashMap<String, String> {
75        std::collections::HashMap::new()
76    }
77}
78
79/// Trait for parsing custom protocol statements
80pub trait StatementParser: Send + Sync + Debug {
81    /// Check if this parser can handle the given rule name
82    fn can_parse(&self, rule_name: &str) -> bool;
83
84    /// Return all rules this parser supports
85    fn supported_rules(&self) -> Vec<String>;
86
87    /// Parse a statement into a protocol extension
88    ///
89    /// # Arguments
90    /// * `rule_name` - The grammar rule name being parsed
91    /// * `content` - The matched content as a string
92    /// * `context` - Parsing context with declared roles
93    ///
94    /// # Returns
95    /// A boxed protocol extension representing the parsed statement
96    fn parse_statement(
97        &self,
98        rule_name: &str,
99        content: &str,
100        context: &ParseContext,
101    ) -> Result<Box<dyn ProtocolExtension>, ParseError>;
102}
103
104/// Trait for custom protocol behaviors that can be projected and validated
105pub trait ProtocolExtension: Send + Sync + Debug {
106    /// Unique identifier for this protocol extension type
107    fn type_name(&self) -> &'static str;
108
109    /// Check if this protocol mentions a specific role
110    fn mentions_role(&self, role: &Role) -> bool;
111
112    /// Validate this protocol against declared roles
113    fn validate(&self, roles: &[Role]) -> Result<(), ExtensionValidationError>;
114
115    /// Project this protocol to a local type for a specific role
116    fn project(
117        &self,
118        role: &Role,
119        context: &ProjectionContext,
120    ) -> Result<LocalType, ProjectionError>;
121
122    /// Generate code for this protocol extension
123    fn generate_code(&self, context: &CodegenContext) -> proc_macro2::TokenStream;
124
125    /// For trait object safety and downcasting
126    fn as_any(&self) -> &dyn Any;
127    fn as_any_mut(&mut self) -> &mut dyn Any;
128    fn type_id(&self) -> TypeId;
129}
130
131/// Registry for managing DSL extensions with conflict resolution
132#[derive(Debug, Default)]
133pub struct ExtensionRegistry {
134    grammar_extensions: BTreeMap<String, Box<dyn GrammarExtension>>,
135    statement_parsers: BTreeMap<String, Box<dyn StatementParser>>,
136    rule_to_parser: BTreeMap<String, String>,
137    /// Track rule conflicts for resolution
138    rule_conflicts: BTreeMap<String, Vec<String>>,
139    /// Extension dependencies
140    extension_dependencies: BTreeMap<String, Vec<String>>,
141    /// Extension version information for compatibility checking
142    extension_versions: BTreeMap<String, String>,
143}
144
145impl ExtensionRegistry {
146    /// Create a new empty extension registry
147    pub fn new() -> Self {
148        Self::default()
149    }
150
151    /// Register a grammar extension with conflict detection
152    pub fn register_grammar<T: GrammarExtension + 'static>(
153        &mut self,
154        extension: T,
155    ) -> Result<(), ParseError> {
156        let id = extension.extension_id().to_string();
157        let rules = extension.statement_rules();
158        let priority = extension.priority();
159
160        // Check for conflicts and resolve by priority
161        for rule in &rules {
162            if let Some(existing_id) = self.rule_to_parser.get(*rule) {
163                let existing_priority = self
164                    .grammar_extensions
165                    .get(existing_id)
166                    .map(|e| e.priority())
167                    .unwrap_or(0);
168
169                if priority > existing_priority {
170                    // New extension wins, record conflict
171                    self.rule_conflicts
172                        .entry((*rule).to_string())
173                        .or_default()
174                        .push(existing_id.clone());
175                    self.rule_to_parser.insert((*rule).to_string(), id.clone());
176                } else if priority == existing_priority {
177                    // Equal priority - this is a conflict
178                    return Err(ParseError::PriorityConflict {
179                        extension1: existing_id.clone(),
180                        extension2: id.clone(),
181                        priority1: existing_priority,
182                        priority2: priority,
183                        rule: (*rule).to_string(),
184                    });
185                }
186                // Lower priority - existing extension wins
187            } else {
188                self.rule_to_parser.insert((*rule).to_string(), id.clone());
189            }
190        }
191
192        self.grammar_extensions
193            .insert(id.clone(), Box::new(extension));
194        // Set default version if not specified
195        self.extension_versions
196            .entry(id)
197            .or_insert_with(|| "0.1.0".to_string());
198        Ok(())
199    }
200
201    /// Register a statement parser
202    pub fn register_parser<T: StatementParser + 'static>(&mut self, parser: T, parser_id: String) {
203        self.statement_parsers.insert(parser_id, Box::new(parser));
204    }
205
206    /// Get all grammar rules from registered extensions
207    pub fn compose_grammar(&self, base_grammar: &str) -> String {
208        let mut composed = base_grammar.to_string();
209
210        // Sort extensions by priority (highest first)
211        let mut extensions: Vec<_> = self.grammar_extensions.iter().collect();
212        extensions.sort_by(|(id_a, ext_a), (id_b, ext_b)| {
213            std::cmp::Reverse(ext_a.priority())
214                .cmp(&std::cmp::Reverse(ext_b.priority()))
215                .then_with(|| id_a.cmp(id_b))
216        });
217
218        for (_, extension) in extensions {
219            composed.push('\n');
220            composed.push_str(extension.grammar_rules());
221        }
222
223        composed
224    }
225
226    /// Find parser for a given rule name
227    pub fn find_parser(&self, rule_name: &str) -> Option<&dyn StatementParser> {
228        if let Some(parser_id) = self.rule_to_parser.get(rule_name) {
229            self.statement_parsers.get(parser_id).map(|p| p.as_ref())
230        } else {
231            None
232        }
233    }
234
235    /// Check if a rule is handled by an extension
236    pub fn can_handle(&self, rule_name: &str) -> bool {
237        self.rule_to_parser.contains_key(rule_name)
238    }
239
240    /// Check if any extensions are registered
241    pub fn has_extensions(&self) -> bool {
242        !self.grammar_extensions.is_empty() || !self.statement_parsers.is_empty()
243    }
244
245    /// Get all grammar extensions
246    pub fn grammar_extensions(&self) -> impl Iterator<Item = &dyn GrammarExtension> {
247        let mut ordered: Vec<_> = self.grammar_extensions.iter().collect();
248        ordered.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b));
249        ordered.into_iter().map(|(_, e)| e.as_ref())
250    }
251
252    /// Check if a specific extension is registered
253    pub fn has_extension(&self, extension_id: &str) -> bool {
254        self.grammar_extensions.contains_key(extension_id)
255    }
256
257    /// Get parser for a rule name
258    pub fn get_parser_for_rule(&self, rule_name: &str) -> Option<&str> {
259        self.rule_to_parser.get(rule_name).map(String::as_str)
260    }
261
262    /// Get statement parser by ID
263    pub fn get_statement_parser(&self, parser_id: &str) -> Option<&dyn StatementParser> {
264        self.statement_parsers.get(parser_id).map(|p| p.as_ref())
265    }
266
267    /// Add dependency between extensions
268    pub fn add_dependency(&mut self, dependent: &str, required: &str) {
269        self.extension_dependencies
270            .entry(dependent.to_string())
271            .or_default()
272            .push(required.to_string());
273    }
274
275    /// Validate all extension dependencies are satisfied
276    pub fn validate_dependencies(&self) -> Result<(), ParseError> {
277        for (dependent, requirements) in &self.extension_dependencies {
278            for required in requirements {
279                if !self.grammar_extensions.contains_key(required) {
280                    return Err(ParseError::MissingDependency {
281                        extension: dependent.clone(),
282                        dependency: required.clone(),
283                    });
284                }
285            }
286        }
287        Ok(())
288    }
289
290    /// Get all rule conflicts for debugging
291    pub fn get_conflicts(&self) -> &BTreeMap<String, Vec<String>> {
292        &self.rule_conflicts
293    }
294
295    /// Get detailed conflict information with resolution suggestions
296    pub fn get_detailed_conflicts(&self) -> Vec<String> {
297        let mut details = Vec::new();
298        let unknown_ext = "unknown".to_string();
299
300        let mut conflicts: Vec<_> = self.rule_conflicts.iter().collect();
301        conflicts.sort_by(|(rule_a, _), (rule_b, _)| rule_a.cmp(rule_b));
302
303        for (rule, conflicting_extensions) in conflicts {
304            if !conflicting_extensions.is_empty() {
305                let active_extension = self.rule_to_parser.get(rule).unwrap_or(&unknown_ext);
306                let active_priority = self
307                    .grammar_extensions
308                    .get(active_extension)
309                    .map(|e| e.priority())
310                    .unwrap_or(0);
311
312                let mut conflicting_extensions = conflicting_extensions.clone();
313                conflicting_extensions.sort();
314
315                for conflicting in &conflicting_extensions {
316                    let conflicting_priority = self
317                        .grammar_extensions
318                        .get(conflicting)
319                        .map(|e| e.priority())
320                        .unwrap_or(0);
321
322                    details.push(format!(
323                        "Rule '{}': Extension '{}' (priority {}) overrode '{}' (priority {}). \
324                         To resolve: 1) Adjust priorities, 2) Use different rule names, or 3) Merge functionality.",
325                        rule, active_extension, active_priority, conflicting, conflicting_priority
326                    ));
327                }
328            }
329        }
330
331        details
332    }
333
334    /// Check extension compatibility
335    pub fn check_compatibility(&self, extension_ids: &[&str]) -> Result<(), ParseError> {
336        // Check for direct conflicts between the specified extensions
337        let mut rules_used = BTreeMap::new();
338
339        for &extension_id in extension_ids {
340            if let Some(extension) = self.grammar_extensions.get(extension_id) {
341                for rule in extension.statement_rules() {
342                    if let Some(existing) = rules_used.get(rule) {
343                        if existing != &extension_id {
344                            return Err(ParseError::IncompatibleExtensions {
345                                details: format!(
346                                    "Extensions '{}' and '{}' both define rule '{}'. Use different rule names or register extensions with different priorities.",
347                                    existing, extension_id, rule
348                                ),
349                            });
350                        }
351                    }
352                    rules_used.insert(rule.to_string(), extension_id);
353                }
354            }
355        }
356        Ok(())
357    }
358
359    /// Create a registry with built-in extensions
360    pub fn with_builtin_extensions() -> Self {
361        let mut registry = Self::new();
362
363        // Register timeout extension
364        registry
365            .register_grammar(timeout::TimeoutGrammarExtension)
366            .expect("builtin timeout extension should register successfully");
367        registry.register_parser(timeout::TimeoutStatementParser, "timeout".to_string());
368
369        registry
370    }
371
372    /// Create a minimal registry for 3rd party integration
373    pub fn for_third_party() -> Self {
374        Self::new()
375    }
376
377    /// Generate basic documentation for all registered extensions
378    pub fn generate_docs(&self) -> String {
379        let mut docs = String::from("# Extension Documentation\n\n");
380
381        let mut entries: Vec<_> = self.grammar_extensions.iter().collect();
382        entries.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b));
383
384        for (id, extension) in entries {
385            docs.push_str(&format!("## {}\n\n", id));
386            docs.push_str(&format!("**Priority:** {}\n\n", extension.priority()));
387            docs.push_str(&format!(
388                "**Rules:** {}\n\n",
389                extension.statement_rules().join(", ")
390            ));
391
392            if let Some(version) = self.extension_versions.get(id) {
393                docs.push_str(&format!("**Version:** {}\n\n", version));
394            }
395
396            docs.push_str("**Grammar:**\n```\n");
397            docs.push_str(extension.grammar_rules());
398            docs.push_str("\n```\n\n");
399        }
400
401        docs
402    }
403}
404
405/// Context provided during statement parsing
406#[derive(Debug)]
407pub struct ParseContext<'a> {
408    /// Roles declared in the choreography
409    pub declared_roles: &'a [Role],
410    /// Original input string for error reporting
411    pub input: &'a str,
412}
413
414/// Context provided during projection
415#[derive(Debug)]
416pub struct ProjectionContext<'a> {
417    /// All roles in the choreography
418    pub all_roles: &'a [Role],
419    /// Current role being projected
420    pub current_role: &'a Role,
421}
422
423/// Context provided during code generation
424#[derive(Debug)]
425pub struct CodegenContext<'a> {
426    /// The choreography being generated
427    pub choreography_name: &'a str,
428    /// All roles in the choreography
429    pub roles: &'a [Role],
430    /// Namespace for generated code
431    pub namespace: Option<&'a str>,
432}
433
434impl<'a> Default for CodegenContext<'a> {
435    fn default() -> Self {
436        Self {
437            choreography_name: "Default",
438            roles: &[],
439            namespace: None,
440        }
441    }
442}
443
444/// Errors that can occur during extension parsing
445#[derive(Debug, thiserror::Error)]
446pub enum ParseError {
447    #[error("Syntax error: {message}")]
448    Syntax { message: String },
449
450    #[error("Unknown role '{role}' used in extension")]
451    UnknownRole { role: String },
452
453    #[error("Invalid extension syntax: {details}")]
454    InvalidSyntax { details: String },
455
456    #[error("Extension conflict: {message}")]
457    Conflict { message: String },
458
459    #[error("Extension priority conflict: Extension '{extension1}' (priority {priority1}) conflicts with '{extension2}' (priority {priority2}) for rule '{rule}'. Consider adjusting priorities or using different rule names.")]
460    PriorityConflict {
461        extension1: String,
462        extension2: String,
463        priority1: u32,
464        priority2: u32,
465        rule: String,
466    },
467
468    #[error("Missing dependency: Extension '{extension}' requires '{dependency}' which is not registered. Please register the required extension first.")]
469    MissingDependency {
470        extension: String,
471        dependency: String,
472    },
473
474    #[error("Extension registration failed: Extension '{extension}' with rule '{rule}' cannot be registered. {details}")]
475    RegistrationFailed {
476        extension: String,
477        rule: String,
478        details: String,
479    },
480
481    #[error("Incompatible extensions: {details}")]
482    IncompatibleExtensions { details: String },
483}
484
485/// Validation errors for protocol extensions
486#[derive(Debug, thiserror::Error)]
487pub enum ExtensionValidationError {
488    #[error("Role '{role}' not declared")]
489    UndeclaredRole { role: String },
490
491    #[error("Invalid protocol structure: {reason}")]
492    InvalidStructure { reason: String },
493
494    #[error("Extension validation failed: {message}")]
495    ExtensionFailed { message: String },
496}
497
498/// Convenience macro for registering extensions
499#[macro_export]
500macro_rules! register_extension {
501    ($registry:expr, $extension:expr) => {{
502        let ext = $extension;
503        let id = ext.extension_id().to_string();
504        $registry.register_grammar(ext);
505    }};
506}
507
508/// Utility trait for easy extension registration
509pub trait RegisterExtension {
510    fn register_all(registry: &mut ExtensionRegistry);
511}
512
513pub mod discovery;
514/// Built-in extensions
515pub mod timeout;
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[derive(Debug)]
522    struct MockGrammarExtension;
523
524    impl GrammarExtension for MockGrammarExtension {
525        fn grammar_rules(&self) -> &'static str {
526            "timeout_stmt = { \"timeout\" ~ integer ~ protocol_block }"
527        }
528
529        fn statement_rules(&self) -> Vec<&'static str> {
530            vec!["timeout_stmt"]
531        }
532
533        fn extension_id(&self) -> &'static str {
534            "mock_timeout"
535        }
536    }
537
538    #[test]
539    fn test_extension_registry() {
540        let mut registry = ExtensionRegistry::new();
541
542        // Register extension
543        registry
544            .register_grammar(MockGrammarExtension)
545            .expect("extension registration should succeed");
546
547        // Test rule mapping
548        assert!(registry.can_handle("timeout_stmt"));
549        assert!(!registry.can_handle("unknown_rule"));
550
551        // Test grammar composition
552        let base = "basic_rule = { \"test\" }";
553        let composed = registry.compose_grammar(base);
554        assert!(composed.contains("basic_rule"));
555        assert!(composed.contains("timeout_stmt"));
556    }
557
558    #[test]
559    fn test_enhanced_error_messages() {
560        use crate::extensions::ParseError;
561
562        // Test priority conflict error
563        let err = ParseError::PriorityConflict {
564            extension1: "ext1".to_string(),
565            extension2: "ext2".to_string(),
566            priority1: 100,
567            priority2: 100,
568            rule: "test_rule".to_string(),
569        };
570        assert!(err.to_string().contains("Consider adjusting priorities"));
571
572        // Test missing dependency error
573        let err = ParseError::MissingDependency {
574            extension: "dependent_ext".to_string(),
575            dependency: "required_ext".to_string(),
576        };
577        assert!(err
578            .to_string()
579            .contains("Please register the required extension first"));
580
581        // Test incompatible extensions error
582        let err = ParseError::IncompatibleExtensions {
583            details: "Test incompatibility".to_string(),
584        };
585        assert!(err.to_string().contains("Incompatible extensions"));
586    }
587
588    #[test]
589    fn test_detailed_conflicts() {
590        #[derive(Debug)]
591        struct TestExt1;
592        impl GrammarExtension for TestExt1 {
593            fn grammar_rules(&self) -> &'static str {
594                "rule1 = { \"test1\" }"
595            }
596            fn statement_rules(&self) -> Vec<&'static str> {
597                vec!["rule1"]
598            }
599            fn priority(&self) -> u32 {
600                200
601            }
602            fn extension_id(&self) -> &'static str {
603                "test_ext1"
604            }
605        }
606
607        #[derive(Debug)]
608        struct TestExt2;
609        impl GrammarExtension for TestExt2 {
610            fn grammar_rules(&self) -> &'static str {
611                "rule1 = { \"test2\" }"
612            }
613            fn statement_rules(&self) -> Vec<&'static str> {
614                vec!["rule1"]
615            }
616            fn priority(&self) -> u32 {
617                100
618            }
619            fn extension_id(&self) -> &'static str {
620                "test_ext2"
621            }
622        }
623
624        let mut registry = ExtensionRegistry::new();
625
626        // Register lower priority first
627        registry
628            .register_grammar(TestExt2)
629            .expect("lower priority extension should register");
630        // Register higher priority second (should override)
631        registry
632            .register_grammar(TestExt1)
633            .expect("higher priority extension should override");
634
635        let conflicts = registry.get_detailed_conflicts();
636        assert!(!conflicts.is_empty());
637        assert!(conflicts[0].contains("overrode"));
638        assert!(conflicts[0].contains("priority"));
639    }
640
641    #[test]
642    fn test_documentation_system() {
643        let mut registry = ExtensionRegistry::new();
644
645        registry
646            .extension_versions
647            .insert("mock_timeout".to_string(), "1.0.0".to_string());
648        registry
649            .register_grammar(MockGrammarExtension)
650            .expect("grammar extension should register");
651
652        // Test documentation generation
653        let docs = registry.generate_docs();
654        assert!(docs.contains("# Extension Documentation"));
655        assert!(docs.contains("mock_timeout"));
656        assert!(docs.contains("**Priority:** 100"));
657        assert!(docs.contains("**Version:** 1.0.0"));
658
659        assert_eq!(
660            registry.extension_versions.get("mock_timeout"),
661            Some(&"1.0.0".to_string())
662        );
663    }
664
665    #[test]
666    fn test_compose_grammar_is_stable_for_equal_priorities() {
667        #[derive(Debug)]
668        struct AlphaExt;
669        impl GrammarExtension for AlphaExt {
670            fn grammar_rules(&self) -> &'static str {
671                "alpha_stmt = { \"alpha\" }"
672            }
673            fn statement_rules(&self) -> Vec<&'static str> {
674                vec!["alpha_stmt"]
675            }
676            fn priority(&self) -> u32 {
677                100
678            }
679            fn extension_id(&self) -> &'static str {
680                "alpha_ext"
681            }
682        }
683
684        #[derive(Debug)]
685        struct BetaExt;
686        impl GrammarExtension for BetaExt {
687            fn grammar_rules(&self) -> &'static str {
688                "beta_stmt = { \"beta\" }"
689            }
690            fn statement_rules(&self) -> Vec<&'static str> {
691                vec!["beta_stmt"]
692            }
693            fn priority(&self) -> u32 {
694                100
695            }
696            fn extension_id(&self) -> &'static str {
697                "beta_ext"
698            }
699        }
700
701        let mut registry = ExtensionRegistry::new();
702        registry.register_grammar(BetaExt).unwrap();
703        registry.register_grammar(AlphaExt).unwrap();
704
705        let composed = registry.compose_grammar("base = { \"x\" }");
706        let alpha_idx = composed.find("alpha_stmt").unwrap();
707        let beta_idx = composed.find("beta_stmt").unwrap();
708        assert!(alpha_idx < beta_idx);
709    }
710
711    #[test]
712    fn test_parse_context() {
713        use proc_macro2::Span;
714        let roles = vec![
715            Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap(),
716            Role::new(proc_macro2::Ident::new("Bob", Span::call_site())).unwrap(),
717        ];
718
719        let context = ParseContext {
720            declared_roles: &roles,
721            input: "test input",
722        };
723
724        assert_eq!(context.declared_roles.len(), 2);
725        assert_eq!(context.input, "test input");
726    }
727}