Skip to main content

wesichain_checkpoint_redis/
lib.rs

1mod history;
2mod keys;
3mod script;
4
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use crate::keys::{index_key, safe_thread_id, ThreadKeys};
9use crate::script::LUA_SAVE;
10use fred::interfaces::{KeysInterface, LuaInterface, SortedSetsInterface};
11use fred::prelude::*;
12use tokio::sync::RwLock;
13use wesichain_core::checkpoint::{Checkpoint, Checkpointer};
14use wesichain_core::state::StateSchema;
15use wesichain_core::WesichainError;
16
17pub use keys::{index_key as redis_index_key, safe_thread_id as validate_thread_id};
18
19#[derive(Clone)]
20pub struct RedisCheckpointer {
21    client: RedisClient,
22    namespace: String,
23    ttl_seconds: Option<u64>,
24    script_sha: Arc<RwLock<String>>,
25}
26
27impl std::fmt::Debug for RedisCheckpointer {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("RedisCheckpointer")
30            .field("namespace", &self.namespace)
31            .field("ttl_seconds", &self.ttl_seconds)
32            .finish()
33    }
34}
35
36pub(crate) fn checkpoint_error(message: impl Into<String>) -> WesichainError {
37    WesichainError::CheckpointFailed(message.into())
38}
39
40pub(crate) fn map_redis_error(error: RedisError) -> WesichainError {
41    checkpoint_error(error.to_string())
42}
43
44impl RedisCheckpointer {
45    pub async fn new(url: &str, namespace: impl Into<String>) -> Result<Self, WesichainError> {
46        let config = RedisConfig::from_url(url).map_err(map_redis_error)?;
47        let client = RedisClient::new(config, None, None, None);
48        client.init().await.map_err(map_redis_error)?;
49
50        let script_sha = client
51            .script_load::<String, _>(LUA_SAVE)
52            .await
53            .map_err(map_redis_error)?;
54
55        Ok(Self {
56            client,
57            namespace: namespace.into(),
58            ttl_seconds: None,
59            script_sha: Arc::new(RwLock::new(script_sha)),
60        })
61    }
62
63    pub fn with_ttl(mut self, seconds: u64) -> Self {
64        self.ttl_seconds = Some(seconds);
65        self
66    }
67
68    async fn eval_save(&self, keys: Vec<String>, args: Vec<String>) -> Result<u64, WesichainError> {
69        let existing_sha = self.script_sha.read().await.clone();
70
71        match self
72            .client
73            .evalsha::<u64, _, _, _>(existing_sha, keys.clone(), args.clone())
74            .await
75        {
76            Ok(seq) => Ok(seq),
77            Err(error) if error.to_string().to_ascii_uppercase().contains("NOSCRIPT") => {
78                let new_sha = self
79                    .client
80                    .script_load::<String, _>(LUA_SAVE)
81                    .await
82                    .map_err(map_redis_error)?;
83                *self.script_sha.write().await = new_sha.clone();
84
85                self.client
86                    .evalsha::<u64, _, _, _>(new_sha, keys, args)
87                    .await
88                    .map_err(map_redis_error)
89            }
90            Err(error) => Err(map_redis_error(error)),
91        }
92    }
93}
94
95#[async_trait::async_trait]
96impl<S> Checkpointer<S> for RedisCheckpointer
97where
98    S: StateSchema,
99{
100    async fn save(&self, checkpoint: &Checkpoint<S>) -> Result<(), WesichainError> {
101        let thread_id = safe_thread_id(&checkpoint.thread_id)?;
102        let keys = ThreadKeys::new(&self.namespace, thread_id);
103
104        let payload = serde_json::to_string(checkpoint).map_err(|error| {
105            checkpoint_error(format!("failed to serialize checkpoint: {error}"))
106        })?;
107        let ttl = self.ttl_seconds.unwrap_or(0).to_string();
108
109        let _seq = self
110            .eval_save(
111                vec![keys.seq, keys.latest, keys.hist_prefix],
112                vec![payload, ttl],
113            )
114            .await?;
115
116        let now_ms = SystemTime::now()
117            .duration_since(UNIX_EPOCH)
118            .unwrap_or_default()
119            .as_millis() as f64;
120
121        let _ = self
122            .client
123            .zadd::<(), _, _>(
124                index_key(&self.namespace),
125                None,
126                None,
127                false,
128                false,
129                vec![(now_ms, thread_id.to_string())],
130            )
131            .await;
132
133        Ok(())
134    }
135
136    async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint<S>>, WesichainError> {
137        let thread_id = safe_thread_id(thread_id)?;
138        let keys = ThreadKeys::new(&self.namespace, thread_id);
139
140        let payload: Option<String> = self
141            .client
142            .get(&keys.latest)
143            .await
144            .map_err(map_redis_error)?;
145
146        let Some(payload) = payload else {
147            return Ok(None);
148        };
149
150        let checkpoint = serde_json::from_str::<Checkpoint<S>>(&payload).map_err(|error| {
151            checkpoint_error(format!("failed to deserialize checkpoint payload: {error}"))
152        })?;
153
154        Ok(Some(checkpoint))
155    }
156}