stakpak_agent_core/
checkpoint.rs1use serde::{Deserialize, Serialize};
2use serde_json::json;
3use thiserror::Error;
4use uuid::Uuid;
5
6pub const CHECKPOINT_VERSION_V1: u16 = 1;
7pub const CHECKPOINT_FORMAT_V1: &str = "stakai_message_v1";
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CheckpointEnvelopeV1 {
11 pub version: u16,
12 pub format: String,
13 pub run_id: Option<Uuid>,
14 pub messages: Vec<stakai::Message>,
15 pub metadata: serde_json::Value,
16}
17
18impl CheckpointEnvelopeV1 {
19 pub fn new(
20 run_id: Option<Uuid>,
21 messages: Vec<stakai::Message>,
22 metadata: serde_json::Value,
23 ) -> Self {
24 Self {
25 version: CHECKPOINT_VERSION_V1,
26 format: CHECKPOINT_FORMAT_V1.to_string(),
27 run_id,
28 messages,
29 metadata,
30 }
31 }
32}
33
34#[derive(Debug, Error)]
35pub enum CheckpointError {
36 #[error("invalid checkpoint payload: {0}")]
37 InvalidPayload(#[from] serde_json::Error),
38
39 #[error("checkpoint payload is missing version")]
40 MissingVersion,
41
42 #[error("unsupported checkpoint version: {0}")]
43 UnsupportedVersion(u16),
44
45 #[error("unsupported checkpoint format: {0}")]
46 UnsupportedFormat(String),
47}
48
49pub fn serialize_checkpoint(envelope: &CheckpointEnvelopeV1) -> Result<Vec<u8>, CheckpointError> {
50 serde_json::to_vec(envelope).map_err(CheckpointError::InvalidPayload)
51}
52
53pub fn deserialize_checkpoint(payload: &[u8]) -> Result<CheckpointEnvelopeV1, CheckpointError> {
54 let value: serde_json::Value = serde_json::from_slice(payload)?;
55
56 let Some(version) = value.get("version").and_then(serde_json::Value::as_u64) else {
57 if let Some(migrated) = migrate_legacy_checkpoint(&value) {
58 return Ok(migrated);
59 }
60 return Err(CheckpointError::MissingVersion);
61 };
62
63 let version = version as u16;
64
65 if version != CHECKPOINT_VERSION_V1 {
66 return Err(CheckpointError::UnsupportedVersion(version));
67 }
68
69 let envelope: CheckpointEnvelopeV1 = serde_json::from_value(value)?;
70
71 if envelope.format != CHECKPOINT_FORMAT_V1 {
72 return Err(CheckpointError::UnsupportedFormat(envelope.format));
73 }
74
75 Ok(envelope)
76}
77
78fn migrate_legacy_checkpoint(value: &serde_json::Value) -> Option<CheckpointEnvelopeV1> {
79 if value.is_array() {
80 let messages: Vec<stakai::Message> = serde_json::from_value(value.clone()).ok()?;
81 return Some(CheckpointEnvelopeV1::new(
82 None,
83 messages,
84 json!({"migrated_from": "legacy_messages_array"}),
85 ));
86 }
87
88 let object = value.as_object()?;
89 let messages_value = object.get("messages")?;
90 let messages: Vec<stakai::Message> = serde_json::from_value(messages_value.clone()).ok()?;
91
92 let run_id = object
93 .get("run_id")
94 .and_then(|value| serde_json::from_value::<Uuid>(value.clone()).ok());
95
96 let metadata = object.get("metadata").cloned().unwrap_or_else(|| json!({}));
97
98 Some(CheckpointEnvelopeV1::new(run_id, messages, metadata))
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use serde_json::json;
105 use stakai::{Message, Role};
106
107 #[test]
108 fn roundtrip_v1_envelope() {
109 let run_id = Some(Uuid::new_v4());
110 let envelope = CheckpointEnvelopeV1::new(
111 run_id,
112 vec![Message::new(Role::User, "hello")],
113 json!({"cwd":"/workspace"}),
114 );
115
116 let payload = match serialize_checkpoint(&envelope) {
117 Ok(payload) => payload,
118 Err(error) => panic!("serialization should succeed, got: {error}"),
119 };
120
121 let parsed = match deserialize_checkpoint(&payload) {
122 Ok(parsed) => parsed,
123 Err(error) => panic!("deserialization should succeed, got: {error}"),
124 };
125
126 assert_eq!(parsed.version, envelope.version);
127 assert_eq!(parsed.format, envelope.format);
128 assert_eq!(parsed.run_id, envelope.run_id);
129 assert_eq!(parsed.metadata, envelope.metadata);
130
131 let first_message_text = parsed.messages.first().and_then(stakai::Message::text);
132 assert_eq!(first_message_text, Some("hello".to_string()));
133 }
134
135 #[test]
136 fn migrates_legacy_messages_array() {
137 let payload = json!([
138 {
139 "role": "user",
140 "content": "legacy"
141 }
142 ]);
143
144 let result = deserialize_checkpoint(payload.to_string().as_bytes());
145 let envelope = match result {
146 Ok(envelope) => envelope,
147 Err(error) => panic!("legacy checkpoint should migrate: {error}"),
148 };
149
150 assert_eq!(envelope.version, CHECKPOINT_VERSION_V1);
151 assert_eq!(envelope.format, CHECKPOINT_FORMAT_V1);
152 assert_eq!(envelope.run_id, None);
153 assert_eq!(
154 envelope.messages.first().and_then(stakai::Message::text),
155 Some("legacy".to_string())
156 );
157 }
158
159 #[test]
160 fn migrates_legacy_messages_object_with_run_id() {
161 let run_id = Uuid::new_v4();
162 let payload = json!({
163 "run_id": run_id,
164 "messages": [
165 {
166 "role": "assistant",
167 "content": "legacy object"
168 }
169 ],
170 "metadata": {"legacy": true}
171 });
172
173 let result = deserialize_checkpoint(payload.to_string().as_bytes());
174 let envelope = match result {
175 Ok(envelope) => envelope,
176 Err(error) => panic!("legacy object checkpoint should migrate: {error}"),
177 };
178
179 assert_eq!(envelope.run_id, Some(run_id));
180 assert_eq!(
181 envelope.messages.first().and_then(stakai::Message::text),
182 Some("legacy object".to_string())
183 );
184 assert_eq!(envelope.metadata, json!({"legacy": true}));
185 }
186
187 #[test]
188 fn reject_unsupported_version() {
189 let payload = json!({
190 "version": 2,
191 "format": CHECKPOINT_FORMAT_V1,
192 "run_id": null,
193 "messages": [],
194 "metadata": {}
195 });
196
197 let result = deserialize_checkpoint(payload.to_string().as_bytes());
198 assert_eq!(
199 result.err().map(|e| e.to_string()),
200 Some("unsupported checkpoint version: 2".to_string())
201 );
202 }
203
204 #[test]
205 fn reject_wrong_format() {
206 let payload = json!({
207 "version": 1,
208 "format": "legacy",
209 "run_id": null,
210 "messages": [],
211 "metadata": {}
212 });
213
214 let result = deserialize_checkpoint(payload.to_string().as_bytes());
215 assert_eq!(
216 result.err().map(|e| e.to_string()),
217 Some("unsupported checkpoint format: legacy".to_string())
218 );
219 }
220
221 #[test]
222 fn reject_payload_without_version() {
223 let payload = json!({
224 "format": CHECKPOINT_FORMAT_V1,
225 "run_id": null,
226 "metadata": {}
227 });
228
229 let result = deserialize_checkpoint(payload.to_string().as_bytes());
230 assert_eq!(
231 result.err().map(|e| e.to_string()),
232 Some("checkpoint payload is missing version".to_string())
233 );
234 }
235}