Skip to main content

zer_core/
comparison.rs

1use crate::record::RecordId;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash,
4         serde::Serialize, serde::Deserialize)]
5#[repr(u8)]
6pub enum ComparisonLevel {
7    None    = 0,
8    Partial = 1,
9    Close   = 2,
10    Exact   = 3,
11    /// Field structurally absent on one or both sides (cross-schema linkage).
12    /// Never fed to EM; both E-step and M-step skip pairs where any field carries this level.
13    Null    = 255,
14}
15
16impl ComparisonLevel {
17    pub fn as_u8(self) -> u8 {
18        self as u8
19    }
20
21    #[inline]
22    pub fn from_u8(v: u8) -> Self {
23        match v {
24            1   => Self::Partial,
25            2   => Self::Close,
26            3   => Self::Exact,
27            255 => Self::Null,
28            _   => Self::None,
29        }
30    }
31}
32
33impl PartialOrd for ComparisonLevel {
34    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
35        Some(self.cmp(other))
36    }
37}
38
39impl Ord for ComparisonLevel {
40    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
41        self.as_u8().cmp(&other.as_u8())
42    }
43}
44
45// ── ComparisonVector (single-pair) ────────────────────────────────────────────
46
47/// Comparison result for a single candidate pair.
48///
49/// Used for single-pair comparisons (`Comparator::compare`) and as the
50/// per-pair view stored in `ScoredPair`.  For batch operations use
51/// `ComparisonBatch`.
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53pub struct ComparisonVector {
54    pub record_a: RecordId,
55    pub record_b: RecordId,
56    pub levels:   Vec<ComparisonLevel>,
57}
58
59impl ComparisonVector {
60    pub fn new(record_a: RecordId, record_b: RecordId, levels: Vec<ComparisonLevel>) -> Self {
61        Self { record_a, record_b, levels }
62    }
63}
64
65// ── ComparisonBatch (field-major SoA) ─────────────────────────────────────────
66
67/// Field-major SoA batch of comparison results for many pairs.
68///
69/// # Layout
70///
71/// ```text
72/// levels[field_idx * n_pairs + pair_idx] = ComparisonLevel as u8
73/// ```
74///
75/// All values for field 0 across every pair are contiguous, then field 1, etc.
76#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
77pub struct ComparisonBatch {
78    pub n_pairs:  usize,
79    pub n_fields: usize,
80    /// `(record_a_id, record_b_id)` for pair `p`.
81    pub pair_ids: Vec<(RecordId, RecordId)>,
82    /// Field-major levels: `levels[f * n_pairs + p]`.
83    pub levels:   Vec<u8>,
84}
85
86impl ComparisonBatch {
87    /// Allocate a zeroed batch (all levels = `ComparisonLevel::None`).
88    pub fn new(n_pairs: usize, n_fields: usize, pair_ids: Vec<(RecordId, RecordId)>) -> Self {
89        Self {
90            n_pairs,
91            n_fields,
92            pair_ids,
93            levels: vec![0u8; n_fields * n_pairs],
94        }
95    }
96
97    /// Read the level for `(field, pair)`.
98    #[inline]
99    pub fn level(&self, field: usize, pair: usize) -> ComparisonLevel {
100        ComparisonLevel::from_u8(self.levels[field * self.n_pairs + pair])
101    }
102
103    /// Write the level for `(field, pair)`.
104    #[inline]
105    pub fn set_level(&mut self, field: usize, pair: usize, level: ComparisonLevel) {
106        self.levels[field * self.n_pairs + pair] = level as u8;
107    }
108
109    /// Reconstruct a `ComparisonVector` for pair `p`.
110    pub fn pair_as_vector(&self, pair_idx: usize) -> ComparisonVector {
111        let (a, b) = self.pair_ids[pair_idx];
112        let levels = (0..self.n_fields)
113            .map(|f| self.level(f, pair_idx))
114            .collect();
115        ComparisonVector::new(a, b, levels)
116    }
117
118    /// Build from an existing `Vec<ComparisonVector>` for migration / tests.
119    pub fn from_vectors(vectors: &[ComparisonVector]) -> Self {
120        if vectors.is_empty() {
121            return Self::new(0, 0, vec![]);
122        }
123        let n_pairs  = vectors.len();
124        let n_fields = vectors[0].levels.len();
125        let pair_ids = vectors.iter().map(|v| (v.record_a, v.record_b)).collect();
126        let mut batch = Self::new(n_pairs, n_fields, pair_ids);
127        for (p, v) in vectors.iter().enumerate() {
128            for (f, &level) in v.levels.iter().enumerate() {
129                batch.set_level(f, p, level);
130            }
131        }
132        batch
133    }
134
135    /// Convert back to `Vec<ComparisonVector>` for callers that still need it.
136    pub fn into_vectors(&self) -> Vec<ComparisonVector> {
137        (0..self.n_pairs).map(|p| self.pair_as_vector(p)).collect()
138    }
139
140    /// Concatenate multiple field-major batches (same `n_fields`) into one.
141    ///
142    /// Each chunk may have a different `n_pairs`.  The merged layout remains
143    /// field-major with `n_pairs_total = sum of all chunk n_pairs`.
144    pub fn concat(chunks: &[Self]) -> Self {
145        let chunks: Vec<&Self> = chunks.iter().filter(|c| c.n_pairs > 0).collect();
146        if chunks.is_empty() {
147            return Self::new(0, 0, vec![]);
148        }
149        let n_fields = chunks[0].n_fields;
150        let n_total: usize = chunks.iter().map(|c| c.n_pairs).sum();
151
152        let mut pair_ids = Vec::with_capacity(n_total);
153        let mut levels = vec![0u8; n_fields * n_total];
154
155        let mut offset = 0usize;
156        for chunk in &chunks {
157            pair_ids.extend_from_slice(&chunk.pair_ids);
158            for f in 0..n_fields {
159                let dst = f * n_total + offset;
160                let src = f * chunk.n_pairs;
161                levels[dst..dst + chunk.n_pairs]
162                    .copy_from_slice(&chunk.levels[src..src + chunk.n_pairs]);
163            }
164            offset += chunk.n_pairs;
165        }
166
167        Self { n_pairs: n_total, n_fields, pair_ids, levels }
168    }
169}
170
171// ── Tests ─────────────────────────────────────────────────────────────────────
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn comparison_level_ordering() {
179        assert!(ComparisonLevel::Exact   > ComparisonLevel::Close);
180        assert!(ComparisonLevel::Close   > ComparisonLevel::Partial);
181        assert!(ComparisonLevel::Partial > ComparisonLevel::None);
182    }
183
184    #[test]
185    fn comparison_level_repr_values() {
186        assert_eq!(ComparisonLevel::Exact.as_u8(),   3);
187        assert_eq!(ComparisonLevel::Close.as_u8(),   2);
188        assert_eq!(ComparisonLevel::Partial.as_u8(), 1);
189        assert_eq!(ComparisonLevel::None.as_u8(),    0);
190    }
191
192    #[test]
193    fn comparison_level_round_trip() {
194        for &l in &[
195            ComparisonLevel::None,
196            ComparisonLevel::Partial,
197            ComparisonLevel::Close,
198            ComparisonLevel::Exact,
199            ComparisonLevel::Null,
200        ] {
201            assert_eq!(ComparisonLevel::from_u8(l.as_u8()), l);
202        }
203        assert_eq!(ComparisonLevel::from_u8(99), ComparisonLevel::None);
204    }
205
206    #[test]
207    fn batch_field_major_layout() {
208        // n_fields=2, n_pairs=3
209        let pair_ids = vec![(1, 2), (3, 4), (5, 6)];
210        let mut batch = ComparisonBatch::new(3, 2, pair_ids);
211
212        // Set levels in a known pattern
213        batch.set_level(0, 0, ComparisonLevel::Exact);   // field 0, pair 0
214        batch.set_level(0, 1, ComparisonLevel::Close);   // field 0, pair 1
215        batch.set_level(0, 2, ComparisonLevel::Partial); // field 0, pair 2
216        batch.set_level(1, 0, ComparisonLevel::None);    // field 1, pair 0
217        batch.set_level(1, 1, ComparisonLevel::Exact);   // field 1, pair 1
218        batch.set_level(1, 2, ComparisonLevel::Close);   // field 1, pair 2
219
220        // Field-major: field 0 values at indices 0,1,2; field 1 at 3,4,5
221        assert_eq!(batch.levels[0], ComparisonLevel::Exact   as u8);
222        assert_eq!(batch.levels[1], ComparisonLevel::Close   as u8);
223        assert_eq!(batch.levels[2], ComparisonLevel::Partial as u8);
224        assert_eq!(batch.levels[3], ComparisonLevel::None    as u8);
225        assert_eq!(batch.levels[4], ComparisonLevel::Exact   as u8);
226        assert_eq!(batch.levels[5], ComparisonLevel::Close   as u8);
227
228        // pair_as_vector reconstructs correctly
229        let v = batch.pair_as_vector(1); // pair index 1
230        assert_eq!(v.record_a, 3);
231        assert_eq!(v.record_b, 4);
232        assert_eq!(v.levels, vec![ComparisonLevel::Close, ComparisonLevel::Exact]);
233    }
234
235    #[test]
236    fn batch_from_vectors_round_trips() {
237        let vectors = vec![
238            ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact, ComparisonLevel::None]),
239            ComparisonVector::new(3, 4, vec![ComparisonLevel::Partial, ComparisonLevel::Close]),
240        ];
241        let batch = ComparisonBatch::from_vectors(&vectors);
242        assert_eq!(batch.n_pairs, 2);
243        assert_eq!(batch.n_fields, 2);
244
245        let back = batch.into_vectors();
246        for (orig, got) in vectors.iter().zip(back.iter()) {
247            assert_eq!(orig.record_a, got.record_a);
248            assert_eq!(orig.record_b, got.record_b);
249            assert_eq!(orig.levels,   got.levels);
250        }
251    }
252
253    #[test]
254    fn batch_empty_is_valid() {
255        let batch = ComparisonBatch::from_vectors(&[]);
256        assert_eq!(batch.n_pairs, 0);
257        assert_eq!(batch.n_fields, 0);
258        assert!(batch.levels.is_empty());
259    }
260}