Skip to main content

zer_cluster/store/
zes.rs

1use std::{path::Path, sync::Mutex};
2
3use rusqlite::{Connection, OptionalExtension};
4use zer_core::{
5    entity::{Entity, EntityId, EntityMember, ResolutionMethod},
6    error::ZerError,
7    record::RecordId,
8    traits::{EntityStore, Result},
9};
10
11use crate::provenance::{append_event, unix_now, ResolutionEvent};
12
13/// SQLite-backed entity store persisted as a single `.zes` file.
14///
15/// Uses `rusqlite/bundled` so no system SQLite installation is required.
16/// All mutations hold the connection `Mutex` for the duration of the
17/// transaction, suitable for single-threaded or lightly-concurrent use.
18pub struct ZalEntityStore {
19    conn: Mutex<Connection>,
20}
21
22impl ZalEntityStore {
23    /// Open (or create) a `.zes` store at the given path.
24    pub fn open(path: &Path) -> Result<Self> {
25        let conn = Connection::open(path).map_err(|e| ZerError::Store(e.to_string()))?;
26        init_schema(&conn)?;
27        Ok(Self {
28            conn: Mutex::new(conn),
29        })
30    }
31
32    /// Open an in-memory store. No file is created; data is lost on drop.
33    pub fn open_in_memory() -> Result<Self> {
34        let conn = Connection::open_in_memory().map_err(|e| ZerError::Store(e.to_string()))?;
35        init_schema(&conn)?;
36        Ok(Self {
37            conn: Mutex::new(conn),
38        })
39    }
40}
41
42fn init_schema(conn: &Connection) -> Result<()> {
43    conn.execute_batch(
44        "CREATE TABLE IF NOT EXISTS entities (
45            entity_id  INTEGER PRIMARY KEY AUTOINCREMENT,
46            created_at INTEGER NOT NULL,
47            updated_at INTEGER NOT NULL
48        );
49
50        CREATE TABLE IF NOT EXISTS entity_members (
51            id         INTEGER PRIMARY KEY AUTOINCREMENT,
52            entity_id  INTEGER NOT NULL REFERENCES entities(entity_id),
53            record_id  INTEGER NOT NULL,
54            record_key TEXT    NOT NULL DEFAULT '',
55            score      REAL    NOT NULL,
56            method     TEXT    NOT NULL,
57            source     TEXT,
58            added_at   INTEGER NOT NULL
59        );
60        CREATE UNIQUE INDEX IF NOT EXISTS idx_record_entity ON entity_members(record_id);
61
62        CREATE TABLE IF NOT EXISTS resolution_events (
63            id            INTEGER PRIMARY KEY AUTOINCREMENT,
64            event_type    TEXT    NOT NULL,
65            entity_id     INTEGER NOT NULL,
66            record_ids    TEXT    NOT NULL,
67            score         REAL,
68            judge_verdict TEXT,
69            occurred_at   INTEGER NOT NULL
70        );",
71    )
72    .map_err(|e| ZerError::Store(e.to_string()))
73}
74
75impl EntityStore for ZalEntityStore {
76    fn upsert_entity(&self, entity: &Entity) -> Result<EntityId> {
77        let conn = self.conn.lock().unwrap();
78        let now = unix_now();
79
80        // Find if any member already belongs to an existing entity.
81        let mut existing_id: Option<EntityId> = None;
82        for member in &entity.members {
83            let id: Option<i64> = conn
84                .query_row(
85                    "SELECT entity_id FROM entity_members WHERE record_id = ?1",
86                    [member.record_id as i64],
87                    |row| row.get(0),
88                )
89                .optional()
90                .map_err(|e| ZerError::Store(e.to_string()))?;
91
92            if let Some(eid) = id {
93                existing_id = Some(eid as EntityId);
94                break;
95            }
96        }
97
98        if let Some(eid) = existing_id {
99            // Entity already exists, merge new members in.
100            conn.execute(
101                "UPDATE entities SET updated_at = ?1 WHERE entity_id = ?2",
102                rusqlite::params![now, eid as i64],
103            )
104            .map_err(|e| ZerError::Store(e.to_string()))?;
105
106            let new_record_ids: Vec<RecordId> =
107                entity.members.iter().map(|m| m.record_id).collect();
108            for member in &entity.members {
109                conn.execute(
110                    "INSERT OR IGNORE INTO entity_members
111                         (entity_id, record_id, record_key, score, method, source, added_at)
112                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
113                    rusqlite::params![
114                        eid as i64,
115                        member.record_id as i64,
116                        &member.record_key,
117                        member.score,
118                        method_to_str(member.method),
119                        member.source.as_deref(),
120                        now,
121                    ],
122                )
123                .map_err(|e| ZerError::Store(e.to_string()))?;
124            }
125
126            append_event(
127                &conn,
128                &ResolutionEvent::RecordsAdded {
129                    entity_id: eid,
130                    record_ids: new_record_ids,
131                    method: entity
132                        .members
133                        .first()
134                        .map(|m| m.method)
135                        .unwrap_or(ResolutionMethod::AutoMatch),
136                },
137            )?;
138
139            Ok(eid)
140        } else {
141            // Brand-new entity.
142            conn.execute(
143                "INSERT INTO entities (created_at, updated_at) VALUES (?1, ?2)",
144                rusqlite::params![now, now],
145            )
146            .map_err(|e| ZerError::Store(e.to_string()))?;
147
148            let eid = conn.last_insert_rowid() as EntityId;
149
150            let record_ids: Vec<RecordId> = entity.members.iter().map(|m| m.record_id).collect();
151            for member in &entity.members {
152                conn.execute(
153                    "INSERT INTO entity_members
154                         (entity_id, record_id, record_key, score, method, source, added_at)
155                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
156                    rusqlite::params![
157                        eid as i64,
158                        member.record_id as i64,
159                        &member.record_key,
160                        member.score,
161                        method_to_str(member.method),
162                        member.source.as_deref(),
163                        now,
164                    ],
165                )
166                .map_err(|e| ZerError::Store(e.to_string()))?;
167            }
168
169            append_event(
170                &conn,
171                &ResolutionEvent::EntityCreated {
172                    entity_id: eid,
173                    record_ids,
174                },
175            )?;
176
177            Ok(eid)
178        }
179    }
180
181    fn get_entity(&self, id: EntityId) -> Result<Entity> {
182        let conn = self.conn.lock().unwrap();
183        let mut stmt = conn
184            .prepare(
185                "SELECT record_id, record_key, score, method, source
186                 FROM entity_members WHERE entity_id = ?1",
187            )
188            .map_err(|e| ZerError::Store(e.to_string()))?;
189
190        let members: Vec<EntityMember> = stmt
191            .query_map([id as i64], |row| {
192                Ok((
193                    row.get::<_, i64>(0)? as RecordId,
194                    row.get::<_, String>(1)?,
195                    row.get::<_, f32>(2)?,
196                    row.get::<_, String>(3)?,
197                    row.get::<_, Option<String>>(4)?,
198                ))
199            })
200            .map_err(|e| ZerError::Store(e.to_string()))?
201            .map(|r| {
202                r.map_err(|e| ZerError::Store(e.to_string())).map(
203                    |(rid, record_key, score, method, source)| EntityMember {
204                        record_id: rid,
205                        record_key,
206                        score,
207                        method: method_from_str(&method),
208                        source,
209                    },
210                )
211            })
212            .collect::<Result<_>>()?;
213
214        Ok(Entity { id, members })
215    }
216
217    fn record_to_entity(&self, record_id: RecordId) -> Result<Option<EntityId>> {
218        let conn = self.conn.lock().unwrap();
219        let id: Option<i64> = conn
220            .query_row(
221                "SELECT entity_id FROM entity_members WHERE record_id = ?1",
222                [record_id as i64],
223                |row| row.get(0),
224            )
225            .optional()
226            .map_err(|e| ZerError::Store(e.to_string()))?;
227        Ok(id.map(|i| i as EntityId))
228    }
229
230    fn all_entities(&self) -> Result<Vec<Entity>> {
231        let conn = self.conn.lock().unwrap();
232        let mut stmt = conn
233            .prepare(
234                "SELECT em.entity_id, em.record_id, em.record_key, em.score, em.method, em.source
235                 FROM entity_members em
236                 ORDER BY em.entity_id",
237            )
238            .map_err(|e| ZerError::Store(e.to_string()))?;
239
240        let rows = stmt
241            .query_map([], |row| {
242                Ok((
243                    row.get::<_, i64>(0)? as EntityId,
244                    row.get::<_, i64>(1)? as RecordId,
245                    row.get::<_, String>(2)?,
246                    row.get::<_, f32>(3)?,
247                    row.get::<_, String>(4)?,
248                    row.get::<_, Option<String>>(5)?,
249                ))
250            })
251            .map_err(|e| ZerError::Store(e.to_string()))?;
252
253        let mut entities: Vec<Entity> = Vec::new();
254
255        for row in rows {
256            let (eid, rid, record_key, score, method, source) =
257                row.map_err(|e| ZerError::Store(e.to_string()))?;
258            let member = EntityMember {
259                record_id: rid,
260                record_key,
261                score,
262                method: method_from_str(&method),
263                source,
264            };
265            match entities.last_mut() {
266                Some(e) if e.id == eid => e.members.push(member),
267                _ => entities.push(Entity {
268                    id: eid,
269                    members: vec![member],
270                }),
271            }
272        }
273
274        Ok(entities)
275    }
276}
277
278// ── Method round-trip helpers ─────────────────────────────────────────────────
279
280fn method_to_str(method: ResolutionMethod) -> &'static str {
281    match method {
282        ResolutionMethod::AutoMatch => "AutoMatch",
283        ResolutionMethod::JudgePromoted => "JudgePromoted",
284        ResolutionMethod::JudgeDemoted => "JudgeDemoted",
285        ResolutionMethod::Manual => "Manual",
286    }
287}
288
289fn method_from_str(s: &str) -> ResolutionMethod {
290    match s {
291        "JudgePromoted" => ResolutionMethod::JudgePromoted,
292        "JudgeDemoted" => ResolutionMethod::JudgeDemoted,
293        "Manual" => ResolutionMethod::Manual,
294        _ => ResolutionMethod::AutoMatch,
295    }
296}
297
298// ── Unit tests ────────────────────────────────────────────────────────────────
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use zer_core::{entity::ResolutionMethod, traits::EntityStore};
304
305    fn make_entity(id: EntityId, record_ids: &[RecordId]) -> Entity {
306        Entity {
307            id,
308            members: record_ids
309                .iter()
310                .map(|&rid| EntityMember {
311                    record_id: rid,
312                    record_key: rid.to_string(),
313                    score: 0.95,
314                    method: ResolutionMethod::AutoMatch,
315                    source: None,
316                })
317                .collect(),
318        }
319    }
320
321    #[test]
322    fn open_in_memory_creates_schema() {
323        ZalEntityStore::open_in_memory().unwrap();
324    }
325
326    #[test]
327    fn upsert_new_entity_returns_id() {
328        let store = ZalEntityStore::open_in_memory().unwrap();
329        let entity = make_entity(0, &[1, 2, 3]);
330        let eid = store.upsert_entity(&entity).unwrap();
331        assert!(eid >= 1, "autoincrement id must be ≥ 1");
332    }
333
334    #[test]
335    fn upsert_same_entity_merges_members() {
336        let store = ZalEntityStore::open_in_memory().unwrap();
337
338        let e1 = make_entity(0, &[1, 2]);
339        let eid = store.upsert_entity(&e1).unwrap();
340
341        // Second upsert shares record 2, should merge into the same entity.
342        let e2 = make_entity(0, &[2, 3]);
343        let eid2 = store.upsert_entity(&e2).unwrap();
344
345        assert_eq!(eid, eid2, "same entity_id must be returned on merge");
346
347        let loaded = store.get_entity(eid).unwrap();
348        let rids: Vec<RecordId> = loaded.members.iter().map(|m| m.record_id).collect();
349        assert!(rids.contains(&1));
350        assert!(rids.contains(&2));
351        assert!(rids.contains(&3));
352    }
353
354    #[test]
355    fn record_to_entity_returns_correct_id() {
356        let store = ZalEntityStore::open_in_memory().unwrap();
357        let entity = make_entity(0, &[10, 20]);
358        let eid = store.upsert_entity(&entity).unwrap();
359
360        assert_eq!(store.record_to_entity(10).unwrap(), Some(eid));
361        assert_eq!(store.record_to_entity(20).unwrap(), Some(eid));
362    }
363
364    #[test]
365    fn record_to_entity_missing_returns_none() {
366        let store = ZalEntityStore::open_in_memory().unwrap();
367        assert!(store.record_to_entity(999).unwrap().is_none());
368    }
369
370    #[test]
371    fn all_entities_returns_all() {
372        let store = ZalEntityStore::open_in_memory().unwrap();
373        store.upsert_entity(&make_entity(0, &[1, 2])).unwrap();
374        store.upsert_entity(&make_entity(0, &[3, 4])).unwrap();
375        store.upsert_entity(&make_entity(0, &[5, 6])).unwrap();
376
377        let all = store.all_entities().unwrap();
378        assert_eq!(all.len(), 3);
379    }
380
381    #[test]
382    fn record_key_survives_round_trip() {
383        let store = ZalEntityStore::open_in_memory().unwrap();
384        let entity = Entity {
385            id: 0,
386            members: vec![EntityMember {
387                record_id: 42,
388                record_key: "893479421".to_string(),
389                score: 0.99,
390                method: ResolutionMethod::AutoMatch,
391                source: Some("brp".to_string()),
392            }],
393        };
394        let eid = store.upsert_entity(&entity).unwrap();
395        let loaded = store.get_entity(eid).unwrap();
396        assert_eq!(loaded.members[0].record_key, "893479421");
397        assert_eq!(loaded.members[0].source.as_deref(), Some("brp"));
398    }
399
400    #[test]
401    fn provenance_event_written_on_create() {
402        let store = ZalEntityStore::open_in_memory().unwrap();
403        store.upsert_entity(&make_entity(0, &[1, 2])).unwrap();
404
405        let conn = store.conn.lock().unwrap();
406        let count: i64 = conn
407            .query_row(
408                "SELECT COUNT(*) FROM resolution_events WHERE event_type = 'EntityCreated'",
409                [],
410                |r| r.get(0),
411            )
412            .unwrap();
413        assert_eq!(count, 1);
414    }
415}