1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::sync::RwLock;
4
5use crate::{
6 comparison::{ComparisonBatch, ComparisonVector},
7 entity::{Entity, EntityId},
8 error::ZerError,
9 record::{Record, RecordId},
10 record_pool::RecordPool,
11 schema::Schema,
12 scoring::{ModelParams, ScoredPair},
13};
14
15pub type Result<T> = std::result::Result<T, ZerError>;
16
17pub trait RecordStore: Send + Sync {
21 fn insert(&self, record: Record);
23
24 fn get(&self, id: RecordId) -> Option<Cow<'_, Record>>;
26
27 fn get_many(&self, ids: &[RecordId]) -> Vec<Option<Cow<'_, Record>>> {
30 ids.iter().map(|id| self.get(*id)).collect()
31 }
32
33 fn len(&self) -> usize;
35
36 fn is_empty(&self) -> bool {
37 self.len() == 0
38 }
39}
40
41struct VecRecordStoreInner {
44 records: Vec<Record>,
45 id_to_idx: HashMap<RecordId, usize>,
46}
47
48pub struct VecRecordStore {
50 inner: RwLock<VecRecordStoreInner>,
51}
52
53impl VecRecordStore {
54 pub fn new() -> Self {
55 Self {
56 inner: RwLock::new(VecRecordStoreInner {
57 records: Vec::new(),
58 id_to_idx: HashMap::new(),
59 }),
60 }
61 }
62}
63
64impl Default for VecRecordStore {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl RecordStore for VecRecordStore {
71 fn insert(&self, record: Record) {
72 let mut inner = self.inner.write().unwrap();
73 let idx = inner.records.len();
74 inner.id_to_idx.insert(record.id, idx);
75 inner.records.push(record);
76 }
77
78 fn get(&self, id: RecordId) -> Option<Cow<'_, Record>> {
79 let inner = self.inner.read().unwrap();
80 let idx = *inner.id_to_idx.get(&id)?;
81 Some(Cow::Owned(inner.records[idx].clone()))
82 }
83
84 fn len(&self) -> usize {
85 self.inner.read().unwrap().records.len()
86 }
87}
88
89pub trait BlockIndex: Send + Sync {
96 fn insert(&mut self, record_id: RecordId, keys: Vec<String>);
98
99 fn lookup_union(&self, keys: &[String], exclude: RecordId) -> Vec<RecordId>;
102
103 fn remove(&mut self, record_id: RecordId);
105
106 fn as_any(&self) -> &dyn std::any::Any;
107 fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
108}
109
110pub trait Blocker: Send + Sync {
112 fn blocking_keys(&self, record: &Record, schema: &Schema) -> Vec<String>;
113 fn index_record(&self, record: &Record, schema: &Schema, index: &mut dyn BlockIndex);
114 fn candidates(&self, record: &Record, schema: &Schema, index: &dyn BlockIndex)
115 -> Vec<RecordId>;
116}
117
118pub trait Comparator: Send + Sync {
119 fn compare(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector;
121
122 fn compare_batch_from_pool(
129 &self,
130 pool: &RecordPool,
131 indices: &[(usize, usize)],
132 schema: &Schema,
133 ) -> ComparisonBatch {
134 let n_pairs = indices.len();
135 let n_fields = schema.fields.len();
136 if n_pairs == 0 {
137 return ComparisonBatch::new(0, n_fields, vec![]);
138 }
139 let pair_ids: Vec<(u64, u64)> = indices
140 .iter()
141 .map(|&(i, j)| (pool.ids[i], pool.ids[j]))
142 .collect();
143 let mut batch = ComparisonBatch::new(n_pairs, n_fields, pair_ids);
144 for (p, &(i, j)) in indices.iter().enumerate() {
145 use crate::record::FieldValue;
146 let mut a = Record::new(pool.ids[i]);
147 let mut b = Record::new(pool.ids[j]);
148 for (f, field) in schema.fields.iter().enumerate() {
149 let va = pool.get(f, i);
150 let vb = pool.get(f, j);
151 if !va.is_empty() {
152 a = a.insert(&field.name, FieldValue::Text(va.to_string()));
153 }
154 if !vb.is_empty() {
155 b = b.insert(&field.name, FieldValue::Text(vb.to_string()));
156 }
157 }
158 let v = self.compare(&a, &b, schema);
159 for (f, &level) in v.levels.iter().enumerate() {
160 batch.set_level(f, p, level);
161 }
162 }
163 batch
164 }
165}
166
167pub trait Scorer: Send + Sync {
168 fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair;
170
171 fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
173 (0..batch.n_pairs)
174 .map(|p| self.score(&batch.pair_as_vector(p), params))
175 .collect()
176 }
177
178 fn estimate_params(
179 &self,
180 batch: &ComparisonBatch,
181 init: Option<ModelParams>,
182 max_iter: usize,
183 ) -> Result<ModelParams>;
184}
185
186pub trait Clusterer: Send + Sync {
188 fn cluster(&self, pairs: &[ScoredPair], params: &ModelParams) -> Vec<Entity>;
189}
190
191pub trait EntityStore: Send + Sync {
193 fn upsert_entity(&self, entity: &Entity) -> Result<EntityId>;
194 fn get_entity(&self, id: EntityId) -> Result<Entity>;
195 fn record_to_entity(&self, record_id: RecordId) -> Result<Option<EntityId>>;
196 fn all_entities(&self) -> Result<Vec<Entity>>;
197}
198
199pub trait IntoRecord {
205 fn into_record(self, id: RecordId) -> Record;
206}
207
208impl IntoRecord for Record {
209 fn into_record(self, _id: RecordId) -> Record {
210 self
211 }
212}
213
214pub trait Judge: Send + Sync {
216 fn adjudicate(&self, pairs: &[ScoredPair]) -> Result<Vec<JudgeVerdict>>;
217}
218
219impl<J: Judge + ?Sized> Judge for Box<J> {
220 fn adjudicate(&self, pairs: &[ScoredPair]) -> Result<Vec<JudgeVerdict>> {
221 (**self).adjudicate(pairs)
222 }
223}
224
225#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
226pub enum JudgeVerdict {
227 IncreaseConfidence,
228 DecreaseConfidence,
229 NoChange,
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn traits_are_object_safe() {
238 let _: Box<dyn super::BlockIndex>;
239 let _: Box<dyn super::Blocker>;
240 let _: Box<dyn super::Comparator>;
241 let _: Box<dyn super::Scorer>;
242 let _: Box<dyn super::Clusterer>;
243 let _: Box<dyn super::EntityStore>;
244 let _: Box<dyn super::Judge>;
245 let _: Box<dyn super::RecordStore>;
246 }
247
248 #[test]
249 fn vec_record_store_insert_and_get() {
250 use crate::record::{FieldValue, Record};
251 let store = VecRecordStore::new();
252 assert!(store.is_empty());
253
254 let r = Record::new(42).insert("name", FieldValue::Text("Alice".into()));
255 store.insert(r);
256
257 assert_eq!(store.len(), 1);
258 let fetched = store.get(42).expect("record 42 must exist");
259 assert_eq!(fetched.id, 42);
260 }
261
262 #[test]
263 fn vec_record_store_get_missing_returns_none() {
264 let store = VecRecordStore::new();
265 assert!(store.get(999).is_none());
266 }
267
268 #[test]
269 fn vec_record_store_get_many() {
270 use crate::record::Record;
271 let store = VecRecordStore::new();
272 store.insert(Record::new(1));
273 store.insert(Record::new(2));
274 store.insert(Record::new(3));
275
276 let results = store.get_many(&[1, 3, 99]);
277 assert!(results[0].is_some());
278 assert!(results[1].is_some());
279 assert!(results[2].is_none());
280 }
281
282 #[test]
283 fn vec_record_store_is_sendable() {
284 use std::sync::Arc;
285 let store: Arc<dyn RecordStore> = Arc::new(VecRecordStore::new());
286 let store2 = Arc::clone(&store);
287 let handle = std::thread::spawn(move || {
288 store2.insert(crate::record::Record::new(7));
289 });
290 handle.join().unwrap();
291 assert_eq!(store.len(), 1);
292 }
293}