1use super::path_tree::{MerkleBatch, MerkleNode, PathMerkle};
23use crate::StorageError;
24use sqlx::{AnyPool, Row, any::AnyPoolOptions};
25use std::collections::BTreeMap;
26use std::time::Duration;
27use tracing::{debug, info, instrument};
28
29const ROOT_PATH: &str = "";
31
32#[derive(Clone)]
34pub struct SqlMerkleStore {
35 pool: AnyPool,
36 is_sqlite: bool,
37}
38
39impl SqlMerkleStore {
40 pub async fn new(connection_string: &str) -> Result<Self, StorageError> {
42 sqlx::any::install_default_drivers();
44
45 let is_sqlite = connection_string.starts_with("sqlite:");
46
47 let pool = AnyPoolOptions::new()
48 .max_connections(5)
49 .acquire_timeout(Duration::from_secs(10))
50 .connect(connection_string)
51 .await
52 .map_err(|e| StorageError::Backend(format!("Failed to connect to SQL merkle store: {}", e)))?;
53
54 let store = Self { pool, is_sqlite };
55 store.init_schema().await?;
56
57 info!("SQL merkle store initialized");
58 Ok(store)
59 }
60
61 pub fn from_pool(pool: AnyPool, is_sqlite: bool) -> Self {
63 Self { pool, is_sqlite }
64 }
65
66 pub async fn init_schema(&self) -> Result<(), StorageError> {
69 let sql = if self.is_sqlite {
71 r#"
72 CREATE TABLE IF NOT EXISTS merkle_nodes (
73 path TEXT PRIMARY KEY,
74 parent_path TEXT,
75 merkle_hash TEXT NOT NULL,
76 object_count INTEGER NOT NULL DEFAULT 0,
77 is_leaf INTEGER DEFAULT 0,
78 updated_at INTEGER DEFAULT (strftime('%s', 'now'))
79 );
80 CREATE INDEX IF NOT EXISTS idx_merkle_parent ON merkle_nodes(parent_path);
81 "#
82 } else {
83 r#"
86 CREATE TABLE IF NOT EXISTS merkle_nodes (
87 path VARCHAR(255) PRIMARY KEY,
88 parent_path VARCHAR(255),
89 merkle_hash VARCHAR(64) NOT NULL,
90 object_count INT NOT NULL DEFAULT 0,
91 is_leaf INT DEFAULT 0,
92 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
93 INDEX idx_merkle_parent (parent_path)
94 ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
95 "#
96 };
97
98 if self.is_sqlite {
100 for stmt in sql.split(';').filter(|s| !s.trim().is_empty()) {
101 sqlx::query(stmt)
102 .execute(&self.pool)
103 .await
104 .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
105 }
106 } else {
107 sqlx::query(sql)
108 .execute(&self.pool)
109 .await
110 .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
111 }
112
113 Ok(())
114 }
115
116 #[instrument(skip(self))]
118 pub async fn root_hash(&self) -> Result<Option<[u8; 32]>, StorageError> {
119 self.get_hash(ROOT_PATH).await
120 }
121
122 pub async fn get_hash(&self, path: &str) -> Result<Option<[u8; 32]>, StorageError> {
124 let result = sqlx::query("SELECT merkle_hash FROM merkle_nodes WHERE path = ?")
125 .bind(path)
126 .fetch_optional(&self.pool)
127 .await
128 .map_err(|e| StorageError::Backend(format!("Failed to get merkle hash: {}", e)))?;
129
130 match result {
131 Some(row) => {
132 let hash_str: String = row.try_get("merkle_hash")
135 .map_err(|e| StorageError::Backend(format!("Failed to read merkle hash: {}", e)))?;
136
137 let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
140 hex::decode(&hash_str).map_err(|e| StorageError::Backend(format!(
141 "Invalid merkle hash hex: {}", e
142 )))?
143 } else {
144 hash_str.into_bytes()
146 };
147
148 if bytes.len() != 32 {
149 return Err(StorageError::Backend(format!(
150 "Invalid merkle hash length: {}", bytes.len()
151 )));
152 }
153 let mut hash = [0u8; 32];
154 hash.copy_from_slice(&bytes);
155 Ok(Some(hash))
156 }
157 None => Ok(None),
158 }
159 }
160
161 pub async fn get_children(&self, path: &str) -> Result<BTreeMap<String, [u8; 32]>, StorageError> {
163 let rows = sqlx::query(
164 "SELECT path, merkle_hash FROM merkle_nodes WHERE parent_path = ?"
165 )
166 .bind(path)
167 .fetch_all(&self.pool)
168 .await
169 .map_err(|e| StorageError::Backend(format!("Failed to get merkle children: {}", e)))?;
170
171 let mut children = BTreeMap::new();
172 for row in rows {
173 let child_path: String = row.try_get("path")
174 .map_err(|e| StorageError::Backend(e.to_string()))?;
175
176 let hash_str: String = row.try_get("merkle_hash")
178 .map_err(|e| StorageError::Backend(e.to_string()))?;
179
180 let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
182 hex::decode(&hash_str).unwrap_or_else(|_| hash_str.into_bytes())
183 } else {
184 hash_str.into_bytes()
185 };
186
187 if bytes.len() == 32 {
188 let segment = if path.is_empty() {
190 child_path.clone()
191 } else {
192 child_path.strip_prefix(&format!("{}.", path))
193 .unwrap_or(&child_path)
194 .to_string()
195 };
196
197 let mut hash = [0u8; 32];
198 hash.copy_from_slice(&bytes);
199 children.insert(segment, hash);
200 }
201 }
202
203 Ok(children)
204 }
205
206 pub async fn get_node(&self, path: &str) -> Result<Option<MerkleNode>, StorageError> {
208 let hash = self.get_hash(path).await?;
209
210 match hash {
211 Some(h) => {
212 let children = self.get_children(path).await?;
213 Ok(Some(if children.is_empty() {
214 MerkleNode::leaf(h)
215 } else {
216 MerkleNode {
217 hash: h,
218 children,
219 is_leaf: false,
220 }
221 }))
222 }
223 None => Ok(None),
224 }
225 }
226
227 #[instrument(skip(self, batch), fields(batch_size = batch.len()))]
231 pub async fn apply_batch(&self, batch: &MerkleBatch) -> Result<(), StorageError> {
232 if batch.is_empty() {
233 return Ok(());
234 }
235
236 let mut tx = self.pool.begin().await
238 .map_err(|e| StorageError::Backend(format!("Failed to begin transaction: {}", e)))?;
239
240 for (object_id, maybe_hash) in &batch.leaves {
242 let parent = PathMerkle::parent_prefix(object_id);
243
244 match maybe_hash {
245 Some(hash) => {
246 let hash_hex = hex::encode(hash);
248 let sql = if self.is_sqlite {
249 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
250 VALUES (?, ?, ?, 1, 1)
251 ON CONFLICT(path) DO UPDATE SET
252 merkle_hash = excluded.merkle_hash,
253 updated_at = strftime('%s', 'now')"
254 } else {
255 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
256 VALUES (?, ?, ?, 1, 1)
257 ON DUPLICATE KEY UPDATE
258 merkle_hash = VALUES(merkle_hash),
259 updated_at = CURRENT_TIMESTAMP"
260 };
261
262 sqlx::query(sql)
263 .bind(object_id)
264 .bind(parent)
265 .bind(&hash_hex)
266 .execute(&mut *tx)
267 .await
268 .map_err(|e| StorageError::Backend(format!("Failed to upsert merkle leaf: {}", e)))?;
269 }
270 None => {
271 sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
272 .bind(object_id)
273 .execute(&mut *tx)
274 .await
275 .map_err(|e| StorageError::Backend(format!("Failed to delete merkle leaf: {}", e)))?;
276 }
277 }
278 }
279
280 tx.commit().await
282 .map_err(|e| StorageError::Backend(format!("Failed to commit merkle leaves: {}", e)))?;
283
284 let affected_prefixes = batch.affected_prefixes();
286 for prefix in affected_prefixes {
287 self.recompute_interior_node(&prefix).await?;
288 }
289
290 self.recompute_interior_node(ROOT_PATH).await?;
292
293 debug!(updates = batch.len(), "SQL merkle batch applied");
294 Ok(())
295 }
296
297 async fn recompute_interior_node(&self, path: &str) -> Result<(), StorageError> {
299 let children = self.get_children(path).await?;
301
302 if children.is_empty() {
303 let result = sqlx::query("SELECT is_leaf FROM merkle_nodes WHERE path = ?")
305 .bind(path)
306 .fetch_optional(&self.pool)
307 .await
308 .map_err(|e| StorageError::Backend(e.to_string()))?;
309
310 if let Some(row) = result {
311 let is_leaf: i32 = row.try_get("is_leaf").unwrap_or(0);
313
314 if is_leaf == 0 {
315 sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
316 .bind(path)
317 .execute(&self.pool)
318 .await
319 .map_err(|e| StorageError::Backend(e.to_string()))?;
320 }
321 }
322 return Ok(());
323 }
324
325 let node = MerkleNode::interior(children);
327 let hash_hex = hex::encode(node.hash); let parent = if path.is_empty() { None } else { Some(PathMerkle::parent_prefix(path)) };
329 let object_count = node.children.len() as i32;
330
331 let sql = if self.is_sqlite {
332 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
333 VALUES (?, ?, ?, 0, ?)
334 ON CONFLICT(path) DO UPDATE SET
335 merkle_hash = excluded.merkle_hash,
336 object_count = excluded.object_count,
337 updated_at = strftime('%s', 'now')"
338 } else {
339 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
340 VALUES (?, ?, ?, 0, ?)
341 ON DUPLICATE KEY UPDATE
342 merkle_hash = VALUES(merkle_hash),
343 object_count = VALUES(object_count),
344 updated_at = CURRENT_TIMESTAMP"
345 };
346
347 sqlx::query(sql)
348 .bind(path)
349 .bind(parent)
350 .bind(&hash_hex)
351 .bind(object_count)
352 .execute(&self.pool)
353 .await
354 .map_err(|e| StorageError::Backend(format!("Failed to update interior node: {}", e)))?;
355
356 debug!(path = %path, children = node.children.len(), "Recomputed interior node");
357 Ok(())
358 }
359
360 #[instrument(skip(self, their_children))]
364 pub async fn diff_children(
365 &self,
366 path: &str,
367 their_children: &BTreeMap<String, [u8; 32]>,
368 ) -> Result<Vec<String>, StorageError> {
369 let our_children = self.get_children(path).await?;
370 let mut diffs = Vec::new();
371
372 let prefix_with_dot = if path.is_empty() {
373 String::new()
374 } else {
375 format!("{}.", path)
376 };
377
378 for (segment, our_hash) in &our_children {
380 match their_children.get(segment) {
381 Some(their_hash) if their_hash != our_hash => {
382 diffs.push(format!("{}{}", prefix_with_dot, segment));
383 }
384 None => {
385 diffs.push(format!("{}{}", prefix_with_dot, segment));
387 }
388 _ => {} }
390 }
391
392 for segment in their_children.keys() {
394 if !our_children.contains_key(segment) {
395 diffs.push(format!("{}{}", prefix_with_dot, segment));
396 }
397 }
398
399 Ok(diffs)
400 }
401
402 pub async fn get_leaves_under(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
404 let pattern = if prefix.is_empty() {
405 "%".to_string()
406 } else {
407 format!("{}%", prefix)
408 };
409
410 let sql = "SELECT path FROM merkle_nodes WHERE path LIKE ? AND is_leaf = 1";
412
413 let rows = sqlx::query(sql)
414 .bind(&pattern)
415 .fetch_all(&self.pool)
416 .await
417 .map_err(|e| StorageError::Backend(e.to_string()))?;
418
419 let mut leaves = Vec::with_capacity(rows.len());
420 for row in rows {
421 let path: String = row.try_get("path")
422 .map_err(|e| StorageError::Backend(e.to_string()))?;
423 leaves.push(path);
424 }
425
426 Ok(leaves)
427 }
428
429 pub async fn count_leaves(&self) -> Result<u64, StorageError> {
431 let sql = "SELECT COUNT(*) as cnt FROM merkle_nodes WHERE is_leaf = 1";
433
434 let row = sqlx::query(sql)
435 .fetch_one(&self.pool)
436 .await
437 .map_err(|e| StorageError::Backend(e.to_string()))?;
438
439 let count: i64 = row.try_get("cnt")
440 .map_err(|e| StorageError::Backend(e.to_string()))?;
441
442 Ok(count as u64)
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use std::path::PathBuf;
450
451 fn temp_db_path(name: &str) -> PathBuf {
452 std::env::temp_dir().join(format!("merkle_test_{}.db", name))
453 }
454
455 #[tokio::test]
456 async fn test_sql_merkle_basic() {
457 let db_path = temp_db_path("basic");
458 let _ = std::fs::remove_file(&db_path); let url = format!("sqlite://{}?mode=rwc", db_path.display());
461 let store = SqlMerkleStore::new(&url).await.unwrap();
462
463 assert!(store.root_hash().await.unwrap().is_none());
465
466 let mut batch = MerkleBatch::new();
468 batch.insert("uk.nhs.patient.123".to_string(), [1u8; 32]);
469 store.apply_batch(&batch).await.unwrap();
470
471 let root = store.root_hash().await.unwrap();
473 assert!(root.is_some());
474
475 let leaf = store.get_hash("uk.nhs.patient.123").await.unwrap();
477 assert_eq!(leaf, Some([1u8; 32]));
478
479 let _ = std::fs::remove_file(&db_path); }
481
482 #[tokio::test]
483 async fn test_sql_merkle_diff() {
484 let db_path = temp_db_path("diff");
485 let _ = std::fs::remove_file(&db_path); let url = format!("sqlite://{}?mode=rwc", db_path.display());
488 let store = SqlMerkleStore::new(&url).await.unwrap();
489
490 let mut batch = MerkleBatch::new();
492 batch.insert("uk.a.1".to_string(), [1u8; 32]);
493 batch.insert("uk.b.2".to_string(), [2u8; 32]);
494 store.apply_batch(&batch).await.unwrap();
495
496 let mut their_children = BTreeMap::new();
498 their_children.insert("a".to_string(), [99u8; 32]); their_children.insert("b".to_string(), store.get_hash("uk.b").await.unwrap().unwrap()); let diffs = store.diff_children("uk", &their_children).await.unwrap();
502 assert!(diffs.contains(&"uk.a".to_string()));
503 assert!(!diffs.contains(&"uk.b".to_string()));
504
505 let _ = std::fs::remove_file(&db_path); }
507}