1use crate::agent::{
4 AfterToolCallContext, AfterToolCallFn, AgentConfig, AgentEvent, AgentHooks, AgentMessage,
5 AgentState, AgentStateSnapshot, AgentTool, AgentToolResult, BeforeToolCallContext,
6 BeforeToolCallFn, BeforeToolCallResult, QueueMode, ThinkingBudgets, ToolExecutionMode,
7 ToolExecutor, ToolUpdateCallback, Transport,
8};
9use crate::provider::{get_provider, ArcProtocol};
10use crate::stream::AssistantMessageEventStream;
11use crate::thinking::ThinkingLevel;
12use crate::types::*;
13use futures::StreamExt;
14use parking_lot::{Mutex, RwLock};
15use std::collections::{HashMap, VecDeque};
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18
19const DEFAULT_MAX_TURNS: usize = 25;
21
22pub type SubscriberId = u64;
24
25type SubscriberCallback = Arc<dyn Fn(&AgentEvent) + Send + Sync>;
27
28struct Subscribers {
30 callbacks: RwLock<HashMap<u64, SubscriberCallback>>,
31 next_id: AtomicU64,
32}
33
34impl Subscribers {
35 fn new() -> Self {
36 Self {
37 callbacks: RwLock::new(HashMap::new()),
38 next_id: AtomicU64::new(0),
39 }
40 }
41
42 fn subscribe(&self, callback: SubscriberCallback) -> SubscriberId {
43 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
44 self.callbacks.write().insert(id, callback);
45 id
46 }
47
48 fn unsubscribe(&self, id: SubscriberId) {
49 self.callbacks.write().remove(&id);
50 }
51
52 fn emit(&self, event: &AgentEvent) {
56 let snapshot: Vec<SubscriberCallback> =
57 { self.callbacks.read().values().cloned().collect() };
58 for cb in &snapshot {
59 cb(event);
60 }
61 }
62}
63
64pub struct Agent {
66 state: Arc<AgentState>,
68 config: RwLock<AgentConfig>,
70 provider: RwLock<Option<ArcProtocol>>,
72 hooks: RwLock<AgentHooks>,
74 max_turns: RwLock<usize>,
76 steering_queue: Mutex<VecDeque<AgentMessage>>,
78 follow_up_queue: Mutex<VecDeque<AgentMessage>>,
80 subscribers: Arc<Subscribers>,
82 abort_flag: Arc<AtomicBool>,
84 api_key: RwLock<Option<String>>,
86 session_id: RwLock<Option<String>>,
88}
89
90impl Agent {
91 pub fn new() -> Self {
93 Self {
94 state: Arc::new(AgentState::new()),
95 config: RwLock::new(AgentConfig::new(
96 Model::builder()
97 .id("gpt-4o-mini")
98 .name("GPT-4o Mini")
99 .provider(Provider::OpenAI)
100 .base_url("https://api.openai.com/v1")
101 .context_window(128000)
102 .max_tokens(16384)
103 .build()
104 .unwrap(),
105 )),
106 provider: RwLock::new(None),
107 hooks: RwLock::new(AgentHooks::default()),
108 max_turns: RwLock::new(DEFAULT_MAX_TURNS),
109 steering_queue: Mutex::new(VecDeque::new()),
110 follow_up_queue: Mutex::new(VecDeque::new()),
111 subscribers: Arc::new(Subscribers::new()),
112 abort_flag: Arc::new(AtomicBool::new(false)),
113 api_key: RwLock::new(None),
114 session_id: RwLock::new(None),
115 }
116 }
117
118 pub fn with_model(model: Model) -> Self {
120 let agent = Self::new();
121 agent.set_model(model.clone());
122 *agent.config.write() = AgentConfig::new(model);
123 agent
124 }
125
126 pub fn set_provider(&self, provider: ArcProtocol) {
132 *self.provider.write() = Some(provider);
133 }
134
135 pub fn set_api_key(&self, key: impl Into<String>) {
137 *self.api_key.write() = Some(key.into());
138 }
139
140 pub fn set_get_api_key<F, Fut>(&self, resolver: F)
145 where
146 F: Fn(&str) -> Fut + Send + Sync + 'static,
147 Fut: std::future::Future<Output = Option<String>> + Send + 'static,
148 {
149 let resolver = Arc::new(move |provider: &str| {
150 let fut = resolver(provider);
151 Box::pin(fut)
152 as std::pin::Pin<Box<dyn std::future::Future<Output = Option<String>> + Send>>
153 });
154 self.hooks.write().get_api_key = Some(resolver);
155 }
156
157 pub fn set_tool_executor<F, Fut>(&self, executor: F)
167 where
168 F: Fn(&str, &str, &serde_json::Value, Option<ToolUpdateCallback>) -> Fut
169 + Send
170 + Sync
171 + 'static,
172 Fut: std::future::Future<Output = AgentToolResult> + Send + 'static,
173 {
174 let executor = Arc::new(
175 move |name: &str,
176 id: &str,
177 args: &serde_json::Value,
178 update_cb: Option<ToolUpdateCallback>| {
179 let fut = executor(name, id, args, update_cb);
180 Box::pin(fut)
181 as std::pin::Pin<Box<dyn std::future::Future<Output = AgentToolResult> + Send>>
182 },
183 );
184 self.hooks.write().tool_executor = Some(executor);
185 }
186
187 pub fn set_tool_executor_simple<F, Fut>(&self, executor: F)
191 where
192 F: Fn(&str, &str, &serde_json::Value) -> Fut + Send + Sync + 'static,
193 Fut: std::future::Future<Output = AgentToolResult> + Send + 'static,
194 {
195 let executor = Arc::new(
196 move |name: &str,
197 id: &str,
198 args: &serde_json::Value,
199 _update_cb: Option<ToolUpdateCallback>| {
200 let fut = executor(name, id, args);
201 Box::pin(fut)
202 as std::pin::Pin<Box<dyn std::future::Future<Output = AgentToolResult> + Send>>
203 },
204 );
205 self.hooks.write().tool_executor = Some(executor);
206 }
207
208 pub fn set_before_tool_call<F, Fut>(&self, hook: F)
213 where
214 F: Fn(BeforeToolCallContext) -> Fut + Send + Sync + 'static,
215 Fut: std::future::Future<Output = Option<BeforeToolCallResult>> + Send + 'static,
216 {
217 let hook = Arc::new(move |ctx: BeforeToolCallContext| {
218 let fut = hook(ctx);
219 Box::pin(fut)
220 as std::pin::Pin<
221 Box<dyn std::future::Future<Output = Option<BeforeToolCallResult>> + Send>,
222 >
223 });
224 self.hooks.write().before_tool_call = Some(hook);
225 }
226
227 pub fn set_after_tool_call<F, Fut>(&self, hook: F)
232 where
233 F: Fn(AfterToolCallContext) -> Fut + Send + Sync + 'static,
234 Fut: std::future::Future<Output = Option<crate::agent::AfterToolCallResult>>
235 + Send
236 + 'static,
237 {
238 let hook = Arc::new(move |ctx: AfterToolCallContext| {
239 let fut = hook(ctx);
240 Box::pin(fut)
241 as std::pin::Pin<
242 Box<
243 dyn std::future::Future<Output = Option<crate::agent::AfterToolCallResult>>
244 + Send,
245 >,
246 >
247 });
248 self.hooks.write().after_tool_call = Some(hook);
249 }
250
251 pub fn set_convert_to_llm<F, Fut>(&self, converter: F)
260 where
261 F: Fn(Vec<AgentMessage>) -> Fut + Send + Sync + 'static,
262 Fut: std::future::Future<Output = Vec<Message>> + Send + 'static,
263 {
264 let converter = Arc::new(move |msgs: Vec<AgentMessage>| {
265 let fut = converter(msgs);
266 Box::pin(fut)
267 as std::pin::Pin<Box<dyn std::future::Future<Output = Vec<Message>> + Send>>
268 });
269 self.hooks.write().convert_to_llm = Some(converter);
270 }
271
272 pub fn set_transform_context<F, Fut>(&self, transform: F)
277 where
278 F: Fn(Vec<AgentMessage>) -> Fut + Send + Sync + 'static,
279 Fut: std::future::Future<Output = Vec<AgentMessage>> + Send + 'static,
280 {
281 let transform = Arc::new(move |msgs: Vec<AgentMessage>| {
282 let fut = transform(msgs);
283 Box::pin(fut)
284 as std::pin::Pin<Box<dyn std::future::Future<Output = Vec<AgentMessage>> + Send>>
285 });
286 self.hooks.write().transform_context = Some(transform);
287 }
288
289 pub fn set_on_payload<F, Fut>(&self, hook: F)
297 where
298 F: Fn(serde_json::Value, Model) -> Fut + Send + Sync + 'static,
299 Fut: std::future::Future<Output = Option<serde_json::Value>> + Send + 'static,
300 {
301 let hook = Arc::new(move |payload: serde_json::Value, model: Model| {
302 let fut = hook(payload, model);
303 Box::pin(fut)
304 as std::pin::Pin<
305 Box<dyn std::future::Future<Output = Option<serde_json::Value>> + Send>,
306 >
307 });
308 self.hooks.write().on_payload = Some(hook);
309 }
310
311 pub fn set_stream_fn<F, Fut>(&self, stream_fn: F)
315 where
316 F: Fn(&Model, &Context, StreamOptions) -> Fut + Send + Sync + 'static,
317 Fut: std::future::Future<Output = AssistantMessageEventStream> + Send + 'static,
318 {
319 let stream_fn = Arc::new(
320 move |model: &Model, context: &Context, options: StreamOptions| {
321 let fut = stream_fn(model, context, options);
322 Box::pin(fut)
323 as std::pin::Pin<
324 Box<dyn std::future::Future<Output = AssistantMessageEventStream> + Send>,
325 >
326 },
327 );
328 self.hooks.write().stream_fn = Some(stream_fn);
329 }
330
331 pub fn set_max_turns(&self, max: usize) {
337 *self.max_turns.write() = max;
338 }
339
340 pub fn set_security_config(&self, config: crate::types::SecurityConfig) {
342 self.config.write().security = config;
343 }
344
345 pub fn security_config(&self) -> crate::types::SecurityConfig {
347 self.config.read().security.clone()
348 }
349
350 pub fn set_tool_execution(&self, mode: ToolExecutionMode) {
352 self.config.write().tool_execution = mode;
353 }
354
355 pub fn set_steering_mode(&self, mode: QueueMode) {
357 self.config.write().steering_mode = mode;
358 }
359
360 pub fn steering_mode(&self) -> QueueMode {
362 self.config.read().steering_mode
363 }
364
365 pub fn set_follow_up_mode(&self, mode: QueueMode) {
367 self.config.write().follow_up_mode = mode;
368 }
369
370 pub fn follow_up_mode(&self) -> QueueMode {
372 self.config.read().follow_up_mode
373 }
374
375 pub fn set_thinking_budgets(&self, budgets: ThinkingBudgets) {
377 self.config.write().thinking_budgets = Some(budgets);
378 }
379
380 pub fn thinking_budgets(&self) -> Option<ThinkingBudgets> {
382 self.config.read().thinking_budgets.clone()
383 }
384
385 pub fn set_transport(&self, transport: Transport) {
387 self.config.write().transport = transport;
388 }
389
390 pub fn transport(&self) -> Transport {
392 self.config.read().transport
393 }
394
395 pub fn set_max_retry_delay_ms(&self, ms: Option<u64>) {
401 self.config.write().max_retry_delay_ms = ms;
402 }
403
404 pub fn max_retry_delay_ms(&self) -> Option<u64> {
406 self.config.read().max_retry_delay_ms
407 }
408
409 pub fn set_session_id(&self, id: impl Into<String>) {
411 *self.session_id.write() = Some(id.into());
412 }
413
414 pub fn session_id(&self) -> Option<String> {
416 self.session_id.read().clone()
417 }
418
419 pub fn clear_session_id(&self) {
421 *self.session_id.write() = None;
422 }
423
424 pub fn subscribe<F>(&self, callback: F) -> impl Fn()
430 where
431 F: Fn(&AgentEvent) + Send + Sync + 'static,
432 {
433 let id = self.subscribers.subscribe(Arc::new(callback));
434 let subs = Arc::clone(&self.subscribers);
435 move || {
436 subs.unsubscribe(id);
437 }
438 }
439
440 fn emit(&self, event: AgentEvent) {
442 self.subscribers.emit(&event);
443 }
444
445 pub fn set_system_prompt(&self, prompt: impl Into<String>) {
451 self.state.set_system_prompt(prompt);
452 }
453
454 pub fn set_model(&self, model: Model) {
456 self.config.write().model = model;
457 }
458
459 pub fn set_thinking_level(&self, level: ThinkingLevel) {
461 self.config.write().thinking_level = level;
462 }
463
464 pub fn set_tools(&self, tools: Vec<AgentTool>) {
466 self.state.set_tools(tools);
467 }
468
469 pub fn replace_messages(&self, messages: Vec<AgentMessage>) {
471 self.state.replace_messages(messages);
472 }
473
474 pub fn append_message(&self, message: AgentMessage) {
476 self.state.add_message(message);
477 }
478
479 pub fn clear_messages(&self) {
481 self.state.clear_messages();
482 }
483
484 pub fn reset(&self) {
486 self.state.reset();
487 self.steering_queue.lock().clear();
488 self.follow_up_queue.lock().clear();
489 *self.session_id.write() = None;
490 }
491
492 pub fn steer(&self, message: AgentMessage) {
498 self.steering_queue.lock().push_back(message);
499 }
500
501 pub fn follow_up(&self, message: AgentMessage) {
503 self.follow_up_queue.lock().push_back(message);
504 }
505
506 pub fn clear_steering_queue(&self) {
508 self.steering_queue.lock().clear();
509 }
510
511 pub fn clear_follow_up_queue(&self) {
513 self.follow_up_queue.lock().clear();
514 }
515
516 pub fn clear_all_queues(&self) {
518 self.clear_steering_queue();
519 self.clear_follow_up_queue();
520 }
521
522 pub fn has_queued_messages(&self) -> bool {
524 !self.steering_queue.lock().is_empty() || !self.follow_up_queue.lock().is_empty()
525 }
526
527 fn dequeue_steering_messages(&self) -> Vec<AgentMessage> {
529 let mode = self.config.read().steering_mode;
530 let mut queue = self.steering_queue.lock();
531 match mode {
532 QueueMode::All => queue.drain(..).collect(),
533 QueueMode::OneAtATime => {
534 if let Some(first) = queue.pop_front() {
535 vec![first]
536 } else {
537 Vec::new()
538 }
539 }
540 }
541 }
542
543 fn dequeue_follow_up_messages(&self) -> Vec<AgentMessage> {
545 let mode = self.config.read().follow_up_mode;
546 let mut queue = self.follow_up_queue.lock();
547 match mode {
548 QueueMode::All => queue.drain(..).collect(),
549 QueueMode::OneAtATime => {
550 if let Some(first) = queue.pop_front() {
551 vec![first]
552 } else {
553 Vec::new()
554 }
555 }
556 }
557 }
558
559 fn default_convert_to_llm(messages: Vec<AgentMessage>) -> Vec<Message> {
565 messages
566 .into_iter()
567 .filter_map(|m| {
568 let opt: Option<Message> = m.into();
569 opt
570 })
571 .collect()
572 }
573
574 async fn build_context(&self) -> Context {
577 let system_prompt = self.state.system_prompt.read().clone();
578 let messages = self.state.messages.read().clone();
579 let tools = self.state.tools.read().clone();
580
581 let transform = self.hooks.read().transform_context.clone();
583 let messages = if let Some(ref transform) = transform {
584 transform(messages).await
585 } else {
586 messages
587 };
588
589 let converter = self.hooks.read().convert_to_llm.clone();
591 let llm_messages = if let Some(ref converter) = converter {
592 converter(messages).await
593 } else {
594 Self::default_convert_to_llm(messages)
595 };
596
597 let mut context = if system_prompt.is_empty() {
599 Context::new()
600 } else {
601 Context::with_system_prompt(&system_prompt)
602 };
603
604 for msg in llm_messages {
605 context.add_message(msg);
606 }
607
608 if !tools.is_empty() {
610 let tool_defs: Vec<Tool> = tools.iter().map(|t| t.as_tool()).collect();
611 context.set_tools(tool_defs);
612 }
613
614 context
615 }
616
617 fn resolve_provider(&self) -> Result<ArcProtocol, AgentError> {
619 if let Some(ref provider) = *self.provider.read() {
621 return Ok(provider.clone());
622 }
623
624 let model = self.config.read().model.clone();
626 if let Some(provider) = get_provider(&model.provider) {
627 return Ok(provider);
628 }
629
630 Err(AgentError::ProviderError(format!(
631 "No provider registered for provider type: {}",
632 model.provider.as_str()
633 )))
634 }
635
636 async fn build_stream_options(&self) -> StreamOptions {
638 let security = self.config.read().security.clone();
639 let model = self.config.read().model.clone();
640 let on_payload = self.hooks.read().on_payload.clone();
641 let transport = self.config.read().transport;
642 let max_retry_delay_ms = self.config.read().max_retry_delay_ms;
643 let session_id = self.session_id.read().clone();
644
645 let get_api_key = self.hooks.read().get_api_key.clone();
647 let api_key = if let Some(ref resolver) = get_api_key {
648 let dynamic = resolver(model.provider.as_str()).await;
649 dynamic.or_else(|| self.api_key.read().clone())
650 } else {
651 self.api_key.read().clone()
652 };
653
654 StreamOptions {
655 api_key,
656 security: Some(security),
657 session_id,
658 on_payload,
659 transport: Some(transport),
660 max_retry_delay_ms,
661 ..Default::default()
662 }
663 }
664
665 async fn build_simple_stream_options(&self) -> SimpleStreamOptions {
667 let base = self.build_stream_options().await;
668 let thinking_level = self.config.read().thinking_level;
669
670 let (reasoning, thinking_budget_tokens) = if thinking_level != ThinkingLevel::Off {
671 let budget = self
672 .config
673 .read()
674 .thinking_budgets
675 .as_ref()
676 .and_then(|b| b.budget_for(thinking_level))
677 .or_else(|| {
678 Some(crate::thinking::ThinkingConfig::default_budget(
679 thinking_level,
680 ))
681 });
682 (Some(thinking_level), budget)
683 } else {
684 (None, None)
685 };
686
687 SimpleStreamOptions {
688 base,
689 reasoning,
690 thinking_budget_tokens,
691 }
692 }
693
694 async fn run_turn(&self, provider: &ArcProtocol) -> Result<AssistantMessage, AgentError> {
696 let context = self.build_context().await;
697 let model = self.config.read().model.clone();
698 let options = self.build_simple_stream_options().await;
699 let stream_timeout = self.config.read().security.stream.result_timeout();
700
701 let stream_fn = self.hooks.read().stream_fn.clone();
703 let mut stream: AssistantMessageEventStream = if let Some(ref custom_stream) = stream_fn {
704 custom_stream(&model, &context, options.base).await
705 } else {
706 provider.stream_simple(&model, &context, options)
707 };
708
709 while let Some(event) = stream.next().await {
711 if self.abort_flag.load(Ordering::SeqCst) {
713 return Err(AgentError::Other("Aborted".to_string()));
714 }
715
716 let steering = self.dequeue_steering_messages();
718 if !steering.is_empty() {
719 for steer_msg in steering {
721 self.state.add_message(steer_msg);
722 }
723 return Err(AgentError::Other("Steered".to_string()));
725 }
726
727 match &event {
729 AssistantMessageEvent::Start { partial } => {
730 *self.state.stream_message.write() =
731 Some(AgentMessage::Assistant(partial.clone()));
732 self.emit(AgentEvent::MessageUpdate {
733 message: AgentMessage::Assistant(partial.clone()),
734 assistant_event: Box::new(event.clone()),
735 });
736 }
737 AssistantMessageEvent::TextDelta { .. }
738 | AssistantMessageEvent::ThinkingDelta { .. }
739 | AssistantMessageEvent::ToolCallDelta { .. } => {
740 if let Some(partial) = event.partial_message() {
741 *self.state.stream_message.write() =
742 Some(AgentMessage::Assistant(partial.clone()));
743 self.emit(AgentEvent::MessageUpdate {
744 message: AgentMessage::Assistant(partial.clone()),
745 assistant_event: Box::new(event.clone()),
746 });
747 }
748 }
749 _ => {
750 if let Some(partial) = event.partial_message() {
751 self.emit(AgentEvent::MessageUpdate {
752 message: AgentMessage::Assistant(partial.clone()),
753 assistant_event: Box::new(event.clone()),
754 });
755 }
756 }
757 }
758 }
759
760 let result = match stream.try_result(stream_timeout).await {
762 Some(r) => r,
763 None => {
764 return Err(AgentError::Other(format!(
765 "Stream result timed out after {:?}",
766 stream_timeout
767 )));
768 }
769 };
770
771 *self.state.stream_message.write() = None;
773
774 if result.stop_reason == StopReason::Error {
775 let error_msg = result
776 .error_message
777 .clone()
778 .unwrap_or_else(|| "Unknown error".to_string());
779 return Err(AgentError::ProviderError(error_msg));
780 }
781
782 Ok(result)
783 }
784
785 async fn execute_tool_calls(
791 &self,
792 assistant_msg: &AssistantMessage,
793 context: &Context,
794 ) -> Vec<ToolResultMessage> {
795 let tool_calls = assistant_msg.tool_calls();
796 if tool_calls.is_empty() {
797 return Vec::new();
798 }
799
800 let executor = self.hooks.read().tool_executor.clone();
801 let execution_mode = self.config.read().tool_execution;
802 let security = self.config.read().security.clone();
803 let tool_timeout = security.agent.tool_execution_timeout();
804 let before_hook = self.hooks.read().before_tool_call.clone();
805 let after_hook = self.hooks.read().after_tool_call.clone();
806
807 let agent_tools = self.state.tools.read().clone();
809 let tool_defs: Vec<Tool> = agent_tools.iter().map(|t| t.as_tool()).collect();
810
811 let mut results = Vec::new();
812
813 match execution_mode {
814 ToolExecutionMode::Parallel => {
815 let max_parallel = security.agent.max_parallel_tool_calls;
816 let abort_flag = Arc::clone(&self.abort_flag);
817
818 let mut tool_futures = Vec::new();
819
820 for tc in &tool_calls {
821 let tc_id = tc.id.clone();
822 let tc_name = tc.name.clone();
823 let tc_args = tc.arguments.clone();
824 let tc_clone = (*tc).clone();
825
826 self.emit(AgentEvent::ToolExecutionStart {
827 tool_call_id: tc_id.clone(),
828 tool_name: tc_name.clone(),
829 args: tc_args.clone(),
830 });
831
832 self.state.pending_tool_calls.write().insert(tc_id.clone());
833
834 if let Some(result) = validate_tool_call_or_error(
836 &tc_id, &tc_name, &tc_args, &tool_defs, &security,
837 ) {
838 self.emit(AgentEvent::ToolExecutionEnd {
839 tool_call_id: tc_id.clone(),
840 tool_name: tc_name.clone(),
841 result: serde_json::json!({"error": result.text_content()}),
842 is_error: true,
843 });
844 self.state.pending_tool_calls.write().remove(&tc_id);
845 results.push(result);
846 continue;
847 }
848
849 if let Some(result) = run_before_hook(
851 &before_hook,
852 assistant_msg,
853 &tc_clone,
854 &tc_args,
855 context,
856 &tc_id,
857 &tc_name,
858 )
859 .await
860 {
861 self.emit(AgentEvent::ToolExecutionEnd {
862 tool_call_id: tc_id.clone(),
863 tool_name: tc_name.clone(),
864 result: serde_json::json!({"error": result.text_content()}),
865 is_error: true,
866 });
867 self.state.pending_tool_calls.write().remove(&tc_id);
868 results.push(result);
869 continue;
870 }
871
872 let executor = executor.clone();
873 let abort = abort_flag.clone();
874 let after_hook = after_hook.clone();
875 let assistant_msg_clone = assistant_msg.clone();
876 let context_clone = context.clone();
877 let subscribers = Arc::clone(&self.subscribers);
878
879 tool_futures.push(async move {
880 let (final_content, final_is_error) =
881 execute_and_apply_after_hook(ToolExecCtx {
882 executor: &executor,
883 after_hook: &after_hook,
884 subscribers: &subscribers,
885 tc_id: &tc_id,
886 tc_name: &tc_name,
887 tc_args: &tc_args,
888 tc: &tc_clone,
889 assistant_msg: &assistant_msg_clone,
890 context: &context_clone,
891 tool_timeout,
892 abort_flag: abort,
893 })
894 .await;
895
896 (tc_id, tc_name, final_content, final_is_error)
897 });
898 }
899
900 let mut buffered =
902 futures::stream::iter(tool_futures).buffer_unordered(max_parallel);
903
904 while let Some((tc_id, tc_name, content, is_error)) = buffered.next().await {
905 let result_json =
906 serde_json::to_value(&content).unwrap_or(serde_json::Value::Null);
907 self.emit(AgentEvent::ToolExecutionEnd {
908 tool_call_id: tc_id.clone(),
909 tool_name: tc_name.clone(),
910 result: result_json,
911 is_error,
912 });
913
914 self.state.pending_tool_calls.write().remove(&tc_id);
915
916 results.push(ToolResultMessage::new(tc_id, tc_name, content, is_error));
917 }
918 }
919 ToolExecutionMode::Sequential => {
920 for tc in &tool_calls {
921 if self.abort_flag.load(Ordering::SeqCst) {
922 break;
923 }
924
925 let tc_id = tc.id.clone();
926 let tc_name = tc.name.clone();
927 let tc_args = tc.arguments.clone();
928 let tc_clone = (*tc).clone();
929
930 self.emit(AgentEvent::ToolExecutionStart {
931 tool_call_id: tc_id.clone(),
932 tool_name: tc_name.clone(),
933 args: tc_args.clone(),
934 });
935
936 self.state.pending_tool_calls.write().insert(tc_id.clone());
937
938 if let Some(result) = validate_tool_call_or_error(
940 &tc_id, &tc_name, &tc_args, &tool_defs, &security,
941 ) {
942 self.emit(AgentEvent::ToolExecutionEnd {
943 tool_call_id: tc_id.clone(),
944 tool_name: tc_name.clone(),
945 result: serde_json::json!({"error": result.text_content()}),
946 is_error: true,
947 });
948 self.state.pending_tool_calls.write().remove(&tc_id);
949 results.push(result);
950 continue;
951 }
952
953 if let Some(result) = run_before_hook(
955 &before_hook,
956 assistant_msg,
957 &tc_clone,
958 &tc_args,
959 context,
960 &tc_id,
961 &tc_name,
962 )
963 .await
964 {
965 self.emit(AgentEvent::ToolExecutionEnd {
966 tool_call_id: tc_id.clone(),
967 tool_name: tc_name.clone(),
968 result: serde_json::json!({"error": result.text_content()}),
969 is_error: true,
970 });
971 self.state.pending_tool_calls.write().remove(&tc_id);
972 results.push(result);
973 continue;
974 }
975
976 let abort_flag = Arc::clone(&self.abort_flag);
977 let (final_content, final_is_error) =
978 execute_and_apply_after_hook(ToolExecCtx {
979 executor: &executor,
980 after_hook: &after_hook,
981 subscribers: &self.subscribers,
982 tc_id: &tc_id,
983 tc_name: &tc_name,
984 tc_args: &tc_args,
985 tc: &tc_clone,
986 assistant_msg,
987 context,
988 tool_timeout,
989 abort_flag,
990 })
991 .await;
992
993 let result_json =
994 serde_json::to_value(&final_content).unwrap_or(serde_json::Value::Null);
995 self.emit(AgentEvent::ToolExecutionEnd {
996 tool_call_id: tc_id.clone(),
997 tool_name: tc_name.clone(),
998 result: result_json,
999 is_error: final_is_error,
1000 });
1001
1002 self.state.pending_tool_calls.write().remove(&tc_id);
1003
1004 results.push(ToolResultMessage::new(
1005 tc_id,
1006 tc_name,
1007 final_content,
1008 final_is_error,
1009 ));
1010
1011 let steering = self.dequeue_steering_messages();
1013 if !steering.is_empty() {
1014 for steer_msg in steering {
1015 self.state.add_message(steer_msg);
1016 }
1017 break;
1019 }
1020 }
1021 }
1022 }
1023
1024 results
1025 }
1026
1027 async fn run_loop(&self) -> Result<Vec<AgentMessage>, AgentError> {
1029 let provider = if self.hooks.read().stream_fn.is_some() {
1030 None
1033 } else {
1034 Some(self.resolve_provider()?)
1035 };
1036
1037 let max_turns = *self.max_turns.read();
1038 let mut new_messages = Vec::new();
1039 let mut turn_count = 0;
1040
1041 let max_messages = self.config.read().security.agent.max_messages;
1043 self.state.set_max_messages(max_messages);
1044
1045 loop {
1046 if self.abort_flag.load(Ordering::SeqCst) {
1048 self.emit(AgentEvent::AgentEnd {
1049 messages: new_messages.clone(),
1050 });
1051 return Err(AgentError::Other("Aborted".to_string()));
1052 }
1053
1054 if turn_count >= max_turns {
1056 break;
1057 }
1058
1059 self.emit(AgentEvent::TurnStart);
1060
1061 let dummy_provider: ArcProtocol = Arc::new(DummyProvider);
1063 let active_provider = provider.as_ref().unwrap_or(&dummy_provider);
1064 let assistant_result = self.run_turn(active_provider).await;
1065
1066 match assistant_result {
1067 Ok(assistant_msg) => {
1068 let context = self.build_context().await;
1070
1071 let agent_msg = AgentMessage::Assistant(assistant_msg.clone());
1073 self.state.add_message(agent_msg.clone());
1074 new_messages.push(agent_msg.clone());
1075
1076 self.emit(AgentEvent::MessageStart {
1077 message: agent_msg.clone(),
1078 });
1079 self.emit(AgentEvent::MessageEnd {
1080 message: agent_msg.clone(),
1081 });
1082
1083 if assistant_msg.has_tool_calls()
1085 && assistant_msg.stop_reason == StopReason::ToolUse
1086 {
1087 let tool_results = self.execute_tool_calls(&assistant_msg, &context).await;
1088
1089 for result in &tool_results {
1090 let result_msg = AgentMessage::ToolResult(result.clone());
1091 self.state.add_message(result_msg.clone());
1092 new_messages.push(result_msg);
1093 }
1094
1095 self.emit(AgentEvent::TurnEnd {
1096 message: agent_msg,
1097 tool_results,
1098 });
1099
1100 let follow_ups = self.dequeue_follow_up_messages();
1102 for msg in follow_ups {
1103 self.state.add_message(msg.clone());
1104 new_messages.push(msg);
1105 }
1106
1107 turn_count += 1;
1108 continue;
1109 } else {
1110 self.emit(AgentEvent::TurnEnd {
1112 message: agent_msg,
1113 tool_results: Vec::new(),
1114 });
1115
1116 let follow_ups = self.dequeue_follow_up_messages();
1118 if !follow_ups.is_empty() {
1119 for msg in follow_ups {
1120 self.state.add_message(msg.clone());
1121 new_messages.push(msg);
1122 }
1123 turn_count += 1;
1124 continue;
1125 }
1126
1127 break;
1128 }
1129 }
1130 Err(AgentError::Other(ref msg)) if msg == "Steered" => {
1131 turn_count += 1;
1132 continue;
1133 }
1134 Err(e) => {
1135 *self.state.error.write() = Some(e.to_string());
1136 return Err(e);
1137 }
1138 }
1139 }
1140
1141 Ok(new_messages)
1142 }
1143
1144 pub async fn prompt(
1152 &self,
1153 message: impl Into<AgentMessage>,
1154 ) -> Result<Vec<AgentMessage>, AgentError> {
1155 if self
1157 .state
1158 .is_streaming
1159 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
1160 .is_err()
1161 {
1162 return Err(AgentError::AlreadyStreaming);
1163 }
1164
1165 let message = message.into();
1166 self.abort_flag.store(false, Ordering::SeqCst);
1167
1168 self.state.add_message(message.clone());
1170
1171 self.emit(AgentEvent::AgentStart);
1173
1174 let result = self.run_loop().await;
1176
1177 self.state.set_streaming(false);
1178
1179 match result {
1180 Ok(messages) => {
1181 self.emit(AgentEvent::AgentEnd {
1182 messages: messages.clone(),
1183 });
1184 Ok(messages)
1185 }
1186 Err(e) => {
1187 self.emit(AgentEvent::AgentEnd {
1188 messages: Vec::new(),
1189 });
1190 Err(e)
1191 }
1192 }
1193 }
1194
1195 pub async fn continue_(&self) -> Result<Vec<AgentMessage>, AgentError> {
1199 if self
1200 .state
1201 .is_streaming
1202 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
1203 .is_err()
1204 {
1205 return Err(AgentError::AlreadyStreaming);
1206 }
1207
1208 {
1209 let messages = self.state.messages.read();
1210 if messages.is_empty() {
1211 self.state.set_streaming(false);
1212 return Err(AgentError::NoMessages);
1213 }
1214 if let Some(AgentMessage::Assistant(_)) = messages.last() {
1215 self.state.set_streaming(false);
1216 return Err(AgentError::CannotContinueFromAssistant);
1217 }
1218 }
1219
1220 self.abort_flag.store(false, Ordering::SeqCst);
1221
1222 self.emit(AgentEvent::AgentStart);
1223
1224 let result = self.run_loop().await;
1225
1226 self.state.set_streaming(false);
1227
1228 match result {
1229 Ok(messages) => {
1230 self.emit(AgentEvent::AgentEnd {
1231 messages: messages.clone(),
1232 });
1233 Ok(messages)
1234 }
1235 Err(e) => {
1236 self.emit(AgentEvent::AgentEnd {
1237 messages: Vec::new(),
1238 });
1239 Err(e)
1240 }
1241 }
1242 }
1243
1244 pub fn abort(&self) {
1246 self.abort_flag.store(true, Ordering::SeqCst);
1247 self.state.set_streaming(false);
1248 self.clear_all_queues();
1249 }
1250
1251 pub async fn wait_for_idle(&self) {
1253 while self.state.is_streaming() {
1254 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1255 }
1256 }
1257
1258 pub fn state(&self) -> &Arc<AgentState> {
1260 &self.state
1261 }
1262
1263 pub fn snapshot(&self) -> AgentStateSnapshot {
1268 let config = self.config.read();
1269 let system_prompt = self.state.system_prompt.read().clone();
1270 let messages = self.state.messages.read().clone();
1271 let is_streaming = self.state.is_streaming();
1272 let stream_message = self.state.stream_message.read().clone();
1273 let pending_tool_calls = self.state.pending_tool_calls.read().clone();
1274 let error = self.state.error.read().clone();
1275 let max_messages = self.state.get_max_messages();
1276 let message_count = messages.len();
1277
1278 AgentStateSnapshot {
1279 system_prompt,
1280 model: config.model.clone(),
1281 thinking_level: config.thinking_level,
1282 messages,
1283 is_streaming,
1284 stream_message,
1285 pending_tool_calls,
1286 error,
1287 message_count,
1288 max_messages,
1289 }
1290 }
1291}
1292
1293async fn wait_for_abort(flag: Arc<AtomicBool>) {
1295 loop {
1296 if flag.load(Ordering::SeqCst) {
1297 return;
1298 }
1299 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1300 }
1301}
1302
1303fn validate_tool_call_or_error(
1313 tc_id: &str,
1314 tc_name: &str,
1315 tc_args: &serde_json::Value,
1316 tool_defs: &[Tool],
1317 security: &SecurityConfig,
1318) -> Option<ToolResultMessage> {
1319 if !security.agent.validate_tool_calls || tool_defs.is_empty() {
1320 return None;
1321 }
1322
1323 let tc = ToolCall::new(tc_id, tc_name, tc_args.clone());
1324 match crate::validation::validate_tool_call(tool_defs, &tc) {
1325 Ok(_) => None,
1326 Err(e) => Some(ToolResultMessage::error(tc_id, tc_name, e.to_string())),
1327 }
1328}
1329
1330async fn run_before_hook(
1335 before_hook: &Option<BeforeToolCallFn>,
1336 assistant_msg: &AssistantMessage,
1337 tc: &ToolCall,
1338 tc_args: &serde_json::Value,
1339 context: &Context,
1340 tc_id: &str,
1341 tc_name: &str,
1342) -> Option<ToolResultMessage> {
1343 let hook = before_hook.as_ref()?;
1344 let ctx = BeforeToolCallContext {
1345 assistant_message: assistant_msg.clone(),
1346 tool_call: tc.clone(),
1347 args: tc_args.clone(),
1348 context: context.clone(),
1349 };
1350 match hook(ctx).await {
1351 Some(result) if result.block => {
1352 let reason = result
1353 .reason
1354 .unwrap_or_else(|| "Tool call blocked by before_tool_call hook".to_string());
1355 Some(ToolResultMessage::error(tc_id, tc_name, reason))
1356 }
1357 _ => None,
1358 }
1359}
1360
1361struct ToolExecCtx<'a> {
1366 executor: &'a Option<ToolExecutor>,
1367 after_hook: &'a Option<AfterToolCallFn>,
1368 subscribers: &'a Arc<Subscribers>,
1369 tc_id: &'a str,
1370 tc_name: &'a str,
1371 tc_args: &'a serde_json::Value,
1372 tc: &'a ToolCall,
1373 assistant_msg: &'a AssistantMessage,
1374 context: &'a Context,
1375 tool_timeout: std::time::Duration,
1376 abort_flag: Arc<AtomicBool>,
1377}
1378
1379async fn execute_and_apply_after_hook(ctx: ToolExecCtx<'_>) -> (Vec<ContentBlock>, bool) {
1387 let ToolExecCtx {
1388 executor,
1389 after_hook,
1390 subscribers,
1391 tc_id,
1392 tc_name,
1393 tc_args,
1394 tc,
1395 assistant_msg,
1396 context,
1397 tool_timeout,
1398 abort_flag,
1399 } = ctx;
1400 let tool_result = if let Some(ref exec) = executor {
1402 let subs = Arc::clone(subscribers);
1404 let update_tc_id = tc_id.to_string();
1405 let update_tc_name = tc_name.to_string();
1406 let update_cb: ToolUpdateCallback = Arc::new(move |partial: serde_json::Value| {
1407 subs.emit(&AgentEvent::ToolExecutionUpdate {
1408 tool_call_id: update_tc_id.clone(),
1409 tool_name: update_tc_name.clone(),
1410 partial_result: partial,
1411 });
1412 });
1413
1414 let exec_future = exec(tc_name, tc_id, tc_args, Some(update_cb));
1415
1416 tokio::select! {
1418 result = exec_future => result,
1419 _ = tokio::time::sleep(tool_timeout) => {
1420 AgentToolResult::error(format!(
1421 "Tool '{}' timed out after {:?}",
1422 tc_name, tool_timeout
1423 ))
1424 }
1425 _ = wait_for_abort(abort_flag) => {
1426 AgentToolResult::error(format!("Tool '{}' aborted", tc_name))
1427 }
1428 }
1429 } else {
1430 AgentToolResult::error(format!(
1431 "No tool executor configured for tool '{}'",
1432 tc_name
1433 ))
1434 };
1435
1436 let mut is_error = tool_result.content.iter().any(|block| {
1438 if let Some(text) = block.as_text() {
1439 text.text.starts_with("Error:") || text.text.starts_with("error:")
1440 } else {
1441 false
1442 }
1443 });
1444
1445 let mut final_content = tool_result.content.clone();
1446
1447 if let Some(ref hook) = after_hook {
1449 let after_ctx = AfterToolCallContext {
1450 assistant_message: assistant_msg.clone(),
1451 tool_call: tc.clone(),
1452 args: tc_args.clone(),
1453 result: tool_result,
1454 is_error,
1455 context: context.clone(),
1456 };
1457 if let Some(overrides) = hook(after_ctx).await {
1458 if let Some(content_override) = overrides.content {
1459 final_content = content_override;
1460 }
1461 if let Some(error_override) = overrides.is_error {
1462 is_error = error_override;
1463 }
1464 }
1465 }
1466
1467 (final_content, is_error)
1468}
1469
1470impl Default for Agent {
1471 fn default() -> Self {
1472 Self::new()
1473 }
1474}
1475
1476struct DummyProvider;
1479
1480#[async_trait::async_trait]
1481impl crate::provider::LLMProtocol for DummyProvider {
1482 fn provider_type(&self) -> Provider {
1483 Provider::Custom("dummy".to_string())
1484 }
1485
1486 fn stream(
1487 &self,
1488 _model: &Model,
1489 _context: &Context,
1490 _options: StreamOptions,
1491 ) -> AssistantMessageEventStream {
1492 let stream = AssistantMessageEventStream::new_assistant_stream();
1493 let error_msg = AssistantMessage::builder()
1494 .provider(Provider::Custom("dummy".to_string()))
1495 .model("dummy")
1496 .stop_reason(StopReason::Error)
1497 .error_message("DummyProvider should not be called when stream_fn is set")
1498 .build()
1499 .unwrap();
1500 stream.push(AssistantMessageEvent::Error {
1501 reason: StopReason::Error,
1502 error: error_msg,
1503 });
1504 stream.end(None);
1505 stream
1506 }
1507
1508 fn stream_simple(
1509 &self,
1510 model: &Model,
1511 context: &Context,
1512 options: SimpleStreamOptions,
1513 ) -> AssistantMessageEventStream {
1514 self.stream(model, context, options.base)
1515 }
1516}
1517
1518#[derive(Debug, thiserror::Error)]
1520pub enum AgentError {
1521 #[error("Agent is already streaming")]
1522 AlreadyStreaming,
1523
1524 #[error("No messages in context")]
1525 NoMessages,
1526
1527 #[error("Cannot continue from assistant message")]
1528 CannotContinueFromAssistant,
1529
1530 #[error("Tool not found: {0}")]
1531 ToolNotFound(String),
1532
1533 #[error("Provider error: {0}")]
1534 ProviderError(String),
1535
1536 #[error("{0}")]
1537 Other(String),
1538}