wesichain_core/
checkpoint.rs1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3use uuid::Uuid;
4
5use chrono::Utc;
6use serde::{Deserialize, Serialize};
7
8use crate::state::{GraphState, StateSchema};
9use crate::WesichainError;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12#[serde(bound = "S: StateSchema")]
13pub struct Checkpoint<S: StateSchema> {
14 pub thread_id: String,
15 pub state: GraphState<S>,
16 pub step: u64,
17 pub node: String,
18 pub queue: Vec<(String, u64)>,
19 pub created_at: String,
20}
21
22impl<S: StateSchema> Checkpoint<S> {
23 pub fn new(
24 thread_id: String,
25 state: GraphState<S>,
26 step: u64,
27 node: String,
28 queue: Vec<(String, u64)>,
29 ) -> Self {
30 Self {
31 thread_id,
32 state,
33 step,
34 node,
35 queue,
36 created_at: Utc::now().to_rfc3339(),
37 }
38 }
39}
40
41#[async_trait::async_trait]
42pub trait Checkpointer<S: StateSchema>: Send + Sync {
43 async fn save(&self, checkpoint: &Checkpoint<S>) -> Result<(), WesichainError>;
44 async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint<S>>, WesichainError>;
45}
46
47#[derive(Debug, Clone, PartialEq)]
48pub struct CheckpointMetadata {
49 pub seq: u64,
50 pub created_at: String,
51}
52
53#[async_trait::async_trait]
54pub trait HistoryCheckpointer<S: StateSchema>: Send + Sync {
55 async fn list_checkpoints(
56 &self,
57 thread_id: &str,
58 ) -> Result<Vec<CheckpointMetadata>, WesichainError>;
59
60 async fn fork(
66 &self,
67 _thread_id: &str,
68 _at_seq: u64,
69 ) -> Result<String, WesichainError> {
70 Err(WesichainError::CheckpointFailed(
71 "fork() not implemented for this checkpointer".into(),
72 ))
73 }
74}
75
76#[derive(Default, Clone)]
77pub struct InMemoryCheckpointer<S: StateSchema> {
78 inner: Arc<RwLock<HashMap<String, Vec<Checkpoint<S>>>>>,
79}
80
81#[async_trait::async_trait]
82impl<S: StateSchema> Checkpointer<S> for InMemoryCheckpointer<S> {
83 async fn save(&self, checkpoint: &Checkpoint<S>) -> Result<(), WesichainError> {
84 let mut guard = self
85 .inner
86 .write()
87 .map_err(|_| WesichainError::CheckpointFailed("lock".into()))?;
88 guard
89 .entry(checkpoint.thread_id.clone())
90 .or_default()
91 .push(checkpoint.clone());
92 Ok(())
93 }
94
95 async fn load(&self, thread_id: &str) -> Result<Option<Checkpoint<S>>, WesichainError> {
96 let guard = self
97 .inner
98 .read()
99 .map_err(|_| WesichainError::CheckpointFailed("lock".into()))?;
100 Ok(guard
101 .get(thread_id)
102 .and_then(|history| history.last().cloned()))
103 }
104}
105#[async_trait::async_trait]
106impl<S: StateSchema> HistoryCheckpointer<S> for InMemoryCheckpointer<S> {
107 async fn list_checkpoints(
108 &self,
109 thread_id: &str,
110 ) -> Result<Vec<CheckpointMetadata>, WesichainError> {
111 let guard = self
112 .inner
113 .read()
114 .map_err(|_| WesichainError::CheckpointFailed("lock".into()))?;
115 let history = guard.get(thread_id).cloned().unwrap_or_default();
116 let metadata = history
117 .into_iter()
118 .map(|cp| CheckpointMetadata {
119 seq: cp.step,
120 created_at: cp.created_at,
121 })
122 .collect();
123 Ok(metadata)
124 }
125
126 async fn fork(&self, thread_id: &str, at_seq: u64) -> Result<String, WesichainError> {
127 let history = {
128 let guard = self
129 .inner
130 .read()
131 .map_err(|_| WesichainError::CheckpointFailed("lock".into()))?;
132 guard.get(thread_id).cloned().unwrap_or_default()
133 };
134
135 if !history.iter().any(|cp| cp.step == at_seq) {
137 return Err(WesichainError::CheckpointFailed(format!(
138 "no checkpoint at seq {at_seq} in thread '{thread_id}'"
139 )));
140 }
141
142 let prefix: Vec<Checkpoint<S>> = history
144 .into_iter()
145 .filter(|cp| cp.step <= at_seq)
146 .collect();
147
148 let new_thread_id = Uuid::new_v4().to_string();
149
150 let mut guard = self
151 .inner
152 .write()
153 .map_err(|_| WesichainError::CheckpointFailed("lock".into()))?;
154
155 let forked: Vec<Checkpoint<S>> = prefix
157 .into_iter()
158 .map(|mut cp| {
159 cp.thread_id = new_thread_id.clone();
160 cp
161 })
162 .collect();
163
164 guard.insert(new_thread_id.clone(), forked);
165 Ok(new_thread_id)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::state::{GraphState, StateSchema};
173 use serde::{Deserialize, Serialize};
174
175 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
176 struct Counter {
177 n: u32,
178 }
179 impl StateSchema for Counter {
180 type Update = u32;
181 fn apply(current: &Self, update: u32) -> Self {
182 Self { n: current.n + update }
183 }
184 }
185
186 fn make_cp(thread_id: &str, step: u64) -> Checkpoint<Counter> {
187 Checkpoint::new(
188 thread_id.to_string(),
189 GraphState { data: Counter { n: step as u32 } },
190 step,
191 "node".to_string(),
192 vec![],
193 )
194 }
195
196 #[tokio::test]
197 async fn fork_creates_new_thread_up_to_seq() {
198 let cp: InMemoryCheckpointer<Counter> = InMemoryCheckpointer::default();
199 for step in 0..5u64 {
200 cp.save(&make_cp("main", step)).await.unwrap();
201 }
202
203 let fork_id = cp.fork("main", 2).await.unwrap();
204 assert_ne!(fork_id, "main");
205
206 let meta = cp.list_checkpoints(&fork_id).await.unwrap();
207 assert_eq!(meta.len(), 3); let latest = cp.load(&fork_id).await.unwrap().unwrap();
210 assert_eq!(latest.step, 2);
211 }
212
213 #[tokio::test]
214 async fn fork_missing_seq_errors() {
215 let cp: InMemoryCheckpointer<Counter> = InMemoryCheckpointer::default();
216 cp.save(&make_cp("main", 0)).await.unwrap();
217 assert!(cp.fork("main", 99).await.is_err());
218 }
219
220 #[tokio::test]
221 async fn fork_independent_of_origin() {
222 let cp: InMemoryCheckpointer<Counter> = InMemoryCheckpointer::default();
223 for step in 0..3u64 {
224 cp.save(&make_cp("main", step)).await.unwrap();
225 }
226 let fork_id = cp.fork("main", 1).await.unwrap();
227
228 cp.save(&make_cp("main", 3)).await.unwrap();
230 let fork_meta = cp.list_checkpoints(&fork_id).await.unwrap();
231 assert_eq!(fork_meta.len(), 2); }
233}