1use super::path_tree::{MerkleBatch, MerkleNode};
10use crate::StorageError;
11use redis::aio::ConnectionManager;
12use redis::AsyncCommands;
13use std::collections::BTreeMap;
14use tracing::{debug, instrument};
15
16const MERKLE_HASH_PREFIX: &str = "merkle:hash:";
18const MERKLE_CHILDREN_PREFIX: &str = "merkle:children:";
19
20#[derive(Clone)]
28pub struct RedisMerkleStore {
29 conn: ConnectionManager,
30 prefix: String,
32}
33
34impl RedisMerkleStore {
35 pub fn new(conn: ConnectionManager) -> Self {
37 Self::with_prefix(conn, None)
38 }
39
40 pub fn with_prefix(conn: ConnectionManager, prefix: Option<&str>) -> Self {
42 Self {
43 conn,
44 prefix: prefix.unwrap_or("").to_string(),
45 }
46 }
47
48 #[inline]
50 fn prefixed_key(&self, suffix: &str) -> String {
51 if self.prefix.is_empty() {
52 suffix.to_string()
53 } else {
54 format!("{}{}", self.prefix, suffix)
55 }
56 }
57
58 pub fn key_prefix(&self) -> &str {
60 &self.prefix
61 }
62
63 #[instrument(skip(self))]
65 pub async fn get_hash(&self, path: &str) -> Result<Option<[u8; 32]>, StorageError> {
66 let key = self.prefixed_key(&format!("{}{}", MERKLE_HASH_PREFIX, path));
67 let mut conn = self.conn.clone();
68
69 let result: Option<String> = conn.get(&key).await.map_err(|e| {
70 StorageError::Backend(format!("Failed to get merkle hash: {}", e))
71 })?;
72
73 match result {
74 Some(hex_str) => {
75 let bytes = hex::decode(&hex_str).map_err(|e| {
76 StorageError::Backend(format!("Invalid merkle hash hex: {}", e))
77 })?;
78 if bytes.len() != 32 {
79 return Err(StorageError::Backend(format!(
80 "Invalid merkle hash length: {}",
81 bytes.len()
82 )));
83 }
84 let mut hash = [0u8; 32];
85 hash.copy_from_slice(&bytes);
86 Ok(Some(hash))
87 }
88 None => Ok(None),
89 }
90 }
91
92 #[instrument(skip(self))]
94 pub async fn get_children(
95 &self,
96 path: &str,
97 ) -> Result<BTreeMap<String, [u8; 32]>, StorageError> {
98 let key = self.prefixed_key(&format!("{}{}", MERKLE_CHILDREN_PREFIX, path));
99 let mut conn = self.conn.clone();
100
101 let members: Vec<String> = conn.zrange(&key, 0, -1).await.map_err(|e| {
103 StorageError::Backend(format!("Failed to get merkle children: {}", e))
104 })?;
105
106 let mut children: BTreeMap<String, [u8; 32]> = BTreeMap::new();
107 for member in &members {
108 let member_str: &str = member.as_str();
110 if let Some((segment, hash_hex)) = member_str.split_once(':') {
111 let bytes = hex::decode(hash_hex).map_err(|e| {
112 StorageError::Backend(format!("Invalid child hash hex: {}", e))
113 })?;
114 if bytes.len() == 32 {
115 let mut hash = [0u8; 32];
116 hash.copy_from_slice(&bytes);
117 children.insert(segment.to_string(), hash);
118 }
119 }
120 }
121
122 Ok(children)
123 }
124
125 pub async fn get_node(&self, prefix: &str) -> Result<Option<MerkleNode>, StorageError> {
127 let hash = self.get_hash(prefix).await?;
128
129 match hash {
130 Some(h) => {
131 let children: BTreeMap<String, [u8; 32]> = self.get_children(prefix).await?;
132 Ok(Some(if children.is_empty() {
133 MerkleNode::leaf(h)
134 } else {
135 MerkleNode {
136 hash: h,
137 children,
138 is_leaf: false,
139 }
140 }))
141 }
142 None => Ok(None),
143 }
144 }
145
146 #[instrument(skip(self, batch), fields(batch_size = batch.len()))]
151 pub async fn apply_batch(&self, batch: &MerkleBatch) -> Result<(), StorageError> {
152 if batch.is_empty() {
153 return Ok(());
154 }
155
156 let mut conn = self.conn.clone();
157 let mut pipe = redis::pipe();
158 pipe.atomic();
159
160 for (object_id, maybe_hash) in &batch.leaves {
162 let hash_key = self.prefixed_key(&format!("{}{}", MERKLE_HASH_PREFIX, object_id));
163
164 match maybe_hash {
165 Some(hash) => {
166 let hex_str = hex::encode(hash);
167 pipe.set(&hash_key, &hex_str);
168 debug!(object_id = %object_id, "Setting leaf hash");
169 }
170 None => {
171 pipe.del(&hash_key);
172 debug!(object_id = %object_id, "Deleting leaf hash");
173 }
174 }
175 }
176
177 pipe.query_async::<()>(&mut conn).await.map_err(|e| {
179 StorageError::Backend(format!("Failed to apply merkle leaf updates: {}", e))
180 })?;
181
182 let affected_prefixes = batch.affected_prefixes();
184
185 for prefix in affected_prefixes {
186 self.recompute_interior_node(&prefix).await?;
187 }
188
189 Ok(())
190 }
191
192 #[instrument(skip(self))]
194 async fn recompute_interior_node(&self, prefix: &str) -> Result<(), StorageError> {
195 let mut conn = self.conn.clone();
196
197 let prefix_with_dot = if prefix.is_empty() {
199 String::new()
200 } else {
201 format!("{}.", prefix)
202 };
203
204 let scan_pattern = if prefix.is_empty() {
206 self.prefixed_key(&format!("{}*", MERKLE_HASH_PREFIX))
207 } else {
208 self.prefixed_key(&format!("{}{}.*", MERKLE_HASH_PREFIX, prefix))
209 };
210
211 let full_hash_prefix = self.prefixed_key(MERKLE_HASH_PREFIX);
213
214 let mut keys: Vec<String> = Vec::new();
215 let mut cursor = 0u64;
216
217 loop {
218 let (new_cursor, batch): (u64, Vec<String>) = redis::cmd("SCAN")
219 .arg(cursor)
220 .arg("MATCH")
221 .arg(&scan_pattern)
222 .arg("COUNT")
223 .arg(100) .query_async(&mut conn)
225 .await
226 .map_err(|e| StorageError::Backend(format!("Failed to scan merkle keys: {}", e)))?;
227
228 keys.extend(batch);
229 cursor = new_cursor;
230
231 if cursor == 0 {
232 break;
233 }
234 }
235
236 let mut children: BTreeMap<String, [u8; 32]> = BTreeMap::new();
237
238 for key in &keys {
239 let path: &str = key.strip_prefix(&full_hash_prefix).unwrap_or(key.as_str());
241
242 let suffix: &str = if prefix.is_empty() {
244 path
245 } else {
246 match path.strip_prefix(&prefix_with_dot) {
247 Some(s) => s,
248 None => continue,
249 }
250 };
251
252 if let Some(segment) = suffix.split('.').next() {
254 if segment == suffix || !suffix.contains('.') {
256 let child_path = if prefix.is_empty() {
258 segment.to_string()
259 } else {
260 format!("{}.{}", prefix, segment)
261 };
262
263 if let Some(hash) = self.get_hash(&child_path).await? {
264 children.insert(segment.to_string(), hash);
265 }
266 }
267 }
268 }
269
270 if children.is_empty() {
271 return Ok(());
273 }
274
275 let node = MerkleNode::interior(children.clone());
277 let hash_hex = hex::encode(node.hash);
278
279 let hash_key = self.prefixed_key(&format!("{}{}", MERKLE_HASH_PREFIX, prefix));
281 let children_key = self.prefixed_key(&format!("{}{}", MERKLE_CHILDREN_PREFIX, prefix));
282
283 let mut pipe = redis::pipe();
284 pipe.atomic();
285 pipe.set(&hash_key, &hash_hex);
286
287 pipe.del(&children_key);
289 for (segment, hash) in &children {
290 let member = format!("{}:{}", segment, hex::encode(hash));
291 pipe.zadd(&children_key, &member, 0i64);
292 }
293
294 pipe.query_async::<()>(&mut conn).await.map_err(|e| {
295 StorageError::Backend(format!("Failed to update interior node: {}", e))
296 })?;
297
298 debug!(prefix = %prefix, children_count = children.len(), "Recomputed interior node");
299
300 Ok(())
301 }
302
303 pub async fn root_hash(&self) -> Result<Option<[u8; 32]>, StorageError> {
305 self.recompute_interior_node("").await?;
307
308 let key = self.prefixed_key(MERKLE_HASH_PREFIX);
310 let mut conn = self.conn.clone();
311
312 let result: Option<String> = conn.get(&key).await.map_err(|e| {
313 StorageError::Backend(format!("Failed to get root hash: {}", e))
314 })?;
315
316 match result {
317 Some(hex_str) => {
318 let bytes = hex::decode(&hex_str).map_err(|e| {
319 StorageError::Backend(format!("Invalid root hash hex: {}", e))
320 })?;
321 if bytes.len() != 32 {
322 return Err(StorageError::Backend(format!(
323 "Invalid root hash length: {}",
324 bytes.len()
325 )));
326 }
327 let mut hash = [0u8; 32];
328 hash.copy_from_slice(&bytes);
329 Ok(Some(hash))
330 }
331 None => Ok(None),
332 }
333 }
334
335 #[instrument(skip(self, their_children))]
339 pub async fn diff_children(
340 &self,
341 prefix: &str,
342 their_children: &BTreeMap<String, [u8; 32]>,
343 ) -> Result<Vec<String>, StorageError> {
344 let our_children: BTreeMap<String, [u8; 32]> = self.get_children(prefix).await?;
345 let mut diffs = Vec::new();
346
347 let prefix_with_dot = if prefix.is_empty() {
348 String::new()
349 } else {
350 format!("{}.", prefix)
351 };
352
353 for (segment, our_hash) in &our_children {
355 match their_children.get(segment) {
356 Some(their_hash) if their_hash != our_hash => {
357 diffs.push(format!("{}{}", prefix_with_dot, segment));
358 }
359 None => {
360 diffs.push(format!("{}{}", prefix_with_dot, segment));
362 }
363 _ => {} }
365 }
366
367 for segment in their_children.keys() {
369 if !our_children.contains_key(segment) {
370 diffs.push(format!("{}{}", prefix_with_dot, segment));
371 }
372 }
373
374 Ok(diffs)
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_key_prefixes() {
384 assert_eq!(
385 format!("{}{}", MERKLE_HASH_PREFIX, "uk.nhs.patient"),
386 "merkle:hash:uk.nhs.patient"
387 );
388 assert_eq!(
389 format!("{}{}", MERKLE_CHILDREN_PREFIX, "uk.nhs"),
390 "merkle:children:uk.nhs"
391 );
392 }
393}