1use serde::{Deserialize, Serialize};
24
25use super::{
26 AgentMessage, CustomMessageRegistry, LlmMessage, deserialize_custom_message,
27 serialize_custom_message,
28};
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "kind")]
40pub enum MessageSlot {
41 Llm { index: usize },
43 Custom { index: usize },
45}
46
47#[derive(Debug, Clone)]
52pub struct SerializedMessages {
53 pub llm_messages: Vec<LlmMessage>,
55 pub custom_messages: Vec<serde_json::Value>,
57 pub message_order: Vec<MessageSlot>,
59}
60
61pub fn serialize_messages(messages: &[AgentMessage], kind: &str) -> SerializedMessages {
72 let mut llm_messages = Vec::new();
73 let mut custom_messages = Vec::new();
74 let mut message_order = Vec::new();
75
76 for message in messages {
77 match message {
78 AgentMessage::Llm(llm) => {
79 message_order.push(MessageSlot::Llm {
80 index: llm_messages.len(),
81 });
82 llm_messages.push(llm.clone());
83 }
84 AgentMessage::Custom(custom) => {
85 if let Some(envelope) = serialize_custom_message(custom.as_ref()) {
86 message_order.push(MessageSlot::Custom {
87 index: custom_messages.len(),
88 });
89 custom_messages.push(envelope);
90 } else {
91 tracing::warn!(
92 "skipping non-serializable CustomMessage in {kind}: {:?}",
93 custom
94 );
95 }
96 }
97 }
98 }
99
100 SerializedMessages {
101 llm_messages,
102 custom_messages,
103 message_order,
104 }
105}
106
107pub fn restore_messages(
118 llm_messages: &[LlmMessage],
119 custom_messages: &[serde_json::Value],
120 message_order: &[MessageSlot],
121 registry: Option<&CustomMessageRegistry>,
122 kind: &str,
123) -> Vec<AgentMessage> {
124 if !message_order.is_empty() {
125 let mut result = Vec::with_capacity(message_order.len());
126 for slot in message_order {
127 match slot {
128 MessageSlot::Llm { index } => {
129 if let Some(llm) = llm_messages.get(*index) {
130 result.push(AgentMessage::Llm(llm.clone()));
131 }
132 }
133 MessageSlot::Custom { index } => {
134 if let Some(reg) = registry
135 && let Some(envelope) = custom_messages.get(*index)
136 {
137 match deserialize_custom_message(reg, envelope) {
138 Ok(custom) => result.push(AgentMessage::Custom(custom)),
139 Err(error) => {
140 tracing::warn!(
141 "failed to deserialize custom message from {kind}: {error}"
142 );
143 }
144 }
145 }
146 }
147 }
148 }
149 return result;
150 }
151
152 let mut result: Vec<AgentMessage> = llm_messages
154 .iter()
155 .cloned()
156 .map(AgentMessage::Llm)
157 .collect();
158
159 if let Some(reg) = registry {
160 for envelope in custom_messages {
161 match deserialize_custom_message(reg, envelope) {
162 Ok(custom) => result.push(AgentMessage::Custom(custom)),
163 Err(error) => {
164 tracing::warn!("failed to deserialize custom message from {kind}: {error}");
165 }
166 }
167 }
168 }
169
170 result
171}
172
173pub fn restore_single_custom(
180 registry: Option<&CustomMessageRegistry>,
181 envelope: &serde_json::Value,
182) -> Result<Option<Box<dyn super::CustomMessage>>, String> {
183 registry.map_or_else(
184 || Ok(None),
185 |reg| deserialize_custom_message(reg, envelope).map(Some),
186 )
187}
188
189#[derive(Debug, Clone)]
202pub struct SerializedCustomMessage {
203 name: String,
204 json: serde_json::Value,
205}
206
207impl SerializedCustomMessage {
208 #[must_use]
210 pub fn new(name: impl Into<String>, json: serde_json::Value) -> Self {
211 Self {
212 name: name.into(),
213 json,
214 }
215 }
216
217 #[must_use]
221 pub fn from_custom(msg: &dyn super::CustomMessage) -> Option<Self> {
222 Some(Self {
223 name: msg.type_name()?.to_string(),
224 json: msg.to_json()?,
225 })
226 }
227}
228
229impl super::CustomMessage for SerializedCustomMessage {
230 fn as_any(&self) -> &dyn std::any::Any {
231 self
232 }
233 fn type_name(&self) -> Option<&str> {
234 Some(&self.name)
235 }
236 fn to_json(&self) -> Option<serde_json::Value> {
237 Some(self.json.clone())
238 }
239 fn clone_box(&self) -> Option<Box<dyn super::CustomMessage>> {
240 Some(Box::new(self.clone()))
241 }
242}
243
244pub fn clone_messages_for_send(messages: &[AgentMessage]) -> Vec<AgentMessage> {
255 messages
256 .iter()
257 .filter_map(|m| match m {
258 AgentMessage::Llm(llm) => Some(AgentMessage::Llm(llm.clone())),
259 AgentMessage::Custom(custom) => {
260 let snapshot = SerializedCustomMessage::from_custom(custom.as_ref())?;
261 Some(AgentMessage::Custom(Box::new(snapshot)))
262 }
263 })
264 .collect()
265}
266
267#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::types::{
273 AssistantMessage, ContentBlock, Cost, CustomMessage, StopReason, Usage, UserMessage,
274 };
275
276 #[derive(Debug)]
279 struct NonSerializableCustom;
280
281 impl CustomMessage for NonSerializableCustom {
282 fn as_any(&self) -> &dyn std::any::Any {
283 self
284 }
285 }
286
287 #[derive(Debug, Clone, PartialEq)]
288 struct TaggedCustom {
289 tag: String,
290 }
291
292 impl CustomMessage for TaggedCustom {
293 fn as_any(&self) -> &dyn std::any::Any {
294 self
295 }
296 fn type_name(&self) -> Option<&str> {
297 Some("TaggedCustom")
298 }
299 fn to_json(&self) -> Option<serde_json::Value> {
300 Some(serde_json::json!({ "tag": self.tag }))
301 }
302 }
303
304 fn tagged_registry() -> CustomMessageRegistry {
305 let mut reg = CustomMessageRegistry::new();
306 reg.register(
307 "TaggedCustom",
308 Box::new(|val: serde_json::Value| {
309 let tag = val
310 .get("tag")
311 .and_then(|v| v.as_str())
312 .ok_or_else(|| "missing tag".to_string())?;
313 Ok(Box::new(TaggedCustom {
314 tag: tag.to_string(),
315 }) as Box<dyn CustomMessage>)
316 }),
317 );
318 reg
319 }
320
321 fn user_msg(text: &str) -> AgentMessage {
322 AgentMessage::Llm(LlmMessage::User(UserMessage {
323 content: vec![ContentBlock::Text {
324 text: text.to_string(),
325 }],
326 timestamp: 0,
327 cache_hint: None,
328 }))
329 }
330
331 fn assistant_msg(text: &str) -> AgentMessage {
332 AgentMessage::Llm(LlmMessage::Assistant(AssistantMessage {
333 content: vec![ContentBlock::Text {
334 text: text.to_string(),
335 }],
336 provider: "test".to_string(),
337 model_id: "m".to_string(),
338 usage: Usage::default(),
339 cost: Cost::default(),
340 stop_reason: StopReason::Stop,
341 error_message: None,
342 error_kind: None,
343 timestamp: 0,
344 cache_hint: None,
345 }))
346 }
347
348 fn custom_msg(tag: &str) -> AgentMessage {
349 AgentMessage::Custom(Box::new(TaggedCustom {
350 tag: tag.to_string(),
351 }))
352 }
353
354 fn message_label(msg: &AgentMessage) -> String {
355 match msg {
356 AgentMessage::Llm(LlmMessage::User(u)) => {
357 format!("user:{}", ContentBlock::extract_text(&u.content))
358 }
359 AgentMessage::Llm(LlmMessage::Assistant(a)) => {
360 format!("assistant:{}", ContentBlock::extract_text(&a.content))
361 }
362 AgentMessage::Custom(c) => {
363 if let Some(json) = c.to_json() {
364 format!("custom:{}", json["tag"].as_str().unwrap_or("?"))
365 } else {
366 "custom:?".to_string()
367 }
368 }
369 _ => "other".to_string(),
370 }
371 }
372
373 #[test]
376 fn serialize_skips_non_serializable_custom() {
377 let messages = vec![
378 user_msg("hi"),
379 AgentMessage::Custom(Box::new(NonSerializableCustom)),
380 assistant_msg("hello"),
381 ];
382
383 let result = serialize_messages(&messages, "test");
384 assert_eq!(result.llm_messages.len(), 2);
385 assert!(result.custom_messages.is_empty());
386 assert_eq!(result.message_order.len(), 2);
387 }
388
389 #[test]
390 fn serialize_preserves_interleaved_order() {
391 let messages = vec![
392 user_msg("hello"),
393 custom_msg("A"),
394 assistant_msg("hi"),
395 custom_msg("B"),
396 user_msg("thanks"),
397 ];
398
399 let result = serialize_messages(&messages, "test");
400 assert_eq!(result.llm_messages.len(), 3);
401 assert_eq!(result.custom_messages.len(), 2);
402 assert_eq!(result.message_order.len(), 5);
403
404 assert_eq!(result.custom_messages[0]["type"], "TaggedCustom");
406 assert_eq!(result.custom_messages[0]["data"]["tag"], "A");
407 assert_eq!(result.custom_messages[1]["data"]["tag"], "B");
408 }
409
410 #[test]
413 fn roundtrip_preserves_order() {
414 let registry = tagged_registry();
415 let messages = vec![
416 user_msg("hello"),
417 custom_msg("A"),
418 assistant_msg("hi"),
419 custom_msg("B"),
420 user_msg("thanks"),
421 ];
422
423 let serialized = serialize_messages(&messages, "test");
424 let restored = restore_messages(
425 &serialized.llm_messages,
426 &serialized.custom_messages,
427 &serialized.message_order,
428 Some(®istry),
429 "test",
430 );
431
432 let labels: Vec<String> = restored.iter().map(message_label).collect();
433 assert_eq!(
434 labels,
435 vec![
436 "user:hello",
437 "custom:A",
438 "assistant:hi",
439 "custom:B",
440 "user:thanks",
441 ]
442 );
443 }
444
445 #[test]
446 fn restore_without_registry_skips_custom() {
447 let messages = vec![user_msg("hi"), custom_msg("skipped"), assistant_msg("ok")];
448
449 let serialized = serialize_messages(&messages, "test");
450 let restored = restore_messages(
451 &serialized.llm_messages,
452 &serialized.custom_messages,
453 &serialized.message_order,
454 None,
455 "test",
456 );
457
458 assert_eq!(restored.len(), 2);
459 let labels: Vec<String> = restored.iter().map(message_label).collect();
460 assert_eq!(labels, vec!["user:hi", "assistant:ok"]);
461 }
462
463 #[test]
464 fn legacy_fallback_no_ordering() {
465 let registry = tagged_registry();
466 let llm = vec![LlmMessage::User(UserMessage {
467 content: vec![ContentBlock::Text {
468 text: "hi".to_string(),
469 }],
470 timestamp: 0,
471 cache_hint: None,
472 })];
473 let custom = vec![serde_json::json!({
474 "type": "TaggedCustom",
475 "data": { "tag": "legacy" }
476 })];
477
478 let restored = restore_messages(&llm, &custom, &[], Some(®istry), "test");
479 assert_eq!(restored.len(), 2);
480 let labels: Vec<String> = restored.iter().map(message_label).collect();
481 assert_eq!(labels, vec!["user:hi", "custom:legacy"]);
482 }
483
484 #[test]
487 fn restore_single_custom_with_registry() {
488 let registry = tagged_registry();
489 let envelope = serde_json::json!({
490 "type": "TaggedCustom",
491 "data": { "tag": "single" }
492 });
493
494 let result = restore_single_custom(Some(®istry), &envelope).unwrap();
495 assert!(result.is_some());
496 let custom = result.unwrap();
497 assert_eq!(custom.type_name(), Some("TaggedCustom"));
498 }
499
500 #[test]
501 fn restore_single_custom_without_registry() {
502 let envelope = serde_json::json!({ "type": "X", "data": {} });
503 let result = restore_single_custom(None, &envelope).unwrap();
504 assert!(result.is_none());
505 }
506
507 #[test]
510 fn serialized_custom_message_from_custom() {
511 let original = TaggedCustom {
512 tag: "hello".to_string(),
513 };
514 let snapshot = SerializedCustomMessage::from_custom(&original).unwrap();
515 assert_eq!(snapshot.type_name(), Some("TaggedCustom"));
516 assert_eq!(snapshot.to_json().unwrap()["tag"], "hello");
517 }
518
519 #[test]
520 fn serialized_custom_message_from_non_serializable() {
521 let bare = NonSerializableCustom;
522 assert!(SerializedCustomMessage::from_custom(&bare).is_none());
523 }
524
525 #[test]
528 fn clone_for_send_preserves_all_serializable() {
529 let messages = vec![
530 user_msg("hello"),
531 custom_msg("kept"),
532 AgentMessage::Custom(Box::new(NonSerializableCustom)),
533 assistant_msg("world"),
534 ];
535
536 let cloned = clone_messages_for_send(&messages);
537 assert_eq!(cloned.len(), 3); let labels: Vec<String> = cloned.iter().map(message_label).collect();
539 assert_eq!(labels, vec!["user:hello", "custom:kept", "assistant:world"]);
540 }
541
542 #[test]
543 fn clone_for_send_custom_roundtrips_through_registry() {
544 let registry = tagged_registry();
545 let messages = vec![custom_msg("roundtrip")];
546 let cloned = clone_messages_for_send(&messages);
547 assert_eq!(cloned.len(), 1);
548
549 let envelope =
551 serialize_custom_message(cloned[0].downcast_ref::<SerializedCustomMessage>().unwrap())
552 .unwrap();
553 let restored = deserialize_custom_message(®istry, &envelope).unwrap();
554 assert_eq!(
555 restored
556 .as_any()
557 .downcast_ref::<TaggedCustom>()
558 .unwrap()
559 .tag,
560 "roundtrip"
561 );
562 }
563}