Skip to main content

weavegraph/runtimes/
persistence.rs

1/*!
2Persistence primitives for serializing/deserializing Weavegraph runtime
3state and checkpoints (used by the SQLite checkpointer and any future
4persistent backends).
5
6Design Goals:
7- Provide explicit serde-friendly structs decoupled from internal
8  in-memory representations.
9- Keep conversion logic localized (From / TryFrom impls) so the
10  checkpointer code is lean and declarative.
11- Allow forward compatibility (unknown NodeKind encodings round-trip
12  as `NodeKind::Custom(encoded_string)`).
13
14This module intentionally does NOT perform I/O. It is pure data
15transformation and (de)serialization glue.
16*/
17
18use chrono::Utc;
19use rustc_hash::FxHashMap;
20use serde::{Deserialize, Serialize};
21use serde_json::Value;
22
23use crate::{
24    channels::{Channel, ExtrasChannel, MessagesChannel},
25    message::Message,
26    runtimes::checkpointer::Checkpoint,
27    state::VersionedState,
28    types::NodeKind,
29    utils::json_ext::JsonSerializable,
30};
31
32/// Blanket implementation of JsonSerializable for all suitable types using PersistenceError.
33impl<T> JsonSerializable<PersistenceError> for T
34where
35    T: serde::Serialize + for<'de> serde::de::DeserializeOwned,
36{
37    fn to_json_string(&self) -> std::result::Result<String, PersistenceError> {
38        serde_json::to_string(self).map_err(|e| PersistenceError::Serde { source: e })
39    }
40
41    fn from_json_str(s: &str) -> std::result::Result<Self, PersistenceError> {
42        serde_json::from_str(s).map_err(|e| PersistenceError::Serde { source: e })
43    }
44}
45
46/// Channel that stores a vector collection (e.g., messages) with version metadata.
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
48pub struct PersistedVecChannel<T> {
49    pub version: u32,
50    #[serde(default)]
51    pub items: Vec<T>,
52}
53
54impl<T> Default for PersistedVecChannel<T> {
55    fn default() -> Self {
56        Self {
57            version: 1,
58            items: Vec::new(),
59        }
60    }
61}
62
63/// Channel that stores a map collection (e.g., extra) with version metadata.
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65pub struct PersistedMapChannel<V> {
66    pub version: u32,
67    #[serde(default)]
68    pub map: FxHashMap<String, V>,
69}
70
71impl<V> Default for PersistedMapChannel<V> {
72    fn default() -> Self {
73        Self {
74            version: 1,
75            map: FxHashMap::default(),
76        }
77    }
78}
79
80/// Complete persisted shape of the in‑memory VersionedState.
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
82pub struct PersistedState {
83    pub messages: PersistedVecChannel<Message>,
84    pub extra: PersistedMapChannel<Value>,
85    #[serde(default)]
86    pub errors: PersistedVecChannel<crate::channels::errors::ErrorEvent>,
87}
88
89/// Wrapper for the scheduler versions_seen structure.
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
91pub struct PersistedVersionsSeen(pub FxHashMap<String, FxHashMap<String, u64>>);
92
93/// Full persisted checkpoint representation.
94/// (Step history tables may store multiple instances of this shape.)
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
96pub struct PersistedCheckpoint {
97    pub session_id: String,
98    pub step: u64,
99    pub state: PersistedState,
100    /// Frontier encoded as string vector using NodeKind::encode().
101    pub frontier: Vec<String>,
102    pub versions_seen: PersistedVersionsSeen,
103    pub concurrency_limit: usize,
104    /// RFC3339 string form of creation time (keeps chrono::DateTime out of serialized shape).
105    pub created_at: String,
106    /// Nodes that executed in this step, encoded as strings
107    #[serde(default)]
108    pub ran_nodes: Vec<String>,
109    /// Nodes that were skipped in this step, encoded as strings
110    #[serde(default)]
111    pub skipped_nodes: Vec<String>,
112    /// Channels that were updated in this step
113    #[serde(default)]
114    pub updated_channels: Vec<String>,
115}
116
117use thiserror::Error;
118
119/// Bidirectional conversion and serialization errors for persistence models.
120#[derive(Debug, Error)]
121#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
122pub enum PersistenceError {
123    #[error("missing field: {0}")]
124    #[cfg_attr(
125        feature = "diagnostics",
126        diagnostic(
127            code(weavegraph::persistence::missing_field),
128            help("Populate the field in the persisted JSON before conversion.")
129        )
130    )]
131    MissingField(&'static str),
132
133    #[error("JSON serialization/deserialization failed: {source}")]
134    #[cfg_attr(
135        feature = "diagnostics",
136        diagnostic(
137            code(weavegraph::persistence::serde),
138            help("Ensure the JSON structure matches Persisted* types; serde error: {source}.")
139        )
140    )]
141    Serde {
142        #[source]
143        source: serde_json::Error,
144    },
145
146    #[error("persistence error: {0}")]
147    #[cfg_attr(
148        feature = "diagnostics",
149        diagnostic(code(weavegraph::persistence::other))
150    )]
151    Other(String),
152}
153
154pub type Result<T> = std::result::Result<T, PersistenceError>;
155
156/* ---------- VersionedState <-> PersistedState Conversions ---------- */
157
158impl From<&VersionedState> for PersistedState {
159    fn from(s: &VersionedState) -> Self {
160        PersistedState {
161            messages: PersistedVecChannel {
162                version: s.messages.version(),
163                items: s.messages.snapshot(),
164            },
165            extra: PersistedMapChannel {
166                version: s.extra.version(),
167                map: s.extra.snapshot(),
168            },
169            errors: PersistedVecChannel {
170                version: s.errors.version(),
171                items: s.errors.snapshot(),
172            },
173        }
174    }
175}
176
177impl TryFrom<PersistedState> for VersionedState {
178    type Error = PersistenceError;
179
180    fn try_from(p: PersistedState) -> Result<Self> {
181        Ok(VersionedState {
182            messages: MessagesChannel::new(p.messages.items, p.messages.version),
183            extra: ExtrasChannel::new(p.extra.map, p.extra.version),
184            errors: crate::channels::ErrorsChannel::new(p.errors.items, p.errors.version),
185        })
186    }
187}
188
189/* ---------- versions_seen conversions ---------- */
190
191impl From<&FxHashMap<String, FxHashMap<String, u64>>> for PersistedVersionsSeen {
192    fn from(v: &FxHashMap<String, FxHashMap<String, u64>>) -> Self {
193        PersistedVersionsSeen(v.clone())
194    }
195}
196
197impl From<PersistedVersionsSeen> for FxHashMap<String, FxHashMap<String, u64>> {
198    fn from(p: PersistedVersionsSeen) -> Self {
199        p.0
200    }
201}
202
203/* ---------- Checkpoint <-> PersistedCheckpoint Conversions ---------- */
204
205impl From<&Checkpoint> for PersistedCheckpoint {
206    fn from(cp: &Checkpoint) -> Self {
207        PersistedCheckpoint {
208            session_id: cp.session_id.clone(),
209            step: cp.step,
210            state: PersistedState::from(&cp.state),
211            frontier: cp.frontier.iter().map(|k| k.encode()).collect(),
212            versions_seen: PersistedVersionsSeen(cp.versions_seen.clone()),
213            concurrency_limit: cp.concurrency_limit,
214            created_at: cp.created_at.to_rfc3339(),
215            ran_nodes: cp.ran_nodes.iter().map(|k| k.encode()).collect(),
216            skipped_nodes: cp.skipped_nodes.iter().map(|k| k.encode()).collect(),
217            updated_channels: cp.updated_channels.clone(),
218        }
219    }
220}
221
222impl TryFrom<PersistedCheckpoint> for Checkpoint {
223    type Error = PersistenceError;
224
225    fn try_from(p: PersistedCheckpoint) -> Result<Self> {
226        let state = VersionedState::try_from(p.state)?;
227        let frontier: Vec<NodeKind> = p.frontier.iter().map(|s| NodeKind::decode(s)).collect();
228        let ran_nodes: Vec<NodeKind> = p.ran_nodes.iter().map(|s| NodeKind::decode(s)).collect();
229        let skipped_nodes: Vec<NodeKind> = p
230            .skipped_nodes
231            .iter()
232            .map(|s| NodeKind::decode(s))
233            .collect();
234        let parsed_dt = chrono::DateTime::parse_from_rfc3339(&p.created_at)
235            .map(|dt| dt.with_timezone(&Utc))
236            .unwrap_or_else(|_| Utc::now());
237        Ok(Checkpoint {
238            session_id: p.session_id,
239            step: p.step,
240            state,
241            frontier,
242            versions_seen: p.versions_seen.0,
243            concurrency_limit: p.concurrency_limit,
244            created_at: parsed_dt,
245            ran_nodes,
246            skipped_nodes,
247            updated_channels: p.updated_channels,
248        })
249    }
250}
251
252/* ---------- Convenience JSON helpers (using JsonSerializable trait from utils::json_ext) ---------- */
253
254// Both PersistedState and PersistedCheckpoint automatically implement JsonSerializable
255// through the blanket implementation above, providing to_json_string() and from_json_str() methods.