1use crate::ast::{LocalType, Role};
8use crate::compiler::projection::ProjectionError;
9use std::any::{Any, TypeId};
10use std::collections::BTreeMap;
11use std::fmt::Debug;
12
13#[derive(Debug, Clone)]
15pub struct ExtensionDocumentation {
16 pub overview: String,
17 pub syntax_guide: String,
18 pub use_cases: Vec<String>,
19 pub limitations: Vec<String>,
20 pub see_also: Vec<String>,
21}
22
23impl Default for ExtensionDocumentation {
24 fn default() -> Self {
25 Self {
26 overview: "No documentation provided".to_string(),
27 syntax_guide: "No syntax guide provided".to_string(),
28 use_cases: vec![],
29 limitations: vec![],
30 see_also: vec![],
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct ExtensionExample {
38 pub title: String,
39 pub description: String,
40 pub code: String,
41 pub expected_output: Option<String>,
42}
43
44pub trait GrammarExtension: Send + Sync + Debug {
46 fn grammar_rules(&self) -> &'static str;
48
49 fn statement_rules(&self) -> Vec<&'static str>;
51
52 fn priority(&self) -> u32 {
54 100
55 }
56
57 fn extension_id(&self) -> &'static str;
59}
60
61pub trait DocumentedGrammarExtension: GrammarExtension {
63 fn documentation(&self) -> ExtensionDocumentation {
65 ExtensionDocumentation::default()
66 }
67
68 fn examples(&self) -> Vec<ExtensionExample> {
70 vec![]
71 }
72
73 fn rule_descriptions(&self) -> std::collections::HashMap<String, String> {
75 std::collections::HashMap::new()
76 }
77}
78
79pub trait StatementParser: Send + Sync + Debug {
81 fn can_parse(&self, rule_name: &str) -> bool;
83
84 fn supported_rules(&self) -> Vec<String>;
86
87 fn parse_statement(
97 &self,
98 rule_name: &str,
99 content: &str,
100 context: &ParseContext,
101 ) -> Result<Box<dyn ProtocolExtension>, ParseError>;
102}
103
104pub trait ProtocolExtension: Send + Sync + Debug {
106 fn type_name(&self) -> &'static str;
108
109 fn mentions_role(&self, role: &Role) -> bool;
111
112 fn validate(&self, roles: &[Role]) -> Result<(), ExtensionValidationError>;
114
115 fn project(
117 &self,
118 role: &Role,
119 context: &ProjectionContext,
120 ) -> Result<LocalType, ProjectionError>;
121
122 fn generate_code(&self, context: &CodegenContext) -> proc_macro2::TokenStream;
124
125 fn as_any(&self) -> &dyn Any;
127 fn as_any_mut(&mut self) -> &mut dyn Any;
128 fn type_id(&self) -> TypeId;
129}
130
131#[derive(Debug, Default)]
133pub struct ExtensionRegistry {
134 grammar_extensions: BTreeMap<String, Box<dyn GrammarExtension>>,
135 statement_parsers: BTreeMap<String, Box<dyn StatementParser>>,
136 rule_to_parser: BTreeMap<String, String>,
137 rule_conflicts: BTreeMap<String, Vec<String>>,
139 extension_dependencies: BTreeMap<String, Vec<String>>,
141 extension_versions: BTreeMap<String, String>,
143}
144
145impl ExtensionRegistry {
146 pub fn new() -> Self {
148 Self::default()
149 }
150
151 pub fn register_grammar<T: GrammarExtension + 'static>(
153 &mut self,
154 extension: T,
155 ) -> Result<(), ParseError> {
156 let id = extension.extension_id().to_string();
157 let rules = extension.statement_rules();
158 let priority = extension.priority();
159
160 for rule in &rules {
162 if let Some(existing_id) = self.rule_to_parser.get(*rule) {
163 let existing_priority = self
164 .grammar_extensions
165 .get(existing_id)
166 .map(|e| e.priority())
167 .unwrap_or(0);
168
169 if priority > existing_priority {
170 self.rule_conflicts
172 .entry((*rule).to_string())
173 .or_default()
174 .push(existing_id.clone());
175 self.rule_to_parser.insert((*rule).to_string(), id.clone());
176 } else if priority == existing_priority {
177 return Err(ParseError::PriorityConflict {
179 extension1: existing_id.clone(),
180 extension2: id.clone(),
181 priority1: existing_priority,
182 priority2: priority,
183 rule: (*rule).to_string(),
184 });
185 }
186 } else {
188 self.rule_to_parser.insert((*rule).to_string(), id.clone());
189 }
190 }
191
192 self.grammar_extensions
193 .insert(id.clone(), Box::new(extension));
194 self.extension_versions
196 .entry(id)
197 .or_insert_with(|| "0.1.0".to_string());
198 Ok(())
199 }
200
201 pub fn register_parser<T: StatementParser + 'static>(&mut self, parser: T, parser_id: String) {
203 self.statement_parsers.insert(parser_id, Box::new(parser));
204 }
205
206 pub fn compose_grammar(&self, base_grammar: &str) -> String {
208 let mut composed = base_grammar.to_string();
209
210 let mut extensions: Vec<_> = self.grammar_extensions.iter().collect();
212 extensions.sort_by(|(id_a, ext_a), (id_b, ext_b)| {
213 std::cmp::Reverse(ext_a.priority())
214 .cmp(&std::cmp::Reverse(ext_b.priority()))
215 .then_with(|| id_a.cmp(id_b))
216 });
217
218 for (_, extension) in extensions {
219 composed.push('\n');
220 composed.push_str(extension.grammar_rules());
221 }
222
223 composed
224 }
225
226 pub fn find_parser(&self, rule_name: &str) -> Option<&dyn StatementParser> {
228 if let Some(parser_id) = self.rule_to_parser.get(rule_name) {
229 self.statement_parsers.get(parser_id).map(|p| p.as_ref())
230 } else {
231 None
232 }
233 }
234
235 pub fn can_handle(&self, rule_name: &str) -> bool {
237 self.rule_to_parser.contains_key(rule_name)
238 }
239
240 pub fn has_extensions(&self) -> bool {
242 !self.grammar_extensions.is_empty() || !self.statement_parsers.is_empty()
243 }
244
245 pub fn grammar_extensions(&self) -> impl Iterator<Item = &dyn GrammarExtension> {
247 let mut ordered: Vec<_> = self.grammar_extensions.iter().collect();
248 ordered.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b));
249 ordered.into_iter().map(|(_, e)| e.as_ref())
250 }
251
252 pub fn has_extension(&self, extension_id: &str) -> bool {
254 self.grammar_extensions.contains_key(extension_id)
255 }
256
257 pub fn get_parser_for_rule(&self, rule_name: &str) -> Option<&str> {
259 self.rule_to_parser.get(rule_name).map(String::as_str)
260 }
261
262 pub fn get_statement_parser(&self, parser_id: &str) -> Option<&dyn StatementParser> {
264 self.statement_parsers.get(parser_id).map(|p| p.as_ref())
265 }
266
267 pub fn add_dependency(&mut self, dependent: &str, required: &str) {
269 self.extension_dependencies
270 .entry(dependent.to_string())
271 .or_default()
272 .push(required.to_string());
273 }
274
275 pub fn validate_dependencies(&self) -> Result<(), ParseError> {
277 for (dependent, requirements) in &self.extension_dependencies {
278 for required in requirements {
279 if !self.grammar_extensions.contains_key(required) {
280 return Err(ParseError::MissingDependency {
281 extension: dependent.clone(),
282 dependency: required.clone(),
283 });
284 }
285 }
286 }
287 Ok(())
288 }
289
290 pub fn get_conflicts(&self) -> &BTreeMap<String, Vec<String>> {
292 &self.rule_conflicts
293 }
294
295 pub fn get_detailed_conflicts(&self) -> Vec<String> {
297 let mut details = Vec::new();
298 let unknown_ext = "unknown".to_string();
299
300 let mut conflicts: Vec<_> = self.rule_conflicts.iter().collect();
301 conflicts.sort_by(|(rule_a, _), (rule_b, _)| rule_a.cmp(rule_b));
302
303 for (rule, conflicting_extensions) in conflicts {
304 if !conflicting_extensions.is_empty() {
305 let active_extension = self.rule_to_parser.get(rule).unwrap_or(&unknown_ext);
306 let active_priority = self
307 .grammar_extensions
308 .get(active_extension)
309 .map(|e| e.priority())
310 .unwrap_or(0);
311
312 let mut conflicting_extensions = conflicting_extensions.clone();
313 conflicting_extensions.sort();
314
315 for conflicting in &conflicting_extensions {
316 let conflicting_priority = self
317 .grammar_extensions
318 .get(conflicting)
319 .map(|e| e.priority())
320 .unwrap_or(0);
321
322 details.push(format!(
323 "Rule '{}': Extension '{}' (priority {}) overrode '{}' (priority {}). \
324 To resolve: 1) Adjust priorities, 2) Use different rule names, or 3) Merge functionality.",
325 rule, active_extension, active_priority, conflicting, conflicting_priority
326 ));
327 }
328 }
329 }
330
331 details
332 }
333
334 pub fn check_compatibility(&self, extension_ids: &[&str]) -> Result<(), ParseError> {
336 let mut rules_used = BTreeMap::new();
338
339 for &extension_id in extension_ids {
340 if let Some(extension) = self.grammar_extensions.get(extension_id) {
341 for rule in extension.statement_rules() {
342 if let Some(existing) = rules_used.get(rule) {
343 if existing != &extension_id {
344 return Err(ParseError::IncompatibleExtensions {
345 details: format!(
346 "Extensions '{}' and '{}' both define rule '{}'. Use different rule names or register extensions with different priorities.",
347 existing, extension_id, rule
348 ),
349 });
350 }
351 }
352 rules_used.insert(rule.to_string(), extension_id);
353 }
354 }
355 }
356 Ok(())
357 }
358
359 pub fn with_builtin_extensions() -> Self {
361 let mut registry = Self::new();
362
363 registry
365 .register_grammar(timeout::TimeoutGrammarExtension)
366 .expect("builtin timeout extension should register successfully");
367 registry.register_parser(timeout::TimeoutStatementParser, "timeout".to_string());
368
369 registry
370 }
371
372 pub fn for_third_party() -> Self {
374 Self::new()
375 }
376
377 pub fn generate_docs(&self) -> String {
379 let mut docs = String::from("# Extension Documentation\n\n");
380
381 let mut entries: Vec<_> = self.grammar_extensions.iter().collect();
382 entries.sort_by(|(id_a, _), (id_b, _)| id_a.cmp(id_b));
383
384 for (id, extension) in entries {
385 docs.push_str(&format!("## {}\n\n", id));
386 docs.push_str(&format!("**Priority:** {}\n\n", extension.priority()));
387 docs.push_str(&format!(
388 "**Rules:** {}\n\n",
389 extension.statement_rules().join(", ")
390 ));
391
392 if let Some(version) = self.extension_versions.get(id) {
393 docs.push_str(&format!("**Version:** {}\n\n", version));
394 }
395
396 docs.push_str("**Grammar:**\n```\n");
397 docs.push_str(extension.grammar_rules());
398 docs.push_str("\n```\n\n");
399 }
400
401 docs
402 }
403}
404
405#[derive(Debug)]
407pub struct ParseContext<'a> {
408 pub declared_roles: &'a [Role],
410 pub input: &'a str,
412}
413
414#[derive(Debug)]
416pub struct ProjectionContext<'a> {
417 pub all_roles: &'a [Role],
419 pub current_role: &'a Role,
421}
422
423#[derive(Debug)]
425pub struct CodegenContext<'a> {
426 pub choreography_name: &'a str,
428 pub roles: &'a [Role],
430 pub namespace: Option<&'a str>,
432}
433
434impl<'a> Default for CodegenContext<'a> {
435 fn default() -> Self {
436 Self {
437 choreography_name: "Default",
438 roles: &[],
439 namespace: None,
440 }
441 }
442}
443
444#[derive(Debug, thiserror::Error)]
446pub enum ParseError {
447 #[error("Syntax error: {message}")]
448 Syntax { message: String },
449
450 #[error("Unknown role '{role}' used in extension")]
451 UnknownRole { role: String },
452
453 #[error("Invalid extension syntax: {details}")]
454 InvalidSyntax { details: String },
455
456 #[error("Extension conflict: {message}")]
457 Conflict { message: String },
458
459 #[error("Extension priority conflict: Extension '{extension1}' (priority {priority1}) conflicts with '{extension2}' (priority {priority2}) for rule '{rule}'. Consider adjusting priorities or using different rule names.")]
460 PriorityConflict {
461 extension1: String,
462 extension2: String,
463 priority1: u32,
464 priority2: u32,
465 rule: String,
466 },
467
468 #[error("Missing dependency: Extension '{extension}' requires '{dependency}' which is not registered. Please register the required extension first.")]
469 MissingDependency {
470 extension: String,
471 dependency: String,
472 },
473
474 #[error("Extension registration failed: Extension '{extension}' with rule '{rule}' cannot be registered. {details}")]
475 RegistrationFailed {
476 extension: String,
477 rule: String,
478 details: String,
479 },
480
481 #[error("Incompatible extensions: {details}")]
482 IncompatibleExtensions { details: String },
483}
484
485#[derive(Debug, thiserror::Error)]
487pub enum ExtensionValidationError {
488 #[error("Role '{role}' not declared")]
489 UndeclaredRole { role: String },
490
491 #[error("Invalid protocol structure: {reason}")]
492 InvalidStructure { reason: String },
493
494 #[error("Extension validation failed: {message}")]
495 ExtensionFailed { message: String },
496}
497
498#[macro_export]
500macro_rules! register_extension {
501 ($registry:expr, $extension:expr) => {{
502 let ext = $extension;
503 let id = ext.extension_id().to_string();
504 $registry.register_grammar(ext);
505 }};
506}
507
508pub trait RegisterExtension {
510 fn register_all(registry: &mut ExtensionRegistry);
511}
512
513pub mod discovery;
514pub mod timeout;
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[derive(Debug)]
522 struct MockGrammarExtension;
523
524 impl GrammarExtension for MockGrammarExtension {
525 fn grammar_rules(&self) -> &'static str {
526 "timeout_stmt = { \"timeout\" ~ integer ~ protocol_block }"
527 }
528
529 fn statement_rules(&self) -> Vec<&'static str> {
530 vec!["timeout_stmt"]
531 }
532
533 fn extension_id(&self) -> &'static str {
534 "mock_timeout"
535 }
536 }
537
538 #[test]
539 fn test_extension_registry() {
540 let mut registry = ExtensionRegistry::new();
541
542 registry
544 .register_grammar(MockGrammarExtension)
545 .expect("extension registration should succeed");
546
547 assert!(registry.can_handle("timeout_stmt"));
549 assert!(!registry.can_handle("unknown_rule"));
550
551 let base = "basic_rule = { \"test\" }";
553 let composed = registry.compose_grammar(base);
554 assert!(composed.contains("basic_rule"));
555 assert!(composed.contains("timeout_stmt"));
556 }
557
558 #[test]
559 fn test_enhanced_error_messages() {
560 use crate::extensions::ParseError;
561
562 let err = ParseError::PriorityConflict {
564 extension1: "ext1".to_string(),
565 extension2: "ext2".to_string(),
566 priority1: 100,
567 priority2: 100,
568 rule: "test_rule".to_string(),
569 };
570 assert!(err.to_string().contains("Consider adjusting priorities"));
571
572 let err = ParseError::MissingDependency {
574 extension: "dependent_ext".to_string(),
575 dependency: "required_ext".to_string(),
576 };
577 assert!(err
578 .to_string()
579 .contains("Please register the required extension first"));
580
581 let err = ParseError::IncompatibleExtensions {
583 details: "Test incompatibility".to_string(),
584 };
585 assert!(err.to_string().contains("Incompatible extensions"));
586 }
587
588 #[test]
589 fn test_detailed_conflicts() {
590 #[derive(Debug)]
591 struct TestExt1;
592 impl GrammarExtension for TestExt1 {
593 fn grammar_rules(&self) -> &'static str {
594 "rule1 = { \"test1\" }"
595 }
596 fn statement_rules(&self) -> Vec<&'static str> {
597 vec!["rule1"]
598 }
599 fn priority(&self) -> u32 {
600 200
601 }
602 fn extension_id(&self) -> &'static str {
603 "test_ext1"
604 }
605 }
606
607 #[derive(Debug)]
608 struct TestExt2;
609 impl GrammarExtension for TestExt2 {
610 fn grammar_rules(&self) -> &'static str {
611 "rule1 = { \"test2\" }"
612 }
613 fn statement_rules(&self) -> Vec<&'static str> {
614 vec!["rule1"]
615 }
616 fn priority(&self) -> u32 {
617 100
618 }
619 fn extension_id(&self) -> &'static str {
620 "test_ext2"
621 }
622 }
623
624 let mut registry = ExtensionRegistry::new();
625
626 registry
628 .register_grammar(TestExt2)
629 .expect("lower priority extension should register");
630 registry
632 .register_grammar(TestExt1)
633 .expect("higher priority extension should override");
634
635 let conflicts = registry.get_detailed_conflicts();
636 assert!(!conflicts.is_empty());
637 assert!(conflicts[0].contains("overrode"));
638 assert!(conflicts[0].contains("priority"));
639 }
640
641 #[test]
642 fn test_documentation_system() {
643 let mut registry = ExtensionRegistry::new();
644
645 registry
646 .extension_versions
647 .insert("mock_timeout".to_string(), "1.0.0".to_string());
648 registry
649 .register_grammar(MockGrammarExtension)
650 .expect("grammar extension should register");
651
652 let docs = registry.generate_docs();
654 assert!(docs.contains("# Extension Documentation"));
655 assert!(docs.contains("mock_timeout"));
656 assert!(docs.contains("**Priority:** 100"));
657 assert!(docs.contains("**Version:** 1.0.0"));
658
659 assert_eq!(
660 registry.extension_versions.get("mock_timeout"),
661 Some(&"1.0.0".to_string())
662 );
663 }
664
665 #[test]
666 fn test_compose_grammar_is_stable_for_equal_priorities() {
667 #[derive(Debug)]
668 struct AlphaExt;
669 impl GrammarExtension for AlphaExt {
670 fn grammar_rules(&self) -> &'static str {
671 "alpha_stmt = { \"alpha\" }"
672 }
673 fn statement_rules(&self) -> Vec<&'static str> {
674 vec!["alpha_stmt"]
675 }
676 fn priority(&self) -> u32 {
677 100
678 }
679 fn extension_id(&self) -> &'static str {
680 "alpha_ext"
681 }
682 }
683
684 #[derive(Debug)]
685 struct BetaExt;
686 impl GrammarExtension for BetaExt {
687 fn grammar_rules(&self) -> &'static str {
688 "beta_stmt = { \"beta\" }"
689 }
690 fn statement_rules(&self) -> Vec<&'static str> {
691 vec!["beta_stmt"]
692 }
693 fn priority(&self) -> u32 {
694 100
695 }
696 fn extension_id(&self) -> &'static str {
697 "beta_ext"
698 }
699 }
700
701 let mut registry = ExtensionRegistry::new();
702 registry.register_grammar(BetaExt).unwrap();
703 registry.register_grammar(AlphaExt).unwrap();
704
705 let composed = registry.compose_grammar("base = { \"x\" }");
706 let alpha_idx = composed.find("alpha_stmt").unwrap();
707 let beta_idx = composed.find("beta_stmt").unwrap();
708 assert!(alpha_idx < beta_idx);
709 }
710
711 #[test]
712 fn test_parse_context() {
713 use proc_macro2::Span;
714 let roles = vec![
715 Role::new(proc_macro2::Ident::new("Alice", Span::call_site())).unwrap(),
716 Role::new(proc_macro2::Ident::new("Bob", Span::call_site())).unwrap(),
717 ];
718
719 let context = ParseContext {
720 declared_roles: &roles,
721 input: "test input",
722 };
723
724 assert_eq!(context.declared_roles.len(), 2);
725 assert_eq!(context.input, "test input");
726 }
727}