Skip to main content

zer_core/
traits.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::sync::RwLock;
4
5use crate::{
6    comparison::{ComparisonBatch, ComparisonVector},
7    entity::{Entity, EntityId},
8    error::ZerError,
9    record::{Record, RecordId},
10    record_pool::RecordPool,
11    schema::Schema,
12    scoring::{ModelParams, ScoredPair},
13};
14
15pub type Result<T> = std::result::Result<T, ZerError>;
16
17// ── RecordStore ───────────────────────────────────────────────────────────────
18
19/// Backing store for records used during ingestion and batch runs.
20pub trait RecordStore: Send + Sync {
21    /// Persist a record.  Must be callable from the ingester background task.
22    fn insert(&self, record: Record);
23
24    /// Retrieve a single record by ID.  Returns `None` if not present.
25    fn get(&self, id: RecordId) -> Option<Cow<'_, Record>>;
26
27    /// Retrieve multiple records in one call (allows batch I/O optimisation).
28    /// The default impl calls `get` in a loop; override for bulk reads.
29    fn get_many(&self, ids: &[RecordId]) -> Vec<Option<Cow<'_, Record>>> {
30        ids.iter().map(|id| self.get(*id)).collect()
31    }
32
33    /// Total number of records held.
34    fn len(&self) -> usize;
35
36    fn is_empty(&self) -> bool {
37        self.len() == 0
38    }
39}
40
41// ── VecRecordStore (default in-memory impl) ───────────────────────────────────
42
43struct VecRecordStoreInner {
44    records: Vec<Record>,
45    id_to_idx: HashMap<RecordId, usize>,
46}
47
48/// Default in-memory [`RecordStore`] backed by a `Vec`, zero-config.
49pub struct VecRecordStore {
50    inner: RwLock<VecRecordStoreInner>,
51}
52
53impl VecRecordStore {
54    pub fn new() -> Self {
55        Self {
56            inner: RwLock::new(VecRecordStoreInner {
57                records: Vec::new(),
58                id_to_idx: HashMap::new(),
59            }),
60        }
61    }
62}
63
64impl Default for VecRecordStore {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl RecordStore for VecRecordStore {
71    fn insert(&self, record: Record) {
72        let mut inner = self.inner.write().unwrap();
73        let idx = inner.records.len();
74        inner.id_to_idx.insert(record.id, idx);
75        inner.records.push(record);
76    }
77
78    fn get(&self, id: RecordId) -> Option<Cow<'_, Record>> {
79        let inner = self.inner.read().unwrap();
80        let idx = *inner.id_to_idx.get(&id)?;
81        Some(Cow::Owned(inner.records[idx].clone()))
82    }
83
84    fn len(&self) -> usize {
85        self.inner.read().unwrap().records.len()
86    }
87}
88
89// ── BlockIndex ────────────────────────────────────────────────────────────────
90
91/// Opaque blocking index.
92///
93/// The `as_any` / `as_any_mut` escape hatches allow access to concrete fields
94/// not covered by the trait, such as index statistics.
95pub trait BlockIndex: Send + Sync {
96    /// Index `record_id` under the given set of blocking keys.
97    fn insert(&mut self, record_id: RecordId, keys: Vec<String>);
98
99    /// Return all record IDs sharing at least one key with `keys`, excluding
100    /// `exclude` (the querying record itself).  Result must be deduplicated.
101    fn lookup_union(&self, keys: &[String], exclude: RecordId) -> Vec<RecordId>;
102
103    /// Remove all index entries for `record_id`.
104    fn remove(&mut self, record_id: RecordId);
105
106    fn as_any(&self) -> &dyn std::any::Any;
107    fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
108}
109
110/// Extracts blocking keys from records and looks up candidates in an index.
111pub trait Blocker: Send + Sync {
112    fn blocking_keys(&self, record: &Record, schema: &Schema) -> Vec<String>;
113    fn index_record(&self, record: &Record, schema: &Schema, index: &mut dyn BlockIndex);
114    fn candidates(&self, record: &Record, schema: &Schema, index: &dyn BlockIndex)
115        -> Vec<RecordId>;
116}
117
118pub trait Comparator: Send + Sync {
119    /// Compare a single pair, always CPU, returns an individual vector.
120    fn compare(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector;
121
122    /// Pool-native batch comparison, the primary hot path.
123    ///
124    /// Reads `RecordPool` columns directly: zero HashMap lookups, no
125    /// `Record::clone()`.  Implementors SHOULD override this method.
126    /// The default falls back to `compare` per pair, which is correct but
127    /// slower than a native pool implementation.
128    fn compare_batch_from_pool(
129        &self,
130        pool: &RecordPool,
131        indices: &[(usize, usize)],
132        schema: &Schema,
133    ) -> ComparisonBatch {
134        let n_pairs = indices.len();
135        let n_fields = schema.fields.len();
136        if n_pairs == 0 {
137            return ComparisonBatch::new(0, n_fields, vec![]);
138        }
139        let pair_ids: Vec<(u64, u64)> = indices
140            .iter()
141            .map(|&(i, j)| (pool.ids[i], pool.ids[j]))
142            .collect();
143        let mut batch = ComparisonBatch::new(n_pairs, n_fields, pair_ids);
144        for (p, &(i, j)) in indices.iter().enumerate() {
145            use crate::record::FieldValue;
146            let mut a = Record::new(pool.ids[i]);
147            let mut b = Record::new(pool.ids[j]);
148            for (f, field) in schema.fields.iter().enumerate() {
149                let va = pool.get(f, i);
150                let vb = pool.get(f, j);
151                if !va.is_empty() {
152                    a = a.insert(&field.name, FieldValue::Text(va.to_string()));
153                }
154                if !vb.is_empty() {
155                    b = b.insert(&field.name, FieldValue::Text(vb.to_string()));
156                }
157            }
158            let v = self.compare(&a, &b, schema);
159            for (f, &level) in v.levels.iter().enumerate() {
160                batch.set_level(f, p, level);
161            }
162        }
163        batch
164    }
165}
166
167pub trait Scorer: Send + Sync {
168    /// Score a single pair, always CPU, cheap dot product.
169    fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair;
170
171    /// Score a batch using the field-major `ComparisonBatch`.
172    fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
173        (0..batch.n_pairs)
174            .map(|p| self.score(&batch.pair_as_vector(p), params))
175            .collect()
176    }
177
178    fn estimate_params(
179        &self,
180        batch: &ComparisonBatch,
181        init: Option<ModelParams>,
182        max_iter: usize,
183    ) -> Result<ModelParams>;
184}
185
186/// Groups scored pairs into entity clusters.
187pub trait Clusterer: Send + Sync {
188    fn cluster(&self, pairs: &[ScoredPair], params: &ModelParams) -> Vec<Entity>;
189}
190
191/// Persistent store for resolved entities.
192pub trait EntityStore: Send + Sync {
193    fn upsert_entity(&self, entity: &Entity) -> Result<EntityId>;
194    fn get_entity(&self, id: EntityId) -> Result<Entity>;
195    fn record_to_entity(&self, record_id: RecordId) -> Result<Option<EntityId>>;
196    fn all_entities(&self) -> Result<Vec<Entity>>;
197}
198
199/// Convert an external row type into a [`Record`].
200///
201/// Implement this in an adapter crate (e.g. `zer-adapters`) for foreign
202/// row types such as a Polars `LazyFrame` row or an Arrow `RecordBatch` row.
203/// The `id` parameter lets callers assign a stable [`RecordId`].
204pub trait IntoRecord {
205    fn into_record(self, id: RecordId) -> Record;
206}
207
208impl IntoRecord for Record {
209    fn into_record(self, _id: RecordId) -> Record {
210        self
211    }
212}
213
214/// Neural re-ranker that adjudicates borderline record pairs.
215pub trait Judge: Send + Sync {
216    fn adjudicate(&self, pairs: &[ScoredPair]) -> Result<Vec<JudgeVerdict>>;
217}
218
219impl<J: Judge + ?Sized> Judge for Box<J> {
220    fn adjudicate(&self, pairs: &[ScoredPair]) -> Result<Vec<JudgeVerdict>> {
221        (**self).adjudicate(pairs)
222    }
223}
224
225#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
226pub enum JudgeVerdict {
227    IncreaseConfidence,
228    DecreaseConfidence,
229    NoChange,
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn traits_are_object_safe() {
238        let _: Box<dyn super::BlockIndex>;
239        let _: Box<dyn super::Blocker>;
240        let _: Box<dyn super::Comparator>;
241        let _: Box<dyn super::Scorer>;
242        let _: Box<dyn super::Clusterer>;
243        let _: Box<dyn super::EntityStore>;
244        let _: Box<dyn super::Judge>;
245        let _: Box<dyn super::RecordStore>;
246    }
247
248    #[test]
249    fn vec_record_store_insert_and_get() {
250        use crate::record::{FieldValue, Record};
251        let store = VecRecordStore::new();
252        assert!(store.is_empty());
253
254        let r = Record::new(42).insert("name", FieldValue::Text("Alice".into()));
255        store.insert(r);
256
257        assert_eq!(store.len(), 1);
258        let fetched = store.get(42).expect("record 42 must exist");
259        assert_eq!(fetched.id, 42);
260    }
261
262    #[test]
263    fn vec_record_store_get_missing_returns_none() {
264        let store = VecRecordStore::new();
265        assert!(store.get(999).is_none());
266    }
267
268    #[test]
269    fn vec_record_store_get_many() {
270        use crate::record::Record;
271        let store = VecRecordStore::new();
272        store.insert(Record::new(1));
273        store.insert(Record::new(2));
274        store.insert(Record::new(3));
275
276        let results = store.get_many(&[1, 3, 99]);
277        assert!(results[0].is_some());
278        assert!(results[1].is_some());
279        assert!(results[2].is_none());
280    }
281
282    #[test]
283    fn vec_record_store_is_sendable() {
284        use std::sync::Arc;
285        let store: Arc<dyn RecordStore> = Arc::new(VecRecordStore::new());
286        let store2 = Arc::clone(&store);
287        let handle = std::thread::spawn(move || {
288            store2.insert(crate::record::Record::new(7));
289        });
290        handle.join().unwrap();
291        assert_eq!(store.len(), 1);
292    }
293}