1use super::path_tree::{MerkleBatch, MerkleNode, PathMerkle};
23use crate::StorageError;
24use sqlx::{AnyPool, Row, any::AnyPoolOptions};
25use std::collections::BTreeMap;
26use std::time::Duration;
27use tracing::{debug, info, instrument};
28
29const ROOT_PATH: &str = "";
31
32#[derive(Clone)]
34pub struct SqlMerkleStore {
35 pool: AnyPool,
36 is_sqlite: bool,
37}
38
39impl SqlMerkleStore {
40 pub async fn new(connection_string: &str) -> Result<Self, StorageError> {
42 sqlx::any::install_default_drivers();
44
45 let is_sqlite = connection_string.starts_with("sqlite:");
46
47 let pool = AnyPoolOptions::new()
48 .max_connections(5)
49 .acquire_timeout(Duration::from_secs(10))
50 .connect(connection_string)
51 .await
52 .map_err(|e| StorageError::Backend(format!("Failed to connect to SQL merkle store: {}", e)))?;
53
54 let store = Self { pool, is_sqlite };
55 store.init_schema().await?;
56
57 info!("SQL merkle store initialized");
58 Ok(store)
59 }
60
61 pub fn from_pool(pool: AnyPool, is_sqlite: bool) -> Self {
63 Self { pool, is_sqlite }
64 }
65
66 pub async fn init_schema(&self) -> Result<(), StorageError> {
69 let sql = if self.is_sqlite {
71 r#"
72 CREATE TABLE IF NOT EXISTS merkle_nodes (
73 path TEXT PRIMARY KEY,
74 parent_path TEXT,
75 merkle_hash TEXT NOT NULL,
76 object_count INTEGER NOT NULL DEFAULT 0,
77 is_leaf INTEGER DEFAULT 0,
78 updated_at INTEGER DEFAULT (strftime('%s', 'now'))
79 );
80 CREATE INDEX IF NOT EXISTS idx_merkle_parent ON merkle_nodes(parent_path);
81 "#
82 } else {
83 r#"
86 CREATE TABLE IF NOT EXISTS merkle_nodes (
87 path VARCHAR(255) PRIMARY KEY,
88 parent_path VARCHAR(255),
89 merkle_hash VARCHAR(64) NOT NULL,
90 object_count INT NOT NULL DEFAULT 0,
91 is_leaf INT DEFAULT 0,
92 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
93 INDEX idx_merkle_parent (parent_path)
94 ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
95 "#
96 };
97
98 if self.is_sqlite {
100 for stmt in sql.split(';').filter(|s| !s.trim().is_empty()) {
101 sqlx::query(stmt)
102 .execute(&self.pool)
103 .await
104 .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
105 }
106 } else {
107 sqlx::query(sql)
108 .execute(&self.pool)
109 .await
110 .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
111 }
112
113 Ok(())
114 }
115
116 #[instrument(skip(self))]
118 pub async fn root_hash(&self) -> Result<Option<[u8; 32]>, StorageError> {
119 self.get_hash(ROOT_PATH).await
120 }
121
122 pub async fn get_hash(&self, path: &str) -> Result<Option<[u8; 32]>, StorageError> {
124 let result = sqlx::query("SELECT merkle_hash FROM merkle_nodes WHERE path = ?")
125 .bind(path)
126 .fetch_optional(&self.pool)
127 .await
128 .map_err(|e| StorageError::Backend(format!("Failed to get merkle hash: {}", e)))?;
129
130 match result {
131 Some(row) => {
132 let hash_str: String = row.try_get("merkle_hash")
135 .map_err(|e| StorageError::Backend(format!("Failed to read merkle hash: {}", e)))?;
136
137 let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
140 hex::decode(&hash_str).map_err(|e| StorageError::Backend(format!(
141 "Invalid merkle hash hex: {}", e
142 )))?
143 } else {
144 hash_str.into_bytes()
146 };
147
148 if bytes.len() != 32 {
149 return Err(StorageError::Backend(format!(
150 "Invalid merkle hash length: {}", bytes.len()
151 )));
152 }
153 let mut hash = [0u8; 32];
154 hash.copy_from_slice(&bytes);
155 Ok(Some(hash))
156 }
157 None => Ok(None),
158 }
159 }
160
161 pub async fn get_children(&self, path: &str) -> Result<BTreeMap<String, [u8; 32]>, StorageError> {
163 let rows = sqlx::query(
164 "SELECT path, merkle_hash FROM merkle_nodes WHERE parent_path = ?"
165 )
166 .bind(path)
167 .fetch_all(&self.pool)
168 .await
169 .map_err(|e| StorageError::Backend(format!("Failed to get merkle children: {}", e)))?;
170
171 let mut children = BTreeMap::new();
172 for row in rows {
173 let child_path: String = row.try_get("path")
174 .map_err(|e| StorageError::Backend(e.to_string()))?;
175
176 let hash_str: String = row.try_get("merkle_hash")
178 .map_err(|e| StorageError::Backend(e.to_string()))?;
179
180 let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
182 hex::decode(&hash_str).unwrap_or_else(|_| hash_str.into_bytes())
183 } else {
184 hash_str.into_bytes()
185 };
186
187 if bytes.len() == 32 {
188 let segment = if path.is_empty() {
190 child_path.clone()
191 } else {
192 child_path.strip_prefix(&format!("{}.", path))
193 .unwrap_or(&child_path)
194 .to_string()
195 };
196
197 let mut hash = [0u8; 32];
198 hash.copy_from_slice(&bytes);
199 children.insert(segment, hash);
200 }
201 }
202
203 Ok(children)
204 }
205
206 pub async fn get_node(&self, path: &str) -> Result<Option<MerkleNode>, StorageError> {
208 let hash = self.get_hash(path).await?;
209
210 match hash {
211 Some(h) => {
212 let children = self.get_children(path).await?;
213 Ok(Some(if children.is_empty() {
214 MerkleNode::leaf(h)
215 } else {
216 MerkleNode {
217 hash: h,
218 children,
219 is_leaf: false,
220 }
221 }))
222 }
223 None => Ok(None),
224 }
225 }
226
227 #[instrument(skip(self, batch), fields(batch_size = batch.len()))]
231 pub async fn apply_batch(&self, batch: &MerkleBatch) -> Result<(), StorageError> {
232 if batch.is_empty() {
233 return Ok(());
234 }
235
236 let mut tx = self.pool.begin().await
238 .map_err(|e| StorageError::Backend(format!("Failed to begin transaction: {}", e)))?;
239
240 if !self.is_sqlite {
242 sqlx::query("SET TRANSACTION ISOLATION LEVEL READ COMMITTED")
243 .execute(&mut *tx)
244 .await
245 .map_err(|e| StorageError::Backend(format!("Failed to set isolation level: {}", e)))?;
246 }
247
248 for (object_id, maybe_hash) in &batch.leaves {
250 let parent = PathMerkle::parent_prefix(object_id);
251
252 match maybe_hash {
253 Some(hash) => {
254 let hash_hex = hex::encode(hash);
256 let sql = if self.is_sqlite {
257 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
258 VALUES (?, ?, ?, 1, 1)
259 ON CONFLICT(path) DO UPDATE SET
260 merkle_hash = excluded.merkle_hash,
261 updated_at = strftime('%s', 'now')"
262 } else {
263 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
264 VALUES (?, ?, ?, 1, 1)
265 ON DUPLICATE KEY UPDATE
266 merkle_hash = VALUES(merkle_hash),
267 updated_at = CURRENT_TIMESTAMP"
268 };
269
270 sqlx::query(sql)
271 .bind(object_id)
272 .bind(parent)
273 .bind(&hash_hex)
274 .execute(&mut *tx)
275 .await
276 .map_err(|e| StorageError::Backend(format!("Failed to upsert merkle leaf: {}", e)))?;
277 }
278 None => {
279 sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
280 .bind(object_id)
281 .execute(&mut *tx)
282 .await
283 .map_err(|e| StorageError::Backend(format!("Failed to delete merkle leaf: {}", e)))?;
284 }
285 }
286 }
287
288 tx.commit().await
290 .map_err(|e| StorageError::Backend(format!("Failed to commit merkle leaves: {}", e)))?;
291
292 let affected_prefixes = batch.affected_prefixes();
294 for prefix in affected_prefixes {
295 self.recompute_interior_node(&prefix).await?;
296 }
297
298 self.recompute_interior_node(ROOT_PATH).await?;
300
301 debug!(updates = batch.len(), "SQL merkle batch applied");
302 Ok(())
303 }
304
305 async fn recompute_interior_node(&self, path: &str) -> Result<(), StorageError> {
307 let children = self.get_children(path).await?;
309
310 if children.is_empty() {
311 let result = sqlx::query("SELECT is_leaf FROM merkle_nodes WHERE path = ?")
313 .bind(path)
314 .fetch_optional(&self.pool)
315 .await
316 .map_err(|e| StorageError::Backend(e.to_string()))?;
317
318 if let Some(row) = result {
319 let is_leaf: i32 = row.try_get("is_leaf").unwrap_or(0);
321
322 if is_leaf == 0 {
323 sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
324 .bind(path)
325 .execute(&self.pool)
326 .await
327 .map_err(|e| StorageError::Backend(e.to_string()))?;
328 }
329 }
330 return Ok(());
331 }
332
333 let node = MerkleNode::interior(children);
335 let hash_hex = hex::encode(node.hash); let parent = if path.is_empty() { None } else { Some(PathMerkle::parent_prefix(path)) };
337 let object_count = node.children.len() as i32;
338
339 let sql = if self.is_sqlite {
340 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
341 VALUES (?, ?, ?, 0, ?)
342 ON CONFLICT(path) DO UPDATE SET
343 merkle_hash = excluded.merkle_hash,
344 object_count = excluded.object_count,
345 updated_at = strftime('%s', 'now')"
346 } else {
347 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
348 VALUES (?, ?, ?, 0, ?)
349 ON DUPLICATE KEY UPDATE
350 merkle_hash = VALUES(merkle_hash),
351 object_count = VALUES(object_count),
352 updated_at = CURRENT_TIMESTAMP"
353 };
354
355 sqlx::query(sql)
356 .bind(path)
357 .bind(parent)
358 .bind(&hash_hex)
359 .bind(object_count)
360 .execute(&self.pool)
361 .await
362 .map_err(|e| StorageError::Backend(format!("Failed to update interior node: {}", e)))?;
363
364 debug!(path = %path, children = node.children.len(), "Recomputed interior node");
365 Ok(())
366 }
367
368 #[instrument(skip(self, their_children))]
372 pub async fn diff_children(
373 &self,
374 path: &str,
375 their_children: &BTreeMap<String, [u8; 32]>,
376 ) -> Result<Vec<String>, StorageError> {
377 let our_children = self.get_children(path).await?;
378 let mut diffs = Vec::new();
379
380 let prefix_with_dot = if path.is_empty() {
381 String::new()
382 } else {
383 format!("{}.", path)
384 };
385
386 for (segment, our_hash) in &our_children {
388 match their_children.get(segment) {
389 Some(their_hash) if their_hash != our_hash => {
390 diffs.push(format!("{}{}", prefix_with_dot, segment));
391 }
392 None => {
393 diffs.push(format!("{}{}", prefix_with_dot, segment));
395 }
396 _ => {} }
398 }
399
400 for segment in their_children.keys() {
402 if !our_children.contains_key(segment) {
403 diffs.push(format!("{}{}", prefix_with_dot, segment));
404 }
405 }
406
407 Ok(diffs)
408 }
409
410 pub async fn get_leaves_under(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
412 let pattern = if prefix.is_empty() {
413 "%".to_string()
414 } else {
415 format!("{}%", prefix)
416 };
417
418 let sql = "SELECT path FROM merkle_nodes WHERE path LIKE ? AND is_leaf = 1";
420
421 let rows = sqlx::query(sql)
422 .bind(&pattern)
423 .fetch_all(&self.pool)
424 .await
425 .map_err(|e| StorageError::Backend(e.to_string()))?;
426
427 let mut leaves = Vec::with_capacity(rows.len());
428 for row in rows {
429 let path: String = row.try_get("path")
430 .map_err(|e| StorageError::Backend(e.to_string()))?;
431 leaves.push(path);
432 }
433
434 Ok(leaves)
435 }
436
437 pub async fn count_leaves(&self) -> Result<u64, StorageError> {
439 let sql = "SELECT COUNT(*) as cnt FROM merkle_nodes WHERE is_leaf = 1";
441
442 let row = sqlx::query(sql)
443 .fetch_one(&self.pool)
444 .await
445 .map_err(|e| StorageError::Backend(e.to_string()))?;
446
447 let count: i64 = row.try_get("cnt")
448 .map_err(|e| StorageError::Backend(e.to_string()))?;
449
450 Ok(count as u64)
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use std::path::PathBuf;
458
459 fn temp_db_path(name: &str) -> PathBuf {
460 std::env::temp_dir().join(format!("merkle_test_{}.db", name))
461 }
462
463 #[tokio::test]
464 async fn test_sql_merkle_basic() {
465 let db_path = temp_db_path("basic");
466 let _ = std::fs::remove_file(&db_path); let url = format!("sqlite://{}?mode=rwc", db_path.display());
469 let store = SqlMerkleStore::new(&url).await.unwrap();
470
471 assert!(store.root_hash().await.unwrap().is_none());
473
474 let mut batch = MerkleBatch::new();
476 batch.insert("uk.nhs.patient.123".to_string(), [1u8; 32]);
477 store.apply_batch(&batch).await.unwrap();
478
479 let root = store.root_hash().await.unwrap();
481 assert!(root.is_some());
482
483 let leaf = store.get_hash("uk.nhs.patient.123").await.unwrap();
485 assert_eq!(leaf, Some([1u8; 32]));
486
487 let _ = std::fs::remove_file(&db_path); }
489
490 #[tokio::test]
491 async fn test_sql_merkle_diff() {
492 let db_path = temp_db_path("diff");
493 let _ = std::fs::remove_file(&db_path); let url = format!("sqlite://{}?mode=rwc", db_path.display());
496 let store = SqlMerkleStore::new(&url).await.unwrap();
497
498 let mut batch = MerkleBatch::new();
500 batch.insert("uk.a.1".to_string(), [1u8; 32]);
501 batch.insert("uk.b.2".to_string(), [2u8; 32]);
502 store.apply_batch(&batch).await.unwrap();
503
504 let mut their_children = BTreeMap::new();
506 their_children.insert("a".to_string(), [99u8; 32]); their_children.insert("b".to_string(), store.get_hash("uk.b").await.unwrap().unwrap()); let diffs = store.diff_children("uk", &their_children).await.unwrap();
510 assert!(diffs.contains(&"uk.a".to_string()));
511 assert!(!diffs.contains(&"uk.b".to_string()));
512
513 let _ = std::fs::remove_file(&db_path); }
515}