1#[path = "agent/checkpointing.rs"]
12mod checkpointing;
13#[path = "agent/control.rs"]
14mod control;
15#[path = "agent/events.rs"]
16mod events;
17#[path = "agent/invoke.rs"]
18mod invoke;
19#[path = "agent/mutation.rs"]
20mod mutation;
21#[path = "agent/queueing.rs"]
22mod queueing;
23#[path = "agent/state_updates.rs"]
24mod state_updates;
25#[path = "agent/structured_output.rs"]
26mod structured_output;
27
28use std::collections::{HashSet, VecDeque};
29use std::sync::atomic::{AtomicBool, AtomicU64};
30use std::sync::{Arc, Mutex};
31
32use tokio::sync::Notify;
33use tokio_util::sync::CancellationToken;
34
35use crate::agent_id::AgentId;
36use crate::agent_options::{
37 ApproveToolArc, AsyncTransformContextArc, CheckpointStoreArc, ConvertToLlmFn, GetApiKeyArc,
38 TransformContextArc,
39};
40use crate::agent_subscriptions::ListenerRegistry;
41use crate::error::AgentError;
42use crate::message_provider::MessageProvider;
43use crate::retry::RetryStrategy;
44use crate::stream::{StreamFn, StreamOptions};
45use crate::tool::{AgentTool, ApprovalMode};
46use crate::types::{AgentMessage, LlmMessage, ModelSpec};
47
48pub use crate::agent_options::{AgentOptions, DEFAULT_PLAN_MODE_ADDENDUM};
50pub use crate::agent_subscriptions::SubscriptionId;
51
52#[non_exhaustive]
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
57pub enum SteeringMode {
58 All,
60 #[default]
62 OneAtATime,
63}
64
65#[non_exhaustive]
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
68pub enum FollowUpMode {
69 All,
71 #[default]
73 OneAtATime,
74}
75
76pub struct AgentState {
80 pub system_prompt: String,
82 pub model: ModelSpec,
84 pub tools: Vec<Arc<dyn AgentTool>>,
86 pub messages: Vec<AgentMessage>,
88 pub is_running: bool,
90 pub stream_message: Option<AgentMessage>,
92 pub pending_tool_calls: HashSet<String>,
94 pub error: Option<String>,
96 pub available_models: Vec<ModelSpec>,
98}
99
100impl std::fmt::Debug for AgentState {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("AgentState")
103 .field("system_prompt", &self.system_prompt)
104 .field("model", &self.model)
105 .field("tools", &format_args!("[{} tool(s)]", self.tools.len()))
106 .field("messages", &self.messages)
107 .field("is_running", &self.is_running)
108 .field("stream_message", &self.stream_message)
109 .field("pending_tool_calls", &self.pending_tool_calls)
110 .field("error", &self.error)
111 .field("available_models", &self.available_models)
112 .finish()
113 }
114}
115
116pub fn default_convert(msg: &AgentMessage) -> Option<LlmMessage> {
123 match msg {
124 AgentMessage::Llm(llm) => Some(llm.clone()),
125 AgentMessage::Custom(_) => None,
126 }
127}
128
129type ModelStreamRegistry = Vec<(ModelSpec, Arc<dyn StreamFn>)>;
130
131fn available_models_and_stream_fns(
132 options: &AgentOptions,
133) -> (Vec<ModelSpec>, ModelStreamRegistry) {
134 let primary_model = options.model.clone();
135 let primary_stream_fn = Arc::clone(&options.stream_fn);
136 let mut available_models = vec![options.model.clone()];
137 available_models.extend(
138 options
139 .available_models
140 .iter()
141 .map(|(model, _): &(ModelSpec, _)| model.clone()),
142 );
143 let mut model_stream_fns = vec![(primary_model, primary_stream_fn)];
144 model_stream_fns.extend(
145 options
146 .available_models
147 .iter()
148 .map(|(model, stream_fn): &(ModelSpec, _)| (model.clone(), Arc::clone(stream_fn))),
149 );
150
151 (available_models, model_stream_fns)
152}
153
154fn assert_unique_tool_names(tools: &[Arc<dyn AgentTool>]) {
155 let mut seen = HashSet::with_capacity(tools.len());
156 let mut duplicates = Vec::new();
157
158 for tool in tools {
159 let name = tool.name();
160 if !seen.insert(name.to_owned()) {
161 duplicates.push(name.to_owned());
162 }
163 }
164
165 if !duplicates.is_empty() {
166 duplicates.sort();
167 duplicates.dedup();
168 panic!(
169 "duplicate tool names are not allowed after composition: {}",
170 duplicates.join(", ")
171 );
172 }
173}
174
175#[cfg(feature = "plugins")]
176fn dispatch_plugin_on_init(agent: &Agent) {
177 for plugin in &agent.plugins {
181 let name = plugin.name().to_owned();
182 let plugin_ref = Arc::clone(plugin);
183 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
184 plugin_ref.on_init(agent);
185 }));
186 if let Err(cause) = result {
187 let msg = cause
188 .downcast_ref::<&str>()
189 .map(|s| (*s).to_owned())
190 .or_else(|| cause.downcast_ref::<String>().cloned())
191 .unwrap_or_else(|| "unknown panic".to_owned());
192 tracing::warn!(plugin = %name, error = %msg, "plugin on_init panicked");
193 }
194 }
195}
196
197#[cfg(not(feature = "plugins"))]
198const fn dispatch_plugin_on_init(_agent: &Agent) {}
199pub struct Agent {
207 id: AgentId,
209
210 state: AgentState,
212
213 steering_queue: Arc<Mutex<VecDeque<AgentMessage>>>,
215 follow_up_queue: Arc<Mutex<VecDeque<AgentMessage>>>,
216 listeners: ListenerRegistry,
217 abort_controller: Option<CancellationToken>,
218 steering_mode: SteeringMode,
219 follow_up_mode: FollowUpMode,
220 stream_fn: Arc<dyn StreamFn>,
221 convert_to_llm: ConvertToLlmFn,
222 transform_context: Option<TransformContextArc>,
223 get_api_key: Option<GetApiKeyArc>,
224 retry_strategy: Arc<dyn RetryStrategy>,
225 stream_options: StreamOptions,
226 structured_output_max_retries: usize,
227 idle_notify: Arc<Notify>,
228 in_flight_llm_messages: Option<Vec<AgentMessage>>,
229 in_flight_messages: Option<Vec<AgentMessage>>,
230 pending_message_snapshot: Arc<crate::pause_state::PendingMessageSnapshot>,
231 loop_context_snapshot: Arc<crate::pause_state::LoopContextSnapshot>,
232 approve_tool: Option<ApproveToolArc>,
233 approval_mode: ApprovalMode,
234 pre_turn_policies: Vec<Arc<dyn crate::policy::PreTurnPolicy>>,
235 pre_dispatch_policies: Vec<Arc<dyn crate::policy::PreDispatchPolicy>>,
236 post_turn_policies: Vec<Arc<dyn crate::policy::PostTurnPolicy>>,
237 post_loop_policies: Vec<Arc<dyn crate::policy::PostLoopPolicy>>,
238 model_stream_fns: Vec<(ModelSpec, Arc<dyn StreamFn>)>,
240 event_forwarders: Vec<crate::event_forwarder::EventForwarderFn>,
242 async_transform_context: Option<AsyncTransformContextArc>,
244 checkpoint_store: Option<CheckpointStoreArc>,
246 pub(crate) custom_message_registry: Option<Arc<crate::types::CustomMessageRegistry>>,
248 metrics_collector: Option<Arc<dyn crate::metrics::MetricsCollector>>,
250 fallback: Option<crate::fallback::ModelFallback>,
252 external_message_provider: Option<Arc<dyn MessageProvider>>,
254 tool_execution_policy: crate::tool_execution_policy::ToolExecutionPolicy,
256 plan_mode_addendum: Option<String>,
258 session_state: Arc<std::sync::RwLock<crate::SessionState>>,
260 credential_resolver: Option<Arc<dyn crate::credential::CredentialResolver>>,
262 cache_config: Option<crate::context_cache::CacheConfig>,
264 dynamic_system_prompt: Option<Arc<dyn Fn() -> String + Send + Sync>>,
266 loop_active: Arc<AtomicBool>,
272 loop_generation: Arc<AtomicU64>,
275 #[cfg(feature = "plugins")]
277 plugins: Vec<Arc<dyn crate::plugin::Plugin>>,
278 #[allow(clippy::struct_field_names)]
280 agent_name: Option<String>,
281 transfer_chain: Option<crate::transfer::TransferChain>,
283}
284
285impl Agent {
286 #[must_use]
288 pub fn new(options: AgentOptions) -> Self {
289 #[cfg(feature = "plugins")]
291 let options = merge_plugin_contributions(options);
292
293 assert_unique_tool_names(&options.tools);
294
295 let effective_prompt = options.effective_system_prompt().to_owned();
297 let (available_models, model_stream_fns) = available_models_and_stream_fns(&options);
298
299 let transform_context = match (options.token_counter, options.transform_context) {
302 (Some(counter), None) => Some(Arc::new(
303 crate::context_transformer::SlidingWindowTransformer::new(100_000, 50_000, 2)
304 .with_token_counter(counter),
305 ) as TransformContextArc),
306 (_, tc) => tc,
307 };
308
309 let agent = Self {
310 id: AgentId::next(),
311 state: AgentState {
312 system_prompt: effective_prompt,
313 model: options.model,
314 tools: options.tools,
315 messages: Vec::new(),
316 is_running: false,
317 stream_message: None,
318 pending_tool_calls: HashSet::new(),
319 error: None,
320 available_models,
321 },
322 steering_queue: Arc::new(Mutex::new(VecDeque::new())),
323 follow_up_queue: Arc::new(Mutex::new(VecDeque::new())),
324 listeners: ListenerRegistry::new(),
325 abort_controller: None,
326 steering_mode: options.steering_mode,
327 follow_up_mode: options.follow_up_mode,
328 stream_fn: options.stream_fn,
329 convert_to_llm: options.convert_to_llm,
330 transform_context,
331 get_api_key: options.get_api_key,
332 retry_strategy: Arc::from(options.retry_strategy),
333 stream_options: options.stream_options,
334 structured_output_max_retries: options.structured_output_max_retries,
335 idle_notify: Arc::new(Notify::new()),
336 in_flight_llm_messages: None,
337 in_flight_messages: None,
338 pending_message_snapshot: Arc::new(
339 crate::pause_state::PendingMessageSnapshot::default(),
340 ),
341 loop_context_snapshot: Arc::new(crate::pause_state::LoopContextSnapshot::default()),
342 approve_tool: options.approve_tool,
343 approval_mode: options.approval_mode,
344 pre_turn_policies: options.pre_turn_policies,
345 pre_dispatch_policies: options.pre_dispatch_policies,
346 post_turn_policies: options.post_turn_policies,
347 post_loop_policies: options.post_loop_policies,
348 model_stream_fns,
349 event_forwarders: options.event_forwarders,
350 async_transform_context: options.async_transform_context,
351 checkpoint_store: options.checkpoint_store,
352 custom_message_registry: options.custom_message_registry,
353 metrics_collector: options.metrics_collector,
354 fallback: options.fallback,
355 external_message_provider: options.external_message_provider,
356 tool_execution_policy: options.tool_execution_policy,
357 plan_mode_addendum: options.plan_mode_addendum,
358 session_state: Arc::new(std::sync::RwLock::new(
359 options.session_state.unwrap_or_default(),
360 )),
361 credential_resolver: options.credential_resolver,
362 cache_config: options.cache_config,
363 dynamic_system_prompt: options.dynamic_system_prompt.map(Arc::from),
364 loop_active: Arc::new(AtomicBool::new(false)),
365 loop_generation: Arc::new(AtomicU64::new(0)),
366 #[cfg(feature = "plugins")]
367 plugins: options.plugins,
368 agent_name: options.agent_name,
369 transfer_chain: options.transfer_chain,
370 };
371
372 dispatch_plugin_on_init(&agent);
373
374 agent
375 }
376
377 #[must_use]
379 pub const fn id(&self) -> AgentId {
380 self.id
381 }
382
383 #[must_use]
389 pub const fn state(&self) -> &AgentState {
390 &self.state
391 }
392
393 #[must_use]
399 pub fn is_running(&self) -> bool {
400 self.loop_active.load(std::sync::atomic::Ordering::Acquire)
401 }
402
403 #[must_use]
405 pub const fn session_state(&self) -> &Arc<std::sync::RwLock<crate::SessionState>> {
406 &self.session_state
407 }
408
409 #[must_use]
414 pub fn custom_message_registry(&self) -> Option<&crate::types::CustomMessageRegistry> {
415 self.custom_message_registry.as_deref()
416 }
417
418 #[cfg(feature = "plugins")]
420 #[must_use]
421 pub fn plugins(&self) -> &[Arc<dyn crate::plugin::Plugin>] {
422 &self.plugins
423 }
424
425 #[cfg(feature = "plugins")]
427 #[must_use]
428 pub fn plugin(&self, name: &str) -> Option<&Arc<dyn crate::plugin::Plugin>> {
429 self.plugins.iter().find(|p| p.name() == name)
430 }
431}
432
433impl std::fmt::Debug for Agent {
434 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435 f.debug_struct("Agent")
436 .field("state", &self.state)
437 .field("steering_mode", &self.steering_mode)
438 .field("follow_up_mode", &self.follow_up_mode)
439 .field(
440 "listeners",
441 &format_args!("{} listener(s)", self.listeners.len()),
442 )
443 .field("is_abort_active", &self.abort_controller.is_some())
444 .finish_non_exhaustive()
445 }
446}
447
448#[cfg(feature = "plugins")]
455fn merge_plugin_contributions(mut options: AgentOptions) -> AgentOptions {
456 options
458 .plugins
459 .sort_by_key(|p| std::cmp::Reverse(p.priority()));
460
461 let mut plugin_pre_turn: Vec<Arc<dyn crate::policy::PreTurnPolicy>> = Vec::new();
462 let mut plugin_pre_dispatch: Vec<Arc<dyn crate::policy::PreDispatchPolicy>> = Vec::new();
463 let mut plugin_post_turn: Vec<Arc<dyn crate::policy::PostTurnPolicy>> = Vec::new();
464 let mut plugin_post_loop: Vec<Arc<dyn crate::policy::PostLoopPolicy>> = Vec::new();
465 let mut plugin_tools: Vec<Arc<dyn AgentTool>> = Vec::new();
466 let mut plugin_forwarders: Vec<crate::event_forwarder::EventForwarderFn> = Vec::new();
467
468 for plugin in &options.plugins {
469 plugin_pre_turn.extend(plugin.pre_turn_policies());
470 plugin_pre_dispatch.extend(plugin.pre_dispatch_policies());
471 plugin_post_turn.extend(plugin.post_turn_policies());
472 plugin_post_loop.extend(plugin.post_loop_policies());
473
474 let plugin_name = plugin.name().to_owned();
476 for tool in plugin.tools() {
477 plugin_tools.push(Arc::new(crate::plugin::NamespacedTool::new(
478 &plugin_name,
479 tool,
480 )));
481 }
482
483 let plugin_ref = Arc::clone(plugin);
485 plugin_forwarders.push(Arc::new(move |event: crate::loop_::AgentEvent| {
486 plugin_ref.on_event(&event);
487 }));
488 }
489
490 plugin_pre_turn.append(&mut options.pre_turn_policies);
492 options.pre_turn_policies = plugin_pre_turn;
493
494 plugin_pre_dispatch.append(&mut options.pre_dispatch_policies);
495 options.pre_dispatch_policies = plugin_pre_dispatch;
496
497 plugin_post_turn.append(&mut options.post_turn_policies);
498 options.post_turn_policies = plugin_post_turn;
499
500 plugin_post_loop.append(&mut options.post_loop_policies);
501 options.post_loop_policies = plugin_post_loop;
502
503 options.tools.extend(plugin_tools);
505
506 plugin_forwarders.append(&mut options.event_forwarders);
508 options.event_forwarders = plugin_forwarders;
509
510 options
511}
512
513struct SharedRetryStrategy(Arc<dyn RetryStrategy>);
518
519impl RetryStrategy for SharedRetryStrategy {
520 fn should_retry(&self, error: &AgentError, attempt: u32) -> bool {
521 self.0.should_retry(error, attempt)
522 }
523
524 fn delay(&self, attempt: u32) -> std::time::Duration {
525 self.0.delay(attempt)
526 }
527
528 fn as_any(&self) -> &dyn std::any::Any {
529 self
530 }
531}
532
533#[cfg(all(test, feature = "plugins"))]
534mod tests {
535 use std::sync::Arc;
536
537 use crate::testing::{MockPlugin, MockTool, SimpleMockStreamFn};
538
539 use super::*;
540
541 #[test]
542 #[should_panic(expected = "duplicate tool names are not allowed after composition")]
543 fn agent_new_rejects_duplicate_names_after_plugin_composition() {
544 let stream_fn = Arc::new(SimpleMockStreamFn::from_text("ok"));
545 let options = AgentOptions::new(
546 "test",
547 crate::testing::default_model(),
548 stream_fn,
549 crate::testing::default_convert,
550 )
551 .with_tools(vec![
552 Arc::new(MockTool::new("my_web_search")) as Arc<dyn AgentTool>
553 ])
554 .with_plugin(Arc::new(MockPlugin::new("my-web").with_tools(&["search"])));
555
556 let _agent = Agent::new(options);
557 }
558}