1use std::{path::Path, sync::Mutex};
2
3use rusqlite::{Connection, OptionalExtension};
4use zer_core::{
5 entity::{Entity, EntityId, EntityMember, ResolutionMethod},
6 error::ZerError,
7 record::RecordId,
8 traits::{EntityStore, Result},
9};
10
11use crate::provenance::{append_event, unix_now, ResolutionEvent};
12
13pub struct ZalEntityStore {
19 conn: Mutex<Connection>,
20}
21
22impl ZalEntityStore {
23 pub fn open(path: &Path) -> Result<Self> {
25 let conn = Connection::open(path).map_err(|e| ZerError::Store(e.to_string()))?;
26 init_schema(&conn)?;
27 Ok(Self {
28 conn: Mutex::new(conn),
29 })
30 }
31
32 pub fn open_in_memory() -> Result<Self> {
34 let conn = Connection::open_in_memory().map_err(|e| ZerError::Store(e.to_string()))?;
35 init_schema(&conn)?;
36 Ok(Self {
37 conn: Mutex::new(conn),
38 })
39 }
40}
41
42fn init_schema(conn: &Connection) -> Result<()> {
43 conn.execute_batch(
44 "CREATE TABLE IF NOT EXISTS entities (
45 entity_id INTEGER PRIMARY KEY AUTOINCREMENT,
46 created_at INTEGER NOT NULL,
47 updated_at INTEGER NOT NULL
48 );
49
50 CREATE TABLE IF NOT EXISTS entity_members (
51 id INTEGER PRIMARY KEY AUTOINCREMENT,
52 entity_id INTEGER NOT NULL REFERENCES entities(entity_id),
53 record_id INTEGER NOT NULL,
54 record_key TEXT NOT NULL DEFAULT '',
55 score REAL NOT NULL,
56 method TEXT NOT NULL,
57 source TEXT,
58 added_at INTEGER NOT NULL
59 );
60 CREATE UNIQUE INDEX IF NOT EXISTS idx_record_entity ON entity_members(record_id);
61
62 CREATE TABLE IF NOT EXISTS resolution_events (
63 id INTEGER PRIMARY KEY AUTOINCREMENT,
64 event_type TEXT NOT NULL,
65 entity_id INTEGER NOT NULL,
66 record_ids TEXT NOT NULL,
67 score REAL,
68 judge_verdict TEXT,
69 occurred_at INTEGER NOT NULL
70 );",
71 )
72 .map_err(|e| ZerError::Store(e.to_string()))
73}
74
75impl EntityStore for ZalEntityStore {
76 fn upsert_entity(&self, entity: &Entity) -> Result<EntityId> {
77 let conn = self.conn.lock().unwrap();
78 let now = unix_now();
79
80 let mut existing_id: Option<EntityId> = None;
82 for member in &entity.members {
83 let id: Option<i64> = conn
84 .query_row(
85 "SELECT entity_id FROM entity_members WHERE record_id = ?1",
86 [member.record_id as i64],
87 |row| row.get(0),
88 )
89 .optional()
90 .map_err(|e| ZerError::Store(e.to_string()))?;
91
92 if let Some(eid) = id {
93 existing_id = Some(eid as EntityId);
94 break;
95 }
96 }
97
98 if let Some(eid) = existing_id {
99 conn.execute(
101 "UPDATE entities SET updated_at = ?1 WHERE entity_id = ?2",
102 rusqlite::params![now, eid as i64],
103 )
104 .map_err(|e| ZerError::Store(e.to_string()))?;
105
106 let new_record_ids: Vec<RecordId> =
107 entity.members.iter().map(|m| m.record_id).collect();
108 for member in &entity.members {
109 conn.execute(
110 "INSERT OR IGNORE INTO entity_members
111 (entity_id, record_id, record_key, score, method, source, added_at)
112 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
113 rusqlite::params![
114 eid as i64,
115 member.record_id as i64,
116 &member.record_key,
117 member.score,
118 method_to_str(member.method),
119 member.source.as_deref(),
120 now,
121 ],
122 )
123 .map_err(|e| ZerError::Store(e.to_string()))?;
124 }
125
126 append_event(
127 &conn,
128 &ResolutionEvent::RecordsAdded {
129 entity_id: eid,
130 record_ids: new_record_ids,
131 method: entity
132 .members
133 .first()
134 .map(|m| m.method)
135 .unwrap_or(ResolutionMethod::AutoMatch),
136 },
137 )?;
138
139 Ok(eid)
140 } else {
141 conn.execute(
143 "INSERT INTO entities (created_at, updated_at) VALUES (?1, ?2)",
144 rusqlite::params![now, now],
145 )
146 .map_err(|e| ZerError::Store(e.to_string()))?;
147
148 let eid = conn.last_insert_rowid() as EntityId;
149
150 let record_ids: Vec<RecordId> = entity.members.iter().map(|m| m.record_id).collect();
151 for member in &entity.members {
152 conn.execute(
153 "INSERT INTO entity_members
154 (entity_id, record_id, record_key, score, method, source, added_at)
155 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
156 rusqlite::params![
157 eid as i64,
158 member.record_id as i64,
159 &member.record_key,
160 member.score,
161 method_to_str(member.method),
162 member.source.as_deref(),
163 now,
164 ],
165 )
166 .map_err(|e| ZerError::Store(e.to_string()))?;
167 }
168
169 append_event(
170 &conn,
171 &ResolutionEvent::EntityCreated {
172 entity_id: eid,
173 record_ids,
174 },
175 )?;
176
177 Ok(eid)
178 }
179 }
180
181 fn get_entity(&self, id: EntityId) -> Result<Entity> {
182 let conn = self.conn.lock().unwrap();
183 let mut stmt = conn
184 .prepare(
185 "SELECT record_id, record_key, score, method, source
186 FROM entity_members WHERE entity_id = ?1",
187 )
188 .map_err(|e| ZerError::Store(e.to_string()))?;
189
190 let members: Vec<EntityMember> = stmt
191 .query_map([id as i64], |row| {
192 Ok((
193 row.get::<_, i64>(0)? as RecordId,
194 row.get::<_, String>(1)?,
195 row.get::<_, f32>(2)?,
196 row.get::<_, String>(3)?,
197 row.get::<_, Option<String>>(4)?,
198 ))
199 })
200 .map_err(|e| ZerError::Store(e.to_string()))?
201 .map(|r| {
202 r.map_err(|e| ZerError::Store(e.to_string())).map(
203 |(rid, record_key, score, method, source)| EntityMember {
204 record_id: rid,
205 record_key,
206 score,
207 method: method_from_str(&method),
208 source,
209 },
210 )
211 })
212 .collect::<Result<_>>()?;
213
214 Ok(Entity { id, members })
215 }
216
217 fn record_to_entity(&self, record_id: RecordId) -> Result<Option<EntityId>> {
218 let conn = self.conn.lock().unwrap();
219 let id: Option<i64> = conn
220 .query_row(
221 "SELECT entity_id FROM entity_members WHERE record_id = ?1",
222 [record_id as i64],
223 |row| row.get(0),
224 )
225 .optional()
226 .map_err(|e| ZerError::Store(e.to_string()))?;
227 Ok(id.map(|i| i as EntityId))
228 }
229
230 fn all_entities(&self) -> Result<Vec<Entity>> {
231 let conn = self.conn.lock().unwrap();
232 let mut stmt = conn
233 .prepare(
234 "SELECT em.entity_id, em.record_id, em.record_key, em.score, em.method, em.source
235 FROM entity_members em
236 ORDER BY em.entity_id",
237 )
238 .map_err(|e| ZerError::Store(e.to_string()))?;
239
240 let rows = stmt
241 .query_map([], |row| {
242 Ok((
243 row.get::<_, i64>(0)? as EntityId,
244 row.get::<_, i64>(1)? as RecordId,
245 row.get::<_, String>(2)?,
246 row.get::<_, f32>(3)?,
247 row.get::<_, String>(4)?,
248 row.get::<_, Option<String>>(5)?,
249 ))
250 })
251 .map_err(|e| ZerError::Store(e.to_string()))?;
252
253 let mut entities: Vec<Entity> = Vec::new();
254
255 for row in rows {
256 let (eid, rid, record_key, score, method, source) =
257 row.map_err(|e| ZerError::Store(e.to_string()))?;
258 let member = EntityMember {
259 record_id: rid,
260 record_key,
261 score,
262 method: method_from_str(&method),
263 source,
264 };
265 match entities.last_mut() {
266 Some(e) if e.id == eid => e.members.push(member),
267 _ => entities.push(Entity {
268 id: eid,
269 members: vec![member],
270 }),
271 }
272 }
273
274 Ok(entities)
275 }
276}
277
278fn method_to_str(method: ResolutionMethod) -> &'static str {
281 match method {
282 ResolutionMethod::AutoMatch => "AutoMatch",
283 ResolutionMethod::JudgePromoted => "JudgePromoted",
284 ResolutionMethod::JudgeDemoted => "JudgeDemoted",
285 ResolutionMethod::Manual => "Manual",
286 }
287}
288
289fn method_from_str(s: &str) -> ResolutionMethod {
290 match s {
291 "JudgePromoted" => ResolutionMethod::JudgePromoted,
292 "JudgeDemoted" => ResolutionMethod::JudgeDemoted,
293 "Manual" => ResolutionMethod::Manual,
294 _ => ResolutionMethod::AutoMatch,
295 }
296}
297
298#[cfg(test)]
301mod tests {
302 use super::*;
303 use zer_core::{entity::ResolutionMethod, traits::EntityStore};
304
305 fn make_entity(id: EntityId, record_ids: &[RecordId]) -> Entity {
306 Entity {
307 id,
308 members: record_ids
309 .iter()
310 .map(|&rid| EntityMember {
311 record_id: rid,
312 record_key: rid.to_string(),
313 score: 0.95,
314 method: ResolutionMethod::AutoMatch,
315 source: None,
316 })
317 .collect(),
318 }
319 }
320
321 #[test]
322 fn open_in_memory_creates_schema() {
323 ZalEntityStore::open_in_memory().unwrap();
324 }
325
326 #[test]
327 fn upsert_new_entity_returns_id() {
328 let store = ZalEntityStore::open_in_memory().unwrap();
329 let entity = make_entity(0, &[1, 2, 3]);
330 let eid = store.upsert_entity(&entity).unwrap();
331 assert!(eid >= 1, "autoincrement id must be ≥ 1");
332 }
333
334 #[test]
335 fn upsert_same_entity_merges_members() {
336 let store = ZalEntityStore::open_in_memory().unwrap();
337
338 let e1 = make_entity(0, &[1, 2]);
339 let eid = store.upsert_entity(&e1).unwrap();
340
341 let e2 = make_entity(0, &[2, 3]);
343 let eid2 = store.upsert_entity(&e2).unwrap();
344
345 assert_eq!(eid, eid2, "same entity_id must be returned on merge");
346
347 let loaded = store.get_entity(eid).unwrap();
348 let rids: Vec<RecordId> = loaded.members.iter().map(|m| m.record_id).collect();
349 assert!(rids.contains(&1));
350 assert!(rids.contains(&2));
351 assert!(rids.contains(&3));
352 }
353
354 #[test]
355 fn record_to_entity_returns_correct_id() {
356 let store = ZalEntityStore::open_in_memory().unwrap();
357 let entity = make_entity(0, &[10, 20]);
358 let eid = store.upsert_entity(&entity).unwrap();
359
360 assert_eq!(store.record_to_entity(10).unwrap(), Some(eid));
361 assert_eq!(store.record_to_entity(20).unwrap(), Some(eid));
362 }
363
364 #[test]
365 fn record_to_entity_missing_returns_none() {
366 let store = ZalEntityStore::open_in_memory().unwrap();
367 assert!(store.record_to_entity(999).unwrap().is_none());
368 }
369
370 #[test]
371 fn all_entities_returns_all() {
372 let store = ZalEntityStore::open_in_memory().unwrap();
373 store.upsert_entity(&make_entity(0, &[1, 2])).unwrap();
374 store.upsert_entity(&make_entity(0, &[3, 4])).unwrap();
375 store.upsert_entity(&make_entity(0, &[5, 6])).unwrap();
376
377 let all = store.all_entities().unwrap();
378 assert_eq!(all.len(), 3);
379 }
380
381 #[test]
382 fn record_key_survives_round_trip() {
383 let store = ZalEntityStore::open_in_memory().unwrap();
384 let entity = Entity {
385 id: 0,
386 members: vec![EntityMember {
387 record_id: 42,
388 record_key: "893479421".to_string(),
389 score: 0.99,
390 method: ResolutionMethod::AutoMatch,
391 source: Some("brp".to_string()),
392 }],
393 };
394 let eid = store.upsert_entity(&entity).unwrap();
395 let loaded = store.get_entity(eid).unwrap();
396 assert_eq!(loaded.members[0].record_key, "893479421");
397 assert_eq!(loaded.members[0].source.as_deref(), Some("brp"));
398 }
399
400 #[test]
401 fn provenance_event_written_on_create() {
402 let store = ZalEntityStore::open_in_memory().unwrap();
403 store.upsert_entity(&make_entity(0, &[1, 2])).unwrap();
404
405 let conn = store.conn.lock().unwrap();
406 let count: i64 = conn
407 .query_row(
408 "SELECT COUNT(*) FROM resolution_events WHERE event_type = 'EntityCreated'",
409 [],
410 |r| r.get(0),
411 )
412 .unwrap();
413 assert_eq!(count, 1);
414 }
415}