1use crate::record::RecordId;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash,
4 serde::Serialize, serde::Deserialize)]
5#[repr(u8)]
6pub enum ComparisonLevel {
7 None = 0,
8 Partial = 1,
9 Close = 2,
10 Exact = 3,
11 Null = 255,
14}
15
16impl ComparisonLevel {
17 pub fn as_u8(self) -> u8 {
18 self as u8
19 }
20
21 #[inline]
22 pub fn from_u8(v: u8) -> Self {
23 match v {
24 1 => Self::Partial,
25 2 => Self::Close,
26 3 => Self::Exact,
27 255 => Self::Null,
28 _ => Self::None,
29 }
30 }
31}
32
33impl PartialOrd for ComparisonLevel {
34 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
35 Some(self.cmp(other))
36 }
37}
38
39impl Ord for ComparisonLevel {
40 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
41 self.as_u8().cmp(&other.as_u8())
42 }
43}
44
45#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53pub struct ComparisonVector {
54 pub record_a: RecordId,
55 pub record_b: RecordId,
56 pub levels: Vec<ComparisonLevel>,
57}
58
59impl ComparisonVector {
60 pub fn new(record_a: RecordId, record_b: RecordId, levels: Vec<ComparisonLevel>) -> Self {
61 Self { record_a, record_b, levels }
62 }
63}
64
65#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
77pub struct ComparisonBatch {
78 pub n_pairs: usize,
79 pub n_fields: usize,
80 pub pair_ids: Vec<(RecordId, RecordId)>,
82 pub levels: Vec<u8>,
84}
85
86impl ComparisonBatch {
87 pub fn new(n_pairs: usize, n_fields: usize, pair_ids: Vec<(RecordId, RecordId)>) -> Self {
89 Self {
90 n_pairs,
91 n_fields,
92 pair_ids,
93 levels: vec![0u8; n_fields * n_pairs],
94 }
95 }
96
97 #[inline]
99 pub fn level(&self, field: usize, pair: usize) -> ComparisonLevel {
100 ComparisonLevel::from_u8(self.levels[field * self.n_pairs + pair])
101 }
102
103 #[inline]
105 pub fn set_level(&mut self, field: usize, pair: usize, level: ComparisonLevel) {
106 self.levels[field * self.n_pairs + pair] = level as u8;
107 }
108
109 pub fn pair_as_vector(&self, pair_idx: usize) -> ComparisonVector {
111 let (a, b) = self.pair_ids[pair_idx];
112 let levels = (0..self.n_fields)
113 .map(|f| self.level(f, pair_idx))
114 .collect();
115 ComparisonVector::new(a, b, levels)
116 }
117
118 pub fn from_vectors(vectors: &[ComparisonVector]) -> Self {
120 if vectors.is_empty() {
121 return Self::new(0, 0, vec![]);
122 }
123 let n_pairs = vectors.len();
124 let n_fields = vectors[0].levels.len();
125 let pair_ids = vectors.iter().map(|v| (v.record_a, v.record_b)).collect();
126 let mut batch = Self::new(n_pairs, n_fields, pair_ids);
127 for (p, v) in vectors.iter().enumerate() {
128 for (f, &level) in v.levels.iter().enumerate() {
129 batch.set_level(f, p, level);
130 }
131 }
132 batch
133 }
134
135 pub fn into_vectors(&self) -> Vec<ComparisonVector> {
137 (0..self.n_pairs).map(|p| self.pair_as_vector(p)).collect()
138 }
139
140 pub fn concat(chunks: &[Self]) -> Self {
145 let chunks: Vec<&Self> = chunks.iter().filter(|c| c.n_pairs > 0).collect();
146 if chunks.is_empty() {
147 return Self::new(0, 0, vec![]);
148 }
149 let n_fields = chunks[0].n_fields;
150 let n_total: usize = chunks.iter().map(|c| c.n_pairs).sum();
151
152 let mut pair_ids = Vec::with_capacity(n_total);
153 let mut levels = vec![0u8; n_fields * n_total];
154
155 let mut offset = 0usize;
156 for chunk in &chunks {
157 pair_ids.extend_from_slice(&chunk.pair_ids);
158 for f in 0..n_fields {
159 let dst = f * n_total + offset;
160 let src = f * chunk.n_pairs;
161 levels[dst..dst + chunk.n_pairs]
162 .copy_from_slice(&chunk.levels[src..src + chunk.n_pairs]);
163 }
164 offset += chunk.n_pairs;
165 }
166
167 Self { n_pairs: n_total, n_fields, pair_ids, levels }
168 }
169}
170
171#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn comparison_level_ordering() {
179 assert!(ComparisonLevel::Exact > ComparisonLevel::Close);
180 assert!(ComparisonLevel::Close > ComparisonLevel::Partial);
181 assert!(ComparisonLevel::Partial > ComparisonLevel::None);
182 }
183
184 #[test]
185 fn comparison_level_repr_values() {
186 assert_eq!(ComparisonLevel::Exact.as_u8(), 3);
187 assert_eq!(ComparisonLevel::Close.as_u8(), 2);
188 assert_eq!(ComparisonLevel::Partial.as_u8(), 1);
189 assert_eq!(ComparisonLevel::None.as_u8(), 0);
190 }
191
192 #[test]
193 fn comparison_level_round_trip() {
194 for &l in &[
195 ComparisonLevel::None,
196 ComparisonLevel::Partial,
197 ComparisonLevel::Close,
198 ComparisonLevel::Exact,
199 ComparisonLevel::Null,
200 ] {
201 assert_eq!(ComparisonLevel::from_u8(l.as_u8()), l);
202 }
203 assert_eq!(ComparisonLevel::from_u8(99), ComparisonLevel::None);
204 }
205
206 #[test]
207 fn batch_field_major_layout() {
208 let pair_ids = vec![(1, 2), (3, 4), (5, 6)];
210 let mut batch = ComparisonBatch::new(3, 2, pair_ids);
211
212 batch.set_level(0, 0, ComparisonLevel::Exact); batch.set_level(0, 1, ComparisonLevel::Close); batch.set_level(0, 2, ComparisonLevel::Partial); batch.set_level(1, 0, ComparisonLevel::None); batch.set_level(1, 1, ComparisonLevel::Exact); batch.set_level(1, 2, ComparisonLevel::Close); assert_eq!(batch.levels[0], ComparisonLevel::Exact as u8);
222 assert_eq!(batch.levels[1], ComparisonLevel::Close as u8);
223 assert_eq!(batch.levels[2], ComparisonLevel::Partial as u8);
224 assert_eq!(batch.levels[3], ComparisonLevel::None as u8);
225 assert_eq!(batch.levels[4], ComparisonLevel::Exact as u8);
226 assert_eq!(batch.levels[5], ComparisonLevel::Close as u8);
227
228 let v = batch.pair_as_vector(1); assert_eq!(v.record_a, 3);
231 assert_eq!(v.record_b, 4);
232 assert_eq!(v.levels, vec![ComparisonLevel::Close, ComparisonLevel::Exact]);
233 }
234
235 #[test]
236 fn batch_from_vectors_round_trips() {
237 let vectors = vec![
238 ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact, ComparisonLevel::None]),
239 ComparisonVector::new(3, 4, vec![ComparisonLevel::Partial, ComparisonLevel::Close]),
240 ];
241 let batch = ComparisonBatch::from_vectors(&vectors);
242 assert_eq!(batch.n_pairs, 2);
243 assert_eq!(batch.n_fields, 2);
244
245 let back = batch.into_vectors();
246 for (orig, got) in vectors.iter().zip(back.iter()) {
247 assert_eq!(orig.record_a, got.record_a);
248 assert_eq!(orig.record_b, got.record_b);
249 assert_eq!(orig.levels, got.levels);
250 }
251 }
252
253 #[test]
254 fn batch_empty_is_valid() {
255 let batch = ComparisonBatch::from_vectors(&[]);
256 assert_eq!(batch.n_pairs, 0);
257 assert_eq!(batch.n_fields, 0);
258 assert!(batch.levels.is_empty());
259 }
260}