1use super::path_tree::{MerkleBatch, MerkleNode, PathMerkle};
23use crate::StorageError;
24use sqlx::{AnyPool, Row, any::AnyPoolOptions};
25use std::collections::BTreeMap;
26use std::time::Duration;
27use tracing::{debug, info, instrument};
28
29const ROOT_PATH: &str = "";
31
32#[derive(Clone)]
34pub struct SqlMerkleStore {
35 pool: AnyPool,
36 is_sqlite: bool,
37}
38
39impl SqlMerkleStore {
40 pub async fn new(connection_string: &str) -> Result<Self, StorageError> {
42 sqlx::any::install_default_drivers();
44
45 let is_sqlite = connection_string.starts_with("sqlite:");
46
47 let pool = AnyPoolOptions::new()
48 .max_connections(5)
49 .acquire_timeout(Duration::from_secs(10))
50 .connect(connection_string)
51 .await
52 .map_err(|e| StorageError::Backend(format!("Failed to connect to SQL merkle store: {}", e)))?;
53
54 let store = Self { pool, is_sqlite };
55 store.init_schema().await?;
56
57 info!("SQL merkle store initialized");
58 Ok(store)
59 }
60
61 pub fn from_pool(pool: AnyPool, is_sqlite: bool) -> Self {
63 Self { pool, is_sqlite }
64 }
65
66 pub async fn init_schema(&self) -> Result<(), StorageError> {
69 let sql = if self.is_sqlite {
71 r#"
72 CREATE TABLE IF NOT EXISTS merkle_nodes (
73 path TEXT PRIMARY KEY,
74 parent_path TEXT,
75 merkle_hash TEXT NOT NULL,
76 object_count INTEGER NOT NULL DEFAULT 0,
77 is_leaf INTEGER DEFAULT 0,
78 updated_at INTEGER DEFAULT (strftime('%s', 'now'))
79 );
80 CREATE INDEX IF NOT EXISTS idx_merkle_parent ON merkle_nodes(parent_path);
81 "#
82 } else {
83 r#"
86 CREATE TABLE IF NOT EXISTS merkle_nodes (
87 path VARCHAR(255) PRIMARY KEY,
88 parent_path VARCHAR(255),
89 merkle_hash VARCHAR(64) NOT NULL,
90 object_count INT NOT NULL DEFAULT 0,
91 is_leaf INT DEFAULT 0,
92 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
93 INDEX idx_merkle_parent (parent_path)
94 ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
95 "#
96 };
97
98 if self.is_sqlite {
100 for stmt in sql.split(';').filter(|s| !s.trim().is_empty()) {
101 sqlx::query(stmt)
102 .execute(&self.pool)
103 .await
104 .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
105 }
106 } else {
107 sqlx::query(sql)
108 .execute(&self.pool)
109 .await
110 .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
111 }
112
113 Ok(())
114 }
115
116 #[instrument(skip(self))]
118 pub async fn root_hash(&self) -> Result<Option<[u8; 32]>, StorageError> {
119 self.get_hash(ROOT_PATH).await
120 }
121
122 pub async fn get_hash(&self, path: &str) -> Result<Option<[u8; 32]>, StorageError> {
124 let result = sqlx::query("SELECT merkle_hash FROM merkle_nodes WHERE path = ?")
125 .bind(path)
126 .fetch_optional(&self.pool)
127 .await
128 .map_err(|e| StorageError::Backend(format!("Failed to get merkle hash: {}", e)))?;
129
130 match result {
131 Some(row) => {
132 let hash_str: String = row.try_get("merkle_hash")
135 .map_err(|e| StorageError::Backend(format!("Failed to read merkle hash: {}", e)))?;
136
137 let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
140 hex::decode(&hash_str).map_err(|e| StorageError::Backend(format!(
141 "Invalid merkle hash hex: {}", e
142 )))?
143 } else {
144 hash_str.into_bytes()
146 };
147
148 if bytes.len() != 32 {
149 return Err(StorageError::Backend(format!(
150 "Invalid merkle hash length: {}", bytes.len()
151 )));
152 }
153 let mut hash = [0u8; 32];
154 hash.copy_from_slice(&bytes);
155 Ok(Some(hash))
156 }
157 None => Ok(None),
158 }
159 }
160
161 pub async fn get_children(&self, path: &str) -> Result<BTreeMap<String, [u8; 32]>, StorageError> {
163 let rows = sqlx::query(
164 "SELECT path, merkle_hash FROM merkle_nodes WHERE parent_path = ?"
165 )
166 .bind(path)
167 .fetch_all(&self.pool)
168 .await
169 .map_err(|e| StorageError::Backend(format!("Failed to get merkle children: {}", e)))?;
170
171 let mut children = BTreeMap::new();
172 for row in rows {
173 let child_path: String = row.try_get("path")
174 .map_err(|e| StorageError::Backend(e.to_string()))?;
175
176 let hash_str: String = row.try_get("merkle_hash")
178 .map_err(|e| StorageError::Backend(e.to_string()))?;
179
180 let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
182 hex::decode(&hash_str).unwrap_or_else(|_| hash_str.into_bytes())
183 } else {
184 hash_str.into_bytes()
185 };
186
187 if bytes.len() == 32 {
188 let segment = if path.is_empty() {
190 child_path.clone()
191 } else {
192 child_path.strip_prefix(&format!("{}.", path))
193 .unwrap_or(&child_path)
194 .to_string()
195 };
196
197 let mut hash = [0u8; 32];
198 hash.copy_from_slice(&bytes);
199 children.insert(segment, hash);
200 }
201 }
202
203 Ok(children)
204 }
205
206 pub async fn get_node(&self, path: &str) -> Result<Option<MerkleNode>, StorageError> {
208 let hash = self.get_hash(path).await?;
209
210 match hash {
211 Some(h) => {
212 let children = self.get_children(path).await?;
213 Ok(Some(if children.is_empty() {
214 MerkleNode::leaf(h)
215 } else {
216 MerkleNode {
217 hash: h,
218 children,
219 is_leaf: false,
220 }
221 }))
222 }
223 None => Ok(None),
224 }
225 }
226
227 #[instrument(skip(self, batch), fields(batch_size = batch.len()))]
232 pub async fn apply_batch(&self, batch: &MerkleBatch) -> Result<(), StorageError> {
233 if batch.is_empty() {
234 return Ok(());
235 }
236
237 let mut tx = self.pool.begin().await
239 .map_err(|e| StorageError::Backend(format!("Failed to begin transaction: {}", e)))?;
240
241 for (object_id, maybe_hash) in &batch.leaves {
243 let parent = PathMerkle::parent_prefix(object_id);
244
245 match maybe_hash {
246 Some(hash) => {
247 let hash_hex = hex::encode(hash);
249 let sql = if self.is_sqlite {
250 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
251 VALUES (?, ?, ?, 1, 1)
252 ON CONFLICT(path) DO UPDATE SET
253 merkle_hash = excluded.merkle_hash,
254 updated_at = strftime('%s', 'now')"
255 } else {
256 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
257 VALUES (?, ?, ?, 1, 1)
258 ON DUPLICATE KEY UPDATE
259 merkle_hash = VALUES(merkle_hash),
260 updated_at = CURRENT_TIMESTAMP"
261 };
262
263 sqlx::query(sql)
264 .bind(object_id)
265 .bind(parent)
266 .bind(&hash_hex)
267 .execute(&mut *tx)
268 .await
269 .map_err(|e| StorageError::Backend(format!("Failed to upsert merkle leaf: {}", e)))?;
270 }
271 None => {
272 sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
273 .bind(object_id)
274 .execute(&mut *tx)
275 .await
276 .map_err(|e| StorageError::Backend(format!("Failed to delete merkle leaf: {}", e)))?;
277 }
278 }
279 }
280
281 tx.commit().await
283 .map_err(|e| StorageError::Backend(format!("Failed to commit merkle leaves: {}", e)))?;
284
285 let affected_prefixes = batch.affected_prefixes();
287 for prefix in affected_prefixes {
288 self.recompute_interior_node(&prefix).await?;
289 }
290
291 self.recompute_interior_node(ROOT_PATH).await?;
293
294 debug!(updates = batch.len(), "SQL merkle batch applied");
295 Ok(())
296 }
297
298 async fn recompute_interior_node(&self, path: &str) -> Result<(), StorageError> {
300 let children = self.get_children(path).await?;
302
303 if children.is_empty() {
304 let result = sqlx::query("SELECT is_leaf FROM merkle_nodes WHERE path = ?")
306 .bind(path)
307 .fetch_optional(&self.pool)
308 .await
309 .map_err(|e| StorageError::Backend(e.to_string()))?;
310
311 if let Some(row) = result {
312 let is_leaf: i32 = row.try_get("is_leaf").unwrap_or(0);
314
315 if is_leaf == 0 {
316 sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
317 .bind(path)
318 .execute(&self.pool)
319 .await
320 .map_err(|e| StorageError::Backend(e.to_string()))?;
321 }
322 }
323 return Ok(());
324 }
325
326 let node = MerkleNode::interior(children);
328 let hash_hex = hex::encode(node.hash); let parent = if path.is_empty() { None } else { Some(PathMerkle::parent_prefix(path)) };
330 let object_count = node.children.len() as i32;
331
332 let sql = if self.is_sqlite {
333 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
334 VALUES (?, ?, ?, 0, ?)
335 ON CONFLICT(path) DO UPDATE SET
336 merkle_hash = excluded.merkle_hash,
337 object_count = excluded.object_count,
338 updated_at = strftime('%s', 'now')"
339 } else {
340 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
341 VALUES (?, ?, ?, 0, ?)
342 ON DUPLICATE KEY UPDATE
343 merkle_hash = VALUES(merkle_hash),
344 object_count = VALUES(object_count),
345 updated_at = CURRENT_TIMESTAMP"
346 };
347
348 sqlx::query(sql)
349 .bind(path)
350 .bind(parent)
351 .bind(&hash_hex)
352 .bind(object_count)
353 .execute(&self.pool)
354 .await
355 .map_err(|e| StorageError::Backend(format!("Failed to update interior node: {}", e)))?;
356
357 debug!(path = %path, children = node.children.len(), "Recomputed interior node");
358 Ok(())
359 }
360
361 #[instrument(skip(self, their_children))]
365 pub async fn diff_children(
366 &self,
367 path: &str,
368 their_children: &BTreeMap<String, [u8; 32]>,
369 ) -> Result<Vec<String>, StorageError> {
370 let our_children = self.get_children(path).await?;
371 let mut diffs = Vec::new();
372
373 let prefix_with_dot = if path.is_empty() {
374 String::new()
375 } else {
376 format!("{}.", path)
377 };
378
379 for (segment, our_hash) in &our_children {
381 match their_children.get(segment) {
382 Some(their_hash) if their_hash != our_hash => {
383 diffs.push(format!("{}{}", prefix_with_dot, segment));
384 }
385 None => {
386 diffs.push(format!("{}{}", prefix_with_dot, segment));
388 }
389 _ => {} }
391 }
392
393 for segment in their_children.keys() {
395 if !our_children.contains_key(segment) {
396 diffs.push(format!("{}{}", prefix_with_dot, segment));
397 }
398 }
399
400 Ok(diffs)
401 }
402
403 pub async fn get_leaves_under(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
405 let pattern = if prefix.is_empty() {
406 "%".to_string()
407 } else {
408 format!("{}%", prefix)
409 };
410
411 let sql = "SELECT path FROM merkle_nodes WHERE path LIKE ? AND is_leaf = 1";
413
414 let rows = sqlx::query(sql)
415 .bind(&pattern)
416 .fetch_all(&self.pool)
417 .await
418 .map_err(|e| StorageError::Backend(e.to_string()))?;
419
420 let mut leaves = Vec::with_capacity(rows.len());
421 for row in rows {
422 let path: String = row.try_get("path")
423 .map_err(|e| StorageError::Backend(e.to_string()))?;
424 leaves.push(path);
425 }
426
427 Ok(leaves)
428 }
429
430 pub async fn count_leaves(&self) -> Result<u64, StorageError> {
432 let sql = "SELECT COUNT(*) as cnt FROM merkle_nodes WHERE is_leaf = 1";
434
435 let row = sqlx::query(sql)
436 .fetch_one(&self.pool)
437 .await
438 .map_err(|e| StorageError::Backend(e.to_string()))?;
439
440 let count: i64 = row.try_get("cnt")
441 .map_err(|e| StorageError::Backend(e.to_string()))?;
442
443 Ok(count as u64)
444 }
445
446 pub async fn get_all_nodes(&self) -> Result<Vec<(String, [u8; 32], BTreeMap<String, [u8; 32]>)>, StorageError> {
450 let rows = sqlx::query("SELECT path, merkle_hash FROM merkle_nodes ORDER BY path")
451 .fetch_all(&self.pool)
452 .await
453 .map_err(|e| StorageError::Backend(format!("Failed to get all merkle nodes: {}", e)))?;
454
455 let mut nodes = Vec::with_capacity(rows.len());
456
457 for row in rows {
458 let path: String = row.try_get("path")
459 .map_err(|e| StorageError::Backend(e.to_string()))?;
460
461 let hash_str: String = row.try_get("merkle_hash")
462 .map_err(|e| StorageError::Backend(e.to_string()))?;
463
464 let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
465 hex::decode(&hash_str).unwrap_or_else(|_| hash_str.into_bytes())
466 } else {
467 hash_str.into_bytes()
468 };
469
470 if bytes.len() == 32 {
471 let mut hash = [0u8; 32];
472 hash.copy_from_slice(&bytes);
473
474 let children = self.get_children(&path).await?;
476
477 nodes.push((path, hash, children));
478 }
479 }
480
481 Ok(nodes)
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use std::path::PathBuf;
489
490 fn temp_db_path(name: &str) -> PathBuf {
491 std::env::temp_dir().join(format!("merkle_test_{}.db", name))
492 }
493
494 #[tokio::test]
495 async fn test_sql_merkle_basic() {
496 let db_path = temp_db_path("basic");
497 let _ = std::fs::remove_file(&db_path); let url = format!("sqlite://{}?mode=rwc", db_path.display());
500 let store = SqlMerkleStore::new(&url).await.unwrap();
501
502 assert!(store.root_hash().await.unwrap().is_none());
504
505 let mut batch = MerkleBatch::new();
507 batch.insert("uk.nhs.patient.123".to_string(), [1u8; 32]);
508 store.apply_batch(&batch).await.unwrap();
509
510 let root = store.root_hash().await.unwrap();
512 assert!(root.is_some());
513
514 let leaf = store.get_hash("uk.nhs.patient.123").await.unwrap();
516 assert_eq!(leaf, Some([1u8; 32]));
517
518 let _ = std::fs::remove_file(&db_path); }
520
521 #[tokio::test]
522 async fn test_sql_merkle_diff() {
523 let db_path = temp_db_path("diff");
524 let _ = std::fs::remove_file(&db_path); let url = format!("sqlite://{}?mode=rwc", db_path.display());
527 let store = SqlMerkleStore::new(&url).await.unwrap();
528
529 let mut batch = MerkleBatch::new();
531 batch.insert("uk.a.1".to_string(), [1u8; 32]);
532 batch.insert("uk.b.2".to_string(), [2u8; 32]);
533 store.apply_batch(&batch).await.unwrap();
534
535 let mut their_children = BTreeMap::new();
537 their_children.insert("a".to_string(), [99u8; 32]); their_children.insert("b".to_string(), store.get_hash("uk.b").await.unwrap().unwrap()); let diffs = store.diff_children("uk", &their_children).await.unwrap();
541 assert!(diffs.contains(&"uk.a".to_string()));
542 assert!(!diffs.contains(&"uk.b".to_string()));
543
544 let _ = std::fs::remove_file(&db_path); }
546}