Skip to main content

swink_agent/types/
message_codec.rs

1//! Shared codec for serializing and deserializing [`AgentMessage`] batches.
2//!
3//! Consolidates the message-envelope logic previously duplicated across
4//! checkpoints, JSONL session storage, and blocking async adapters into a
5//! single module.
6//!
7//! ## Provided functionality
8//!
9//! - [`MessageSlot`] — records the original position of each message in an
10//!   interleaved LLM/custom sequence.
11//! - [`SerializedMessages`] — the result of splitting a `&[AgentMessage]`
12//!   into separate LLM and custom vectors with ordering metadata.
13//! - [`serialize_messages`] / [`restore_messages`] — batch serialization and
14//!   deserialization with interleaved ordering.
15//! - [`restore_single_custom`] — restore one custom-message envelope via a
16//!   registry (useful for line-oriented formats like JSONL).
17//! - [`SerializedCustomMessage`] — a lightweight [`CustomMessage`](super::CustomMessage)
18//!   implementation that holds pre-serialized `type_name` + `to_json` data,
19//!   enabling transfer across thread or process boundaries.
20//! - [`clone_messages_for_send`] — snapshot a slice of `AgentMessage` into
21//!   fully `Send + Clone`-safe values for crossing `spawn_blocking` or IPC.
22
23use serde::{Deserialize, Serialize};
24
25use super::{
26    AgentMessage, CustomMessageRegistry, LlmMessage, deserialize_custom_message,
27    serialize_custom_message,
28};
29
30// ─── MessageSlot ────────────────────────────────────────────────────────────
31
32/// Tracks the original position of each message in the sequence.
33///
34/// During serialization, LLM and custom messages are stored in separate
35/// vectors for backward compatibility. `MessageSlot` records the original
36/// ordering so that [`restore_messages`] can reconstruct the interleaved
37/// sequence faithfully.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "kind")]
40pub enum MessageSlot {
41    /// An LLM message at the given index in the `messages` vector.
42    Llm { index: usize },
43    /// A custom message at the given index in the `custom_messages` vector.
44    Custom { index: usize },
45}
46
47// ─── SerializedMessages ─────────────────────────────────────────────────────
48
49/// The result of splitting an `AgentMessage` slice into LLM and custom
50/// vectors, plus ordering metadata.
51#[derive(Debug, Clone)]
52pub struct SerializedMessages {
53    /// LLM messages in insertion order.
54    pub llm_messages: Vec<LlmMessage>,
55    /// Custom message envelopes (`{"type": "…", "data": {…}}`).
56    pub custom_messages: Vec<serde_json::Value>,
57    /// Records the original interleaved order of LLM and custom messages.
58    pub message_order: Vec<MessageSlot>,
59}
60
61// ─── Batch serialize / restore ──────────────────────────────────────────────
62
63/// Split a slice of [`AgentMessage`] into separate LLM and custom vectors
64/// with ordering metadata.
65///
66/// Custom messages that do not support serialization (`type_name()` or
67/// `to_json()` returns `None`) are skipped with a `tracing::warn`.
68///
69/// `kind` is a human-readable label used in log messages (e.g. "checkpoint",
70/// "session").
71pub fn serialize_messages(messages: &[AgentMessage], kind: &str) -> SerializedMessages {
72    let mut llm_messages = Vec::new();
73    let mut custom_messages = Vec::new();
74    let mut message_order = Vec::new();
75
76    for message in messages {
77        match message {
78            AgentMessage::Llm(llm) => {
79                message_order.push(MessageSlot::Llm {
80                    index: llm_messages.len(),
81                });
82                llm_messages.push(llm.clone());
83            }
84            AgentMessage::Custom(custom) => {
85                if let Some(envelope) = serialize_custom_message(custom.as_ref()) {
86                    message_order.push(MessageSlot::Custom {
87                        index: custom_messages.len(),
88                    });
89                    custom_messages.push(envelope);
90                } else {
91                    tracing::warn!(
92                        "skipping non-serializable CustomMessage in {kind}: {:?}",
93                        custom
94                    );
95                }
96            }
97        }
98    }
99
100    SerializedMessages {
101        llm_messages,
102        custom_messages,
103        message_order,
104    }
105}
106
107/// Reconstruct an interleaved `Vec<AgentMessage>` from separate LLM and
108/// custom vectors, using [`MessageSlot`] ordering metadata.
109///
110/// If `message_order` is empty (legacy data created before ordering support),
111/// falls back to LLM messages first, then custom messages appended.
112///
113/// If `registry` is `None`, custom messages are silently skipped.
114/// Deserialization failures are logged as warnings.
115///
116/// `kind` is used in log messages (e.g. "checkpoint", "session").
117pub fn restore_messages(
118    llm_messages: &[LlmMessage],
119    custom_messages: &[serde_json::Value],
120    message_order: &[MessageSlot],
121    registry: Option<&CustomMessageRegistry>,
122    kind: &str,
123) -> Vec<AgentMessage> {
124    if !message_order.is_empty() {
125        let mut result = Vec::with_capacity(message_order.len());
126        for slot in message_order {
127            match slot {
128                MessageSlot::Llm { index } => {
129                    if let Some(llm) = llm_messages.get(*index) {
130                        result.push(AgentMessage::Llm(llm.clone()));
131                    }
132                }
133                MessageSlot::Custom { index } => {
134                    if let Some(reg) = registry
135                        && let Some(envelope) = custom_messages.get(*index)
136                    {
137                        match deserialize_custom_message(reg, envelope) {
138                            Ok(custom) => result.push(AgentMessage::Custom(custom)),
139                            Err(error) => {
140                                tracing::warn!(
141                                    "failed to deserialize custom message from {kind}: {error}"
142                                );
143                            }
144                        }
145                    }
146                }
147            }
148        }
149        return result;
150    }
151
152    // Legacy fallback: LLM messages first, then custom messages appended.
153    let mut result: Vec<AgentMessage> = llm_messages
154        .iter()
155        .cloned()
156        .map(AgentMessage::Llm)
157        .collect();
158
159    if let Some(reg) = registry {
160        for envelope in custom_messages {
161            match deserialize_custom_message(reg, envelope) {
162                Ok(custom) => result.push(AgentMessage::Custom(custom)),
163                Err(error) => {
164                    tracing::warn!("failed to deserialize custom message from {kind}: {error}");
165                }
166            }
167        }
168    }
169
170    result
171}
172
173// ─── Single-envelope restore ────────────────────────────────────────────────
174
175/// Restore a single custom-message envelope via a registry.
176///
177/// Returns `Ok(Some(msg))` on success, `Ok(None)` if the registry is `None`,
178/// or `Err(reason)` if deserialization fails.
179pub fn restore_single_custom(
180    registry: Option<&CustomMessageRegistry>,
181    envelope: &serde_json::Value,
182) -> Result<Option<Box<dyn super::CustomMessage>>, String> {
183    registry.map_or_else(
184        || Ok(None),
185        |reg| deserialize_custom_message(reg, envelope).map(Some),
186    )
187}
188
189// ─── SerializedCustomMessage ────────────────────────────────────────────────
190
191/// A lightweight [`CustomMessage`](super::CustomMessage) stand-in that holds pre-serialized data.
192///
193/// Useful for ferrying custom messages across `spawn_blocking` boundaries or
194/// other contexts where the original `Box<dyn CustomMessage>` (which is
195/// neither `Clone` nor necessarily transferable) must be replaced with a
196/// plain-data snapshot.
197///
198/// Implements `CustomMessage` so it can be stored in `AgentMessage::Custom`
199/// and round-trips faithfully through `serialize_custom_message` /
200/// `deserialize_custom_message`.
201#[derive(Debug, Clone)]
202pub struct SerializedCustomMessage {
203    name: String,
204    json: serde_json::Value,
205}
206
207impl SerializedCustomMessage {
208    /// Create a new serialized custom message from a name and JSON payload.
209    #[must_use]
210    pub fn new(name: impl Into<String>, json: serde_json::Value) -> Self {
211        Self {
212            name: name.into(),
213            json,
214        }
215    }
216
217    /// Attempt to create a `SerializedCustomMessage` from a `dyn CustomMessage`.
218    ///
219    /// Returns `None` if the custom message does not support serialization.
220    #[must_use]
221    pub fn from_custom(msg: &dyn super::CustomMessage) -> Option<Self> {
222        Some(Self {
223            name: msg.type_name()?.to_string(),
224            json: msg.to_json()?,
225        })
226    }
227}
228
229impl super::CustomMessage for SerializedCustomMessage {
230    fn as_any(&self) -> &dyn std::any::Any {
231        self
232    }
233    fn type_name(&self) -> Option<&str> {
234        Some(&self.name)
235    }
236    fn to_json(&self) -> Option<serde_json::Value> {
237        Some(self.json.clone())
238    }
239    fn clone_box(&self) -> Option<Box<dyn super::CustomMessage>> {
240        Some(Box::new(self.clone()))
241    }
242}
243
244// ─── clone_messages_for_send ────────────────────────────────────────────────
245
246/// Snapshot a slice of [`AgentMessage`] into fully `Send + Clone`-safe values.
247///
248/// `Llm` variants are cloned directly. `Custom` variants are
249/// snapshot-serialized into [`SerializedCustomMessage`] wrappers so they can
250/// cross thread (`spawn_blocking`) or process (IPC) boundaries faithfully.
251///
252/// Custom messages that lack `type_name()` or `to_json()` are silently
253/// dropped — matching the existing behavior of `serialize_custom_message`.
254pub fn clone_messages_for_send(messages: &[AgentMessage]) -> Vec<AgentMessage> {
255    messages
256        .iter()
257        .filter_map(|m| match m {
258            AgentMessage::Llm(llm) => Some(AgentMessage::Llm(llm.clone())),
259            AgentMessage::Custom(custom) => {
260                let snapshot = SerializedCustomMessage::from_custom(custom.as_ref())?;
261                Some(AgentMessage::Custom(Box::new(snapshot)))
262            }
263        })
264        .collect()
265}
266
267// ─── Tests ──────────────────────────────────────────────────────────────────
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::types::{
273        AssistantMessage, ContentBlock, Cost, CustomMessage, StopReason, Usage, UserMessage,
274    };
275
276    // ── Test helpers ────────────────────────────────────────────────────────
277
278    #[derive(Debug)]
279    struct NonSerializableCustom;
280
281    impl CustomMessage for NonSerializableCustom {
282        fn as_any(&self) -> &dyn std::any::Any {
283            self
284        }
285    }
286
287    #[derive(Debug, Clone, PartialEq)]
288    struct TaggedCustom {
289        tag: String,
290    }
291
292    impl CustomMessage for TaggedCustom {
293        fn as_any(&self) -> &dyn std::any::Any {
294            self
295        }
296        fn type_name(&self) -> Option<&str> {
297            Some("TaggedCustom")
298        }
299        fn to_json(&self) -> Option<serde_json::Value> {
300            Some(serde_json::json!({ "tag": self.tag }))
301        }
302    }
303
304    fn tagged_registry() -> CustomMessageRegistry {
305        let mut reg = CustomMessageRegistry::new();
306        reg.register(
307            "TaggedCustom",
308            Box::new(|val: serde_json::Value| {
309                let tag = val
310                    .get("tag")
311                    .and_then(|v| v.as_str())
312                    .ok_or_else(|| "missing tag".to_string())?;
313                Ok(Box::new(TaggedCustom {
314                    tag: tag.to_string(),
315                }) as Box<dyn CustomMessage>)
316            }),
317        );
318        reg
319    }
320
321    fn user_msg(text: &str) -> AgentMessage {
322        AgentMessage::Llm(LlmMessage::User(UserMessage {
323            content: vec![ContentBlock::Text {
324                text: text.to_string(),
325            }],
326            timestamp: 0,
327            cache_hint: None,
328        }))
329    }
330
331    fn assistant_msg(text: &str) -> AgentMessage {
332        AgentMessage::Llm(LlmMessage::Assistant(AssistantMessage {
333            content: vec![ContentBlock::Text {
334                text: text.to_string(),
335            }],
336            provider: "test".to_string(),
337            model_id: "m".to_string(),
338            usage: Usage::default(),
339            cost: Cost::default(),
340            stop_reason: StopReason::Stop,
341            error_message: None,
342            error_kind: None,
343            timestamp: 0,
344            cache_hint: None,
345        }))
346    }
347
348    fn custom_msg(tag: &str) -> AgentMessage {
349        AgentMessage::Custom(Box::new(TaggedCustom {
350            tag: tag.to_string(),
351        }))
352    }
353
354    fn message_label(msg: &AgentMessage) -> String {
355        match msg {
356            AgentMessage::Llm(LlmMessage::User(u)) => {
357                format!("user:{}", ContentBlock::extract_text(&u.content))
358            }
359            AgentMessage::Llm(LlmMessage::Assistant(a)) => {
360                format!("assistant:{}", ContentBlock::extract_text(&a.content))
361            }
362            AgentMessage::Custom(c) => {
363                if let Some(json) = c.to_json() {
364                    format!("custom:{}", json["tag"].as_str().unwrap_or("?"))
365                } else {
366                    "custom:?".to_string()
367                }
368            }
369            _ => "other".to_string(),
370        }
371    }
372
373    // ── serialize_messages ──────────────────────────────────────────────────
374
375    #[test]
376    fn serialize_skips_non_serializable_custom() {
377        let messages = vec![
378            user_msg("hi"),
379            AgentMessage::Custom(Box::new(NonSerializableCustom)),
380            assistant_msg("hello"),
381        ];
382
383        let result = serialize_messages(&messages, "test");
384        assert_eq!(result.llm_messages.len(), 2);
385        assert!(result.custom_messages.is_empty());
386        assert_eq!(result.message_order.len(), 2);
387    }
388
389    #[test]
390    fn serialize_preserves_interleaved_order() {
391        let messages = vec![
392            user_msg("hello"),
393            custom_msg("A"),
394            assistant_msg("hi"),
395            custom_msg("B"),
396            user_msg("thanks"),
397        ];
398
399        let result = serialize_messages(&messages, "test");
400        assert_eq!(result.llm_messages.len(), 3);
401        assert_eq!(result.custom_messages.len(), 2);
402        assert_eq!(result.message_order.len(), 5);
403
404        // Verify envelope content
405        assert_eq!(result.custom_messages[0]["type"], "TaggedCustom");
406        assert_eq!(result.custom_messages[0]["data"]["tag"], "A");
407        assert_eq!(result.custom_messages[1]["data"]["tag"], "B");
408    }
409
410    // ── restore_messages ───────────────────────────────────────────────────
411
412    #[test]
413    fn roundtrip_preserves_order() {
414        let registry = tagged_registry();
415        let messages = vec![
416            user_msg("hello"),
417            custom_msg("A"),
418            assistant_msg("hi"),
419            custom_msg("B"),
420            user_msg("thanks"),
421        ];
422
423        let serialized = serialize_messages(&messages, "test");
424        let restored = restore_messages(
425            &serialized.llm_messages,
426            &serialized.custom_messages,
427            &serialized.message_order,
428            Some(&registry),
429            "test",
430        );
431
432        let labels: Vec<String> = restored.iter().map(message_label).collect();
433        assert_eq!(
434            labels,
435            vec![
436                "user:hello",
437                "custom:A",
438                "assistant:hi",
439                "custom:B",
440                "user:thanks",
441            ]
442        );
443    }
444
445    #[test]
446    fn restore_without_registry_skips_custom() {
447        let messages = vec![user_msg("hi"), custom_msg("skipped"), assistant_msg("ok")];
448
449        let serialized = serialize_messages(&messages, "test");
450        let restored = restore_messages(
451            &serialized.llm_messages,
452            &serialized.custom_messages,
453            &serialized.message_order,
454            None,
455            "test",
456        );
457
458        assert_eq!(restored.len(), 2);
459        let labels: Vec<String> = restored.iter().map(message_label).collect();
460        assert_eq!(labels, vec!["user:hi", "assistant:ok"]);
461    }
462
463    #[test]
464    fn legacy_fallback_no_ordering() {
465        let registry = tagged_registry();
466        let llm = vec![LlmMessage::User(UserMessage {
467            content: vec![ContentBlock::Text {
468                text: "hi".to_string(),
469            }],
470            timestamp: 0,
471            cache_hint: None,
472        })];
473        let custom = vec![serde_json::json!({
474            "type": "TaggedCustom",
475            "data": { "tag": "legacy" }
476        })];
477
478        let restored = restore_messages(&llm, &custom, &[], Some(&registry), "test");
479        assert_eq!(restored.len(), 2);
480        let labels: Vec<String> = restored.iter().map(message_label).collect();
481        assert_eq!(labels, vec!["user:hi", "custom:legacy"]);
482    }
483
484    // ── restore_single_custom ──────────────────────────────────────────────
485
486    #[test]
487    fn restore_single_custom_with_registry() {
488        let registry = tagged_registry();
489        let envelope = serde_json::json!({
490            "type": "TaggedCustom",
491            "data": { "tag": "single" }
492        });
493
494        let result = restore_single_custom(Some(&registry), &envelope).unwrap();
495        assert!(result.is_some());
496        let custom = result.unwrap();
497        assert_eq!(custom.type_name(), Some("TaggedCustom"));
498    }
499
500    #[test]
501    fn restore_single_custom_without_registry() {
502        let envelope = serde_json::json!({ "type": "X", "data": {} });
503        let result = restore_single_custom(None, &envelope).unwrap();
504        assert!(result.is_none());
505    }
506
507    // ── SerializedCustomMessage ────────────────────────────────────────────
508
509    #[test]
510    fn serialized_custom_message_from_custom() {
511        let original = TaggedCustom {
512            tag: "hello".to_string(),
513        };
514        let snapshot = SerializedCustomMessage::from_custom(&original).unwrap();
515        assert_eq!(snapshot.type_name(), Some("TaggedCustom"));
516        assert_eq!(snapshot.to_json().unwrap()["tag"], "hello");
517    }
518
519    #[test]
520    fn serialized_custom_message_from_non_serializable() {
521        let bare = NonSerializableCustom;
522        assert!(SerializedCustomMessage::from_custom(&bare).is_none());
523    }
524
525    // ── clone_messages_for_send ────────────────────────────────────────────
526
527    #[test]
528    fn clone_for_send_preserves_all_serializable() {
529        let messages = vec![
530            user_msg("hello"),
531            custom_msg("kept"),
532            AgentMessage::Custom(Box::new(NonSerializableCustom)),
533            assistant_msg("world"),
534        ];
535
536        let cloned = clone_messages_for_send(&messages);
537        assert_eq!(cloned.len(), 3); // non-serializable custom dropped
538        let labels: Vec<String> = cloned.iter().map(message_label).collect();
539        assert_eq!(labels, vec!["user:hello", "custom:kept", "assistant:world"]);
540    }
541
542    #[test]
543    fn clone_for_send_custom_roundtrips_through_registry() {
544        let registry = tagged_registry();
545        let messages = vec![custom_msg("roundtrip")];
546        let cloned = clone_messages_for_send(&messages);
547        assert_eq!(cloned.len(), 1);
548
549        // The cloned custom message can be serialized and restored
550        let envelope =
551            serialize_custom_message(cloned[0].downcast_ref::<SerializedCustomMessage>().unwrap())
552                .unwrap();
553        let restored = deserialize_custom_message(&registry, &envelope).unwrap();
554        assert_eq!(
555            restored
556                .as_any()
557                .downcast_ref::<TaggedCustom>()
558                .unwrap()
559                .tag,
560            "roundtrip"
561        );
562    }
563}