Skip to main content

vernier_partial/
merge.rs

1//! Paradigm-agnostic merge policy: image-id partition disjointness
2//! and strict-mode rank-collision detection.
3//!
4//! Each paradigm's outer merge accumulator embeds a
5//! [`BaseMergeAccumulator`] and delegates the policy checks to its
6//! [`BaseMergeAccumulator::ingest_image_ids`] and
7//! [`BaseMergeAccumulator::ingest_rank_id`] methods, then folds in its
8//! own paradigm-specific cell store.
9
10use std::collections::{HashMap, HashSet};
11
12/// `RankId` type alias re-exported here so paradigm crates don't have
13/// to import from `envelope` separately.
14pub type RankId = u32;
15
16/// Sentinel `RankId` carried in the partition-overlap error when one
17/// of the colliding partials lacked a rank id (single-rank flow). Real
18/// `rank_id`s should always be `< u32::MAX` in any DDP shape.
19pub(crate) const UNRANKED_SENTINEL: RankId = u32::MAX;
20
21use crate::error::PartialError;
22
23/// Cross-partial state that every paradigm's merge accumulator
24/// shares: image-id ownership (for the disjoint-partition rule) and
25/// strict-mode rank-id distinctness.
26///
27/// Paradigms wrap this in their own outer accumulator and call
28/// [`Self::ingest_rank_id`] + [`Self::ingest_image_ids`] from their
29/// `ingest` per partial, before folding the paradigm-specific body.
30#[derive(Debug)]
31pub struct BaseMergeAccumulator {
32    /// Image id → rank that ingested it. Source of truth for the
33    /// merged `seen_images` set.
34    pub image_owner: HashMap<i64, RankId>,
35    /// Rank ids observed so far. Only populated in strict mode (the
36    /// distinctness invariant is strict-only); empty in corrected.
37    pub seen_rank_ids: HashSet<RankId>,
38    /// Whether the receiver is in strict mode.
39    pub strict: bool,
40}
41
42impl BaseMergeAccumulator {
43    /// Construct an empty accumulator. `strict` controls whether the
44    /// rank-id distinctness check fires.
45    pub fn new(strict: bool) -> Self {
46        Self {
47            image_owner: HashMap::new(),
48            seen_rank_ids: HashSet::new(),
49            strict,
50        }
51    }
52
53    /// Record one partial's `rank_id`. In strict mode, returns
54    /// [`PartialError::RankCollision`] if a previous partial declared
55    /// the same id. In corrected mode this is a no-op.
56    pub fn ingest_rank_id(&mut self, rank_id: Option<RankId>) -> Result<(), PartialError> {
57        if !self.strict {
58            return Ok(());
59        }
60        if let Some(rid) = rank_id {
61            if !self.seen_rank_ids.insert(rid) {
62                return Err(PartialError::RankCollision { rank_id: rid });
63            }
64        }
65        Ok(())
66    }
67
68    /// Record one partial's `seen_images` against its declared
69    /// `rank_id`. Returns [`PartialError::PartitionOverlap`] if any
70    /// `image_id` was already registered by a different rank.
71    pub fn ingest_image_ids(
72        &mut self,
73        rank_id: Option<RankId>,
74        image_ids: impl IntoIterator<Item = i64>,
75    ) -> Result<(), PartialError> {
76        let owner = rank_id.unwrap_or(UNRANKED_SENTINEL);
77        for id in image_ids {
78            if let Some(&prev) = self.image_owner.get(&id) {
79                let (a, b) = (prev.min(owner), prev.max(owner));
80                return Err(PartialError::PartitionOverlap {
81                    rank_a: a,
82                    rank_b: b,
83                    image_id: id,
84                });
85            }
86            self.image_owner.insert(id, owner);
87        }
88        Ok(())
89    }
90
91    /// Borrow the merged image-id set. Order is unspecified — sort
92    /// before consuming if determinism is needed.
93    pub fn image_ids(&self) -> impl Iterator<Item = i64> + '_ {
94        self.image_owner.keys().copied()
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn rank_collision_strict() {
104        let mut acc = BaseMergeAccumulator::new(true);
105        acc.ingest_rank_id(Some(0)).unwrap();
106        let err = acc.ingest_rank_id(Some(0)).unwrap_err();
107        assert!(matches!(err, PartialError::RankCollision { rank_id: 0 }));
108    }
109
110    #[test]
111    fn rank_collision_corrected_tolerated() {
112        let mut acc = BaseMergeAccumulator::new(false);
113        acc.ingest_rank_id(Some(0)).unwrap();
114        // Corrected mode tolerates duplicate rank_ids — they're informational only.
115        acc.ingest_rank_id(Some(0)).unwrap();
116    }
117
118    #[test]
119    fn partition_overlap_named_ranks() {
120        let mut acc = BaseMergeAccumulator::new(true);
121        acc.ingest_image_ids(Some(0), [1, 2, 3]).unwrap();
122        let err = acc.ingest_image_ids(Some(1), [3, 4]).unwrap_err();
123        match err {
124            PartialError::PartitionOverlap {
125                rank_a,
126                rank_b,
127                image_id,
128            } => {
129                assert_eq!(rank_a, 0);
130                assert_eq!(rank_b, 1);
131                assert_eq!(image_id, 3);
132            }
133            other => panic!("unexpected error: {other:?}"),
134        }
135    }
136
137    #[test]
138    fn partition_overlap_unranked_sentinel() {
139        let mut acc = BaseMergeAccumulator::new(false);
140        acc.ingest_image_ids(None, [7]).unwrap();
141        let err = acc.ingest_image_ids(None, [7]).unwrap_err();
142        match err {
143            PartialError::PartitionOverlap { rank_a, rank_b, .. } => {
144                assert_eq!(rank_a, UNRANKED_SENTINEL);
145                assert_eq!(rank_b, UNRANKED_SENTINEL);
146            }
147            other => panic!("unexpected error: {other:?}"),
148        }
149    }
150}