1use std::{path::Path, sync::Mutex};
2
3use rusqlite::{Connection, OptionalExtension};
4use zer_core::{
5 entity::{Entity, EntityId, EntityMember, ResolutionMethod},
6 error::ZerError,
7 record::RecordId,
8 traits::{EntityStore, Result},
9};
10
11use crate::provenance::{append_event, unix_now, ResolutionEvent};
12
13pub struct ZalEntityStore {
19 conn: Mutex<Connection>,
20}
21
22impl ZalEntityStore {
23 pub fn open(path: &Path) -> Result<Self> {
25 let conn = Connection::open(path)
26 .map_err(|e| ZerError::Store(e.to_string()))?;
27 init_schema(&conn)?;
28 Ok(Self { conn: Mutex::new(conn) })
29 }
30
31 pub fn open_in_memory() -> Result<Self> {
33 let conn = Connection::open_in_memory()
34 .map_err(|e| ZerError::Store(e.to_string()))?;
35 init_schema(&conn)?;
36 Ok(Self { conn: Mutex::new(conn) })
37 }
38}
39
40fn init_schema(conn: &Connection) -> Result<()> {
41 conn.execute_batch(
42 "CREATE TABLE IF NOT EXISTS entities (
43 entity_id INTEGER PRIMARY KEY AUTOINCREMENT,
44 created_at INTEGER NOT NULL,
45 updated_at INTEGER NOT NULL
46 );
47
48 CREATE TABLE IF NOT EXISTS entity_members (
49 id INTEGER PRIMARY KEY AUTOINCREMENT,
50 entity_id INTEGER NOT NULL REFERENCES entities(entity_id),
51 record_id INTEGER NOT NULL,
52 score REAL NOT NULL,
53 method TEXT NOT NULL,
54 source TEXT,
55 added_at INTEGER NOT NULL
56 );
57 CREATE UNIQUE INDEX IF NOT EXISTS idx_record_entity ON entity_members(record_id);
58
59 CREATE TABLE IF NOT EXISTS resolution_events (
60 id INTEGER PRIMARY KEY AUTOINCREMENT,
61 event_type TEXT NOT NULL,
62 entity_id INTEGER NOT NULL,
63 record_ids TEXT NOT NULL,
64 score REAL,
65 judge_verdict TEXT,
66 occurred_at INTEGER NOT NULL
67 );",
68 )
69 .map_err(|e| ZerError::Store(e.to_string()))
70}
71
72impl EntityStore for ZalEntityStore {
73 fn upsert_entity(&self, entity: &Entity) -> Result<EntityId> {
74 let conn = self.conn.lock().unwrap();
75 let now = unix_now();
76
77 let mut existing_id: Option<EntityId> = None;
79 for member in &entity.members {
80 let id: Option<i64> = conn
81 .query_row(
82 "SELECT entity_id FROM entity_members WHERE record_id = ?1",
83 [member.record_id as i64],
84 |row| row.get(0),
85 )
86 .optional()
87 .map_err(|e| ZerError::Store(e.to_string()))?;
88
89 if let Some(eid) = id {
90 existing_id = Some(eid as EntityId);
91 break;
92 }
93 }
94
95 if let Some(eid) = existing_id {
96 conn.execute(
98 "UPDATE entities SET updated_at = ?1 WHERE entity_id = ?2",
99 rusqlite::params![now, eid as i64],
100 )
101 .map_err(|e| ZerError::Store(e.to_string()))?;
102
103 let new_record_ids: Vec<RecordId> = entity.members.iter().map(|m| m.record_id).collect();
104 for member in &entity.members {
105 conn.execute(
106 "INSERT OR IGNORE INTO entity_members
107 (entity_id, record_id, score, method, source, added_at)
108 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
109 rusqlite::params![
110 eid as i64,
111 member.record_id as i64,
112 member.score,
113 method_to_str(member.method),
114 member.source.as_deref(),
115 now,
116 ],
117 )
118 .map_err(|e| ZerError::Store(e.to_string()))?;
119 }
120
121 append_event(
122 &conn,
123 &ResolutionEvent::RecordsAdded {
124 entity_id: eid,
125 record_ids: new_record_ids,
126 method: entity.members.first().map(|m| m.method).unwrap_or(ResolutionMethod::AutoMatch),
127 },
128 )?;
129
130 Ok(eid)
131 } else {
132 conn.execute(
134 "INSERT INTO entities (created_at, updated_at) VALUES (?1, ?2)",
135 rusqlite::params![now, now],
136 )
137 .map_err(|e| ZerError::Store(e.to_string()))?;
138
139 let eid = conn.last_insert_rowid() as EntityId;
140
141 let record_ids: Vec<RecordId> = entity.members.iter().map(|m| m.record_id).collect();
142 for member in &entity.members {
143 conn.execute(
144 "INSERT INTO entity_members
145 (entity_id, record_id, score, method, source, added_at)
146 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
147 rusqlite::params![
148 eid as i64,
149 member.record_id as i64,
150 member.score,
151 method_to_str(member.method),
152 member.source.as_deref(),
153 now,
154 ],
155 )
156 .map_err(|e| ZerError::Store(e.to_string()))?;
157 }
158
159 append_event(
160 &conn,
161 &ResolutionEvent::EntityCreated { entity_id: eid, record_ids },
162 )?;
163
164 Ok(eid)
165 }
166 }
167
168 fn get_entity(&self, id: EntityId) -> Result<Entity> {
169 let conn = self.conn.lock().unwrap();
170 let mut stmt = conn
171 .prepare(
172 "SELECT record_id, score, method, source
173 FROM entity_members WHERE entity_id = ?1",
174 )
175 .map_err(|e| ZerError::Store(e.to_string()))?;
176
177 let members: Vec<EntityMember> = stmt
178 .query_map([id as i64], |row| {
179 Ok((
180 row.get::<_, i64>(0)? as RecordId,
181 row.get::<_, f32>(1)?,
182 row.get::<_, String>(2)?,
183 row.get::<_, Option<String>>(3)?,
184 ))
185 })
186 .map_err(|e| ZerError::Store(e.to_string()))?
187 .map(|r| {
188 r.map_err(|e| ZerError::Store(e.to_string())).and_then(|(rid, score, method, source)| {
189 Ok(EntityMember {
190 record_id: rid,
191 score,
192 method: method_from_str(&method),
193 source,
194 })
195 })
196 })
197 .collect::<Result<_>>()?;
198
199 Ok(Entity { id, members })
200 }
201
202 fn record_to_entity(&self, record_id: RecordId) -> Result<Option<EntityId>> {
203 let conn = self.conn.lock().unwrap();
204 let id: Option<i64> = conn
205 .query_row(
206 "SELECT entity_id FROM entity_members WHERE record_id = ?1",
207 [record_id as i64],
208 |row| row.get(0),
209 )
210 .optional()
211 .map_err(|e| ZerError::Store(e.to_string()))?;
212 Ok(id.map(|i| i as EntityId))
213 }
214
215 fn all_entities(&self) -> Result<Vec<Entity>> {
216 let conn = self.conn.lock().unwrap();
217 let mut stmt = conn
218 .prepare(
219 "SELECT em.entity_id, em.record_id, em.score, em.method, em.source
220 FROM entity_members em
221 ORDER BY em.entity_id",
222 )
223 .map_err(|e| ZerError::Store(e.to_string()))?;
224
225 let rows = stmt
226 .query_map([], |row| {
227 Ok((
228 row.get::<_, i64>(0)? as EntityId,
229 row.get::<_, i64>(1)? as RecordId,
230 row.get::<_, f32>(2)?,
231 row.get::<_, String>(3)?,
232 row.get::<_, Option<String>>(4)?,
233 ))
234 })
235 .map_err(|e| ZerError::Store(e.to_string()))?;
236
237 let mut entities: Vec<Entity> = Vec::new();
238
239 for row in rows {
240 let (eid, rid, score, method, source) =
241 row.map_err(|e| ZerError::Store(e.to_string()))?;
242 let member = EntityMember {
243 record_id: rid,
244 score,
245 method: method_from_str(&method),
246 source,
247 };
248 match entities.last_mut() {
249 Some(e) if e.id == eid => e.members.push(member),
250 _ => entities.push(Entity { id: eid, members: vec![member] }),
251 }
252 }
253
254 Ok(entities)
255 }
256}
257
258fn method_to_str(method: ResolutionMethod) -> &'static str {
261 match method {
262 ResolutionMethod::AutoMatch => "AutoMatch",
263 ResolutionMethod::JudgePromoted => "JudgePromoted",
264 ResolutionMethod::JudgeDemoted => "JudgeDemoted",
265 ResolutionMethod::Manual => "Manual",
266 }
267}
268
269fn method_from_str(s: &str) -> ResolutionMethod {
270 match s {
271 "JudgePromoted" => ResolutionMethod::JudgePromoted,
272 "JudgeDemoted" => ResolutionMethod::JudgeDemoted,
273 "Manual" => ResolutionMethod::Manual,
274 _ => ResolutionMethod::AutoMatch,
275 }
276}
277
278#[cfg(test)]
281mod tests {
282 use super::*;
283 use zer_core::{entity::ResolutionMethod, traits::EntityStore};
284
285 fn make_entity(id: EntityId, record_ids: &[RecordId]) -> Entity {
286 Entity {
287 id,
288 members: record_ids
289 .iter()
290 .map(|&rid| EntityMember {
291 record_id: rid,
292 score: 0.95,
293 method: ResolutionMethod::AutoMatch,
294 source: None,
295 })
296 .collect(),
297 }
298 }
299
300 #[test]
301 fn open_in_memory_creates_schema() {
302 ZalEntityStore::open_in_memory().unwrap();
303 }
304
305 #[test]
306 fn upsert_new_entity_returns_id() {
307 let store = ZalEntityStore::open_in_memory().unwrap();
308 let entity = make_entity(0, &[1, 2, 3]);
309 let eid = store.upsert_entity(&entity).unwrap();
310 assert!(eid >= 1, "autoincrement id must be ≥ 1");
311 }
312
313 #[test]
314 fn upsert_same_entity_merges_members() {
315 let store = ZalEntityStore::open_in_memory().unwrap();
316
317 let e1 = make_entity(0, &[1, 2]);
318 let eid = store.upsert_entity(&e1).unwrap();
319
320 let e2 = make_entity(0, &[2, 3]);
322 let eid2 = store.upsert_entity(&e2).unwrap();
323
324 assert_eq!(eid, eid2, "same entity_id must be returned on merge");
325
326 let loaded = store.get_entity(eid).unwrap();
327 let rids: Vec<RecordId> = loaded.members.iter().map(|m| m.record_id).collect();
328 assert!(rids.contains(&1));
329 assert!(rids.contains(&2));
330 assert!(rids.contains(&3));
331 }
332
333 #[test]
334 fn record_to_entity_returns_correct_id() {
335 let store = ZalEntityStore::open_in_memory().unwrap();
336 let entity = make_entity(0, &[10, 20]);
337 let eid = store.upsert_entity(&entity).unwrap();
338
339 assert_eq!(store.record_to_entity(10).unwrap(), Some(eid));
340 assert_eq!(store.record_to_entity(20).unwrap(), Some(eid));
341 }
342
343 #[test]
344 fn record_to_entity_missing_returns_none() {
345 let store = ZalEntityStore::open_in_memory().unwrap();
346 assert!(store.record_to_entity(999).unwrap().is_none());
347 }
348
349 #[test]
350 fn all_entities_returns_all() {
351 let store = ZalEntityStore::open_in_memory().unwrap();
352 store.upsert_entity(&make_entity(0, &[1, 2])).unwrap();
353 store.upsert_entity(&make_entity(0, &[3, 4])).unwrap();
354 store.upsert_entity(&make_entity(0, &[5, 6])).unwrap();
355
356 let all = store.all_entities().unwrap();
357 assert_eq!(all.len(), 3);
358 }
359
360 #[test]
361 fn provenance_event_written_on_create() {
362 let store = ZalEntityStore::open_in_memory().unwrap();
363 store.upsert_entity(&make_entity(0, &[1, 2])).unwrap();
364
365 let conn = store.conn.lock().unwrap();
366 let count: i64 = conn
367 .query_row(
368 "SELECT COUNT(*) FROM resolution_events WHERE event_type = 'EntityCreated'",
369 [],
370 |r| r.get(0),
371 )
372 .unwrap();
373 assert_eq!(count, 1);
374 }
375}