Skip to main content

swink_agent/agent/
checkpointing.rs

1use std::pin::Pin;
2use std::sync::atomic::Ordering;
3
4use futures::Stream;
5
6use crate::checkpoint::{Checkpoint, CheckpointStore};
7use crate::error::AgentError;
8use crate::loop_::AgentEvent;
9
10use super::Agent;
11use super::queueing::drain_messages_from_queue;
12
13fn invalid_state_snapshot(error: &serde_json::Error) -> std::io::Error {
14    std::io::Error::new(
15        std::io::ErrorKind::InvalidData,
16        format!("corrupted session state snapshot: {error}"),
17    )
18}
19
20fn restore_session_state(
21    snapshot: Option<&serde_json::Value>,
22) -> Result<crate::SessionState, std::io::Error> {
23    snapshot.map_or_else(
24        || Ok(crate::SessionState::new()),
25        |state_val| {
26            crate::SessionState::restore_from_snapshot(state_val.clone())
27                .map_err(|e| invalid_state_snapshot(&e))
28        },
29    )
30}
31
32impl Agent {
33    /// Rebind `self.stream_fn` if the current model's `provider`/`model_id`
34    /// matches one of the registered `model_stream_fns`.
35    fn rebind_stream_fn_for_current_model(&mut self) {
36        if let Some((_, stream_fn)) = self.model_stream_fns.iter().find(|(m, _)| {
37            m.provider == self.state.model.provider && m.model_id == self.state.model.model_id
38        }) {
39            self.stream_fn = std::sync::Arc::clone(stream_fn);
40        }
41    }
42
43    // ── Checkpointing ────────────────────────────────────────────────────
44
45    /// Create a checkpoint of the current agent state.
46    ///
47    /// If a [`CheckpointStore`] is configured, the checkpoint is also persisted.
48    /// Returns the checkpoint regardless of whether a store is configured.
49    pub async fn save_checkpoint(
50        &self,
51        id: impl Into<String>,
52    ) -> Result<Checkpoint, std::io::Error> {
53        let mut checkpoint = Checkpoint::new(
54            id,
55            &self.state.system_prompt,
56            &self.state.model.provider,
57            &self.state.model.model_id,
58            &self.state.messages,
59        );
60
61        {
62            let s = self
63                .session_state
64                .read()
65                .unwrap_or_else(std::sync::PoisonError::into_inner);
66            if !s.is_empty() {
67                checkpoint.state = Some(s.snapshot());
68            }
69        }
70
71        if let Some(ref store) = self.checkpoint_store {
72            store.save_checkpoint(checkpoint.clone()).await?;
73        }
74
75        Ok(checkpoint)
76    }
77
78    fn ensure_idle_for_checkpoint_restore(&mut self) -> Result<(), std::io::Error> {
79        self.check_not_running().map_err(|_| {
80            std::io::Error::new(
81                std::io::ErrorKind::WouldBlock,
82                "cannot restore checkpoint while agent is running",
83            )
84        })
85    }
86
87    /// Restore agent message history from a checkpoint.
88    ///
89    /// Replaces the current messages with those from the checkpoint and
90    /// updates the system prompt to match. If the checkpoint's model
91    /// matches one of the [`available_models`](crate::AgentOptions::with_available_models),
92    /// the stream function is rebound automatically; otherwise the current
93    /// stream function is left in place. Persisted custom messages are
94    /// restored when a [`CustomMessageRegistry`](crate::types::CustomMessageRegistry)
95    /// has been configured on [`AgentOptions`](crate::AgentOptions) via
96    /// [`with_custom_message_registry`](crate::AgentOptions::with_custom_message_registry);
97    /// otherwise they are dropped. Returns [`std::io::ErrorKind::WouldBlock`]
98    /// if a loop is still active; callers must wait for the agent to become
99    /// idle before restoring a checkpoint into it.
100    pub fn restore_from_checkpoint(
101        &mut self,
102        checkpoint: &Checkpoint,
103    ) -> Result<(), std::io::Error> {
104        self.ensure_idle_for_checkpoint_restore()?;
105        let restored_messages =
106            checkpoint.restore_messages(self.custom_message_registry.as_deref());
107        let restored_state = restore_session_state(checkpoint.state.as_ref())?;
108
109        self.clear_transient_runtime_state();
110        self.state.messages = restored_messages;
111        self.state
112            .system_prompt
113            .clone_from(&checkpoint.system_prompt);
114        self.state.model.provider.clone_from(&checkpoint.provider);
115        self.state.model.model_id.clone_from(&checkpoint.model_id);
116        self.rebind_stream_fn_for_current_model();
117        *self
118            .session_state
119            .write()
120            .unwrap_or_else(std::sync::PoisonError::into_inner) = restored_state;
121
122        Ok(())
123    }
124
125    /// Load a checkpoint from the configured store and restore state from it.
126    ///
127    /// Returns the loaded checkpoint, or `None` if not found.
128    /// Returns an error if no checkpoint store is configured. Returns
129    /// [`std::io::ErrorKind::WouldBlock`] if the agent is still running.
130    pub async fn load_and_restore_checkpoint(
131        &mut self,
132        id: &str,
133    ) -> Result<Option<Checkpoint>, std::io::Error> {
134        self.ensure_idle_for_checkpoint_restore()?;
135        let store = self
136            .checkpoint_store
137            .as_ref()
138            .ok_or_else(|| std::io::Error::other("no checkpoint store configured"))?;
139
140        let maybe = store.load_checkpoint(id).await?;
141        if let Some(ref checkpoint) = maybe {
142            self.restore_from_checkpoint(checkpoint)?;
143        }
144        Ok(maybe)
145    }
146
147    /// Access the checkpoint store, if configured.
148    #[must_use]
149    pub fn checkpoint_store(&self) -> Option<&dyn CheckpointStore> {
150        self.checkpoint_store.as_deref()
151    }
152
153    /// Pause the currently running loop and capture its state as a [`crate::checkpoint::LoopCheckpoint`].
154    ///
155    /// Signals the loop to stop via the cancellation token and snapshots the
156    /// agent's messages, system prompt, and queued LLM messages into a serializable
157    /// checkpoint. The checkpoint can later be passed to [`resume`](Self::resume)
158    /// to continue the loop from where it left off.
159    ///
160    /// The agent remains in the *running* state after this call. It becomes idle
161    /// when the caller either drains the event stream to completion or drops the
162    /// stream returned by [`prompt_stream`](Self::prompt_stream). This prevents a
163    /// new run from starting while the previous loop is still tearing down.
164    ///
165    /// Returns `None` if the agent is not currently running.
166    pub fn pause(&mut self) -> Option<crate::checkpoint::LoopCheckpoint> {
167        if !self.loop_active.load(Ordering::Acquire) {
168            return None;
169        }
170
171        if let Some(ref token) = self.abort_controller {
172            tracing::info!("pausing agent loop");
173            token.cancel();
174        }
175
176        let mut pending_messages = self.pending_message_snapshot.snapshot();
177        pending_messages.extend(drain_messages_from_queue(&self.follow_up_queue));
178
179        // Prefer the loop_context_snapshot when available: it is updated
180        // immediately after pending messages are drained into loop-local
181        // context_messages, closing the window where a concurrent pause() would
182        // miss those messages (they've left the shared pending queue but haven't
183        // yet been delivered back via a TurnEnd event that updates
184        // in_flight_messages).
185        let loop_ctx = self.loop_context_snapshot.snapshot();
186        let checkpoint_messages: &[crate::types::AgentMessage] = if let Some(ref ctx) = loop_ctx {
187            ctx.as_slice()
188        } else {
189            self.in_flight_messages
190                .as_deref()
191                .unwrap_or(&self.state.messages)
192        };
193
194        let mut checkpoint = crate::checkpoint::LoopCheckpoint::new(
195            &self.state.system_prompt,
196            &self.state.model.provider,
197            &self.state.model.model_id,
198            checkpoint_messages,
199        )
200        .with_pending_message_batch(&pending_messages)
201        .with_pending_steering_message_batch(&drain_messages_from_queue(&self.steering_queue));
202
203        let s = self
204            .session_state
205            .read()
206            .unwrap_or_else(std::sync::PoisonError::into_inner);
207        if !s.is_empty() {
208            checkpoint.state = Some(s.snapshot());
209        }
210        drop(s);
211
212        // Do NOT clear is_running / abort_controller / notify idle here.
213        // The agent stays "running" until the LoopGuardStream is dropped or
214        // the stream is drained to AgentEnd, which guarantees the spawned loop
215        // task has finished using the channel before a new run can start.
216
217        Some(checkpoint)
218    }
219
220    /// Resume the agent loop from a previously captured [`crate::checkpoint::LoopCheckpoint`].
221    pub async fn resume(
222        &mut self,
223        checkpoint: &crate::checkpoint::LoopCheckpoint,
224    ) -> Result<crate::types::AgentResult, AgentError> {
225        self.check_not_running()?;
226        self.restore_from_loop_checkpoint(checkpoint)?;
227        self.continue_async().await
228    }
229
230    /// Resume the agent loop from a checkpoint, returning an event stream.
231    pub fn resume_stream(
232        &mut self,
233        checkpoint: &crate::checkpoint::LoopCheckpoint,
234    ) -> Result<Pin<Box<dyn Stream<Item = AgentEvent> + Send>>, AgentError> {
235        self.check_not_running()?;
236        self.restore_from_loop_checkpoint(checkpoint)?;
237        self.continue_stream()
238    }
239
240    fn restore_from_loop_checkpoint(
241        &mut self,
242        checkpoint: &crate::checkpoint::LoopCheckpoint,
243    ) -> Result<(), AgentError> {
244        let restored_messages =
245            checkpoint.restore_messages(self.custom_message_registry.as_deref());
246        if restored_messages.is_empty() {
247            return Err(AgentError::NoMessages);
248        }
249        let restored_state =
250            restore_session_state(checkpoint.state.as_ref()).map_err(AgentError::stream)?;
251
252        self.clear_transient_runtime_state();
253        self.state.messages = restored_messages;
254        self.state
255            .system_prompt
256            .clone_from(&checkpoint.system_prompt);
257        self.state.model.provider.clone_from(&checkpoint.provider);
258        self.state.model.model_id.clone_from(&checkpoint.model_id);
259        self.rebind_stream_fn_for_current_model();
260        {
261            let mut s = self
262                .session_state
263                .write()
264                .unwrap_or_else(std::sync::PoisonError::into_inner);
265            *s = restored_state;
266        }
267
268        // Clear live queues before re-enqueueing from the checkpoint so that
269        // an in-process pause→resume cycle does not duplicate pending work.
270        self.clear_queues();
271
272        for msg in checkpoint.restore_pending_messages(self.custom_message_registry.as_deref()) {
273            self.follow_up(msg);
274        }
275        for msg in
276            checkpoint.restore_pending_steering_messages(self.custom_message_registry.as_deref())
277        {
278            self.steer(msg);
279        }
280
281        tracing::info!(
282            messages = self.state.messages.len(),
283            "resuming agent loop from checkpoint"
284        );
285
286        Ok(())
287    }
288}
289
290#[cfg(all(test, feature = "testkit"))]
291mod tests {
292    use std::collections::HashMap;
293    use std::sync::Arc;
294    use std::sync::Mutex;
295
296    use tokio_util::sync::CancellationToken;
297
298    use crate::agent::Agent;
299    use crate::agent_options::AgentOptions;
300    use crate::checkpoint::{CheckpointFuture, CheckpointStore, LoopCheckpoint};
301    use crate::testing::SimpleMockStreamFn;
302    use crate::types::{
303        AgentMessage, CustomMessage, CustomMessageRegistry, LlmMessage, ModelSpec, UserMessage,
304    };
305    use crate::{AgentError, Checkpoint};
306
307    #[derive(Debug, Clone, PartialEq)]
308    struct Tagged {
309        value: String,
310    }
311
312    impl CustomMessage for Tagged {
313        fn as_any(&self) -> &dyn std::any::Any {
314            self
315        }
316        fn type_name(&self) -> Option<&str> {
317            Some("Tagged")
318        }
319        fn to_json(&self) -> Option<serde_json::Value> {
320            Some(serde_json::json!({ "value": self.value }))
321        }
322        fn clone_box(&self) -> Option<Box<dyn CustomMessage>> {
323            Some(Box::new(self.clone()))
324        }
325    }
326
327    fn tagged_registry() -> CustomMessageRegistry {
328        let mut reg = CustomMessageRegistry::new();
329        reg.register(
330            "Tagged",
331            Box::new(|val: serde_json::Value| {
332                let value = val
333                    .get("value")
334                    .and_then(|v| v.as_str())
335                    .ok_or_else(|| "missing value".to_string())?;
336                Ok(Box::new(Tagged {
337                    value: value.to_string(),
338                }) as Box<dyn CustomMessage>)
339            }),
340        );
341        reg
342    }
343
344    fn make_agent(registry: Option<CustomMessageRegistry>) -> Agent {
345        let stream_fn = Arc::new(SimpleMockStreamFn::from_text("ok"));
346        let mut opts =
347            AgentOptions::new_simple("system", ModelSpec::new("mock", "mock-model"), stream_fn);
348        if let Some(reg) = registry {
349            opts = opts.with_custom_message_registry(reg);
350        }
351        Agent::new(opts)
352    }
353
354    fn user_msg(text: &str) -> AgentMessage {
355        AgentMessage::Llm(LlmMessage::User(UserMessage {
356            content: vec![crate::types::ContentBlock::Text {
357                text: text.to_string(),
358            }],
359            timestamp: 0,
360            cache_hint: None,
361        }))
362    }
363
364    fn seed_transient_runtime_state(agent: &mut Agent) {
365        agent.state.is_running = true;
366        agent.state.stream_message = Some(user_msg("streaming"));
367        agent
368            .state
369            .pending_tool_calls
370            .insert("tool-call-1".to_string());
371        agent.state.error = Some("stale error".to_string());
372        agent.abort_controller = Some(CancellationToken::new());
373        agent.in_flight_llm_messages = Some(vec![user_msg("in-flight-llm")]);
374        agent.in_flight_messages = Some(vec![user_msg("in-flight-checkpoint")]);
375    }
376
377    #[derive(Default)]
378    struct TestCheckpointStore {
379        data: Mutex<HashMap<String, String>>,
380    }
381
382    impl CheckpointStore for TestCheckpointStore {
383        fn save_checkpoint(&self, checkpoint: Checkpoint) -> CheckpointFuture<'_, ()> {
384            let json = serde_json::to_string(&checkpoint).unwrap();
385            let id = checkpoint.id;
386            Box::pin(async move {
387                self.data
388                    .lock()
389                    .unwrap_or_else(std::sync::PoisonError::into_inner)
390                    .insert(id, json);
391                Ok(())
392            })
393        }
394
395        fn load_checkpoint(&self, id: &str) -> CheckpointFuture<'_, Option<Checkpoint>> {
396            let id = id.to_string();
397            Box::pin(async move {
398                self.data
399                    .lock()
400                    .unwrap_or_else(std::sync::PoisonError::into_inner)
401                    .get(&id)
402                    .map(|json| serde_json::from_str(json).map_err(std::io::Error::other))
403                    .transpose()
404            })
405        }
406
407        fn list_checkpoints(&self) -> CheckpointFuture<'_, Vec<String>> {
408            Box::pin(async move {
409                Ok(self
410                    .data
411                    .lock()
412                    .unwrap_or_else(std::sync::PoisonError::into_inner)
413                    .keys()
414                    .cloned()
415                    .collect())
416            })
417        }
418
419        fn delete_checkpoint(&self, id: &str) -> CheckpointFuture<'_, ()> {
420            let id = id.to_string();
421            Box::pin(async move {
422                self.data
423                    .lock()
424                    .unwrap_or_else(std::sync::PoisonError::into_inner)
425                    .remove(&id);
426                Ok(())
427            })
428        }
429    }
430
431    #[tokio::test]
432    async fn restore_from_checkpoint_rehydrates_custom_messages_via_registry() {
433        let mut source = make_agent(None);
434        source
435            .state
436            .messages
437            .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
438                content: vec![crate::types::ContentBlock::Text {
439                    text: "hi".to_string(),
440                }],
441                timestamp: 0,
442                cache_hint: None,
443            })));
444        source
445            .state
446            .messages
447            .push(AgentMessage::Custom(Box::new(Tagged {
448                value: "preserved".to_string(),
449            })));
450
451        let checkpoint = source.save_checkpoint("cp-1").await.unwrap();
452        let json = serde_json::to_string(&checkpoint).unwrap();
453        let loaded: crate::checkpoint::Checkpoint = serde_json::from_str(&json).unwrap();
454        assert_eq!(loaded.custom_messages.len(), 1);
455
456        // Without a registry the custom message is dropped (legacy behavior).
457        let mut no_reg = make_agent(None);
458        no_reg.restore_from_checkpoint(&loaded).unwrap();
459        assert_eq!(no_reg.state.messages.len(), 1);
460
461        // With a registry configured on AgentOptions, the custom message
462        // survives restoration through the public API.
463        let mut with_reg = make_agent(Some(tagged_registry()));
464        with_reg.restore_from_checkpoint(&loaded).unwrap();
465        assert_eq!(with_reg.state.messages.len(), 2);
466        let restored = with_reg.state.messages[1]
467            .downcast_ref::<Tagged>()
468            .expect("custom message should be restored via registry");
469        assert_eq!(restored.value, "preserved");
470    }
471
472    #[tokio::test]
473    async fn pause_captures_both_steering_and_follow_up_queues() {
474        use crate::types::ContentBlock;
475
476        let mut agent = make_agent(None);
477        // Give the agent a message so it's valid to resume
478        agent
479            .state
480            .messages
481            .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
482                content: vec![ContentBlock::Text {
483                    text: "hi".to_string(),
484                }],
485                timestamp: 0,
486                cache_hint: None,
487            })));
488
489        // Queue a steering message and a follow-up message
490        agent.steer(AgentMessage::Llm(LlmMessage::User(UserMessage {
491            content: vec![ContentBlock::Text {
492                text: "steering-msg".to_string(),
493            }],
494            timestamp: 1,
495            cache_hint: None,
496        })));
497        agent.follow_up(AgentMessage::Llm(LlmMessage::User(UserMessage {
498            content: vec![ContentBlock::Text {
499                text: "followup-msg".to_string(),
500            }],
501            timestamp: 2,
502            cache_hint: None,
503        })));
504
505        // Simulate a running loop so pause() doesn't return None
506        agent
507            .loop_active
508            .store(true, std::sync::atomic::Ordering::Release);
509
510        let checkpoint = agent.pause().expect("agent should be running");
511
512        // Verify both queues are captured separately
513        assert_eq!(
514            checkpoint.pending_messages.len(),
515            1,
516            "follow-up queue should be captured"
517        );
518        assert_eq!(
519            checkpoint.pending_steering_messages.len(),
520            1,
521            "steering queue should be captured"
522        );
523
524        // Verify the content is correct
525        match &checkpoint.pending_messages[0] {
526            LlmMessage::User(u) => match &u.content[0] {
527                ContentBlock::Text { text } => assert_eq!(text, "followup-msg"),
528                _ => panic!("expected text content"),
529            },
530            _ => panic!("expected user message"),
531        }
532        match &checkpoint.pending_steering_messages[0] {
533            LlmMessage::User(u) => match &u.content[0] {
534                ContentBlock::Text { text } => assert_eq!(text, "steering-msg"),
535                _ => panic!("expected text content"),
536            },
537            _ => panic!("expected user message"),
538        }
539
540        // After pause, live queues must be drained (#337).
541        assert!(
542            !agent.has_pending_messages(),
543            "queues should be empty after pause drains them"
544        );
545    }
546
547    #[tokio::test]
548    async fn restore_from_loop_checkpoint_routes_steering_to_steering_queue() {
549        use crate::checkpoint::LoopCheckpoint;
550        use crate::types::ContentBlock;
551
552        let messages = vec![AgentMessage::Llm(LlmMessage::User(UserMessage {
553            content: vec![ContentBlock::Text {
554                text: "hi".to_string(),
555            }],
556            timestamp: 0,
557            cache_hint: None,
558        }))];
559
560        let cp = LoopCheckpoint::new("system", "mock", "mock-model", &messages)
561            .with_pending_messages(vec![LlmMessage::User(UserMessage {
562                content: vec![ContentBlock::Text {
563                    text: "followup".to_string(),
564                }],
565                timestamp: 1,
566                cache_hint: None,
567            })])
568            .with_pending_steering_messages(vec![LlmMessage::User(UserMessage {
569                content: vec![ContentBlock::Text {
570                    text: "steering".to_string(),
571                }],
572                timestamp: 2,
573                cache_hint: None,
574            })]);
575
576        let mut agent = make_agent(None);
577        agent.restore_from_loop_checkpoint(&cp).unwrap();
578
579        // Verify steering went to steering queue, follow-up to follow-up queue
580        let steering = agent.steering_queue.lock().unwrap();
581        let follow_up = agent.follow_up_queue.lock().unwrap();
582
583        assert_eq!(steering.len(), 1, "steering queue should have 1 message");
584        assert_eq!(follow_up.len(), 1, "follow-up queue should have 1 message");
585
586        match &steering[0] {
587            AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
588                ContentBlock::Text { text } => assert_eq!(text, "steering"),
589                _ => panic!("expected text"),
590            },
591            _ => panic!("expected user message in steering queue"),
592        }
593        match &follow_up[0] {
594            AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
595                ContentBlock::Text { text } => assert_eq!(text, "followup"),
596                _ => panic!("expected text"),
597            },
598            _ => panic!("expected user message in follow-up queue"),
599        }
600    }
601
602    /// Regression test for #337: pause then resume must not duplicate queued
603    /// messages.  Before the fix, `pause()` snapshotted the queues without
604    /// draining them, and `restore_from_loop_checkpoint()` re-enqueued the
605    /// same entries on top of the still-populated live queues.
606    #[tokio::test]
607    async fn pause_drains_queues_so_resume_does_not_duplicate() {
608        use crate::types::ContentBlock;
609
610        let mut agent = make_agent(None);
611        agent
612            .state
613            .messages
614            .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
615                content: vec![ContentBlock::Text {
616                    text: "hi".to_string(),
617                }],
618                timestamp: 0,
619                cache_hint: None,
620            })));
621
622        // Enqueue one steering and one follow-up message.
623        agent.steer(AgentMessage::Llm(LlmMessage::User(UserMessage {
624            content: vec![ContentBlock::Text {
625                text: "steering-1".to_string(),
626            }],
627            timestamp: 1,
628            cache_hint: None,
629        })));
630        agent.follow_up(AgentMessage::Llm(LlmMessage::User(UserMessage {
631            content: vec![ContentBlock::Text {
632                text: "followup-1".to_string(),
633            }],
634            timestamp: 2,
635            cache_hint: None,
636        })));
637
638        // Simulate a running loop so pause() doesn't return None.
639        agent
640            .loop_active
641            .store(true, std::sync::atomic::Ordering::Release);
642
643        let checkpoint = agent.pause().expect("agent should be running");
644
645        // After pause, live queues must be empty (drained into checkpoint).
646        assert!(
647            !agent.has_pending_messages(),
648            "queues should be drained after pause"
649        );
650
651        // Restore from the checkpoint — queues should have exactly 1 each.
652        agent
653            .loop_active
654            .store(false, std::sync::atomic::Ordering::Release);
655        agent.restore_from_loop_checkpoint(&checkpoint).unwrap();
656
657        let steering = agent.steering_queue.lock().unwrap();
658        let follow_up = agent.follow_up_queue.lock().unwrap();
659
660        assert_eq!(
661            steering.len(),
662            1,
663            "steering queue should have exactly 1 message, not duplicated"
664        );
665        assert_eq!(
666            follow_up.len(),
667            1,
668            "follow-up queue should have exactly 1 message, not duplicated"
669        );
670    }
671
672    #[tokio::test]
673    async fn pause_and_resume_preserves_serializable_custom_pending_messages() {
674        use crate::types::ContentBlock;
675
676        let mut agent = make_agent(Some(tagged_registry()));
677        agent
678            .state
679            .messages
680            .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
681                content: vec![ContentBlock::Text {
682                    text: "hi".to_string(),
683                }],
684                timestamp: 0,
685                cache_hint: None,
686            })));
687
688        agent.follow_up(AgentMessage::Llm(LlmMessage::User(UserMessage {
689            content: vec![ContentBlock::Text {
690                text: "followup-1".to_string(),
691            }],
692            timestamp: 1,
693            cache_hint: None,
694        })));
695        agent.follow_up(AgentMessage::Custom(Box::new(Tagged {
696            value: "followup-custom".to_string(),
697        })));
698        agent.steer(AgentMessage::Custom(Box::new(Tagged {
699            value: "steering-custom".to_string(),
700        })));
701        agent.steer(AgentMessage::Llm(LlmMessage::User(UserMessage {
702            content: vec![ContentBlock::Text {
703                text: "steering-1".to_string(),
704            }],
705            timestamp: 2,
706            cache_hint: None,
707        })));
708
709        agent
710            .loop_active
711            .store(true, std::sync::atomic::Ordering::Release);
712
713        let checkpoint = agent.pause().expect("agent should be running");
714        assert!(
715            !agent.has_pending_messages(),
716            "queues should be drained after pause"
717        );
718
719        let json = serde_json::to_string(&checkpoint).unwrap();
720        let loaded: LoopCheckpoint = serde_json::from_str(&json).unwrap();
721
722        agent
723            .loop_active
724            .store(false, std::sync::atomic::Ordering::Release);
725        agent.restore_from_loop_checkpoint(&loaded).unwrap();
726
727        let steering = agent.steering_queue.lock().unwrap();
728        let follow_up = agent.follow_up_queue.lock().unwrap();
729
730        assert_eq!(
731            follow_up.len(),
732            2,
733            "follow-up queue should keep mixed messages"
734        );
735        assert_eq!(
736            steering.len(),
737            2,
738            "steering queue should keep mixed messages"
739        );
740
741        match &follow_up[0] {
742            AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
743                ContentBlock::Text { text } => assert_eq!(text, "followup-1"),
744                _ => panic!("expected text content"),
745            },
746            _ => panic!("expected llm follow-up message"),
747        }
748        let follow_up_custom = follow_up[1]
749            .downcast_ref::<Tagged>()
750            .expect("custom follow-up should be restored");
751        assert_eq!(follow_up_custom.value, "followup-custom");
752
753        let steering_custom = steering[0]
754            .downcast_ref::<Tagged>()
755            .expect("custom steering should be restored");
756        assert_eq!(steering_custom.value, "steering-custom");
757        match &steering[1] {
758            AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
759                ContentBlock::Text { text } => assert_eq!(text, "steering-1"),
760                _ => panic!("expected text content"),
761            },
762            _ => panic!("expected llm steering message"),
763        }
764    }
765
766    #[tokio::test]
767    async fn pause_captures_messages_already_moved_into_loop_local_pending_state() {
768        let mut agent = make_agent(Some(tagged_registry()));
769        agent.state.messages.push(user_msg("hi"));
770        agent.pending_message_snapshot.replace(&[
771            AgentMessage::Llm(LlmMessage::User(UserMessage {
772                content: vec![crate::types::ContentBlock::Text {
773                    text: "polled-follow-up".to_string(),
774                }],
775                timestamp: 1,
776                cache_hint: None,
777            })),
778            AgentMessage::Custom(Box::new(Tagged {
779                value: "polled-custom".to_string(),
780            })),
781        ]);
782
783        agent
784            .loop_active
785            .store(true, std::sync::atomic::Ordering::Release);
786
787        let checkpoint = agent.pause().expect("agent should be running");
788        let pending = checkpoint.restore_pending_messages(agent.custom_message_registry.as_deref());
789
790        assert_eq!(
791            pending.len(),
792            2,
793            "pause should include loop-local pending messages even when the shared queue is already empty"
794        );
795        match &pending[0] {
796            AgentMessage::Llm(LlmMessage::User(user)) => match &user.content[0] {
797                crate::types::ContentBlock::Text { text } => {
798                    assert_eq!(text, "polled-follow-up");
799                }
800                other => panic!("expected text content, got {other:?}"),
801            },
802            other => panic!("expected user message, got {other:?}"),
803        }
804        let restored_custom = pending[1]
805            .downcast_ref::<Tagged>()
806            .expect("custom pending message should be preserved");
807        assert_eq!(restored_custom.value, "polled-custom");
808    }
809
810    #[tokio::test]
811    async fn pause_preserves_in_flight_custom_messages_during_streamed_runs() {
812        use futures::future::pending;
813
814        struct PendingStreamFn;
815
816        impl crate::stream::StreamFn for PendingStreamFn {
817            fn stream<'a>(
818                &'a self,
819                _model: &'a crate::ModelSpec,
820                _context: &'a crate::AgentContext,
821                _options: &'a crate::StreamOptions,
822                _cancellation_token: tokio_util::sync::CancellationToken,
823            ) -> std::pin::Pin<
824                Box<dyn futures::Stream<Item = crate::AssistantMessageEvent> + Send + 'a>,
825            > {
826                Box::pin(futures::stream::once(async {
827                    pending::<()>().await;
828                    crate::AssistantMessageEvent::error("unreachable")
829                }))
830            }
831        }
832
833        let stream_fn = Arc::new(PendingStreamFn);
834        let opts =
835            AgentOptions::new_simple("system", ModelSpec::new("mock", "mock-model"), stream_fn)
836                .with_custom_message_registry(tagged_registry());
837        let mut agent = Agent::new(opts);
838        agent
839            .state
840            .messages
841            .push(AgentMessage::Custom(Box::new(Tagged {
842                value: "history-custom".to_string(),
843            })));
844
845        let _stream = agent.prompt_stream(vec![user_msg("start")]).unwrap();
846        let checkpoint = agent.pause().expect("agent should be running");
847        let restored = checkpoint.restore_messages(agent.custom_message_registry.as_deref());
848
849        assert_eq!(
850            restored.len(),
851            2,
852            "pause should keep custom history in checkpoint"
853        );
854
855        let restored_custom = restored[0]
856            .downcast_ref::<Tagged>()
857            .expect("custom history should be restored from the paused checkpoint");
858        assert_eq!(restored_custom.value, "history-custom");
859
860        match &restored[1] {
861            AgentMessage::Llm(LlmMessage::User(user)) => match &user.content[0] {
862                crate::types::ContentBlock::Text { text } => assert_eq!(text, "start"),
863                other => panic!("expected text content, got {other:?}"),
864            },
865            other => panic!("expected user message, got {other:?}"),
866        }
867    }
868
869    #[tokio::test]
870    async fn restore_from_checkpoint_rebinds_stream_fn_for_matching_model() {
871        use crate::stream::StreamFn;
872        use crate::types::ContentBlock;
873
874        let model_a = ModelSpec::new("provider-a", "model-a");
875        let model_b = ModelSpec::new("provider-b", "model-b");
876        let stream_a = Arc::new(SimpleMockStreamFn::from_text("from-a"));
877        let stream_b = Arc::new(SimpleMockStreamFn::from_text("from-b"));
878
879        // Agent starts on model_a, with model_b registered as available.
880        let opts = AgentOptions::new_simple("system", model_a.clone(), stream_a.clone())
881            .with_available_models(vec![(model_b.clone(), stream_b.clone())]);
882        let mut agent = Agent::new(opts);
883
884        // Confirm initial stream_fn points to stream_a.
885        assert!(
886            Arc::ptr_eq(&agent.stream_fn, &(stream_a.clone() as Arc<dyn StreamFn>)),
887            "initial stream_fn should be stream_a"
888        );
889
890        // Build a checkpoint from a source agent that uses model_b.
891        let source_opts = AgentOptions::new_simple("system", model_b.clone(), stream_b.clone());
892        let mut source = Agent::new(source_opts);
893        source
894            .state
895            .messages
896            .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
897                content: vec![ContentBlock::Text {
898                    text: "hello".to_string(),
899                }],
900                timestamp: 0,
901                cache_hint: None,
902            })));
903        let checkpoint = source.save_checkpoint("cp-rebind").await.unwrap();
904
905        // Restore into agent (currently on model_a).
906        agent.restore_from_checkpoint(&checkpoint).unwrap();
907
908        // Model metadata should reflect model_b.
909        assert_eq!(agent.state.model.provider, "provider-b");
910        assert_eq!(agent.state.model.model_id, "model-b");
911
912        // Stream function should now be rebound to stream_b.
913        assert!(
914            Arc::ptr_eq(&agent.stream_fn, &(stream_b.clone() as Arc<dyn StreamFn>)),
915            "stream_fn should be rebound to stream_b after checkpoint restore"
916        );
917    }
918
919    #[tokio::test]
920    async fn restore_from_checkpoint_clears_transient_runtime_state() {
921        let mut source = make_agent(None);
922        source.state.messages.push(user_msg("restored"));
923        let checkpoint = source.save_checkpoint("cp-clear-runtime").await.unwrap();
924
925        let mut agent = make_agent(None);
926        seed_transient_runtime_state(&mut agent);
927
928        agent.restore_from_checkpoint(&checkpoint).unwrap();
929
930        assert!(!agent.state.is_running);
931        assert!(agent.state.stream_message.is_none());
932        assert!(agent.state.pending_tool_calls.is_empty());
933        assert!(agent.state.error.is_none());
934        assert!(agent.abort_controller.is_none());
935        assert!(agent.in_flight_llm_messages.is_none());
936        assert!(agent.in_flight_messages.is_none());
937    }
938
939    #[tokio::test]
940    async fn restore_from_checkpoint_rejects_restore_while_running() {
941        let mut source = make_agent(None);
942        source.state.messages.push(user_msg("restored"));
943        let checkpoint = source.save_checkpoint("cp-running-guard").await.unwrap();
944
945        let mut agent = make_agent(None);
946        let stream = agent.prompt_stream(vec![user_msg("hi")]).unwrap();
947
948        let err = agent.restore_from_checkpoint(&checkpoint).unwrap_err();
949        assert_eq!(err.kind(), std::io::ErrorKind::WouldBlock);
950        assert!(
951            err.to_string()
952                .contains("cannot restore checkpoint while agent is running")
953        );
954        assert!(agent.is_running());
955
956        drop(stream);
957        agent.wait_for_idle().await;
958    }
959
960    #[tokio::test]
961    async fn restore_from_loop_checkpoint_rebinds_stream_fn_for_matching_model() {
962        use crate::checkpoint::LoopCheckpoint;
963        use crate::stream::StreamFn;
964        use crate::types::ContentBlock;
965
966        let model_a = ModelSpec::new("provider-a", "model-a");
967        let model_b = ModelSpec::new("provider-b", "model-b");
968        let stream_a = Arc::new(SimpleMockStreamFn::from_text("from-a"));
969        let stream_b = Arc::new(SimpleMockStreamFn::from_text("from-b"));
970
971        let opts = AgentOptions::new_simple("system", model_a.clone(), stream_a.clone())
972            .with_available_models(vec![(model_b.clone(), stream_b.clone())]);
973        let mut agent = Agent::new(opts);
974
975        assert!(
976            Arc::ptr_eq(&agent.stream_fn, &(stream_a.clone() as Arc<dyn StreamFn>)),
977            "initial stream_fn should be stream_a"
978        );
979
980        // Build a LoopCheckpoint for model_b.
981        let messages = vec![AgentMessage::Llm(LlmMessage::User(UserMessage {
982            content: vec![ContentBlock::Text {
983                text: "hello".to_string(),
984            }],
985            timestamp: 0,
986            cache_hint: None,
987        }))];
988        let cp = LoopCheckpoint::new("system", "provider-b", "model-b", &messages);
989
990        agent.restore_from_loop_checkpoint(&cp).unwrap();
991
992        assert_eq!(agent.state.model.provider, "provider-b");
993        assert_eq!(agent.state.model.model_id, "model-b");
994        assert!(
995            Arc::ptr_eq(&agent.stream_fn, &(stream_b.clone() as Arc<dyn StreamFn>)),
996            "stream_fn should be rebound to stream_b after loop checkpoint restore"
997        );
998    }
999
1000    #[tokio::test]
1001    async fn restore_from_loop_checkpoint_clears_transient_runtime_state() {
1002        let checkpoint = LoopCheckpoint::new("system", "mock", "mock-model", &[user_msg("hi")]);
1003        let mut agent = make_agent(None);
1004        seed_transient_runtime_state(&mut agent);
1005
1006        agent.restore_from_loop_checkpoint(&checkpoint).unwrap();
1007
1008        assert!(!agent.state.is_running);
1009        assert!(agent.state.stream_message.is_none());
1010        assert!(agent.state.pending_tool_calls.is_empty());
1011        assert!(agent.state.error.is_none());
1012        assert!(agent.abort_controller.is_none());
1013        assert!(agent.in_flight_llm_messages.is_none());
1014        assert!(agent.in_flight_messages.is_none());
1015    }
1016
1017    #[tokio::test]
1018    async fn loop_checkpoint_resume_rehydrates_custom_messages_via_registry() {
1019        let messages = vec![
1020            AgentMessage::Llm(LlmMessage::User(UserMessage {
1021                content: vec![crate::types::ContentBlock::Text {
1022                    text: "hi".to_string(),
1023                }],
1024                timestamp: 0,
1025                cache_hint: None,
1026            })),
1027            AgentMessage::Custom(Box::new(Tagged {
1028                value: "resumed".to_string(),
1029            })),
1030        ];
1031        let cp = LoopCheckpoint::new("system", "mock", "mock-model", &messages);
1032        let json = serde_json::to_string(&cp).unwrap();
1033        let loaded: LoopCheckpoint = serde_json::from_str(&json).unwrap();
1034
1035        let mut agent = make_agent(Some(tagged_registry()));
1036        agent.restore_from_loop_checkpoint(&loaded).unwrap();
1037        assert_eq!(agent.state.messages.len(), 2);
1038        let restored = agent.state.messages[1]
1039            .downcast_ref::<Tagged>()
1040            .expect("custom message should be restored via registry");
1041        assert_eq!(restored.value, "resumed");
1042    }
1043
1044    #[tokio::test]
1045    async fn load_and_restore_checkpoint_rejects_corrupt_state_snapshot() {
1046        let store = TestCheckpointStore::default();
1047        let checkpoint = Checkpoint::new(
1048            "bad-state",
1049            "system",
1050            "mock",
1051            "mock-model",
1052            &[user_msg("hi")],
1053        )
1054        .with_state(serde_json::json!(["bad"]));
1055        store.save_checkpoint(checkpoint).await.unwrap();
1056
1057        let stream_fn = Arc::new(SimpleMockStreamFn::from_text("ok"));
1058        let agent_options =
1059            AgentOptions::new_simple("system", ModelSpec::new("mock", "mock-model"), stream_fn)
1060                .with_checkpoint_store(store);
1061        let mut agent = Agent::new(agent_options);
1062
1063        let err = agent
1064            .load_and_restore_checkpoint("bad-state")
1065            .await
1066            .unwrap_err();
1067        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1068        assert!(err.to_string().contains("corrupted session state snapshot"));
1069    }
1070
1071    #[tokio::test]
1072    async fn load_and_restore_checkpoint_rejects_restore_while_running() {
1073        let store = TestCheckpointStore::default();
1074        let checkpoint = Checkpoint::new(
1075            "running-guard",
1076            "system",
1077            "mock",
1078            "mock-model",
1079            &[user_msg("hi")],
1080        );
1081        store.save_checkpoint(checkpoint).await.unwrap();
1082
1083        let stream_fn = Arc::new(SimpleMockStreamFn::from_text("ok"));
1084        let agent_options =
1085            AgentOptions::new_simple("system", ModelSpec::new("mock", "mock-model"), stream_fn)
1086                .with_checkpoint_store(store);
1087        let mut agent = Agent::new(agent_options);
1088        let stream = agent.prompt_stream(vec![user_msg("start")]).unwrap();
1089
1090        let err = agent
1091            .load_and_restore_checkpoint("running-guard")
1092            .await
1093            .unwrap_err();
1094        assert_eq!(err.kind(), std::io::ErrorKind::WouldBlock);
1095        assert!(
1096            err.to_string()
1097                .contains("cannot restore checkpoint while agent is running")
1098        );
1099        assert!(agent.is_running());
1100
1101        drop(stream);
1102        agent.wait_for_idle().await;
1103    }
1104
1105    #[tokio::test]
1106    async fn resume_rejects_corrupt_loop_checkpoint_state_snapshot() {
1107        let checkpoint = LoopCheckpoint::new("system", "mock", "mock-model", &[user_msg("hi")])
1108            .with_state(serde_json::json!(["bad"]));
1109        let mut agent = make_agent(None);
1110
1111        let err = agent.resume(&checkpoint).await.unwrap_err();
1112        match err {
1113            AgentError::StreamError { source } => {
1114                let io = source
1115                    .downcast_ref::<std::io::Error>()
1116                    .expect("expected io::Error source");
1117                assert_eq!(io.kind(), std::io::ErrorKind::InvalidData);
1118                assert!(io.to_string().contains("corrupted session state snapshot"));
1119            }
1120            other => panic!("expected StreamError, got {other:?}"),
1121        }
1122    }
1123
1124    #[tokio::test]
1125    async fn restore_from_checkpoint_keeps_live_state_when_snapshot_is_corrupt() {
1126        let checkpoint = Checkpoint::new(
1127            "bad-state",
1128            "restored-system",
1129            "restored",
1130            "restored-model",
1131            &[user_msg("restored")],
1132        )
1133        .with_state(serde_json::json!(["bad"]));
1134        let mut agent = make_agent(None);
1135        agent.state.messages.push(user_msg("existing"));
1136        agent.state.system_prompt = "live-system".to_string();
1137        agent.state.model = ModelSpec::new("live-provider", "live-model");
1138        {
1139            let mut state = agent
1140                .session_state()
1141                .write()
1142                .unwrap_or_else(std::sync::PoisonError::into_inner);
1143            state.set("live", 7_i64).unwrap();
1144        }
1145
1146        let err = agent.restore_from_checkpoint(&checkpoint).unwrap_err();
1147        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1148
1149        assert_eq!(agent.state.messages.len(), 1);
1150        match &agent.state.messages[0] {
1151            AgentMessage::Llm(LlmMessage::User(user)) => match &user.content[0] {
1152                crate::types::ContentBlock::Text { text } => assert_eq!(text, "existing"),
1153                other => panic!("expected text content, got {other:?}"),
1154            },
1155            other => panic!("expected user message, got {other:?}"),
1156        }
1157        assert_eq!(agent.state.system_prompt, "live-system");
1158        assert_eq!(agent.state.model.provider, "live-provider");
1159        assert_eq!(agent.state.model.model_id, "live-model");
1160
1161        let state = agent
1162            .session_state()
1163            .read()
1164            .unwrap_or_else(std::sync::PoisonError::into_inner);
1165        assert_eq!(state.get::<i64>("live"), Some(7));
1166    }
1167
1168    #[tokio::test]
1169    async fn restore_from_loop_checkpoint_keeps_live_state_when_snapshot_is_corrupt() {
1170        let checkpoint = LoopCheckpoint::new(
1171            "restored-system",
1172            "restored",
1173            "restored-model",
1174            &[user_msg("restored")],
1175        )
1176        .with_state(serde_json::json!(["bad"]));
1177        let mut agent = make_agent(None);
1178        agent.state.messages.push(user_msg("existing"));
1179        agent.state.system_prompt = "live-system".to_string();
1180        agent.state.model = ModelSpec::new("live-provider", "live-model");
1181        agent.follow_up(user_msg("live-follow-up"));
1182        agent.steer(user_msg("live-steering"));
1183        {
1184            let mut state = agent
1185                .session_state()
1186                .write()
1187                .unwrap_or_else(std::sync::PoisonError::into_inner);
1188            state.set("live", 9_i64).unwrap();
1189        }
1190
1191        let err = agent.resume(&checkpoint).await.unwrap_err();
1192        match err {
1193            AgentError::StreamError { source } => {
1194                let io = source
1195                    .downcast_ref::<std::io::Error>()
1196                    .expect("expected io::Error source");
1197                assert_eq!(io.kind(), std::io::ErrorKind::InvalidData);
1198            }
1199            other => panic!("expected StreamError, got {other:?}"),
1200        }
1201
1202        assert_eq!(agent.state.messages.len(), 1);
1203        match &agent.state.messages[0] {
1204            AgentMessage::Llm(LlmMessage::User(user)) => match &user.content[0] {
1205                crate::types::ContentBlock::Text { text } => assert_eq!(text, "existing"),
1206                other => panic!("expected text content, got {other:?}"),
1207            },
1208            other => panic!("expected user message, got {other:?}"),
1209        }
1210        assert_eq!(agent.state.system_prompt, "live-system");
1211        assert_eq!(agent.state.model.provider, "live-provider");
1212        assert_eq!(agent.state.model.model_id, "live-model");
1213
1214        let state = agent
1215            .session_state()
1216            .read()
1217            .unwrap_or_else(std::sync::PoisonError::into_inner);
1218        assert_eq!(state.get::<i64>("live"), Some(9));
1219        drop(state);
1220
1221        let follow_up = agent
1222            .follow_up_queue
1223            .lock()
1224            .unwrap_or_else(std::sync::PoisonError::into_inner);
1225        let steering = agent
1226            .steering_queue
1227            .lock()
1228            .unwrap_or_else(std::sync::PoisonError::into_inner);
1229        assert_eq!(
1230            follow_up.len(),
1231            1,
1232            "failed restore should not clear follow-up queue"
1233        );
1234        assert_eq!(
1235            steering.len(),
1236            1,
1237            "failed restore should not clear steering queue"
1238        );
1239    }
1240
1241    #[tokio::test]
1242    async fn restore_from_checkpoint_clears_session_state_when_snapshot_missing() {
1243        let mut source = make_agent(None);
1244        source
1245            .state
1246            .messages
1247            .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
1248                content: vec![crate::types::ContentBlock::Text {
1249                    text: "hi".to_string(),
1250                }],
1251                timestamp: 0,
1252                cache_hint: None,
1253            })));
1254
1255        let mut checkpoint = source.save_checkpoint("cp-empty-state").await.unwrap();
1256        checkpoint.state = None;
1257
1258        let mut agent = make_agent(None);
1259        {
1260            let mut state = agent
1261                .session_state()
1262                .write()
1263                .unwrap_or_else(std::sync::PoisonError::into_inner);
1264            state.set("stale", 42_i64).unwrap();
1265        }
1266
1267        agent.restore_from_checkpoint(&checkpoint).unwrap();
1268
1269        let state = agent
1270            .session_state()
1271            .read()
1272            .unwrap_or_else(std::sync::PoisonError::into_inner);
1273        assert!(
1274            state.is_empty(),
1275            "missing snapshot should clear stale state"
1276        );
1277    }
1278
1279    #[tokio::test]
1280    async fn restore_from_loop_checkpoint_clears_session_state_when_snapshot_missing() {
1281        use crate::checkpoint::LoopCheckpoint;
1282
1283        let messages = vec![AgentMessage::Llm(LlmMessage::User(UserMessage {
1284            content: vec![crate::types::ContentBlock::Text {
1285                text: "hi".to_string(),
1286            }],
1287            timestamp: 0,
1288            cache_hint: None,
1289        }))];
1290        let mut checkpoint = LoopCheckpoint::new("system", "mock", "mock-model", &messages);
1291        checkpoint.state = None;
1292
1293        let mut agent = make_agent(None);
1294        {
1295            let mut state = agent
1296                .session_state()
1297                .write()
1298                .unwrap_or_else(std::sync::PoisonError::into_inner);
1299            state.set("stale", 99_i64).unwrap();
1300        }
1301
1302        agent.restore_from_loop_checkpoint(&checkpoint).unwrap();
1303
1304        let state = agent
1305            .session_state()
1306            .read()
1307            .unwrap_or_else(std::sync::PoisonError::into_inner);
1308        assert!(
1309            state.is_empty(),
1310            "missing snapshot should clear stale state"
1311        );
1312    }
1313
1314    /// Regression test for issue #557: `run_single_turn` drains pending messages
1315    /// into loop-local `context_messages` and then clears the shared
1316    /// `pending_message_snapshot`. A concurrent `pause()` in that window would
1317    /// previously miss those messages. The fix syncs `context_messages` to
1318    /// `loop_context_snapshot` immediately after the drain, and `pause()` now
1319    /// prefers that snapshot over `in_flight_messages`.
1320    #[tokio::test]
1321    async fn pause_captures_messages_drained_from_pending_into_loop_context() {
1322        let mut agent = make_agent(None);
1323        // Simulate the agent being mid-turn: in_flight_messages holds the
1324        // original messages (before any pending drain), and loop_context_snapshot
1325        // holds the expanded context after the drain.
1326
1327        // in_flight_messages = original message only (set at loop start).
1328        agent.in_flight_messages = Some(vec![user_msg("original")]);
1329        // pending_message_snapshot is cleared (run_single_turn has already drained it).
1330        agent.pending_message_snapshot.clear();
1331        // loop_context_snapshot = original + consumed pending (synced just after drain).
1332        // replace() uses the internal clone_messages helper which handles AgentMessage variants.
1333        agent
1334            .loop_context_snapshot
1335            .replace(&[user_msg("original"), user_msg("consumed-pending")]);
1336
1337        agent
1338            .loop_active
1339            .store(true, std::sync::atomic::Ordering::Release);
1340
1341        let checkpoint = agent.pause().expect("agent should be paused");
1342        let restored = checkpoint.restore_messages(agent.custom_message_registry.as_deref());
1343
1344        assert_eq!(
1345            restored.len(),
1346            2,
1347            "pause snapshot must include messages already consumed from the pending queue \
1348             into loop context, not just in_flight_messages"
1349        );
1350        match &restored[0] {
1351            AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
1352                crate::types::ContentBlock::Text { text } => {
1353                    assert_eq!(text, "original");
1354                }
1355                other => panic!("expected text content, got {other:?}"),
1356            },
1357            other => panic!("expected user message, got {other:?}"),
1358        }
1359        match &restored[1] {
1360            AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
1361                crate::types::ContentBlock::Text { text } => {
1362                    assert_eq!(text, "consumed-pending");
1363                }
1364                other => panic!("expected text content, got {other:?}"),
1365            },
1366            other => panic!("expected user message, got {other:?}"),
1367        }
1368    }
1369
1370    /// When `loop_context_snapshot` is not set (loop has not yet started its
1371    /// first turn), `pause()` must fall back to `in_flight_messages` as before.
1372    #[tokio::test]
1373    async fn pause_falls_back_to_in_flight_messages_when_context_snapshot_absent() {
1374        let mut agent = make_agent(None);
1375
1376        // in_flight_messages = message set at loop start.
1377        agent.in_flight_messages = Some(vec![user_msg("in-flight")]);
1378        // loop_context_snapshot is empty (not yet set — pre-first-turn).
1379        // (default state after Agent::new)
1380
1381        agent
1382            .loop_active
1383            .store(true, std::sync::atomic::Ordering::Release);
1384
1385        let checkpoint = agent.pause().expect("agent should be paused");
1386        let restored = checkpoint.restore_messages(agent.custom_message_registry.as_deref());
1387
1388        assert_eq!(
1389            restored.len(),
1390            1,
1391            "pause must fall back to in_flight_messages when loop_context_snapshot is absent"
1392        );
1393        match &restored[0] {
1394            AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
1395                crate::types::ContentBlock::Text { text } => {
1396                    assert_eq!(text, "in-flight");
1397                }
1398                other => panic!("expected text content, got {other:?}"),
1399            },
1400            other => panic!("expected user message, got {other:?}"),
1401        }
1402    }
1403}