1use crate::ast::{LocalType, Role};
8use crate::compiler::projection::ProjectionError;
9use std::any::{Any, TypeId};
10use std::collections::BTreeMap;
11use std::fmt::Debug;
12
13#[derive(Debug, Clone)]
15pub struct ExtensionDocumentation {
16 pub overview: String,
17 pub syntax_guide: String,
18 pub use_cases: Vec<String>,
19 pub limitations: Vec<String>,
20 pub see_also: Vec<String>,
21}
22
23impl Default for ExtensionDocumentation {
24 fn default() -> Self {
25 Self {
26 overview: "No documentation provided".to_string(),
27 syntax_guide: "No syntax guide provided".to_string(),
28 use_cases: vec![],
29 limitations: vec![],
30 see_also: vec![],
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct ExtensionExample {
38 pub title: String,
39 pub description: String,
40 pub code: String,
41 pub expected_output: Option<String>,
42}
43
44pub trait GrammarExtension: Send + Sync + Debug {
46 fn grammar_rules(&self) -> &'static str;
48
49 fn statement_rules(&self) -> Vec<&'static str>;
51
52 fn priority(&self) -> u32 {
54 100
55 }
56
57 fn extension_id(&self) -> &'static str;
59}
60
61pub trait DocumentedGrammarExtension: GrammarExtension {
63 fn documentation(&self) -> ExtensionDocumentation {
65 ExtensionDocumentation::default()
66 }
67
68 fn examples(&self) -> Vec<ExtensionExample> {
70 vec![]
71 }
72
73 fn rule_descriptions(&self) -> std::collections::HashMap<String, String> {
75 std::collections::HashMap::new()
76 }
77}
78
79pub trait StatementParser: Send + Sync + Debug {
81 fn can_parse(&self, rule_name: &str) -> bool;
83
84 fn supported_rules(&self) -> Vec<String>;
86
87 fn parse_statement(
97 &self,
98 rule_name: &str,
99 content: &str,
100 context: &ParseContext,
101 ) -> Result<Box<dyn ProtocolExtension>, ParseError>;
102}
103
104pub trait ProtocolExtension: Send + Sync + Debug {
106 fn type_name(&self) -> &'static str;
108
109 fn mentions_role(&self, role: &Role) -> bool;
111
112 fn validate(&self, roles: &[Role]) -> Result<(), ExtensionValidationError>;
114
115 fn project(
117 &self,
118 role: &Role,
119 context: &ProjectionContext,
120 ) -> Result<LocalType, ProjectionError>;
121
122 fn generate_code(&self, context: &CodegenContext) -> proc_macro2::TokenStream;
124
125 fn as_any(&self) -> &dyn Any;
127 fn as_any_mut(&mut self) -> &mut dyn Any;
128 fn type_id(&self) -> TypeId;
129 fn clone_box(&self) -> Box<dyn ProtocolExtension>;
130}
131
132impl Clone for Box<dyn ProtocolExtension> {
133 fn clone(&self) -> Self {
134 self.clone_box()
135 }
136}
137
138#[derive(Debug, Default)]
140pub struct ExtensionRegistry {
141 grammar_extensions: BTreeMap<String, Box<dyn GrammarExtension>>,
142 statement_parsers: BTreeMap<String, Box<dyn StatementParser>>,
143 rule_to_parser: BTreeMap<String, String>,
144 rule_conflicts: BTreeMap<String, Vec<String>>,
146 extension_dependencies: BTreeMap<String, Vec<String>>,
148 extension_versions: BTreeMap<String, String>,
150}
151
152impl ExtensionRegistry {
153 pub fn new() -> Self {
155 Self::default()
156 }
157
158 pub fn register_grammar<T: GrammarExtension + 'static>(
160 &mut self,
161 extension: T,
162 ) -> Result<(), ParseError> {
163 let id = extension.extension_id().to_string();
164 let rules = extension.statement_rules();
165 let priority = extension.priority();
166
167 for rule in &rules {
169 if let Some(existing_id) = self.rule_to_parser.get(*rule) {
170 let existing_priority = self
171 .grammar_extensions
172 .get(existing_id)
173 .map(|e| e.priority())
174 .unwrap_or(0);
175
176 if priority > existing_priority {
177 self.rule_conflicts
179 .entry((*rule).to_string())
180 .or_default()
181 .push(existing_id.clone());
182 self.rule_to_parser.insert((*rule).to_string(), id.clone());
183 } else if priority == existing_priority {
184 return Err(ParseError::PriorityConflict {
186 extension1: existing_id.clone(),
187 extension2: id.clone(),
188 priority1: existing_priority,
189 priority2: priority,
190 rule: (*rule).to_string(),
191 });
192 }
193 } else {
195 self.rule_to_parser.insert((*rule).to_string(), id.clone());
196 }
197 }
198
199 self.grammar_extensions
200 .insert(id.clone(), Box::new(extension));
201 self.extension_versions
203 .entry(id)
204 .or_insert_with(|| "0.1.0".to_string());
205 Ok(())
206 }
207
208 pub fn register_parser<T: StatementParser + 'static>(&mut self, parser: T, parser_id: String) {
210 self.statement_parsers.insert(parser_id, Box::new(parser));
211 }
212
213 pub fn compose_grammar(&self, base_grammar: &str) -> String {
215 let mut composed = base_grammar.to_string();
216
217 let mut extensions: Vec<_> = self.grammar_extensions.iter().collect();
219 extensions.sort_by(|(id_a, ext_a), (id_b, ext_b)| {
220 std::cmp::Reverse(ext_a.priority())
221 .cmp(&std::cmp::Reverse(ext_b.priority()))
222 .then_with(|| id_a.cmp(id_b))
223 });
224
225 for (_, extension) in extensions {
226 composed.push('\n');
227 composed.push_str(extension.grammar_rules());
228 }
229
230 composed
231 }
232
233 pub fn find_parser(&self, rule_name: &str) -> Option<&dyn StatementParser> {
235 if let Some(parser_id) = self.rule_to_parser.get(rule_name) {
236 self.statement_parsers.get(parser_id).map(|p| p.as_ref())
237 } else {
238 None
239 }
240 }
241
242 pub fn can_handle(&self, rule_name: &str) -> bool {
244 self.rule_to_parser.contains_key(rule_name)
245 }
246
247 pub fn has_extensions(&self) -> bool {
249 !self.grammar_extensions.is_empty() || !self.statement_parsers.is_empty()
250 }
251
252 pub fn grammar_extensions(&self) -> impl Iterator<Item = &dyn GrammarExtension> {
254 let mut ordered: Vec<_> = self.grammar_extensions.iter().collect();
255 ordered.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b));
256 ordered.into_iter().map(|(_, e)| e.as_ref())
257 }
258
259 pub fn has_extension(&self, extension_id: &str) -> bool {
261 self.grammar_extensions.contains_key(extension_id)
262 }
263
264 pub fn get_parser_for_rule(&self, rule_name: &str) -> Option<&str> {
266 self.rule_to_parser.get(rule_name).map(String::as_str)
267 }
268
269 pub fn get_statement_parser(&self, parser_id: &str) -> Option<&dyn StatementParser> {
271 self.statement_parsers.get(parser_id).map(|p| p.as_ref())
272 }
273
274 pub fn statement_parser_count(&self) -> usize {
276 self.statement_parsers.len()
277 }
278
279 pub fn statement_rules(&self) -> Vec<&str> {
281 let mut rules: Vec<_> = self.rule_to_parser.keys().map(String::as_str).collect();
282 rules.sort_unstable();
283 rules
284 }
285
286 pub fn add_dependency(&mut self, dependent: &str, required: &str) {
288 self.extension_dependencies
289 .entry(dependent.to_string())
290 .or_default()
291 .push(required.to_string());
292 }
293
294 pub fn validate_dependencies(&self) -> Result<(), ParseError> {
296 for (dependent, requirements) in &self.extension_dependencies {
297 for required in requirements {
298 if !self.grammar_extensions.contains_key(required) {
299 return Err(ParseError::MissingDependency {
300 extension: dependent.clone(),
301 dependency: required.clone(),
302 });
303 }
304 }
305 }
306 Ok(())
307 }
308
309 pub fn get_conflicts(&self) -> &BTreeMap<String, Vec<String>> {
311 &self.rule_conflicts
312 }
313
314 pub fn get_detailed_conflicts(&self) -> Vec<String> {
316 let mut details = Vec::new();
317 let unknown_ext = "unknown".to_string();
318
319 let mut conflicts: Vec<_> = self.rule_conflicts.iter().collect();
320 conflicts.sort_by(|(rule_a, _), (rule_b, _)| rule_a.cmp(rule_b));
321
322 for (rule, conflicting_extensions) in conflicts {
323 if !conflicting_extensions.is_empty() {
324 let active_extension = self.rule_to_parser.get(rule).unwrap_or(&unknown_ext);
325 let active_priority = self
326 .grammar_extensions
327 .get(active_extension)
328 .map(|e| e.priority())
329 .unwrap_or(0);
330
331 let mut conflicting_extensions = conflicting_extensions.clone();
332 conflicting_extensions.sort();
333
334 for conflicting in &conflicting_extensions {
335 let conflicting_priority = self
336 .grammar_extensions
337 .get(conflicting)
338 .map(|e| e.priority())
339 .unwrap_or(0);
340
341 details.push(format!(
342 "Rule '{}': Extension '{}' (priority {}) overrode '{}' (priority {}). \
343 To resolve: 1) Adjust priorities, 2) Use different rule names, or 3) Merge functionality.",
344 rule, active_extension, active_priority, conflicting, conflicting_priority
345 ));
346 }
347 }
348 }
349
350 details
351 }
352
353 pub fn check_compatibility(&self, extension_ids: &[&str]) -> Result<(), ParseError> {
355 let mut rules_used = BTreeMap::new();
357
358 for &extension_id in extension_ids {
359 if let Some(extension) = self.grammar_extensions.get(extension_id) {
360 for rule in extension.statement_rules() {
361 if let Some(existing) = rules_used.get(rule) {
362 if existing != &extension_id {
363 return Err(ParseError::IncompatibleExtensions {
364 details: format!(
365 "Extensions '{}' and '{}' both define rule '{}'. Use different rule names or register extensions with different priorities.",
366 existing, extension_id, rule
367 ),
368 });
369 }
370 }
371 rules_used.insert(rule.to_string(), extension_id);
372 }
373 }
374 }
375 Ok(())
376 }
377
378 pub fn with_builtin_extensions() -> Self {
380 let mut registry = Self::new();
381
382 registry
384 .register_grammar(timeout::TimeoutGrammarExtension)
385 .expect("builtin timeout extension should register successfully");
386 registry.register_parser(timeout::TimeoutStatementParser, "timeout".to_string());
387
388 registry
389 }
390
391 pub fn for_third_party() -> Self {
393 Self::new()
394 }
395
396 pub fn generate_docs(&self) -> String {
398 let mut docs = String::from("# Extension Documentation\n\n");
399
400 let mut entries: Vec<_> = self.grammar_extensions.iter().collect();
401 entries.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b));
402
403 for (id, extension) in entries {
404 docs.push_str(&format!("## {}\n\n", id));
405 docs.push_str(&format!("**Priority:** {}\n\n", extension.priority()));
406 docs.push_str(&format!(
407 "**Rules:** {}\n\n",
408 extension.statement_rules().join(", ")
409 ));
410
411 if let Some(version) = self.extension_versions.get(id) {
412 docs.push_str(&format!("**Version:** {}\n\n", version));
413 }
414
415 docs.push_str("**Grammar:**\n```\n");
416 docs.push_str(extension.grammar_rules());
417 docs.push_str("\n```\n\n");
418 }
419
420 docs
421 }
422}
423
424#[derive(Debug)]
426pub struct ParseContext<'a> {
427 pub declared_roles: &'a [Role],
429 pub input: &'a str,
431}
432
433#[derive(Debug)]
435pub struct ProjectionContext<'a> {
436 pub all_roles: &'a [Role],
438 pub current_role: &'a Role,
440}
441
442#[derive(Debug)]
444pub struct CodegenContext<'a> {
445 pub choreography_name: &'a str,
447 pub roles: &'a [Role],
449 pub namespace: Option<&'a str>,
451}
452
453impl<'a> Default for CodegenContext<'a> {
454 fn default() -> Self {
455 Self {
456 choreography_name: "Default",
457 roles: &[],
458 namespace: None,
459 }
460 }
461}
462
463#[derive(Debug, thiserror::Error)]
465pub enum ParseError {
466 #[error("Syntax error: {message}")]
467 Syntax { message: String },
468
469 #[error("Unknown role '{role}' used in extension")]
470 UnknownRole { role: String },
471
472 #[error("Invalid extension syntax: {details}")]
473 InvalidSyntax { details: String },
474
475 #[error("Extension conflict: {message}")]
476 Conflict { message: String },
477
478 #[error("Extension priority conflict: Extension '{extension1}' (priority {priority1}) conflicts with '{extension2}' (priority {priority2}) for rule '{rule}'. Consider adjusting priorities or using different rule names.")]
479 PriorityConflict {
480 extension1: String,
481 extension2: String,
482 priority1: u32,
483 priority2: u32,
484 rule: String,
485 },
486
487 #[error("Missing dependency: Extension '{extension}' requires '{dependency}' which is not registered. Please register the required extension first.")]
488 MissingDependency {
489 extension: String,
490 dependency: String,
491 },
492
493 #[error("Extension registration failed: Extension '{extension}' with rule '{rule}' cannot be registered. {details}")]
494 RegistrationFailed {
495 extension: String,
496 rule: String,
497 details: String,
498 },
499
500 #[error("Incompatible extensions: {details}")]
501 IncompatibleExtensions { details: String },
502}
503
504#[derive(Debug, thiserror::Error)]
506pub enum ExtensionValidationError {
507 #[error("Role '{role}' not declared")]
508 UndeclaredRole { role: String },
509
510 #[error("Invalid protocol structure: {reason}")]
511 InvalidStructure { reason: String },
512
513 #[error("Extension validation failed: {message}")]
514 ExtensionFailed { message: String },
515}
516
517#[macro_export]
519macro_rules! register_extension {
520 ($registry:expr, $extension:expr) => {{
521 let ext = $extension;
522 let id = ext.extension_id().to_string();
523 $registry.register_grammar(ext);
524 }};
525}
526
527pub trait RegisterExtension {
529 fn register_all(registry: &mut ExtensionRegistry);
530}
531
532pub mod discovery;
533pub mod timeout;
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 #[derive(Debug)]
541 struct MockGrammarExtension;
542
543 impl GrammarExtension for MockGrammarExtension {
544 fn grammar_rules(&self) -> &'static str {
545 "timeout_stmt = { \"timeout\" ~ integer ~ protocol_block }"
546 }
547
548 fn statement_rules(&self) -> Vec<&'static str> {
549 vec!["timeout_stmt"]
550 }
551
552 fn extension_id(&self) -> &'static str {
553 "mock_timeout"
554 }
555 }
556
557 #[test]
558 fn test_extension_registry() {
559 let mut registry = ExtensionRegistry::new();
560
561 registry
563 .register_grammar(MockGrammarExtension)
564 .expect("extension registration should succeed");
565
566 assert!(registry.can_handle("timeout_stmt"));
568 assert!(!registry.can_handle("unknown_rule"));
569
570 let base = "basic_rule = { \"test\" }";
572 let composed = registry.compose_grammar(base);
573 assert!(composed.contains("basic_rule"));
574 assert!(composed.contains("timeout_stmt"));
575 }
576
577 #[test]
578 fn test_enhanced_error_messages() {
579 use crate::extensions::ParseError;
580
581 let err = ParseError::PriorityConflict {
583 extension1: "ext1".to_string(),
584 extension2: "ext2".to_string(),
585 priority1: 100,
586 priority2: 100,
587 rule: "test_rule".to_string(),
588 };
589 assert!(err.to_string().contains("Consider adjusting priorities"));
590
591 let err = ParseError::MissingDependency {
593 extension: "dependent_ext".to_string(),
594 dependency: "required_ext".to_string(),
595 };
596 assert!(err
597 .to_string()
598 .contains("Please register the required extension first"));
599
600 let err = ParseError::IncompatibleExtensions {
602 details: "Test incompatibility".to_string(),
603 };
604 assert!(err.to_string().contains("Incompatible extensions"));
605 }
606
607 #[test]
608 fn test_detailed_conflicts() {
609 #[derive(Debug)]
610 struct TestExt1;
611 impl GrammarExtension for TestExt1 {
612 fn grammar_rules(&self) -> &'static str {
613 "rule1 = { \"test1\" }"
614 }
615 fn statement_rules(&self) -> Vec<&'static str> {
616 vec!["rule1"]
617 }
618 fn priority(&self) -> u32 {
619 200
620 }
621 fn extension_id(&self) -> &'static str {
622 "test_ext1"
623 }
624 }
625
626 #[derive(Debug)]
627 struct TestExt2;
628 impl GrammarExtension for TestExt2 {
629 fn grammar_rules(&self) -> &'static str {
630 "rule1 = { \"test2\" }"
631 }
632 fn statement_rules(&self) -> Vec<&'static str> {
633 vec!["rule1"]
634 }
635 fn priority(&self) -> u32 {
636 100
637 }
638 fn extension_id(&self) -> &'static str {
639 "test_ext2"
640 }
641 }
642
643 let mut registry = ExtensionRegistry::new();
644
645 registry
647 .register_grammar(TestExt2)
648 .expect("lower priority extension should register");
649 registry
651 .register_grammar(TestExt1)
652 .expect("higher priority extension should override");
653
654 let conflicts = registry.get_detailed_conflicts();
655 assert!(!conflicts.is_empty());
656 assert!(conflicts[0].contains("overrode"));
657 assert!(conflicts[0].contains("priority"));
658 }
659
660 #[test]
661 fn test_documentation_system() {
662 let mut registry = ExtensionRegistry::new();
663
664 registry
665 .extension_versions
666 .insert("mock_timeout".to_string(), "1.0.0".to_string());
667 registry
668 .register_grammar(MockGrammarExtension)
669 .expect("grammar extension should register");
670
671 let docs = registry.generate_docs();
673 assert!(docs.contains("# Extension Documentation"));
674 assert!(docs.contains("mock_timeout"));
675 assert!(docs.contains("**Priority:** 100"));
676 assert!(docs.contains("**Version:** 1.0.0"));
677
678 assert_eq!(
679 registry.extension_versions.get("mock_timeout"),
680 Some(&"1.0.0".to_string())
681 );
682 }
683
684 #[test]
685 fn test_compose_grammar_is_stable_for_equal_priorities() {
686 #[derive(Debug)]
687 struct AlphaExt;
688 impl GrammarExtension for AlphaExt {
689 fn grammar_rules(&self) -> &'static str {
690 "alpha_stmt = { \"alpha\" }"
691 }
692 fn statement_rules(&self) -> Vec<&'static str> {
693 vec!["alpha_stmt"]
694 }
695 fn priority(&self) -> u32 {
696 100
697 }
698 fn extension_id(&self) -> &'static str {
699 "alpha_ext"
700 }
701 }
702
703 #[derive(Debug)]
704 struct BetaExt;
705 impl GrammarExtension for BetaExt {
706 fn grammar_rules(&self) -> &'static str {
707 "beta_stmt = { \"beta\" }"
708 }
709 fn statement_rules(&self) -> Vec<&'static str> {
710 vec!["beta_stmt"]
711 }
712 fn priority(&self) -> u32 {
713 100
714 }
715 fn extension_id(&self) -> &'static str {
716 "beta_ext"
717 }
718 }
719
720 let mut registry = ExtensionRegistry::new();
721 registry.register_grammar(BetaExt).unwrap();
722 registry.register_grammar(AlphaExt).unwrap();
723
724 let composed = registry.compose_grammar("base = { \"x\" }");
725 let alpha_idx = composed.find("alpha_stmt").unwrap();
726 let beta_idx = composed.find("beta_stmt").unwrap();
727 assert!(alpha_idx < beta_idx);
728 }
729
730 #[test]
731 fn test_parse_context() {
732 use proc_macro2::Span;
733 let roles = vec![
734 Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap(),
735 Role::new(proc_macro2::Ident::new("Bob", Span::call_site())).unwrap(),
736 ];
737
738 let context = ParseContext {
739 declared_roles: &roles,
740 input: "test input",
741 };
742
743 assert_eq!(context.declared_roles.len(), 2);
744 assert_eq!(context.input, "test input");
745 }
746}