sync_engine/merkle/
sql_store.rs

1// Copyright (c) 2025-2026 Adrian Robinson. Licensed under the AGPL-3.0.
2// See LICENSE file in the project root for full license text.
3
4//! SQL storage for Merkle tree nodes (ground truth).
5//!
6//! The SQL merkle store is the authoritative source for merkle hashes.
7//! On startup, we trust SQL merkle root over Redis.
8//!
9//! # Schema
10//!
11//! ```sql
12//! CREATE TABLE merkle_nodes (
13//!     path VARCHAR(255) PRIMARY KEY,    -- "" = root, "uk", "uk.nhs", etc
14//!     parent_path VARCHAR(255),          -- Parent node path (NULL for root)
15//!     merkle_hash BINARY(32) NOT NULL,   -- SHA256 hash
16//!     object_count INT NOT NULL DEFAULT 0,
17//!     is_leaf BOOLEAN DEFAULT FALSE,
18//!     updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
19//! );
20//! ```
21
22use super::path_tree::{MerkleBatch, MerkleNode, PathMerkle};
23use crate::StorageError;
24use sqlx::{AnyPool, Row, any::AnyPoolOptions};
25use std::collections::BTreeMap;
26use std::time::Duration;
27use tracing::{debug, info, instrument};
28
29/// SQL key for root hash
30const ROOT_PATH: &str = "";
31
32/// SQL-backed Merkle tree storage (ground truth).
33#[derive(Clone)]
34pub struct SqlMerkleStore {
35    pool: AnyPool,
36    is_sqlite: bool,
37}
38
39impl SqlMerkleStore {
40    /// Create a new SQL merkle store, initializing schema if needed.
41    pub async fn new(connection_string: &str) -> Result<Self, StorageError> {
42        // Ensure drivers are installed
43        sqlx::any::install_default_drivers();
44        
45        let is_sqlite = connection_string.starts_with("sqlite:");
46        
47        let pool = AnyPoolOptions::new()
48            .max_connections(5)
49            .acquire_timeout(Duration::from_secs(10))
50            .connect(connection_string)
51            .await
52            .map_err(|e| StorageError::Backend(format!("Failed to connect to SQL merkle store: {}", e)))?;
53
54        let store = Self { pool, is_sqlite };
55        store.init_schema().await?;
56        
57        info!("SQL merkle store initialized");
58        Ok(store)
59    }
60
61    /// Create from existing pool (e.g., share with SqlStore).
62    pub fn from_pool(pool: AnyPool, is_sqlite: bool) -> Self {
63        Self { pool, is_sqlite }
64    }
65
66    /// Initialize the schema (creates tables if not exist).
67    /// Called automatically by `new()`, but must be called manually after `from_pool()`.
68    pub async fn init_schema(&self) -> Result<(), StorageError> {
69        // Use TEXT/VARCHAR for merkle_hash (stored as 64-char hex) for SQLx Any driver compatibility
70        let sql = if self.is_sqlite {
71            r#"
72            CREATE TABLE IF NOT EXISTS merkle_nodes (
73                path TEXT PRIMARY KEY,
74                parent_path TEXT,
75                merkle_hash TEXT NOT NULL,
76                object_count INTEGER NOT NULL DEFAULT 0,
77                is_leaf INTEGER DEFAULT 0,
78                updated_at INTEGER DEFAULT (strftime('%s', 'now'))
79            );
80            CREATE INDEX IF NOT EXISTS idx_merkle_parent ON merkle_nodes(parent_path);
81            "#
82        } else {
83            // Use VARCHAR for merkle_hash to avoid SQLx Any driver binary type mapping issues
84            // Hash is stored as 64-char hex string
85            r#"
86            CREATE TABLE IF NOT EXISTS merkle_nodes (
87                path VARCHAR(255) PRIMARY KEY,
88                parent_path VARCHAR(255),
89                merkle_hash VARCHAR(64) NOT NULL,
90                object_count INT NOT NULL DEFAULT 0,
91                is_leaf INT DEFAULT 0,
92                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
93                INDEX idx_merkle_parent (parent_path)
94            ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
95            "#
96        };
97
98        // SQLite needs separate statements
99        if self.is_sqlite {
100            for stmt in sql.split(';').filter(|s| !s.trim().is_empty()) {
101                sqlx::query(stmt)
102                    .execute(&self.pool)
103                    .await
104                    .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
105            }
106        } else {
107            sqlx::query(sql)
108                .execute(&self.pool)
109                .await
110                .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
111        }
112
113        Ok(())
114    }
115
116    /// Get the root hash (ground truth).
117    #[instrument(skip(self))]
118    pub async fn root_hash(&self) -> Result<Option<[u8; 32]>, StorageError> {
119        self.get_hash(ROOT_PATH).await
120    }
121
122    /// Get hash for a specific path.
123    pub async fn get_hash(&self, path: &str) -> Result<Option<[u8; 32]>, StorageError> {
124        let result = sqlx::query("SELECT merkle_hash FROM merkle_nodes WHERE path = ?")
125            .bind(path)
126            .fetch_optional(&self.pool)
127            .await
128            .map_err(|e| StorageError::Backend(format!("Failed to get merkle hash: {}", e)))?;
129
130        match result {
131            Some(row) => {
132                // For MySQL we store as hex VARCHAR, for SQLite as BLOB
133                // Try reading as string first (works for both), then decode
134                let hash_str: String = row.try_get("merkle_hash")
135                    .map_err(|e| StorageError::Backend(format!("Failed to read merkle hash: {}", e)))?;
136                
137                // If it looks like hex (64 chars), decode it
138                // Otherwise it might be raw bytes that got stringified
139                let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
140                    hex::decode(&hash_str).map_err(|e| StorageError::Backend(format!(
141                        "Invalid merkle hash hex: {}", e
142                    )))?
143                } else {
144                    // SQLite might return raw bytes
145                    hash_str.into_bytes()
146                };
147                
148                if bytes.len() != 32 {
149                    return Err(StorageError::Backend(format!(
150                        "Invalid merkle hash length: {}", bytes.len()
151                    )));
152                }
153                let mut hash = [0u8; 32];
154                hash.copy_from_slice(&bytes);
155                Ok(Some(hash))
156            }
157            None => Ok(None),
158        }
159    }
160
161    /// Get children of an interior node.
162    pub async fn get_children(&self, path: &str) -> Result<BTreeMap<String, [u8; 32]>, StorageError> {
163        let rows = sqlx::query(
164            "SELECT path, merkle_hash FROM merkle_nodes WHERE parent_path = ?"
165        )
166            .bind(path)
167            .fetch_all(&self.pool)
168            .await
169            .map_err(|e| StorageError::Backend(format!("Failed to get merkle children: {}", e)))?;
170
171        let mut children = BTreeMap::new();
172        for row in rows {
173            let child_path: String = row.try_get("path")
174                .map_err(|e| StorageError::Backend(e.to_string()))?;
175            
176            // Read hash as string (works for both MySQL VARCHAR and SQLite BLOB-as-text)
177            let hash_str: String = row.try_get("merkle_hash")
178                .map_err(|e| StorageError::Backend(e.to_string()))?;
179            
180            // Decode from hex if it looks like hex
181            let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
182                hex::decode(&hash_str).unwrap_or_else(|_| hash_str.into_bytes())
183            } else {
184                hash_str.into_bytes()
185            };
186            
187            if bytes.len() == 32 {
188                // Extract just the segment name from path
189                let segment = if path.is_empty() {
190                    child_path.clone()
191                } else {
192                    child_path.strip_prefix(&format!("{}.", path))
193                        .unwrap_or(&child_path)
194                        .to_string()
195                };
196                
197                let mut hash = [0u8; 32];
198                hash.copy_from_slice(&bytes);
199                children.insert(segment, hash);
200            }
201        }
202
203        Ok(children)
204    }
205
206    /// Get a full node (hash + children).
207    pub async fn get_node(&self, path: &str) -> Result<Option<MerkleNode>, StorageError> {
208        let hash = self.get_hash(path).await?;
209        
210        match hash {
211            Some(h) => {
212                let children = self.get_children(path).await?;
213                Ok(Some(if children.is_empty() {
214                    MerkleNode::leaf(h)
215                } else {
216                    MerkleNode {
217                        hash: h,
218                        children,
219                        is_leaf: false,
220                    }
221                }))
222            }
223            None => Ok(None),
224        }
225    }
226
227    /// Apply a batch of merkle updates atomically.
228    ///
229    /// This stores leaf hashes and recomputes affected interior nodes.
230    #[instrument(skip(self, batch), fields(batch_size = batch.len()))]
231    pub async fn apply_batch(&self, batch: &MerkleBatch) -> Result<(), StorageError> {
232        if batch.is_empty() {
233            return Ok(());
234        }
235
236        // Start transaction
237        let mut tx = self.pool.begin().await
238            .map_err(|e| StorageError::Backend(format!("Failed to begin transaction: {}", e)))?;
239
240        // Step 1: Apply leaf updates
241        for (object_id, maybe_hash) in &batch.leaves {
242            let parent = PathMerkle::parent_prefix(object_id);
243            
244            match maybe_hash {
245                Some(hash) => {
246                    // Store hash as hex string for cross-DB compatibility
247                    let hash_hex = hex::encode(hash);
248                    let sql = if self.is_sqlite {
249                        "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
250                         VALUES (?, ?, ?, 1, 1)
251                         ON CONFLICT(path) DO UPDATE SET 
252                            merkle_hash = excluded.merkle_hash,
253                            updated_at = strftime('%s', 'now')"
254                    } else {
255                        "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
256                         VALUES (?, ?, ?, 1, 1)
257                         ON DUPLICATE KEY UPDATE 
258                            merkle_hash = VALUES(merkle_hash),
259                            updated_at = CURRENT_TIMESTAMP"
260                    };
261                    
262                    sqlx::query(sql)
263                        .bind(object_id)
264                        .bind(parent)
265                        .bind(&hash_hex)
266                        .execute(&mut *tx)
267                        .await
268                        .map_err(|e| StorageError::Backend(format!("Failed to upsert merkle leaf: {}", e)))?;
269                }
270                None => {
271                    sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
272                        .bind(object_id)
273                        .execute(&mut *tx)
274                        .await
275                        .map_err(|e| StorageError::Backend(format!("Failed to delete merkle leaf: {}", e)))?;
276                }
277            }
278        }
279
280        // Commit leaf changes first
281        tx.commit().await
282            .map_err(|e| StorageError::Backend(format!("Failed to commit merkle leaves: {}", e)))?;
283
284        // Step 2: Recompute affected interior nodes (bottom-up)
285        let affected_prefixes = batch.affected_prefixes();
286        for prefix in affected_prefixes {
287            self.recompute_interior_node(&prefix).await?;
288        }
289
290        // Step 3: Recompute root
291        self.recompute_interior_node(ROOT_PATH).await?;
292
293        debug!(updates = batch.len(), "SQL merkle batch applied");
294        Ok(())
295    }
296
297    /// Recompute an interior node's hash from its children.
298    async fn recompute_interior_node(&self, path: &str) -> Result<(), StorageError> {
299        // Get all direct children
300        let children = self.get_children(path).await?;
301        
302        if children.is_empty() {
303            // No children, remove this interior node (unless it's a leaf)
304            let result = sqlx::query("SELECT is_leaf FROM merkle_nodes WHERE path = ?")
305                .bind(path)
306                .fetch_optional(&self.pool)
307                .await
308                .map_err(|e| StorageError::Backend(e.to_string()))?;
309            
310            if let Some(row) = result {
311                // Both SQLite and MySQL now store is_leaf as INT
312                let is_leaf: i32 = row.try_get("is_leaf").unwrap_or(0);
313                
314                if is_leaf == 0 {
315                    sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
316                        .bind(path)
317                        .execute(&self.pool)
318                        .await
319                        .map_err(|e| StorageError::Backend(e.to_string()))?;
320                }
321            }
322            return Ok(());
323        }
324
325        // Compute new interior hash
326        let node = MerkleNode::interior(children);
327        let hash_hex = hex::encode(node.hash); // Store as hex string
328        let parent = if path.is_empty() { None } else { Some(PathMerkle::parent_prefix(path)) };
329        let object_count = node.children.len() as i32;
330
331        let sql = if self.is_sqlite {
332            "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
333             VALUES (?, ?, ?, 0, ?)
334             ON CONFLICT(path) DO UPDATE SET 
335                merkle_hash = excluded.merkle_hash,
336                object_count = excluded.object_count,
337                updated_at = strftime('%s', 'now')"
338        } else {
339            "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
340             VALUES (?, ?, ?, 0, ?)
341             ON DUPLICATE KEY UPDATE 
342                merkle_hash = VALUES(merkle_hash),
343                object_count = VALUES(object_count),
344                updated_at = CURRENT_TIMESTAMP"
345        };
346
347        sqlx::query(sql)
348            .bind(path)
349            .bind(parent)
350            .bind(&hash_hex)
351            .bind(object_count)
352            .execute(&self.pool)
353            .await
354            .map_err(|e| StorageError::Backend(format!("Failed to update interior node: {}", e)))?;
355
356        debug!(path = %path, children = node.children.len(), "Recomputed interior node");
357        Ok(())
358    }
359
360    /// Compare with another merkle store and find differing branches.
361    ///
362    /// Returns prefixes where hashes differ (for sync).
363    #[instrument(skip(self, their_children))]
364    pub async fn diff_children(
365        &self,
366        path: &str,
367        their_children: &BTreeMap<String, [u8; 32]>,
368    ) -> Result<Vec<String>, StorageError> {
369        let our_children = self.get_children(path).await?;
370        let mut diffs = Vec::new();
371
372        let prefix_with_dot = if path.is_empty() {
373            String::new()
374        } else {
375            format!("{}.", path)
376        };
377
378        // Find segments where hashes differ or we have but they don't
379        for (segment, our_hash) in &our_children {
380            match their_children.get(segment) {
381                Some(their_hash) if their_hash != our_hash => {
382                    diffs.push(format!("{}{}", prefix_with_dot, segment));
383                }
384                None => {
385                    // We have it, they don't
386                    diffs.push(format!("{}{}", prefix_with_dot, segment));
387                }
388                _ => {} // Hashes match
389            }
390        }
391
392        // Find segments they have but we don't
393        for segment in their_children.keys() {
394            if !our_children.contains_key(segment) {
395                diffs.push(format!("{}{}", prefix_with_dot, segment));
396            }
397        }
398
399        Ok(diffs)
400    }
401
402    /// Get all leaf paths under a prefix (for sync).
403    pub async fn get_leaves_under(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
404        let pattern = if prefix.is_empty() {
405            "%".to_string()
406        } else {
407            format!("{}%", prefix)
408        };
409
410        // Both SQLite and MySQL now use INT for is_leaf
411        let sql = "SELECT path FROM merkle_nodes WHERE path LIKE ? AND is_leaf = 1";
412
413        let rows = sqlx::query(sql)
414            .bind(&pattern)
415            .fetch_all(&self.pool)
416            .await
417            .map_err(|e| StorageError::Backend(e.to_string()))?;
418
419        let mut leaves = Vec::with_capacity(rows.len());
420        for row in rows {
421            let path: String = row.try_get("path")
422                .map_err(|e| StorageError::Backend(e.to_string()))?;
423            leaves.push(path);
424        }
425
426        Ok(leaves)
427    }
428
429    /// Count total objects (leaves) in the tree.
430    pub async fn count_leaves(&self) -> Result<u64, StorageError> {
431        // Both SQLite and MySQL now use INT for is_leaf
432        let sql = "SELECT COUNT(*) as cnt FROM merkle_nodes WHERE is_leaf = 1";
433
434        let row = sqlx::query(sql)
435            .fetch_one(&self.pool)
436            .await
437            .map_err(|e| StorageError::Backend(e.to_string()))?;
438
439        let count: i64 = row.try_get("cnt")
440            .map_err(|e| StorageError::Backend(e.to_string()))?;
441
442        Ok(count as u64)
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use std::path::PathBuf;
450    
451    fn temp_db_path(name: &str) -> PathBuf {
452        std::env::temp_dir().join(format!("merkle_test_{}.db", name))
453    }
454
455    #[tokio::test]
456    async fn test_sql_merkle_basic() {
457        let db_path = temp_db_path("basic");
458        let _ = std::fs::remove_file(&db_path); // Clean up any old test
459        
460        let url = format!("sqlite://{}?mode=rwc", db_path.display());
461        let store = SqlMerkleStore::new(&url).await.unwrap();
462        
463        // Initially empty
464        assert!(store.root_hash().await.unwrap().is_none());
465        
466        // Insert a leaf
467        let mut batch = MerkleBatch::new();
468        batch.insert("uk.nhs.patient.123".to_string(), [1u8; 32]);
469        store.apply_batch(&batch).await.unwrap();
470        
471        // Root should now exist
472        let root = store.root_hash().await.unwrap();
473        assert!(root.is_some());
474        
475        // Leaf should exist
476        let leaf = store.get_hash("uk.nhs.patient.123").await.unwrap();
477        assert_eq!(leaf, Some([1u8; 32]));
478        
479        let _ = std::fs::remove_file(&db_path); // Clean up
480    }
481
482    #[tokio::test]
483    async fn test_sql_merkle_diff() {
484        let db_path = temp_db_path("diff");
485        let _ = std::fs::remove_file(&db_path); // Clean up any old test
486        
487        let url = format!("sqlite://{}?mode=rwc", db_path.display());
488        let store = SqlMerkleStore::new(&url).await.unwrap();
489        
490        // Insert some leaves
491        let mut batch = MerkleBatch::new();
492        batch.insert("uk.a.1".to_string(), [1u8; 32]);
493        batch.insert("uk.b.2".to_string(), [2u8; 32]);
494        store.apply_batch(&batch).await.unwrap();
495        
496        // Compare with "their" children where uk.a differs
497        let mut their_children = BTreeMap::new();
498        their_children.insert("a".to_string(), [99u8; 32]); // Different!
499        their_children.insert("b".to_string(), store.get_hash("uk.b").await.unwrap().unwrap()); // Same
500        
501        let diffs = store.diff_children("uk", &their_children).await.unwrap();
502        assert!(diffs.contains(&"uk.a".to_string()));
503        assert!(!diffs.contains(&"uk.b".to_string()));
504        
505        let _ = std::fs::remove_file(&db_path); // Clean up
506    }
507}