Skip to main content

zer_core/
comparison.rs

1use crate::record::RecordId;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
4#[repr(u8)]
5pub enum ComparisonLevel {
6    None = 0,
7    Partial = 1,
8    Close = 2,
9    Exact = 3,
10    /// Field structurally absent on one or both sides (cross-schema linkage).
11    /// Never fed to EM; both E-step and M-step skip pairs where any field carries this level.
12    Null = 255,
13}
14
15impl ComparisonLevel {
16    pub fn as_u8(self) -> u8 {
17        self as u8
18    }
19
20    #[inline]
21    pub fn from_u8(v: u8) -> Self {
22        match v {
23            1 => Self::Partial,
24            2 => Self::Close,
25            3 => Self::Exact,
26            255 => Self::Null,
27            _ => Self::None,
28        }
29    }
30}
31
32impl PartialOrd for ComparisonLevel {
33    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
34        Some(self.cmp(other))
35    }
36}
37
38impl Ord for ComparisonLevel {
39    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
40        self.as_u8().cmp(&other.as_u8())
41    }
42}
43
44// ── ComparisonVector (single-pair) ────────────────────────────────────────────
45
46/// Comparison result for a single candidate pair.
47///
48/// Used for single-pair comparisons (`Comparator::compare`) and as the
49/// per-pair view stored in `ScoredPair`.  For batch operations use
50/// `ComparisonBatch`.
51#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
52pub struct ComparisonVector {
53    pub record_a: RecordId,
54    pub record_b: RecordId,
55    pub levels: Vec<ComparisonLevel>,
56}
57
58impl ComparisonVector {
59    pub fn new(record_a: RecordId, record_b: RecordId, levels: Vec<ComparisonLevel>) -> Self {
60        Self {
61            record_a,
62            record_b,
63            levels,
64        }
65    }
66}
67
68// ── ComparisonBatch (field-major SoA) ─────────────────────────────────────────
69
70/// Field-major SoA batch of comparison results for many pairs.
71///
72/// # Layout
73///
74/// ```text
75/// levels[field_idx * n_pairs + pair_idx] = ComparisonLevel as u8
76/// ```
77///
78/// All values for field 0 across every pair are contiguous, then field 1, etc.
79#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
80pub struct ComparisonBatch {
81    pub n_pairs: usize,
82    pub n_fields: usize,
83    /// `(record_a_id, record_b_id)` for pair `p`.
84    pub pair_ids: Vec<(RecordId, RecordId)>,
85    /// Field-major levels: `levels[f * n_pairs + p]`.
86    pub levels: Vec<u8>,
87}
88
89impl ComparisonBatch {
90    /// Allocate a zeroed batch (all levels = `ComparisonLevel::None`).
91    pub fn new(n_pairs: usize, n_fields: usize, pair_ids: Vec<(RecordId, RecordId)>) -> Self {
92        Self {
93            n_pairs,
94            n_fields,
95            pair_ids,
96            levels: vec![0u8; n_fields * n_pairs],
97        }
98    }
99
100    /// Read the level for `(field, pair)`.
101    #[inline]
102    pub fn level(&self, field: usize, pair: usize) -> ComparisonLevel {
103        ComparisonLevel::from_u8(self.levels[field * self.n_pairs + pair])
104    }
105
106    /// Write the level for `(field, pair)`.
107    #[inline]
108    pub fn set_level(&mut self, field: usize, pair: usize, level: ComparisonLevel) {
109        self.levels[field * self.n_pairs + pair] = level as u8;
110    }
111
112    /// Reconstruct a `ComparisonVector` for pair `p`.
113    pub fn pair_as_vector(&self, pair_idx: usize) -> ComparisonVector {
114        let (a, b) = self.pair_ids[pair_idx];
115        let levels = (0..self.n_fields)
116            .map(|f| self.level(f, pair_idx))
117            .collect();
118        ComparisonVector::new(a, b, levels)
119    }
120
121    /// Build from an existing `Vec<ComparisonVector>` for migration / tests.
122    pub fn from_vectors(vectors: &[ComparisonVector]) -> Self {
123        if vectors.is_empty() {
124            return Self::new(0, 0, vec![]);
125        }
126        let n_pairs = vectors.len();
127        let n_fields = vectors[0].levels.len();
128        let pair_ids = vectors.iter().map(|v| (v.record_a, v.record_b)).collect();
129        let mut batch = Self::new(n_pairs, n_fields, pair_ids);
130        for (p, v) in vectors.iter().enumerate() {
131            for (f, &level) in v.levels.iter().enumerate() {
132                batch.set_level(f, p, level);
133            }
134        }
135        batch
136    }
137
138    /// Convert back to `Vec<ComparisonVector>` for callers that still need it.
139    pub fn into_vectors(&self) -> Vec<ComparisonVector> {
140        (0..self.n_pairs).map(|p| self.pair_as_vector(p)).collect()
141    }
142
143    /// Concatenate multiple field-major batches (same `n_fields`) into one.
144    ///
145    /// Each chunk may have a different `n_pairs`.  The merged layout remains
146    /// field-major with `n_pairs_total = sum of all chunk n_pairs`.
147    pub fn concat(chunks: &[Self]) -> Self {
148        let chunks: Vec<&Self> = chunks.iter().filter(|c| c.n_pairs > 0).collect();
149        if chunks.is_empty() {
150            return Self::new(0, 0, vec![]);
151        }
152        let n_fields = chunks[0].n_fields;
153        let n_total: usize = chunks.iter().map(|c| c.n_pairs).sum();
154
155        let mut pair_ids = Vec::with_capacity(n_total);
156        let mut levels = vec![0u8; n_fields * n_total];
157
158        let mut offset = 0usize;
159        for chunk in &chunks {
160            pair_ids.extend_from_slice(&chunk.pair_ids);
161            for f in 0..n_fields {
162                let dst = f * n_total + offset;
163                let src = f * chunk.n_pairs;
164                levels[dst..dst + chunk.n_pairs]
165                    .copy_from_slice(&chunk.levels[src..src + chunk.n_pairs]);
166            }
167            offset += chunk.n_pairs;
168        }
169
170        Self {
171            n_pairs: n_total,
172            n_fields,
173            pair_ids,
174            levels,
175        }
176    }
177}
178
179// ── Tests ─────────────────────────────────────────────────────────────────────
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn comparison_level_ordering() {
187        assert!(ComparisonLevel::Exact > ComparisonLevel::Close);
188        assert!(ComparisonLevel::Close > ComparisonLevel::Partial);
189        assert!(ComparisonLevel::Partial > ComparisonLevel::None);
190    }
191
192    #[test]
193    fn comparison_level_repr_values() {
194        assert_eq!(ComparisonLevel::Exact.as_u8(), 3);
195        assert_eq!(ComparisonLevel::Close.as_u8(), 2);
196        assert_eq!(ComparisonLevel::Partial.as_u8(), 1);
197        assert_eq!(ComparisonLevel::None.as_u8(), 0);
198    }
199
200    #[test]
201    fn comparison_level_round_trip() {
202        for &l in &[
203            ComparisonLevel::None,
204            ComparisonLevel::Partial,
205            ComparisonLevel::Close,
206            ComparisonLevel::Exact,
207            ComparisonLevel::Null,
208        ] {
209            assert_eq!(ComparisonLevel::from_u8(l.as_u8()), l);
210        }
211        assert_eq!(ComparisonLevel::from_u8(99), ComparisonLevel::None);
212    }
213
214    #[test]
215    fn batch_field_major_layout() {
216        // n_fields=2, n_pairs=3
217        let pair_ids = vec![(1, 2), (3, 4), (5, 6)];
218        let mut batch = ComparisonBatch::new(3, 2, pair_ids);
219
220        // Set levels in a known pattern
221        batch.set_level(0, 0, ComparisonLevel::Exact); // field 0, pair 0
222        batch.set_level(0, 1, ComparisonLevel::Close); // field 0, pair 1
223        batch.set_level(0, 2, ComparisonLevel::Partial); // field 0, pair 2
224        batch.set_level(1, 0, ComparisonLevel::None); // field 1, pair 0
225        batch.set_level(1, 1, ComparisonLevel::Exact); // field 1, pair 1
226        batch.set_level(1, 2, ComparisonLevel::Close); // field 1, pair 2
227
228        // Field-major: field 0 values at indices 0,1,2; field 1 at 3,4,5
229        assert_eq!(batch.levels[0], ComparisonLevel::Exact as u8);
230        assert_eq!(batch.levels[1], ComparisonLevel::Close as u8);
231        assert_eq!(batch.levels[2], ComparisonLevel::Partial as u8);
232        assert_eq!(batch.levels[3], ComparisonLevel::None as u8);
233        assert_eq!(batch.levels[4], ComparisonLevel::Exact as u8);
234        assert_eq!(batch.levels[5], ComparisonLevel::Close as u8);
235
236        // pair_as_vector reconstructs correctly
237        let v = batch.pair_as_vector(1); // pair index 1
238        assert_eq!(v.record_a, 3);
239        assert_eq!(v.record_b, 4);
240        assert_eq!(
241            v.levels,
242            vec![ComparisonLevel::Close, ComparisonLevel::Exact]
243        );
244    }
245
246    #[test]
247    fn batch_from_vectors_round_trips() {
248        let vectors = vec![
249            ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact, ComparisonLevel::None]),
250            ComparisonVector::new(3, 4, vec![ComparisonLevel::Partial, ComparisonLevel::Close]),
251        ];
252        let batch = ComparisonBatch::from_vectors(&vectors);
253        assert_eq!(batch.n_pairs, 2);
254        assert_eq!(batch.n_fields, 2);
255
256        let back = batch.into_vectors();
257        for (orig, got) in vectors.iter().zip(back.iter()) {
258            assert_eq!(orig.record_a, got.record_a);
259            assert_eq!(orig.record_b, got.record_b);
260            assert_eq!(orig.levels, got.levels);
261        }
262    }
263
264    #[test]
265    fn batch_empty_is_valid() {
266        let batch = ComparisonBatch::from_vectors(&[]);
267        assert_eq!(batch.n_pairs, 0);
268        assert_eq!(batch.n_fields, 0);
269        assert!(batch.levels.is_empty());
270    }
271}