Skip to main content

zer_cluster/store/
zes.rs

1use std::{path::Path, sync::Mutex};
2
3use rusqlite::{Connection, OptionalExtension};
4use zer_core::{
5    entity::{Entity, EntityId, EntityMember, ResolutionMethod},
6    error::ZerError,
7    record::RecordId,
8    traits::{EntityStore, Result},
9};
10
11use crate::provenance::{append_event, unix_now, ResolutionEvent};
12
13/// SQLite-backed entity store persisted as a single `.zes` file.
14///
15/// Uses `rusqlite/bundled` so no system SQLite installation is required.
16/// All mutations hold the connection `Mutex` for the duration of the
17/// transaction, suitable for single-threaded or lightly-concurrent use.
18pub struct ZalEntityStore {
19    conn: Mutex<Connection>,
20}
21
22impl ZalEntityStore {
23    /// Open (or create) a `.zes` store at the given path.
24    pub fn open(path: &Path) -> Result<Self> {
25        let conn = Connection::open(path)
26            .map_err(|e| ZerError::Store(e.to_string()))?;
27        init_schema(&conn)?;
28        Ok(Self { conn: Mutex::new(conn) })
29    }
30
31    /// Open an in-memory store. No file is created; data is lost on drop.
32    pub fn open_in_memory() -> Result<Self> {
33        let conn = Connection::open_in_memory()
34            .map_err(|e| ZerError::Store(e.to_string()))?;
35        init_schema(&conn)?;
36        Ok(Self { conn: Mutex::new(conn) })
37    }
38}
39
40fn init_schema(conn: &Connection) -> Result<()> {
41    conn.execute_batch(
42        "CREATE TABLE IF NOT EXISTS entities (
43            entity_id  INTEGER PRIMARY KEY AUTOINCREMENT,
44            created_at INTEGER NOT NULL,
45            updated_at INTEGER NOT NULL
46        );
47
48        CREATE TABLE IF NOT EXISTS entity_members (
49            id        INTEGER PRIMARY KEY AUTOINCREMENT,
50            entity_id INTEGER NOT NULL REFERENCES entities(entity_id),
51            record_id INTEGER NOT NULL,
52            score     REAL    NOT NULL,
53            method    TEXT    NOT NULL,
54            source    TEXT,
55            added_at  INTEGER NOT NULL
56        );
57        CREATE UNIQUE INDEX IF NOT EXISTS idx_record_entity ON entity_members(record_id);
58
59        CREATE TABLE IF NOT EXISTS resolution_events (
60            id            INTEGER PRIMARY KEY AUTOINCREMENT,
61            event_type    TEXT    NOT NULL,
62            entity_id     INTEGER NOT NULL,
63            record_ids    TEXT    NOT NULL,
64            score         REAL,
65            judge_verdict TEXT,
66            occurred_at   INTEGER NOT NULL
67        );",
68    )
69    .map_err(|e| ZerError::Store(e.to_string()))
70}
71
72impl EntityStore for ZalEntityStore {
73    fn upsert_entity(&self, entity: &Entity) -> Result<EntityId> {
74        let conn = self.conn.lock().unwrap();
75        let now  = unix_now();
76
77        // Find if any member already belongs to an existing entity.
78        let mut existing_id: Option<EntityId> = None;
79        for member in &entity.members {
80            let id: Option<i64> = conn
81                .query_row(
82                    "SELECT entity_id FROM entity_members WHERE record_id = ?1",
83                    [member.record_id as i64],
84                    |row| row.get(0),
85                )
86                .optional()
87                .map_err(|e| ZerError::Store(e.to_string()))?;
88
89            if let Some(eid) = id {
90                existing_id = Some(eid as EntityId);
91                break;
92            }
93        }
94
95        if let Some(eid) = existing_id {
96            // Entity already exists, merge new members in.
97            conn.execute(
98                "UPDATE entities SET updated_at = ?1 WHERE entity_id = ?2",
99                rusqlite::params![now, eid as i64],
100            )
101            .map_err(|e| ZerError::Store(e.to_string()))?;
102
103            let new_record_ids: Vec<RecordId> = entity.members.iter().map(|m| m.record_id).collect();
104            for member in &entity.members {
105                conn.execute(
106                    "INSERT OR IGNORE INTO entity_members
107                         (entity_id, record_id, score, method, source, added_at)
108                     VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
109                    rusqlite::params![
110                        eid as i64,
111                        member.record_id as i64,
112                        member.score,
113                        method_to_str(member.method),
114                        member.source.as_deref(),
115                        now,
116                    ],
117                )
118                .map_err(|e| ZerError::Store(e.to_string()))?;
119            }
120
121            append_event(
122                &conn,
123                &ResolutionEvent::RecordsAdded {
124                    entity_id: eid,
125                    record_ids: new_record_ids,
126                    method: entity.members.first().map(|m| m.method).unwrap_or(ResolutionMethod::AutoMatch),
127                },
128            )?;
129
130            Ok(eid)
131        } else {
132            // Brand-new entity.
133            conn.execute(
134                "INSERT INTO entities (created_at, updated_at) VALUES (?1, ?2)",
135                rusqlite::params![now, now],
136            )
137            .map_err(|e| ZerError::Store(e.to_string()))?;
138
139            let eid = conn.last_insert_rowid() as EntityId;
140
141            let record_ids: Vec<RecordId> = entity.members.iter().map(|m| m.record_id).collect();
142            for member in &entity.members {
143                conn.execute(
144                    "INSERT INTO entity_members
145                         (entity_id, record_id, score, method, source, added_at)
146                     VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
147                    rusqlite::params![
148                        eid as i64,
149                        member.record_id as i64,
150                        member.score,
151                        method_to_str(member.method),
152                        member.source.as_deref(),
153                        now,
154                    ],
155                )
156                .map_err(|e| ZerError::Store(e.to_string()))?;
157            }
158
159            append_event(
160                &conn,
161                &ResolutionEvent::EntityCreated { entity_id: eid, record_ids },
162            )?;
163
164            Ok(eid)
165        }
166    }
167
168    fn get_entity(&self, id: EntityId) -> Result<Entity> {
169        let conn = self.conn.lock().unwrap();
170        let mut stmt = conn
171            .prepare(
172                "SELECT record_id, score, method, source
173                 FROM entity_members WHERE entity_id = ?1",
174            )
175            .map_err(|e| ZerError::Store(e.to_string()))?;
176
177        let members: Vec<EntityMember> = stmt
178            .query_map([id as i64], |row| {
179                Ok((
180                    row.get::<_, i64>(0)? as RecordId,
181                    row.get::<_, f32>(1)?,
182                    row.get::<_, String>(2)?,
183                    row.get::<_, Option<String>>(3)?,
184                ))
185            })
186            .map_err(|e| ZerError::Store(e.to_string()))?
187            .map(|r| {
188                r.map_err(|e| ZerError::Store(e.to_string())).and_then(|(rid, score, method, source)| {
189                    Ok(EntityMember {
190                        record_id: rid,
191                        score,
192                        method: method_from_str(&method),
193                        source,
194                    })
195                })
196            })
197            .collect::<Result<_>>()?;
198
199        Ok(Entity { id, members })
200    }
201
202    fn record_to_entity(&self, record_id: RecordId) -> Result<Option<EntityId>> {
203        let conn = self.conn.lock().unwrap();
204        let id: Option<i64> = conn
205            .query_row(
206                "SELECT entity_id FROM entity_members WHERE record_id = ?1",
207                [record_id as i64],
208                |row| row.get(0),
209            )
210            .optional()
211            .map_err(|e| ZerError::Store(e.to_string()))?;
212        Ok(id.map(|i| i as EntityId))
213    }
214
215    fn all_entities(&self) -> Result<Vec<Entity>> {
216        let conn = self.conn.lock().unwrap();
217        let mut stmt = conn
218            .prepare(
219                "SELECT em.entity_id, em.record_id, em.score, em.method, em.source
220                 FROM entity_members em
221                 ORDER BY em.entity_id",
222            )
223            .map_err(|e| ZerError::Store(e.to_string()))?;
224
225        let rows = stmt
226            .query_map([], |row| {
227                Ok((
228                    row.get::<_, i64>(0)? as EntityId,
229                    row.get::<_, i64>(1)? as RecordId,
230                    row.get::<_, f32>(2)?,
231                    row.get::<_, String>(3)?,
232                    row.get::<_, Option<String>>(4)?,
233                ))
234            })
235            .map_err(|e| ZerError::Store(e.to_string()))?;
236
237        let mut entities: Vec<Entity> = Vec::new();
238
239        for row in rows {
240            let (eid, rid, score, method, source) =
241                row.map_err(|e| ZerError::Store(e.to_string()))?;
242            let member = EntityMember {
243                record_id: rid,
244                score,
245                method: method_from_str(&method),
246                source,
247            };
248            match entities.last_mut() {
249                Some(e) if e.id == eid => e.members.push(member),
250                _ => entities.push(Entity { id: eid, members: vec![member] }),
251            }
252        }
253
254        Ok(entities)
255    }
256}
257
258// ── Method round-trip helpers ─────────────────────────────────────────────────
259
260fn method_to_str(method: ResolutionMethod) -> &'static str {
261    match method {
262        ResolutionMethod::AutoMatch    => "AutoMatch",
263        ResolutionMethod::JudgePromoted => "JudgePromoted",
264        ResolutionMethod::JudgeDemoted  => "JudgeDemoted",
265        ResolutionMethod::Manual        => "Manual",
266    }
267}
268
269fn method_from_str(s: &str) -> ResolutionMethod {
270    match s {
271        "JudgePromoted" => ResolutionMethod::JudgePromoted,
272        "JudgeDemoted"  => ResolutionMethod::JudgeDemoted,
273        "Manual"        => ResolutionMethod::Manual,
274        _               => ResolutionMethod::AutoMatch,
275    }
276}
277
278// ── Unit tests ────────────────────────────────────────────────────────────────
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use zer_core::{entity::ResolutionMethod, traits::EntityStore};
284
285    fn make_entity(id: EntityId, record_ids: &[RecordId]) -> Entity {
286        Entity {
287            id,
288            members: record_ids
289                .iter()
290                .map(|&rid| EntityMember {
291                    record_id: rid,
292                    score:     0.95,
293                    method:    ResolutionMethod::AutoMatch,
294                    source:    None,
295                })
296                .collect(),
297        }
298    }
299
300    #[test]
301    fn open_in_memory_creates_schema() {
302        ZalEntityStore::open_in_memory().unwrap();
303    }
304
305    #[test]
306    fn upsert_new_entity_returns_id() {
307        let store = ZalEntityStore::open_in_memory().unwrap();
308        let entity = make_entity(0, &[1, 2, 3]);
309        let eid = store.upsert_entity(&entity).unwrap();
310        assert!(eid >= 1, "autoincrement id must be ≥ 1");
311    }
312
313    #[test]
314    fn upsert_same_entity_merges_members() {
315        let store = ZalEntityStore::open_in_memory().unwrap();
316
317        let e1 = make_entity(0, &[1, 2]);
318        let eid = store.upsert_entity(&e1).unwrap();
319
320        // Second upsert shares record 2, should merge into the same entity.
321        let e2 = make_entity(0, &[2, 3]);
322        let eid2 = store.upsert_entity(&e2).unwrap();
323
324        assert_eq!(eid, eid2, "same entity_id must be returned on merge");
325
326        let loaded = store.get_entity(eid).unwrap();
327        let rids: Vec<RecordId> = loaded.members.iter().map(|m| m.record_id).collect();
328        assert!(rids.contains(&1));
329        assert!(rids.contains(&2));
330        assert!(rids.contains(&3));
331    }
332
333    #[test]
334    fn record_to_entity_returns_correct_id() {
335        let store = ZalEntityStore::open_in_memory().unwrap();
336        let entity = make_entity(0, &[10, 20]);
337        let eid = store.upsert_entity(&entity).unwrap();
338
339        assert_eq!(store.record_to_entity(10).unwrap(), Some(eid));
340        assert_eq!(store.record_to_entity(20).unwrap(), Some(eid));
341    }
342
343    #[test]
344    fn record_to_entity_missing_returns_none() {
345        let store = ZalEntityStore::open_in_memory().unwrap();
346        assert!(store.record_to_entity(999).unwrap().is_none());
347    }
348
349    #[test]
350    fn all_entities_returns_all() {
351        let store = ZalEntityStore::open_in_memory().unwrap();
352        store.upsert_entity(&make_entity(0, &[1, 2])).unwrap();
353        store.upsert_entity(&make_entity(0, &[3, 4])).unwrap();
354        store.upsert_entity(&make_entity(0, &[5, 6])).unwrap();
355
356        let all = store.all_entities().unwrap();
357        assert_eq!(all.len(), 3);
358    }
359
360    #[test]
361    fn provenance_event_written_on_create() {
362        let store = ZalEntityStore::open_in_memory().unwrap();
363        store.upsert_entity(&make_entity(0, &[1, 2])).unwrap();
364
365        let conn = store.conn.lock().unwrap();
366        let count: i64 = conn
367            .query_row(
368                "SELECT COUNT(*) FROM resolution_events WHERE event_type = 'EntityCreated'",
369                [],
370                |r| r.get(0),
371            )
372            .unwrap();
373        assert_eq!(count, 1);
374    }
375}