Skip to main content

stakpak_agent_core/
checkpoint.rs

1use serde::{Deserialize, Serialize};
2use serde_json::json;
3use thiserror::Error;
4use uuid::Uuid;
5
6pub const CHECKPOINT_VERSION_V1: u16 = 1;
7pub const CHECKPOINT_FORMAT_V1: &str = "stakai_message_v1";
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CheckpointEnvelopeV1 {
11    pub version: u16,
12    pub format: String,
13    pub run_id: Option<Uuid>,
14    pub messages: Vec<stakai::Message>,
15    pub metadata: serde_json::Value,
16}
17
18impl CheckpointEnvelopeV1 {
19    pub fn new(
20        run_id: Option<Uuid>,
21        messages: Vec<stakai::Message>,
22        metadata: serde_json::Value,
23    ) -> Self {
24        Self {
25            version: CHECKPOINT_VERSION_V1,
26            format: CHECKPOINT_FORMAT_V1.to_string(),
27            run_id,
28            messages,
29            metadata,
30        }
31    }
32}
33
34#[derive(Debug, Error)]
35pub enum CheckpointError {
36    #[error("invalid checkpoint payload: {0}")]
37    InvalidPayload(#[from] serde_json::Error),
38
39    #[error("checkpoint payload is missing version")]
40    MissingVersion,
41
42    #[error("unsupported checkpoint version: {0}")]
43    UnsupportedVersion(u16),
44
45    #[error("unsupported checkpoint format: {0}")]
46    UnsupportedFormat(String),
47}
48
49pub fn serialize_checkpoint(envelope: &CheckpointEnvelopeV1) -> Result<Vec<u8>, CheckpointError> {
50    serde_json::to_vec(envelope).map_err(CheckpointError::InvalidPayload)
51}
52
53pub fn deserialize_checkpoint(payload: &[u8]) -> Result<CheckpointEnvelopeV1, CheckpointError> {
54    let value: serde_json::Value = serde_json::from_slice(payload)?;
55
56    let Some(version) = value.get("version").and_then(serde_json::Value::as_u64) else {
57        if let Some(migrated) = migrate_legacy_checkpoint(&value) {
58            return Ok(migrated);
59        }
60        return Err(CheckpointError::MissingVersion);
61    };
62
63    let version = version as u16;
64
65    if version != CHECKPOINT_VERSION_V1 {
66        return Err(CheckpointError::UnsupportedVersion(version));
67    }
68
69    let envelope: CheckpointEnvelopeV1 = serde_json::from_value(value)?;
70
71    if envelope.format != CHECKPOINT_FORMAT_V1 {
72        return Err(CheckpointError::UnsupportedFormat(envelope.format));
73    }
74
75    Ok(envelope)
76}
77
78fn migrate_legacy_checkpoint(value: &serde_json::Value) -> Option<CheckpointEnvelopeV1> {
79    if value.is_array() {
80        let messages: Vec<stakai::Message> = serde_json::from_value(value.clone()).ok()?;
81        return Some(CheckpointEnvelopeV1::new(
82            None,
83            messages,
84            json!({"migrated_from": "legacy_messages_array"}),
85        ));
86    }
87
88    let object = value.as_object()?;
89    let messages_value = object.get("messages")?;
90    let messages: Vec<stakai::Message> = serde_json::from_value(messages_value.clone()).ok()?;
91
92    let run_id = object
93        .get("run_id")
94        .and_then(|value| serde_json::from_value::<Uuid>(value.clone()).ok());
95
96    let metadata = object.get("metadata").cloned().unwrap_or_else(|| json!({}));
97
98    Some(CheckpointEnvelopeV1::new(run_id, messages, metadata))
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use serde_json::json;
105    use stakai::{Message, Role};
106
107    #[test]
108    fn roundtrip_v1_envelope() {
109        let run_id = Some(Uuid::new_v4());
110        let envelope = CheckpointEnvelopeV1::new(
111            run_id,
112            vec![Message::new(Role::User, "hello")],
113            json!({"cwd":"/workspace"}),
114        );
115
116        let payload = match serialize_checkpoint(&envelope) {
117            Ok(payload) => payload,
118            Err(error) => panic!("serialization should succeed, got: {error}"),
119        };
120
121        let parsed = match deserialize_checkpoint(&payload) {
122            Ok(parsed) => parsed,
123            Err(error) => panic!("deserialization should succeed, got: {error}"),
124        };
125
126        assert_eq!(parsed.version, envelope.version);
127        assert_eq!(parsed.format, envelope.format);
128        assert_eq!(parsed.run_id, envelope.run_id);
129        assert_eq!(parsed.metadata, envelope.metadata);
130
131        let first_message_text = parsed.messages.first().and_then(stakai::Message::text);
132        assert_eq!(first_message_text, Some("hello".to_string()));
133    }
134
135    #[test]
136    fn migrates_legacy_messages_array() {
137        let payload = json!([
138            {
139                "role": "user",
140                "content": "legacy"
141            }
142        ]);
143
144        let result = deserialize_checkpoint(payload.to_string().as_bytes());
145        let envelope = match result {
146            Ok(envelope) => envelope,
147            Err(error) => panic!("legacy checkpoint should migrate: {error}"),
148        };
149
150        assert_eq!(envelope.version, CHECKPOINT_VERSION_V1);
151        assert_eq!(envelope.format, CHECKPOINT_FORMAT_V1);
152        assert_eq!(envelope.run_id, None);
153        assert_eq!(
154            envelope.messages.first().and_then(stakai::Message::text),
155            Some("legacy".to_string())
156        );
157    }
158
159    #[test]
160    fn migrates_legacy_messages_object_with_run_id() {
161        let run_id = Uuid::new_v4();
162        let payload = json!({
163            "run_id": run_id,
164            "messages": [
165                {
166                    "role": "assistant",
167                    "content": "legacy object"
168                }
169            ],
170            "metadata": {"legacy": true}
171        });
172
173        let result = deserialize_checkpoint(payload.to_string().as_bytes());
174        let envelope = match result {
175            Ok(envelope) => envelope,
176            Err(error) => panic!("legacy object checkpoint should migrate: {error}"),
177        };
178
179        assert_eq!(envelope.run_id, Some(run_id));
180        assert_eq!(
181            envelope.messages.first().and_then(stakai::Message::text),
182            Some("legacy object".to_string())
183        );
184        assert_eq!(envelope.metadata, json!({"legacy": true}));
185    }
186
187    #[test]
188    fn reject_unsupported_version() {
189        let payload = json!({
190            "version": 2,
191            "format": CHECKPOINT_FORMAT_V1,
192            "run_id": null,
193            "messages": [],
194            "metadata": {}
195        });
196
197        let result = deserialize_checkpoint(payload.to_string().as_bytes());
198        assert_eq!(
199            result.err().map(|e| e.to_string()),
200            Some("unsupported checkpoint version: 2".to_string())
201        );
202    }
203
204    #[test]
205    fn reject_wrong_format() {
206        let payload = json!({
207            "version": 1,
208            "format": "legacy",
209            "run_id": null,
210            "messages": [],
211            "metadata": {}
212        });
213
214        let result = deserialize_checkpoint(payload.to_string().as_bytes());
215        assert_eq!(
216            result.err().map(|e| e.to_string()),
217            Some("unsupported checkpoint format: legacy".to_string())
218        );
219    }
220
221    #[test]
222    fn reject_payload_without_version() {
223        let payload = json!({
224            "format": CHECKPOINT_FORMAT_V1,
225            "run_id": null,
226            "metadata": {}
227        });
228
229        let result = deserialize_checkpoint(payload.to_string().as_bytes());
230        assert_eq!(
231            result.err().map(|e| e.to_string()),
232            Some("checkpoint payload is missing version".to_string())
233        );
234    }
235}