sync_engine/merkle/
sql_store.rs

1// Copyright (c) 2025-2026 Adrian Robinson. Licensed under the AGPL-3.0.
2// See LICENSE file in the project root for full license text.
3
4//! SQL storage for Merkle tree nodes (ground truth).
5//!
6//! The SQL merkle store is the authoritative source for merkle hashes.
7//! On startup, we trust SQL merkle root over Redis.
8//!
9//! # Schema
10//!
11//! ```sql
12//! CREATE TABLE merkle_nodes (
13//!     path VARCHAR(255) PRIMARY KEY,    -- "" = root, "uk", "uk.nhs", etc
14//!     parent_path VARCHAR(255),          -- Parent node path (NULL for root)
15//!     merkle_hash BINARY(32) NOT NULL,   -- SHA256 hash
16//!     object_count INT NOT NULL DEFAULT 0,
17//!     is_leaf BOOLEAN DEFAULT FALSE,
18//!     updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
19//! );
20//! ```
21
22use super::path_tree::{MerkleBatch, MerkleNode, PathMerkle};
23use crate::StorageError;
24use sqlx::{AnyPool, Row, any::AnyPoolOptions};
25use std::collections::BTreeMap;
26use std::time::Duration;
27use tracing::{debug, info, instrument};
28
29/// SQL key for root hash
30const ROOT_PATH: &str = "";
31
32/// SQL-backed Merkle tree storage (ground truth).
33#[derive(Clone)]
34pub struct SqlMerkleStore {
35    pool: AnyPool,
36    is_sqlite: bool,
37}
38
39impl SqlMerkleStore {
40    /// Create a new SQL merkle store, initializing schema if needed.
41    pub async fn new(connection_string: &str) -> Result<Self, StorageError> {
42        // Ensure drivers are installed
43        sqlx::any::install_default_drivers();
44        
45        let is_sqlite = connection_string.starts_with("sqlite:");
46        
47        let pool = AnyPoolOptions::new()
48            .max_connections(5)
49            .acquire_timeout(Duration::from_secs(10))
50            .connect(connection_string)
51            .await
52            .map_err(|e| StorageError::Backend(format!("Failed to connect to SQL merkle store: {}", e)))?;
53
54        let store = Self { pool, is_sqlite };
55        store.init_schema().await?;
56        
57        info!("SQL merkle store initialized");
58        Ok(store)
59    }
60
61    /// Create from existing pool (e.g., share with SqlStore).
62    pub fn from_pool(pool: AnyPool, is_sqlite: bool) -> Self {
63        Self { pool, is_sqlite }
64    }
65
66    /// Initialize the schema (creates tables if not exist).
67    /// Called automatically by `new()`, but must be called manually after `from_pool()`.
68    pub async fn init_schema(&self) -> Result<(), StorageError> {
69        // Use TEXT/VARCHAR for merkle_hash (stored as 64-char hex) for SQLx Any driver compatibility
70        let sql = if self.is_sqlite {
71            r#"
72            CREATE TABLE IF NOT EXISTS merkle_nodes (
73                path TEXT PRIMARY KEY,
74                parent_path TEXT,
75                merkle_hash TEXT NOT NULL,
76                object_count INTEGER NOT NULL DEFAULT 0,
77                is_leaf INTEGER DEFAULT 0,
78                updated_at INTEGER DEFAULT (strftime('%s', 'now'))
79            );
80            CREATE INDEX IF NOT EXISTS idx_merkle_parent ON merkle_nodes(parent_path);
81            "#
82        } else {
83            // Use VARCHAR for merkle_hash to avoid SQLx Any driver binary type mapping issues
84            // Hash is stored as 64-char hex string
85            r#"
86            CREATE TABLE IF NOT EXISTS merkle_nodes (
87                path VARCHAR(255) PRIMARY KEY,
88                parent_path VARCHAR(255),
89                merkle_hash VARCHAR(64) NOT NULL,
90                object_count INT NOT NULL DEFAULT 0,
91                is_leaf INT DEFAULT 0,
92                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
93                INDEX idx_merkle_parent (parent_path)
94            ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
95            "#
96        };
97
98        // SQLite needs separate statements
99        if self.is_sqlite {
100            for stmt in sql.split(';').filter(|s| !s.trim().is_empty()) {
101                sqlx::query(stmt)
102                    .execute(&self.pool)
103                    .await
104                    .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
105            }
106        } else {
107            sqlx::query(sql)
108                .execute(&self.pool)
109                .await
110                .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
111        }
112
113        Ok(())
114    }
115
116    /// Get the root hash (ground truth).
117    #[instrument(skip(self))]
118    pub async fn root_hash(&self) -> Result<Option<[u8; 32]>, StorageError> {
119        self.get_hash(ROOT_PATH).await
120    }
121
122    /// Get hash for a specific path.
123    pub async fn get_hash(&self, path: &str) -> Result<Option<[u8; 32]>, StorageError> {
124        let result = sqlx::query("SELECT merkle_hash FROM merkle_nodes WHERE path = ?")
125            .bind(path)
126            .fetch_optional(&self.pool)
127            .await
128            .map_err(|e| StorageError::Backend(format!("Failed to get merkle hash: {}", e)))?;
129
130        match result {
131            Some(row) => {
132                // For MySQL we store as hex VARCHAR, for SQLite as BLOB
133                // Try reading as string first (works for both), then decode
134                let hash_str: String = row.try_get("merkle_hash")
135                    .map_err(|e| StorageError::Backend(format!("Failed to read merkle hash: {}", e)))?;
136                
137                // If it looks like hex (64 chars), decode it
138                // Otherwise it might be raw bytes that got stringified
139                let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
140                    hex::decode(&hash_str).map_err(|e| StorageError::Backend(format!(
141                        "Invalid merkle hash hex: {}", e
142                    )))?
143                } else {
144                    // SQLite might return raw bytes
145                    hash_str.into_bytes()
146                };
147                
148                if bytes.len() != 32 {
149                    return Err(StorageError::Backend(format!(
150                        "Invalid merkle hash length: {}", bytes.len()
151                    )));
152                }
153                let mut hash = [0u8; 32];
154                hash.copy_from_slice(&bytes);
155                Ok(Some(hash))
156            }
157            None => Ok(None),
158        }
159    }
160
161    /// Get children of an interior node.
162    pub async fn get_children(&self, path: &str) -> Result<BTreeMap<String, [u8; 32]>, StorageError> {
163        let rows = sqlx::query(
164            "SELECT path, merkle_hash FROM merkle_nodes WHERE parent_path = ?"
165        )
166            .bind(path)
167            .fetch_all(&self.pool)
168            .await
169            .map_err(|e| StorageError::Backend(format!("Failed to get merkle children: {}", e)))?;
170
171        let mut children = BTreeMap::new();
172        for row in rows {
173            let child_path: String = row.try_get("path")
174                .map_err(|e| StorageError::Backend(e.to_string()))?;
175            
176            // Read hash as string (works for both MySQL VARCHAR and SQLite BLOB-as-text)
177            let hash_str: String = row.try_get("merkle_hash")
178                .map_err(|e| StorageError::Backend(e.to_string()))?;
179            
180            // Decode from hex if it looks like hex
181            let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
182                hex::decode(&hash_str).unwrap_or_else(|_| hash_str.into_bytes())
183            } else {
184                hash_str.into_bytes()
185            };
186            
187            if bytes.len() == 32 {
188                // Extract just the segment name from path
189                let segment = if path.is_empty() {
190                    child_path.clone()
191                } else {
192                    child_path.strip_prefix(&format!("{}.", path))
193                        .unwrap_or(&child_path)
194                        .to_string()
195                };
196                
197                let mut hash = [0u8; 32];
198                hash.copy_from_slice(&bytes);
199                children.insert(segment, hash);
200            }
201        }
202
203        Ok(children)
204    }
205
206    /// Get a full node (hash + children).
207    pub async fn get_node(&self, path: &str) -> Result<Option<MerkleNode>, StorageError> {
208        let hash = self.get_hash(path).await?;
209        
210        match hash {
211            Some(h) => {
212                let children = self.get_children(path).await?;
213                Ok(Some(if children.is_empty() {
214                    MerkleNode::leaf(h)
215                } else {
216                    MerkleNode {
217                        hash: h,
218                        children,
219                        is_leaf: false,
220                    }
221                }))
222            }
223            None => Ok(None),
224        }
225    }
226
227    /// Apply a batch of merkle updates atomically.
228    ///
229    /// This stores leaf hashes and recomputes affected interior nodes.
230    /// Note: Session-level READ COMMITTED is set in SqlStore::new() for MySQL.
231    #[instrument(skip(self, batch), fields(batch_size = batch.len()))]
232    pub async fn apply_batch(&self, batch: &MerkleBatch) -> Result<(), StorageError> {
233        if batch.is_empty() {
234            return Ok(());
235        }
236
237        // Start transaction (READ COMMITTED already set at session level for MySQL)
238        let mut tx = self.pool.begin().await
239            .map_err(|e| StorageError::Backend(format!("Failed to begin transaction: {}", e)))?;
240
241        // Step 1: Apply leaf updates
242        for (object_id, maybe_hash) in &batch.leaves {
243            let parent = PathMerkle::parent_prefix(object_id);
244            
245            match maybe_hash {
246                Some(hash) => {
247                    // Store hash as hex string for cross-DB compatibility
248                    let hash_hex = hex::encode(hash);
249                    let sql = if self.is_sqlite {
250                        "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
251                         VALUES (?, ?, ?, 1, 1)
252                         ON CONFLICT(path) DO UPDATE SET 
253                            merkle_hash = excluded.merkle_hash,
254                            updated_at = strftime('%s', 'now')"
255                    } else {
256                        "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
257                         VALUES (?, ?, ?, 1, 1)
258                         ON DUPLICATE KEY UPDATE 
259                            merkle_hash = VALUES(merkle_hash),
260                            updated_at = CURRENT_TIMESTAMP"
261                    };
262                    
263                    sqlx::query(sql)
264                        .bind(object_id)
265                        .bind(parent)
266                        .bind(&hash_hex)
267                        .execute(&mut *tx)
268                        .await
269                        .map_err(|e| StorageError::Backend(format!("Failed to upsert merkle leaf: {}", e)))?;
270                }
271                None => {
272                    sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
273                        .bind(object_id)
274                        .execute(&mut *tx)
275                        .await
276                        .map_err(|e| StorageError::Backend(format!("Failed to delete merkle leaf: {}", e)))?;
277                }
278            }
279        }
280
281        // Commit leaf changes first
282        tx.commit().await
283            .map_err(|e| StorageError::Backend(format!("Failed to commit merkle leaves: {}", e)))?;
284
285        // Step 2: Recompute affected interior nodes (bottom-up)
286        let affected_prefixes = batch.affected_prefixes();
287        for prefix in affected_prefixes {
288            self.recompute_interior_node(&prefix).await?;
289        }
290
291        // Step 3: Recompute root
292        self.recompute_interior_node(ROOT_PATH).await?;
293
294        debug!(updates = batch.len(), "SQL merkle batch applied");
295        Ok(())
296    }
297
298    /// Recompute an interior node's hash from its children.
299    async fn recompute_interior_node(&self, path: &str) -> Result<(), StorageError> {
300        // Get all direct children
301        let children = self.get_children(path).await?;
302        
303        if children.is_empty() {
304            // No children, remove this interior node (unless it's a leaf)
305            let result = sqlx::query("SELECT is_leaf FROM merkle_nodes WHERE path = ?")
306                .bind(path)
307                .fetch_optional(&self.pool)
308                .await
309                .map_err(|e| StorageError::Backend(e.to_string()))?;
310            
311            if let Some(row) = result {
312                // Both SQLite and MySQL now store is_leaf as INT
313                let is_leaf: i32 = row.try_get("is_leaf").unwrap_or(0);
314                
315                if is_leaf == 0 {
316                    sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
317                        .bind(path)
318                        .execute(&self.pool)
319                        .await
320                        .map_err(|e| StorageError::Backend(e.to_string()))?;
321                }
322            }
323            return Ok(());
324        }
325
326        // Compute new interior hash
327        let node = MerkleNode::interior(children);
328        let hash_hex = hex::encode(node.hash); // Store as hex string
329        let parent = if path.is_empty() { None } else { Some(PathMerkle::parent_prefix(path)) };
330        let object_count = node.children.len() as i32;
331
332        let sql = if self.is_sqlite {
333            "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
334             VALUES (?, ?, ?, 0, ?)
335             ON CONFLICT(path) DO UPDATE SET 
336                merkle_hash = excluded.merkle_hash,
337                object_count = excluded.object_count,
338                updated_at = strftime('%s', 'now')"
339        } else {
340            "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
341             VALUES (?, ?, ?, 0, ?)
342             ON DUPLICATE KEY UPDATE 
343                merkle_hash = VALUES(merkle_hash),
344                object_count = VALUES(object_count),
345                updated_at = CURRENT_TIMESTAMP"
346        };
347
348        sqlx::query(sql)
349            .bind(path)
350            .bind(parent)
351            .bind(&hash_hex)
352            .bind(object_count)
353            .execute(&self.pool)
354            .await
355            .map_err(|e| StorageError::Backend(format!("Failed to update interior node: {}", e)))?;
356
357        debug!(path = %path, children = node.children.len(), "Recomputed interior node");
358        Ok(())
359    }
360
361    /// Compare with another merkle store and find differing branches.
362    ///
363    /// Returns prefixes where hashes differ (for sync).
364    #[instrument(skip(self, their_children))]
365    pub async fn diff_children(
366        &self,
367        path: &str,
368        their_children: &BTreeMap<String, [u8; 32]>,
369    ) -> Result<Vec<String>, StorageError> {
370        let our_children = self.get_children(path).await?;
371        let mut diffs = Vec::new();
372
373        let prefix_with_dot = if path.is_empty() {
374            String::new()
375        } else {
376            format!("{}.", path)
377        };
378
379        // Find segments where hashes differ or we have but they don't
380        for (segment, our_hash) in &our_children {
381            match their_children.get(segment) {
382                Some(their_hash) if their_hash != our_hash => {
383                    diffs.push(format!("{}{}", prefix_with_dot, segment));
384                }
385                None => {
386                    // We have it, they don't
387                    diffs.push(format!("{}{}", prefix_with_dot, segment));
388                }
389                _ => {} // Hashes match
390            }
391        }
392
393        // Find segments they have but we don't
394        for segment in their_children.keys() {
395            if !our_children.contains_key(segment) {
396                diffs.push(format!("{}{}", prefix_with_dot, segment));
397            }
398        }
399
400        Ok(diffs)
401    }
402
403    /// Get all leaf paths under a prefix (for sync).
404    pub async fn get_leaves_under(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
405        let pattern = if prefix.is_empty() {
406            "%".to_string()
407        } else {
408            format!("{}%", prefix)
409        };
410
411        // Both SQLite and MySQL now use INT for is_leaf
412        let sql = "SELECT path FROM merkle_nodes WHERE path LIKE ? AND is_leaf = 1";
413
414        let rows = sqlx::query(sql)
415            .bind(&pattern)
416            .fetch_all(&self.pool)
417            .await
418            .map_err(|e| StorageError::Backend(e.to_string()))?;
419
420        let mut leaves = Vec::with_capacity(rows.len());
421        for row in rows {
422            let path: String = row.try_get("path")
423                .map_err(|e| StorageError::Backend(e.to_string()))?;
424            leaves.push(path);
425        }
426
427        Ok(leaves)
428    }
429
430    /// Count total objects (leaves) in the tree.
431    pub async fn count_leaves(&self) -> Result<u64, StorageError> {
432        // Both SQLite and MySQL now use INT for is_leaf
433        let sql = "SELECT COUNT(*) as cnt FROM merkle_nodes WHERE is_leaf = 1";
434
435        let row = sqlx::query(sql)
436            .fetch_one(&self.pool)
437            .await
438            .map_err(|e| StorageError::Backend(e.to_string()))?;
439
440        let count: i64 = row.try_get("cnt")
441            .map_err(|e| StorageError::Backend(e.to_string()))?;
442
443        Ok(count as u64)
444    }
445
446    /// Get all nodes from the tree (for cache sync).
447    ///
448    /// Returns (path, hash, children) for each node.
449    pub async fn get_all_nodes(&self) -> Result<Vec<(String, [u8; 32], BTreeMap<String, [u8; 32]>)>, StorageError> {
450        let rows = sqlx::query("SELECT path, merkle_hash FROM merkle_nodes ORDER BY path")
451            .fetch_all(&self.pool)
452            .await
453            .map_err(|e| StorageError::Backend(format!("Failed to get all merkle nodes: {}", e)))?;
454
455        let mut nodes = Vec::with_capacity(rows.len());
456        
457        for row in rows {
458            let path: String = row.try_get("path")
459                .map_err(|e| StorageError::Backend(e.to_string()))?;
460            
461            let hash_str: String = row.try_get("merkle_hash")
462                .map_err(|e| StorageError::Backend(e.to_string()))?;
463            
464            let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
465                hex::decode(&hash_str).unwrap_or_else(|_| hash_str.into_bytes())
466            } else {
467                hash_str.into_bytes()
468            };
469            
470            if bytes.len() == 32 {
471                let mut hash = [0u8; 32];
472                hash.copy_from_slice(&bytes);
473                
474                // Get children for this node
475                let children = self.get_children(&path).await?;
476                
477                nodes.push((path, hash, children));
478            }
479        }
480
481        Ok(nodes)
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use std::path::PathBuf;
489    
490    fn temp_db_path(name: &str) -> PathBuf {
491        std::env::temp_dir().join(format!("merkle_test_{}.db", name))
492    }
493
494    #[tokio::test]
495    async fn test_sql_merkle_basic() {
496        let db_path = temp_db_path("basic");
497        let _ = std::fs::remove_file(&db_path); // Clean up any old test
498        
499        let url = format!("sqlite://{}?mode=rwc", db_path.display());
500        let store = SqlMerkleStore::new(&url).await.unwrap();
501        
502        // Initially empty
503        assert!(store.root_hash().await.unwrap().is_none());
504        
505        // Insert a leaf
506        let mut batch = MerkleBatch::new();
507        batch.insert("uk.nhs.patient.123".to_string(), [1u8; 32]);
508        store.apply_batch(&batch).await.unwrap();
509        
510        // Root should now exist
511        let root = store.root_hash().await.unwrap();
512        assert!(root.is_some());
513        
514        // Leaf should exist
515        let leaf = store.get_hash("uk.nhs.patient.123").await.unwrap();
516        assert_eq!(leaf, Some([1u8; 32]));
517        
518        let _ = std::fs::remove_file(&db_path); // Clean up
519    }
520
521    #[tokio::test]
522    async fn test_sql_merkle_diff() {
523        let db_path = temp_db_path("diff");
524        let _ = std::fs::remove_file(&db_path); // Clean up any old test
525        
526        let url = format!("sqlite://{}?mode=rwc", db_path.display());
527        let store = SqlMerkleStore::new(&url).await.unwrap();
528        
529        // Insert some leaves
530        let mut batch = MerkleBatch::new();
531        batch.insert("uk.a.1".to_string(), [1u8; 32]);
532        batch.insert("uk.b.2".to_string(), [2u8; 32]);
533        store.apply_batch(&batch).await.unwrap();
534        
535        // Compare with "their" children where uk.a differs
536        let mut their_children = BTreeMap::new();
537        their_children.insert("a".to_string(), [99u8; 32]); // Different!
538        their_children.insert("b".to_string(), store.get_hash("uk.b").await.unwrap().unwrap()); // Same
539        
540        let diffs = store.diff_children("uk", &their_children).await.unwrap();
541        assert!(diffs.contains(&"uk.a".to_string()));
542        assert!(!diffs.contains(&"uk.b".to_string()));
543        
544        let _ = std::fs::remove_file(&db_path); // Clean up
545    }
546}