swink_agent/types/
custom_message.rs1use std::collections::HashMap;
2use std::fmt;
3
4use serde::{Deserialize, Serialize};
5
6use super::LlmMessage;
7
8pub trait CustomMessage: Send + Sync + fmt::Debug + std::any::Any {
22 fn as_any(&self) -> &dyn std::any::Any;
24
25 fn type_name(&self) -> Option<&str> {
31 None
32 }
33
34 fn to_json(&self) -> Option<serde_json::Value> {
38 None
39 }
40
41 fn clone_box(&self) -> Option<Box<dyn CustomMessage>> {
47 None
48 }
49}
50
51pub type CustomMessageDeserializer =
53 Box<dyn Fn(serde_json::Value) -> Result<Box<dyn CustomMessage>, String> + Send + Sync>;
54
55pub struct CustomMessageRegistry {
60 deserializers: HashMap<String, CustomMessageDeserializer>,
61}
62
63impl CustomMessageRegistry {
64 #[must_use]
66 pub fn new() -> Self {
67 Self {
68 deserializers: HashMap::new(),
69 }
70 }
71
72 pub fn register(
77 &mut self,
78 type_name: impl Into<String>,
79 deserializer: CustomMessageDeserializer,
80 ) {
81 self.deserializers.insert(type_name.into(), deserializer);
82 }
83
84 pub fn register_type<T>(&mut self, type_name: impl Into<String>)
89 where
90 T: CustomMessage + serde::de::DeserializeOwned + 'static,
91 {
92 self.deserializers.insert(
93 type_name.into(),
94 Box::new(|value| {
95 serde_json::from_value::<T>(value)
96 .map(|v| Box::new(v) as Box<dyn CustomMessage>)
97 .map_err(|e| e.to_string())
98 }),
99 );
100 }
101
102 pub fn deserialize(
109 &self,
110 type_name: &str,
111 value: serde_json::Value,
112 ) -> Result<Box<dyn CustomMessage>, String> {
113 let deser = self.deserializers.get(type_name).ok_or_else(|| {
114 format!("no deserializer registered for custom message type: {type_name}")
115 })?;
116 deser(value)
117 }
118
119 #[must_use]
121 pub fn has_type_name(&self, type_name: &str) -> bool {
122 self.deserializers.contains_key(type_name)
123 }
124}
125
126impl Default for CustomMessageRegistry {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132impl fmt::Debug for CustomMessageRegistry {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("CustomMessageRegistry")
135 .field(
136 "registered_types",
137 &self.deserializers.keys().collect::<Vec<_>>(),
138 )
139 .finish()
140 }
141}
142
143#[must_use]
148pub fn serialize_custom_message(msg: &dyn CustomMessage) -> Option<serde_json::Value> {
149 let type_name = msg.type_name()?;
150 let payload = msg.to_json()?;
151 Some(serde_json::json!({
152 "type": type_name,
153 "data": payload,
154 }))
155}
156
157pub fn deserialize_custom_message(
165 registry: &CustomMessageRegistry,
166 envelope: &serde_json::Value,
167) -> Result<Box<dyn CustomMessage>, String> {
168 let type_name = envelope
169 .get("type")
170 .and_then(|v| v.as_str())
171 .ok_or_else(|| "missing 'type' field in custom message envelope".to_string())?;
172 let data = envelope
173 .get("data")
174 .cloned()
175 .ok_or_else(|| "missing 'data' field in custom message envelope".to_string())?;
176 registry.deserialize(type_name, data)
177}
178
179#[allow(clippy::large_enum_variant)]
187pub enum AgentMessage {
188 Llm(LlmMessage),
190
191 Custom(Box<dyn CustomMessage>),
193}
194
195impl AgentMessage {
196 pub const fn cache_hint(&self) -> Option<&crate::context_cache::CacheHint> {
200 match self {
201 Self::Llm(msg) => match msg {
202 LlmMessage::User(m) => m.cache_hint.as_ref(),
203 LlmMessage::Assistant(m) => m.cache_hint.as_ref(),
204 LlmMessage::ToolResult(m) => m.cache_hint.as_ref(),
205 },
206 Self::Custom(_) => None,
207 }
208 }
209
210 pub const fn set_cache_hint(&mut self, hint: crate::context_cache::CacheHint) {
214 match self {
215 Self::Llm(msg) => match msg {
216 LlmMessage::User(m) => m.cache_hint = Some(hint),
217 LlmMessage::Assistant(m) => m.cache_hint = Some(hint),
218 LlmMessage::ToolResult(m) => m.cache_hint = Some(hint),
219 },
220 Self::Custom(_) => {}
221 }
222 }
223
224 pub const fn clear_cache_hint(&mut self) {
226 match self {
227 Self::Llm(msg) => match msg {
228 LlmMessage::User(m) => m.cache_hint = None,
229 LlmMessage::Assistant(m) => m.cache_hint = None,
230 LlmMessage::ToolResult(m) => m.cache_hint = None,
231 },
232 Self::Custom(_) => {}
233 }
234 }
235
236 pub fn downcast_ref<T: 'static>(&self) -> Result<&T, crate::error::DowncastError> {
241 match self {
242 Self::Custom(msg) => {
243 msg.as_any()
244 .downcast_ref::<T>()
245 .ok_or_else(|| crate::error::DowncastError {
246 expected: std::any::type_name::<T>(),
247 actual: msg
248 .type_name()
249 .map_or_else(|| format!("{msg:?}"), ToString::to_string),
250 })
251 }
252 Self::Llm(_) => Err(crate::error::DowncastError {
253 expected: std::any::type_name::<T>(),
254 actual: "LlmMessage".to_string(),
255 }),
256 }
257 }
258}
259
260impl fmt::Debug for AgentMessage {
261 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262 match self {
263 Self::Llm(msg) => f.debug_tuple("Llm").field(msg).finish(),
264 Self::Custom(msg) => f.debug_tuple("Custom").field(msg).finish(),
265 }
266 }
267}
268
269impl Serialize for AgentMessage {
270 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
271 match self {
272 Self::Llm(msg) => {
273 use serde::ser::SerializeMap;
274 let mut map = serializer.serialize_map(Some(2))?;
275 map.serialize_entry("kind", "llm")?;
276 map.serialize_entry("message", msg)?;
277 map.end()
278 }
279 Self::Custom(msg) => {
280 use serde::ser::SerializeMap;
281 let mut map = serializer.serialize_map(Some(2))?;
282 map.serialize_entry("kind", "custom")?;
283 let envelope = serialize_custom_message(msg.as_ref());
285 map.serialize_entry("message", &envelope)?;
286 map.end()
287 }
288 }
289 }
290}
291
292impl<'de> Deserialize<'de> for AgentMessage {
293 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
294 #[derive(Deserialize)]
295 struct Tagged {
296 kind: String,
297 message: serde_json::Value,
298 }
299
300 let tagged = Tagged::deserialize(deserializer)?;
301 match tagged.kind.as_str() {
302 "llm" => {
303 let msg: LlmMessage =
304 serde_json::from_value(tagged.message).map_err(serde::de::Error::custom)?;
305 Ok(Self::Llm(msg))
306 }
307 "custom" => Err(serde::de::Error::custom(
308 "cannot deserialize AgentMessage::Custom (requires runtime type info)",
309 )),
310 other => Err(serde::de::Error::unknown_variant(other, &["llm", "custom"])),
311 }
312 }
313}