1use std::collections::HashMap;
8
9use serde::{Deserialize, Serialize};
10
11use crate::types::message_codec::{self, MessageSlot};
12use crate::types::{AgentMessage, Cost, CustomMessageRegistry, LlmMessage, Usage};
13
14mod store;
15
16pub use store::{CheckpointFuture, CheckpointStore};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Checkpoint {
27 pub id: String,
29 pub system_prompt: String,
31 pub provider: String,
33 pub model_id: String,
35 pub messages: Vec<LlmMessage>,
37 #[serde(default, skip_serializing_if = "Vec::is_empty")]
39 pub custom_messages: Vec<serde_json::Value>,
40 #[serde(default, skip_serializing_if = "Vec::is_empty")]
45 message_order: Vec<MessageSlot>,
46 pub turn_count: usize,
48 pub usage: Usage,
50 pub cost: Cost,
52 pub created_at: u64,
54 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
56 pub metadata: HashMap<String, serde_json::Value>,
57 #[serde(default, skip_serializing_if = "Option::is_none")]
59 pub state: Option<serde_json::Value>,
60}
61
62impl Checkpoint {
63 #[must_use]
72 pub fn new(
73 id: impl Into<String>,
74 system_prompt: impl Into<String>,
75 provider: impl Into<String>,
76 model_id: impl Into<String>,
77 messages: &[AgentMessage],
78 ) -> Self {
79 let serialized = message_codec::serialize_messages(messages, "checkpoint");
80
81 Self {
82 id: id.into(),
83 system_prompt: system_prompt.into(),
84 provider: provider.into(),
85 model_id: model_id.into(),
86 messages: serialized.llm_messages,
87 custom_messages: serialized.custom_messages,
88 message_order: serialized.message_order,
89 turn_count: 0,
90 usage: Usage::default(),
91 cost: Cost::default(),
92 created_at: crate::util::now_timestamp(),
93 metadata: HashMap::new(),
94 state: None,
95 }
96 }
97
98 #[must_use]
100 pub fn with_state(mut self, state: serde_json::Value) -> Self {
101 self.state = Some(state);
102 self
103 }
104
105 #[must_use]
107 pub const fn with_turn_count(mut self, turn_count: usize) -> Self {
108 self.turn_count = turn_count;
109 self
110 }
111
112 #[must_use]
114 pub fn with_usage(mut self, usage: Usage) -> Self {
115 self.usage = usage;
116 self
117 }
118
119 #[must_use]
121 pub fn with_cost(mut self, cost: Cost) -> Self {
122 self.cost = cost;
123 self
124 }
125
126 #[must_use]
128 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
129 self.metadata.insert(key.into(), value);
130 self
131 }
132
133 #[must_use]
142 pub fn restore_messages(&self, registry: Option<&CustomMessageRegistry>) -> Vec<AgentMessage> {
143 message_codec::restore_messages(
144 &self.messages,
145 &self.custom_messages,
146 &self.message_order,
147 registry,
148 "checkpoint",
149 )
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct LoopCheckpoint {
164 pub messages: Vec<LlmMessage>,
166 #[serde(default, skip_serializing_if = "Vec::is_empty")]
168 pub custom_messages: Vec<serde_json::Value>,
169 #[serde(default, skip_serializing_if = "Vec::is_empty")]
171 message_order: Vec<MessageSlot>,
172 pub pending_messages: Vec<LlmMessage>,
174 #[serde(default, skip_serializing_if = "Vec::is_empty")]
176 pending_custom_messages: Vec<serde_json::Value>,
177 #[serde(default, skip_serializing_if = "Vec::is_empty")]
179 pending_message_order: Vec<MessageSlot>,
180 #[serde(default, skip_serializing_if = "Vec::is_empty")]
185 pub pending_steering_messages: Vec<LlmMessage>,
186 #[serde(default, skip_serializing_if = "Vec::is_empty")]
188 pending_steering_custom_messages: Vec<serde_json::Value>,
189 #[serde(default, skip_serializing_if = "Vec::is_empty")]
191 pending_steering_message_order: Vec<MessageSlot>,
192 pub system_prompt: String,
194 pub provider: String,
196 pub model_id: String,
198 pub created_at: u64,
200 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
202 pub metadata: HashMap<String, serde_json::Value>,
203 #[serde(default, skip_serializing_if = "Option::is_none")]
205 pub state: Option<serde_json::Value>,
206}
207
208impl LoopCheckpoint {
209 #[must_use]
214 pub fn new(
215 system_prompt: impl Into<String>,
216 provider: impl Into<String>,
217 model_id: impl Into<String>,
218 messages: &[AgentMessage],
219 ) -> Self {
220 let serialized = message_codec::serialize_messages(messages, "loop checkpoint");
221
222 Self {
223 messages: serialized.llm_messages,
224 custom_messages: serialized.custom_messages,
225 message_order: serialized.message_order,
226 pending_messages: Vec::new(),
227 pending_custom_messages: Vec::new(),
228 pending_message_order: Vec::new(),
229 pending_steering_messages: Vec::new(),
230 pending_steering_custom_messages: Vec::new(),
231 pending_steering_message_order: Vec::new(),
232 system_prompt: system_prompt.into(),
233 provider: provider.into(),
234 model_id: model_id.into(),
235 created_at: crate::util::now_timestamp(),
236 metadata: HashMap::new(),
237 state: None,
238 }
239 }
240
241 #[must_use]
243 pub fn with_state(mut self, state: serde_json::Value) -> Self {
244 self.state = Some(state);
245 self
246 }
247
248 #[must_use]
250 pub fn with_pending_messages(mut self, pending: Vec<LlmMessage>) -> Self {
251 self.pending_messages = pending;
252 self.pending_custom_messages.clear();
253 self.pending_message_order.clear();
254 self
255 }
256
257 #[must_use]
259 pub fn with_pending_steering_messages(mut self, pending: Vec<LlmMessage>) -> Self {
260 self.pending_steering_messages = pending;
261 self.pending_steering_custom_messages.clear();
262 self.pending_steering_message_order.clear();
263 self
264 }
265
266 #[must_use]
268 pub fn with_pending_message_batch(mut self, pending: &[AgentMessage]) -> Self {
269 let serialized = message_codec::serialize_messages(pending, "loop checkpoint pending");
270 self.pending_messages = serialized.llm_messages;
271 self.pending_custom_messages = serialized.custom_messages;
272 self.pending_message_order = serialized.message_order;
273 self
274 }
275
276 #[must_use]
278 pub fn with_pending_steering_message_batch(mut self, pending: &[AgentMessage]) -> Self {
279 let serialized =
280 message_codec::serialize_messages(pending, "loop checkpoint steering pending");
281 self.pending_steering_messages = serialized.llm_messages;
282 self.pending_steering_custom_messages = serialized.custom_messages;
283 self.pending_steering_message_order = serialized.message_order;
284 self
285 }
286
287 #[must_use]
289 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
290 self.metadata.insert(key.into(), value);
291 self
292 }
293
294 #[must_use]
299 pub fn restore_messages(&self, registry: Option<&CustomMessageRegistry>) -> Vec<AgentMessage> {
300 message_codec::restore_messages(
301 &self.messages,
302 &self.custom_messages,
303 &self.message_order,
304 registry,
305 "loop checkpoint",
306 )
307 }
308
309 #[must_use]
311 pub fn restore_pending_messages(
312 &self,
313 registry: Option<&CustomMessageRegistry>,
314 ) -> Vec<AgentMessage> {
315 message_codec::restore_messages(
316 &self.pending_messages,
317 &self.pending_custom_messages,
318 &self.pending_message_order,
319 registry,
320 "loop checkpoint pending",
321 )
322 }
323
324 #[must_use]
326 pub fn restore_pending_steering_messages(
327 &self,
328 registry: Option<&CustomMessageRegistry>,
329 ) -> Vec<AgentMessage> {
330 message_codec::restore_messages(
331 &self.pending_steering_messages,
332 &self.pending_steering_custom_messages,
333 &self.pending_steering_message_order,
334 registry,
335 "loop checkpoint steering pending",
336 )
337 }
338
339 #[must_use]
341 pub fn to_checkpoint(&self, id: impl Into<String>) -> Checkpoint {
342 Checkpoint {
343 id: id.into(),
344 system_prompt: self.system_prompt.clone(),
345 provider: self.provider.clone(),
346 model_id: self.model_id.clone(),
347 messages: self.messages.clone(),
348 custom_messages: self.custom_messages.clone(),
349 message_order: self.message_order.clone(),
350 turn_count: 0,
351 usage: Usage::default(),
352 cost: Cost::default(),
353 created_at: self.created_at,
354 metadata: self.metadata.clone(),
355 state: self.state.clone(),
356 }
357 }
358}
359
360const _: () = {
363 const fn assert_send_sync<T: Send + Sync>() {}
364 assert_send_sync::<Checkpoint>();
365 assert_send_sync::<LoopCheckpoint>();
366};
367
368#[cfg(test)]
371mod tests {
372 use super::*;
373 use crate::types::{ContentBlock, UserMessage};
374
375 #[derive(Debug)]
376 struct TestCustom;
377
378 impl crate::types::CustomMessage for TestCustom {
379 fn as_any(&self) -> &dyn std::any::Any {
380 self
381 }
382 }
383
384 fn sample_messages() -> Vec<AgentMessage> {
385 vec![
386 AgentMessage::Llm(LlmMessage::User(UserMessage {
387 content: vec![ContentBlock::Text {
388 text: "Hello".to_string(),
389 }],
390 timestamp: 100,
391 cache_hint: None,
392 })),
393 AgentMessage::Llm(LlmMessage::Assistant(crate::types::AssistantMessage {
394 content: vec![ContentBlock::Text {
395 text: "Hi there!".to_string(),
396 }],
397 provider: "test".to_string(),
398 model_id: "test-model".to_string(),
399 usage: Usage::default(),
400 cost: Cost::default(),
401 stop_reason: crate::types::StopReason::Stop,
402 error_message: None,
403 error_kind: None,
404 timestamp: 101,
405 cache_hint: None,
406 })),
407 ]
408 }
409
410 #[test]
411 fn checkpoint_creation_skips_non_serializable_custom_messages() {
412 let mut messages = sample_messages();
413 messages.push(AgentMessage::Custom(Box::new(TestCustom)));
415
416 let checkpoint = Checkpoint::new(
417 "cp-1",
418 "Be helpful.",
419 "anthropic",
420 "claude-sonnet",
421 &messages,
422 )
423 .with_turn_count(3);
424
425 assert_eq!(checkpoint.id, "cp-1");
426 assert_eq!(checkpoint.system_prompt, "Be helpful.");
427 assert_eq!(checkpoint.provider, "anthropic");
428 assert_eq!(checkpoint.model_id, "claude-sonnet");
429 assert_eq!(checkpoint.messages.len(), 2); assert!(checkpoint.custom_messages.is_empty()); assert_eq!(checkpoint.turn_count, 3);
432 }
433
434 #[test]
435 fn checkpoint_custom_message_roundtrip() {
436 use crate::types::CustomMessageRegistry;
437
438 #[derive(Debug, Clone, PartialEq)]
439 struct SerializableCustom {
440 value: String,
441 }
442
443 impl crate::types::CustomMessage for SerializableCustom {
444 fn as_any(&self) -> &dyn std::any::Any {
445 self
446 }
447 fn type_name(&self) -> Option<&str> {
448 Some("SerializableCustom")
449 }
450 fn to_json(&self) -> Option<serde_json::Value> {
451 Some(serde_json::json!({ "value": self.value }))
452 }
453 }
454
455 let mut messages = sample_messages();
456 messages.push(AgentMessage::Custom(Box::new(SerializableCustom {
457 value: "hello".to_string(),
458 })));
459
460 let checkpoint = Checkpoint::new("cp-custom", "prompt", "p", "m", &messages);
461
462 assert_eq!(checkpoint.messages.len(), 2);
463 assert_eq!(checkpoint.custom_messages.len(), 1);
464 assert_eq!(checkpoint.custom_messages[0]["type"], "SerializableCustom");
465 assert_eq!(checkpoint.custom_messages[0]["data"]["value"], "hello");
466
467 let json = serde_json::to_string(&checkpoint).unwrap();
469 let restored_cp: Checkpoint = serde_json::from_str(&json).unwrap();
470 assert_eq!(restored_cp.custom_messages.len(), 1);
471
472 let mut registry = CustomMessageRegistry::new();
474 registry.register(
475 "SerializableCustom",
476 Box::new(|val: serde_json::Value| {
477 let value = val
478 .get("value")
479 .and_then(|v| v.as_str())
480 .ok_or_else(|| "missing value".to_string())?;
481 Ok(Box::new(SerializableCustom {
482 value: value.to_string(),
483 }) as Box<dyn crate::types::CustomMessage>)
484 }),
485 );
486
487 let restored = restored_cp.restore_messages(Some(®istry));
488 assert_eq!(restored.len(), 3);
489 assert!(matches!(
490 restored[0],
491 AgentMessage::Llm(LlmMessage::User(_))
492 ));
493 assert!(matches!(
494 restored[1],
495 AgentMessage::Llm(LlmMessage::Assistant(_))
496 ));
497 let custom = restored[2].downcast_ref::<SerializableCustom>().unwrap();
498 assert_eq!(custom.value, "hello");
499
500 let restored_no_reg = restored_cp.restore_messages(None);
502 assert_eq!(restored_no_reg.len(), 2);
503 }
504
505 #[test]
506 fn checkpoint_serde_roundtrip() {
507 let messages = sample_messages();
508 let checkpoint = Checkpoint::new(
509 "cp-roundtrip",
510 "System prompt",
511 "openai",
512 "gpt-4",
513 &messages,
514 )
515 .with_turn_count(5)
516 .with_usage(Usage {
517 input: 100,
518 output: 50,
519 ..Default::default()
520 })
521 .with_cost(Cost {
522 input: 0.01,
523 output: 0.005,
524 ..Default::default()
525 })
526 .with_metadata("session_id", serde_json::json!("sess-abc"));
527
528 let json = serde_json::to_string(&checkpoint).unwrap();
529 let restored: Checkpoint = serde_json::from_str(&json).unwrap();
530
531 assert_eq!(restored.id, "cp-roundtrip");
532 assert_eq!(restored.system_prompt, "System prompt");
533 assert_eq!(restored.messages.len(), 2);
534 assert_eq!(restored.turn_count, 5);
535 assert_eq!(restored.usage.input, 100);
536 assert_eq!(restored.usage.output, 50);
537 assert_eq!(restored.metadata["session_id"], "sess-abc");
538 }
539
540 #[test]
541 fn restore_messages_wraps_in_agent_message() {
542 let messages = sample_messages();
543 let checkpoint =
544 Checkpoint::new("cp-restore", "prompt", "p", "m", &messages).with_turn_count(1);
545
546 let restored = checkpoint.restore_messages(None);
547 assert_eq!(restored.len(), 2);
548 assert!(matches!(
549 restored[0],
550 AgentMessage::Llm(LlmMessage::User(_))
551 ));
552 assert!(matches!(
553 restored[1],
554 AgentMessage::Llm(LlmMessage::Assistant(_))
555 ));
556 }
557
558 #[test]
559 fn checkpoint_with_metadata_builder() {
560 let checkpoint = Checkpoint::new("cp-meta", "p", "p", "m", &[])
561 .with_metadata("key1", serde_json::json!("value1"))
562 .with_metadata("key2", serde_json::json!(42));
563
564 assert_eq!(checkpoint.metadata.len(), 2);
565 assert_eq!(checkpoint.metadata["key1"], "value1");
566 assert_eq!(checkpoint.metadata["key2"], 42);
567 }
568
569 #[test]
570 fn checkpoint_backward_compat_no_metadata() {
571 let json = r#"{
573 "id": "cp-compat",
574 "system_prompt": "hello",
575 "provider": "p",
576 "model_id": "m",
577 "messages": [],
578 "turn_count": 0,
579 "usage": {"input": 0, "output": 0, "cache_read": 0, "cache_write": 0, "total": 0},
580 "cost": {"input": 0.0, "output": 0.0, "cache_read": 0.0, "cache_write": 0.0, "total": 0.0},
581 "created_at": 100
582 }"#;
583
584 let checkpoint: Checkpoint = serde_json::from_str(json).unwrap();
585 assert!(checkpoint.metadata.is_empty());
586 assert!(checkpoint.custom_messages.is_empty());
587 }
588
589 #[test]
592 fn loop_checkpoint_creation_skips_non_serializable_custom_messages() {
593 let mut messages = sample_messages();
594 messages.push(AgentMessage::Custom(Box::new(TestCustom)));
595
596 let cp = LoopCheckpoint::new("Be helpful.", "anthropic", "claude-sonnet", &messages)
597 .with_pending_messages(vec![LlmMessage::User(UserMessage {
598 content: vec![ContentBlock::Text {
599 text: "continue".to_string(),
600 }],
601 timestamp: 123,
602 cache_hint: None,
603 })]);
604
605 assert_eq!(cp.messages.len(), 2);
606 assert!(cp.custom_messages.is_empty());
607 assert_eq!(cp.pending_messages.len(), 1);
608 assert_eq!(cp.system_prompt, "Be helpful.");
609 assert_eq!(cp.provider, "anthropic");
610 assert_eq!(cp.model_id, "claude-sonnet");
611 }
612
613 #[test]
614 fn loop_checkpoint_serde_roundtrip() {
615 let messages = sample_messages();
616 let cp = LoopCheckpoint::new("System prompt", "openai", "gpt-4", &messages)
617 .with_pending_messages(vec![LlmMessage::User(UserMessage {
618 content: vec![ContentBlock::Text {
619 text: "follow-up".to_string(),
620 }],
621 timestamp: 200,
622 cache_hint: None,
623 })])
624 .with_metadata("workflow_id", serde_json::json!("wf-123"));
625
626 let json = serde_json::to_string(&cp).unwrap();
627 let restored: LoopCheckpoint = serde_json::from_str(&json).unwrap();
628
629 assert_eq!(restored.messages.len(), 2);
630 assert_eq!(restored.pending_messages.len(), 1);
631 assert_eq!(restored.system_prompt, "System prompt");
632 assert_eq!(restored.metadata["workflow_id"], "wf-123");
633 }
634
635 #[test]
636 fn loop_checkpoint_restore_messages() {
637 let messages = sample_messages();
638 let cp = LoopCheckpoint::new("p", "p", "m", &messages);
639
640 let restored = cp.restore_messages(None);
641 assert_eq!(restored.len(), 2);
642 assert!(matches!(
643 restored[0],
644 AgentMessage::Llm(LlmMessage::User(_))
645 ));
646 assert!(matches!(
647 restored[1],
648 AgentMessage::Llm(LlmMessage::Assistant(_))
649 ));
650 }
651
652 #[test]
653 fn loop_checkpoint_steering_messages_roundtrip() {
654 let steering = vec![LlmMessage::User(UserMessage {
655 content: vec![ContentBlock::Text {
656 text: "steer-me".to_string(),
657 }],
658 timestamp: 300,
659 cache_hint: None,
660 })];
661 let follow_up = vec![LlmMessage::User(UserMessage {
662 content: vec![ContentBlock::Text {
663 text: "follow-up".to_string(),
664 }],
665 timestamp: 301,
666 cache_hint: None,
667 })];
668
669 let cp = LoopCheckpoint::new("p", "p", "m", &[])
670 .with_pending_messages(follow_up)
671 .with_pending_steering_messages(steering);
672
673 let json = serde_json::to_string(&cp).unwrap();
675 let restored: LoopCheckpoint = serde_json::from_str(&json).unwrap();
676
677 assert_eq!(restored.pending_messages.len(), 1);
678 assert_eq!(restored.pending_steering_messages.len(), 1);
679
680 let restored_steering = restored.restore_pending_steering_messages(None);
681 assert_eq!(restored_steering.len(), 1);
682 assert!(matches!(
683 restored_steering[0],
684 AgentMessage::Llm(LlmMessage::User(_))
685 ));
686 }
687
688 #[test]
689 fn loop_checkpoint_backward_compat_no_steering_field() {
690 let cp =
692 LoopCheckpoint::new("p", "p", "m", &[]).with_pending_messages(vec![LlmMessage::User(
693 UserMessage {
694 content: vec![ContentBlock::Text {
695 text: "old-follow-up".to_string(),
696 }],
697 timestamp: 100,
698 cache_hint: None,
699 },
700 )]);
701
702 let mut json_val = serde_json::to_value(&cp).unwrap();
703 json_val
705 .as_object_mut()
706 .unwrap()
707 .remove("pending_steering_messages");
708 let legacy: LoopCheckpoint = serde_json::from_value(json_val).unwrap();
709
710 assert!(
711 legacy.pending_steering_messages.is_empty(),
712 "missing steering field should default to empty"
713 );
714 assert_eq!(legacy.pending_messages.len(), 1);
715 }
716
717 #[test]
718 fn loop_checkpoint_pending_messages_roundtrip() {
719 let pending = vec![LlmMessage::User(UserMessage {
720 content: vec![ContentBlock::Text {
721 text: "follow-up".to_string(),
722 }],
723 timestamp: 200,
724 cache_hint: None,
725 })];
726
727 let cp = LoopCheckpoint::new("p", "p", "m", &[]).with_pending_messages(pending);
728
729 let restored_pending = cp.restore_pending_messages(None);
730 assert_eq!(restored_pending.len(), 1);
731 assert!(matches!(
732 restored_pending[0],
733 AgentMessage::Llm(LlmMessage::User(_))
734 ));
735 }
736
737 #[test]
738 fn loop_checkpoint_pending_messages_restore_custom_with_registry() {
739 let (registry, factory) = make_registry_and_custom("test");
740 let pending = vec![
741 user_msg("follow-up"),
742 AgentMessage::Custom(factory("pending-custom")),
743 ];
744
745 let cp = LoopCheckpoint::new("p", "p", "m", &[]).with_pending_message_batch(&pending);
746 let restored_pending = cp.restore_pending_messages(Some(®istry));
747
748 assert_eq!(restored_pending.len(), 2);
749 let order: Vec<String> = restored_pending.iter().map(message_text).collect();
750 assert_eq!(order, vec!["user:follow-up", "custom:pending-custom"]);
751 }
752
753 #[test]
754 fn loop_checkpoint_to_standard_checkpoint() {
755 let messages = sample_messages();
756 let cp = LoopCheckpoint::new("prompt", "anthropic", "claude", &messages)
757 .with_metadata("key", serde_json::json!("val"));
758
759 let standard = cp.to_checkpoint("cp-from-loop");
760 assert_eq!(standard.id, "cp-from-loop");
761 assert_eq!(standard.system_prompt, "prompt");
762 assert_eq!(standard.turn_count, 0);
763 assert_eq!(standard.usage.input, 0);
764 assert_eq!(standard.messages.len(), 2);
765 assert_eq!(standard.metadata["key"], "val");
766 }
767
768 fn make_registry_and_custom(
771 tag: &str,
772 ) -> (
773 CustomMessageRegistry,
774 impl Fn(&str) -> Box<dyn crate::types::CustomMessage>,
775 ) {
776 #[derive(Debug, Clone, PartialEq)]
777 struct TaggedCustom {
778 tag: String,
779 }
780
781 impl crate::types::CustomMessage for TaggedCustom {
782 fn as_any(&self) -> &dyn std::any::Any {
783 self
784 }
785 fn type_name(&self) -> Option<&str> {
786 Some("TaggedCustom")
787 }
788 fn to_json(&self) -> Option<serde_json::Value> {
789 Some(serde_json::json!({ "tag": self.tag }))
790 }
791 }
792
793 let _ = tag; let mut registry = CustomMessageRegistry::new();
795 registry.register(
796 "TaggedCustom",
797 Box::new(|val: serde_json::Value| {
798 let tag = val
799 .get("tag")
800 .and_then(|v| v.as_str())
801 .ok_or_else(|| "missing tag".to_string())?;
802 Ok(Box::new(TaggedCustom {
803 tag: tag.to_string(),
804 }) as Box<dyn crate::types::CustomMessage>)
805 }),
806 );
807
808 let factory = |tag: &str| -> Box<dyn crate::types::CustomMessage> {
809 Box::new(TaggedCustom {
810 tag: tag.to_string(),
811 })
812 };
813
814 (registry, factory)
815 }
816
817 fn user_msg(text: &str) -> AgentMessage {
818 AgentMessage::Llm(LlmMessage::User(UserMessage {
819 content: vec![ContentBlock::Text {
820 text: text.to_string(),
821 }],
822 timestamp: 0,
823 cache_hint: None,
824 }))
825 }
826
827 fn assistant_msg(text: &str) -> AgentMessage {
828 AgentMessage::Llm(LlmMessage::Assistant(crate::types::AssistantMessage {
829 content: vec![ContentBlock::Text {
830 text: text.to_string(),
831 }],
832 provider: "test".to_string(),
833 model_id: "test-model".to_string(),
834 usage: Usage::default(),
835 cost: Cost::default(),
836 stop_reason: crate::types::StopReason::Stop,
837 error_message: None,
838 error_kind: None,
839 timestamp: 0,
840 cache_hint: None,
841 }))
842 }
843
844 fn message_text(msg: &AgentMessage) -> String {
846 match msg {
847 AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
848 ContentBlock::Text { text } => format!("user:{text}"),
849 _ => "user:?".to_string(),
850 },
851 AgentMessage::Llm(LlmMessage::Assistant(a)) => match &a.content[0] {
852 ContentBlock::Text { text } => format!("assistant:{text}"),
853 _ => "assistant:?".to_string(),
854 },
855 AgentMessage::Custom(c) => {
856 if let Some(json) = c.to_json() {
857 format!("custom:{}", json["tag"].as_str().unwrap_or("?"))
858 } else {
859 "custom:?".to_string()
860 }
861 }
862 _ => "other".to_string(),
863 }
864 }
865
866 #[test]
867 fn checkpoint_preserves_interleaved_custom_message_order() {
868 let (registry, factory) = make_registry_and_custom("test");
869
870 let messages = vec![
872 user_msg("hello"),
873 AgentMessage::Custom(factory("A")),
874 assistant_msg("hi"),
875 AgentMessage::Custom(factory("B")),
876 user_msg("thanks"),
877 ];
878
879 let checkpoint = Checkpoint::new("cp-order", "prompt", "p", "m", &messages);
880
881 let json = serde_json::to_string(&checkpoint).unwrap();
883 let restored_cp: Checkpoint = serde_json::from_str(&json).unwrap();
884
885 let restored = restored_cp.restore_messages(Some(®istry));
886 let order: Vec<String> = restored.iter().map(message_text).collect();
887
888 assert_eq!(
889 order,
890 vec![
891 "user:hello",
892 "custom:A",
893 "assistant:hi",
894 "custom:B",
895 "user:thanks",
896 ],
897 "interleaved custom messages must preserve their original position"
898 );
899 }
900
901 #[test]
902 fn loop_checkpoint_preserves_interleaved_custom_message_order() {
903 let (registry, factory) = make_registry_and_custom("test");
904
905 let messages = vec![
906 user_msg("q1"),
907 AgentMessage::Custom(factory("mid")),
908 assistant_msg("a1"),
909 ];
910
911 let cp = LoopCheckpoint::new("prompt", "p", "m", &messages);
912
913 let json = serde_json::to_string(&cp).unwrap();
914 let restored_cp: LoopCheckpoint = serde_json::from_str(&json).unwrap();
915
916 let restored = restored_cp.restore_messages(Some(®istry));
917 let order: Vec<String> = restored.iter().map(message_text).collect();
918
919 assert_eq!(
920 order,
921 vec!["user:q1", "custom:mid", "assistant:a1"],
922 "LoopCheckpoint must preserve interleaved custom message order"
923 );
924 }
925
926 #[test]
927 fn loop_checkpoint_to_checkpoint_preserves_order() {
928 let (registry, factory) = make_registry_and_custom("test");
929
930 let messages = vec![
931 AgentMessage::Custom(factory("first")),
932 user_msg("hello"),
933 AgentMessage::Custom(factory("second")),
934 ];
935
936 let loop_cp = LoopCheckpoint::new("prompt", "p", "m", &messages);
937 let standard = loop_cp.to_checkpoint("cp-conv");
938
939 let restored = standard.restore_messages(Some(®istry));
940 let order: Vec<String> = restored.iter().map(message_text).collect();
941
942 assert_eq!(
943 order,
944 vec!["custom:first", "user:hello", "custom:second"],
945 "to_checkpoint conversion must preserve message order"
946 );
947 }
948
949 #[test]
950 fn backward_compat_no_message_order_field() {
951 let (registry, factory) = make_registry_and_custom("test");
954
955 let messages = vec![user_msg("hi"), AgentMessage::Custom(factory("legacy"))];
956
957 let checkpoint = Checkpoint::new("cp-legacy", "hello", "p", "m", &messages);
958 let mut json_val = serde_json::to_value(&checkpoint).unwrap();
960 json_val.as_object_mut().unwrap().remove("message_order");
961 let legacy_cp: Checkpoint = serde_json::from_value(json_val).unwrap();
962
963 assert!(legacy_cp.message_order.is_empty());
965
966 let restored = legacy_cp.restore_messages(Some(®istry));
967 assert_eq!(restored.len(), 2);
969 assert!(matches!(
970 restored[0],
971 AgentMessage::Llm(LlmMessage::User(_))
972 ));
973 let order: Vec<String> = restored.iter().map(message_text).collect();
974 assert_eq!(order, vec!["user:hi", "custom:legacy"]);
975 }
976
977 #[test]
978 fn restore_without_registry_skips_custom_in_ordered_mode() {
979 let (_registry, factory) = make_registry_and_custom("test");
980
981 let messages = vec![
982 user_msg("hello"),
983 AgentMessage::Custom(factory("skipped")),
984 assistant_msg("world"),
985 ];
986
987 let checkpoint = Checkpoint::new("cp-no-reg", "prompt", "p", "m", &messages);
988 let restored = checkpoint.restore_messages(None);
989
990 assert_eq!(restored.len(), 2);
992 let order: Vec<String> = restored.iter().map(message_text).collect();
993 assert_eq!(order, vec!["user:hello", "assistant:world"]);
994 }
995}