1use crate::{Label, LocalTypeR};
43use std::collections::BTreeMap;
44use thiserror::Error;
45
46#[derive(Debug, Clone, Error)]
48pub enum MergeError {
49 #[error("cannot merge end with non-end type: {0:?}")]
51 EndMismatch(LocalTypeR),
52
53 #[error("partner mismatch in merge: expected {expected}, found {found}")]
55 PartnerMismatch { expected: String, found: String },
56
57 #[error("direction mismatch in merge: cannot merge send with recv")]
59 DirectionMismatch,
60
61 #[error("incompatible continuations for label '{label}'")]
63 IncompatibleContinuations { label: String },
64
65 #[error("payload annotation mismatch for label '{label}': left={left}, right={right}")]
67 PayloadAnnotationMismatch {
68 label: String,
69 left: String,
70 right: String,
71 },
72
73 #[error("send branch label mismatch: cannot merge sends with different labels '{left}' vs '{right}'")]
75 SendLabelMismatch { left: String, right: String },
76
77 #[error("send branch count mismatch: {left} labels vs {right} labels")]
79 SendBranchCountMismatch { left: usize, right: usize },
80
81 #[error("recursive variable mismatch: expected {expected}, found {found}")]
83 RecursiveVariableMismatch { expected: String, found: String },
84
85 #[error("type variable mismatch: expected {expected}, found {found}")]
87 VariableMismatch { expected: String, found: String },
88
89 #[error("cannot merge incompatible types")]
91 IncompatibleTypes,
92}
93
94pub type MergeResult = Result<LocalTypeR, MergeError>;
96
97pub fn merge(t1: &LocalTypeR, t2: &LocalTypeR) -> MergeResult {
132 if t1 == t2 {
133 return Ok(t1.clone());
134 }
135
136 match (t1, t2) {
137 (LocalTypeR::End, LocalTypeR::End) => Ok(LocalTypeR::End),
138
139 (LocalTypeR::End, other) | (other, LocalTypeR::End) => {
140 Err(MergeError::EndMismatch(other.clone()))
141 }
142
143 (
144 LocalTypeR::Send {
145 partner: p1,
146 branches: b1,
147 },
148 LocalTypeR::Send {
149 partner: p2,
150 branches: b2,
151 },
152 ) => merge_send_pair(p1, b1, p2, b2),
153
154 (
155 LocalTypeR::Recv {
156 partner: p1,
157 branches: b1,
158 },
159 LocalTypeR::Recv {
160 partner: p2,
161 branches: b2,
162 },
163 ) => merge_recv_pair(p1, b1, p2, b2),
164
165 (
166 LocalTypeR::Mu {
167 var: v1,
168 body: body1,
169 },
170 LocalTypeR::Mu {
171 var: v2,
172 body: body2,
173 },
174 ) => merge_recursive_pair(v1, body1, v2, body2),
175
176 (LocalTypeR::Var(v1), LocalTypeR::Var(v2)) => merge_var_pair(v1, v2),
177
178 (LocalTypeR::Send { .. }, LocalTypeR::Recv { .. })
179 | (LocalTypeR::Recv { .. }, LocalTypeR::Send { .. }) => Err(MergeError::DirectionMismatch),
180
181 _ => Err(MergeError::IncompatibleTypes),
182 }
183}
184
185fn merge_send_pair(
186 p1: &str,
187 b1: &[(Label, Option<crate::ValType>, LocalTypeR)],
188 p2: &str,
189 b2: &[(Label, Option<crate::ValType>, LocalTypeR)],
190) -> MergeResult {
191 if p1 != p2 {
192 return Err(MergeError::PartnerMismatch {
193 expected: p1.to_string(),
194 found: p2.to_string(),
195 });
196 }
197 let merged_branches = merge_send_branches(b1, b2)?;
198 Ok(LocalTypeR::Send {
199 partner: p1.to_string(),
200 branches: merged_branches,
201 })
202}
203
204fn merge_recv_pair(
205 p1: &str,
206 b1: &[(Label, Option<crate::ValType>, LocalTypeR)],
207 p2: &str,
208 b2: &[(Label, Option<crate::ValType>, LocalTypeR)],
209) -> MergeResult {
210 if p1 != p2 {
211 return Err(MergeError::PartnerMismatch {
212 expected: p1.to_string(),
213 found: p2.to_string(),
214 });
215 }
216 let merged_branches = merge_recv_branches(b1, b2)?;
217 Ok(LocalTypeR::Recv {
218 partner: p1.to_string(),
219 branches: merged_branches,
220 })
221}
222
223fn merge_recursive_pair(v1: &str, body1: &LocalTypeR, v2: &str, body2: &LocalTypeR) -> MergeResult {
224 if v1 != v2 {
225 return Err(MergeError::RecursiveVariableMismatch {
226 expected: v1.to_string(),
227 found: v2.to_string(),
228 });
229 }
230 let merged_body = merge(body1, body2)?;
231 Ok(LocalTypeR::Mu {
232 var: v1.to_string(),
233 body: Box::new(merged_body),
234 })
235}
236
237fn merge_var_pair(v1: &str, v2: &str) -> MergeResult {
238 if v1 != v2 {
239 return Err(MergeError::VariableMismatch {
240 expected: v1.to_string(),
241 found: v2.to_string(),
242 });
243 }
244 Ok(LocalTypeR::Var(v1.to_string()))
245}
246
247fn merge_payload_annotations(
248 label: &Label,
249 left: &Option<crate::ValType>,
250 right: &Option<crate::ValType>,
251) -> Result<Option<crate::ValType>, MergeError> {
252 if left == right {
253 return Ok(left.clone());
254 }
255 Err(MergeError::PayloadAnnotationMismatch {
256 label: label.name.clone(),
257 left: format!("{left:?}"),
258 right: format!("{right:?}"),
259 })
260}
261
262fn merge_send_branches(
275 branches1: &[(Label, Option<crate::ValType>, LocalTypeR)],
276 branches2: &[(Label, Option<crate::ValType>, LocalTypeR)],
277) -> Result<Vec<(Label, Option<crate::ValType>, LocalTypeR)>, MergeError> {
278 let mut sorted1: Vec<_> = branches1.to_vec();
280 let mut sorted2: Vec<_> = branches2.to_vec();
281 sorted1.sort_by(|a, b| a.0.name.cmp(&b.0.name));
282 sorted2.sort_by(|a, b| a.0.name.cmp(&b.0.name));
283
284 if sorted1.len() != sorted2.len() {
286 return Err(MergeError::SendBranchCountMismatch {
287 left: sorted1.len(),
288 right: sorted2.len(),
289 });
290 }
291
292 let mut result = Vec::with_capacity(sorted1.len());
294 for ((label1, vt1, cont1), (label2, vt2, cont2)) in sorted1.iter().zip(sorted2.iter()) {
295 if label1.name != label2.name {
297 return Err(MergeError::SendLabelMismatch {
298 left: label1.name.clone(),
299 right: label2.name.clone(),
300 });
301 }
302 if label1.sort != label2.sort {
303 return Err(MergeError::IncompatibleContinuations {
304 label: label1.name.clone(),
305 });
306 }
307
308 let merged_cont =
310 merge(cont1, cont2).map_err(|_| MergeError::IncompatibleContinuations {
311 label: label1.name.clone(),
312 })?;
313 let merged_vt = merge_payload_annotations(label1, vt1, vt2)?;
314
315 result.push((label1.clone(), merged_vt, merged_cont));
316 }
317
318 Ok(result)
319}
320
321fn merge_recv_branches(
334 branches1: &[(Label, Option<crate::ValType>, LocalTypeR)],
335 branches2: &[(Label, Option<crate::ValType>, LocalTypeR)],
336) -> Result<Vec<(Label, Option<crate::ValType>, LocalTypeR)>, MergeError> {
337 let mut result: BTreeMap<String, (Label, Option<crate::ValType>, LocalTypeR)> = BTreeMap::new();
338
339 for (label, vt, cont) in branches1 {
341 result.insert(
342 label.name.clone(),
343 (label.clone(), vt.clone(), cont.clone()),
344 );
345 }
346
347 for (label, vt, cont) in branches2 {
349 if let Some((existing_label, existing_vt, existing_cont)) = result.get(&label.name) {
350 let merged_cont =
352 merge(existing_cont, cont).map_err(|_| MergeError::IncompatibleContinuations {
353 label: label.name.clone(),
354 })?;
355 if existing_label.sort != label.sort {
357 return Err(MergeError::IncompatibleContinuations {
358 label: label.name.clone(),
359 });
360 }
361 let merged_vt = merge_payload_annotations(label, existing_vt, vt)?;
362 result.insert(label.name.clone(), (label.clone(), merged_vt, merged_cont));
363 } else {
364 result.insert(
366 label.name.clone(),
367 (label.clone(), vt.clone(), cont.clone()),
368 );
369 }
370 }
371
372 Ok(result.into_values().collect())
373}
374
375pub fn merge_all(types: &[LocalTypeR]) -> MergeResult {
385 match types {
386 [] => Err(MergeError::IncompatibleTypes),
387 [single] => Ok(single.clone()),
388 [first, rest @ ..] => {
389 let mut result = first.clone();
390 for t in rest {
391 result = merge(&result, t)?;
392 }
393 Ok(result)
394 }
395 }
396}
397
398#[must_use]
402pub fn can_merge(t1: &LocalTypeR, t2: &LocalTypeR) -> bool {
403 merge(t1, t2).is_ok()
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use crate::ValType;
410 use assert_matches::assert_matches;
411
412 #[test]
413 fn test_merge_identical_end() {
414 let result = merge(&LocalTypeR::End, &LocalTypeR::End).unwrap();
415 assert_eq!(result, LocalTypeR::End);
416 }
417
418 #[test]
419 fn test_merge_identical_send() {
420 let t = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
421 let result = merge(&t, &t).unwrap();
422 assert_eq!(result, t);
423 }
424
425 #[test]
430 fn test_merge_sends_same_labels_succeeds() {
431 let t1 = LocalTypeR::send("B", Label::new("x"), LocalTypeR::End);
433 let t2 = LocalTypeR::send("B", Label::new("x"), LocalTypeR::End);
434
435 let result = merge(&t1, &t2).unwrap();
436 assert_matches!(result, LocalTypeR::Send { partner, branches } => {
437 assert_eq!(partner, "B");
438 assert_eq!(branches.len(), 1);
439 assert_eq!(branches[0].0.name, "x");
440 });
441 }
442
443 #[test]
444 fn test_merge_sends_different_labels_fails() {
445 let t1 = LocalTypeR::send("B", Label::new("yes"), LocalTypeR::End);
448 let t2 = LocalTypeR::send("B", Label::new("no"), LocalTypeR::End);
449
450 let result = merge(&t1, &t2);
451 assert!(
452 matches!(result, Err(MergeError::SendLabelMismatch { .. })),
453 "Expected SendLabelMismatch, got {:?}",
454 result
455 );
456 }
457
458 #[test]
459 fn test_merge_sends_different_count_fails() {
460 let t1 = LocalTypeR::Send {
462 partner: "B".to_string(),
463 branches: vec![
464 (Label::new("x"), None, LocalTypeR::End),
465 (Label::new("y"), None, LocalTypeR::End),
466 ],
467 };
468 let t2 = LocalTypeR::send("B", Label::new("x"), LocalTypeR::End);
469
470 let result = merge(&t1, &t2);
471 assert!(
472 matches!(result, Err(MergeError::SendBranchCountMismatch { .. })),
473 "Expected SendBranchCountMismatch, got {:?}",
474 result
475 );
476 }
477
478 #[test]
479 fn test_merge_sends_payload_annotation_mismatch_fails() {
480 let t1 = LocalTypeR::Send {
481 partner: "B".to_string(),
482 branches: vec![(Label::new("x"), Some(ValType::Nat), LocalTypeR::End)],
483 };
484 let t2 = LocalTypeR::Send {
485 partner: "B".to_string(),
486 branches: vec![(Label::new("x"), Some(ValType::Bool), LocalTypeR::End)],
487 };
488
489 let result = merge(&t1, &t2);
490 assert!(matches!(
491 result,
492 Err(MergeError::PayloadAnnotationMismatch { .. })
493 ));
494 }
495
496 #[test]
497 fn test_merge_sends_payload_annotation_none_some_mismatch_fails() {
498 let t1 = LocalTypeR::Send {
499 partner: "B".to_string(),
500 branches: vec![(Label::new("x"), None, LocalTypeR::End)],
501 };
502 let t2 = LocalTypeR::Send {
503 partner: "B".to_string(),
504 branches: vec![(Label::new("x"), Some(ValType::Nat), LocalTypeR::End)],
505 };
506
507 let result = merge(&t1, &t2);
508 assert!(matches!(
509 result,
510 Err(MergeError::PayloadAnnotationMismatch { .. })
511 ));
512 }
513
514 #[test]
515 fn test_merge_sends_different_partners_fails() {
516 let t1 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
517 let t2 = LocalTypeR::send("C", Label::new("msg"), LocalTypeR::End);
518
519 let result = merge(&t1, &t2);
520 assert!(matches!(result, Err(MergeError::PartnerMismatch { .. })));
521 }
522
523 #[test]
528 fn test_merge_recvs_different_labels_succeeds() {
529 let t1 = LocalTypeR::recv("A", Label::new("x"), LocalTypeR::End);
532 let t2 = LocalTypeR::recv("A", Label::new("y"), LocalTypeR::End);
533
534 let result = merge(&t1, &t2).unwrap();
535 assert_matches!(result, LocalTypeR::Recv { partner, branches } => {
536 assert_eq!(partner, "A");
537 assert_eq!(branches.len(), 2);
538 let labels: Vec<_> = branches.iter().map(|(l, _, _)| l.name.as_str()).collect();
539 assert!(labels.contains(&"x"));
540 assert!(labels.contains(&"y"));
541 });
542 }
543
544 #[test]
545 fn test_merge_recvs_same_label_merges_continuations() {
546 let t1 = LocalTypeR::recv(
548 "A",
549 Label::new("x"),
550 LocalTypeR::send("B", Label::new("m"), LocalTypeR::End),
551 );
552 let t2 = LocalTypeR::recv(
553 "A",
554 Label::new("x"),
555 LocalTypeR::send("B", Label::new("m"), LocalTypeR::End),
556 );
557
558 let result = merge(&t1, &t2).unwrap();
559 assert_matches!(result, LocalTypeR::Recv { branches, .. } => {
560 assert_eq!(branches.len(), 1);
561 assert_matches!(&branches[0].2, LocalTypeR::Send { partner, .. } => {
563 assert_eq!(partner, "B");
564 });
565 });
566 }
567
568 #[test]
569 fn test_merge_recvs_overlapping_labels() {
570 let t1 = LocalTypeR::Recv {
572 partner: "A".to_string(),
573 branches: vec![
574 (Label::new("x"), None, LocalTypeR::End),
575 (Label::new("y"), None, LocalTypeR::End),
576 ],
577 };
578 let t2 = LocalTypeR::Recv {
579 partner: "A".to_string(),
580 branches: vec![
581 (Label::new("y"), None, LocalTypeR::End),
582 (Label::new("z"), None, LocalTypeR::End),
583 ],
584 };
585
586 let result = merge(&t1, &t2).unwrap();
587 assert_matches!(result, LocalTypeR::Recv { partner, branches } => {
588 assert_eq!(partner, "A");
589 assert_eq!(branches.len(), 3);
590 let labels: Vec<_> = branches.iter().map(|(l, _, _)| l.name.as_str()).collect();
591 assert!(labels.contains(&"x"));
592 assert!(labels.contains(&"y"));
593 assert!(labels.contains(&"z"));
594 });
595 }
596
597 #[test]
598 fn test_merge_recvs_overlapping_payload_annotation_mismatch_fails() {
599 let t1 = LocalTypeR::Recv {
600 partner: "A".to_string(),
601 branches: vec![(Label::new("y"), Some(ValType::Nat), LocalTypeR::End)],
602 };
603 let t2 = LocalTypeR::Recv {
604 partner: "A".to_string(),
605 branches: vec![(Label::new("y"), Some(ValType::Bool), LocalTypeR::End)],
606 };
607
608 let result = merge(&t1, &t2);
609 assert!(matches!(
610 result,
611 Err(MergeError::PayloadAnnotationMismatch { .. })
612 ));
613 }
614
615 #[test]
616 fn test_merge_recvs_overlapping_payload_annotation_match_succeeds() {
617 let t1 = LocalTypeR::Recv {
618 partner: "A".to_string(),
619 branches: vec![(Label::new("y"), Some(ValType::Nat), LocalTypeR::End)],
620 };
621 let t2 = LocalTypeR::Recv {
622 partner: "A".to_string(),
623 branches: vec![(Label::new("y"), Some(ValType::Nat), LocalTypeR::End)],
624 };
625
626 let result = merge(&t1, &t2).expect("matching payload annotations should merge");
627 assert_matches!(result, LocalTypeR::Recv { branches, .. } => {
628 assert_eq!(branches.len(), 1);
629 assert_eq!(branches[0].1, Some(ValType::Nat));
630 });
631 }
632
633 #[test]
638 fn test_merge_send_recv_fails() {
639 let t1 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
640 let t2 = LocalTypeR::recv("B", Label::new("msg"), LocalTypeR::End);
641
642 let result = merge(&t1, &t2);
643 assert!(matches!(result, Err(MergeError::DirectionMismatch)));
644 }
645
646 #[test]
651 fn test_merge_all_sends_same_labels() {
652 let types = vec![
654 LocalTypeR::send("B", Label::new("x"), LocalTypeR::End),
655 LocalTypeR::send("B", Label::new("x"), LocalTypeR::End),
656 LocalTypeR::send("B", Label::new("x"), LocalTypeR::End),
657 ];
658
659 let result = merge_all(&types).unwrap();
660 assert_matches!(result, LocalTypeR::Send { branches, .. } => {
661 assert_eq!(branches.len(), 1);
662 assert_eq!(branches[0].0.name, "x");
663 });
664 }
665
666 #[test]
667 fn test_merge_all_sends_different_labels_fails() {
668 let types = vec![
670 LocalTypeR::send("B", Label::new("a"), LocalTypeR::End),
671 LocalTypeR::send("B", Label::new("b"), LocalTypeR::End),
672 ];
673
674 let result = merge_all(&types);
675 assert!(result.is_err());
676 }
677
678 #[test]
679 fn test_merge_all_recvs_different_labels() {
680 let types = vec![
682 LocalTypeR::recv("A", Label::new("a"), LocalTypeR::End),
683 LocalTypeR::recv("A", Label::new("b"), LocalTypeR::End),
684 LocalTypeR::recv("A", Label::new("c"), LocalTypeR::End),
685 ];
686
687 let result = merge_all(&types).unwrap();
688 assert_matches!(result, LocalTypeR::Recv { branches, .. } => {
689 assert_eq!(branches.len(), 3);
690 });
691 }
692
693 #[test]
698 fn test_can_merge_sends_same_label() {
699 let t1 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
700 let t2 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
701 assert!(can_merge(&t1, &t2));
702 }
703
704 #[test]
705 fn test_can_merge_sends_different_labels_false() {
706 let t1 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
708 let t2 = LocalTypeR::send("B", Label::new("other"), LocalTypeR::End);
709 assert!(!can_merge(&t1, &t2));
710 }
711
712 #[test]
713 fn test_can_merge_recvs_different_labels_true() {
714 let t1 = LocalTypeR::recv("A", Label::new("msg"), LocalTypeR::End);
716 let t2 = LocalTypeR::recv("A", Label::new("other"), LocalTypeR::End);
717 assert!(can_merge(&t1, &t2));
718 }
719
720 #[test]
721 fn test_can_merge_send_recv_false() {
722 let t1 = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
723 let t2 = LocalTypeR::recv("B", Label::new("msg"), LocalTypeR::End);
724 assert!(!can_merge(&t1, &t2));
725 }
726}