1use super::path_tree::{MerkleBatch, MerkleNode};
7use crate::StorageError;
8use redis::aio::ConnectionManager;
9use redis::AsyncCommands;
10use std::collections::BTreeMap;
11use tracing::{debug, instrument};
12
13const MERKLE_HASH_PREFIX: &str = "merkle:hash:";
15const MERKLE_CHILDREN_PREFIX: &str = "merkle:children:";
16
17#[derive(Clone)]
25pub struct RedisMerkleStore {
26 conn: ConnectionManager,
27 prefix: String,
29}
30
31impl RedisMerkleStore {
32 pub fn new(conn: ConnectionManager) -> Self {
34 Self::with_prefix(conn, None)
35 }
36
37 pub fn with_prefix(conn: ConnectionManager, prefix: Option<&str>) -> Self {
39 Self {
40 conn,
41 prefix: prefix.unwrap_or("").to_string(),
42 }
43 }
44
45 #[inline]
47 fn prefixed_key(&self, suffix: &str) -> String {
48 if self.prefix.is_empty() {
49 suffix.to_string()
50 } else {
51 format!("{}{}", self.prefix, suffix)
52 }
53 }
54
55 pub fn key_prefix(&self) -> &str {
57 &self.prefix
58 }
59
60 #[instrument(skip(self))]
62 pub async fn get_hash(&self, path: &str) -> Result<Option<[u8; 32]>, StorageError> {
63 let key = self.prefixed_key(&format!("{}{}", MERKLE_HASH_PREFIX, path));
64 let mut conn = self.conn.clone();
65
66 let result: Option<String> = conn.get(&key).await.map_err(|e| {
67 StorageError::Backend(format!("Failed to get merkle hash: {}", e))
68 })?;
69
70 match result {
71 Some(hex_str) => {
72 let bytes = hex::decode(&hex_str).map_err(|e| {
73 StorageError::Backend(format!("Invalid merkle hash hex: {}", e))
74 })?;
75 if bytes.len() != 32 {
76 return Err(StorageError::Backend(format!(
77 "Invalid merkle hash length: {}",
78 bytes.len()
79 )));
80 }
81 let mut hash = [0u8; 32];
82 hash.copy_from_slice(&bytes);
83 Ok(Some(hash))
84 }
85 None => Ok(None),
86 }
87 }
88
89 #[instrument(skip(self))]
91 pub async fn get_children(
92 &self,
93 path: &str,
94 ) -> Result<BTreeMap<String, [u8; 32]>, StorageError> {
95 let key = self.prefixed_key(&format!("{}{}", MERKLE_CHILDREN_PREFIX, path));
96 let mut conn = self.conn.clone();
97
98 let members: Vec<String> = conn.zrange(&key, 0, -1).await.map_err(|e| {
100 StorageError::Backend(format!("Failed to get merkle children: {}", e))
101 })?;
102
103 let mut children: BTreeMap<String, [u8; 32]> = BTreeMap::new();
104 for member in &members {
105 let member_str: &str = member.as_str();
107 if let Some((segment, hash_hex)) = member_str.split_once(':') {
108 let bytes = hex::decode(hash_hex).map_err(|e| {
109 StorageError::Backend(format!("Invalid child hash hex: {}", e))
110 })?;
111 if bytes.len() == 32 {
112 let mut hash = [0u8; 32];
113 hash.copy_from_slice(&bytes);
114 children.insert(segment.to_string(), hash);
115 }
116 }
117 }
118
119 Ok(children)
120 }
121
122 pub async fn get_node(&self, prefix: &str) -> Result<Option<MerkleNode>, StorageError> {
124 let hash = self.get_hash(prefix).await?;
125
126 match hash {
127 Some(h) => {
128 let children: BTreeMap<String, [u8; 32]> = self.get_children(prefix).await?;
129 Ok(Some(if children.is_empty() {
130 MerkleNode::leaf(h)
131 } else {
132 MerkleNode {
133 hash: h,
134 children,
135 is_leaf: false,
136 }
137 }))
138 }
139 None => Ok(None),
140 }
141 }
142
143 #[instrument(skip(self, batch), fields(batch_size = batch.len()))]
148 pub async fn apply_batch(&self, batch: &MerkleBatch) -> Result<(), StorageError> {
149 if batch.is_empty() {
150 return Ok(());
151 }
152
153 let mut conn = self.conn.clone();
154 let mut pipe = redis::pipe();
155 pipe.atomic();
156
157 for (object_id, maybe_hash) in &batch.leaves {
159 let hash_key = self.prefixed_key(&format!("{}{}", MERKLE_HASH_PREFIX, object_id));
160
161 match maybe_hash {
162 Some(hash) => {
163 let hex_str = hex::encode(hash);
164 pipe.set(&hash_key, &hex_str);
165 debug!(object_id = %object_id, "Setting leaf hash");
166 }
167 None => {
168 pipe.del(&hash_key);
169 debug!(object_id = %object_id, "Deleting leaf hash");
170 }
171 }
172 }
173
174 pipe.query_async::<()>(&mut conn).await.map_err(|e| {
176 StorageError::Backend(format!("Failed to apply merkle leaf updates: {}", e))
177 })?;
178
179 let affected_prefixes = batch.affected_prefixes();
181
182 for prefix in affected_prefixes {
183 self.recompute_interior_node(&prefix).await?;
184 }
185
186 Ok(())
187 }
188
189 #[instrument(skip(self))]
191 async fn recompute_interior_node(&self, prefix: &str) -> Result<(), StorageError> {
192 let mut conn = self.conn.clone();
193
194 let prefix_with_dot = if prefix.is_empty() {
196 String::new()
197 } else {
198 format!("{}.", prefix)
199 };
200
201 let scan_pattern = if prefix.is_empty() {
203 self.prefixed_key(&format!("{}*", MERKLE_HASH_PREFIX))
204 } else {
205 self.prefixed_key(&format!("{}{}.*", MERKLE_HASH_PREFIX, prefix))
206 };
207
208 let full_hash_prefix = self.prefixed_key(MERKLE_HASH_PREFIX);
210
211 let mut keys: Vec<String> = Vec::new();
212 let mut cursor = 0u64;
213
214 loop {
215 let (new_cursor, batch): (u64, Vec<String>) = redis::cmd("SCAN")
216 .arg(cursor)
217 .arg("MATCH")
218 .arg(&scan_pattern)
219 .arg("COUNT")
220 .arg(100) .query_async(&mut conn)
222 .await
223 .map_err(|e| StorageError::Backend(format!("Failed to scan merkle keys: {}", e)))?;
224
225 keys.extend(batch);
226 cursor = new_cursor;
227
228 if cursor == 0 {
229 break;
230 }
231 }
232
233 let mut children: BTreeMap<String, [u8; 32]> = BTreeMap::new();
234
235 for key in &keys {
236 let path: &str = key.strip_prefix(&full_hash_prefix).unwrap_or(key.as_str());
238
239 let suffix: &str = if prefix.is_empty() {
241 path
242 } else {
243 match path.strip_prefix(&prefix_with_dot) {
244 Some(s) => s,
245 None => continue,
246 }
247 };
248
249 if let Some(segment) = suffix.split('.').next() {
251 if segment == suffix || !suffix.contains('.') {
253 let child_path = if prefix.is_empty() {
255 segment.to_string()
256 } else {
257 format!("{}.{}", prefix, segment)
258 };
259
260 if let Some(hash) = self.get_hash(&child_path).await? {
261 children.insert(segment.to_string(), hash);
262 }
263 }
264 }
265 }
266
267 if children.is_empty() {
268 return Ok(());
270 }
271
272 let node = MerkleNode::interior(children.clone());
274 let hash_hex = hex::encode(node.hash);
275
276 let hash_key = self.prefixed_key(&format!("{}{}", MERKLE_HASH_PREFIX, prefix));
278 let children_key = self.prefixed_key(&format!("{}{}", MERKLE_CHILDREN_PREFIX, prefix));
279
280 let mut pipe = redis::pipe();
281 pipe.atomic();
282 pipe.set(&hash_key, &hash_hex);
283
284 pipe.del(&children_key);
286 for (segment, hash) in &children {
287 let member = format!("{}:{}", segment, hex::encode(hash));
288 pipe.zadd(&children_key, &member, 0i64);
289 }
290
291 pipe.query_async::<()>(&mut conn).await.map_err(|e| {
292 StorageError::Backend(format!("Failed to update interior node: {}", e))
293 })?;
294
295 debug!(prefix = %prefix, children_count = children.len(), "Recomputed interior node");
296
297 Ok(())
298 }
299
300 pub async fn root_hash(&self) -> Result<Option<[u8; 32]>, StorageError> {
302 self.recompute_interior_node("").await?;
304
305 let key = self.prefixed_key(MERKLE_HASH_PREFIX);
307 let mut conn = self.conn.clone();
308
309 let result: Option<String> = conn.get(&key).await.map_err(|e| {
310 StorageError::Backend(format!("Failed to get root hash: {}", e))
311 })?;
312
313 match result {
314 Some(hex_str) => {
315 let bytes = hex::decode(&hex_str).map_err(|e| {
316 StorageError::Backend(format!("Invalid root hash hex: {}", e))
317 })?;
318 if bytes.len() != 32 {
319 return Err(StorageError::Backend(format!(
320 "Invalid root hash length: {}",
321 bytes.len()
322 )));
323 }
324 let mut hash = [0u8; 32];
325 hash.copy_from_slice(&bytes);
326 Ok(Some(hash))
327 }
328 None => Ok(None),
329 }
330 }
331
332 #[instrument(skip(self, their_children))]
336 pub async fn diff_children(
337 &self,
338 prefix: &str,
339 their_children: &BTreeMap<String, [u8; 32]>,
340 ) -> Result<Vec<String>, StorageError> {
341 let our_children: BTreeMap<String, [u8; 32]> = self.get_children(prefix).await?;
342 let mut diffs = Vec::new();
343
344 let prefix_with_dot = if prefix.is_empty() {
345 String::new()
346 } else {
347 format!("{}.", prefix)
348 };
349
350 for (segment, our_hash) in &our_children {
352 match their_children.get(segment) {
353 Some(their_hash) if their_hash != our_hash => {
354 diffs.push(format!("{}{}", prefix_with_dot, segment));
355 }
356 None => {
357 diffs.push(format!("{}{}", prefix_with_dot, segment));
359 }
360 _ => {} }
362 }
363
364 for segment in their_children.keys() {
366 if !our_children.contains_key(segment) {
367 diffs.push(format!("{}{}", prefix_with_dot, segment));
368 }
369 }
370
371 Ok(diffs)
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn test_key_prefixes() {
381 assert_eq!(
382 format!("{}{}", MERKLE_HASH_PREFIX, "uk.nhs.patient"),
383 "merkle:hash:uk.nhs.patient"
384 );
385 assert_eq!(
386 format!("{}{}", MERKLE_CHILDREN_PREFIX, "uk.nhs"),
387 "merkle:children:uk.nhs"
388 );
389 }
390}