Skip to main content

synwire_core/agents/
plugin.rs

1//! Plugin system with isolated state.
2
3use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::marker::PhantomData;
6
7use serde_json::Value;
8
9use std::sync::Arc;
10
11use crate::BoxFuture;
12use crate::agents::directive::Directive;
13use crate::agents::streaming::AgentEvent;
14use crate::tools::Tool;
15
16/// Typed key for plugin state stored in a [`PluginStateMap`].
17///
18/// Implement this trait to define a plugin's state type.
19pub trait PluginStateKey: Send + Sync + 'static {
20    /// The state type stored for this key.
21    type State: Send + Sync + 'static;
22
23    /// Unique string key for serialization.
24    const KEY: &'static str;
25}
26
27/// Zero-sized proof token returned when a plugin state is registered.
28///
29/// Holding a `PluginHandle<P>` proves that the plugin `P` has been registered
30/// in the associated `PluginStateMap`.
31pub struct PluginHandle<P: PluginStateKey> {
32    _marker: PhantomData<P>,
33}
34
35impl<P: PluginStateKey> std::fmt::Debug for PluginHandle<P> {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        f.debug_struct("PluginHandle")
38            .field("key", &P::KEY)
39            .finish()
40    }
41}
42
43impl<P: PluginStateKey> Clone for PluginHandle<P> {
44    fn clone(&self) -> Self {
45        *self
46    }
47}
48
49impl<P: PluginStateKey> Copy for PluginHandle<P> {}
50
51/// Type-erased serializer stored alongside plugin state.
52struct PluginStateMeta {
53    value: Box<dyn Any + Send + Sync>,
54    serialize: fn(&dyn Any) -> Option<Value>,
55    key: &'static str,
56}
57
58fn serialize_fn<T: serde::Serialize + 'static>(v: &dyn Any) -> Option<Value> {
59    v.downcast_ref::<T>()
60        .and_then(|t| serde_json::to_value(t).ok())
61}
62
63/// Type-keyed map for plugin state with serialization support.
64///
65/// Provides type-safe access keyed by [`PluginStateKey`] implementations.
66/// Plugins cannot access each other's state — the type key enforces isolation.
67#[derive(Default)]
68pub struct PluginStateMap {
69    entries: HashMap<TypeId, PluginStateMeta>,
70}
71
72impl std::fmt::Debug for PluginStateMap {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_struct("PluginStateMap")
75            .field("len", &self.entries.len())
76            .finish()
77    }
78}
79
80impl PluginStateMap {
81    /// Create an empty map.
82    #[must_use]
83    pub fn new() -> Self {
84        Self::default()
85    }
86
87    /// Register plugin state, returning a proof handle.
88    ///
89    /// # Errors
90    ///
91    /// Returns the key string if a plugin with the same `TypeId` is already registered.
92    pub fn register<P>(&mut self, state: P::State) -> Result<PluginHandle<P>, &'static str>
93    where
94        P: PluginStateKey,
95        P::State: serde::Serialize + 'static,
96    {
97        let id = TypeId::of::<P>();
98        if self.entries.contains_key(&id) {
99            return Err(P::KEY);
100        }
101
102        let _ = self.entries.insert(
103            id,
104            PluginStateMeta {
105                value: Box::new(state),
106                serialize: serialize_fn::<P::State>,
107                key: P::KEY,
108            },
109        );
110
111        Ok(PluginHandle {
112            _marker: PhantomData,
113        })
114    }
115
116    /// Get an immutable reference to plugin state.
117    #[must_use]
118    pub fn get<P: PluginStateKey>(&self) -> Option<&P::State> {
119        self.entries
120            .get(&TypeId::of::<P>())
121            .and_then(|m| m.value.downcast_ref::<P::State>())
122    }
123
124    /// Get a mutable reference to plugin state.
125    pub fn get_mut<P: PluginStateKey>(&mut self) -> Option<&mut P::State> {
126        self.entries
127            .get_mut(&TypeId::of::<P>())
128            .and_then(|m| m.value.downcast_mut::<P::State>())
129    }
130
131    /// Insert or replace plugin state.
132    pub fn insert<P: PluginStateKey>(&mut self, state: P::State)
133    where
134        P::State: serde::Serialize + 'static,
135    {
136        let _ = self.entries.insert(
137            TypeId::of::<P>(),
138            PluginStateMeta {
139                value: Box::new(state),
140                serialize: serialize_fn::<P::State>,
141                key: P::KEY,
142            },
143        );
144    }
145
146    /// Serialize all plugin state to a JSON object keyed by plugin key strings.
147    #[must_use]
148    pub fn serialize_all(&self) -> Value {
149        let mut map = serde_json::Map::new();
150        for meta in self.entries.values() {
151            if let Some(v) = (meta.serialize)(meta.value.as_ref()) {
152                let _ = map.insert(meta.key.to_string(), v);
153            }
154        }
155        Value::Object(map)
156    }
157}
158
159/// Input passed to plugin lifecycle hooks.
160#[derive(Debug, Clone)]
161pub struct PluginInput {
162    /// Current conversation turn index.
163    pub turn: u32,
164    /// Optional user message text.
165    pub message: Option<String>,
166}
167
168/// Plugin lifecycle trait.
169///
170/// All methods have default no-op implementations so plugins only need to
171/// override the hooks they care about.
172pub trait Plugin: Send + Sync {
173    /// Plugin name (used for debugging and logging).
174    fn name(&self) -> &str;
175
176    /// Called when a user message arrives.
177    fn on_user_message<'a>(
178        &'a self,
179        _input: &'a PluginInput,
180        _state: &'a PluginStateMap,
181    ) -> BoxFuture<'a, Vec<Directive>> {
182        Box::pin(async { Vec::new() })
183    }
184
185    /// Called when an agent event is emitted.
186    fn on_event<'a>(
187        &'a self,
188        _event: &'a AgentEvent,
189        _state: &'a PluginStateMap,
190    ) -> BoxFuture<'a, Vec<Directive>> {
191        Box::pin(async { Vec::new() })
192    }
193
194    /// Called before each agent run loop iteration.
195    fn before_run<'a>(&'a self, _state: &'a PluginStateMap) -> BoxFuture<'a, Vec<Directive>> {
196        Box::pin(async { Vec::new() })
197    }
198
199    /// Called after each agent run loop iteration.
200    fn after_run<'a>(&'a self, _state: &'a PluginStateMap) -> BoxFuture<'a, Vec<Directive>> {
201        Box::pin(async { Vec::new() })
202    }
203
204    /// Signal routes contributed by this plugin.
205    fn signal_routes(&self) -> Vec<crate::agents::signal::SignalRoute> {
206        Vec::new()
207    }
208
209    /// Tools contributed by this plugin.
210    ///
211    /// Called once during agent construction. The returned tools are merged
212    /// into the agent's tool registry alongside any tools provided directly
213    /// via `Agent::with_tools`. Tool names must not conflict.
214    fn tools(&self) -> Vec<Arc<dyn Tool>> {
215        Vec::new()
216    }
217}
218
219#[cfg(test)]
220#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
221mod tests {
222    use super::*;
223
224    #[derive(Debug, serde::Serialize, serde::Deserialize)]
225    struct CounterState {
226        count: u32,
227    }
228
229    struct CounterKey;
230
231    impl PluginStateKey for CounterKey {
232        type State = CounterState;
233        const KEY: &'static str = "counter";
234    }
235
236    #[derive(Debug, serde::Serialize, serde::Deserialize)]
237    struct FlagState {
238        enabled: bool,
239    }
240
241    struct FlagKey;
242
243    impl PluginStateKey for FlagKey {
244        type State = FlagState;
245        const KEY: &'static str = "flag";
246    }
247
248    #[test]
249    fn test_type_safe_access() {
250        let mut map = PluginStateMap::new();
251        let _handle = map
252            .register::<CounterKey>(CounterState { count: 0 })
253            .expect("register");
254
255        let state = map.get::<CounterKey>().expect("get");
256        assert_eq!(state.count, 0);
257
258        map.get_mut::<CounterKey>().expect("get_mut").count = 42;
259        assert_eq!(map.get::<CounterKey>().expect("get after mut").count, 42);
260    }
261
262    #[test]
263    fn test_cross_plugin_isolation() {
264        let mut map = PluginStateMap::new();
265        let _ = map.register::<CounterKey>(CounterState { count: 10 });
266        let _ = map.register::<FlagKey>(FlagState { enabled: true });
267
268        assert!(map.get::<CounterKey>().is_some());
269        assert!(map.get::<FlagKey>().is_some());
270
271        map.get_mut::<CounterKey>().expect("mut").count = 99;
272        assert!(map.get::<FlagKey>().expect("flag").enabled);
273    }
274
275    #[test]
276    fn test_key_collision_detection() {
277        let mut map = PluginStateMap::new();
278        let _ = map
279            .register::<CounterKey>(CounterState { count: 0 })
280            .expect("first register");
281
282        let err = map
283            .register::<CounterKey>(CounterState { count: 1 })
284            .expect_err("second register should fail");
285        assert_eq!(err, CounterKey::KEY);
286    }
287
288    #[test]
289    fn test_serialization_round_trip() {
290        let mut map = PluginStateMap::new();
291        let _ = map.register::<CounterKey>(CounterState { count: 7 });
292        let _ = map.register::<FlagKey>(FlagState { enabled: false });
293
294        let serialized = map.serialize_all();
295        assert_eq!(serialized["counter"]["count"], 7);
296        assert_eq!(serialized["flag"]["enabled"], false);
297    }
298}