synwire_core/agents/
plugin.rs1use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::marker::PhantomData;
6
7use serde_json::Value;
8
9use std::sync::Arc;
10
11use crate::BoxFuture;
12use crate::agents::directive::Directive;
13use crate::agents::streaming::AgentEvent;
14use crate::tools::Tool;
15
16pub trait PluginStateKey: Send + Sync + 'static {
20 type State: Send + Sync + 'static;
22
23 const KEY: &'static str;
25}
26
27pub struct PluginHandle<P: PluginStateKey> {
32 _marker: PhantomData<P>,
33}
34
35impl<P: PluginStateKey> std::fmt::Debug for PluginHandle<P> {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.debug_struct("PluginHandle")
38 .field("key", &P::KEY)
39 .finish()
40 }
41}
42
43impl<P: PluginStateKey> Clone for PluginHandle<P> {
44 fn clone(&self) -> Self {
45 *self
46 }
47}
48
49impl<P: PluginStateKey> Copy for PluginHandle<P> {}
50
51struct PluginStateMeta {
53 value: Box<dyn Any + Send + Sync>,
54 serialize: fn(&dyn Any) -> Option<Value>,
55 key: &'static str,
56}
57
58fn serialize_fn<T: serde::Serialize + 'static>(v: &dyn Any) -> Option<Value> {
59 v.downcast_ref::<T>()
60 .and_then(|t| serde_json::to_value(t).ok())
61}
62
63#[derive(Default)]
68pub struct PluginStateMap {
69 entries: HashMap<TypeId, PluginStateMeta>,
70}
71
72impl std::fmt::Debug for PluginStateMap {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("PluginStateMap")
75 .field("len", &self.entries.len())
76 .finish()
77 }
78}
79
80impl PluginStateMap {
81 #[must_use]
83 pub fn new() -> Self {
84 Self::default()
85 }
86
87 pub fn register<P>(&mut self, state: P::State) -> Result<PluginHandle<P>, &'static str>
93 where
94 P: PluginStateKey,
95 P::State: serde::Serialize + 'static,
96 {
97 let id = TypeId::of::<P>();
98 if self.entries.contains_key(&id) {
99 return Err(P::KEY);
100 }
101
102 let _ = self.entries.insert(
103 id,
104 PluginStateMeta {
105 value: Box::new(state),
106 serialize: serialize_fn::<P::State>,
107 key: P::KEY,
108 },
109 );
110
111 Ok(PluginHandle {
112 _marker: PhantomData,
113 })
114 }
115
116 #[must_use]
118 pub fn get<P: PluginStateKey>(&self) -> Option<&P::State> {
119 self.entries
120 .get(&TypeId::of::<P>())
121 .and_then(|m| m.value.downcast_ref::<P::State>())
122 }
123
124 pub fn get_mut<P: PluginStateKey>(&mut self) -> Option<&mut P::State> {
126 self.entries
127 .get_mut(&TypeId::of::<P>())
128 .and_then(|m| m.value.downcast_mut::<P::State>())
129 }
130
131 pub fn insert<P: PluginStateKey>(&mut self, state: P::State)
133 where
134 P::State: serde::Serialize + 'static,
135 {
136 let _ = self.entries.insert(
137 TypeId::of::<P>(),
138 PluginStateMeta {
139 value: Box::new(state),
140 serialize: serialize_fn::<P::State>,
141 key: P::KEY,
142 },
143 );
144 }
145
146 #[must_use]
148 pub fn serialize_all(&self) -> Value {
149 let mut map = serde_json::Map::new();
150 for meta in self.entries.values() {
151 if let Some(v) = (meta.serialize)(meta.value.as_ref()) {
152 let _ = map.insert(meta.key.to_string(), v);
153 }
154 }
155 Value::Object(map)
156 }
157}
158
159#[derive(Debug, Clone)]
161pub struct PluginInput {
162 pub turn: u32,
164 pub message: Option<String>,
166}
167
168pub trait Plugin: Send + Sync {
173 fn name(&self) -> &str;
175
176 fn on_user_message<'a>(
178 &'a self,
179 _input: &'a PluginInput,
180 _state: &'a PluginStateMap,
181 ) -> BoxFuture<'a, Vec<Directive>> {
182 Box::pin(async { Vec::new() })
183 }
184
185 fn on_event<'a>(
187 &'a self,
188 _event: &'a AgentEvent,
189 _state: &'a PluginStateMap,
190 ) -> BoxFuture<'a, Vec<Directive>> {
191 Box::pin(async { Vec::new() })
192 }
193
194 fn before_run<'a>(&'a self, _state: &'a PluginStateMap) -> BoxFuture<'a, Vec<Directive>> {
196 Box::pin(async { Vec::new() })
197 }
198
199 fn after_run<'a>(&'a self, _state: &'a PluginStateMap) -> BoxFuture<'a, Vec<Directive>> {
201 Box::pin(async { Vec::new() })
202 }
203
204 fn signal_routes(&self) -> Vec<crate::agents::signal::SignalRoute> {
206 Vec::new()
207 }
208
209 fn tools(&self) -> Vec<Arc<dyn Tool>> {
215 Vec::new()
216 }
217}
218
219#[cfg(test)]
220#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
221mod tests {
222 use super::*;
223
224 #[derive(Debug, serde::Serialize, serde::Deserialize)]
225 struct CounterState {
226 count: u32,
227 }
228
229 struct CounterKey;
230
231 impl PluginStateKey for CounterKey {
232 type State = CounterState;
233 const KEY: &'static str = "counter";
234 }
235
236 #[derive(Debug, serde::Serialize, serde::Deserialize)]
237 struct FlagState {
238 enabled: bool,
239 }
240
241 struct FlagKey;
242
243 impl PluginStateKey for FlagKey {
244 type State = FlagState;
245 const KEY: &'static str = "flag";
246 }
247
248 #[test]
249 fn test_type_safe_access() {
250 let mut map = PluginStateMap::new();
251 let _handle = map
252 .register::<CounterKey>(CounterState { count: 0 })
253 .expect("register");
254
255 let state = map.get::<CounterKey>().expect("get");
256 assert_eq!(state.count, 0);
257
258 map.get_mut::<CounterKey>().expect("get_mut").count = 42;
259 assert_eq!(map.get::<CounterKey>().expect("get after mut").count, 42);
260 }
261
262 #[test]
263 fn test_cross_plugin_isolation() {
264 let mut map = PluginStateMap::new();
265 let _ = map.register::<CounterKey>(CounterState { count: 10 });
266 let _ = map.register::<FlagKey>(FlagState { enabled: true });
267
268 assert!(map.get::<CounterKey>().is_some());
269 assert!(map.get::<FlagKey>().is_some());
270
271 map.get_mut::<CounterKey>().expect("mut").count = 99;
272 assert!(map.get::<FlagKey>().expect("flag").enabled);
273 }
274
275 #[test]
276 fn test_key_collision_detection() {
277 let mut map = PluginStateMap::new();
278 let _ = map
279 .register::<CounterKey>(CounterState { count: 0 })
280 .expect("first register");
281
282 let err = map
283 .register::<CounterKey>(CounterState { count: 1 })
284 .expect_err("second register should fail");
285 assert_eq!(err, CounterKey::KEY);
286 }
287
288 #[test]
289 fn test_serialization_round_trip() {
290 let mut map = PluginStateMap::new();
291 let _ = map.register::<CounterKey>(CounterState { count: 7 });
292 let _ = map.register::<FlagKey>(FlagState { enabled: false });
293
294 let serialized = map.serialize_all();
295 assert_eq!(serialized["counter"]["count"], 7);
296 assert_eq!(serialized["flag"]["enabled"], false);
297 }
298}