1use super::path_tree::{MerkleBatch, MerkleNode, PathMerkle};
20use crate::StorageError;
21use sqlx::{AnyPool, Row, any::AnyPoolOptions};
22use std::collections::BTreeMap;
23use std::time::Duration;
24use tracing::{debug, info, instrument};
25
26const ROOT_PATH: &str = "";
28
29#[derive(Clone)]
31pub struct SqlMerkleStore {
32 pool: AnyPool,
33 is_sqlite: bool,
34}
35
36impl SqlMerkleStore {
37 pub async fn new(connection_string: &str) -> Result<Self, StorageError> {
39 sqlx::any::install_default_drivers();
41
42 let is_sqlite = connection_string.starts_with("sqlite:");
43
44 let pool = AnyPoolOptions::new()
45 .max_connections(5)
46 .acquire_timeout(Duration::from_secs(10))
47 .connect(connection_string)
48 .await
49 .map_err(|e| StorageError::Backend(format!("Failed to connect to SQL merkle store: {}", e)))?;
50
51 let store = Self { pool, is_sqlite };
52 store.init_schema().await?;
53
54 info!("SQL merkle store initialized");
55 Ok(store)
56 }
57
58 pub fn from_pool(pool: AnyPool, is_sqlite: bool) -> Self {
60 Self { pool, is_sqlite }
61 }
62
63 pub async fn init_schema(&self) -> Result<(), StorageError> {
66 let sql = if self.is_sqlite {
68 r#"
69 CREATE TABLE IF NOT EXISTS merkle_nodes (
70 path TEXT PRIMARY KEY,
71 parent_path TEXT,
72 merkle_hash TEXT NOT NULL,
73 object_count INTEGER NOT NULL DEFAULT 0,
74 is_leaf INTEGER DEFAULT 0,
75 updated_at INTEGER DEFAULT (strftime('%s', 'now'))
76 );
77 CREATE INDEX IF NOT EXISTS idx_merkle_parent ON merkle_nodes(parent_path);
78 "#
79 } else {
80 r#"
83 CREATE TABLE IF NOT EXISTS merkle_nodes (
84 path VARCHAR(255) PRIMARY KEY,
85 parent_path VARCHAR(255),
86 merkle_hash VARCHAR(64) NOT NULL,
87 object_count INT NOT NULL DEFAULT 0,
88 is_leaf INT DEFAULT 0,
89 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
90 INDEX idx_merkle_parent (parent_path)
91 ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
92 "#
93 };
94
95 if self.is_sqlite {
97 for stmt in sql.split(';').filter(|s| !s.trim().is_empty()) {
98 sqlx::query(stmt)
99 .execute(&self.pool)
100 .await
101 .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
102 }
103 } else {
104 sqlx::query(sql)
105 .execute(&self.pool)
106 .await
107 .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
108 }
109
110 Ok(())
111 }
112
113 #[instrument(skip(self))]
115 pub async fn root_hash(&self) -> Result<Option<[u8; 32]>, StorageError> {
116 self.get_hash(ROOT_PATH).await
117 }
118
119 pub async fn get_hash(&self, path: &str) -> Result<Option<[u8; 32]>, StorageError> {
121 let result = sqlx::query("SELECT merkle_hash FROM merkle_nodes WHERE path = ?")
122 .bind(path)
123 .fetch_optional(&self.pool)
124 .await
125 .map_err(|e| StorageError::Backend(format!("Failed to get merkle hash: {}", e)))?;
126
127 match result {
128 Some(row) => {
129 let hash_str: String = row.try_get("merkle_hash")
132 .map_err(|e| StorageError::Backend(format!("Failed to read merkle hash: {}", e)))?;
133
134 let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
137 hex::decode(&hash_str).map_err(|e| StorageError::Backend(format!(
138 "Invalid merkle hash hex: {}", e
139 )))?
140 } else {
141 hash_str.into_bytes()
143 };
144
145 if bytes.len() != 32 {
146 return Err(StorageError::Backend(format!(
147 "Invalid merkle hash length: {}", bytes.len()
148 )));
149 }
150 let mut hash = [0u8; 32];
151 hash.copy_from_slice(&bytes);
152 Ok(Some(hash))
153 }
154 None => Ok(None),
155 }
156 }
157
158 pub async fn get_children(&self, path: &str) -> Result<BTreeMap<String, [u8; 32]>, StorageError> {
160 let rows = sqlx::query(
161 "SELECT path, merkle_hash FROM merkle_nodes WHERE parent_path = ?"
162 )
163 .bind(path)
164 .fetch_all(&self.pool)
165 .await
166 .map_err(|e| StorageError::Backend(format!("Failed to get merkle children: {}", e)))?;
167
168 let mut children = BTreeMap::new();
169 for row in rows {
170 let child_path: String = row.try_get("path")
171 .map_err(|e| StorageError::Backend(e.to_string()))?;
172
173 let hash_str: String = row.try_get("merkle_hash")
175 .map_err(|e| StorageError::Backend(e.to_string()))?;
176
177 let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
179 hex::decode(&hash_str).unwrap_or_else(|_| hash_str.into_bytes())
180 } else {
181 hash_str.into_bytes()
182 };
183
184 if bytes.len() == 32 {
185 let segment = if path.is_empty() {
187 child_path.clone()
188 } else {
189 child_path.strip_prefix(&format!("{}.", path))
190 .unwrap_or(&child_path)
191 .to_string()
192 };
193
194 let mut hash = [0u8; 32];
195 hash.copy_from_slice(&bytes);
196 children.insert(segment, hash);
197 }
198 }
199
200 Ok(children)
201 }
202
203 pub async fn get_node(&self, path: &str) -> Result<Option<MerkleNode>, StorageError> {
205 let hash = self.get_hash(path).await?;
206
207 match hash {
208 Some(h) => {
209 let children = self.get_children(path).await?;
210 Ok(Some(if children.is_empty() {
211 MerkleNode::leaf(h)
212 } else {
213 MerkleNode {
214 hash: h,
215 children,
216 is_leaf: false,
217 }
218 }))
219 }
220 None => Ok(None),
221 }
222 }
223
224 #[instrument(skip(self, batch), fields(batch_size = batch.len()))]
228 pub async fn apply_batch(&self, batch: &MerkleBatch) -> Result<(), StorageError> {
229 if batch.is_empty() {
230 return Ok(());
231 }
232
233 let mut tx = self.pool.begin().await
235 .map_err(|e| StorageError::Backend(format!("Failed to begin transaction: {}", e)))?;
236
237 for (object_id, maybe_hash) in &batch.leaves {
239 let parent = PathMerkle::parent_prefix(object_id);
240
241 match maybe_hash {
242 Some(hash) => {
243 let hash_hex = hex::encode(hash);
245 let sql = if self.is_sqlite {
246 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
247 VALUES (?, ?, ?, 1, 1)
248 ON CONFLICT(path) DO UPDATE SET
249 merkle_hash = excluded.merkle_hash,
250 updated_at = strftime('%s', 'now')"
251 } else {
252 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
253 VALUES (?, ?, ?, 1, 1)
254 ON DUPLICATE KEY UPDATE
255 merkle_hash = VALUES(merkle_hash),
256 updated_at = CURRENT_TIMESTAMP"
257 };
258
259 sqlx::query(sql)
260 .bind(object_id)
261 .bind(parent)
262 .bind(&hash_hex)
263 .execute(&mut *tx)
264 .await
265 .map_err(|e| StorageError::Backend(format!("Failed to upsert merkle leaf: {}", e)))?;
266 }
267 None => {
268 sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
269 .bind(object_id)
270 .execute(&mut *tx)
271 .await
272 .map_err(|e| StorageError::Backend(format!("Failed to delete merkle leaf: {}", e)))?;
273 }
274 }
275 }
276
277 tx.commit().await
279 .map_err(|e| StorageError::Backend(format!("Failed to commit merkle leaves: {}", e)))?;
280
281 let affected_prefixes = batch.affected_prefixes();
283 for prefix in affected_prefixes {
284 self.recompute_interior_node(&prefix).await?;
285 }
286
287 self.recompute_interior_node(ROOT_PATH).await?;
289
290 debug!(updates = batch.len(), "SQL merkle batch applied");
291 Ok(())
292 }
293
294 async fn recompute_interior_node(&self, path: &str) -> Result<(), StorageError> {
296 let children = self.get_children(path).await?;
298
299 if children.is_empty() {
300 let result = sqlx::query("SELECT is_leaf FROM merkle_nodes WHERE path = ?")
302 .bind(path)
303 .fetch_optional(&self.pool)
304 .await
305 .map_err(|e| StorageError::Backend(e.to_string()))?;
306
307 if let Some(row) = result {
308 let is_leaf: i32 = row.try_get("is_leaf").unwrap_or(0);
310
311 if is_leaf == 0 {
312 sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
313 .bind(path)
314 .execute(&self.pool)
315 .await
316 .map_err(|e| StorageError::Backend(e.to_string()))?;
317 }
318 }
319 return Ok(());
320 }
321
322 let node = MerkleNode::interior(children);
324 let hash_hex = hex::encode(node.hash); let parent = if path.is_empty() { None } else { Some(PathMerkle::parent_prefix(path)) };
326 let object_count = node.children.len() as i32;
327
328 let sql = if self.is_sqlite {
329 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
330 VALUES (?, ?, ?, 0, ?)
331 ON CONFLICT(path) DO UPDATE SET
332 merkle_hash = excluded.merkle_hash,
333 object_count = excluded.object_count,
334 updated_at = strftime('%s', 'now')"
335 } else {
336 "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count)
337 VALUES (?, ?, ?, 0, ?)
338 ON DUPLICATE KEY UPDATE
339 merkle_hash = VALUES(merkle_hash),
340 object_count = VALUES(object_count),
341 updated_at = CURRENT_TIMESTAMP"
342 };
343
344 sqlx::query(sql)
345 .bind(path)
346 .bind(parent)
347 .bind(&hash_hex)
348 .bind(object_count)
349 .execute(&self.pool)
350 .await
351 .map_err(|e| StorageError::Backend(format!("Failed to update interior node: {}", e)))?;
352
353 debug!(path = %path, children = node.children.len(), "Recomputed interior node");
354 Ok(())
355 }
356
357 #[instrument(skip(self, their_children))]
361 pub async fn diff_children(
362 &self,
363 path: &str,
364 their_children: &BTreeMap<String, [u8; 32]>,
365 ) -> Result<Vec<String>, StorageError> {
366 let our_children = self.get_children(path).await?;
367 let mut diffs = Vec::new();
368
369 let prefix_with_dot = if path.is_empty() {
370 String::new()
371 } else {
372 format!("{}.", path)
373 };
374
375 for (segment, our_hash) in &our_children {
377 match their_children.get(segment) {
378 Some(their_hash) if their_hash != our_hash => {
379 diffs.push(format!("{}{}", prefix_with_dot, segment));
380 }
381 None => {
382 diffs.push(format!("{}{}", prefix_with_dot, segment));
384 }
385 _ => {} }
387 }
388
389 for segment in their_children.keys() {
391 if !our_children.contains_key(segment) {
392 diffs.push(format!("{}{}", prefix_with_dot, segment));
393 }
394 }
395
396 Ok(diffs)
397 }
398
399 pub async fn get_leaves_under(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
401 let pattern = if prefix.is_empty() {
402 "%".to_string()
403 } else {
404 format!("{}%", prefix)
405 };
406
407 let sql = "SELECT path FROM merkle_nodes WHERE path LIKE ? AND is_leaf = 1";
409
410 let rows = sqlx::query(sql)
411 .bind(&pattern)
412 .fetch_all(&self.pool)
413 .await
414 .map_err(|e| StorageError::Backend(e.to_string()))?;
415
416 let mut leaves = Vec::with_capacity(rows.len());
417 for row in rows {
418 let path: String = row.try_get("path")
419 .map_err(|e| StorageError::Backend(e.to_string()))?;
420 leaves.push(path);
421 }
422
423 Ok(leaves)
424 }
425
426 pub async fn count_leaves(&self) -> Result<u64, StorageError> {
428 let sql = "SELECT COUNT(*) as cnt FROM merkle_nodes WHERE is_leaf = 1";
430
431 let row = sqlx::query(sql)
432 .fetch_one(&self.pool)
433 .await
434 .map_err(|e| StorageError::Backend(e.to_string()))?;
435
436 let count: i64 = row.try_get("cnt")
437 .map_err(|e| StorageError::Backend(e.to_string()))?;
438
439 Ok(count as u64)
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use std::path::PathBuf;
447
448 fn temp_db_path(name: &str) -> PathBuf {
449 std::env::temp_dir().join(format!("merkle_test_{}.db", name))
450 }
451
452 #[tokio::test]
453 async fn test_sql_merkle_basic() {
454 let db_path = temp_db_path("basic");
455 let _ = std::fs::remove_file(&db_path); let url = format!("sqlite://{}?mode=rwc", db_path.display());
458 let store = SqlMerkleStore::new(&url).await.unwrap();
459
460 assert!(store.root_hash().await.unwrap().is_none());
462
463 let mut batch = MerkleBatch::new();
465 batch.insert("uk.nhs.patient.123".to_string(), [1u8; 32]);
466 store.apply_batch(&batch).await.unwrap();
467
468 let root = store.root_hash().await.unwrap();
470 assert!(root.is_some());
471
472 let leaf = store.get_hash("uk.nhs.patient.123").await.unwrap();
474 assert_eq!(leaf, Some([1u8; 32]));
475
476 let _ = std::fs::remove_file(&db_path); }
478
479 #[tokio::test]
480 async fn test_sql_merkle_diff() {
481 let db_path = temp_db_path("diff");
482 let _ = std::fs::remove_file(&db_path); let url = format!("sqlite://{}?mode=rwc", db_path.display());
485 let store = SqlMerkleStore::new(&url).await.unwrap();
486
487 let mut batch = MerkleBatch::new();
489 batch.insert("uk.a.1".to_string(), [1u8; 32]);
490 batch.insert("uk.b.2".to_string(), [2u8; 32]);
491 store.apply_batch(&batch).await.unwrap();
492
493 let mut their_children = BTreeMap::new();
495 their_children.insert("a".to_string(), [99u8; 32]); their_children.insert("b".to_string(), store.get_hash("uk.b").await.unwrap().unwrap()); let diffs = store.diff_children("uk", &their_children).await.unwrap();
499 assert!(diffs.contains(&"uk.a".to_string()));
500 assert!(!diffs.contains(&"uk.b".to_string()));
501
502 let _ = std::fs::remove_file(&db_path); }
504}