sync_engine/merkle/
sql_store.rs

1// Copyright (c) 2025-2026 Adrian Robinson. Licensed under the AGPL-3.0.
2// See LICENSE file in the project root for full license text.
3
4//! SQL storage for Merkle tree nodes (ground truth).
5//!
6//! The SQL merkle store is the authoritative source for merkle hashes.
7//! On startup, we trust SQL merkle root over Redis.
8//!
9//! # Schema
10//!
11//! ```sql
12//! CREATE TABLE merkle_nodes (
13//!     path VARCHAR(255) PRIMARY KEY,    -- "" = root, "uk", "uk.nhs", etc
14//!     parent_path VARCHAR(255),          -- Parent node path (NULL for root)
15//!     merkle_hash BINARY(32) NOT NULL,   -- SHA256 hash
16//!     object_count INT NOT NULL DEFAULT 0,
17//!     is_leaf BOOLEAN DEFAULT FALSE,
18//!     updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
19//! );
20//! ```
21
22use super::path_tree::{MerkleBatch, MerkleNode, PathMerkle};
23use crate::StorageError;
24use sqlx::{AnyPool, Row, any::AnyPoolOptions};
25use std::collections::BTreeMap;
26use std::time::Duration;
27use tracing::{debug, info, instrument};
28
29/// SQL key for root hash
30const ROOT_PATH: &str = "";
31
32/// SQL-backed Merkle tree storage (ground truth).
33#[derive(Clone)]
34pub struct SqlMerkleStore {
35    pool: AnyPool,
36    is_sqlite: bool,
37}
38
39impl SqlMerkleStore {
40    /// Create a new SQL merkle store, initializing schema if needed.
41    pub async fn new(connection_string: &str) -> Result<Self, StorageError> {
42        // Ensure drivers are installed
43        sqlx::any::install_default_drivers();
44        
45        let is_sqlite = connection_string.starts_with("sqlite:");
46        
47        let pool = AnyPoolOptions::new()
48            .max_connections(5)
49            .acquire_timeout(Duration::from_secs(10))
50            .connect(connection_string)
51            .await
52            .map_err(|e| StorageError::Backend(format!("Failed to connect to SQL merkle store: {}", e)))?;
53
54        let store = Self { pool, is_sqlite };
55        store.init_schema().await?;
56        
57        info!("SQL merkle store initialized");
58        Ok(store)
59    }
60
61    /// Create from existing pool (e.g., share with SqlStore).
62    pub fn from_pool(pool: AnyPool, is_sqlite: bool) -> Self {
63        Self { pool, is_sqlite }
64    }
65
66    /// Initialize the schema (creates tables if not exist).
67    /// Called automatically by `new()`, but must be called manually after `from_pool()`.
68    pub async fn init_schema(&self) -> Result<(), StorageError> {
69        // Use TEXT/VARCHAR for merkle_hash (stored as 64-char hex) for SQLx Any driver compatibility
70        let sql = if self.is_sqlite {
71            r#"
72            CREATE TABLE IF NOT EXISTS merkle_nodes (
73                path TEXT PRIMARY KEY,
74                parent_path TEXT,
75                merkle_hash TEXT NOT NULL,
76                object_count INTEGER NOT NULL DEFAULT 0,
77                is_leaf INTEGER DEFAULT 0,
78                updated_at INTEGER DEFAULT (strftime('%s', 'now'))
79            );
80            CREATE INDEX IF NOT EXISTS idx_merkle_parent ON merkle_nodes(parent_path);
81            "#
82        } else {
83            // Use VARCHAR for merkle_hash to avoid SQLx Any driver binary type mapping issues
84            // Hash is stored as 64-char hex string
85            r#"
86            CREATE TABLE IF NOT EXISTS merkle_nodes (
87                path VARCHAR(255) PRIMARY KEY,
88                parent_path VARCHAR(255),
89                merkle_hash VARCHAR(64) NOT NULL,
90                object_count INT NOT NULL DEFAULT 0,
91                is_leaf INT DEFAULT 0,
92                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
93                INDEX idx_merkle_parent (parent_path)
94            ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
95            "#
96        };
97
98        // SQLite needs separate statements
99        if self.is_sqlite {
100            for stmt in sql.split(';').filter(|s| !s.trim().is_empty()) {
101                sqlx::query(stmt)
102                    .execute(&self.pool)
103                    .await
104                    .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
105            }
106        } else {
107            sqlx::query(sql)
108                .execute(&self.pool)
109                .await
110                .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
111        }
112
113        Ok(())
114    }
115
116    /// Get the root hash (ground truth).
117    #[instrument(skip(self))]
118    pub async fn root_hash(&self) -> Result<Option<[u8; 32]>, StorageError> {
119        self.get_hash(ROOT_PATH).await
120    }
121
122    /// Get hash for a specific path.
123    pub async fn get_hash(&self, path: &str) -> Result<Option<[u8; 32]>, StorageError> {
124        let result = sqlx::query("SELECT merkle_hash FROM merkle_nodes WHERE path = ?")
125            .bind(path)
126            .fetch_optional(&self.pool)
127            .await
128            .map_err(|e| StorageError::Backend(format!("Failed to get merkle hash: {}", e)))?;
129
130        match result {
131            Some(row) => {
132                // For MySQL we store as hex VARCHAR, for SQLite as BLOB
133                // Try reading as string first (works for both), then decode
134                let hash_str: String = row.try_get("merkle_hash")
135                    .map_err(|e| StorageError::Backend(format!("Failed to read merkle hash: {}", e)))?;
136                
137                // If it looks like hex (64 chars), decode it
138                // Otherwise it might be raw bytes that got stringified
139                let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
140                    hex::decode(&hash_str).map_err(|e| StorageError::Backend(format!(
141                        "Invalid merkle hash hex: {}", e
142                    )))?
143                } else {
144                    // SQLite might return raw bytes
145                    hash_str.into_bytes()
146                };
147                
148                if bytes.len() != 32 {
149                    return Err(StorageError::Backend(format!(
150                        "Invalid merkle hash length: {}", bytes.len()
151                    )));
152                }
153                let mut hash = [0u8; 32];
154                hash.copy_from_slice(&bytes);
155                Ok(Some(hash))
156            }
157            None => Ok(None),
158        }
159    }
160
161    /// Get children of an interior node.
162    pub async fn get_children(&self, path: &str) -> Result<BTreeMap<String, [u8; 32]>, StorageError> {
163        let rows = sqlx::query(
164            "SELECT path, merkle_hash FROM merkle_nodes WHERE parent_path = ?"
165        )
166            .bind(path)
167            .fetch_all(&self.pool)
168            .await
169            .map_err(|e| StorageError::Backend(format!("Failed to get merkle children: {}", e)))?;
170
171        let mut children = BTreeMap::new();
172        for row in rows {
173            let child_path: String = row.try_get("path")
174                .map_err(|e| StorageError::Backend(e.to_string()))?;
175            
176            // Read hash as string (works for both MySQL VARCHAR and SQLite BLOB-as-text)
177            let hash_str: String = row.try_get("merkle_hash")
178                .map_err(|e| StorageError::Backend(e.to_string()))?;
179            
180            // Decode from hex if it looks like hex
181            let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
182                hex::decode(&hash_str).unwrap_or_else(|_| hash_str.into_bytes())
183            } else {
184                hash_str.into_bytes()
185            };
186            
187            if bytes.len() == 32 {
188                // Extract just the segment name from path
189                let segment = if path.is_empty() {
190                    child_path.clone()
191                } else {
192                    child_path.strip_prefix(&format!("{}.", path))
193                        .unwrap_or(&child_path)
194                        .to_string()
195                };
196                
197                let mut hash = [0u8; 32];
198                hash.copy_from_slice(&bytes);
199                children.insert(segment, hash);
200            }
201        }
202
203        Ok(children)
204    }
205
206    /// Get a full node (hash + children).
207    pub async fn get_node(&self, path: &str) -> Result<Option<MerkleNode>, StorageError> {
208        let hash = self.get_hash(path).await?;
209        
210        match hash {
211            Some(h) => {
212                let children = self.get_children(path).await?;
213                Ok(Some(if children.is_empty() {
214                    MerkleNode::leaf(h)
215                } else {
216                    MerkleNode {
217                        hash: h,
218                        children,
219                        is_leaf: false,
220                    }
221                }))
222            }
223            None => Ok(None),
224        }
225    }
226
227    /// Apply a batch of merkle updates atomically.
228    ///
229    /// This stores leaf hashes and recomputes affected interior nodes.
230    #[instrument(skip(self, batch), fields(batch_size = batch.len()))]
231    pub async fn apply_batch(&self, batch: &MerkleBatch) -> Result<(), StorageError> {
232        if batch.is_empty() {
233            return Ok(());
234        }
235
236        // Start transaction with READ COMMITTED isolation (reduces gap locks on MySQL)
237        let mut tx = self.pool.begin().await
238            .map_err(|e| StorageError::Backend(format!("Failed to begin transaction: {}", e)))?;
239
240        // MySQL: Set READ COMMITTED to reduce gap locking and deadlocks
241        if !self.is_sqlite {
242            sqlx::query("SET TRANSACTION ISOLATION LEVEL READ COMMITTED")
243                .execute(&mut *tx)
244                .await
245                .map_err(|e| StorageError::Backend(format!("Failed to set isolation level: {}", e)))?;
246        }
247
248        // Step 1: Apply leaf updates
249        for (object_id, maybe_hash) in &batch.leaves {
250            let parent = PathMerkle::parent_prefix(object_id);
251            
252            match maybe_hash {
253                Some(hash) => {
254                    // Store hash as hex string for cross-DB compatibility
255                    let hash_hex = hex::encode(hash);
256                    let sql = if self.is_sqlite {
257                        "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
258                         VALUES (?, ?, ?, 1, 1)
259                         ON CONFLICT(path) DO UPDATE SET 
260                            merkle_hash = excluded.merkle_hash,
261                            updated_at = strftime('%s', 'now')"
262                    } else {
263                        "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
264                         VALUES (?, ?, ?, 1, 1)
265                         ON DUPLICATE KEY UPDATE 
266                            merkle_hash = VALUES(merkle_hash),
267                            updated_at = CURRENT_TIMESTAMP"
268                    };
269                    
270                    sqlx::query(sql)
271                        .bind(object_id)
272                        .bind(parent)
273                        .bind(&hash_hex)
274                        .execute(&mut *tx)
275                        .await
276                        .map_err(|e| StorageError::Backend(format!("Failed to upsert merkle leaf: {}", e)))?;
277                }
278                None => {
279                    sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
280                        .bind(object_id)
281                        .execute(&mut *tx)
282                        .await
283                        .map_err(|e| StorageError::Backend(format!("Failed to delete merkle leaf: {}", e)))?;
284                }
285            }
286        }
287
288        // Commit leaf changes first
289        tx.commit().await
290            .map_err(|e| StorageError::Backend(format!("Failed to commit merkle leaves: {}", e)))?;
291
292        // Step 2: Recompute affected interior nodes (bottom-up)
293        let affected_prefixes = batch.affected_prefixes();
294        for prefix in affected_prefixes {
295            self.recompute_interior_node(&prefix).await?;
296        }
297
298        // Step 3: Recompute root
299        self.recompute_interior_node(ROOT_PATH).await?;
300
301        debug!(updates = batch.len(), "SQL merkle batch applied");
302        Ok(())
303    }
304
305    /// Recompute an interior node's hash from its children.
306    async fn recompute_interior_node(&self, path: &str) -> Result<(), StorageError> {
307        // Get all direct children
308        let children = self.get_children(path).await?;
309        
310        if children.is_empty() {
311            // No children, remove this interior node (unless it's a leaf)
312            let result = sqlx::query("SELECT is_leaf FROM merkle_nodes WHERE path = ?")
313                .bind(path)
314                .fetch_optional(&self.pool)
315                .await
316                .map_err(|e| StorageError::Backend(e.to_string()))?;
317            
318            if let Some(row) = result {
319                // Both SQLite and MySQL now store is_leaf as INT
320                let is_leaf: i32 = row.try_get("is_leaf").unwrap_or(0);
321                
322                if is_leaf == 0 {
323                    sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
324                        .bind(path)
325                        .execute(&self.pool)
326                        .await
327                        .map_err(|e| StorageError::Backend(e.to_string()))?;
328                }
329            }
330            return Ok(());
331        }
332
333        // Compute new interior hash
334        let node = MerkleNode::interior(children);
335        let hash_hex = hex::encode(node.hash); // Store as hex string
336        let parent = if path.is_empty() { None } else { Some(PathMerkle::parent_prefix(path)) };
337        let object_count = node.children.len() as i32;
338
339        let sql = if self.is_sqlite {
340            "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
341             VALUES (?, ?, ?, 0, ?)
342             ON CONFLICT(path) DO UPDATE SET 
343                merkle_hash = excluded.merkle_hash,
344                object_count = excluded.object_count,
345                updated_at = strftime('%s', 'now')"
346        } else {
347            "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
348             VALUES (?, ?, ?, 0, ?)
349             ON DUPLICATE KEY UPDATE 
350                merkle_hash = VALUES(merkle_hash),
351                object_count = VALUES(object_count),
352                updated_at = CURRENT_TIMESTAMP"
353        };
354
355        sqlx::query(sql)
356            .bind(path)
357            .bind(parent)
358            .bind(&hash_hex)
359            .bind(object_count)
360            .execute(&self.pool)
361            .await
362            .map_err(|e| StorageError::Backend(format!("Failed to update interior node: {}", e)))?;
363
364        debug!(path = %path, children = node.children.len(), "Recomputed interior node");
365        Ok(())
366    }
367
368    /// Compare with another merkle store and find differing branches.
369    ///
370    /// Returns prefixes where hashes differ (for sync).
371    #[instrument(skip(self, their_children))]
372    pub async fn diff_children(
373        &self,
374        path: &str,
375        their_children: &BTreeMap<String, [u8; 32]>,
376    ) -> Result<Vec<String>, StorageError> {
377        let our_children = self.get_children(path).await?;
378        let mut diffs = Vec::new();
379
380        let prefix_with_dot = if path.is_empty() {
381            String::new()
382        } else {
383            format!("{}.", path)
384        };
385
386        // Find segments where hashes differ or we have but they don't
387        for (segment, our_hash) in &our_children {
388            match their_children.get(segment) {
389                Some(their_hash) if their_hash != our_hash => {
390                    diffs.push(format!("{}{}", prefix_with_dot, segment));
391                }
392                None => {
393                    // We have it, they don't
394                    diffs.push(format!("{}{}", prefix_with_dot, segment));
395                }
396                _ => {} // Hashes match
397            }
398        }
399
400        // Find segments they have but we don't
401        for segment in their_children.keys() {
402            if !our_children.contains_key(segment) {
403                diffs.push(format!("{}{}", prefix_with_dot, segment));
404            }
405        }
406
407        Ok(diffs)
408    }
409
410    /// Get all leaf paths under a prefix (for sync).
411    pub async fn get_leaves_under(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
412        let pattern = if prefix.is_empty() {
413            "%".to_string()
414        } else {
415            format!("{}%", prefix)
416        };
417
418        // Both SQLite and MySQL now use INT for is_leaf
419        let sql = "SELECT path FROM merkle_nodes WHERE path LIKE ? AND is_leaf = 1";
420
421        let rows = sqlx::query(sql)
422            .bind(&pattern)
423            .fetch_all(&self.pool)
424            .await
425            .map_err(|e| StorageError::Backend(e.to_string()))?;
426
427        let mut leaves = Vec::with_capacity(rows.len());
428        for row in rows {
429            let path: String = row.try_get("path")
430                .map_err(|e| StorageError::Backend(e.to_string()))?;
431            leaves.push(path);
432        }
433
434        Ok(leaves)
435    }
436
437    /// Count total objects (leaves) in the tree.
438    pub async fn count_leaves(&self) -> Result<u64, StorageError> {
439        // Both SQLite and MySQL now use INT for is_leaf
440        let sql = "SELECT COUNT(*) as cnt FROM merkle_nodes WHERE is_leaf = 1";
441
442        let row = sqlx::query(sql)
443            .fetch_one(&self.pool)
444            .await
445            .map_err(|e| StorageError::Backend(e.to_string()))?;
446
447        let count: i64 = row.try_get("cnt")
448            .map_err(|e| StorageError::Backend(e.to_string()))?;
449
450        Ok(count as u64)
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use std::path::PathBuf;
458    
459    fn temp_db_path(name: &str) -> PathBuf {
460        std::env::temp_dir().join(format!("merkle_test_{}.db", name))
461    }
462
463    #[tokio::test]
464    async fn test_sql_merkle_basic() {
465        let db_path = temp_db_path("basic");
466        let _ = std::fs::remove_file(&db_path); // Clean up any old test
467        
468        let url = format!("sqlite://{}?mode=rwc", db_path.display());
469        let store = SqlMerkleStore::new(&url).await.unwrap();
470        
471        // Initially empty
472        assert!(store.root_hash().await.unwrap().is_none());
473        
474        // Insert a leaf
475        let mut batch = MerkleBatch::new();
476        batch.insert("uk.nhs.patient.123".to_string(), [1u8; 32]);
477        store.apply_batch(&batch).await.unwrap();
478        
479        // Root should now exist
480        let root = store.root_hash().await.unwrap();
481        assert!(root.is_some());
482        
483        // Leaf should exist
484        let leaf = store.get_hash("uk.nhs.patient.123").await.unwrap();
485        assert_eq!(leaf, Some([1u8; 32]));
486        
487        let _ = std::fs::remove_file(&db_path); // Clean up
488    }
489
490    #[tokio::test]
491    async fn test_sql_merkle_diff() {
492        let db_path = temp_db_path("diff");
493        let _ = std::fs::remove_file(&db_path); // Clean up any old test
494        
495        let url = format!("sqlite://{}?mode=rwc", db_path.display());
496        let store = SqlMerkleStore::new(&url).await.unwrap();
497        
498        // Insert some leaves
499        let mut batch = MerkleBatch::new();
500        batch.insert("uk.a.1".to_string(), [1u8; 32]);
501        batch.insert("uk.b.2".to_string(), [2u8; 32]);
502        store.apply_batch(&batch).await.unwrap();
503        
504        // Compare with "their" children where uk.a differs
505        let mut their_children = BTreeMap::new();
506        their_children.insert("a".to_string(), [99u8; 32]); // Different!
507        their_children.insert("b".to_string(), store.get_hash("uk.b").await.unwrap().unwrap()); // Same
508        
509        let diffs = store.diff_children("uk", &their_children).await.unwrap();
510        assert!(diffs.contains(&"uk.a".to_string()));
511        assert!(!diffs.contains(&"uk.b".to_string()));
512        
513        let _ = std::fs::remove_file(&db_path); // Clean up
514    }
515}