1use chrono::Utc;
19use rustc_hash::FxHashMap;
20use serde::{Deserialize, Serialize};
21use serde_json::Value;
22
23use crate::{
24 channels::{Channel, ExtrasChannel, MessagesChannel},
25 message::Message,
26 runtimes::checkpointer::Checkpoint,
27 state::VersionedState,
28 types::NodeKind,
29 utils::json_ext::JsonSerializable,
30};
31
32impl<T> JsonSerializable<PersistenceError> for T
34where
35 T: serde::Serialize + for<'de> serde::de::DeserializeOwned,
36{
37 fn to_json_string(&self) -> std::result::Result<String, PersistenceError> {
38 serde_json::to_string(self).map_err(|e| PersistenceError::Serde { source: e })
39 }
40
41 fn from_json_str(s: &str) -> std::result::Result<Self, PersistenceError> {
42 serde_json::from_str(s).map_err(|e| PersistenceError::Serde { source: e })
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
48pub struct PersistedVecChannel<T> {
49 pub version: u32,
50 #[serde(default)]
51 pub items: Vec<T>,
52}
53
54impl<T> Default for PersistedVecChannel<T> {
55 fn default() -> Self {
56 Self {
57 version: 1,
58 items: Vec::new(),
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65pub struct PersistedMapChannel<V> {
66 pub version: u32,
67 #[serde(default)]
68 pub map: FxHashMap<String, V>,
69}
70
71impl<V> Default for PersistedMapChannel<V> {
72 fn default() -> Self {
73 Self {
74 version: 1,
75 map: FxHashMap::default(),
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
82pub struct PersistedState {
83 pub messages: PersistedVecChannel<Message>,
84 pub extra: PersistedMapChannel<Value>,
85 #[serde(default)]
86 pub errors: PersistedVecChannel<crate::channels::errors::ErrorEvent>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
91pub struct PersistedVersionsSeen(pub FxHashMap<String, FxHashMap<String, u64>>);
92
93#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
96pub struct PersistedCheckpoint {
97 pub session_id: String,
98 pub step: u64,
99 pub state: PersistedState,
100 pub frontier: Vec<String>,
102 pub versions_seen: PersistedVersionsSeen,
103 pub concurrency_limit: usize,
104 pub created_at: String,
106 #[serde(default)]
108 pub ran_nodes: Vec<String>,
109 #[serde(default)]
111 pub skipped_nodes: Vec<String>,
112 #[serde(default)]
114 pub updated_channels: Vec<String>,
115}
116
117use thiserror::Error;
118
119#[derive(Debug, Error)]
121#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
122pub enum PersistenceError {
123 #[error("missing field: {0}")]
124 #[cfg_attr(
125 feature = "diagnostics",
126 diagnostic(
127 code(weavegraph::persistence::missing_field),
128 help("Populate the field in the persisted JSON before conversion.")
129 )
130 )]
131 MissingField(&'static str),
132
133 #[error("JSON serialization/deserialization failed: {source}")]
134 #[cfg_attr(
135 feature = "diagnostics",
136 diagnostic(
137 code(weavegraph::persistence::serde),
138 help("Ensure the JSON structure matches Persisted* types; serde error: {source}.")
139 )
140 )]
141 Serde {
142 #[source]
143 source: serde_json::Error,
144 },
145
146 #[error("persistence error: {0}")]
147 #[cfg_attr(
148 feature = "diagnostics",
149 diagnostic(code(weavegraph::persistence::other))
150 )]
151 Other(String),
152}
153
154pub type Result<T> = std::result::Result<T, PersistenceError>;
155
156impl From<&VersionedState> for PersistedState {
159 fn from(s: &VersionedState) -> Self {
160 PersistedState {
161 messages: PersistedVecChannel {
162 version: s.messages.version(),
163 items: s.messages.snapshot(),
164 },
165 extra: PersistedMapChannel {
166 version: s.extra.version(),
167 map: s.extra.snapshot(),
168 },
169 errors: PersistedVecChannel {
170 version: s.errors.version(),
171 items: s.errors.snapshot(),
172 },
173 }
174 }
175}
176
177impl TryFrom<PersistedState> for VersionedState {
178 type Error = PersistenceError;
179
180 fn try_from(p: PersistedState) -> Result<Self> {
181 Ok(VersionedState {
182 messages: MessagesChannel::new(p.messages.items, p.messages.version),
183 extra: ExtrasChannel::new(p.extra.map, p.extra.version),
184 errors: crate::channels::ErrorsChannel::new(p.errors.items, p.errors.version),
185 })
186 }
187}
188
189impl From<&FxHashMap<String, FxHashMap<String, u64>>> for PersistedVersionsSeen {
192 fn from(v: &FxHashMap<String, FxHashMap<String, u64>>) -> Self {
193 PersistedVersionsSeen(v.clone())
194 }
195}
196
197impl From<PersistedVersionsSeen> for FxHashMap<String, FxHashMap<String, u64>> {
198 fn from(p: PersistedVersionsSeen) -> Self {
199 p.0
200 }
201}
202
203impl From<&Checkpoint> for PersistedCheckpoint {
206 fn from(cp: &Checkpoint) -> Self {
207 PersistedCheckpoint {
208 session_id: cp.session_id.clone(),
209 step: cp.step,
210 state: PersistedState::from(&cp.state),
211 frontier: cp.frontier.iter().map(|k| k.encode()).collect(),
212 versions_seen: PersistedVersionsSeen(cp.versions_seen.clone()),
213 concurrency_limit: cp.concurrency_limit,
214 created_at: cp.created_at.to_rfc3339(),
215 ran_nodes: cp.ran_nodes.iter().map(|k| k.encode()).collect(),
216 skipped_nodes: cp.skipped_nodes.iter().map(|k| k.encode()).collect(),
217 updated_channels: cp.updated_channels.clone(),
218 }
219 }
220}
221
222impl TryFrom<PersistedCheckpoint> for Checkpoint {
223 type Error = PersistenceError;
224
225 fn try_from(p: PersistedCheckpoint) -> Result<Self> {
226 let state = VersionedState::try_from(p.state)?;
227 let frontier: Vec<NodeKind> = p.frontier.iter().map(|s| NodeKind::decode(s)).collect();
228 let ran_nodes: Vec<NodeKind> = p.ran_nodes.iter().map(|s| NodeKind::decode(s)).collect();
229 let skipped_nodes: Vec<NodeKind> = p
230 .skipped_nodes
231 .iter()
232 .map(|s| NodeKind::decode(s))
233 .collect();
234 let parsed_dt = chrono::DateTime::parse_from_rfc3339(&p.created_at)
235 .map(|dt| dt.with_timezone(&Utc))
236 .unwrap_or_else(|_| Utc::now());
237 Ok(Checkpoint {
238 session_id: p.session_id,
239 step: p.step,
240 state,
241 frontier,
242 versions_seen: p.versions_seen.0,
243 concurrency_limit: p.concurrency_limit,
244 created_at: parsed_dt,
245 ran_nodes,
246 skipped_nodes,
247 updated_channels: p.updated_channels,
248 })
249 }
250}
251
252