Skip to main content

wesichain_core/
checkpoint.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3use uuid::Uuid;
4
5use chrono::Utc;
6use serde::{Deserialize, Serialize};
7
8use crate::state::{GraphState, StateSchema};
9use crate::WesichainError;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12#[serde(bound = "S: StateSchema")]
13pub struct Checkpoint<S: StateSchema> {
14    pub thread_id: String,
15    pub state: GraphState<S>,
16    pub step: u64,
17    pub node: String,
18    pub queue: Vec<(String, u64)>,
19    pub created_at: String,
20}
21
22impl<S: StateSchema> Checkpoint<S> {
23    pub fn new(
24        thread_id: String,
25        state: GraphState<S>,
26        step: u64,
27        node: String,
28        queue: Vec<(String, u64)>,
29    ) -> Self {
30        Self {
31            thread_id,
32            state,
33            step,
34            node,
35            queue,
36            created_at: Utc::now().to_rfc3339(),
37        }
38    }
39}
40
41#[async_trait::async_trait]
42pub trait Checkpointer<S: StateSchema>: Send + Sync {
43    async fn save(&self, checkpoint: &Checkpoint<S>) -> Result<(), WesichainError>;
44    async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint<S>>, WesichainError>;
45}
46
47#[derive(Debug, Clone, PartialEq)]
48pub struct CheckpointMetadata {
49    pub seq: u64,
50    pub created_at: String,
51}
52
53#[async_trait::async_trait]
54pub trait HistoryCheckpointer<S: StateSchema>: Send + Sync {
55    async fn list_checkpoints(
56        &self,
57        thread_id: &str,
58    ) -> Result<Vec<CheckpointMetadata>, WesichainError>;
59
60    /// Fork execution from a historical checkpoint.
61    ///
62    /// Copies all state up to and including `at_seq` from `thread_id` into a
63    /// new thread and returns the new thread id.  The caller can then resume
64    /// from the forked thread, creating a separate branch of execution.
65    async fn fork(
66        &self,
67        _thread_id: &str,
68        _at_seq: u64,
69    ) -> Result<String, WesichainError> {
70        Err(WesichainError::CheckpointFailed(
71            "fork() not implemented for this checkpointer".into(),
72        ))
73    }
74}
75
76#[derive(Default, Clone)]
77pub struct InMemoryCheckpointer<S: StateSchema> {
78    inner: Arc<RwLock<HashMap<String, Vec<Checkpoint<S>>>>>,
79}
80
81#[async_trait::async_trait]
82impl<S: StateSchema> Checkpointer<S> for InMemoryCheckpointer<S> {
83    async fn save(&self, checkpoint: &Checkpoint<S>) -> Result<(), WesichainError> {
84        let mut guard = self
85            .inner
86            .write()
87            .map_err(|_| WesichainError::CheckpointFailed("lock".into()))?;
88        guard
89            .entry(checkpoint.thread_id.clone())
90            .or_default()
91            .push(checkpoint.clone());
92        Ok(())
93    }
94
95    async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint<S>>, WesichainError> {
96        let guard = self
97            .inner
98            .read()
99            .map_err(|_| WesichainError::CheckpointFailed("lock".into()))?;
100        Ok(guard
101            .get(thread_id)
102            .and_then(|history| history.last().cloned()))
103    }
104}
105#[async_trait::async_trait]
106impl<S: StateSchema> HistoryCheckpointer<S> for InMemoryCheckpointer<S> {
107    async fn list_checkpoints(
108        &self,
109        thread_id: &str,
110    ) -> Result<Vec<CheckpointMetadata>, WesichainError> {
111        let guard = self
112            .inner
113            .read()
114            .map_err(|_| WesichainError::CheckpointFailed("lock".into()))?;
115        let history = guard.get(thread_id).cloned().unwrap_or_default();
116        let metadata = history
117            .into_iter()
118            .map(|cp| CheckpointMetadata {
119                seq: cp.step,
120                created_at: cp.created_at,
121            })
122            .collect();
123        Ok(metadata)
124    }
125
126    async fn fork(&self, thread_id: &str, at_seq: u64) -> Result<String, WesichainError> {
127        let history = {
128            let guard = self
129                .inner
130                .read()
131                .map_err(|_| WesichainError::CheckpointFailed("lock".into()))?;
132            guard.get(thread_id).cloned().unwrap_or_default()
133        };
134
135        // Require that the requested seq exists before forking
136        if !history.iter().any(|cp| cp.step == at_seq) {
137            return Err(WesichainError::CheckpointFailed(format!(
138                "no checkpoint at seq {at_seq} in thread '{thread_id}'"
139            )));
140        }
141
142        // Collect all checkpoints up to and including at_seq
143        let prefix: Vec<Checkpoint<S>> = history
144            .into_iter()
145            .filter(|cp| cp.step <= at_seq)
146            .collect();
147
148        let new_thread_id = Uuid::new_v4().to_string();
149
150        let mut guard = self
151            .inner
152            .write()
153            .map_err(|_| WesichainError::CheckpointFailed("lock".into()))?;
154
155        // Re-stamp cloned checkpoints with the new thread id
156        let forked: Vec<Checkpoint<S>> = prefix
157            .into_iter()
158            .map(|mut cp| {
159                cp.thread_id = new_thread_id.clone();
160                cp
161            })
162            .collect();
163
164        guard.insert(new_thread_id.clone(), forked);
165        Ok(new_thread_id)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use crate::state::{GraphState, StateSchema};
173    use serde::{Deserialize, Serialize};
174
175    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
176    struct Counter {
177        n: u32,
178    }
179    impl StateSchema for Counter {
180        type Update = u32;
181        fn apply(current: &Self, update: u32) -> Self {
182            Self { n: current.n + update }
183        }
184    }
185
186    fn make_cp(thread_id: &str, step: u64) -> Checkpoint<Counter> {
187        Checkpoint::new(
188            thread_id.to_string(),
189            GraphState { data: Counter { n: step as u32 } },
190            step,
191            "node".to_string(),
192            vec![],
193        )
194    }
195
196    #[tokio::test]
197    async fn fork_creates_new_thread_up_to_seq() {
198        let cp: InMemoryCheckpointer<Counter> = InMemoryCheckpointer::default();
199        for step in 0..5u64 {
200            cp.save(&make_cp("main", step)).await.unwrap();
201        }
202
203        let fork_id = cp.fork("main", 2).await.unwrap();
204        assert_ne!(fork_id, "main");
205
206        let meta = cp.list_checkpoints(&fork_id).await.unwrap();
207        assert_eq!(meta.len(), 3); // steps 0, 1, 2
208
209        let latest = cp.load(&fork_id).await.unwrap().unwrap();
210        assert_eq!(latest.step, 2);
211    }
212
213    #[tokio::test]
214    async fn fork_missing_seq_errors() {
215        let cp: InMemoryCheckpointer<Counter> = InMemoryCheckpointer::default();
216        cp.save(&make_cp("main", 0)).await.unwrap();
217        assert!(cp.fork("main", 99).await.is_err());
218    }
219
220    #[tokio::test]
221    async fn fork_independent_of_origin() {
222        let cp: InMemoryCheckpointer<Counter> = InMemoryCheckpointer::default();
223        for step in 0..3u64 {
224            cp.save(&make_cp("main", step)).await.unwrap();
225        }
226        let fork_id = cp.fork("main", 1).await.unwrap();
227
228        // Saving more into main should not affect the fork
229        cp.save(&make_cp("main", 3)).await.unwrap();
230        let fork_meta = cp.list_checkpoints(&fork_id).await.unwrap();
231        assert_eq!(fork_meta.len(), 2); // still only steps 0 and 1
232    }
233}