Skip to main content

swink_agent/types/
custom_message.rs

1use std::collections::HashMap;
2use std::fmt;
3
4use serde::{Deserialize, Serialize};
5
6use super::LlmMessage;
7
8// ─── Custom Messages ────────────────────────────────────────────────────────
9
10/// Trait for application-defined custom message types.
11///
12/// Allows downstream code to attach application-specific message types
13/// (e.g. notifications, artifacts) to the message history without modifying
14/// the harness.
15///
16/// ## Serialization
17///
18/// To support store/load of conversations containing custom messages, implement
19/// [`type_name`](Self::type_name) and [`to_json`](Self::to_json), then register
20/// a deserializer with [`CustomMessageRegistry`].
21pub trait CustomMessage: Send + Sync + fmt::Debug + std::any::Any {
22    /// Downcast helper. Returns `self` as `&dyn Any` for type-safe downcasting.
23    fn as_any(&self) -> &dyn std::any::Any;
24
25    /// A unique, stable identifier for this custom message type.
26    ///
27    /// Used as the discriminator when serializing. Must match the key
28    /// registered in [`CustomMessageRegistry`]. Returns `None` if
29    /// serialization is not supported.
30    fn type_name(&self) -> Option<&str> {
31        None
32    }
33
34    /// Serialize this custom message to a JSON value.
35    ///
36    /// Returns `None` if serialization is not supported (the default).
37    fn to_json(&self) -> Option<serde_json::Value> {
38        None
39    }
40
41    /// Clone this custom message into a new boxed trait object.
42    ///
43    /// Returns `None` if the underlying type does not support cloning
44    /// (the default). Implement this to preserve custom messages across
45    /// stream-driven state rebuilds (e.g. `handle_stream_event`).
46    fn clone_box(&self) -> Option<Box<dyn CustomMessage>> {
47        None
48    }
49}
50
51/// A function that deserializes a JSON value into a boxed [`CustomMessage`].
52pub type CustomMessageDeserializer =
53    Box<dyn Fn(serde_json::Value) -> Result<Box<dyn CustomMessage>, String> + Send + Sync>;
54
55/// Registry for deserializing [`CustomMessage`] types from JSON.
56///
57/// Each custom message type that supports serialization must register a
58/// deserializer keyed by its [`CustomMessage::type_name`].
59pub struct CustomMessageRegistry {
60    deserializers: HashMap<String, CustomMessageDeserializer>,
61}
62
63impl CustomMessageRegistry {
64    /// Create an empty registry.
65    #[must_use]
66    pub fn new() -> Self {
67        Self {
68            deserializers: HashMap::new(),
69        }
70    }
71
72    /// Register a deserializer for a custom message type.
73    ///
74    /// The `type_name` must match the value returned by the corresponding
75    /// [`CustomMessage::type_name`] implementation.
76    pub fn register(
77        &mut self,
78        type_name: impl Into<String>,
79        deserializer: CustomMessageDeserializer,
80    ) {
81        self.deserializers.insert(type_name.into(), deserializer);
82    }
83
84    /// Convenience method: register a type that implements `serde::Deserialize`.
85    ///
86    /// Equivalent to calling [`register`](Self::register) with a closure that
87    /// deserializes via `serde_json::from_value`.
88    pub fn register_type<T>(&mut self, type_name: impl Into<String>)
89    where
90        T: CustomMessage + serde::de::DeserializeOwned + 'static,
91    {
92        self.deserializers.insert(
93            type_name.into(),
94            Box::new(|value| {
95                serde_json::from_value::<T>(value)
96                    .map(|v| Box::new(v) as Box<dyn CustomMessage>)
97                    .map_err(|e| e.to_string())
98            }),
99        );
100    }
101
102    /// Deserialize a custom message from its type name and JSON payload.
103    ///
104    /// # Errors
105    ///
106    /// Returns `Err` if no deserializer is registered for `type_name` or if
107    /// deserialization fails.
108    pub fn deserialize(
109        &self,
110        type_name: &str,
111        value: serde_json::Value,
112    ) -> Result<Box<dyn CustomMessage>, String> {
113        let deser = self.deserializers.get(type_name).ok_or_else(|| {
114            format!("no deserializer registered for custom message type: {type_name}")
115        })?;
116        deser(value)
117    }
118
119    /// Returns `true` if a deserializer is registered for `type_name`.
120    #[must_use]
121    pub fn has_type_name(&self, type_name: &str) -> bool {
122        self.deserializers.contains_key(type_name)
123    }
124}
125
126impl Default for CustomMessageRegistry {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132impl fmt::Debug for CustomMessageRegistry {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        f.debug_struct("CustomMessageRegistry")
135            .field(
136                "registered_types",
137                &self.deserializers.keys().collect::<Vec<_>>(),
138            )
139            .finish()
140    }
141}
142
143/// Serialize a [`CustomMessage`] into a portable JSON envelope.
144///
145/// Returns `None` if the message does not support serialization (i.e.
146/// `type_name()` or `to_json()` returns `None`).
147#[must_use]
148pub fn serialize_custom_message(msg: &dyn CustomMessage) -> Option<serde_json::Value> {
149    let type_name = msg.type_name()?;
150    let payload = msg.to_json()?;
151    Some(serde_json::json!({
152        "type": type_name,
153        "data": payload,
154    }))
155}
156
157/// Deserialize a [`CustomMessage`] from a JSON envelope produced by
158/// [`serialize_custom_message`].
159///
160/// # Errors
161///
162/// Returns `Err` if the envelope is malformed, the type is unknown, or
163/// deserialization fails.
164pub fn deserialize_custom_message(
165    registry: &CustomMessageRegistry,
166    envelope: &serde_json::Value,
167) -> Result<Box<dyn CustomMessage>, String> {
168    let type_name = envelope
169        .get("type")
170        .and_then(|v| v.as_str())
171        .ok_or_else(|| "missing 'type' field in custom message envelope".to_string())?;
172    let data = envelope
173        .get("data")
174        .cloned()
175        .ok_or_else(|| "missing 'data' field in custom message envelope".to_string())?;
176    registry.deserialize(type_name, data)
177}
178
179/// The top-level message type that wraps either an LLM message or a custom
180/// application-defined message.
181///
182/// Implements [`Serialize`] so events containing messages can be forwarded
183/// across process boundaries. `Llm` variants delegate to the derived impl;
184/// `Custom` variants serialize via [`serialize_custom_message`] (or as
185/// `null` if the custom message does not support serialization).
186#[allow(clippy::large_enum_variant)]
187pub enum AgentMessage {
188    /// A standard LLM message (user, assistant, or tool result).
189    Llm(LlmMessage),
190
191    /// A custom application-defined message.
192    Custom(Box<dyn CustomMessage>),
193}
194
195impl AgentMessage {
196    /// Get the cache hint for this message (if any).
197    ///
198    /// Returns `None` for `Custom` messages (they are never sent to the LLM).
199    pub const fn cache_hint(&self) -> Option<&crate::context_cache::CacheHint> {
200        match self {
201            Self::Llm(msg) => match msg {
202                LlmMessage::User(m) => m.cache_hint.as_ref(),
203                LlmMessage::Assistant(m) => m.cache_hint.as_ref(),
204                LlmMessage::ToolResult(m) => m.cache_hint.as_ref(),
205            },
206            Self::Custom(_) => None,
207        }
208    }
209
210    /// Set the cache hint on this message.
211    ///
212    /// No-op for `Custom` messages.
213    pub const fn set_cache_hint(&mut self, hint: crate::context_cache::CacheHint) {
214        match self {
215            Self::Llm(msg) => match msg {
216                LlmMessage::User(m) => m.cache_hint = Some(hint),
217                LlmMessage::Assistant(m) => m.cache_hint = Some(hint),
218                LlmMessage::ToolResult(m) => m.cache_hint = Some(hint),
219            },
220            Self::Custom(_) => {}
221        }
222    }
223
224    /// Clear the cache hint on this message.
225    pub const fn clear_cache_hint(&mut self) {
226        match self {
227            Self::Llm(msg) => match msg {
228                LlmMessage::User(m) => m.cache_hint = None,
229                LlmMessage::Assistant(m) => m.cache_hint = None,
230                LlmMessage::ToolResult(m) => m.cache_hint = None,
231            },
232            Self::Custom(_) => {}
233        }
234    }
235
236    /// Attempt to downcast the inner custom message to a concrete type.
237    ///
238    /// Returns `Ok(&T)` if this is a `Custom` variant and the inner type matches `T`.
239    /// Returns `Err(DowncastError)` if this is an `Llm` variant or the type does not match.
240    pub fn downcast_ref<T: 'static>(&self) -> Result<&T, crate::error::DowncastError> {
241        match self {
242            Self::Custom(msg) => {
243                msg.as_any()
244                    .downcast_ref::<T>()
245                    .ok_or_else(|| crate::error::DowncastError {
246                        expected: std::any::type_name::<T>(),
247                        actual: msg
248                            .type_name()
249                            .map_or_else(|| format!("{msg:?}"), ToString::to_string),
250                    })
251            }
252            Self::Llm(_) => Err(crate::error::DowncastError {
253                expected: std::any::type_name::<T>(),
254                actual: "LlmMessage".to_string(),
255            }),
256        }
257    }
258}
259
260impl fmt::Debug for AgentMessage {
261    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262        match self {
263            Self::Llm(msg) => f.debug_tuple("Llm").field(msg).finish(),
264            Self::Custom(msg) => f.debug_tuple("Custom").field(msg).finish(),
265        }
266    }
267}
268
269impl Serialize for AgentMessage {
270    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
271        match self {
272            Self::Llm(msg) => {
273                use serde::ser::SerializeMap;
274                let mut map = serializer.serialize_map(Some(2))?;
275                map.serialize_entry("kind", "llm")?;
276                map.serialize_entry("message", msg)?;
277                map.end()
278            }
279            Self::Custom(msg) => {
280                use serde::ser::SerializeMap;
281                let mut map = serializer.serialize_map(Some(2))?;
282                map.serialize_entry("kind", "custom")?;
283                // Use the existing envelope helper; falls back to null.
284                let envelope = serialize_custom_message(msg.as_ref());
285                map.serialize_entry("message", &envelope)?;
286                map.end()
287            }
288        }
289    }
290}
291
292impl<'de> Deserialize<'de> for AgentMessage {
293    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
294        #[derive(Deserialize)]
295        struct Tagged {
296            kind: String,
297            message: serde_json::Value,
298        }
299
300        let tagged = Tagged::deserialize(deserializer)?;
301        match tagged.kind.as_str() {
302            "llm" => {
303                let msg: LlmMessage =
304                    serde_json::from_value(tagged.message).map_err(serde::de::Error::custom)?;
305                Ok(Self::Llm(msg))
306            }
307            "custom" => Err(serde::de::Error::custom(
308                "cannot deserialize AgentMessage::Custom (requires runtime type info)",
309            )),
310            other => Err(serde::de::Error::unknown_variant(other, &["llm", "custom"])),
311        }
312    }
313}