wesichain_checkpoint_redis/
lib.rs1mod history;
2mod keys;
3mod script;
4
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use crate::keys::{index_key, safe_thread_id, ThreadKeys};
9use crate::script::LUA_SAVE;
10use fred::interfaces::{KeysInterface, LuaInterface, SortedSetsInterface};
11use fred::prelude::*;
12use tokio::sync::RwLock;
13use wesichain_core::checkpoint::{Checkpoint, Checkpointer};
14use wesichain_core::state::StateSchema;
15use wesichain_core::WesichainError;
16
17pub use keys::{index_key as redis_index_key, safe_thread_id as validate_thread_id};
18
19#[derive(Clone)]
20pub struct RedisCheckpointer {
21 client: RedisClient,
22 namespace: String,
23 ttl_seconds: Option<u64>,
24 script_sha: Arc<RwLock<String>>,
25}
26
27impl std::fmt::Debug for RedisCheckpointer {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("RedisCheckpointer")
30 .field("namespace", &self.namespace)
31 .field("ttl_seconds", &self.ttl_seconds)
32 .finish()
33 }
34}
35
36pub(crate) fn checkpoint_error(message: impl Into<String>) -> WesichainError {
37 WesichainError::CheckpointFailed(message.into())
38}
39
40pub(crate) fn map_redis_error(error: RedisError) -> WesichainError {
41 checkpoint_error(error.to_string())
42}
43
44impl RedisCheckpointer {
45 pub async fn new(url: &str, namespace: impl Into<String>) -> Result<Self, WesichainError> {
46 let config = RedisConfig::from_url(url).map_err(map_redis_error)?;
47 let client = RedisClient::new(config, None, None, None);
48 client.init().await.map_err(map_redis_error)?;
49
50 let script_sha = client
51 .script_load::<String, _>(LUA_SAVE)
52 .await
53 .map_err(map_redis_error)?;
54
55 Ok(Self {
56 client,
57 namespace: namespace.into(),
58 ttl_seconds: None,
59 script_sha: Arc::new(RwLock::new(script_sha)),
60 })
61 }
62
63 pub fn with_ttl(mut self, seconds: u64) -> Self {
64 self.ttl_seconds = Some(seconds);
65 self
66 }
67
68 async fn eval_save(&self, keys: Vec<String>, args: Vec<String>) -> Result<u64, WesichainError> {
69 let existing_sha = self.script_sha.read().await.clone();
70
71 match self
72 .client
73 .evalsha::<u64, _, _, _>(existing_sha, keys.clone(), args.clone())
74 .await
75 {
76 Ok(seq) => Ok(seq),
77 Err(error) if error.to_string().to_ascii_uppercase().contains("NOSCRIPT") => {
78 let new_sha = self
79 .client
80 .script_load::<String, _>(LUA_SAVE)
81 .await
82 .map_err(map_redis_error)?;
83 *self.script_sha.write().await = new_sha.clone();
84
85 self.client
86 .evalsha::<u64, _, _, _>(new_sha, keys, args)
87 .await
88 .map_err(map_redis_error)
89 }
90 Err(error) => Err(map_redis_error(error)),
91 }
92 }
93}
94
95#[async_trait::async_trait]
96impl<S> Checkpointer<S> for RedisCheckpointer
97where
98 S: StateSchema,
99{
100 async fn save(&self, checkpoint: &Checkpoint<S>) -> Result<(), WesichainError> {
101 let thread_id = safe_thread_id(&checkpoint.thread_id)?;
102 let keys = ThreadKeys::new(&self.namespace, thread_id);
103
104 let payload = serde_json::to_string(checkpoint).map_err(|error| {
105 checkpoint_error(format!("failed to serialize checkpoint: {error}"))
106 })?;
107 let ttl = self.ttl_seconds.unwrap_or(0).to_string();
108
109 let _seq = self
110 .eval_save(
111 vec![keys.seq, keys.latest, keys.hist_prefix],
112 vec![payload, ttl],
113 )
114 .await?;
115
116 let now_ms = SystemTime::now()
117 .duration_since(UNIX_EPOCH)
118 .unwrap_or_default()
119 .as_millis() as f64;
120
121 let _ = self
122 .client
123 .zadd::<(), _, _>(
124 index_key(&self.namespace),
125 None,
126 None,
127 false,
128 false,
129 vec![(now_ms, thread_id.to_string())],
130 )
131 .await;
132
133 Ok(())
134 }
135
136 async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint<S>>, WesichainError> {
137 let thread_id = safe_thread_id(thread_id)?;
138 let keys = ThreadKeys::new(&self.namespace, thread_id);
139
140 let payload: Option<String> = self
141 .client
142 .get(&keys.latest)
143 .await
144 .map_err(map_redis_error)?;
145
146 let Some(payload) = payload else {
147 return Ok(None);
148 };
149
150 let checkpoint = serde_json::from_str::<Checkpoint<S>>(&payload).map_err(|error| {
151 checkpoint_error(format!("failed to deserialize checkpoint payload: {error}"))
152 })?;
153
154 Ok(Some(checkpoint))
155 }
156}