Skip to main content

swink_agent/
checkpoint.rs

1//! State persistence and checkpointing for agent conversations.
2//!
3//! Provides a [`Checkpoint`] struct that captures a snapshot of agent state
4//! (messages, system prompt, model, turn count, metadata) and a
5//! [`CheckpointStore`] trait for async save/load of checkpoints.
6
7use std::collections::HashMap;
8
9use serde::{Deserialize, Serialize};
10
11use crate::types::message_codec::{self, MessageSlot};
12use crate::types::{AgentMessage, Cost, CustomMessageRegistry, LlmMessage, Usage};
13
14mod store;
15
16pub use store::{CheckpointFuture, CheckpointStore};
17
18// ─── Checkpoint ──────────────────────────────────────────────────────────────
19
20/// A serializable snapshot of agent conversation state.
21///
22/// Captures everything needed to restore an agent to a previous point:
23/// messages, system prompt, model info, turn count, accumulated usage/cost,
24/// and arbitrary metadata.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Checkpoint {
27    /// Unique identifier for this checkpoint.
28    pub id: String,
29    /// System prompt at the time of the checkpoint.
30    pub system_prompt: String,
31    /// Model provider name.
32    pub provider: String,
33    /// Model identifier.
34    pub model_id: String,
35    /// Conversation messages (LLM messages only).
36    pub messages: Vec<LlmMessage>,
37    /// Serialized custom messages (envelopes with `type` and `data` fields).
38    #[serde(default, skip_serializing_if = "Vec::is_empty")]
39    pub custom_messages: Vec<serde_json::Value>,
40    /// Records the original interleaved order of LLM and custom messages.
41    ///
42    /// Empty for checkpoints created before ordering support was added;
43    /// `restore_messages` falls back to legacy (LLM-first) behavior in that case.
44    #[serde(default, skip_serializing_if = "Vec::is_empty")]
45    message_order: Vec<MessageSlot>,
46    /// Number of completed turns at the time of checkpointing.
47    pub turn_count: usize,
48    /// Accumulated token usage.
49    pub usage: Usage,
50    /// Accumulated cost.
51    pub cost: Cost,
52    /// Unix timestamp when the checkpoint was created.
53    pub created_at: u64,
54    /// Arbitrary metadata for application-specific use.
55    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
56    pub metadata: HashMap<String, serde_json::Value>,
57    /// Serialized session state snapshot (`SessionState.data`).
58    #[serde(default, skip_serializing_if = "Option::is_none")]
59    pub state: Option<serde_json::Value>,
60}
61
62impl Checkpoint {
63    /// Create a new checkpoint from the current agent state.
64    ///
65    /// Serializes `CustomMessage` variants that support serialization (i.e.
66    /// `type_name()` and `to_json()` return `Some`). Custom messages that
67    /// cannot be serialized are skipped with a warning.
68    ///
69    /// Use `with_turn_count()`, `with_usage()`, and `with_cost()` to set
70    /// additional fields.
71    #[must_use]
72    pub fn new(
73        id: impl Into<String>,
74        system_prompt: impl Into<String>,
75        provider: impl Into<String>,
76        model_id: impl Into<String>,
77        messages: &[AgentMessage],
78    ) -> Self {
79        let serialized = message_codec::serialize_messages(messages, "checkpoint");
80
81        Self {
82            id: id.into(),
83            system_prompt: system_prompt.into(),
84            provider: provider.into(),
85            model_id: model_id.into(),
86            messages: serialized.llm_messages,
87            custom_messages: serialized.custom_messages,
88            message_order: serialized.message_order,
89            turn_count: 0,
90            usage: Usage::default(),
91            cost: Cost::default(),
92            created_at: crate::util::now_timestamp(),
93            metadata: HashMap::new(),
94            state: None,
95        }
96    }
97
98    /// Set the session state snapshot.
99    #[must_use]
100    pub fn with_state(mut self, state: serde_json::Value) -> Self {
101        self.state = Some(state);
102        self
103    }
104
105    /// Set the turn count.
106    #[must_use]
107    pub const fn with_turn_count(mut self, turn_count: usize) -> Self {
108        self.turn_count = turn_count;
109        self
110    }
111
112    /// Set the accumulated usage.
113    #[must_use]
114    pub fn with_usage(mut self, usage: Usage) -> Self {
115        self.usage = usage;
116        self
117    }
118
119    /// Set the accumulated cost.
120    #[must_use]
121    pub fn with_cost(mut self, cost: Cost) -> Self {
122        self.cost = cost;
123        self
124    }
125
126    /// Add metadata to this checkpoint.
127    #[must_use]
128    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
129        self.metadata.insert(key.into(), value);
130        self
131    }
132
133    /// Restore all messages as `AgentMessage` values, preserving their
134    /// original interleaved order.
135    ///
136    /// If `registry` is `None`, custom messages are silently skipped.
137    /// Deserialization failures are logged as warnings but do not cause errors.
138    ///
139    /// For checkpoints created before ordering support, falls back to
140    /// legacy behavior (LLM messages first, then custom messages appended).
141    #[must_use]
142    pub fn restore_messages(&self, registry: Option<&CustomMessageRegistry>) -> Vec<AgentMessage> {
143        message_codec::restore_messages(
144            &self.messages,
145            &self.custom_messages,
146            &self.message_order,
147            registry,
148            "checkpoint",
149        )
150    }
151}
152
153// ─── LoopCheckpoint ──────────────────────────────────────────────────────
154
155/// A serializable snapshot of the agent loop's in-flight state.
156///
157/// Captures everything needed to pause a running loop and resume it later:
158/// messages, pending injections, system prompt, model, and session state.
159/// Created by
160/// [`Agent::pause`](crate::Agent::pause) and consumed by
161/// [`Agent::resume`](crate::Agent::resume).
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct LoopCheckpoint {
164    /// All context messages at the time of pause.
165    pub messages: Vec<LlmMessage>,
166    /// Serialized custom messages (envelopes with `type` and `data` fields).
167    #[serde(default, skip_serializing_if = "Vec::is_empty")]
168    pub custom_messages: Vec<serde_json::Value>,
169    /// Records the original interleaved order of LLM and custom messages.
170    #[serde(default, skip_serializing_if = "Vec::is_empty")]
171    message_order: Vec<MessageSlot>,
172    /// Follow-up messages queued for injection into the next turn.
173    pub pending_messages: Vec<LlmMessage>,
174    /// Serialized custom follow-up messages queued for the next turn.
175    #[serde(default, skip_serializing_if = "Vec::is_empty")]
176    pending_custom_messages: Vec<serde_json::Value>,
177    /// Records the original interleaved order of pending follow-up messages.
178    #[serde(default, skip_serializing_if = "Vec::is_empty")]
179    pending_message_order: Vec<MessageSlot>,
180    /// Steering messages queued at the time of pause.
181    ///
182    /// Older checkpoints without this field deserialize with an empty vec
183    /// (backward compatible).
184    #[serde(default, skip_serializing_if = "Vec::is_empty")]
185    pub pending_steering_messages: Vec<LlmMessage>,
186    /// Serialized custom steering messages queued at the time of pause.
187    #[serde(default, skip_serializing_if = "Vec::is_empty")]
188    pending_steering_custom_messages: Vec<serde_json::Value>,
189    /// Records the original interleaved order of pending steering messages.
190    #[serde(default, skip_serializing_if = "Vec::is_empty")]
191    pending_steering_message_order: Vec<MessageSlot>,
192    /// The system prompt active at the time of pause.
193    pub system_prompt: String,
194    /// Model provider name.
195    pub provider: String,
196    /// Model identifier.
197    pub model_id: String,
198    /// Unix timestamp when the checkpoint was created.
199    pub created_at: u64,
200    /// Arbitrary metadata for application-specific use.
201    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
202    pub metadata: HashMap<String, serde_json::Value>,
203    /// Serialized session state snapshot (`SessionState.data`).
204    #[serde(default, skip_serializing_if = "Option::is_none")]
205    pub state: Option<serde_json::Value>,
206}
207
208impl LoopCheckpoint {
209    /// Create a loop checkpoint from the current agent state.
210    ///
211    /// Serializes `CustomMessage` variants that support serialization.
212    /// Non-serializable custom messages are skipped with a warning.
213    #[must_use]
214    pub fn new(
215        system_prompt: impl Into<String>,
216        provider: impl Into<String>,
217        model_id: impl Into<String>,
218        messages: &[AgentMessage],
219    ) -> Self {
220        let serialized = message_codec::serialize_messages(messages, "loop checkpoint");
221
222        Self {
223            messages: serialized.llm_messages,
224            custom_messages: serialized.custom_messages,
225            message_order: serialized.message_order,
226            pending_messages: Vec::new(),
227            pending_custom_messages: Vec::new(),
228            pending_message_order: Vec::new(),
229            pending_steering_messages: Vec::new(),
230            pending_steering_custom_messages: Vec::new(),
231            pending_steering_message_order: Vec::new(),
232            system_prompt: system_prompt.into(),
233            provider: provider.into(),
234            model_id: model_id.into(),
235            created_at: crate::util::now_timestamp(),
236            metadata: HashMap::new(),
237            state: None,
238        }
239    }
240
241    /// Set the session state snapshot.
242    #[must_use]
243    pub fn with_state(mut self, state: serde_json::Value) -> Self {
244        self.state = Some(state);
245        self
246    }
247
248    /// Set pending follow-up messages.
249    #[must_use]
250    pub fn with_pending_messages(mut self, pending: Vec<LlmMessage>) -> Self {
251        self.pending_messages = pending;
252        self.pending_custom_messages.clear();
253        self.pending_message_order.clear();
254        self
255    }
256
257    /// Set pending steering messages.
258    #[must_use]
259    pub fn with_pending_steering_messages(mut self, pending: Vec<LlmMessage>) -> Self {
260        self.pending_steering_messages = pending;
261        self.pending_steering_custom_messages.clear();
262        self.pending_steering_message_order.clear();
263        self
264    }
265
266    /// Set pending follow-up messages from a full `AgentMessage` batch.
267    #[must_use]
268    pub fn with_pending_message_batch(mut self, pending: &[AgentMessage]) -> Self {
269        let serialized = message_codec::serialize_messages(pending, "loop checkpoint pending");
270        self.pending_messages = serialized.llm_messages;
271        self.pending_custom_messages = serialized.custom_messages;
272        self.pending_message_order = serialized.message_order;
273        self
274    }
275
276    /// Set pending steering messages from a full `AgentMessage` batch.
277    #[must_use]
278    pub fn with_pending_steering_message_batch(mut self, pending: &[AgentMessage]) -> Self {
279        let serialized =
280            message_codec::serialize_messages(pending, "loop checkpoint steering pending");
281        self.pending_steering_messages = serialized.llm_messages;
282        self.pending_steering_custom_messages = serialized.custom_messages;
283        self.pending_steering_message_order = serialized.message_order;
284        self
285    }
286
287    /// Add metadata.
288    #[must_use]
289    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
290        self.metadata.insert(key.into(), value);
291        self
292    }
293
294    /// Restore all messages as `AgentMessage` values, preserving their
295    /// original interleaved order.
296    ///
297    /// If `registry` is `None`, custom messages are silently skipped.
298    #[must_use]
299    pub fn restore_messages(&self, registry: Option<&CustomMessageRegistry>) -> Vec<AgentMessage> {
300        message_codec::restore_messages(
301            &self.messages,
302            &self.custom_messages,
303            &self.message_order,
304            registry,
305            "loop checkpoint",
306        )
307    }
308
309    /// Restore pending follow-up messages as `AgentMessage` values.
310    #[must_use]
311    pub fn restore_pending_messages(
312        &self,
313        registry: Option<&CustomMessageRegistry>,
314    ) -> Vec<AgentMessage> {
315        message_codec::restore_messages(
316            &self.pending_messages,
317            &self.pending_custom_messages,
318            &self.pending_message_order,
319            registry,
320            "loop checkpoint pending",
321        )
322    }
323
324    /// Restore pending steering messages as `AgentMessage` values.
325    #[must_use]
326    pub fn restore_pending_steering_messages(
327        &self,
328        registry: Option<&CustomMessageRegistry>,
329    ) -> Vec<AgentMessage> {
330        message_codec::restore_messages(
331            &self.pending_steering_messages,
332            &self.pending_steering_custom_messages,
333            &self.pending_steering_message_order,
334            registry,
335            "loop checkpoint steering pending",
336        )
337    }
338
339    /// Convert this loop checkpoint into a standard [`Checkpoint`] for storage.
340    #[must_use]
341    pub fn to_checkpoint(&self, id: impl Into<String>) -> Checkpoint {
342        Checkpoint {
343            id: id.into(),
344            system_prompt: self.system_prompt.clone(),
345            provider: self.provider.clone(),
346            model_id: self.model_id.clone(),
347            messages: self.messages.clone(),
348            custom_messages: self.custom_messages.clone(),
349            message_order: self.message_order.clone(),
350            turn_count: 0,
351            usage: Usage::default(),
352            cost: Cost::default(),
353            created_at: self.created_at,
354            metadata: self.metadata.clone(),
355            state: self.state.clone(),
356        }
357    }
358}
359
360// ─── Send + Sync assertions ─────────────────────────────────────────────────
361
362const _: () = {
363    const fn assert_send_sync<T: Send + Sync>() {}
364    assert_send_sync::<Checkpoint>();
365    assert_send_sync::<LoopCheckpoint>();
366};
367
368// ─── Tests ───────────────────────────────────────────────────────────────────
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use crate::types::{ContentBlock, UserMessage};
374
375    #[derive(Debug)]
376    struct TestCustom;
377
378    impl crate::types::CustomMessage for TestCustom {
379        fn as_any(&self) -> &dyn std::any::Any {
380            self
381        }
382    }
383
384    fn sample_messages() -> Vec<AgentMessage> {
385        vec![
386            AgentMessage::Llm(LlmMessage::User(UserMessage {
387                content: vec![ContentBlock::Text {
388                    text: "Hello".to_string(),
389                }],
390                timestamp: 100,
391                cache_hint: None,
392            })),
393            AgentMessage::Llm(LlmMessage::Assistant(crate::types::AssistantMessage {
394                content: vec![ContentBlock::Text {
395                    text: "Hi there!".to_string(),
396                }],
397                provider: "test".to_string(),
398                model_id: "test-model".to_string(),
399                usage: Usage::default(),
400                cost: Cost::default(),
401                stop_reason: crate::types::StopReason::Stop,
402                error_message: None,
403                error_kind: None,
404                timestamp: 101,
405                cache_hint: None,
406            })),
407        ]
408    }
409
410    #[test]
411    fn checkpoint_creation_skips_non_serializable_custom_messages() {
412        let mut messages = sample_messages();
413        // Add a custom message without type_name/to_json — should be skipped
414        messages.push(AgentMessage::Custom(Box::new(TestCustom)));
415
416        let checkpoint = Checkpoint::new(
417            "cp-1",
418            "Be helpful.",
419            "anthropic",
420            "claude-sonnet",
421            &messages,
422        )
423        .with_turn_count(3);
424
425        assert_eq!(checkpoint.id, "cp-1");
426        assert_eq!(checkpoint.system_prompt, "Be helpful.");
427        assert_eq!(checkpoint.provider, "anthropic");
428        assert_eq!(checkpoint.model_id, "claude-sonnet");
429        assert_eq!(checkpoint.messages.len(), 2); // LLM messages only
430        assert!(checkpoint.custom_messages.is_empty()); // non-serializable skipped
431        assert_eq!(checkpoint.turn_count, 3);
432    }
433
434    #[test]
435    fn checkpoint_custom_message_roundtrip() {
436        use crate::types::CustomMessageRegistry;
437
438        #[derive(Debug, Clone, PartialEq)]
439        struct SerializableCustom {
440            value: String,
441        }
442
443        impl crate::types::CustomMessage for SerializableCustom {
444            fn as_any(&self) -> &dyn std::any::Any {
445                self
446            }
447            fn type_name(&self) -> Option<&str> {
448                Some("SerializableCustom")
449            }
450            fn to_json(&self) -> Option<serde_json::Value> {
451                Some(serde_json::json!({ "value": self.value }))
452            }
453        }
454
455        let mut messages = sample_messages();
456        messages.push(AgentMessage::Custom(Box::new(SerializableCustom {
457            value: "hello".to_string(),
458        })));
459
460        let checkpoint = Checkpoint::new("cp-custom", "prompt", "p", "m", &messages);
461
462        assert_eq!(checkpoint.messages.len(), 2);
463        assert_eq!(checkpoint.custom_messages.len(), 1);
464        assert_eq!(checkpoint.custom_messages[0]["type"], "SerializableCustom");
465        assert_eq!(checkpoint.custom_messages[0]["data"]["value"], "hello");
466
467        // Serde roundtrip preserves custom_messages
468        let json = serde_json::to_string(&checkpoint).unwrap();
469        let restored_cp: Checkpoint = serde_json::from_str(&json).unwrap();
470        assert_eq!(restored_cp.custom_messages.len(), 1);
471
472        // Restore with registry
473        let mut registry = CustomMessageRegistry::new();
474        registry.register(
475            "SerializableCustom",
476            Box::new(|val: serde_json::Value| {
477                let value = val
478                    .get("value")
479                    .and_then(|v| v.as_str())
480                    .ok_or_else(|| "missing value".to_string())?;
481                Ok(Box::new(SerializableCustom {
482                    value: value.to_string(),
483                }) as Box<dyn crate::types::CustomMessage>)
484            }),
485        );
486
487        let restored = restored_cp.restore_messages(Some(&registry));
488        assert_eq!(restored.len(), 3);
489        assert!(matches!(
490            restored[0],
491            AgentMessage::Llm(LlmMessage::User(_))
492        ));
493        assert!(matches!(
494            restored[1],
495            AgentMessage::Llm(LlmMessage::Assistant(_))
496        ));
497        let custom = restored[2].downcast_ref::<SerializableCustom>().unwrap();
498        assert_eq!(custom.value, "hello");
499
500        // Restore without registry — custom messages skipped
501        let restored_no_reg = restored_cp.restore_messages(None);
502        assert_eq!(restored_no_reg.len(), 2);
503    }
504
505    #[test]
506    fn checkpoint_serde_roundtrip() {
507        let messages = sample_messages();
508        let checkpoint = Checkpoint::new(
509            "cp-roundtrip",
510            "System prompt",
511            "openai",
512            "gpt-4",
513            &messages,
514        )
515        .with_turn_count(5)
516        .with_usage(Usage {
517            input: 100,
518            output: 50,
519            ..Default::default()
520        })
521        .with_cost(Cost {
522            input: 0.01,
523            output: 0.005,
524            ..Default::default()
525        })
526        .with_metadata("session_id", serde_json::json!("sess-abc"));
527
528        let json = serde_json::to_string(&checkpoint).unwrap();
529        let restored: Checkpoint = serde_json::from_str(&json).unwrap();
530
531        assert_eq!(restored.id, "cp-roundtrip");
532        assert_eq!(restored.system_prompt, "System prompt");
533        assert_eq!(restored.messages.len(), 2);
534        assert_eq!(restored.turn_count, 5);
535        assert_eq!(restored.usage.input, 100);
536        assert_eq!(restored.usage.output, 50);
537        assert_eq!(restored.metadata["session_id"], "sess-abc");
538    }
539
540    #[test]
541    fn restore_messages_wraps_in_agent_message() {
542        let messages = sample_messages();
543        let checkpoint =
544            Checkpoint::new("cp-restore", "prompt", "p", "m", &messages).with_turn_count(1);
545
546        let restored = checkpoint.restore_messages(None);
547        assert_eq!(restored.len(), 2);
548        assert!(matches!(
549            restored[0],
550            AgentMessage::Llm(LlmMessage::User(_))
551        ));
552        assert!(matches!(
553            restored[1],
554            AgentMessage::Llm(LlmMessage::Assistant(_))
555        ));
556    }
557
558    #[test]
559    fn checkpoint_with_metadata_builder() {
560        let checkpoint = Checkpoint::new("cp-meta", "p", "p", "m", &[])
561            .with_metadata("key1", serde_json::json!("value1"))
562            .with_metadata("key2", serde_json::json!(42));
563
564        assert_eq!(checkpoint.metadata.len(), 2);
565        assert_eq!(checkpoint.metadata["key1"], "value1");
566        assert_eq!(checkpoint.metadata["key2"], 42);
567    }
568
569    #[test]
570    fn checkpoint_backward_compat_no_metadata() {
571        // JSON without metadata field should deserialize fine
572        let json = r#"{
573            "id": "cp-compat",
574            "system_prompt": "hello",
575            "provider": "p",
576            "model_id": "m",
577            "messages": [],
578            "turn_count": 0,
579            "usage": {"input": 0, "output": 0, "cache_read": 0, "cache_write": 0, "total": 0},
580            "cost": {"input": 0.0, "output": 0.0, "cache_read": 0.0, "cache_write": 0.0, "total": 0.0},
581            "created_at": 100
582        }"#;
583
584        let checkpoint: Checkpoint = serde_json::from_str(json).unwrap();
585        assert!(checkpoint.metadata.is_empty());
586        assert!(checkpoint.custom_messages.is_empty());
587    }
588
589    // ─── LoopCheckpoint Tests ────────────────────────────────────────────
590
591    #[test]
592    fn loop_checkpoint_creation_skips_non_serializable_custom_messages() {
593        let mut messages = sample_messages();
594        messages.push(AgentMessage::Custom(Box::new(TestCustom)));
595
596        let cp = LoopCheckpoint::new("Be helpful.", "anthropic", "claude-sonnet", &messages)
597            .with_pending_messages(vec![LlmMessage::User(UserMessage {
598                content: vec![ContentBlock::Text {
599                    text: "continue".to_string(),
600                }],
601                timestamp: 123,
602                cache_hint: None,
603            })]);
604
605        assert_eq!(cp.messages.len(), 2);
606        assert!(cp.custom_messages.is_empty());
607        assert_eq!(cp.pending_messages.len(), 1);
608        assert_eq!(cp.system_prompt, "Be helpful.");
609        assert_eq!(cp.provider, "anthropic");
610        assert_eq!(cp.model_id, "claude-sonnet");
611    }
612
613    #[test]
614    fn loop_checkpoint_serde_roundtrip() {
615        let messages = sample_messages();
616        let cp = LoopCheckpoint::new("System prompt", "openai", "gpt-4", &messages)
617            .with_pending_messages(vec![LlmMessage::User(UserMessage {
618                content: vec![ContentBlock::Text {
619                    text: "follow-up".to_string(),
620                }],
621                timestamp: 200,
622                cache_hint: None,
623            })])
624            .with_metadata("workflow_id", serde_json::json!("wf-123"));
625
626        let json = serde_json::to_string(&cp).unwrap();
627        let restored: LoopCheckpoint = serde_json::from_str(&json).unwrap();
628
629        assert_eq!(restored.messages.len(), 2);
630        assert_eq!(restored.pending_messages.len(), 1);
631        assert_eq!(restored.system_prompt, "System prompt");
632        assert_eq!(restored.metadata["workflow_id"], "wf-123");
633    }
634
635    #[test]
636    fn loop_checkpoint_restore_messages() {
637        let messages = sample_messages();
638        let cp = LoopCheckpoint::new("p", "p", "m", &messages);
639
640        let restored = cp.restore_messages(None);
641        assert_eq!(restored.len(), 2);
642        assert!(matches!(
643            restored[0],
644            AgentMessage::Llm(LlmMessage::User(_))
645        ));
646        assert!(matches!(
647            restored[1],
648            AgentMessage::Llm(LlmMessage::Assistant(_))
649        ));
650    }
651
652    #[test]
653    fn loop_checkpoint_steering_messages_roundtrip() {
654        let steering = vec![LlmMessage::User(UserMessage {
655            content: vec![ContentBlock::Text {
656                text: "steer-me".to_string(),
657            }],
658            timestamp: 300,
659            cache_hint: None,
660        })];
661        let follow_up = vec![LlmMessage::User(UserMessage {
662            content: vec![ContentBlock::Text {
663                text: "follow-up".to_string(),
664            }],
665            timestamp: 301,
666            cache_hint: None,
667        })];
668
669        let cp = LoopCheckpoint::new("p", "p", "m", &[])
670            .with_pending_messages(follow_up)
671            .with_pending_steering_messages(steering);
672
673        // Serde roundtrip
674        let json = serde_json::to_string(&cp).unwrap();
675        let restored: LoopCheckpoint = serde_json::from_str(&json).unwrap();
676
677        assert_eq!(restored.pending_messages.len(), 1);
678        assert_eq!(restored.pending_steering_messages.len(), 1);
679
680        let restored_steering = restored.restore_pending_steering_messages(None);
681        assert_eq!(restored_steering.len(), 1);
682        assert!(matches!(
683            restored_steering[0],
684            AgentMessage::Llm(LlmMessage::User(_))
685        ));
686    }
687
688    #[test]
689    fn loop_checkpoint_backward_compat_no_steering_field() {
690        // Simulate a checkpoint created before pending_steering_messages existed
691        let cp =
692            LoopCheckpoint::new("p", "p", "m", &[]).with_pending_messages(vec![LlmMessage::User(
693                UserMessage {
694                    content: vec![ContentBlock::Text {
695                        text: "old-follow-up".to_string(),
696                    }],
697                    timestamp: 100,
698                    cache_hint: None,
699                },
700            )]);
701
702        let mut json_val = serde_json::to_value(&cp).unwrap();
703        // Strip the field to simulate old format
704        json_val
705            .as_object_mut()
706            .unwrap()
707            .remove("pending_steering_messages");
708        let legacy: LoopCheckpoint = serde_json::from_value(json_val).unwrap();
709
710        assert!(
711            legacy.pending_steering_messages.is_empty(),
712            "missing steering field should default to empty"
713        );
714        assert_eq!(legacy.pending_messages.len(), 1);
715    }
716
717    #[test]
718    fn loop_checkpoint_pending_messages_roundtrip() {
719        let pending = vec![LlmMessage::User(UserMessage {
720            content: vec![ContentBlock::Text {
721                text: "follow-up".to_string(),
722            }],
723            timestamp: 200,
724            cache_hint: None,
725        })];
726
727        let cp = LoopCheckpoint::new("p", "p", "m", &[]).with_pending_messages(pending);
728
729        let restored_pending = cp.restore_pending_messages(None);
730        assert_eq!(restored_pending.len(), 1);
731        assert!(matches!(
732            restored_pending[0],
733            AgentMessage::Llm(LlmMessage::User(_))
734        ));
735    }
736
737    #[test]
738    fn loop_checkpoint_pending_messages_restore_custom_with_registry() {
739        let (registry, factory) = make_registry_and_custom("test");
740        let pending = vec![
741            user_msg("follow-up"),
742            AgentMessage::Custom(factory("pending-custom")),
743        ];
744
745        let cp = LoopCheckpoint::new("p", "p", "m", &[]).with_pending_message_batch(&pending);
746        let restored_pending = cp.restore_pending_messages(Some(&registry));
747
748        assert_eq!(restored_pending.len(), 2);
749        let order: Vec<String> = restored_pending.iter().map(message_text).collect();
750        assert_eq!(order, vec!["user:follow-up", "custom:pending-custom"]);
751    }
752
753    #[test]
754    fn loop_checkpoint_to_standard_checkpoint() {
755        let messages = sample_messages();
756        let cp = LoopCheckpoint::new("prompt", "anthropic", "claude", &messages)
757            .with_metadata("key", serde_json::json!("val"));
758
759        let standard = cp.to_checkpoint("cp-from-loop");
760        assert_eq!(standard.id, "cp-from-loop");
761        assert_eq!(standard.system_prompt, "prompt");
762        assert_eq!(standard.turn_count, 0);
763        assert_eq!(standard.usage.input, 0);
764        assert_eq!(standard.messages.len(), 2);
765        assert_eq!(standard.metadata["key"], "val");
766    }
767
768    // ─── Interleaved ordering regression tests (issue #51) ──────────────
769
770    fn make_registry_and_custom(
771        tag: &str,
772    ) -> (
773        CustomMessageRegistry,
774        impl Fn(&str) -> Box<dyn crate::types::CustomMessage>,
775    ) {
776        #[derive(Debug, Clone, PartialEq)]
777        struct TaggedCustom {
778            tag: String,
779        }
780
781        impl crate::types::CustomMessage for TaggedCustom {
782            fn as_any(&self) -> &dyn std::any::Any {
783                self
784            }
785            fn type_name(&self) -> Option<&str> {
786                Some("TaggedCustom")
787            }
788            fn to_json(&self) -> Option<serde_json::Value> {
789                Some(serde_json::json!({ "tag": self.tag }))
790            }
791        }
792
793        let _ = tag; // suppress unused warning
794        let mut registry = CustomMessageRegistry::new();
795        registry.register(
796            "TaggedCustom",
797            Box::new(|val: serde_json::Value| {
798                let tag = val
799                    .get("tag")
800                    .and_then(|v| v.as_str())
801                    .ok_or_else(|| "missing tag".to_string())?;
802                Ok(Box::new(TaggedCustom {
803                    tag: tag.to_string(),
804                }) as Box<dyn crate::types::CustomMessage>)
805            }),
806        );
807
808        let factory = |tag: &str| -> Box<dyn crate::types::CustomMessage> {
809            Box::new(TaggedCustom {
810                tag: tag.to_string(),
811            })
812        };
813
814        (registry, factory)
815    }
816
817    fn user_msg(text: &str) -> AgentMessage {
818        AgentMessage::Llm(LlmMessage::User(UserMessage {
819            content: vec![ContentBlock::Text {
820                text: text.to_string(),
821            }],
822            timestamp: 0,
823            cache_hint: None,
824        }))
825    }
826
827    fn assistant_msg(text: &str) -> AgentMessage {
828        AgentMessage::Llm(LlmMessage::Assistant(crate::types::AssistantMessage {
829            content: vec![ContentBlock::Text {
830                text: text.to_string(),
831            }],
832            provider: "test".to_string(),
833            model_id: "test-model".to_string(),
834            usage: Usage::default(),
835            cost: Cost::default(),
836            stop_reason: crate::types::StopReason::Stop,
837            error_message: None,
838            error_kind: None,
839            timestamp: 0,
840            cache_hint: None,
841        }))
842    }
843
844    /// Extracts text from an `AgentMessage` for assertion purposes.
845    fn message_text(msg: &AgentMessage) -> String {
846        match msg {
847            AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
848                ContentBlock::Text { text } => format!("user:{text}"),
849                _ => "user:?".to_string(),
850            },
851            AgentMessage::Llm(LlmMessage::Assistant(a)) => match &a.content[0] {
852                ContentBlock::Text { text } => format!("assistant:{text}"),
853                _ => "assistant:?".to_string(),
854            },
855            AgentMessage::Custom(c) => {
856                if let Some(json) = c.to_json() {
857                    format!("custom:{}", json["tag"].as_str().unwrap_or("?"))
858                } else {
859                    "custom:?".to_string()
860                }
861            }
862            _ => "other".to_string(),
863        }
864    }
865
866    #[test]
867    fn checkpoint_preserves_interleaved_custom_message_order() {
868        let (registry, factory) = make_registry_and_custom("test");
869
870        // Interleaved: User, Custom("A"), Assistant, Custom("B"), User
871        let messages = vec![
872            user_msg("hello"),
873            AgentMessage::Custom(factory("A")),
874            assistant_msg("hi"),
875            AgentMessage::Custom(factory("B")),
876            user_msg("thanks"),
877        ];
878
879        let checkpoint = Checkpoint::new("cp-order", "prompt", "p", "m", &messages);
880
881        // Serde roundtrip
882        let json = serde_json::to_string(&checkpoint).unwrap();
883        let restored_cp: Checkpoint = serde_json::from_str(&json).unwrap();
884
885        let restored = restored_cp.restore_messages(Some(&registry));
886        let order: Vec<String> = restored.iter().map(message_text).collect();
887
888        assert_eq!(
889            order,
890            vec![
891                "user:hello",
892                "custom:A",
893                "assistant:hi",
894                "custom:B",
895                "user:thanks",
896            ],
897            "interleaved custom messages must preserve their original position"
898        );
899    }
900
901    #[test]
902    fn loop_checkpoint_preserves_interleaved_custom_message_order() {
903        let (registry, factory) = make_registry_and_custom("test");
904
905        let messages = vec![
906            user_msg("q1"),
907            AgentMessage::Custom(factory("mid")),
908            assistant_msg("a1"),
909        ];
910
911        let cp = LoopCheckpoint::new("prompt", "p", "m", &messages);
912
913        let json = serde_json::to_string(&cp).unwrap();
914        let restored_cp: LoopCheckpoint = serde_json::from_str(&json).unwrap();
915
916        let restored = restored_cp.restore_messages(Some(&registry));
917        let order: Vec<String> = restored.iter().map(message_text).collect();
918
919        assert_eq!(
920            order,
921            vec!["user:q1", "custom:mid", "assistant:a1"],
922            "LoopCheckpoint must preserve interleaved custom message order"
923        );
924    }
925
926    #[test]
927    fn loop_checkpoint_to_checkpoint_preserves_order() {
928        let (registry, factory) = make_registry_and_custom("test");
929
930        let messages = vec![
931            AgentMessage::Custom(factory("first")),
932            user_msg("hello"),
933            AgentMessage::Custom(factory("second")),
934        ];
935
936        let loop_cp = LoopCheckpoint::new("prompt", "p", "m", &messages);
937        let standard = loop_cp.to_checkpoint("cp-conv");
938
939        let restored = standard.restore_messages(Some(&registry));
940        let order: Vec<String> = restored.iter().map(message_text).collect();
941
942        assert_eq!(
943            order,
944            vec!["custom:first", "user:hello", "custom:second"],
945            "to_checkpoint conversion must preserve message order"
946        );
947    }
948
949    #[test]
950    fn backward_compat_no_message_order_field() {
951        // Create a checkpoint with interleaved messages, then strip the
952        // message_order field to simulate a legacy checkpoint.
953        let (registry, factory) = make_registry_and_custom("test");
954
955        let messages = vec![user_msg("hi"), AgentMessage::Custom(factory("legacy"))];
956
957        let checkpoint = Checkpoint::new("cp-legacy", "hello", "p", "m", &messages);
958        // Serialize, strip message_order, deserialize — simulates old format
959        let mut json_val = serde_json::to_value(&checkpoint).unwrap();
960        json_val.as_object_mut().unwrap().remove("message_order");
961        let legacy_cp: Checkpoint = serde_json::from_value(json_val).unwrap();
962
963        // message_order should be empty (stripped)
964        assert!(legacy_cp.message_order.is_empty());
965
966        let restored = legacy_cp.restore_messages(Some(&registry));
967        // Legacy fallback: LLM first, then custom appended
968        assert_eq!(restored.len(), 2);
969        assert!(matches!(
970            restored[0],
971            AgentMessage::Llm(LlmMessage::User(_))
972        ));
973        let order: Vec<String> = restored.iter().map(message_text).collect();
974        assert_eq!(order, vec!["user:hi", "custom:legacy"]);
975    }
976
977    #[test]
978    fn restore_without_registry_skips_custom_in_ordered_mode() {
979        let (_registry, factory) = make_registry_and_custom("test");
980
981        let messages = vec![
982            user_msg("hello"),
983            AgentMessage::Custom(factory("skipped")),
984            assistant_msg("world"),
985        ];
986
987        let checkpoint = Checkpoint::new("cp-no-reg", "prompt", "p", "m", &messages);
988        let restored = checkpoint.restore_messages(None);
989
990        // Custom messages are skipped when no registry is provided
991        assert_eq!(restored.len(), 2);
992        let order: Vec<String> = restored.iter().map(message_text).collect();
993        assert_eq!(order, vec!["user:hello", "assistant:world"]);
994    }
995}