1use crate::record::RecordId;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
4#[repr(u8)]
5pub enum ComparisonLevel {
6 None = 0,
7 Partial = 1,
8 Close = 2,
9 Exact = 3,
10 Null = 255,
13}
14
15impl ComparisonLevel {
16 pub fn as_u8(self) -> u8 {
17 self as u8
18 }
19
20 #[inline]
21 pub fn from_u8(v: u8) -> Self {
22 match v {
23 1 => Self::Partial,
24 2 => Self::Close,
25 3 => Self::Exact,
26 255 => Self::Null,
27 _ => Self::None,
28 }
29 }
30}
31
32impl PartialOrd for ComparisonLevel {
33 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
34 Some(self.cmp(other))
35 }
36}
37
38impl Ord for ComparisonLevel {
39 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
40 self.as_u8().cmp(&other.as_u8())
41 }
42}
43
44#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
52pub struct ComparisonVector {
53 pub record_a: RecordId,
54 pub record_b: RecordId,
55 pub levels: Vec<ComparisonLevel>,
56}
57
58impl ComparisonVector {
59 pub fn new(record_a: RecordId, record_b: RecordId, levels: Vec<ComparisonLevel>) -> Self {
60 Self {
61 record_a,
62 record_b,
63 levels,
64 }
65 }
66}
67
68#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
80pub struct ComparisonBatch {
81 pub n_pairs: usize,
82 pub n_fields: usize,
83 pub pair_ids: Vec<(RecordId, RecordId)>,
85 pub levels: Vec<u8>,
87}
88
89impl ComparisonBatch {
90 pub fn new(n_pairs: usize, n_fields: usize, pair_ids: Vec<(RecordId, RecordId)>) -> Self {
92 Self {
93 n_pairs,
94 n_fields,
95 pair_ids,
96 levels: vec![0u8; n_fields * n_pairs],
97 }
98 }
99
100 #[inline]
102 pub fn level(&self, field: usize, pair: usize) -> ComparisonLevel {
103 ComparisonLevel::from_u8(self.levels[field * self.n_pairs + pair])
104 }
105
106 #[inline]
108 pub fn set_level(&mut self, field: usize, pair: usize, level: ComparisonLevel) {
109 self.levels[field * self.n_pairs + pair] = level as u8;
110 }
111
112 pub fn pair_as_vector(&self, pair_idx: usize) -> ComparisonVector {
114 let (a, b) = self.pair_ids[pair_idx];
115 let levels = (0..self.n_fields)
116 .map(|f| self.level(f, pair_idx))
117 .collect();
118 ComparisonVector::new(a, b, levels)
119 }
120
121 pub fn from_vectors(vectors: &[ComparisonVector]) -> Self {
123 if vectors.is_empty() {
124 return Self::new(0, 0, vec![]);
125 }
126 let n_pairs = vectors.len();
127 let n_fields = vectors[0].levels.len();
128 let pair_ids = vectors.iter().map(|v| (v.record_a, v.record_b)).collect();
129 let mut batch = Self::new(n_pairs, n_fields, pair_ids);
130 for (p, v) in vectors.iter().enumerate() {
131 for (f, &level) in v.levels.iter().enumerate() {
132 batch.set_level(f, p, level);
133 }
134 }
135 batch
136 }
137
138 pub fn into_vectors(&self) -> Vec<ComparisonVector> {
140 (0..self.n_pairs).map(|p| self.pair_as_vector(p)).collect()
141 }
142
143 pub fn concat(chunks: &[Self]) -> Self {
148 let chunks: Vec<&Self> = chunks.iter().filter(|c| c.n_pairs > 0).collect();
149 if chunks.is_empty() {
150 return Self::new(0, 0, vec![]);
151 }
152 let n_fields = chunks[0].n_fields;
153 let n_total: usize = chunks.iter().map(|c| c.n_pairs).sum();
154
155 let mut pair_ids = Vec::with_capacity(n_total);
156 let mut levels = vec![0u8; n_fields * n_total];
157
158 let mut offset = 0usize;
159 for chunk in &chunks {
160 pair_ids.extend_from_slice(&chunk.pair_ids);
161 for f in 0..n_fields {
162 let dst = f * n_total + offset;
163 let src = f * chunk.n_pairs;
164 levels[dst..dst + chunk.n_pairs]
165 .copy_from_slice(&chunk.levels[src..src + chunk.n_pairs]);
166 }
167 offset += chunk.n_pairs;
168 }
169
170 Self {
171 n_pairs: n_total,
172 n_fields,
173 pair_ids,
174 levels,
175 }
176 }
177}
178
179#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn comparison_level_ordering() {
187 assert!(ComparisonLevel::Exact > ComparisonLevel::Close);
188 assert!(ComparisonLevel::Close > ComparisonLevel::Partial);
189 assert!(ComparisonLevel::Partial > ComparisonLevel::None);
190 }
191
192 #[test]
193 fn comparison_level_repr_values() {
194 assert_eq!(ComparisonLevel::Exact.as_u8(), 3);
195 assert_eq!(ComparisonLevel::Close.as_u8(), 2);
196 assert_eq!(ComparisonLevel::Partial.as_u8(), 1);
197 assert_eq!(ComparisonLevel::None.as_u8(), 0);
198 }
199
200 #[test]
201 fn comparison_level_round_trip() {
202 for &l in &[
203 ComparisonLevel::None,
204 ComparisonLevel::Partial,
205 ComparisonLevel::Close,
206 ComparisonLevel::Exact,
207 ComparisonLevel::Null,
208 ] {
209 assert_eq!(ComparisonLevel::from_u8(l.as_u8()), l);
210 }
211 assert_eq!(ComparisonLevel::from_u8(99), ComparisonLevel::None);
212 }
213
214 #[test]
215 fn batch_field_major_layout() {
216 let pair_ids = vec![(1, 2), (3, 4), (5, 6)];
218 let mut batch = ComparisonBatch::new(3, 2, pair_ids);
219
220 batch.set_level(0, 0, ComparisonLevel::Exact); batch.set_level(0, 1, ComparisonLevel::Close); batch.set_level(0, 2, ComparisonLevel::Partial); batch.set_level(1, 0, ComparisonLevel::None); batch.set_level(1, 1, ComparisonLevel::Exact); batch.set_level(1, 2, ComparisonLevel::Close); assert_eq!(batch.levels[0], ComparisonLevel::Exact as u8);
230 assert_eq!(batch.levels[1], ComparisonLevel::Close as u8);
231 assert_eq!(batch.levels[2], ComparisonLevel::Partial as u8);
232 assert_eq!(batch.levels[3], ComparisonLevel::None as u8);
233 assert_eq!(batch.levels[4], ComparisonLevel::Exact as u8);
234 assert_eq!(batch.levels[5], ComparisonLevel::Close as u8);
235
236 let v = batch.pair_as_vector(1); assert_eq!(v.record_a, 3);
239 assert_eq!(v.record_b, 4);
240 assert_eq!(
241 v.levels,
242 vec![ComparisonLevel::Close, ComparisonLevel::Exact]
243 );
244 }
245
246 #[test]
247 fn batch_from_vectors_round_trips() {
248 let vectors = vec![
249 ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact, ComparisonLevel::None]),
250 ComparisonVector::new(3, 4, vec![ComparisonLevel::Partial, ComparisonLevel::Close]),
251 ];
252 let batch = ComparisonBatch::from_vectors(&vectors);
253 assert_eq!(batch.n_pairs, 2);
254 assert_eq!(batch.n_fields, 2);
255
256 let back = batch.into_vectors();
257 for (orig, got) in vectors.iter().zip(back.iter()) {
258 assert_eq!(orig.record_a, got.record_a);
259 assert_eq!(orig.record_b, got.record_b);
260 assert_eq!(orig.levels, got.levels);
261 }
262 }
263
264 #[test]
265 fn batch_empty_is_valid() {
266 let batch = ComparisonBatch::from_vectors(&[]);
267 assert_eq!(batch.n_pairs, 0);
268 assert_eq!(batch.n_fields, 0);
269 assert!(batch.levels.is_empty());
270 }
271}