Skip to main content

tiy_core/agent/
agent.rs

1//! Agent implementation with full conversation loop.
2
3use crate::agent::{
4    AfterToolCallContext, AfterToolCallFn, AgentConfig, AgentEvent, AgentHooks, AgentMessage,
5    AgentState, AgentStateSnapshot, AgentTool, AgentToolResult, BeforeToolCallContext,
6    BeforeToolCallFn, BeforeToolCallResult, QueueMode, ThinkingBudgets, ToolExecutionMode,
7    ToolExecutor, ToolUpdateCallback, Transport,
8};
9use crate::provider::{get_provider, ArcProtocol};
10use crate::stream::AssistantMessageEventStream;
11use crate::thinking::ThinkingLevel;
12use crate::types::*;
13use futures::StreamExt;
14use parking_lot::{Mutex, RwLock};
15use std::collections::{HashMap, VecDeque};
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18
19/// Default maximum number of turns (LLM calls) per prompt.
20const DEFAULT_MAX_TURNS: usize = 25;
21
22/// Subscriber ID for unsubscription.
23pub type SubscriberId = u64;
24
25/// Callback type for event subscribers.
26type SubscriberCallback = Arc<dyn Fn(&AgentEvent) + Send + Sync>;
27
28/// Thread-safe subscriber storage using HashMap to avoid tombstone leaks.
29struct Subscribers {
30    callbacks: RwLock<HashMap<u64, SubscriberCallback>>,
31    next_id: AtomicU64,
32}
33
34impl Subscribers {
35    fn new() -> Self {
36        Self {
37            callbacks: RwLock::new(HashMap::new()),
38            next_id: AtomicU64::new(0),
39        }
40    }
41
42    fn subscribe(&self, callback: SubscriberCallback) -> SubscriberId {
43        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
44        self.callbacks.write().insert(id, callback);
45        id
46    }
47
48    fn unsubscribe(&self, id: SubscriberId) {
49        self.callbacks.write().remove(&id);
50    }
51
52    /// Emit an event to all subscribers.
53    /// Clones Arcs under read lock, then calls callbacks outside the lock
54    /// to prevent blocking subscribe/unsubscribe operations.
55    fn emit(&self, event: &AgentEvent) {
56        let snapshot: Vec<SubscriberCallback> =
57            { self.callbacks.read().values().cloned().collect() };
58        for cb in &snapshot {
59            cb(event);
60        }
61    }
62}
63
64/// Agent for managing stateful conversations with LLM providers.
65pub struct Agent {
66    /// Agent state.
67    state: Arc<AgentState>,
68    /// Configuration.
69    config: RwLock<AgentConfig>,
70    /// Provider (optional, resolved from registry if not set).
71    provider: RwLock<Option<ArcProtocol>>,
72    /// Aggregated hooks (tool executor, before/after hooks, converters, etc.).
73    hooks: RwLock<AgentHooks>,
74    /// Maximum turns per prompt.
75    max_turns: RwLock<usize>,
76    /// Steering message queue.
77    steering_queue: Mutex<VecDeque<AgentMessage>>,
78    /// Follow-up message queue.
79    follow_up_queue: Mutex<VecDeque<AgentMessage>>,
80    /// Event subscribers (HashMap-based, no tombstone leak).
81    subscribers: Arc<Subscribers>,
82    /// Abort flag.
83    abort_flag: Arc<AtomicBool>,
84    /// API key for the provider.
85    api_key: RwLock<Option<String>>,
86    /// Session ID for caching.
87    session_id: RwLock<Option<String>>,
88}
89
90impl Agent {
91    /// Create a new agent with default configuration.
92    pub fn new() -> Self {
93        Self {
94            state: Arc::new(AgentState::new()),
95            config: RwLock::new(AgentConfig::new(
96                Model::builder()
97                    .id("gpt-4o-mini")
98                    .name("GPT-4o Mini")
99                    .provider(Provider::OpenAI)
100                    .base_url("https://api.openai.com/v1")
101                    .context_window(128000)
102                    .max_tokens(16384)
103                    .build()
104                    .unwrap(),
105            )),
106            provider: RwLock::new(None),
107            hooks: RwLock::new(AgentHooks::default()),
108            max_turns: RwLock::new(DEFAULT_MAX_TURNS),
109            steering_queue: Mutex::new(VecDeque::new()),
110            follow_up_queue: Mutex::new(VecDeque::new()),
111            subscribers: Arc::new(Subscribers::new()),
112            abort_flag: Arc::new(AtomicBool::new(false)),
113            api_key: RwLock::new(None),
114            session_id: RwLock::new(None),
115        }
116    }
117
118    /// Create an agent with a model.
119    pub fn with_model(model: Model) -> Self {
120        let agent = Self::new();
121        agent.set_model(model.clone());
122        *agent.config.write() = AgentConfig::new(model);
123        agent
124    }
125
126    // ============================================================================
127    // Provider & API Key
128    // ============================================================================
129
130    /// Set the LLM provider explicitly.
131    pub fn set_provider(&self, provider: ArcProtocol) {
132        *self.provider.write() = Some(provider);
133    }
134
135    /// Set a static API key.
136    pub fn set_api_key(&self, key: impl Into<String>) {
137        *self.api_key.write() = Some(key.into());
138    }
139
140    /// Set a dynamic API key resolver.
141    ///
142    /// Called before each LLM request. Useful for short-lived OAuth tokens
143    /// that may expire during long-running tool execution phases.
144    pub fn set_get_api_key<F, Fut>(&self, resolver: F)
145    where
146        F: Fn(&str) -> Fut + Send + Sync + 'static,
147        Fut: std::future::Future<Output = Option<String>> + Send + 'static,
148    {
149        let resolver = Arc::new(move |provider: &str| {
150            let fut = resolver(provider);
151            Box::pin(fut)
152                as std::pin::Pin<Box<dyn std::future::Future<Output = Option<String>> + Send>>
153        });
154        self.hooks.write().get_api_key = Some(resolver);
155    }
156
157    // ============================================================================
158    // Tool Executor & Hooks
159    // ============================================================================
160
161    /// Set the tool executor callback.
162    ///
163    /// The executor receives `(tool_name, tool_call_id, arguments, update_callback)`.
164    /// The `update_callback` can be called during execution to push streaming
165    /// partial results (emitted as `ToolExecutionUpdate` events).
166    pub fn set_tool_executor<F, Fut>(&self, executor: F)
167    where
168        F: Fn(&str, &str, &serde_json::Value, Option<ToolUpdateCallback>) -> Fut
169            + Send
170            + Sync
171            + 'static,
172        Fut: std::future::Future<Output = AgentToolResult> + Send + 'static,
173    {
174        let executor = Arc::new(
175            move |name: &str,
176                  id: &str,
177                  args: &serde_json::Value,
178                  update_cb: Option<ToolUpdateCallback>| {
179                let fut = executor(name, id, args, update_cb);
180                Box::pin(fut)
181                    as std::pin::Pin<Box<dyn std::future::Future<Output = AgentToolResult> + Send>>
182            },
183        );
184        self.hooks.write().tool_executor = Some(executor);
185    }
186
187    /// Set the tool executor callback (simple version without update callback).
188    ///
189    /// Convenience method for tools that don't need streaming updates.
190    pub fn set_tool_executor_simple<F, Fut>(&self, executor: F)
191    where
192        F: Fn(&str, &str, &serde_json::Value) -> Fut + Send + Sync + 'static,
193        Fut: std::future::Future<Output = AgentToolResult> + Send + 'static,
194    {
195        let executor = Arc::new(
196            move |name: &str,
197                  id: &str,
198                  args: &serde_json::Value,
199                  _update_cb: Option<ToolUpdateCallback>| {
200                let fut = executor(name, id, args);
201                Box::pin(fut)
202                    as std::pin::Pin<Box<dyn std::future::Future<Output = AgentToolResult> + Send>>
203            },
204        );
205        self.hooks.write().tool_executor = Some(executor);
206    }
207
208    /// Set the `before_tool_call` hook.
209    ///
210    /// Called after arguments are validated but before tool execution.
211    /// Return `BeforeToolCallResult { block: true, .. }` to prevent execution.
212    pub fn set_before_tool_call<F, Fut>(&self, hook: F)
213    where
214        F: Fn(BeforeToolCallContext) -> Fut + Send + Sync + 'static,
215        Fut: std::future::Future<Output = Option<BeforeToolCallResult>> + Send + 'static,
216    {
217        let hook = Arc::new(move |ctx: BeforeToolCallContext| {
218            let fut = hook(ctx);
219            Box::pin(fut)
220                as std::pin::Pin<
221                    Box<dyn std::future::Future<Output = Option<BeforeToolCallResult>> + Send>,
222                >
223        });
224        self.hooks.write().before_tool_call = Some(hook);
225    }
226
227    /// Set the `after_tool_call` hook.
228    ///
229    /// Called after tool execution, before the result is committed.
230    /// Return `AfterToolCallResult` to override content, details, or is_error.
231    pub fn set_after_tool_call<F, Fut>(&self, hook: F)
232    where
233        F: Fn(AfterToolCallContext) -> Fut + Send + Sync + 'static,
234        Fut: std::future::Future<Output = Option<crate::agent::AfterToolCallResult>>
235            + Send
236            + 'static,
237    {
238        let hook = Arc::new(move |ctx: AfterToolCallContext| {
239            let fut = hook(ctx);
240            Box::pin(fut)
241                as std::pin::Pin<
242                    Box<
243                        dyn std::future::Future<Output = Option<crate::agent::AfterToolCallResult>>
244                            + Send,
245                    >,
246                >
247        });
248        self.hooks.write().after_tool_call = Some(hook);
249    }
250
251    // ============================================================================
252    // Context Pipeline
253    // ============================================================================
254
255    /// Set the custom `AgentMessage[]` → `Message[]` conversion function.
256    ///
257    /// Called before each LLM request. The default filters out `Custom` messages
258    /// and maps User/Assistant/ToolResult directly.
259    pub fn set_convert_to_llm<F, Fut>(&self, converter: F)
260    where
261        F: Fn(Vec<AgentMessage>) -> Fut + Send + Sync + 'static,
262        Fut: std::future::Future<Output = Vec<Message>> + Send + 'static,
263    {
264        let converter = Arc::new(move |msgs: Vec<AgentMessage>| {
265            let fut = converter(msgs);
266            Box::pin(fut)
267                as std::pin::Pin<Box<dyn std::future::Future<Output = Vec<Message>> + Send>>
268        });
269        self.hooks.write().convert_to_llm = Some(converter);
270    }
271
272    /// Set the context transformation function (applied BEFORE `convert_to_llm`).
273    ///
274    /// Use this for context window management, message pruning, injecting
275    /// external context, etc.
276    pub fn set_transform_context<F, Fut>(&self, transform: F)
277    where
278        F: Fn(Vec<AgentMessage>) -> Fut + Send + Sync + 'static,
279        Fut: std::future::Future<Output = Vec<AgentMessage>> + Send + 'static,
280    {
281        let transform = Arc::new(move |msgs: Vec<AgentMessage>| {
282            let fut = transform(msgs);
283            Box::pin(fut)
284                as std::pin::Pin<Box<dyn std::future::Future<Output = Vec<AgentMessage>> + Send>>
285        });
286        self.hooks.write().transform_context = Some(transform);
287    }
288
289    // ============================================================================
290    // Payload & Stream Hooks
291    // ============================================================================
292
293    /// Set the payload inspection / replacement hook.
294    ///
295    /// Called with the serialized request body before it is sent to the provider.
296    pub fn set_on_payload<F, Fut>(&self, hook: F)
297    where
298        F: Fn(serde_json::Value, Model) -> Fut + Send + Sync + 'static,
299        Fut: std::future::Future<Output = Option<serde_json::Value>> + Send + 'static,
300    {
301        let hook = Arc::new(move |payload: serde_json::Value, model: Model| {
302            let fut = hook(payload, model);
303            Box::pin(fut)
304                as std::pin::Pin<
305                    Box<dyn std::future::Future<Output = Option<serde_json::Value>> + Send>,
306                >
307        });
308        self.hooks.write().on_payload = Some(hook);
309    }
310
311    /// Set a custom stream function to replace the default provider streaming.
312    ///
313    /// Useful for proxy backends, custom routing, etc.
314    pub fn set_stream_fn<F, Fut>(&self, stream_fn: F)
315    where
316        F: Fn(&Model, &Context, StreamOptions) -> Fut + Send + Sync + 'static,
317        Fut: std::future::Future<Output = AssistantMessageEventStream> + Send + 'static,
318    {
319        let stream_fn = Arc::new(
320            move |model: &Model, context: &Context, options: StreamOptions| {
321                let fut = stream_fn(model, context, options);
322                Box::pin(fut)
323                    as std::pin::Pin<
324                        Box<dyn std::future::Future<Output = AssistantMessageEventStream> + Send>,
325                    >
326            },
327        );
328        self.hooks.write().stream_fn = Some(stream_fn);
329    }
330
331    // ============================================================================
332    // Configuration Setters
333    // ============================================================================
334
335    /// Set maximum turns per prompt.
336    pub fn set_max_turns(&self, max: usize) {
337        *self.max_turns.write() = max;
338    }
339
340    /// Set the security configuration.
341    pub fn set_security_config(&self, config: crate::types::SecurityConfig) {
342        self.config.write().security = config;
343    }
344
345    /// Get the current security configuration.
346    pub fn security_config(&self) -> crate::types::SecurityConfig {
347        self.config.read().security.clone()
348    }
349
350    /// Set tool execution mode.
351    pub fn set_tool_execution(&self, mode: ToolExecutionMode) {
352        self.config.write().tool_execution = mode;
353    }
354
355    /// Set the steering queue mode.
356    pub fn set_steering_mode(&self, mode: QueueMode) {
357        self.config.write().steering_mode = mode;
358    }
359
360    /// Get the steering queue mode.
361    pub fn steering_mode(&self) -> QueueMode {
362        self.config.read().steering_mode
363    }
364
365    /// Set the follow-up queue mode.
366    pub fn set_follow_up_mode(&self, mode: QueueMode) {
367        self.config.write().follow_up_mode = mode;
368    }
369
370    /// Get the follow-up queue mode.
371    pub fn follow_up_mode(&self) -> QueueMode {
372        self.config.read().follow_up_mode
373    }
374
375    /// Set custom thinking budgets.
376    pub fn set_thinking_budgets(&self, budgets: ThinkingBudgets) {
377        self.config.write().thinking_budgets = Some(budgets);
378    }
379
380    /// Get the current thinking budgets.
381    pub fn thinking_budgets(&self) -> Option<ThinkingBudgets> {
382        self.config.read().thinking_budgets.clone()
383    }
384
385    /// Set the preferred transport.
386    pub fn set_transport(&self, transport: Transport) {
387        self.config.write().transport = transport;
388    }
389
390    /// Get the preferred transport.
391    pub fn transport(&self) -> Transport {
392        self.config.read().transport
393    }
394
395    /// Set the maximum retry delay in milliseconds.
396    ///
397    /// If the server requests a retry delay exceeding this value, the request
398    /// fails immediately so higher-level retry logic can handle it with user
399    /// visibility. `None` = use default (60_000ms). `Some(0)` = disable cap.
400    pub fn set_max_retry_delay_ms(&self, ms: Option<u64>) {
401        self.config.write().max_retry_delay_ms = ms;
402    }
403
404    /// Get the current max retry delay.
405    pub fn max_retry_delay_ms(&self) -> Option<u64> {
406        self.config.read().max_retry_delay_ms
407    }
408
409    /// Set the session ID for caching.
410    pub fn set_session_id(&self, id: impl Into<String>) {
411        *self.session_id.write() = Some(id.into());
412    }
413
414    /// Get the current session ID.
415    pub fn session_id(&self) -> Option<String> {
416        self.session_id.read().clone()
417    }
418
419    /// Clear the session ID.
420    pub fn clear_session_id(&self) {
421        *self.session_id.write() = None;
422    }
423
424    // ============================================================================
425    // Event Subscription
426    // ============================================================================
427
428    /// Subscribe to agent events. Returns an unsubscribe closure.
429    pub fn subscribe<F>(&self, callback: F) -> impl Fn()
430    where
431        F: Fn(&AgentEvent) + Send + Sync + 'static,
432    {
433        let id = self.subscribers.subscribe(Arc::new(callback));
434        let subs = Arc::clone(&self.subscribers);
435        move || {
436            subs.unsubscribe(id);
437        }
438    }
439
440    /// Emit an event to all subscribers.
441    fn emit(&self, event: AgentEvent) {
442        self.subscribers.emit(&event);
443    }
444
445    // ============================================================================
446    // State Management
447    // ============================================================================
448
449    /// Set the system prompt.
450    pub fn set_system_prompt(&self, prompt: impl Into<String>) {
451        self.state.set_system_prompt(prompt);
452    }
453
454    /// Set the model.
455    pub fn set_model(&self, model: Model) {
456        self.config.write().model = model;
457    }
458
459    /// Set the thinking level.
460    pub fn set_thinking_level(&self, level: ThinkingLevel) {
461        self.config.write().thinking_level = level;
462    }
463
464    /// Set the tools.
465    pub fn set_tools(&self, tools: Vec<AgentTool>) {
466        self.state.set_tools(tools);
467    }
468
469    /// Replace all messages.
470    pub fn replace_messages(&self, messages: Vec<AgentMessage>) {
471        self.state.replace_messages(messages);
472    }
473
474    /// Append a message.
475    pub fn append_message(&self, message: AgentMessage) {
476        self.state.add_message(message);
477    }
478
479    /// Clear all messages.
480    pub fn clear_messages(&self) {
481        self.state.clear_messages();
482    }
483
484    /// Reset the agent.
485    pub fn reset(&self) {
486        self.state.reset();
487        self.steering_queue.lock().clear();
488        self.follow_up_queue.lock().clear();
489        *self.session_id.write() = None;
490    }
491
492    // ============================================================================
493    // Steering and Follow-up
494    // ============================================================================
495
496    /// Add a steering message (interrupts current work).
497    pub fn steer(&self, message: AgentMessage) {
498        self.steering_queue.lock().push_back(message);
499    }
500
501    /// Add a follow-up message (processed after current work completes).
502    pub fn follow_up(&self, message: AgentMessage) {
503        self.follow_up_queue.lock().push_back(message);
504    }
505
506    /// Clear steering queue.
507    pub fn clear_steering_queue(&self) {
508        self.steering_queue.lock().clear();
509    }
510
511    /// Clear follow-up queue.
512    pub fn clear_follow_up_queue(&self) {
513        self.follow_up_queue.lock().clear();
514    }
515
516    /// Clear all queues.
517    pub fn clear_all_queues(&self) {
518        self.clear_steering_queue();
519        self.clear_follow_up_queue();
520    }
521
522    /// Check if there are queued messages.
523    pub fn has_queued_messages(&self) -> bool {
524        !self.steering_queue.lock().is_empty() || !self.follow_up_queue.lock().is_empty()
525    }
526
527    /// Dequeue steering messages respecting the configured mode.
528    fn dequeue_steering_messages(&self) -> Vec<AgentMessage> {
529        let mode = self.config.read().steering_mode;
530        let mut queue = self.steering_queue.lock();
531        match mode {
532            QueueMode::All => queue.drain(..).collect(),
533            QueueMode::OneAtATime => {
534                if let Some(first) = queue.pop_front() {
535                    vec![first]
536                } else {
537                    Vec::new()
538                }
539            }
540        }
541    }
542
543    /// Dequeue follow-up messages respecting the configured mode.
544    fn dequeue_follow_up_messages(&self) -> Vec<AgentMessage> {
545        let mode = self.config.read().follow_up_mode;
546        let mut queue = self.follow_up_queue.lock();
547        match mode {
548            QueueMode::All => queue.drain(..).collect(),
549            QueueMode::OneAtATime => {
550                if let Some(first) = queue.pop_front() {
551                    vec![first]
552                } else {
553                    Vec::new()
554                }
555            }
556        }
557    }
558
559    // ============================================================================
560    // Core Agent Loop
561    // ============================================================================
562
563    /// Default `convert_to_llm`: filters out Custom messages and maps directly.
564    fn default_convert_to_llm(messages: Vec<AgentMessage>) -> Vec<Message> {
565        messages
566            .into_iter()
567            .filter_map(|m| {
568                let opt: Option<Message> = m.into();
569                opt
570            })
571            .collect()
572    }
573
574    /// Build the context from current agent state using the full pipeline:
575    /// `messages → transform_context → convert_to_llm → Context`
576    async fn build_context(&self) -> Context {
577        let system_prompt = self.state.system_prompt.read().clone();
578        let messages = self.state.messages.read().clone();
579        let tools = self.state.tools.read().clone();
580
581        // Step 1: transform_context (if set)
582        let transform = self.hooks.read().transform_context.clone();
583        let messages = if let Some(ref transform) = transform {
584            transform(messages).await
585        } else {
586            messages
587        };
588
589        // Step 2: convert_to_llm
590        let converter = self.hooks.read().convert_to_llm.clone();
591        let llm_messages = if let Some(ref converter) = converter {
592            converter(messages).await
593        } else {
594            Self::default_convert_to_llm(messages)
595        };
596
597        // Step 3: Build Context
598        let mut context = if system_prompt.is_empty() {
599            Context::new()
600        } else {
601            Context::with_system_prompt(&system_prompt)
602        };
603
604        for msg in llm_messages {
605            context.add_message(msg);
606        }
607
608        // Add tools
609        if !tools.is_empty() {
610            let tool_defs: Vec<Tool> = tools.iter().map(|t| t.as_tool()).collect();
611            context.set_tools(tool_defs);
612        }
613
614        context
615    }
616
617    /// Resolve the provider to use.
618    fn resolve_provider(&self) -> Result<ArcProtocol, AgentError> {
619        // First check explicit provider
620        if let Some(ref provider) = *self.provider.read() {
621            return Ok(provider.clone());
622        }
623
624        // Then try registry by Provider type
625        let model = self.config.read().model.clone();
626        if let Some(provider) = get_provider(&model.provider) {
627            return Ok(provider);
628        }
629
630        Err(AgentError::ProviderError(format!(
631            "No provider registered for provider type: {}",
632            model.provider.as_str()
633        )))
634    }
635
636    /// Build stream options, resolving API key dynamically if configured.
637    async fn build_stream_options(&self) -> StreamOptions {
638        let security = self.config.read().security.clone();
639        let model = self.config.read().model.clone();
640        let on_payload = self.hooks.read().on_payload.clone();
641        let transport = self.config.read().transport;
642        let max_retry_delay_ms = self.config.read().max_retry_delay_ms;
643        let session_id = self.session_id.read().clone();
644
645        // Dynamic API key resolution: getApiKey > static api_key
646        let get_api_key = self.hooks.read().get_api_key.clone();
647        let api_key = if let Some(ref resolver) = get_api_key {
648            let dynamic = resolver(model.provider.as_str()).await;
649            dynamic.or_else(|| self.api_key.read().clone())
650        } else {
651            self.api_key.read().clone()
652        };
653
654        StreamOptions {
655            api_key,
656            security: Some(security),
657            session_id,
658            on_payload,
659            transport: Some(transport),
660            max_retry_delay_ms,
661            ..Default::default()
662        }
663    }
664
665    /// Build SimpleStreamOptions with thinking level/budget resolution.
666    async fn build_simple_stream_options(&self) -> SimpleStreamOptions {
667        let base = self.build_stream_options().await;
668        let thinking_level = self.config.read().thinking_level;
669
670        let (reasoning, thinking_budget_tokens) = if thinking_level != ThinkingLevel::Off {
671            let budget = self
672                .config
673                .read()
674                .thinking_budgets
675                .as_ref()
676                .and_then(|b| b.budget_for(thinking_level))
677                .or_else(|| {
678                    Some(crate::thinking::ThinkingConfig::default_budget(
679                        thinking_level,
680                    ))
681                });
682            (Some(thinking_level), budget)
683        } else {
684            (None, None)
685        };
686
687        SimpleStreamOptions {
688            base,
689            reasoning,
690            thinking_budget_tokens,
691        }
692    }
693
694    /// Run a single LLM turn: call provider, consume stream, return AssistantMessage.
695    async fn run_turn(&self, provider: &ArcProtocol) -> Result<AssistantMessage, AgentError> {
696        let context = self.build_context().await;
697        let model = self.config.read().model.clone();
698        let options = self.build_simple_stream_options().await;
699        let stream_timeout = self.config.read().security.stream.result_timeout();
700
701        // Create the stream (custom stream_fn or default provider via stream_simple)
702        let stream_fn = self.hooks.read().stream_fn.clone();
703        let mut stream: AssistantMessageEventStream = if let Some(ref custom_stream) = stream_fn {
704            custom_stream(&model, &context, options.base).await
705        } else {
706            provider.stream_simple(&model, &context, options)
707        };
708
709        // Process stream events
710        while let Some(event) = stream.next().await {
711            // Check abort
712            if self.abort_flag.load(Ordering::SeqCst) {
713                return Err(AgentError::Other("Aborted".to_string()));
714            }
715
716            // Check for steering messages
717            let steering = self.dequeue_steering_messages();
718            if !steering.is_empty() {
719                // Apply steering: add steering messages to state
720                for steer_msg in steering {
721                    self.state.add_message(steer_msg);
722                }
723                // Abort current turn and restart
724                return Err(AgentError::Other("Steered".to_string()));
725            }
726
727            // Forward stream event to subscribers
728            match &event {
729                AssistantMessageEvent::Start { partial } => {
730                    *self.state.stream_message.write() =
731                        Some(AgentMessage::Assistant(partial.clone()));
732                    self.emit(AgentEvent::MessageUpdate {
733                        message: AgentMessage::Assistant(partial.clone()),
734                        assistant_event: Box::new(event.clone()),
735                    });
736                }
737                AssistantMessageEvent::TextDelta { .. }
738                | AssistantMessageEvent::ThinkingDelta { .. }
739                | AssistantMessageEvent::ToolCallDelta { .. } => {
740                    if let Some(partial) = event.partial_message() {
741                        *self.state.stream_message.write() =
742                            Some(AgentMessage::Assistant(partial.clone()));
743                        self.emit(AgentEvent::MessageUpdate {
744                            message: AgentMessage::Assistant(partial.clone()),
745                            assistant_event: Box::new(event.clone()),
746                        });
747                    }
748                }
749                _ => {
750                    if let Some(partial) = event.partial_message() {
751                        self.emit(AgentEvent::MessageUpdate {
752                            message: AgentMessage::Assistant(partial.clone()),
753                            assistant_event: Box::new(event.clone()),
754                        });
755                    }
756                }
757            }
758        }
759
760        // Get the final result with timeout to prevent infinite blocking
761        let result = match stream.try_result(stream_timeout).await {
762            Some(r) => r,
763            None => {
764                return Err(AgentError::Other(format!(
765                    "Stream result timed out after {:?}",
766                    stream_timeout
767                )));
768            }
769        };
770
771        // Clear streaming message
772        *self.state.stream_message.write() = None;
773
774        if result.stop_reason == StopReason::Error {
775            let error_msg = result
776                .error_message
777                .clone()
778                .unwrap_or_else(|| "Unknown error".to_string());
779            return Err(AgentError::ProviderError(error_msg));
780        }
781
782        Ok(result)
783    }
784
785    /// Execute tool calls from an assistant message.
786    ///
787    /// Supports: beforeToolCall/afterToolCall hooks, tool validation,
788    /// streaming ToolExecutionUpdate events, bounded parallel exec + timeout,
789    /// and abort-aware execution.
790    async fn execute_tool_calls(
791        &self,
792        assistant_msg: &AssistantMessage,
793        context: &Context,
794    ) -> Vec<ToolResultMessage> {
795        let tool_calls = assistant_msg.tool_calls();
796        if tool_calls.is_empty() {
797            return Vec::new();
798        }
799
800        let executor = self.hooks.read().tool_executor.clone();
801        let execution_mode = self.config.read().tool_execution;
802        let security = self.config.read().security.clone();
803        let tool_timeout = security.agent.tool_execution_timeout();
804        let before_hook = self.hooks.read().before_tool_call.clone();
805        let after_hook = self.hooks.read().after_tool_call.clone();
806
807        // Build Tool list for validation
808        let agent_tools = self.state.tools.read().clone();
809        let tool_defs: Vec<Tool> = agent_tools.iter().map(|t| t.as_tool()).collect();
810
811        let mut results = Vec::new();
812
813        match execution_mode {
814            ToolExecutionMode::Parallel => {
815                let max_parallel = security.agent.max_parallel_tool_calls;
816                let abort_flag = Arc::clone(&self.abort_flag);
817
818                let mut tool_futures = Vec::new();
819
820                for tc in &tool_calls {
821                    let tc_id = tc.id.clone();
822                    let tc_name = tc.name.clone();
823                    let tc_args = tc.arguments.clone();
824                    let tc_clone = (*tc).clone();
825
826                    self.emit(AgentEvent::ToolExecutionStart {
827                        tool_call_id: tc_id.clone(),
828                        tool_name: tc_name.clone(),
829                        args: tc_args.clone(),
830                    });
831
832                    self.state.pending_tool_calls.write().insert(tc_id.clone());
833
834                    // Validate tool call before execution
835                    if let Some(result) = validate_tool_call_or_error(
836                        &tc_id, &tc_name, &tc_args, &tool_defs, &security,
837                    ) {
838                        self.emit(AgentEvent::ToolExecutionEnd {
839                            tool_call_id: tc_id.clone(),
840                            tool_name: tc_name.clone(),
841                            result: serde_json::json!({"error": result.text_content()}),
842                            is_error: true,
843                        });
844                        self.state.pending_tool_calls.write().remove(&tc_id);
845                        results.push(result);
846                        continue;
847                    }
848
849                    // beforeToolCall hook
850                    if let Some(result) = run_before_hook(
851                        &before_hook,
852                        assistant_msg,
853                        &tc_clone,
854                        &tc_args,
855                        context,
856                        &tc_id,
857                        &tc_name,
858                    )
859                    .await
860                    {
861                        self.emit(AgentEvent::ToolExecutionEnd {
862                            tool_call_id: tc_id.clone(),
863                            tool_name: tc_name.clone(),
864                            result: serde_json::json!({"error": result.text_content()}),
865                            is_error: true,
866                        });
867                        self.state.pending_tool_calls.write().remove(&tc_id);
868                        results.push(result);
869                        continue;
870                    }
871
872                    let executor = executor.clone();
873                    let abort = abort_flag.clone();
874                    let after_hook = after_hook.clone();
875                    let assistant_msg_clone = assistant_msg.clone();
876                    let context_clone = context.clone();
877                    let subscribers = Arc::clone(&self.subscribers);
878
879                    tool_futures.push(async move {
880                        let (final_content, final_is_error) =
881                            execute_and_apply_after_hook(ToolExecCtx {
882                                executor: &executor,
883                                after_hook: &after_hook,
884                                subscribers: &subscribers,
885                                tc_id: &tc_id,
886                                tc_name: &tc_name,
887                                tc_args: &tc_args,
888                                tc: &tc_clone,
889                                assistant_msg: &assistant_msg_clone,
890                                context: &context_clone,
891                                tool_timeout,
892                                abort_flag: abort,
893                            })
894                            .await;
895
896                        (tc_id, tc_name, final_content, final_is_error)
897                    });
898                }
899
900                // Use buffer_unordered for bounded parallel execution
901                let mut buffered =
902                    futures::stream::iter(tool_futures).buffer_unordered(max_parallel);
903
904                while let Some((tc_id, tc_name, content, is_error)) = buffered.next().await {
905                    let result_json =
906                        serde_json::to_value(&content).unwrap_or(serde_json::Value::Null);
907                    self.emit(AgentEvent::ToolExecutionEnd {
908                        tool_call_id: tc_id.clone(),
909                        tool_name: tc_name.clone(),
910                        result: result_json,
911                        is_error,
912                    });
913
914                    self.state.pending_tool_calls.write().remove(&tc_id);
915
916                    results.push(ToolResultMessage::new(tc_id, tc_name, content, is_error));
917                }
918            }
919            ToolExecutionMode::Sequential => {
920                for tc in &tool_calls {
921                    if self.abort_flag.load(Ordering::SeqCst) {
922                        break;
923                    }
924
925                    let tc_id = tc.id.clone();
926                    let tc_name = tc.name.clone();
927                    let tc_args = tc.arguments.clone();
928                    let tc_clone = (*tc).clone();
929
930                    self.emit(AgentEvent::ToolExecutionStart {
931                        tool_call_id: tc_id.clone(),
932                        tool_name: tc_name.clone(),
933                        args: tc_args.clone(),
934                    });
935
936                    self.state.pending_tool_calls.write().insert(tc_id.clone());
937
938                    // Validate tool call before execution
939                    if let Some(result) = validate_tool_call_or_error(
940                        &tc_id, &tc_name, &tc_args, &tool_defs, &security,
941                    ) {
942                        self.emit(AgentEvent::ToolExecutionEnd {
943                            tool_call_id: tc_id.clone(),
944                            tool_name: tc_name.clone(),
945                            result: serde_json::json!({"error": result.text_content()}),
946                            is_error: true,
947                        });
948                        self.state.pending_tool_calls.write().remove(&tc_id);
949                        results.push(result);
950                        continue;
951                    }
952
953                    // beforeToolCall hook
954                    if let Some(result) = run_before_hook(
955                        &before_hook,
956                        assistant_msg,
957                        &tc_clone,
958                        &tc_args,
959                        context,
960                        &tc_id,
961                        &tc_name,
962                    )
963                    .await
964                    {
965                        self.emit(AgentEvent::ToolExecutionEnd {
966                            tool_call_id: tc_id.clone(),
967                            tool_name: tc_name.clone(),
968                            result: serde_json::json!({"error": result.text_content()}),
969                            is_error: true,
970                        });
971                        self.state.pending_tool_calls.write().remove(&tc_id);
972                        results.push(result);
973                        continue;
974                    }
975
976                    let abort_flag = Arc::clone(&self.abort_flag);
977                    let (final_content, final_is_error) =
978                        execute_and_apply_after_hook(ToolExecCtx {
979                            executor: &executor,
980                            after_hook: &after_hook,
981                            subscribers: &self.subscribers,
982                            tc_id: &tc_id,
983                            tc_name: &tc_name,
984                            tc_args: &tc_args,
985                            tc: &tc_clone,
986                            assistant_msg,
987                            context,
988                            tool_timeout,
989                            abort_flag,
990                        })
991                        .await;
992
993                    let result_json =
994                        serde_json::to_value(&final_content).unwrap_or(serde_json::Value::Null);
995                    self.emit(AgentEvent::ToolExecutionEnd {
996                        tool_call_id: tc_id.clone(),
997                        tool_name: tc_name.clone(),
998                        result: result_json,
999                        is_error: final_is_error,
1000                    });
1001
1002                    self.state.pending_tool_calls.write().remove(&tc_id);
1003
1004                    results.push(ToolResultMessage::new(
1005                        tc_id,
1006                        tc_name,
1007                        final_content,
1008                        final_is_error,
1009                    ));
1010
1011                    // Check for steering messages after each sequential tool
1012                    let steering = self.dequeue_steering_messages();
1013                    if !steering.is_empty() {
1014                        for steer_msg in steering {
1015                            self.state.add_message(steer_msg);
1016                        }
1017                        // Break out of remaining tool calls
1018                        break;
1019                    }
1020                }
1021            }
1022        }
1023
1024        results
1025    }
1026
1027    /// Run the agent loop: stream LLM → check tool calls → execute → loop.
1028    async fn run_loop(&self) -> Result<Vec<AgentMessage>, AgentError> {
1029        let provider = if self.hooks.read().stream_fn.is_some() {
1030            // When a custom stream function is set, we don't need a provider.
1031            // Create a dummy Arc for the loop (won't be used).
1032            None
1033        } else {
1034            Some(self.resolve_provider()?)
1035        };
1036
1037        let max_turns = *self.max_turns.read();
1038        let mut new_messages = Vec::new();
1039        let mut turn_count = 0;
1040
1041        // Sync message limit from security config
1042        let max_messages = self.config.read().security.agent.max_messages;
1043        self.state.set_max_messages(max_messages);
1044
1045        loop {
1046            // Check abort
1047            if self.abort_flag.load(Ordering::SeqCst) {
1048                self.emit(AgentEvent::AgentEnd {
1049                    messages: new_messages.clone(),
1050                });
1051                return Err(AgentError::Other("Aborted".to_string()));
1052            }
1053
1054            // Check max turns
1055            if turn_count >= max_turns {
1056                break;
1057            }
1058
1059            self.emit(AgentEvent::TurnStart);
1060
1061            // Run one LLM turn
1062            let dummy_provider: ArcProtocol = Arc::new(DummyProvider);
1063            let active_provider = provider.as_ref().unwrap_or(&dummy_provider);
1064            let assistant_result = self.run_turn(active_provider).await;
1065
1066            match assistant_result {
1067                Ok(assistant_msg) => {
1068                    // Build context snapshot for tool hook use
1069                    let context = self.build_context().await;
1070
1071                    // Add assistant message to state and new_messages
1072                    let agent_msg = AgentMessage::Assistant(assistant_msg.clone());
1073                    self.state.add_message(agent_msg.clone());
1074                    new_messages.push(agent_msg.clone());
1075
1076                    self.emit(AgentEvent::MessageStart {
1077                        message: agent_msg.clone(),
1078                    });
1079                    self.emit(AgentEvent::MessageEnd {
1080                        message: agent_msg.clone(),
1081                    });
1082
1083                    // Check if there are tool calls
1084                    if assistant_msg.has_tool_calls()
1085                        && assistant_msg.stop_reason == StopReason::ToolUse
1086                    {
1087                        let tool_results = self.execute_tool_calls(&assistant_msg, &context).await;
1088
1089                        for result in &tool_results {
1090                            let result_msg = AgentMessage::ToolResult(result.clone());
1091                            self.state.add_message(result_msg.clone());
1092                            new_messages.push(result_msg);
1093                        }
1094
1095                        self.emit(AgentEvent::TurnEnd {
1096                            message: agent_msg,
1097                            tool_results,
1098                        });
1099
1100                        // Check for follow-up messages
1101                        let follow_ups = self.dequeue_follow_up_messages();
1102                        for msg in follow_ups {
1103                            self.state.add_message(msg.clone());
1104                            new_messages.push(msg);
1105                        }
1106
1107                        turn_count += 1;
1108                        continue;
1109                    } else {
1110                        // No tool calls — conversation turn is complete
1111                        self.emit(AgentEvent::TurnEnd {
1112                            message: agent_msg,
1113                            tool_results: Vec::new(),
1114                        });
1115
1116                        // Check for follow-up messages
1117                        let follow_ups = self.dequeue_follow_up_messages();
1118                        if !follow_ups.is_empty() {
1119                            for msg in follow_ups {
1120                                self.state.add_message(msg.clone());
1121                                new_messages.push(msg);
1122                            }
1123                            turn_count += 1;
1124                            continue;
1125                        }
1126
1127                        break;
1128                    }
1129                }
1130                Err(AgentError::Other(ref msg)) if msg == "Steered" => {
1131                    turn_count += 1;
1132                    continue;
1133                }
1134                Err(e) => {
1135                    *self.state.error.write() = Some(e.to_string());
1136                    return Err(e);
1137                }
1138            }
1139        }
1140
1141        Ok(new_messages)
1142    }
1143
1144    // ============================================================================
1145    // Prompt methods
1146    // ============================================================================
1147
1148    /// Send a prompt to the agent.
1149    ///
1150    /// Uses atomic compare_exchange to prevent TOCTOU race condition.
1151    pub async fn prompt(
1152        &self,
1153        message: impl Into<AgentMessage>,
1154    ) -> Result<Vec<AgentMessage>, AgentError> {
1155        // Atomic CAS: only one caller wins the race
1156        if self
1157            .state
1158            .is_streaming
1159            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
1160            .is_err()
1161        {
1162            return Err(AgentError::AlreadyStreaming);
1163        }
1164
1165        let message = message.into();
1166        self.abort_flag.store(false, Ordering::SeqCst);
1167
1168        // Add user message to state
1169        self.state.add_message(message.clone());
1170
1171        // Emit start event
1172        self.emit(AgentEvent::AgentStart);
1173
1174        // Run the agent loop
1175        let result = self.run_loop().await;
1176
1177        self.state.set_streaming(false);
1178
1179        match result {
1180            Ok(messages) => {
1181                self.emit(AgentEvent::AgentEnd {
1182                    messages: messages.clone(),
1183                });
1184                Ok(messages)
1185            }
1186            Err(e) => {
1187                self.emit(AgentEvent::AgentEnd {
1188                    messages: Vec::new(),
1189                });
1190                Err(e)
1191            }
1192        }
1193    }
1194
1195    /// Continue from current state (e.g., after adding tool results externally).
1196    ///
1197    /// Uses atomic compare_exchange to prevent TOCTOU race condition.
1198    pub async fn continue_(&self) -> Result<Vec<AgentMessage>, AgentError> {
1199        if self
1200            .state
1201            .is_streaming
1202            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
1203            .is_err()
1204        {
1205            return Err(AgentError::AlreadyStreaming);
1206        }
1207
1208        {
1209            let messages = self.state.messages.read();
1210            if messages.is_empty() {
1211                self.state.set_streaming(false);
1212                return Err(AgentError::NoMessages);
1213            }
1214            if let Some(AgentMessage::Assistant(_)) = messages.last() {
1215                self.state.set_streaming(false);
1216                return Err(AgentError::CannotContinueFromAssistant);
1217            }
1218        }
1219
1220        self.abort_flag.store(false, Ordering::SeqCst);
1221
1222        self.emit(AgentEvent::AgentStart);
1223
1224        let result = self.run_loop().await;
1225
1226        self.state.set_streaming(false);
1227
1228        match result {
1229            Ok(messages) => {
1230                self.emit(AgentEvent::AgentEnd {
1231                    messages: messages.clone(),
1232                });
1233                Ok(messages)
1234            }
1235            Err(e) => {
1236                self.emit(AgentEvent::AgentEnd {
1237                    messages: Vec::new(),
1238                });
1239                Err(e)
1240            }
1241        }
1242    }
1243
1244    /// Abort current operation.
1245    pub fn abort(&self) {
1246        self.abort_flag.store(true, Ordering::SeqCst);
1247        self.state.set_streaming(false);
1248        self.clear_all_queues();
1249    }
1250
1251    /// Wait for the agent to become idle.
1252    pub async fn wait_for_idle(&self) {
1253        while self.state.is_streaming() {
1254            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1255        }
1256    }
1257
1258    /// Get the current state.
1259    pub fn state(&self) -> &Arc<AgentState> {
1260        &self.state
1261    }
1262
1263    /// Take a consistent point-in-time snapshot of the agent's full state.
1264    ///
1265    /// Combines runtime state from [`AgentState`] with configuration
1266    /// (model, thinking_level) from [`AgentConfig`].
1267    pub fn snapshot(&self) -> AgentStateSnapshot {
1268        let config = self.config.read();
1269        let system_prompt = self.state.system_prompt.read().clone();
1270        let messages = self.state.messages.read().clone();
1271        let is_streaming = self.state.is_streaming();
1272        let stream_message = self.state.stream_message.read().clone();
1273        let pending_tool_calls = self.state.pending_tool_calls.read().clone();
1274        let error = self.state.error.read().clone();
1275        let max_messages = self.state.get_max_messages();
1276        let message_count = messages.len();
1277
1278        AgentStateSnapshot {
1279            system_prompt,
1280            model: config.model.clone(),
1281            thinking_level: config.thinking_level,
1282            messages,
1283            is_streaming,
1284            stream_message,
1285            pending_tool_calls,
1286            error,
1287            message_count,
1288            max_messages,
1289        }
1290    }
1291}
1292
1293/// Helper: wait until the abort flag is set.
1294async fn wait_for_abort(flag: Arc<AtomicBool>) {
1295    loop {
1296        if flag.load(Ordering::SeqCst) {
1297            return;
1298        }
1299        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1300    }
1301}
1302
1303// ============================================================================
1304// Extracted helpers for execute_tool_calls deduplication
1305// ============================================================================
1306
1307/// Validate a tool call against the tool definitions and security config.
1308///
1309/// Returns `Some(ToolResultMessage)` (an error result) if validation fails,
1310/// or `None` if the tool call is valid and execution should proceed.
1311/// Skips validation when disabled in config or when no tools are registered.
1312fn validate_tool_call_or_error(
1313    tc_id: &str,
1314    tc_name: &str,
1315    tc_args: &serde_json::Value,
1316    tool_defs: &[Tool],
1317    security: &SecurityConfig,
1318) -> Option<ToolResultMessage> {
1319    if !security.agent.validate_tool_calls || tool_defs.is_empty() {
1320        return None;
1321    }
1322
1323    let tc = ToolCall::new(tc_id, tc_name, tc_args.clone());
1324    match crate::validation::validate_tool_call(tool_defs, &tc) {
1325        Ok(_) => None,
1326        Err(e) => Some(ToolResultMessage::error(tc_id, tc_name, e.to_string())),
1327    }
1328}
1329
1330/// Run the `before_tool_call` hook if set.
1331///
1332/// Returns `Some(ToolResultMessage)` (a blocked/error result) if the hook
1333/// blocked execution, or `None` if execution should proceed.
1334async fn run_before_hook(
1335    before_hook: &Option<BeforeToolCallFn>,
1336    assistant_msg: &AssistantMessage,
1337    tc: &ToolCall,
1338    tc_args: &serde_json::Value,
1339    context: &Context,
1340    tc_id: &str,
1341    tc_name: &str,
1342) -> Option<ToolResultMessage> {
1343    let hook = before_hook.as_ref()?;
1344    let ctx = BeforeToolCallContext {
1345        assistant_message: assistant_msg.clone(),
1346        tool_call: tc.clone(),
1347        args: tc_args.clone(),
1348        context: context.clone(),
1349    };
1350    match hook(ctx).await {
1351        Some(result) if result.block => {
1352            let reason = result
1353                .reason
1354                .unwrap_or_else(|| "Tool call blocked by before_tool_call hook".to_string());
1355            Some(ToolResultMessage::error(tc_id, tc_name, reason))
1356        }
1357        _ => None,
1358    }
1359}
1360
1361/// Context for a single tool call execution.
1362///
1363/// Groups the parameters needed by `execute_and_apply_after_hook` to avoid
1364/// exceeding clippy's `too_many_arguments` limit.
1365struct ToolExecCtx<'a> {
1366    executor: &'a Option<ToolExecutor>,
1367    after_hook: &'a Option<AfterToolCallFn>,
1368    subscribers: &'a Arc<Subscribers>,
1369    tc_id: &'a str,
1370    tc_name: &'a str,
1371    tc_args: &'a serde_json::Value,
1372    tc: &'a ToolCall,
1373    assistant_msg: &'a AssistantMessage,
1374    context: &'a Context,
1375    tool_timeout: std::time::Duration,
1376    abort_flag: Arc<AtomicBool>,
1377}
1378
1379/// Execute a tool call and apply the `after_tool_call` hook if set.
1380///
1381/// Handles: executor invocation with timeout, abort-awareness,
1382/// streaming `ToolExecutionUpdate` events, error detection,
1383/// and after-hook overrides.
1384///
1385/// Returns `(final_content, final_is_error)`.
1386async fn execute_and_apply_after_hook(ctx: ToolExecCtx<'_>) -> (Vec<ContentBlock>, bool) {
1387    let ToolExecCtx {
1388        executor,
1389        after_hook,
1390        subscribers,
1391        tc_id,
1392        tc_name,
1393        tc_args,
1394        tc,
1395        assistant_msg,
1396        context,
1397        tool_timeout,
1398        abort_flag,
1399    } = ctx;
1400    // Execute the tool
1401    let tool_result = if let Some(ref exec) = executor {
1402        // Build update callback for streaming partial results
1403        let subs = Arc::clone(subscribers);
1404        let update_tc_id = tc_id.to_string();
1405        let update_tc_name = tc_name.to_string();
1406        let update_cb: ToolUpdateCallback = Arc::new(move |partial: serde_json::Value| {
1407            subs.emit(&AgentEvent::ToolExecutionUpdate {
1408                tool_call_id: update_tc_id.clone(),
1409                tool_name: update_tc_name.clone(),
1410                partial_result: partial,
1411            });
1412        });
1413
1414        let exec_future = exec(tc_name, tc_id, tc_args, Some(update_cb));
1415
1416        // Race: tool execution vs timeout vs abort
1417        tokio::select! {
1418            result = exec_future => result,
1419            _ = tokio::time::sleep(tool_timeout) => {
1420                AgentToolResult::error(format!(
1421                    "Tool '{}' timed out after {:?}",
1422                    tc_name, tool_timeout
1423                ))
1424            }
1425            _ = wait_for_abort(abort_flag) => {
1426                AgentToolResult::error(format!("Tool '{}' aborted", tc_name))
1427            }
1428        }
1429    } else {
1430        AgentToolResult::error(format!(
1431            "No tool executor configured for tool '{}'",
1432            tc_name
1433        ))
1434    };
1435
1436    // Detect is_error from content
1437    let mut is_error = tool_result.content.iter().any(|block| {
1438        if let Some(text) = block.as_text() {
1439            text.text.starts_with("Error:") || text.text.starts_with("error:")
1440        } else {
1441            false
1442        }
1443    });
1444
1445    let mut final_content = tool_result.content.clone();
1446
1447    // Apply after_tool_call hook
1448    if let Some(ref hook) = after_hook {
1449        let after_ctx = AfterToolCallContext {
1450            assistant_message: assistant_msg.clone(),
1451            tool_call: tc.clone(),
1452            args: tc_args.clone(),
1453            result: tool_result,
1454            is_error,
1455            context: context.clone(),
1456        };
1457        if let Some(overrides) = hook(after_ctx).await {
1458            if let Some(content_override) = overrides.content {
1459                final_content = content_override;
1460            }
1461            if let Some(error_override) = overrides.is_error {
1462                is_error = error_override;
1463            }
1464        }
1465    }
1466
1467    (final_content, is_error)
1468}
1469
1470impl Default for Agent {
1471    fn default() -> Self {
1472        Self::new()
1473    }
1474}
1475
1476/// Minimal dummy provider used when a custom `stream_fn` is set.
1477/// This should never actually be called.
1478struct DummyProvider;
1479
1480#[async_trait::async_trait]
1481impl crate::provider::LLMProtocol for DummyProvider {
1482    fn provider_type(&self) -> Provider {
1483        Provider::Custom("dummy".to_string())
1484    }
1485
1486    fn stream(
1487        &self,
1488        _model: &Model,
1489        _context: &Context,
1490        _options: StreamOptions,
1491    ) -> AssistantMessageEventStream {
1492        let stream = AssistantMessageEventStream::new_assistant_stream();
1493        let error_msg = AssistantMessage::builder()
1494            .provider(Provider::Custom("dummy".to_string()))
1495            .model("dummy")
1496            .stop_reason(StopReason::Error)
1497            .error_message("DummyProvider should not be called when stream_fn is set")
1498            .build()
1499            .unwrap();
1500        stream.push(AssistantMessageEvent::Error {
1501            reason: StopReason::Error,
1502            error: error_msg,
1503        });
1504        stream.end(None);
1505        stream
1506    }
1507
1508    fn stream_simple(
1509        &self,
1510        model: &Model,
1511        context: &Context,
1512        options: SimpleStreamOptions,
1513    ) -> AssistantMessageEventStream {
1514        self.stream(model, context, options.base)
1515    }
1516}
1517
1518/// Agent error type.
1519#[derive(Debug, thiserror::Error)]
1520pub enum AgentError {
1521    #[error("Agent is already streaming")]
1522    AlreadyStreaming,
1523
1524    #[error("No messages in context")]
1525    NoMessages,
1526
1527    #[error("Cannot continue from assistant message")]
1528    CannotContinueFromAssistant,
1529
1530    #[error("Tool not found: {0}")]
1531    ToolNotFound(String),
1532
1533    #[error("Provider error: {0}")]
1534    ProviderError(String),
1535
1536    #[error("{0}")]
1537    Other(String),
1538}