Skip to main content

swink_agent/
agent.rs

1//! Stateful public API wrapper over the agent loop.
2//!
3//! The [`Agent`] struct owns conversation state, configuration, and queue
4//! management. It provides three invocation modes (`prompt_stream`,
5//! `prompt_async`, `prompt_sync`), structured output extraction, steering and
6//! follow-up queues, and an observer/subscriber pattern for event dispatch.
7//!
8//! Configuration is split into [`AgentOptions`] (defined in [`crate::agent_options`])
9//! and subscription management is in [`crate::agent_subscriptions`].
10
11#[path = "agent/checkpointing.rs"]
12mod checkpointing;
13#[path = "agent/control.rs"]
14mod control;
15#[path = "agent/events.rs"]
16mod events;
17#[path = "agent/invoke.rs"]
18mod invoke;
19#[path = "agent/mutation.rs"]
20mod mutation;
21#[path = "agent/queueing.rs"]
22mod queueing;
23#[path = "agent/state_updates.rs"]
24mod state_updates;
25#[path = "agent/structured_output.rs"]
26mod structured_output;
27
28use std::collections::{HashSet, VecDeque};
29use std::sync::atomic::{AtomicBool, AtomicU64};
30use std::sync::{Arc, Mutex};
31
32use tokio::sync::Notify;
33use tokio_util::sync::CancellationToken;
34
35use crate::agent_id::AgentId;
36use crate::agent_options::{
37    ApproveToolArc, AsyncTransformContextArc, CheckpointStoreArc, ConvertToLlmFn, GetApiKeyArc,
38    TransformContextArc,
39};
40use crate::agent_subscriptions::ListenerRegistry;
41use crate::error::AgentError;
42use crate::message_provider::MessageProvider;
43use crate::retry::RetryStrategy;
44use crate::stream::{StreamFn, StreamOptions};
45use crate::tool::{AgentTool, ApprovalMode};
46use crate::types::{AgentMessage, LlmMessage, ModelSpec};
47
48// Re-export so `lib.rs` can still do `pub use agent::{AgentOptions, SubscriptionId, ...}`.
49pub use crate::agent_options::{AgentOptions, DEFAULT_PLAN_MODE_ADDENDUM};
50pub use crate::agent_subscriptions::SubscriptionId;
51
52// ─── Enums / modes ───────────────────────────────────────────────────────────
53
54/// Controls how steering messages are drained from the queue.
55#[non_exhaustive]
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
57pub enum SteeringMode {
58    /// Drain all pending steering messages at once.
59    All,
60    /// Drain one steering message per poll.
61    #[default]
62    OneAtATime,
63}
64
65/// Controls how follow-up messages are drained from the queue.
66#[non_exhaustive]
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
68pub enum FollowUpMode {
69    /// Drain all pending follow-up messages at once.
70    All,
71    /// Drain one follow-up message per poll.
72    #[default]
73    OneAtATime,
74}
75
76// ─── AgentState ──────────────────────────────────────────────────────────────
77
78/// Observable state of the agent.
79pub struct AgentState {
80    /// The system prompt sent to the LLM.
81    pub system_prompt: String,
82    /// The model specification.
83    pub model: ModelSpec,
84    /// Available tools.
85    pub tools: Vec<Arc<dyn AgentTool>>,
86    /// Full conversation history.
87    pub messages: Vec<AgentMessage>,
88    /// Whether the agent loop is currently executing.
89    pub is_running: bool,
90    /// The message currently being streamed (if any).
91    pub stream_message: Option<AgentMessage>,
92    /// Tool call IDs that are currently executing.
93    pub pending_tool_calls: HashSet<String>,
94    /// Last error from a run, if any.
95    pub error: Option<String>,
96    /// Available model specifications for model cycling.
97    pub available_models: Vec<ModelSpec>,
98}
99
100impl std::fmt::Debug for AgentState {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        f.debug_struct("AgentState")
103            .field("system_prompt", &self.system_prompt)
104            .field("model", &self.model)
105            .field("tools", &format_args!("[{} tool(s)]", self.tools.len()))
106            .field("messages", &self.messages)
107            .field("is_running", &self.is_running)
108            .field("stream_message", &self.stream_message)
109            .field("pending_tool_calls", &self.pending_tool_calls)
110            .field("error", &self.error)
111            .field("available_models", &self.available_models)
112            .finish()
113    }
114}
115
116// ─── Helpers ─────────────────────────────────────────────────────────────────
117
118/// Default message converter: pass LLM messages through, drop custom messages.
119///
120/// This is the standard converter for most use cases. Custom messages are
121/// filtered out since they are not meant to be sent to the LLM provider.
122pub fn default_convert(msg: &AgentMessage) -> Option<LlmMessage> {
123    match msg {
124        AgentMessage::Llm(llm) => Some(llm.clone()),
125        AgentMessage::Custom(_) => None,
126    }
127}
128
129type ModelStreamRegistry = Vec<(ModelSpec, Arc<dyn StreamFn>)>;
130
131fn available_models_and_stream_fns(
132    options: &AgentOptions,
133) -> (Vec<ModelSpec>, ModelStreamRegistry) {
134    let primary_model = options.model.clone();
135    let primary_stream_fn = Arc::clone(&options.stream_fn);
136    let mut available_models = vec![options.model.clone()];
137    available_models.extend(
138        options
139            .available_models
140            .iter()
141            .map(|(model, _): &(ModelSpec, _)| model.clone()),
142    );
143    let mut model_stream_fns = vec![(primary_model, primary_stream_fn)];
144    model_stream_fns.extend(
145        options
146            .available_models
147            .iter()
148            .map(|(model, stream_fn): &(ModelSpec, _)| (model.clone(), Arc::clone(stream_fn))),
149    );
150
151    (available_models, model_stream_fns)
152}
153
154fn assert_unique_tool_names(tools: &[Arc<dyn AgentTool>]) {
155    let mut seen = HashSet::with_capacity(tools.len());
156    let mut duplicates = Vec::new();
157
158    for tool in tools {
159        let name = tool.name();
160        if !seen.insert(name.to_owned()) {
161            duplicates.push(name.to_owned());
162        }
163    }
164
165    if !duplicates.is_empty() {
166        duplicates.sort();
167        duplicates.dedup();
168        panic!(
169            "duplicate tool names are not allowed after composition: {}",
170            duplicates.join(", ")
171        );
172    }
173}
174
175#[cfg(feature = "plugins")]
176fn dispatch_plugin_on_init(agent: &Agent) {
177    // Dispatch on_init to each plugin in priority order (already sorted).
178    // Panics are caught and logged — the plugin's other contributions
179    // (policies, tools, event observers) remain active.
180    for plugin in &agent.plugins {
181        let name = plugin.name().to_owned();
182        let plugin_ref = Arc::clone(plugin);
183        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
184            plugin_ref.on_init(agent);
185        }));
186        if let Err(cause) = result {
187            let msg = cause
188                .downcast_ref::<&str>()
189                .map(|s| (*s).to_owned())
190                .or_else(|| cause.downcast_ref::<String>().cloned())
191                .unwrap_or_else(|| "unknown panic".to_owned());
192            tracing::warn!(plugin = %name, error = %msg, "plugin on_init panicked");
193        }
194    }
195}
196
197#[cfg(not(feature = "plugins"))]
198const fn dispatch_plugin_on_init(_agent: &Agent) {}
199// ─── Agent ───────────────────────────────────────────────────────────────────
200
201/// Stateful wrapper over the agent loop.
202///
203/// Owns conversation history, configuration, steering/follow-up queues, and
204/// subscriber callbacks. Provides prompt, continue, and structured output
205/// invocation modes.
206pub struct Agent {
207    // ── Identity ──
208    id: AgentId,
209
210    // ── Public-facing state ──
211    state: AgentState,
212
213    // ── Private fields ──
214    steering_queue: Arc<Mutex<VecDeque<AgentMessage>>>,
215    follow_up_queue: Arc<Mutex<VecDeque<AgentMessage>>>,
216    listeners: ListenerRegistry,
217    abort_controller: Option<CancellationToken>,
218    steering_mode: SteeringMode,
219    follow_up_mode: FollowUpMode,
220    stream_fn: Arc<dyn StreamFn>,
221    convert_to_llm: ConvertToLlmFn,
222    transform_context: Option<TransformContextArc>,
223    get_api_key: Option<GetApiKeyArc>,
224    retry_strategy: Arc<dyn RetryStrategy>,
225    stream_options: StreamOptions,
226    structured_output_max_retries: usize,
227    idle_notify: Arc<Notify>,
228    in_flight_llm_messages: Option<Vec<AgentMessage>>,
229    in_flight_messages: Option<Vec<AgentMessage>>,
230    pending_message_snapshot: Arc<crate::pause_state::PendingMessageSnapshot>,
231    loop_context_snapshot: Arc<crate::pause_state::LoopContextSnapshot>,
232    approve_tool: Option<ApproveToolArc>,
233    approval_mode: ApprovalMode,
234    pre_turn_policies: Vec<Arc<dyn crate::policy::PreTurnPolicy>>,
235    pre_dispatch_policies: Vec<Arc<dyn crate::policy::PreDispatchPolicy>>,
236    post_turn_policies: Vec<Arc<dyn crate::policy::PostTurnPolicy>>,
237    post_loop_policies: Vec<Arc<dyn crate::policy::PostLoopPolicy>>,
238    /// Extra `model/stream_fn` pairs for model cycling.
239    model_stream_fns: Vec<(ModelSpec, Arc<dyn StreamFn>)>,
240    /// Event forwarders that receive cloned events after listener dispatch.
241    event_forwarders: Vec<crate::event_forwarder::EventForwarderFn>,
242    /// Optional async context transformer.
243    async_transform_context: Option<AsyncTransformContextArc>,
244    /// Optional checkpoint store.
245    checkpoint_store: Option<CheckpointStoreArc>,
246    /// Optional registry for decoding persisted custom messages on restore.
247    pub(crate) custom_message_registry: Option<Arc<crate::types::CustomMessageRegistry>>,
248    /// Optional metrics collector.
249    metrics_collector: Option<Arc<dyn crate::metrics::MetricsCollector>>,
250    /// Optional model fallback chain.
251    fallback: Option<crate::fallback::ModelFallback>,
252    /// Optional external message provider.
253    external_message_provider: Option<Arc<dyn MessageProvider>>,
254    /// Tool execution policy.
255    tool_execution_policy: crate::tool_execution_policy::ToolExecutionPolicy,
256    /// Optional plan mode addendum (falls back to `DEFAULT_PLAN_MODE_ADDENDUM`).
257    plan_mode_addendum: Option<String>,
258    /// Session key-value state store shared with tools and policies.
259    session_state: Arc<std::sync::RwLock<crate::SessionState>>,
260    /// Optional credential resolver for tool authentication.
261    credential_resolver: Option<Arc<dyn crate::credential::CredentialResolver>>,
262    /// Optional context caching configuration.
263    cache_config: Option<crate::context_cache::CacheConfig>,
264    /// Optional dynamic system prompt.
265    dynamic_system_prompt: Option<Arc<dyn Fn() -> String + Send + Sync>>,
266    /// Shared flag: true while a spawned loop task is active. Set to false by
267    /// the `LoopGuardStream` wrapper on drop or by `collect_stream`/`AgentEnd`.
268    /// Used by `check_not_running` and `wait_for_idle` as the ground truth for
269    /// whether a loop is still in progress (instead of `state.is_running` which
270    /// may lag on the stream-drop path).
271    loop_active: Arc<AtomicBool>,
272    /// Monotonically increasing counter bumped on each `start_loop`. Prevents a
273    /// stale `LoopGuardStream` from clearing `loop_active` for a newer run.
274    loop_generation: Arc<AtomicU64>,
275    /// Registered plugins retained for runtime introspection (priority-sorted).
276    #[cfg(feature = "plugins")]
277    plugins: Vec<Arc<dyn crate::plugin::Plugin>>,
278    /// Optional agent name for transfer chain safety enforcement.
279    #[allow(clippy::struct_field_names)]
280    agent_name: Option<String>,
281    /// Optional transfer chain state carried from a previous handoff.
282    transfer_chain: Option<crate::transfer::TransferChain>,
283}
284
285impl Agent {
286    /// Create a new agent from the given options.
287    #[must_use]
288    pub fn new(options: AgentOptions) -> Self {
289        // Merge plugin contributions (policies, tools, event observers) into options.
290        #[cfg(feature = "plugins")]
291        let options = merge_plugin_contributions(options);
292
293        assert_unique_tool_names(&options.tools);
294
295        // Compute the effective system prompt before partial moves.
296        let effective_prompt = options.effective_system_prompt().to_owned();
297        let (available_models, model_stream_fns) = available_models_and_stream_fns(&options);
298
299        // If a custom token counter is provided and no custom transform_context
300        // was set, rebuild the default SlidingWindowTransformer with the counter.
301        let transform_context = match (options.token_counter, options.transform_context) {
302            (Some(counter), None) => Some(Arc::new(
303                crate::context_transformer::SlidingWindowTransformer::new(100_000, 50_000, 2)
304                    .with_token_counter(counter),
305            ) as TransformContextArc),
306            (_, tc) => tc,
307        };
308
309        let agent = Self {
310            id: AgentId::next(),
311            state: AgentState {
312                system_prompt: effective_prompt,
313                model: options.model,
314                tools: options.tools,
315                messages: Vec::new(),
316                is_running: false,
317                stream_message: None,
318                pending_tool_calls: HashSet::new(),
319                error: None,
320                available_models,
321            },
322            steering_queue: Arc::new(Mutex::new(VecDeque::new())),
323            follow_up_queue: Arc::new(Mutex::new(VecDeque::new())),
324            listeners: ListenerRegistry::new(),
325            abort_controller: None,
326            steering_mode: options.steering_mode,
327            follow_up_mode: options.follow_up_mode,
328            stream_fn: options.stream_fn,
329            convert_to_llm: options.convert_to_llm,
330            transform_context,
331            get_api_key: options.get_api_key,
332            retry_strategy: Arc::from(options.retry_strategy),
333            stream_options: options.stream_options,
334            structured_output_max_retries: options.structured_output_max_retries,
335            idle_notify: Arc::new(Notify::new()),
336            in_flight_llm_messages: None,
337            in_flight_messages: None,
338            pending_message_snapshot: Arc::new(
339                crate::pause_state::PendingMessageSnapshot::default(),
340            ),
341            loop_context_snapshot: Arc::new(crate::pause_state::LoopContextSnapshot::default()),
342            approve_tool: options.approve_tool,
343            approval_mode: options.approval_mode,
344            pre_turn_policies: options.pre_turn_policies,
345            pre_dispatch_policies: options.pre_dispatch_policies,
346            post_turn_policies: options.post_turn_policies,
347            post_loop_policies: options.post_loop_policies,
348            model_stream_fns,
349            event_forwarders: options.event_forwarders,
350            async_transform_context: options.async_transform_context,
351            checkpoint_store: options.checkpoint_store,
352            custom_message_registry: options.custom_message_registry,
353            metrics_collector: options.metrics_collector,
354            fallback: options.fallback,
355            external_message_provider: options.external_message_provider,
356            tool_execution_policy: options.tool_execution_policy,
357            plan_mode_addendum: options.plan_mode_addendum,
358            session_state: Arc::new(std::sync::RwLock::new(
359                options.session_state.unwrap_or_default(),
360            )),
361            credential_resolver: options.credential_resolver,
362            cache_config: options.cache_config,
363            dynamic_system_prompt: options.dynamic_system_prompt.map(Arc::from),
364            loop_active: Arc::new(AtomicBool::new(false)),
365            loop_generation: Arc::new(AtomicU64::new(0)),
366            #[cfg(feature = "plugins")]
367            plugins: options.plugins,
368            agent_name: options.agent_name,
369            transfer_chain: options.transfer_chain,
370        };
371
372        dispatch_plugin_on_init(&agent);
373
374        agent
375    }
376
377    /// Returns this agent's unique identifier.
378    #[must_use]
379    pub const fn id(&self) -> AgentId {
380        self.id
381    }
382
383    /// Access the current agent state.
384    ///
385    /// Note: [`AgentState::is_running`] may lag behind the true loop lifecycle
386    /// after a stream is dropped. Use [`Agent::is_running`] for an accurate
387    /// real-time check.
388    #[must_use]
389    pub const fn state(&self) -> &AgentState {
390        &self.state
391    }
392
393    /// Returns whether a loop is currently active.
394    ///
395    /// This reads an atomic flag that is cleared immediately when the event
396    /// stream is dropped or drained to `AgentEnd`, making it accurate even in
397    /// the window between dropping a stream and the next `&mut self` call.
398    #[must_use]
399    pub fn is_running(&self) -> bool {
400        self.loop_active.load(std::sync::atomic::Ordering::Acquire)
401    }
402
403    /// Access the session key-value state (thread-safe, shared reference).
404    #[must_use]
405    pub const fn session_state(&self) -> &Arc<std::sync::RwLock<crate::SessionState>> {
406        &self.session_state
407    }
408
409    /// Access the custom message registry, if one was configured.
410    ///
411    /// Useful for passing to `SessionStore::load` (from `swink-agent-memory`) so that
412    /// persisted custom messages are deserialized instead of silently dropped.
413    #[must_use]
414    pub fn custom_message_registry(&self) -> Option<&crate::types::CustomMessageRegistry> {
415        self.custom_message_registry.as_deref()
416    }
417
418    /// Returns all registered plugins sorted by priority (highest first).
419    #[cfg(feature = "plugins")]
420    #[must_use]
421    pub fn plugins(&self) -> &[Arc<dyn crate::plugin::Plugin>] {
422        &self.plugins
423    }
424
425    /// Look up a registered plugin by name.
426    #[cfg(feature = "plugins")]
427    #[must_use]
428    pub fn plugin(&self, name: &str) -> Option<&Arc<dyn crate::plugin::Plugin>> {
429        self.plugins.iter().find(|p| p.name() == name)
430    }
431}
432
433impl std::fmt::Debug for Agent {
434    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435        f.debug_struct("Agent")
436            .field("state", &self.state)
437            .field("steering_mode", &self.steering_mode)
438            .field("follow_up_mode", &self.follow_up_mode)
439            .field(
440                "listeners",
441                &format_args!("{} listener(s)", self.listeners.len()),
442            )
443            .field("is_abort_active", &self.abort_controller.is_some())
444            .finish_non_exhaustive()
445    }
446}
447
448// ─── Plugin merge ───────────────────────────────────────────────────────────
449
450/// Sort plugins by priority and merge their policies, tools, and event
451/// observers into the `AgentOptions`. Plugin policies are prepended before
452/// directly-registered policies; plugin tools are appended after direct tools
453/// (namespaced with the plugin name); plugin event forwarders are prepended.
454#[cfg(feature = "plugins")]
455fn merge_plugin_contributions(mut options: AgentOptions) -> AgentOptions {
456    // Sort plugins by priority descending (stable sort preserves insertion order for ties).
457    options
458        .plugins
459        .sort_by_key(|p| std::cmp::Reverse(p.priority()));
460
461    let mut plugin_pre_turn: Vec<Arc<dyn crate::policy::PreTurnPolicy>> = Vec::new();
462    let mut plugin_pre_dispatch: Vec<Arc<dyn crate::policy::PreDispatchPolicy>> = Vec::new();
463    let mut plugin_post_turn: Vec<Arc<dyn crate::policy::PostTurnPolicy>> = Vec::new();
464    let mut plugin_post_loop: Vec<Arc<dyn crate::policy::PostLoopPolicy>> = Vec::new();
465    let mut plugin_tools: Vec<Arc<dyn AgentTool>> = Vec::new();
466    let mut plugin_forwarders: Vec<crate::event_forwarder::EventForwarderFn> = Vec::new();
467
468    for plugin in &options.plugins {
469        plugin_pre_turn.extend(plugin.pre_turn_policies());
470        plugin_pre_dispatch.extend(plugin.pre_dispatch_policies());
471        plugin_post_turn.extend(plugin.post_turn_policies());
472        plugin_post_loop.extend(plugin.post_loop_policies());
473
474        // Wrap plugin tools in NamespacedTool.
475        let plugin_name = plugin.name().to_owned();
476        for tool in plugin.tools() {
477            plugin_tools.push(Arc::new(crate::plugin::NamespacedTool::new(
478                &plugin_name,
479                tool,
480            )));
481        }
482
483        // Wrap plugin's on_event as EventForwarderFn.
484        let plugin_ref = Arc::clone(plugin);
485        plugin_forwarders.push(Arc::new(move |event: crate::loop_::AgentEvent| {
486            plugin_ref.on_event(&event);
487        }));
488    }
489
490    // Prepend plugin policies before direct policies.
491    plugin_pre_turn.append(&mut options.pre_turn_policies);
492    options.pre_turn_policies = plugin_pre_turn;
493
494    plugin_pre_dispatch.append(&mut options.pre_dispatch_policies);
495    options.pre_dispatch_policies = plugin_pre_dispatch;
496
497    plugin_post_turn.append(&mut options.post_turn_policies);
498    options.post_turn_policies = plugin_post_turn;
499
500    plugin_post_loop.append(&mut options.post_loop_policies);
501    options.post_loop_policies = plugin_post_loop;
502
503    // Append namespaced plugin tools after direct tools.
504    options.tools.extend(plugin_tools);
505
506    // Prepend plugin event forwarders before direct forwarders.
507    plugin_forwarders.append(&mut options.event_forwarders);
508    options.event_forwarders = plugin_forwarders;
509
510    options
511}
512
513// ─── SharedRetryStrategy ─────────────────────────────────────────────────────
514
515/// Wrapper that delegates to an `Arc<dyn RetryStrategy>`, allowing
516/// the agent to share its retry strategy with each loop config.
517struct SharedRetryStrategy(Arc<dyn RetryStrategy>);
518
519impl RetryStrategy for SharedRetryStrategy {
520    fn should_retry(&self, error: &AgentError, attempt: u32) -> bool {
521        self.0.should_retry(error, attempt)
522    }
523
524    fn delay(&self, attempt: u32) -> std::time::Duration {
525        self.0.delay(attempt)
526    }
527
528    fn as_any(&self) -> &dyn std::any::Any {
529        self
530    }
531}
532
533#[cfg(all(test, feature = "plugins"))]
534mod tests {
535    use std::sync::Arc;
536
537    use crate::testing::{MockPlugin, MockTool, SimpleMockStreamFn};
538
539    use super::*;
540
541    #[test]
542    #[should_panic(expected = "duplicate tool names are not allowed after composition")]
543    fn agent_new_rejects_duplicate_names_after_plugin_composition() {
544        let stream_fn = Arc::new(SimpleMockStreamFn::from_text("ok"));
545        let options = AgentOptions::new(
546            "test",
547            crate::testing::default_model(),
548            stream_fn,
549            crate::testing::default_convert,
550        )
551        .with_tools(vec![
552            Arc::new(MockTool::new("my_web_search")) as Arc<dyn AgentTool>
553        ])
554        .with_plugin(Arc::new(MockPlugin::new("my-web").with_tools(&["search"])));
555
556        let _agent = Agent::new(options);
557    }
558}