sync_engine/merkle/
sql_store.rs

1//! SQL storage for Merkle tree nodes (ground truth).
2//!
3//! The SQL merkle store is the authoritative source for merkle hashes.
4//! On startup, we trust SQL merkle root over Redis.
5//!
6//! # Schema
7//!
8//! ```sql
9//! CREATE TABLE merkle_nodes (
10//!     path VARCHAR(255) PRIMARY KEY,    -- "" = root, "uk", "uk.nhs", etc
11//!     parent_path VARCHAR(255),          -- Parent node path (NULL for root)
12//!     merkle_hash BINARY(32) NOT NULL,   -- SHA256 hash
13//!     object_count INT NOT NULL DEFAULT 0,
14//!     is_leaf BOOLEAN DEFAULT FALSE,
15//!     updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
16//! );
17//! ```
18
19use super::path_tree::{MerkleBatch, MerkleNode, PathMerkle};
20use crate::StorageError;
21use sqlx::{AnyPool, Row, any::AnyPoolOptions};
22use std::collections::BTreeMap;
23use std::time::Duration;
24use tracing::{debug, info, instrument};
25
26/// SQL key for root hash
27const ROOT_PATH: &str = "";
28
29/// SQL-backed Merkle tree storage (ground truth).
30#[derive(Clone)]
31pub struct SqlMerkleStore {
32    pool: AnyPool,
33    is_sqlite: bool,
34}
35
36impl SqlMerkleStore {
37    /// Create a new SQL merkle store, initializing schema if needed.
38    pub async fn new(connection_string: &str) -> Result<Self, StorageError> {
39        // Ensure drivers are installed
40        sqlx::any::install_default_drivers();
41        
42        let is_sqlite = connection_string.starts_with("sqlite:");
43        
44        let pool = AnyPoolOptions::new()
45            .max_connections(5)
46            .acquire_timeout(Duration::from_secs(10))
47            .connect(connection_string)
48            .await
49            .map_err(|e| StorageError::Backend(format!("Failed to connect to SQL merkle store: {}", e)))?;
50
51        let store = Self { pool, is_sqlite };
52        store.init_schema().await?;
53        
54        info!("SQL merkle store initialized");
55        Ok(store)
56    }
57
58    /// Create from existing pool (e.g., share with SqlStore).
59    pub fn from_pool(pool: AnyPool, is_sqlite: bool) -> Self {
60        Self { pool, is_sqlite }
61    }
62
63    /// Initialize the schema (creates tables if not exist).
64    /// Called automatically by `new()`, but must be called manually after `from_pool()`.
65    pub async fn init_schema(&self) -> Result<(), StorageError> {
66        // Use TEXT/VARCHAR for merkle_hash (stored as 64-char hex) for SQLx Any driver compatibility
67        let sql = if self.is_sqlite {
68            r#"
69            CREATE TABLE IF NOT EXISTS merkle_nodes (
70                path TEXT PRIMARY KEY,
71                parent_path TEXT,
72                merkle_hash TEXT NOT NULL,
73                object_count INTEGER NOT NULL DEFAULT 0,
74                is_leaf INTEGER DEFAULT 0,
75                updated_at INTEGER DEFAULT (strftime('%s', 'now'))
76            );
77            CREATE INDEX IF NOT EXISTS idx_merkle_parent ON merkle_nodes(parent_path);
78            "#
79        } else {
80            // Use VARCHAR for merkle_hash to avoid SQLx Any driver binary type mapping issues
81            // Hash is stored as 64-char hex string
82            r#"
83            CREATE TABLE IF NOT EXISTS merkle_nodes (
84                path VARCHAR(255) PRIMARY KEY,
85                parent_path VARCHAR(255),
86                merkle_hash VARCHAR(64) NOT NULL,
87                object_count INT NOT NULL DEFAULT 0,
88                is_leaf INT DEFAULT 0,
89                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
90                INDEX idx_merkle_parent (parent_path)
91            ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
92            "#
93        };
94
95        // SQLite needs separate statements
96        if self.is_sqlite {
97            for stmt in sql.split(';').filter(|s| !s.trim().is_empty()) {
98                sqlx::query(stmt)
99                    .execute(&self.pool)
100                    .await
101                    .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
102            }
103        } else {
104            sqlx::query(sql)
105                .execute(&self.pool)
106                .await
107                .map_err(|e| StorageError::Backend(format!("Failed to init merkle schema: {}", e)))?;
108        }
109
110        Ok(())
111    }
112
113    /// Get the root hash (ground truth).
114    #[instrument(skip(self))]
115    pub async fn root_hash(&self) -> Result<Option<[u8; 32]>, StorageError> {
116        self.get_hash(ROOT_PATH).await
117    }
118
119    /// Get hash for a specific path.
120    pub async fn get_hash(&self, path: &str) -> Result<Option<[u8; 32]>, StorageError> {
121        let result = sqlx::query("SELECT merkle_hash FROM merkle_nodes WHERE path = ?")
122            .bind(path)
123            .fetch_optional(&self.pool)
124            .await
125            .map_err(|e| StorageError::Backend(format!("Failed to get merkle hash: {}", e)))?;
126
127        match result {
128            Some(row) => {
129                // For MySQL we store as hex VARCHAR, for SQLite as BLOB
130                // Try reading as string first (works for both), then decode
131                let hash_str: String = row.try_get("merkle_hash")
132                    .map_err(|e| StorageError::Backend(format!("Failed to read merkle hash: {}", e)))?;
133                
134                // If it looks like hex (64 chars), decode it
135                // Otherwise it might be raw bytes that got stringified
136                let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
137                    hex::decode(&hash_str).map_err(|e| StorageError::Backend(format!(
138                        "Invalid merkle hash hex: {}", e
139                    )))?
140                } else {
141                    // SQLite might return raw bytes
142                    hash_str.into_bytes()
143                };
144                
145                if bytes.len() != 32 {
146                    return Err(StorageError::Backend(format!(
147                        "Invalid merkle hash length: {}", bytes.len()
148                    )));
149                }
150                let mut hash = [0u8; 32];
151                hash.copy_from_slice(&bytes);
152                Ok(Some(hash))
153            }
154            None => Ok(None),
155        }
156    }
157
158    /// Get children of an interior node.
159    pub async fn get_children(&self, path: &str) -> Result<BTreeMap<String, [u8; 32]>, StorageError> {
160        let rows = sqlx::query(
161            "SELECT path, merkle_hash FROM merkle_nodes WHERE parent_path = ?"
162        )
163            .bind(path)
164            .fetch_all(&self.pool)
165            .await
166            .map_err(|e| StorageError::Backend(format!("Failed to get merkle children: {}", e)))?;
167
168        let mut children = BTreeMap::new();
169        for row in rows {
170            let child_path: String = row.try_get("path")
171                .map_err(|e| StorageError::Backend(e.to_string()))?;
172            
173            // Read hash as string (works for both MySQL VARCHAR and SQLite BLOB-as-text)
174            let hash_str: String = row.try_get("merkle_hash")
175                .map_err(|e| StorageError::Backend(e.to_string()))?;
176            
177            // Decode from hex if it looks like hex
178            let bytes = if hash_str.len() == 64 && hash_str.chars().all(|c| c.is_ascii_hexdigit()) {
179                hex::decode(&hash_str).unwrap_or_else(|_| hash_str.into_bytes())
180            } else {
181                hash_str.into_bytes()
182            };
183            
184            if bytes.len() == 32 {
185                // Extract just the segment name from path
186                let segment = if path.is_empty() {
187                    child_path.clone()
188                } else {
189                    child_path.strip_prefix(&format!("{}.", path))
190                        .unwrap_or(&child_path)
191                        .to_string()
192                };
193                
194                let mut hash = [0u8; 32];
195                hash.copy_from_slice(&bytes);
196                children.insert(segment, hash);
197            }
198        }
199
200        Ok(children)
201    }
202
203    /// Get a full node (hash + children).
204    pub async fn get_node(&self, path: &str) -> Result<Option<MerkleNode>, StorageError> {
205        let hash = self.get_hash(path).await?;
206        
207        match hash {
208            Some(h) => {
209                let children = self.get_children(path).await?;
210                Ok(Some(if children.is_empty() {
211                    MerkleNode::leaf(h)
212                } else {
213                    MerkleNode {
214                        hash: h,
215                        children,
216                        is_leaf: false,
217                    }
218                }))
219            }
220            None => Ok(None),
221        }
222    }
223
224    /// Apply a batch of merkle updates atomically.
225    ///
226    /// This stores leaf hashes and recomputes affected interior nodes.
227    #[instrument(skip(self, batch), fields(batch_size = batch.len()))]
228    pub async fn apply_batch(&self, batch: &MerkleBatch) -> Result<(), StorageError> {
229        if batch.is_empty() {
230            return Ok(());
231        }
232
233        // Start transaction
234        let mut tx = self.pool.begin().await
235            .map_err(|e| StorageError::Backend(format!("Failed to begin transaction: {}", e)))?;
236
237        // Step 1: Apply leaf updates
238        for (object_id, maybe_hash) in &batch.leaves {
239            let parent = PathMerkle::parent_prefix(object_id);
240            
241            match maybe_hash {
242                Some(hash) => {
243                    // Store hash as hex string for cross-DB compatibility
244                    let hash_hex = hex::encode(hash);
245                    let sql = if self.is_sqlite {
246                        "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
247                         VALUES (?, ?, ?, 1, 1)
248                         ON CONFLICT(path) DO UPDATE SET 
249                            merkle_hash = excluded.merkle_hash,
250                            updated_at = strftime('%s', 'now')"
251                    } else {
252                        "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
253                         VALUES (?, ?, ?, 1, 1)
254                         ON DUPLICATE KEY UPDATE 
255                            merkle_hash = VALUES(merkle_hash),
256                            updated_at = CURRENT_TIMESTAMP"
257                    };
258                    
259                    sqlx::query(sql)
260                        .bind(object_id)
261                        .bind(parent)
262                        .bind(&hash_hex)
263                        .execute(&mut *tx)
264                        .await
265                        .map_err(|e| StorageError::Backend(format!("Failed to upsert merkle leaf: {}", e)))?;
266                }
267                None => {
268                    sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
269                        .bind(object_id)
270                        .execute(&mut *tx)
271                        .await
272                        .map_err(|e| StorageError::Backend(format!("Failed to delete merkle leaf: {}", e)))?;
273                }
274            }
275        }
276
277        // Commit leaf changes first
278        tx.commit().await
279            .map_err(|e| StorageError::Backend(format!("Failed to commit merkle leaves: {}", e)))?;
280
281        // Step 2: Recompute affected interior nodes (bottom-up)
282        let affected_prefixes = batch.affected_prefixes();
283        for prefix in affected_prefixes {
284            self.recompute_interior_node(&prefix).await?;
285        }
286
287        // Step 3: Recompute root
288        self.recompute_interior_node(ROOT_PATH).await?;
289
290        debug!(updates = batch.len(), "SQL merkle batch applied");
291        Ok(())
292    }
293
294    /// Recompute an interior node's hash from its children.
295    async fn recompute_interior_node(&self, path: &str) -> Result<(), StorageError> {
296        // Get all direct children
297        let children = self.get_children(path).await?;
298        
299        if children.is_empty() {
300            // No children, remove this interior node (unless it's a leaf)
301            let result = sqlx::query("SELECT is_leaf FROM merkle_nodes WHERE path = ?")
302                .bind(path)
303                .fetch_optional(&self.pool)
304                .await
305                .map_err(|e| StorageError::Backend(e.to_string()))?;
306            
307            if let Some(row) = result {
308                // Both SQLite and MySQL now store is_leaf as INT
309                let is_leaf: i32 = row.try_get("is_leaf").unwrap_or(0);
310                
311                if is_leaf == 0 {
312                    sqlx::query("DELETE FROM merkle_nodes WHERE path = ?")
313                        .bind(path)
314                        .execute(&self.pool)
315                        .await
316                        .map_err(|e| StorageError::Backend(e.to_string()))?;
317                }
318            }
319            return Ok(());
320        }
321
322        // Compute new interior hash
323        let node = MerkleNode::interior(children);
324        let hash_hex = hex::encode(node.hash); // Store as hex string
325        let parent = if path.is_empty() { None } else { Some(PathMerkle::parent_prefix(path)) };
326        let object_count = node.children.len() as i32;
327
328        let sql = if self.is_sqlite {
329            "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
330             VALUES (?, ?, ?, 0, ?)
331             ON CONFLICT(path) DO UPDATE SET 
332                merkle_hash = excluded.merkle_hash,
333                object_count = excluded.object_count,
334                updated_at = strftime('%s', 'now')"
335        } else {
336            "INSERT INTO merkle_nodes (path, parent_path, merkle_hash, is_leaf, object_count) 
337             VALUES (?, ?, ?, 0, ?)
338             ON DUPLICATE KEY UPDATE 
339                merkle_hash = VALUES(merkle_hash),
340                object_count = VALUES(object_count),
341                updated_at = CURRENT_TIMESTAMP"
342        };
343
344        sqlx::query(sql)
345            .bind(path)
346            .bind(parent)
347            .bind(&hash_hex)
348            .bind(object_count)
349            .execute(&self.pool)
350            .await
351            .map_err(|e| StorageError::Backend(format!("Failed to update interior node: {}", e)))?;
352
353        debug!(path = %path, children = node.children.len(), "Recomputed interior node");
354        Ok(())
355    }
356
357    /// Compare with another merkle store and find differing branches.
358    ///
359    /// Returns prefixes where hashes differ (for sync).
360    #[instrument(skip(self, their_children))]
361    pub async fn diff_children(
362        &self,
363        path: &str,
364        their_children: &BTreeMap<String, [u8; 32]>,
365    ) -> Result<Vec<String>, StorageError> {
366        let our_children = self.get_children(path).await?;
367        let mut diffs = Vec::new();
368
369        let prefix_with_dot = if path.is_empty() {
370            String::new()
371        } else {
372            format!("{}.", path)
373        };
374
375        // Find segments where hashes differ or we have but they don't
376        for (segment, our_hash) in &our_children {
377            match their_children.get(segment) {
378                Some(their_hash) if their_hash != our_hash => {
379                    diffs.push(format!("{}{}", prefix_with_dot, segment));
380                }
381                None => {
382                    // We have it, they don't
383                    diffs.push(format!("{}{}", prefix_with_dot, segment));
384                }
385                _ => {} // Hashes match
386            }
387        }
388
389        // Find segments they have but we don't
390        for segment in their_children.keys() {
391            if !our_children.contains_key(segment) {
392                diffs.push(format!("{}{}", prefix_with_dot, segment));
393            }
394        }
395
396        Ok(diffs)
397    }
398
399    /// Get all leaf paths under a prefix (for sync).
400    pub async fn get_leaves_under(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
401        let pattern = if prefix.is_empty() {
402            "%".to_string()
403        } else {
404            format!("{}%", prefix)
405        };
406
407        // Both SQLite and MySQL now use INT for is_leaf
408        let sql = "SELECT path FROM merkle_nodes WHERE path LIKE ? AND is_leaf = 1";
409
410        let rows = sqlx::query(sql)
411            .bind(&pattern)
412            .fetch_all(&self.pool)
413            .await
414            .map_err(|e| StorageError::Backend(e.to_string()))?;
415
416        let mut leaves = Vec::with_capacity(rows.len());
417        for row in rows {
418            let path: String = row.try_get("path")
419                .map_err(|e| StorageError::Backend(e.to_string()))?;
420            leaves.push(path);
421        }
422
423        Ok(leaves)
424    }
425
426    /// Count total objects (leaves) in the tree.
427    pub async fn count_leaves(&self) -> Result<u64, StorageError> {
428        // Both SQLite and MySQL now use INT for is_leaf
429        let sql = "SELECT COUNT(*) as cnt FROM merkle_nodes WHERE is_leaf = 1";
430
431        let row = sqlx::query(sql)
432            .fetch_one(&self.pool)
433            .await
434            .map_err(|e| StorageError::Backend(e.to_string()))?;
435
436        let count: i64 = row.try_get("cnt")
437            .map_err(|e| StorageError::Backend(e.to_string()))?;
438
439        Ok(count as u64)
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use std::path::PathBuf;
447    
448    fn temp_db_path(name: &str) -> PathBuf {
449        std::env::temp_dir().join(format!("merkle_test_{}.db", name))
450    }
451
452    #[tokio::test]
453    async fn test_sql_merkle_basic() {
454        let db_path = temp_db_path("basic");
455        let _ = std::fs::remove_file(&db_path); // Clean up any old test
456        
457        let url = format!("sqlite://{}?mode=rwc", db_path.display());
458        let store = SqlMerkleStore::new(&url).await.unwrap();
459        
460        // Initially empty
461        assert!(store.root_hash().await.unwrap().is_none());
462        
463        // Insert a leaf
464        let mut batch = MerkleBatch::new();
465        batch.insert("uk.nhs.patient.123".to_string(), [1u8; 32]);
466        store.apply_batch(&batch).await.unwrap();
467        
468        // Root should now exist
469        let root = store.root_hash().await.unwrap();
470        assert!(root.is_some());
471        
472        // Leaf should exist
473        let leaf = store.get_hash("uk.nhs.patient.123").await.unwrap();
474        assert_eq!(leaf, Some([1u8; 32]));
475        
476        let _ = std::fs::remove_file(&db_path); // Clean up
477    }
478
479    #[tokio::test]
480    async fn test_sql_merkle_diff() {
481        let db_path = temp_db_path("diff");
482        let _ = std::fs::remove_file(&db_path); // Clean up any old test
483        
484        let url = format!("sqlite://{}?mode=rwc", db_path.display());
485        let store = SqlMerkleStore::new(&url).await.unwrap();
486        
487        // Insert some leaves
488        let mut batch = MerkleBatch::new();
489        batch.insert("uk.a.1".to_string(), [1u8; 32]);
490        batch.insert("uk.b.2".to_string(), [2u8; 32]);
491        store.apply_batch(&batch).await.unwrap();
492        
493        // Compare with "their" children where uk.a differs
494        let mut their_children = BTreeMap::new();
495        their_children.insert("a".to_string(), [99u8; 32]); // Different!
496        their_children.insert("b".to_string(), store.get_hash("uk.b").await.unwrap().unwrap()); // Same
497        
498        let diffs = store.diff_children("uk", &their_children).await.unwrap();
499        assert!(diffs.contains(&"uk.a".to_string()));
500        assert!(!diffs.contains(&"uk.b".to_string()));
501        
502        let _ = std::fs::remove_file(&db_path); // Clean up
503    }
504}