1use super::choreography::Choreography;
34use super::protocol::Protocol;
35use super::Annotations;
36use std::collections::HashMap;
37use std::fmt;
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub enum OperationStep {
44 Send(usize),
46 Recv(usize),
48 Branch(String),
50 Select(String),
52 Loop(usize),
54 Rec(String),
56}
57
58impl fmt::Display for OperationStep {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 match self {
61 OperationStep::Send(n) => write!(f, "send:{}", n),
62 OperationStep::Recv(n) => write!(f, "recv:{}", n),
63 OperationStep::Branch(label) => write!(f, "branch:{}", label),
64 OperationStep::Select(label) => write!(f, "select:{}", label),
65 OperationStep::Loop(n) => write!(f, "loop:{}", n),
66 OperationStep::Rec(label) => write!(f, "rec:{}", label),
67 }
68 }
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
77pub struct OperationPath(Vec<OperationStep>);
78
79impl OperationPath {
80 #[must_use]
82 pub fn new() -> Self {
83 Self(Vec::new())
84 }
85
86 #[must_use]
88 pub fn from_steps(steps: Vec<OperationStep>) -> Self {
89 Self(steps)
90 }
91
92 #[must_use]
94 pub fn push(&self, step: OperationStep) -> Self {
95 let mut steps = self.0.clone();
96 steps.push(step);
97 Self(steps)
98 }
99
100 #[must_use]
102 pub fn steps(&self) -> &[OperationStep] {
103 &self.0
104 }
105
106 #[must_use]
108 pub fn is_empty(&self) -> bool {
109 self.0.is_empty()
110 }
111
112 #[must_use]
114 pub fn len(&self) -> usize {
115 self.0.len()
116 }
117}
118
119impl fmt::Display for OperationPath {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 if self.0.is_empty() {
122 write!(f, "<root>")
123 } else {
124 let parts: Vec<String> = self.0.iter().map(|s| s.to_string()).collect();
125 write!(f, "{}", parts.join("/"))
126 }
127 }
128}
129
130#[derive(Debug, Clone, Default)]
134pub struct OperationHints {
135 pub parallel: bool,
137
138 pub min_responses: Option<u32>,
141
142 pub ordered: bool,
145}
146
147impl OperationHints {
148 #[must_use]
150 pub fn parallel() -> Self {
151 Self {
152 parallel: true,
153 ..Default::default()
154 }
155 }
156
157 #[must_use]
159 pub fn with_min_responses(min: u32) -> Self {
160 Self {
161 min_responses: Some(min),
162 ..Default::default()
163 }
164 }
165
166 #[must_use]
168 pub fn parallel_ordered() -> Self {
169 Self {
170 parallel: true,
171 ordered: true,
172 ..Default::default()
173 }
174 }
175
176 #[must_use]
178 pub fn with_parallel(mut self) -> Self {
179 self.parallel = true;
180 self
181 }
182
183 #[must_use]
185 pub fn sequential(mut self) -> Self {
186 self.parallel = false;
187 self
188 }
189
190 #[must_use]
192 pub fn set_min_responses(mut self, min: Option<u32>) -> Self {
193 self.min_responses = min;
194 self
195 }
196
197 #[must_use]
199 pub fn with_ordered(mut self) -> Self {
200 self.ordered = true;
201 self
202 }
203
204 #[must_use]
206 pub fn unordered(mut self) -> Self {
207 self.ordered = false;
208 self
209 }
210
211 #[must_use]
213 pub fn merge(&self, other: &Self) -> Self {
214 Self {
215 parallel: self.parallel || other.parallel,
216 min_responses: self.min_responses.or(other.min_responses),
217 ordered: self.ordered || other.ordered,
218 }
219 }
220}
221
222#[derive(Debug, Clone, Default)]
227pub struct ExecutionHints {
228 hints: HashMap<OperationPath, OperationHints>,
230
231 role: Option<String>,
233}
234
235impl ExecutionHints {
236 #[must_use]
238 pub fn new() -> Self {
239 Self::default()
240 }
241
242 #[must_use]
244 pub fn for_role(role: impl Into<String>) -> Self {
245 Self {
246 hints: HashMap::new(),
247 role: Some(role.into()),
248 }
249 }
250
251 #[must_use]
253 pub fn role(&self) -> Option<&str> {
254 self.role.as_deref()
255 }
256
257 pub fn insert(&mut self, path: OperationPath, hints: OperationHints) {
259 self.hints.insert(path, hints);
260 }
261
262 #[must_use]
264 pub fn get(&self, path: &OperationPath) -> Option<&OperationHints> {
265 self.hints.get(path)
266 }
267
268 #[must_use]
270 pub fn is_parallel(&self, path: &OperationPath) -> bool {
271 self.get(path).map(|h| h.parallel).unwrap_or(false)
272 }
273
274 #[must_use]
276 pub fn min_responses(&self, path: &OperationPath) -> Option<u32> {
277 self.get(path).and_then(|h| h.min_responses)
278 }
279
280 #[must_use]
282 pub fn is_ordered(&self, path: &OperationPath) -> bool {
283 self.get(path).map(|h| h.ordered).unwrap_or(false)
284 }
285
286 #[must_use]
288 pub fn is_empty(&self) -> bool {
289 self.hints.is_empty()
290 }
291
292 #[must_use]
294 pub fn len(&self) -> usize {
295 self.hints.len()
296 }
297
298 pub fn iter(&self) -> impl Iterator<Item = (&OperationPath, &OperationHints)> {
300 self.hints.iter()
301 }
302
303 #[must_use]
305 pub fn merge(&self, other: &Self) -> Self {
306 let mut merged = self.clone();
307 for (path, hints) in &other.hints {
308 merged
309 .hints
310 .entry(path.clone())
311 .and_modify(|h| *h = h.merge(hints))
312 .or_insert_with(|| hints.clone());
313 }
314 merged
315 }
316
317 #[must_use]
322 pub fn extract_from_protocol(protocol: &Protocol) -> Self {
323 let mut hints = Self::new();
324 let mut counters = HintExtractionCounters::default();
325 Self::extract_recursive(protocol, &OperationPath::new(), &mut hints, &mut counters);
326 hints
327 }
328
329 fn extract_recursive(
331 protocol: &Protocol,
332 path: &OperationPath,
333 hints: &mut ExecutionHints,
334 counters: &mut HintExtractionCounters,
335 ) {
336 match protocol {
337 Protocol::Begin { continuation, .. }
338 | Protocol::Await { continuation, .. }
339 | Protocol::Resolve { continuation, .. }
340 | Protocol::Invalidate { continuation, .. } => {
341 Self::extract_recursive(continuation, path, hints, counters);
342 }
343 Protocol::Send {
344 annotations,
345 continuation,
346 ..
347 } => {
348 let send_path = path.push(OperationStep::Send(counters.send_count));
349 counters.send_count += 1;
350
351 if let Some(op_hints) = Self::hints_from_annotations(annotations) {
352 hints.insert(send_path.clone(), op_hints);
353 }
354
355 Self::extract_recursive(continuation, &send_path, hints, counters);
356 }
357
358 Protocol::Broadcast {
359 annotations,
360 continuation,
361 ..
362 } => {
363 let send_path = path.push(OperationStep::Send(counters.send_count));
364 counters.send_count += 1;
365
366 if let Some(op_hints) = Self::hints_from_annotations(annotations) {
367 hints.insert(send_path.clone(), op_hints);
368 }
369
370 Self::extract_recursive(continuation, &send_path, hints, counters);
371 }
372
373 Protocol::Choice {
374 branches,
375 annotations,
376 ..
377 } => {
378 if let Some(op_hints) = Self::hints_from_annotations(annotations) {
380 hints.insert(path.clone(), op_hints);
381 }
382
383 for branch in branches.as_slice() {
384 let branch_path = path.push(OperationStep::Branch(branch.label.to_string()));
385 let mut branch_counters = HintExtractionCounters::default();
386 Self::extract_recursive(
387 &branch.protocol,
388 &branch_path,
389 hints,
390 &mut branch_counters,
391 );
392 }
393 }
394 Protocol::Case { branches, .. } => {
395 for branch in branches.as_slice() {
396 let branch_path =
397 path.push(OperationStep::Branch(branch.pattern.constructor.clone()));
398 let mut branch_counters = HintExtractionCounters::default();
399 Self::extract_recursive(
400 &branch.protocol,
401 &branch_path,
402 hints,
403 &mut branch_counters,
404 );
405 }
406 }
407 Protocol::Timeout {
408 body,
409 on_timeout,
410 on_cancel,
411 ..
412 } => {
413 Self::extract_recursive(body, path, hints, counters);
414 let timeout_path = path.push(OperationStep::Branch("timeout".to_string()));
415 let mut timeout_counters = HintExtractionCounters::default();
416 Self::extract_recursive(on_timeout, &timeout_path, hints, &mut timeout_counters);
417 if let Some(on_cancel) = on_cancel.as_deref() {
418 let cancel_path = path.push(OperationStep::Branch("cancel".to_string()));
419 let mut cancel_counters = HintExtractionCounters::default();
420 Self::extract_recursive(on_cancel, &cancel_path, hints, &mut cancel_counters);
421 }
422 }
423
424 Protocol::Loop { body, .. } => {
425 let loop_path = path.push(OperationStep::Loop(counters.loop_count));
426 counters.loop_count += 1;
427 let mut loop_counters = HintExtractionCounters::default();
428 Self::extract_recursive(body, &loop_path, hints, &mut loop_counters);
429 }
430
431 Protocol::Rec { label, body } => {
432 let rec_path = path.push(OperationStep::Rec(label.to_string()));
433 let mut rec_counters = HintExtractionCounters::default();
434 Self::extract_recursive(body, &rec_path, hints, &mut rec_counters);
435 }
436 Protocol::Publish { continuation, .. }
437 | Protocol::PublishAuthority { continuation, .. }
438 | Protocol::Materialize { continuation, .. }
439 | Protocol::Handoff { continuation, .. }
440 | Protocol::DependentWork { continuation, .. } => {
441 Self::extract_recursive(continuation, path, hints, counters);
442 }
443
444 Protocol::Parallel { protocols } => {
445 for (i, proto) in protocols.as_slice().iter().enumerate() {
446 let parallel_path = path.push(OperationStep::Loop(i)); let mut parallel_counters = HintExtractionCounters::default();
448 Self::extract_recursive(proto, ¶llel_path, hints, &mut parallel_counters);
449 }
450 }
451
452 Protocol::Extension {
453 annotations,
454 continuation,
455 ..
456 } => {
457 if let Some(op_hints) = Self::hints_from_annotations(annotations) {
458 hints.insert(path.clone(), op_hints);
459 }
460 Self::extract_recursive(continuation, path, hints, counters);
461 }
462 Protocol::Let { continuation, .. } => {
463 Self::extract_recursive(continuation, path, hints, counters);
464 }
465
466 Protocol::Var(_) | Protocol::End => {
467 }
469 }
470 }
471
472 fn hints_from_annotations(annotations: &Annotations) -> Option<OperationHints> {
474 let parallel = annotations.has_parallel();
475 let ordered = annotations.has_ordered();
476 let min_responses = annotations.min_responses();
477
478 if parallel || ordered || min_responses.is_some() {
479 Some(OperationHints {
480 parallel,
481 ordered,
482 min_responses,
483 })
484 } else {
485 None
486 }
487 }
488}
489
490#[derive(Default)]
492struct HintExtractionCounters {
493 send_count: usize,
494 loop_count: usize,
495}
496
497#[derive(Debug)]
503pub struct ChoreographyWithHints {
504 pub choreography: Choreography,
506 pub hints: ExecutionHints,
508}
509
510impl ChoreographyWithHints {
511 #[must_use]
515 pub fn from_choreography(choreography: Choreography) -> Self {
516 let hints = ExecutionHints::extract_from_protocol(&choreography.protocol);
517 Self {
518 choreography,
519 hints,
520 }
521 }
522
523 #[must_use]
525 pub fn new(choreography: Choreography, hints: ExecutionHints) -> Self {
526 Self {
527 choreography,
528 hints,
529 }
530 }
531}
532
533#[derive(Debug, Default)]
535pub struct ExecutionHintsBuilder {
536 hints: ExecutionHints,
537 current_path: OperationPath,
538}
539
540impl ExecutionHintsBuilder {
541 #[must_use]
543 pub fn new() -> Self {
544 Self::default()
545 }
546
547 #[must_use]
549 pub fn for_role(role: impl Into<String>) -> Self {
550 Self {
551 hints: ExecutionHints::for_role(role),
552 current_path: OperationPath::new(),
553 }
554 }
555
556 #[must_use]
558 pub fn at_path(mut self, path: OperationPath) -> Self {
559 self.current_path = path;
560 self
561 }
562
563 #[must_use]
565 pub fn parallel(mut self) -> Self {
566 let hints = self
567 .hints
568 .hints
569 .entry(self.current_path.clone())
570 .or_default();
571 hints.parallel = true;
572 self
573 }
574
575 #[must_use]
577 pub fn min_responses(mut self, min: u32) -> Self {
578 let hints = self
579 .hints
580 .hints
581 .entry(self.current_path.clone())
582 .or_default();
583 hints.min_responses = Some(min);
584 self
585 }
586
587 #[must_use]
589 pub fn ordered(mut self) -> Self {
590 let hints = self
591 .hints
592 .hints
593 .entry(self.current_path.clone())
594 .or_default();
595 hints.ordered = true;
596 self
597 }
598
599 #[must_use]
601 pub fn with_hints(mut self, path: OperationPath, hints: OperationHints) -> Self {
602 self.hints.insert(path, hints);
603 self
604 }
605
606 #[must_use]
608 pub fn build(self) -> ExecutionHints {
609 self.hints
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_operation_path_display() {
619 let path = OperationPath::new();
620 assert_eq!(path.to_string(), "<root>");
621
622 let path = path.push(OperationStep::Send(0));
623 assert_eq!(path.to_string(), "send:0");
624
625 let path = path.push(OperationStep::Branch("Accept".to_string()));
626 assert_eq!(path.to_string(), "send:0/branch:Accept");
627
628 let path = path.push(OperationStep::Recv(1));
629 assert_eq!(path.to_string(), "send:0/branch:Accept/recv:1");
630 }
631
632 #[test]
633 fn test_execution_hints_basic() {
634 let mut hints = ExecutionHints::new();
635 let path = OperationPath::from_steps(vec![OperationStep::Send(0)]);
636
637 hints.insert(path.clone(), OperationHints::parallel());
638
639 assert!(hints.is_parallel(&path));
640 assert!(!hints.is_ordered(&path));
641 assert_eq!(hints.min_responses(&path), None);
642 }
643
644 #[test]
645 fn test_execution_hints_min_responses() {
646 let mut hints = ExecutionHints::new();
647 let path = OperationPath::from_steps(vec![OperationStep::Recv(0)]);
648
649 hints.insert(
650 path.clone(),
651 OperationHints::with_min_responses(3).with_parallel(),
652 );
653
654 assert!(hints.is_parallel(&path));
655 assert_eq!(hints.min_responses(&path), Some(3));
656 }
657
658 #[test]
659 fn test_execution_hints_builder() {
660 let hints = ExecutionHintsBuilder::for_role("Coordinator")
661 .at_path(OperationPath::from_steps(vec![OperationStep::Send(0)]))
662 .parallel()
663 .at_path(OperationPath::from_steps(vec![OperationStep::Recv(0)]))
664 .min_responses(3)
665 .ordered()
666 .build();
667
668 let send_path = OperationPath::from_steps(vec![OperationStep::Send(0)]);
669 let recv_path = OperationPath::from_steps(vec![OperationStep::Recv(0)]);
670
671 assert!(hints.is_parallel(&send_path));
672 assert!(!hints.is_ordered(&send_path));
673
674 assert!(!hints.is_parallel(&recv_path));
675 assert!(hints.is_ordered(&recv_path));
676 assert_eq!(hints.min_responses(&recv_path), Some(3));
677 }
678
679 #[test]
680 fn test_operation_hints_merge() {
681 let h1 = OperationHints {
682 parallel: true,
683 min_responses: None,
684 ordered: false,
685 };
686 let h2 = OperationHints {
687 parallel: false,
688 min_responses: Some(3),
689 ordered: true,
690 };
691
692 let merged = h1.merge(&h2);
693 assert!(merged.parallel); assert_eq!(merged.min_responses, Some(3)); assert!(merged.ordered); }
697
698 #[test]
699 fn test_execution_hints_default_values() {
700 let hints = ExecutionHints::new();
701 let path = OperationPath::from_steps(vec![OperationStep::Send(0)]);
702
703 assert!(!hints.is_parallel(&path));
705 assert_eq!(hints.min_responses(&path), None);
706 assert!(!hints.is_ordered(&path));
707 }
708
709 #[test]
710 fn test_extract_from_protocol_with_parallel() {
711 use crate::ast::annotation::Annotations;
712 use crate::ast::role::Role;
713 use crate::ast::MessageType;
714 use proc_macro2::Ident;
715 use proc_macro2::Span;
716
717 let mut annotations = Annotations::new();
719 annotations.push(crate::ast::ProtocolAnnotation::Parallel);
720
721 let protocol = Protocol::Send {
722 from: Role::new(Ident::new("A", Span::call_site())).unwrap(),
723 to: Role::new(Ident::new("B", Span::call_site())).unwrap(),
724 message: MessageType {
725 name: Ident::new("Msg", Span::call_site()),
726 type_annotation: None,
727 payload: None,
728 },
729 continuation: Box::new(Protocol::End),
730 annotations,
731 from_annotations: Annotations::new(),
732 to_annotations: Annotations::new(),
733 };
734
735 let hints = ExecutionHints::extract_from_protocol(&protocol);
736 let path = OperationPath::from_steps(vec![OperationStep::Send(0)]);
737
738 assert!(hints.is_parallel(&path));
739 assert!(!hints.is_ordered(&path));
740 assert_eq!(hints.min_responses(&path), None);
741 }
742
743 #[test]
744 fn test_extract_from_protocol_with_min_responses() {
745 use crate::ast::annotation::Annotations;
746 use crate::ast::role::Role;
747 use crate::ast::MessageType;
748 use proc_macro2::Ident;
749 use proc_macro2::Span;
750
751 let mut annotations = Annotations::new();
753 annotations.push(crate::ast::ProtocolAnnotation::MinResponses(3));
754
755 let protocol = Protocol::Send {
756 from: Role::new(Ident::new("A", Span::call_site())).unwrap(),
757 to: Role::new(Ident::new("B", Span::call_site())).unwrap(),
758 message: MessageType {
759 name: Ident::new("Msg", Span::call_site()),
760 type_annotation: None,
761 payload: None,
762 },
763 continuation: Box::new(Protocol::End),
764 annotations,
765 from_annotations: Annotations::new(),
766 to_annotations: Annotations::new(),
767 };
768
769 let hints = ExecutionHints::extract_from_protocol(&protocol);
770 let path = OperationPath::from_steps(vec![OperationStep::Send(0)]);
771
772 assert!(!hints.is_parallel(&path));
773 assert_eq!(hints.min_responses(&path), Some(3));
774 }
775
776 #[test]
777 fn test_extract_from_protocol_combined() {
778 use crate::ast::annotation::Annotations;
779 use crate::ast::role::Role;
780 use crate::ast::MessageType;
781 use proc_macro2::Ident;
782 use proc_macro2::Span;
783
784 let mut annotations = Annotations::new();
786 annotations.push(crate::ast::ProtocolAnnotation::Parallel);
787 annotations.push(crate::ast::ProtocolAnnotation::Ordered);
788 annotations.push(crate::ast::ProtocolAnnotation::MinResponses(2));
789
790 let protocol = Protocol::Send {
791 from: Role::new(Ident::new("A", Span::call_site())).unwrap(),
792 to: Role::new(Ident::new("B", Span::call_site())).unwrap(),
793 message: MessageType {
794 name: Ident::new("Msg", Span::call_site()),
795 type_annotation: None,
796 payload: None,
797 },
798 continuation: Box::new(Protocol::End),
799 annotations,
800 from_annotations: Annotations::new(),
801 to_annotations: Annotations::new(),
802 };
803
804 let hints = ExecutionHints::extract_from_protocol(&protocol);
805 let path = OperationPath::from_steps(vec![OperationStep::Send(0)]);
806
807 assert!(hints.is_parallel(&path));
808 assert!(hints.is_ordered(&path));
809 assert_eq!(hints.min_responses(&path), Some(2));
810 }
811
812 #[test]
813 fn test_extract_no_hints_when_no_annotations() {
814 use crate::ast::annotation::Annotations;
815 use crate::ast::role::Role;
816 use crate::ast::MessageType;
817 use proc_macro2::Ident;
818 use proc_macro2::Span;
819
820 let protocol = Protocol::Send {
822 from: Role::new(Ident::new("A", Span::call_site())).unwrap(),
823 to: Role::new(Ident::new("B", Span::call_site())).unwrap(),
824 message: MessageType {
825 name: Ident::new("Msg", Span::call_site()),
826 type_annotation: None,
827 payload: None,
828 },
829 continuation: Box::new(Protocol::End),
830 annotations: Annotations::new(),
831 from_annotations: Annotations::new(),
832 to_annotations: Annotations::new(),
833 };
834
835 let hints = ExecutionHints::extract_from_protocol(&protocol);
836
837 assert!(hints.is_empty());
839 }
840}