Skip to main content

agent_sdk/
query.rs

1//! The query function and agent loop implementation.
2//!
3//! This module contains the core `query()` function that creates an async stream
4//! of messages, driving Claude through the agentic loop of prompt → response →
5//! tool calls → tool results → repeat.
6
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Instant;
12
13use futures::stream::FuturesUnordered;
14use futures::{Stream, StreamExt as FuturesStreamExt};
15use serde_json::json;
16use tokio::sync::mpsc;
17use tokio_stream::wrappers::UnboundedReceiverStream;
18use tracing::{debug, error, warn};
19use uuid::Uuid;
20
21use crate::client::{
22    ApiContentBlock, ApiMessage, ApiUsage, CacheControl, ContentDelta, CreateMessageRequest,
23    ImageSource, MessageResponse, StreamEvent as ClientStreamEvent, SystemBlock, ThinkingParam,
24    ToolDefinition,
25};
26use crate::compact;
27use crate::error::{AgentError, Result};
28use crate::hooks::HookRegistry;
29use crate::options::{Options, PermissionMode, ThinkingConfig};
30use crate::permissions::{PermissionEvaluator, PermissionVerdict};
31use crate::provider::LlmProvider;
32use crate::providers::AnthropicProvider;
33use crate::sanitize;
34use crate::session::Session;
35use crate::tools::definitions::get_tool_definitions;
36use crate::tools::executor::{ToolExecutor, ToolResult};
37use crate::types::messages::*;
38
39/// Default model to use when none is specified.
40const DEFAULT_MODEL: &str = "claude-haiku-4-5";
41/// Default max tokens for API responses.
42const DEFAULT_MAX_TOKENS: u32 = 16384;
43
44/// A handle to a running query that streams messages.
45///
46/// Implements `Stream<Item = Result<Message>>` for async iteration.
47pub struct Query {
48    receiver: UnboundedReceiverStream<Result<Message>>,
49    session_id: Option<String>,
50    cancel_token: tokio_util::sync::CancellationToken,
51}
52
53impl Query {
54    /// Interrupt the current query.
55    pub async fn interrupt(&self) -> Result<()> {
56        self.cancel_token.cancel();
57        Ok(())
58    }
59
60    /// Get the session ID (available after the init message).
61    pub fn session_id(&self) -> Option<&str> {
62        self.session_id.as_deref()
63    }
64
65    /// Change the permission mode mid-session.
66    pub async fn set_permission_mode(&self, _mode: PermissionMode) -> Result<()> {
67        // TODO: Send control message to the running agent loop
68        Ok(())
69    }
70
71    /// Change the model mid-session.
72    pub async fn set_model(&self, _model: &str) -> Result<()> {
73        // TODO: Send control message to the running agent loop
74        Ok(())
75    }
76
77    /// Close the query and terminate the underlying process.
78    pub fn close(&self) {
79        self.cancel_token.cancel();
80    }
81}
82
83impl Stream for Query {
84    type Item = Result<Message>;
85
86    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
87        Pin::new(&mut self.receiver).poll_next(cx)
88    }
89}
90
91/// Create a query that streams messages from Claude.
92///
93/// This is the primary function for interacting with the Claude Agent SDK.
94/// Returns a [`Query`] stream that yields [`Message`] items as the agent loop
95/// progresses.
96///
97/// # Arguments
98///
99/// * `prompt` - The input prompt string
100/// * `options` - Configuration options for the query
101///
102/// # Example
103///
104/// ```rust,no_run
105/// use agent_sdk::{query, Options, Message};
106/// use tokio_stream::StreamExt;
107///
108/// # async fn example() -> anyhow::Result<()> {
109/// let mut stream = query(
110///     "What files are in this directory?",
111///     Options::builder()
112///         .allowed_tools(vec!["Bash".into(), "Glob".into()])
113///         .build(),
114/// );
115///
116/// while let Some(message) = stream.next().await {
117///     let message = message?;
118///     if let Message::Result(result) = &message {
119///         println!("{}", result.result.as_deref().unwrap_or(""));
120///     }
121/// }
122/// # Ok(())
123/// # }
124/// ```
125pub fn query(prompt: &str, options: Options) -> Query {
126    let (tx, rx) = mpsc::unbounded_channel();
127    let cancel_token = tokio_util::sync::CancellationToken::new();
128    let cancel = cancel_token.clone();
129
130    let prompt = prompt.to_string();
131
132    tokio::spawn(async move {
133        let result = run_agent_loop(prompt, options, tx.clone(), cancel).await;
134        if let Err(e) = result {
135            let _ = tx.send(Err(e));
136        }
137    });
138
139    Query {
140        receiver: UnboundedReceiverStream::new(rx),
141        session_id: None,
142        cancel_token,
143    }
144}
145
146/// The main agent loop.
147///
148/// This implements the core cycle:
149/// 1. Receive prompt
150/// 2. Send to Claude
151/// 3. Process response (text + tool calls)
152/// 4. Execute tools
153/// 5. Feed results back
154/// 6. Repeat until done or limits hit
155async fn run_agent_loop(
156    prompt: String,
157    mut options: Options,
158    tx: mpsc::UnboundedSender<Result<Message>>,
159    cancel: tokio_util::sync::CancellationToken,
160) -> Result<()> {
161    let start_time = Instant::now();
162    let mut api_time_ms: u64 = 0;
163
164    // Resolve working directory
165    let cwd = options.cwd.clone().unwrap_or_else(|| {
166        std::env::current_dir()
167            .unwrap_or_else(|_| PathBuf::from("."))
168            .to_string_lossy()
169            .to_string()
170    });
171
172    // Create or resume session
173    let session = if let Some(ref resume_id) = options.resume {
174        Session::with_id(resume_id, &cwd)
175    } else if options.continue_session {
176        // Find most recent session
177        match crate::session::find_most_recent_session(Some(&cwd)).await? {
178            Some(info) => Session::with_id(&info.session_id, &cwd),
179            None => Session::new(&cwd),
180        }
181    } else {
182        match &options.session_id {
183            Some(id) => Session::with_id(id, &cwd),
184            None => Session::new(&cwd),
185        }
186    };
187
188    let session_id = session.id.clone();
189    let model = options
190        .model
191        .clone()
192        .unwrap_or_else(|| DEFAULT_MODEL.to_string());
193
194    // Build tool definitions (skip tools entirely when output_format is set —
195    // structured-output queries should not use tools).
196    let tool_names: Vec<String> = if options.output_format.is_some() {
197        Vec::new()
198    } else if options.allowed_tools.is_empty() {
199        // Default set of tools
200        vec![
201            "Read".into(),
202            "Write".into(),
203            "Edit".into(),
204            "Bash".into(),
205            "Glob".into(),
206            "Grep".into(),
207        ]
208    } else {
209        options.allowed_tools.clone()
210    };
211
212    let raw_defs: Vec<_> = get_tool_definitions(&tool_names);
213
214    // Combine built-in + custom tool definitions
215    let mut all_defs: Vec<ToolDefinition> = raw_defs
216        .into_iter()
217        .map(|td| ToolDefinition {
218            name: td.name.to_string(),
219            description: td.description.to_string(),
220            input_schema: td.input_schema,
221            cache_control: None,
222        })
223        .collect();
224
225    // Append custom tool definitions
226    for ctd in &options.custom_tool_definitions {
227        all_defs.push(ToolDefinition {
228            name: ctd.name.clone(),
229            description: ctd.description.clone(),
230            input_schema: ctd.input_schema.clone(),
231            cache_control: None,
232        });
233    }
234
235    // Mark the last tool with cache_control so the tools block is cached
236    if let Some(last) = all_defs.last_mut() {
237        last.cache_control = Some(CacheControl::ephemeral());
238    }
239
240    let tool_defs = all_defs;
241
242    // Emit init system message
243    let init_msg = Message::System(SystemMessage {
244        subtype: SystemSubtype::Init,
245        uuid: Uuid::new_v4(),
246        session_id: session_id.clone(),
247        agents: if options.agents.is_empty() {
248            None
249        } else {
250            Some(options.agents.keys().cloned().collect())
251        },
252        claude_code_version: Some(env!("CARGO_PKG_VERSION").to_string()),
253        cwd: Some(cwd.clone()),
254        tools: Some(tool_names.clone()),
255        mcp_servers: if options.mcp_servers.is_empty() {
256            None
257        } else {
258            Some(
259                options
260                    .mcp_servers
261                    .keys()
262                    .map(|name| McpServerStatus {
263                        name: name.clone(),
264                        status: "connected".to_string(),
265                    })
266                    .collect(),
267            )
268        },
269        model: Some(model.clone()),
270        permission_mode: Some(options.permission_mode.to_string()),
271        compact_metadata: None,
272    });
273
274    // Persist and emit init message
275    if options.persist_session {
276        let _ = session
277            .append_message(&serde_json::to_value(&init_msg).unwrap_or_default())
278            .await;
279    }
280    if tx.send(Ok(init_msg)).is_err() {
281        return Ok(());
282    }
283
284    // Initialize LLM provider
285    let provider: Box<dyn LlmProvider> = match options.provider.take() {
286        Some(p) => p,
287        None => Box::new(AnthropicProvider::from_env()?),
288    };
289
290    // Initialize tool executor with optional path boundary
291    let additional_dirs: Vec<PathBuf> = options
292        .additional_directories
293        .iter()
294        .map(PathBuf::from)
295        .collect();
296    let env_blocklist = std::mem::take(&mut options.env_blocklist);
297    let tool_executor = if additional_dirs.is_empty() {
298        ToolExecutor::new(PathBuf::from(&cwd))
299    } else {
300        ToolExecutor::with_allowed_dirs(PathBuf::from(&cwd), additional_dirs)
301    }
302    .with_env_blocklist(env_blocklist);
303
304    // Build hook registry from options, merging file-discovered hooks
305    let mut hook_registry = HookRegistry::from_map(std::mem::take(&mut options.hooks));
306    if !options.hook_dirs.is_empty() {
307        let dirs: Vec<&std::path::Path> = options.hook_dirs.iter().map(|p| p.as_path()).collect();
308        match crate::hooks::HookDiscovery::discover(&dirs) {
309            Ok(discovered) => hook_registry.merge(discovered),
310            Err(e) => tracing::warn!("Failed to discover hooks from dirs: {}", e),
311        }
312    }
313
314    // Take followup_rx out of options before borrowing options immutably
315    let mut followup_rx = options.followup_rx.take();
316
317    // Initialize permission evaluator
318    let permission_eval = PermissionEvaluator::new(&options);
319
320    // Build the system prompt as SystemBlock(s) with prompt caching
321    let system_prompt: Option<Vec<SystemBlock>> = {
322        let text = match &options.system_prompt {
323            Some(crate::options::SystemPrompt::Custom(s)) => s.clone(),
324            Some(crate::options::SystemPrompt::Preset { append, .. }) => {
325                let base = "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.";
326                match append {
327                    Some(extra) => format!("{}\n\n{}", base, extra),
328                    None => base.to_string(),
329                }
330            }
331            None => "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.".to_string(),
332        };
333        Some(vec![SystemBlock {
334            kind: "text".to_string(),
335            text,
336            cache_control: Some(CacheControl::ephemeral()),
337        }])
338    };
339
340    // Build initial conversation from prompt
341    let mut conversation: Vec<ApiMessage> = Vec::new();
342
343    // Load previous messages if resuming
344    if options.resume.is_some() || options.continue_session {
345        let prev_messages = session.load_messages().await?;
346        for msg_value in prev_messages {
347            if let Some(api_msg) = value_to_api_message(&msg_value) {
348                conversation.push(api_msg);
349            }
350        }
351    }
352
353    // Add the user prompt (with optional image attachments)
354    {
355        let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
356
357        // Add image attachments as Image content blocks
358        for att in &options.attachments {
359            let is_image = matches!(
360                att.mime_type.as_str(),
361                "image/png" | "image/jpeg" | "image/gif" | "image/webp"
362            );
363            if is_image {
364                content_blocks.push(ApiContentBlock::Image {
365                    source: ImageSource {
366                        kind: "base64".to_string(),
367                        media_type: att.mime_type.clone(),
368                        data: att.base64_data.clone(),
369                    },
370                });
371            }
372        }
373
374        // Add the text prompt
375        content_blocks.push(ApiContentBlock::Text {
376            text: prompt.clone(),
377            cache_control: None,
378        });
379
380        conversation.push(ApiMessage {
381            role: "user".to_string(),
382            content: content_blocks,
383        });
384    }
385
386    // Persist user message
387    if options.persist_session {
388        let user_msg = json!({
389            "type": "user",
390            "uuid": Uuid::new_v4().to_string(),
391            "session_id": &session_id,
392            "content": [{"type": "text", "text": &prompt}]
393        });
394        let _ = session.append_message(&user_msg).await;
395    }
396
397    // Agent loop
398    let mut num_turns: u32 = 0;
399    let mut total_usage = Usage::default();
400    let mut total_cost: f64 = 0.0;
401    let mut model_usage: HashMap<String, ModelUsage> = HashMap::new();
402    let mut permission_denials: Vec<PermissionDenial> = Vec::new();
403
404    loop {
405        // Check cancellation
406        if cancel.is_cancelled() {
407            return Err(AgentError::Cancelled);
408        }
409
410        // Check turn limit
411        if let Some(max_turns) = options.max_turns {
412            if num_turns >= max_turns {
413                let result_msg = build_result_message(
414                    ResultSubtype::ErrorMaxTurns,
415                    &session_id,
416                    None,
417                    start_time,
418                    api_time_ms,
419                    num_turns,
420                    total_cost,
421                    &total_usage,
422                    &model_usage,
423                    &permission_denials,
424                );
425                let _ = tx.send(Ok(result_msg));
426                return Ok(());
427            }
428        }
429
430        // Check budget limit
431        if let Some(max_budget) = options.max_budget_usd {
432            if total_cost >= max_budget {
433                let result_msg = build_result_message(
434                    ResultSubtype::ErrorMaxBudgetUsd,
435                    &session_id,
436                    None,
437                    start_time,
438                    api_time_ms,
439                    num_turns,
440                    total_cost,
441                    &total_usage,
442                    &model_usage,
443                    &permission_denials,
444                );
445                let _ = tx.send(Ok(result_msg));
446                return Ok(());
447            }
448        }
449
450        // Drain any followup messages that arrived while we were processing.
451        // These are batched into a single user message appended to the conversation
452        // so the model sees them on the next API call.
453        if let Some(ref mut followup_rx) = followup_rx {
454            let mut followups: Vec<String> = Vec::new();
455            while let Ok(msg) = followup_rx.try_recv() {
456                followups.push(msg);
457            }
458            if !followups.is_empty() {
459                let combined = followups.join("\n\n");
460                debug!(
461                    count = followups.len(),
462                    "Injecting followup messages into agent loop"
463                );
464
465                conversation.push(ApiMessage {
466                    role: "user".to_string(),
467                    content: vec![ApiContentBlock::Text {
468                        text: combined.clone(),
469                        cache_control: None,
470                    }],
471                });
472
473                // Emit a user message so downstream consumers know about the injection
474                let followup_msg = Message::User(UserMessage {
475                    uuid: Some(Uuid::new_v4()),
476                    session_id: session_id.clone(),
477                    content: vec![ContentBlock::Text { text: combined }],
478                    parent_tool_use_id: None,
479                    is_synthetic: false,
480                    tool_use_result: None,
481                });
482
483                if options.persist_session {
484                    let _ = session
485                        .append_message(&serde_json::to_value(&followup_msg).unwrap_or_default())
486                        .await;
487                }
488                if tx.send(Ok(followup_msg)).is_err() {
489                    return Ok(());
490                }
491            }
492        }
493
494        // Set a cache breakpoint on the last content block of the last user
495        // message. This keeps the total breakpoints at 3 (system + tools + last
496        // user turn), well within the API limit of 4.
497        apply_cache_breakpoint(&mut conversation);
498
499        // Build thinking param from options
500        let thinking_param = options.thinking.as_ref().map(|tc| match tc {
501            ThinkingConfig::Adaptive => ThinkingParam {
502                kind: "enabled".into(),
503                budget_tokens: Some(10240),
504            },
505            ThinkingConfig::Disabled => ThinkingParam {
506                kind: "disabled".into(),
507                budget_tokens: None,
508            },
509            ThinkingConfig::Enabled { budget_tokens } => ThinkingParam {
510                kind: "enabled".into(),
511                budget_tokens: Some(*budget_tokens),
512            },
513        });
514
515        // Increase max_tokens when thinking is enabled
516        let base_max_tokens = options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
517        let max_tokens = if let Some(ref tp) = thinking_param {
518            if let Some(budget) = tp.budget_tokens {
519                base_max_tokens.max(budget as u32 + 8192)
520            } else {
521                base_max_tokens
522            }
523        } else {
524            base_max_tokens
525        };
526
527        // Build the API request
528        let use_streaming = options.include_partial_messages;
529        let request = CreateMessageRequest {
530            model: model.clone(),
531            max_tokens,
532            messages: conversation.clone(),
533            system: system_prompt.clone(),
534            tools: if tool_defs.is_empty() {
535                None
536            } else {
537                Some(tool_defs.clone())
538            },
539            stream: use_streaming,
540            metadata: None,
541            thinking: thinking_param,
542        };
543
544        // Call LLM provider
545        let api_start = Instant::now();
546        let response = if use_streaming {
547            // Streaming mode: consume SSE events, emit text deltas, accumulate full response
548            match provider.create_message_stream(&request).await {
549                Ok(mut event_stream) => {
550                    match accumulate_stream(&mut event_stream, &tx, &session_id).await {
551                        Ok(resp) => resp,
552                        Err(e) => {
553                            error!("Stream accumulation failed: {}", e);
554                            let result_msg = build_error_result_message(
555                                &session_id,
556                                &format!("Stream error: {}", e),
557                                start_time,
558                                api_time_ms,
559                                num_turns,
560                                total_cost,
561                                &total_usage,
562                                &model_usage,
563                                &permission_denials,
564                            );
565                            let _ = tx.send(Ok(result_msg));
566                            return Ok(());
567                        }
568                    }
569                }
570                Err(e) => {
571                    error!("API stream call failed: {}", e);
572                    let result_msg = build_error_result_message(
573                        &session_id,
574                        &format!("API error: {}", e),
575                        start_time,
576                        api_time_ms,
577                        num_turns,
578                        total_cost,
579                        &total_usage,
580                        &model_usage,
581                        &permission_denials,
582                    );
583                    let _ = tx.send(Ok(result_msg));
584                    return Ok(());
585                }
586            }
587        } else {
588            // Non-streaming mode: single request/response
589            match provider.create_message(&request).await {
590                Ok(resp) => resp,
591                Err(e) => {
592                    error!("API call failed: {}", e);
593                    let result_msg = build_error_result_message(
594                        &session_id,
595                        &format!("API error: {}", e),
596                        start_time,
597                        api_time_ms,
598                        num_turns,
599                        total_cost,
600                        &total_usage,
601                        &model_usage,
602                        &permission_denials,
603                    );
604                    let _ = tx.send(Ok(result_msg));
605                    return Ok(());
606                }
607            }
608        };
609        api_time_ms += api_start.elapsed().as_millis() as u64;
610
611        // Update usage
612        total_usage.input_tokens += response.usage.input_tokens;
613        total_usage.output_tokens += response.usage.output_tokens;
614        total_usage.cache_creation_input_tokens +=
615            response.usage.cache_creation_input_tokens.unwrap_or(0);
616        total_usage.cache_read_input_tokens += response.usage.cache_read_input_tokens.unwrap_or(0);
617
618        // Estimate cost using provider-specific rates (with cache-aware pricing)
619        let rates = provider.cost_rates(&model);
620        let turn_cost = rates.compute_with_cache(
621            response.usage.input_tokens,
622            response.usage.output_tokens,
623            response.usage.cache_read_input_tokens.unwrap_or(0),
624            response.usage.cache_creation_input_tokens.unwrap_or(0),
625        );
626        total_cost += turn_cost;
627
628        // Update model usage
629        let model_entry = model_usage.entry(model.clone()).or_default();
630        model_entry.input_tokens += response.usage.input_tokens;
631        model_entry.output_tokens += response.usage.output_tokens;
632        model_entry.cost_usd += turn_cost;
633
634        // Convert response to our message types
635        let content_blocks: Vec<ContentBlock> = response
636            .content
637            .iter()
638            .map(api_block_to_content_block)
639            .collect();
640
641        // Emit assistant message
642        let assistant_msg = Message::Assistant(AssistantMessage {
643            uuid: Uuid::new_v4(),
644            session_id: session_id.clone(),
645            content: content_blocks.clone(),
646            model: response.model.clone(),
647            stop_reason: response.stop_reason.clone(),
648            parent_tool_use_id: None,
649            usage: Some(Usage {
650                input_tokens: response.usage.input_tokens,
651                output_tokens: response.usage.output_tokens,
652                cache_creation_input_tokens: response
653                    .usage
654                    .cache_creation_input_tokens
655                    .unwrap_or(0),
656                cache_read_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
657            }),
658            error: None,
659        });
660
661        if options.persist_session {
662            let _ = session
663                .append_message(&serde_json::to_value(&assistant_msg).unwrap_or_default())
664                .await;
665        }
666        if tx.send(Ok(assistant_msg)).is_err() {
667            return Ok(());
668        }
669
670        // Add assistant response to conversation
671        conversation.push(ApiMessage {
672            role: "assistant".to_string(),
673            content: response.content.clone(),
674        });
675
676        // Check if there are tool calls
677        let tool_uses: Vec<_> = response
678            .content
679            .iter()
680            .filter_map(|block| match block {
681                ApiContentBlock::ToolUse { id, name, input } => {
682                    Some((id.clone(), name.clone(), input.clone()))
683                }
684                _ => None,
685            })
686            .collect();
687
688        // If no tool calls, we're done
689        if tool_uses.is_empty() {
690            // Extract final text
691            let final_text = response
692                .content
693                .iter()
694                .filter_map(|block| match block {
695                    ApiContentBlock::Text { text, .. } => Some(text.as_str()),
696                    _ => None,
697                })
698                .collect::<Vec<_>>()
699                .join("");
700
701            let result_msg = build_result_message(
702                ResultSubtype::Success,
703                &session_id,
704                Some(final_text),
705                start_time,
706                api_time_ms,
707                num_turns,
708                total_cost,
709                &total_usage,
710                &model_usage,
711                &permission_denials,
712            );
713
714            if options.persist_session {
715                let _ = session
716                    .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
717                    .await;
718            }
719            let _ = tx.send(Ok(result_msg));
720            return Ok(());
721        }
722
723        // Execute tool calls
724        num_turns += 1;
725        let mut tool_results: Vec<ApiContentBlock> = Vec::new();
726
727        // Phase 0: Reject hallucinated tool names immediately with a helpful error.
728        // Collect known tool names from the definitions we sent to the model.
729        let known_tool_names: std::collections::HashSet<&str> =
730            tool_defs.iter().map(|td| td.name.as_str()).collect();
731
732        let mut valid_tool_uses: Vec<&(String, String, serde_json::Value)> = Vec::new();
733        for tu in &tool_uses {
734            let (tool_use_id, tool_name, _tool_input) = tu;
735            if known_tool_names.contains(tool_name.as_str()) {
736                valid_tool_uses.push(tu);
737            } else {
738                warn!(tool = %tool_name, "model invoked unknown tool, returning error");
739                let available: Vec<&str> = tool_defs.iter().map(|td| td.name.as_str()).collect();
740                let error_msg = format!(
741                    "Error: '{}' is not a valid tool. You MUST use one of the following tools: {}",
742                    tool_name,
743                    available.join(", ")
744                );
745                let api_block = ApiContentBlock::ToolResult {
746                    tool_use_id: tool_use_id.clone(),
747                    content: json!(error_msg),
748                    is_error: Some(true),
749                    cache_control: None,
750                    name: Some(tool_name.clone()),
751                };
752
753                // Stream the error to the frontend
754                let result_msg = Message::User(UserMessage {
755                    uuid: Some(Uuid::new_v4()),
756                    session_id: session_id.clone(),
757                    content: vec![api_block_to_content_block(&api_block)],
758                    parent_tool_use_id: None,
759                    is_synthetic: true,
760                    tool_use_result: None,
761                });
762                if options.persist_session {
763                    let _ = session
764                        .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
765                        .await;
766                }
767                if tx.send(Ok(result_msg)).is_err() {
768                    return Ok(());
769                }
770
771                tool_results.push(api_block);
772            }
773        }
774
775        // Phase 1: Evaluate permissions sequentially (may involve user interaction)
776        struct PermittedTool {
777            tool_use_id: String,
778            tool_name: String,
779            actual_input: serde_json::Value,
780        }
781        let mut permitted_tools: Vec<PermittedTool> = Vec::new();
782
783        for (tool_use_id, tool_name, tool_input) in valid_tool_uses.iter().map(|t| &**t) {
784            let verdict = permission_eval
785                .evaluate(tool_name, tool_input, tool_use_id, &session_id, &cwd)
786                .await?;
787
788            let actual_input = match &verdict {
789                PermissionVerdict::AllowWithUpdatedInput(new_input) => new_input.clone(),
790                _ => tool_input.clone(),
791            };
792
793            match verdict {
794                PermissionVerdict::Allow | PermissionVerdict::AllowWithUpdatedInput(_) => {
795                    permitted_tools.push(PermittedTool {
796                        tool_use_id: tool_use_id.clone(),
797                        tool_name: tool_name.clone(),
798                        actual_input,
799                    });
800                }
801                PermissionVerdict::Deny { reason } => {
802                    debug!(tool = %tool_name, reason = %reason, "Tool denied");
803                    permission_denials.push(PermissionDenial {
804                        tool_name: tool_name.clone(),
805                        tool_use_id: tool_use_id.clone(),
806                        tool_input: tool_input.clone(),
807                    });
808
809                    let api_block = ApiContentBlock::ToolResult {
810                        tool_use_id: tool_use_id.clone(),
811                        content: json!(format!("Permission denied: {}", reason)),
812                        is_error: Some(true),
813                        cache_control: None,
814                        name: Some(tool_name.clone()),
815                    };
816
817                    // Stream denial result to frontend immediately
818                    let denial_msg = Message::User(UserMessage {
819                        uuid: Some(Uuid::new_v4()),
820                        session_id: session_id.clone(),
821                        content: vec![api_block_to_content_block(&api_block)],
822                        parent_tool_use_id: None,
823                        is_synthetic: true,
824                        tool_use_result: None,
825                    });
826                    if options.persist_session {
827                        let _ = session
828                            .append_message(&serde_json::to_value(&denial_msg).unwrap_or_default())
829                            .await;
830                    }
831                    if tx.send(Ok(denial_msg)).is_err() {
832                        return Ok(());
833                    }
834
835                    tool_results.push(api_block);
836                }
837            }
838        }
839
840        // Phase 2: Execute permitted tools concurrently, stream results as they complete
841        let mut futs: FuturesUnordered<_> = permitted_tools
842            .iter()
843            .map(|pt| {
844                let handler = &options.external_tool_handler;
845                let executor = &tool_executor;
846                let name = &pt.tool_name;
847                let input = &pt.actual_input;
848                let id = &pt.tool_use_id;
849                async move {
850                    debug!(tool = %name, "Executing tool");
851
852                    let tool_result = if let Some(ref handler) = handler {
853                        let ext_result = handler(name.clone(), input.clone()).await;
854                        if let Some(tr) = ext_result {
855                            tr
856                        } else {
857                            match executor.execute(name, input.clone()).await {
858                                Ok(tr) => tr,
859                                Err(e) => ToolResult {
860                                    content: format!("{}", e),
861                                    is_error: true,
862                                    raw_content: None,
863                                },
864                            }
865                        }
866                    } else {
867                        match executor.execute(name, input.clone()).await {
868                            Ok(tr) => tr,
869                            Err(e) => ToolResult {
870                                content: format!("{}", e),
871                                is_error: true,
872                                raw_content: None,
873                            },
874                        }
875                    };
876                    (id.as_str(), name.as_str(), input, tool_result)
877                }
878            })
879            .collect();
880
881        while let Some((tool_use_id, tool_name, actual_input, mut tool_result)) = futs.next().await
882        {
883            // Sanitize tool result: strip blobs, enforce byte limit.
884            let max_result_bytes = options
885                .max_tool_result_bytes
886                .unwrap_or(sanitize::DEFAULT_MAX_TOOL_RESULT_BYTES);
887            tool_result.content =
888                sanitize::sanitize_tool_result(&tool_result.content, max_result_bytes);
889
890            // Run PostToolUse hooks
891            hook_registry
892                .run_post_tool_use(
893                    tool_name,
894                    actual_input,
895                    &serde_json::to_value(&tool_result.content).unwrap_or_default(),
896                    tool_use_id,
897                    &session_id,
898                    &cwd,
899                )
900                .await;
901
902            let result_content = tool_result
903                .raw_content
904                .unwrap_or_else(|| json!(tool_result.content));
905
906            let api_block = ApiContentBlock::ToolResult {
907                tool_use_id: tool_use_id.to_string(),
908                content: result_content,
909                is_error: if tool_result.is_error {
910                    Some(true)
911                } else {
912                    None
913                },
914                cache_control: None,
915                name: Some(tool_name.to_string()),
916            };
917
918            // Stream this individual result to the frontend immediately
919            let result_msg = Message::User(UserMessage {
920                uuid: Some(Uuid::new_v4()),
921                session_id: session_id.clone(),
922                content: vec![api_block_to_content_block(&api_block)],
923                parent_tool_use_id: None,
924                is_synthetic: true,
925                tool_use_result: None,
926            });
927            if options.persist_session {
928                let _ = session
929                    .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
930                    .await;
931            }
932            if tx.send(Ok(result_msg)).is_err() {
933                return Ok(());
934            }
935
936            tool_results.push(api_block);
937        }
938
939        // Add all tool results to conversation for the next API call
940        conversation.push(ApiMessage {
941            role: "user".to_string(),
942            content: tool_results,
943        });
944
945        // --- Lightweight pruning (between turns, before full compaction) ---
946        if let Some(context_budget) = options.context_budget {
947            let prune_pct = options
948                .prune_threshold_pct
949                .unwrap_or(compact::DEFAULT_PRUNE_THRESHOLD_PCT);
950            if compact::should_prune(response.usage.input_tokens, context_budget, prune_pct) {
951                let max_chars = options
952                    .prune_tool_result_max_chars
953                    .unwrap_or(compact::DEFAULT_PRUNE_TOOL_RESULT_MAX_CHARS);
954                let min_keep = options.min_keep_messages.unwrap_or(4);
955                let removed = compact::prune_tool_results(&mut conversation, max_chars, min_keep);
956                if removed > 0 {
957                    debug!(
958                        chars_removed = removed,
959                        input_tokens = response.usage.input_tokens,
960                        "Pruned oversized tool results to free context space"
961                    );
962                }
963            }
964        }
965
966        // --- Compaction check (between turns) ---
967        if let Some(context_budget) = options.context_budget {
968            if compact::should_compact(response.usage.input_tokens, context_budget) {
969                let min_keep = options.min_keep_messages.unwrap_or(4);
970                let split_point = compact::find_split_point(&conversation, min_keep);
971                if split_point > 0 {
972                    debug!(
973                        input_tokens = response.usage.input_tokens,
974                        context_budget,
975                        split_point,
976                        "Context budget exceeded, compacting conversation"
977                    );
978
979                    let compaction_model = options
980                        .compaction_model
981                        .as_deref()
982                        .unwrap_or(compact::DEFAULT_COMPACTION_MODEL);
983
984                    // Fire pre-compact handler so the host can persist key facts
985                    if let Some(ref handler) = options.pre_compact_handler {
986                        let msgs_to_compact = conversation[..split_point].to_vec();
987                        handler(msgs_to_compact).await;
988                    }
989
990                    let summary_prompt =
991                        compact::build_summary_prompt(&conversation[..split_point]);
992
993                    let summary_max_tokens = options.summary_max_tokens.unwrap_or(4096);
994                    let compact_provider: &dyn LlmProvider = match &options.compaction_provider {
995                        Some(cp) => cp.as_ref(),
996                        None => provider.as_ref(),
997                    };
998                    let fallback_provider: Option<&dyn LlmProvider> =
999                        if options.compaction_provider.is_some() {
1000                            Some(provider.as_ref())
1001                        } else {
1002                            None
1003                        };
1004                    match compact::call_summarizer(
1005                        compact_provider,
1006                        &summary_prompt,
1007                        compaction_model,
1008                        fallback_provider,
1009                        &model,
1010                        summary_max_tokens,
1011                    )
1012                    .await
1013                    {
1014                        Ok(summary) => {
1015                            let pre_tokens = response.usage.input_tokens;
1016                            let messages_compacted = split_point;
1017
1018                            compact::splice_conversation(&mut conversation, split_point, &summary);
1019
1020                            // Emit CompactBoundary system message
1021                            let compact_msg = Message::System(SystemMessage {
1022                                subtype: SystemSubtype::CompactBoundary,
1023                                uuid: Uuid::new_v4(),
1024                                session_id: session_id.clone(),
1025                                agents: None,
1026                                claude_code_version: None,
1027                                cwd: None,
1028                                tools: None,
1029                                mcp_servers: None,
1030                                model: None,
1031                                permission_mode: None,
1032                                compact_metadata: Some(CompactMetadata {
1033                                    trigger: CompactTrigger::Auto,
1034                                    pre_tokens,
1035                                }),
1036                            });
1037
1038                            if options.persist_session {
1039                                let _ = session
1040                                    .append_message(
1041                                        &serde_json::to_value(&compact_msg).unwrap_or_default(),
1042                                    )
1043                                    .await;
1044                            }
1045                            let _ = tx.send(Ok(compact_msg));
1046
1047                            debug!(
1048                                pre_tokens,
1049                                messages_compacted,
1050                                summary_len = summary.len(),
1051                                "Conversation compacted"
1052                            );
1053                        }
1054                        Err(e) => {
1055                            warn!("Compaction failed, continuing without compaction: {}", e);
1056                        }
1057                    }
1058                }
1059            }
1060        }
1061    }
1062}
1063
1064/// Consume a streaming response, emitting `Message::StreamEvent` for each text
1065/// delta, and accumulate the full `MessageResponse` for the agent loop.
1066async fn accumulate_stream(
1067    event_stream: &mut std::pin::Pin<
1068        Box<dyn futures::Stream<Item = Result<ClientStreamEvent>> + Send>,
1069    >,
1070    tx: &mpsc::UnboundedSender<Result<Message>>,
1071    session_id: &str,
1072) -> Result<MessageResponse> {
1073    use crate::client::StreamEvent as SE;
1074
1075    // Accumulated state
1076    let mut message_id = String::new();
1077    let mut model = String::new();
1078    let mut role = String::from("assistant");
1079    let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
1080    let mut stop_reason: Option<String> = None;
1081    let mut usage = ApiUsage::default();
1082
1083    // Track in-progress content blocks by index
1084    // For text blocks: accumulate text. For tool_use: accumulate JSON string.
1085    let mut block_texts: Vec<String> = Vec::new();
1086    let mut block_types: Vec<String> = Vec::new(); // "text", "tool_use", "thinking"
1087    let mut block_tool_ids: Vec<String> = Vec::new();
1088    let mut block_tool_names: Vec<String> = Vec::new();
1089
1090    while let Some(event_result) = FuturesStreamExt::next(event_stream).await {
1091        let event = event_result?;
1092        match event {
1093            SE::MessageStart { message } => {
1094                message_id = message.id;
1095                model = message.model;
1096                role = message.role;
1097                usage = message.usage;
1098            }
1099            SE::ContentBlockStart {
1100                index,
1101                content_block,
1102            } => {
1103                // Ensure vectors are large enough
1104                while block_texts.len() <= index {
1105                    block_texts.push(String::new());
1106                    block_types.push(String::new());
1107                    block_tool_ids.push(String::new());
1108                    block_tool_names.push(String::new());
1109                }
1110                match &content_block {
1111                    ApiContentBlock::Text { .. } => {
1112                        block_types[index] = "text".to_string();
1113                    }
1114                    ApiContentBlock::ToolUse { id, name, input } => {
1115                        block_types[index] = "tool_use".to_string();
1116                        block_tool_ids[index] = id.clone();
1117                        block_tool_names[index] = name.clone();
1118                        // OpenAI/Ollama streaming delivers the complete input
1119                        // in ContentBlockStart (not via InputJsonDelta like
1120                        // Anthropic). Store it so ContentBlockStop can parse it.
1121                        let input_str = input.to_string();
1122                        if input_str != "{}" {
1123                            block_texts[index] = input_str;
1124                        }
1125                    }
1126                    ApiContentBlock::Thinking { .. } => {
1127                        block_types[index] = "thinking".to_string();
1128                    }
1129                    _ => {}
1130                }
1131            }
1132            SE::ContentBlockDelta { index, delta } => {
1133                while block_texts.len() <= index {
1134                    block_texts.push(String::new());
1135                    block_types.push(String::new());
1136                    block_tool_ids.push(String::new());
1137                    block_tool_names.push(String::new());
1138                }
1139                match &delta {
1140                    ContentDelta::TextDelta { text } => {
1141                        block_texts[index].push_str(text);
1142                        // Emit streaming event so downstream consumers get per-token updates
1143                        let stream_event = Message::StreamEvent(StreamEventMessage {
1144                            event: serde_json::json!({
1145                                "type": "content_block_delta",
1146                                "index": index,
1147                                "delta": { "type": "text_delta", "text": text }
1148                            }),
1149                            parent_tool_use_id: None,
1150                            uuid: Uuid::new_v4(),
1151                            session_id: session_id.to_string(),
1152                        });
1153                        if tx.send(Ok(stream_event)).is_err() {
1154                            return Err(AgentError::Cancelled);
1155                        }
1156                    }
1157                    ContentDelta::InputJsonDelta { partial_json } => {
1158                        block_texts[index].push_str(partial_json);
1159                    }
1160                    ContentDelta::ThinkingDelta { thinking } => {
1161                        block_texts[index].push_str(thinking);
1162                    }
1163                }
1164            }
1165            SE::ContentBlockStop { index } => {
1166                if index < block_types.len() {
1167                    let block = match block_types[index].as_str() {
1168                        "text" => ApiContentBlock::Text {
1169                            text: std::mem::take(&mut block_texts[index]),
1170                            cache_control: None,
1171                        },
1172                        "tool_use" => {
1173                            let input: serde_json::Value =
1174                                serde_json::from_str(&block_texts[index])
1175                                    .unwrap_or(serde_json::Value::Object(Default::default()));
1176                            ApiContentBlock::ToolUse {
1177                                id: std::mem::take(&mut block_tool_ids[index]),
1178                                name: std::mem::take(&mut block_tool_names[index]),
1179                                input,
1180                            }
1181                        }
1182                        "thinking" => ApiContentBlock::Thinking {
1183                            thinking: std::mem::take(&mut block_texts[index]),
1184                        },
1185                        _ => continue,
1186                    };
1187                    // Place blocks at the correct index
1188                    while content_blocks.len() <= index {
1189                        content_blocks.push(ApiContentBlock::Text {
1190                            text: String::new(),
1191                            cache_control: None,
1192                        });
1193                    }
1194                    content_blocks[index] = block;
1195                }
1196            }
1197            SE::MessageDelta {
1198                delta,
1199                usage: delta_usage,
1200            } => {
1201                stop_reason = delta.stop_reason;
1202                // MessageDelta carries output_tokens for the whole message
1203                usage.output_tokens = delta_usage.output_tokens;
1204            }
1205            SE::MessageStop => {
1206                break;
1207            }
1208            SE::Error { error } => {
1209                return Err(AgentError::Api(error.message));
1210            }
1211            SE::Ping => {}
1212        }
1213    }
1214
1215    Ok(MessageResponse {
1216        id: message_id,
1217        role,
1218        content: content_blocks,
1219        model,
1220        stop_reason,
1221        usage,
1222    })
1223}
1224
1225/// Apply a single cache breakpoint to the last content block of the last user
1226/// message in the conversation. Clears any previous breakpoints from messages
1227/// so we stay within the API limit of 4 cache_control blocks (system + tools +
1228/// this one = 3 total).
1229fn apply_cache_breakpoint(conversation: &mut [ApiMessage]) {
1230    // First, clear all existing cache_control from messages
1231    for msg in conversation.iter_mut() {
1232        for block in msg.content.iter_mut() {
1233            match block {
1234                ApiContentBlock::Text { cache_control, .. }
1235                | ApiContentBlock::ToolResult { cache_control, .. } => {
1236                    *cache_control = None;
1237                }
1238                ApiContentBlock::Image { .. }
1239                | ApiContentBlock::ToolUse { .. }
1240                | ApiContentBlock::Thinking { .. } => {}
1241            }
1242        }
1243    }
1244
1245    // Set cache_control on the last content block of the last user message
1246    if let Some(last_user) = conversation.iter_mut().rev().find(|m| m.role == "user") {
1247        if let Some(last_block) = last_user.content.last_mut() {
1248            match last_block {
1249                ApiContentBlock::Text { cache_control, .. }
1250                | ApiContentBlock::ToolResult { cache_control, .. } => {
1251                    *cache_control = Some(CacheControl::ephemeral());
1252                }
1253                ApiContentBlock::Image { .. }
1254                | ApiContentBlock::ToolUse { .. }
1255                | ApiContentBlock::Thinking { .. } => {}
1256            }
1257        }
1258    }
1259}
1260
1261/// Convert an API content block to our ContentBlock type.
1262fn api_block_to_content_block(block: &ApiContentBlock) -> ContentBlock {
1263    match block {
1264        ApiContentBlock::Text { text, .. } => ContentBlock::Text { text: text.clone() },
1265        ApiContentBlock::Image { .. } => ContentBlock::Text {
1266            text: "[image]".to_string(),
1267        },
1268        ApiContentBlock::ToolUse { id, name, input } => ContentBlock::ToolUse {
1269            id: id.clone(),
1270            name: name.clone(),
1271            input: input.clone(),
1272        },
1273        ApiContentBlock::ToolResult {
1274            tool_use_id,
1275            content,
1276            is_error,
1277            ..
1278        } => ContentBlock::ToolResult {
1279            tool_use_id: tool_use_id.clone(),
1280            content: content.clone(),
1281            is_error: *is_error,
1282        },
1283        ApiContentBlock::Thinking { thinking } => ContentBlock::Thinking {
1284            thinking: thinking.clone(),
1285        },
1286    }
1287}
1288
1289/// Try to convert a stored JSON value to an API message.
1290fn value_to_api_message(value: &serde_json::Value) -> Option<ApiMessage> {
1291    let msg_type = value.get("type")?.as_str()?;
1292
1293    match msg_type {
1294        "assistant" => {
1295            let content = value.get("content")?;
1296            let blocks = parse_content_blocks(content)?;
1297            Some(ApiMessage {
1298                role: "assistant".to_string(),
1299                content: blocks,
1300            })
1301        }
1302        "user" => {
1303            let content = value.get("content")?;
1304            let blocks = parse_content_blocks(content)?;
1305            Some(ApiMessage {
1306                role: "user".to_string(),
1307                content: blocks,
1308            })
1309        }
1310        _ => None,
1311    }
1312}
1313
1314/// Parse content blocks from a JSON value.
1315fn parse_content_blocks(content: &serde_json::Value) -> Option<Vec<ApiContentBlock>> {
1316    if let Some(text) = content.as_str() {
1317        return Some(vec![ApiContentBlock::Text {
1318            text: text.to_string(),
1319            cache_control: None,
1320        }]);
1321    }
1322
1323    if let Some(blocks) = content.as_array() {
1324        let parsed: Vec<ApiContentBlock> = blocks
1325            .iter()
1326            .filter_map(|b| serde_json::from_value(b.clone()).ok())
1327            .collect();
1328        if !parsed.is_empty() {
1329            return Some(parsed);
1330        }
1331    }
1332
1333    None
1334}
1335
1336/// Build a ResultMessage.
1337#[allow(clippy::too_many_arguments)]
1338fn build_result_message(
1339    subtype: ResultSubtype,
1340    session_id: &str,
1341    result_text: Option<String>,
1342    start_time: Instant,
1343    api_time_ms: u64,
1344    num_turns: u32,
1345    total_cost: f64,
1346    usage: &Usage,
1347    model_usage: &HashMap<String, ModelUsage>,
1348    permission_denials: &[PermissionDenial],
1349) -> Message {
1350    Message::Result(ResultMessage {
1351        subtype,
1352        uuid: Uuid::new_v4(),
1353        session_id: session_id.to_string(),
1354        duration_ms: start_time.elapsed().as_millis() as u64,
1355        duration_api_ms: api_time_ms,
1356        is_error: result_text.is_none(),
1357        num_turns,
1358        result: result_text,
1359        stop_reason: Some("end_turn".to_string()),
1360        total_cost_usd: total_cost,
1361        usage: Some(usage.clone()),
1362        model_usage: model_usage.clone(),
1363        permission_denials: permission_denials.to_vec(),
1364        structured_output: None,
1365        errors: Vec::new(),
1366    })
1367}
1368
1369/// Build an error ResultMessage.
1370#[allow(clippy::too_many_arguments)]
1371fn build_error_result_message(
1372    session_id: &str,
1373    error_msg: &str,
1374    start_time: Instant,
1375    api_time_ms: u64,
1376    num_turns: u32,
1377    total_cost: f64,
1378    usage: &Usage,
1379    model_usage: &HashMap<String, ModelUsage>,
1380    permission_denials: &[PermissionDenial],
1381) -> Message {
1382    Message::Result(ResultMessage {
1383        subtype: ResultSubtype::ErrorDuringExecution,
1384        uuid: Uuid::new_v4(),
1385        session_id: session_id.to_string(),
1386        duration_ms: start_time.elapsed().as_millis() as u64,
1387        duration_api_ms: api_time_ms,
1388        is_error: true,
1389        num_turns,
1390        result: None,
1391        stop_reason: None,
1392        total_cost_usd: total_cost,
1393        usage: Some(usage.clone()),
1394        model_usage: model_usage.clone(),
1395        permission_denials: permission_denials.to_vec(),
1396        structured_output: None,
1397        errors: vec![error_msg.to_string()],
1398    })
1399}
1400
1401#[cfg(test)]
1402mod tests {
1403    use super::*;
1404    use std::sync::atomic::{AtomicUsize, Ordering};
1405    use std::sync::Arc;
1406    use std::time::Duration;
1407
1408    /// Helper: execute tools concurrently using the same FuturesUnordered pattern
1409    /// as the production code, collecting (tool_use_id, content, completion_order).
1410    async fn run_concurrent_tools(
1411        tools: Vec<(String, String, serde_json::Value)>,
1412        handler: impl Fn(
1413            String,
1414            serde_json::Value,
1415        ) -> Pin<Box<dyn futures::Future<Output = Option<ToolResult>> + Send>>,
1416    ) -> Vec<(String, String, usize)> {
1417        let order = Arc::new(AtomicUsize::new(0));
1418        let handler = Arc::new(handler);
1419
1420        struct PermittedTool {
1421            tool_use_id: String,
1422            tool_name: String,
1423            actual_input: serde_json::Value,
1424        }
1425
1426        let permitted: Vec<PermittedTool> = tools
1427            .into_iter()
1428            .map(|(id, name, input)| PermittedTool {
1429                tool_use_id: id,
1430                tool_name: name,
1431                actual_input: input,
1432            })
1433            .collect();
1434
1435        let mut futs: FuturesUnordered<_> = permitted
1436            .iter()
1437            .map(|pt| {
1438                let handler = handler.clone();
1439                let order = order.clone();
1440                let name = pt.tool_name.clone();
1441                let input = pt.actual_input.clone();
1442                let id = pt.tool_use_id.clone();
1443                async move {
1444                    let result = handler(name, input).await;
1445                    let seq = order.fetch_add(1, Ordering::SeqCst);
1446                    (id, result, seq)
1447                }
1448            })
1449            .collect();
1450
1451        let mut results = Vec::new();
1452        while let Some((id, result, seq)) = futs.next().await {
1453            let content = result
1454                .map(|r| r.content)
1455                .unwrap_or_else(|| "no handler".into());
1456            results.push((id, content, seq));
1457        }
1458        results
1459    }
1460
1461    #[tokio::test]
1462    async fn concurrent_tools_all_complete() {
1463        let results = run_concurrent_tools(
1464            vec![
1465                ("t1".into(), "Read".into(), json!({"path": "a.txt"})),
1466                ("t2".into(), "Read".into(), json!({"path": "b.txt"})),
1467                ("t3".into(), "Read".into(), json!({"path": "c.txt"})),
1468            ],
1469            |name, input| {
1470                Box::pin(async move {
1471                    let path = input["path"].as_str().unwrap_or("?");
1472                    Some(ToolResult {
1473                        content: format!("{}: {}", name, path),
1474                        is_error: false,
1475                        raw_content: None,
1476                    })
1477                })
1478            },
1479        )
1480        .await;
1481
1482        assert_eq!(results.len(), 3);
1483        let ids: Vec<&str> = results.iter().map(|(id, _, _)| id.as_str()).collect();
1484        assert!(ids.contains(&"t1"));
1485        assert!(ids.contains(&"t2"));
1486        assert!(ids.contains(&"t3"));
1487    }
1488
1489    #[tokio::test]
1490    async fn slow_tool_does_not_block_fast_tools() {
1491        let start = Instant::now();
1492
1493        let results = run_concurrent_tools(
1494            vec![
1495                ("slow".into(), "Bash".into(), json!({})),
1496                ("fast1".into(), "Read".into(), json!({})),
1497                ("fast2".into(), "Read".into(), json!({})),
1498            ],
1499            |name, _input| {
1500                Box::pin(async move {
1501                    if name == "Bash" {
1502                        tokio::time::sleep(Duration::from_millis(200)).await;
1503                        Some(ToolResult {
1504                            content: "slow done".into(),
1505                            is_error: false,
1506                            raw_content: None,
1507                        })
1508                    } else {
1509                        // Fast tools complete immediately
1510                        Some(ToolResult {
1511                            content: "fast done".into(),
1512                            is_error: false,
1513                            raw_content: None,
1514                        })
1515                    }
1516                })
1517            },
1518        )
1519        .await;
1520
1521        let elapsed = start.elapsed();
1522
1523        // All three should complete
1524        assert_eq!(results.len(), 3);
1525
1526        // Fast tools should complete before the slow tool (lower order index)
1527        let slow = results.iter().find(|(id, _, _)| id == "slow").unwrap();
1528        let fast1 = results.iter().find(|(id, _, _)| id == "fast1").unwrap();
1529        let fast2 = results.iter().find(|(id, _, _)| id == "fast2").unwrap();
1530
1531        assert!(fast1.2 < slow.2, "fast1 should complete before slow");
1532        assert!(fast2.2 < slow.2, "fast2 should complete before slow");
1533
1534        // Total time should be ~200ms (concurrent), not ~400ms+ (sequential)
1535        assert!(
1536            elapsed < Duration::from_millis(400),
1537            "elapsed {:?} should be under 400ms (concurrent execution)",
1538            elapsed
1539        );
1540    }
1541
1542    #[tokio::test]
1543    async fn results_streamed_individually_as_they_complete() {
1544        // Simulate the streaming pattern from the production code:
1545        // each tool result is sent to the channel as it completes.
1546        let (tx, mut rx) = mpsc::unbounded_channel::<(String, String)>();
1547
1548        let tools = vec![
1549            ("t_slow".into(), "Slow".into(), json!({})),
1550            ("t_fast".into(), "Fast".into(), json!({})),
1551        ];
1552
1553        struct PT {
1554            tool_use_id: String,
1555            tool_name: String,
1556        }
1557
1558        let permitted: Vec<PT> = tools
1559            .into_iter()
1560            .map(|(id, name, _)| PT {
1561                tool_use_id: id,
1562                tool_name: name,
1563            })
1564            .collect();
1565
1566        let mut futs: FuturesUnordered<_> = permitted
1567            .iter()
1568            .map(|pt| {
1569                let name = pt.tool_name.clone();
1570                let id = pt.tool_use_id.clone();
1571                async move {
1572                    if name == "Slow" {
1573                        tokio::time::sleep(Duration::from_millis(100)).await;
1574                    }
1575                    let result = ToolResult {
1576                        content: format!("{} result", name),
1577                        is_error: false,
1578                        raw_content: None,
1579                    };
1580                    (id, result)
1581                }
1582            })
1583            .collect();
1584
1585        // Process results as they complete (like production code)
1586        while let Some((id, result)) = futs.next().await {
1587            tx.send((id, result.content)).unwrap();
1588        }
1589        drop(tx);
1590
1591        // Collect what was streamed
1592        let mut streamed = Vec::new();
1593        while let Some(item) = rx.recv().await {
1594            streamed.push(item);
1595        }
1596
1597        assert_eq!(streamed.len(), 2);
1598        // Fast should arrive first
1599        assert_eq!(streamed[0].0, "t_fast");
1600        assert_eq!(streamed[0].1, "Fast result");
1601        assert_eq!(streamed[1].0, "t_slow");
1602        assert_eq!(streamed[1].1, "Slow result");
1603    }
1604
1605    #[tokio::test]
1606    async fn error_tool_does_not_prevent_other_tools() {
1607        let results = run_concurrent_tools(
1608            vec![
1609                ("t_ok".into(), "Read".into(), json!({})),
1610                ("t_err".into(), "Fail".into(), json!({})),
1611            ],
1612            |name, _input| {
1613                Box::pin(async move {
1614                    if name == "Fail" {
1615                        Some(ToolResult {
1616                            content: "something went wrong".into(),
1617                            is_error: true,
1618                            raw_content: None,
1619                        })
1620                    } else {
1621                        Some(ToolResult {
1622                            content: "ok".into(),
1623                            is_error: false,
1624                            raw_content: None,
1625                        })
1626                    }
1627                })
1628            },
1629        )
1630        .await;
1631
1632        assert_eq!(results.len(), 2);
1633        let ok = results.iter().find(|(id, _, _)| id == "t_ok").unwrap();
1634        let err = results.iter().find(|(id, _, _)| id == "t_err").unwrap();
1635        assert_eq!(ok.1, "ok");
1636        assert_eq!(err.1, "something went wrong");
1637    }
1638
1639    #[tokio::test]
1640    async fn external_handler_none_falls_through_correctly() {
1641        // When handler returns None for a tool, the production code falls through
1642        // to the built-in executor. Test that the pattern works.
1643        let results = run_concurrent_tools(
1644            vec![
1645                ("t_custom".into(), "MyTool".into(), json!({"x": 1})),
1646                ("t_builtin".into(), "Read".into(), json!({"path": "/tmp"})),
1647            ],
1648            |name, _input| {
1649                Box::pin(async move {
1650                    if name == "MyTool" {
1651                        Some(ToolResult {
1652                            content: "custom handled".into(),
1653                            is_error: false,
1654                            raw_content: None,
1655                        })
1656                    } else {
1657                        // Returns None => would fall through to built-in executor
1658                        None
1659                    }
1660                })
1661            },
1662        )
1663        .await;
1664
1665        assert_eq!(results.len(), 2);
1666        let custom = results.iter().find(|(id, _, _)| id == "t_custom").unwrap();
1667        let builtin = results.iter().find(|(id, _, _)| id == "t_builtin").unwrap();
1668        assert_eq!(custom.1, "custom handled");
1669        assert_eq!(builtin.1, "no handler"); // our test helper treats None as "no handler"
1670    }
1671
1672    #[tokio::test]
1673    async fn single_tool_works_same_as_before() {
1674        let results = run_concurrent_tools(
1675            vec![("t1".into(), "Read".into(), json!({"path": "file.txt"}))],
1676            |_name, _input| {
1677                Box::pin(async move {
1678                    Some(ToolResult {
1679                        content: "file contents".into(),
1680                        is_error: false,
1681                        raw_content: None,
1682                    })
1683                })
1684            },
1685        )
1686        .await;
1687
1688        assert_eq!(results.len(), 1);
1689        assert_eq!(results[0].0, "t1");
1690        assert_eq!(results[0].1, "file contents");
1691        assert_eq!(results[0].2, 0); // first (and only) completion
1692    }
1693
1694    #[tokio::test]
1695    async fn empty_tool_list_produces_no_results() {
1696        let results =
1697            run_concurrent_tools(vec![], |_name, _input| Box::pin(async move { None })).await;
1698
1699        assert_eq!(results.len(), 0);
1700    }
1701
1702    #[tokio::test]
1703    async fn tool_use_ids_preserved_through_concurrent_execution() {
1704        let results = run_concurrent_tools(
1705            vec![
1706                ("toolu_abc123".into(), "Read".into(), json!({})),
1707                ("toolu_def456".into(), "Write".into(), json!({})),
1708                ("toolu_ghi789".into(), "Bash".into(), json!({})),
1709            ],
1710            |name, _input| {
1711                Box::pin(async move {
1712                    // Add varying delays to shuffle completion order
1713                    match name.as_str() {
1714                        "Read" => tokio::time::sleep(Duration::from_millis(30)).await,
1715                        "Write" => tokio::time::sleep(Duration::from_millis(10)).await,
1716                        _ => tokio::time::sleep(Duration::from_millis(50)).await,
1717                    }
1718                    Some(ToolResult {
1719                        content: format!("{} result", name),
1720                        is_error: false,
1721                        raw_content: None,
1722                    })
1723                })
1724            },
1725        )
1726        .await;
1727
1728        assert_eq!(results.len(), 3);
1729
1730        // Regardless of completion order, IDs must match their tools
1731        for (id, content, _) in &results {
1732            match id.as_str() {
1733                "toolu_abc123" => assert_eq!(content, "Read result"),
1734                "toolu_def456" => assert_eq!(content, "Write result"),
1735                "toolu_ghi789" => assert_eq!(content, "Bash result"),
1736                other => panic!("unexpected tool_use_id: {}", other),
1737            }
1738        }
1739    }
1740
1741    #[tokio::test]
1742    async fn concurrent_execution_timing_is_parallel() {
1743        // 5 tools each taking 50ms should complete in ~50ms total, not 250ms
1744        let tools: Vec<_> = (0..5)
1745            .map(|i| (format!("t{}", i), "Tool".into(), json!({})))
1746            .collect();
1747
1748        let start = Instant::now();
1749
1750        let results = run_concurrent_tools(tools, |_name, _input| {
1751            Box::pin(async move {
1752                tokio::time::sleep(Duration::from_millis(50)).await;
1753                Some(ToolResult {
1754                    content: "done".into(),
1755                    is_error: false,
1756                    raw_content: None,
1757                })
1758            })
1759        })
1760        .await;
1761
1762        let elapsed = start.elapsed();
1763
1764        assert_eq!(results.len(), 5);
1765        // Should complete in roughly 50ms, definitely under 200ms
1766        assert!(
1767            elapsed < Duration::from_millis(200),
1768            "5 x 50ms tools took {:?} — should be ~50ms if concurrent",
1769            elapsed
1770        );
1771    }
1772
1773    #[tokio::test]
1774    async fn api_block_to_content_block_preserves_tool_result_fields() {
1775        let block = ApiContentBlock::ToolResult {
1776            tool_use_id: "toolu_abc".into(),
1777            content: json!("result text"),
1778            is_error: Some(true),
1779            cache_control: None,
1780            name: None,
1781        };
1782
1783        let content = api_block_to_content_block(&block);
1784        match content {
1785            ContentBlock::ToolResult {
1786                tool_use_id,
1787                content,
1788                is_error,
1789            } => {
1790                assert_eq!(tool_use_id, "toolu_abc");
1791                assert_eq!(content, json!("result text"));
1792                assert_eq!(is_error, Some(true));
1793            }
1794            _ => panic!("expected ToolResult content block"),
1795        }
1796    }
1797
1798    #[tokio::test]
1799    async fn streamed_messages_each_contain_single_tool_result() {
1800        // Verify that the streaming pattern produces one User message per tool result
1801        let (tx, mut rx) = mpsc::unbounded_channel::<Result<Message>>();
1802        let session_id = "test-session".to_string();
1803
1804        // Simulate what the production code does
1805        let tool_ids = vec!["t1", "t2", "t3"];
1806        for id in &tool_ids {
1807            let api_block = ApiContentBlock::ToolResult {
1808                tool_use_id: id.to_string(),
1809                content: json!(format!("result for {}", id)),
1810                is_error: None,
1811                cache_control: None,
1812                name: None,
1813            };
1814
1815            let result_msg = Message::User(UserMessage {
1816                uuid: Some(Uuid::new_v4()),
1817                session_id: session_id.clone(),
1818                content: vec![api_block_to_content_block(&api_block)],
1819                parent_tool_use_id: None,
1820                is_synthetic: true,
1821                tool_use_result: None,
1822            });
1823            tx.send(Ok(result_msg)).unwrap();
1824        }
1825        drop(tx);
1826
1827        let mut messages = Vec::new();
1828        while let Some(Ok(msg)) = rx.recv().await {
1829            messages.push(msg);
1830        }
1831
1832        assert_eq!(messages.len(), 3, "should have 3 individual messages");
1833
1834        for (i, msg) in messages.iter().enumerate() {
1835            if let Message::User(user) = msg {
1836                assert_eq!(
1837                    user.content.len(),
1838                    1,
1839                    "each message should have exactly 1 content block"
1840                );
1841                assert!(user.is_synthetic);
1842                if let ContentBlock::ToolResult { tool_use_id, .. } = &user.content[0] {
1843                    assert_eq!(tool_use_id, tool_ids[i]);
1844                } else {
1845                    panic!("expected ToolResult block");
1846                }
1847            } else {
1848                panic!("expected User message");
1849            }
1850        }
1851    }
1852
1853    #[tokio::test]
1854    async fn accumulate_stream_emits_text_deltas_and_builds_response() {
1855        use crate::client::{
1856            ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE,
1857        };
1858
1859        // Build a fake stream of SSE events
1860        let events: Vec<Result<SE>> = vec![
1861            Ok(SE::MessageStart {
1862                message: MessageResponse {
1863                    id: "msg_123".into(),
1864                    role: "assistant".into(),
1865                    content: vec![],
1866                    model: "claude-test".into(),
1867                    stop_reason: None,
1868                    usage: ApiUsage {
1869                        input_tokens: 100,
1870                        output_tokens: 0,
1871                        cache_creation_input_tokens: None,
1872                        cache_read_input_tokens: None,
1873                    },
1874                },
1875            }),
1876            Ok(SE::ContentBlockStart {
1877                index: 0,
1878                content_block: ApiContentBlock::Text {
1879                    text: String::new(),
1880                    cache_control: None,
1881                },
1882            }),
1883            Ok(SE::ContentBlockDelta {
1884                index: 0,
1885                delta: ContentDelta::TextDelta {
1886                    text: "Hello".into(),
1887                },
1888            }),
1889            Ok(SE::ContentBlockDelta {
1890                index: 0,
1891                delta: ContentDelta::TextDelta {
1892                    text: " world".into(),
1893                },
1894            }),
1895            Ok(SE::ContentBlockDelta {
1896                index: 0,
1897                delta: ContentDelta::TextDelta { text: "!".into() },
1898            }),
1899            Ok(SE::ContentBlockStop { index: 0 }),
1900            Ok(SE::MessageDelta {
1901                delta: crate::client::MessageDelta {
1902                    stop_reason: Some("end_turn".into()),
1903                },
1904                usage: ApiUsage {
1905                    input_tokens: 0,
1906                    output_tokens: 15,
1907                    cache_creation_input_tokens: None,
1908                    cache_read_input_tokens: None,
1909                },
1910            }),
1911            Ok(SE::MessageStop),
1912        ];
1913
1914        let stream = futures::stream::iter(events);
1915        let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
1916            Box::pin(stream);
1917
1918        let (tx, mut rx) = mpsc::unbounded_channel();
1919
1920        let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
1921            .await
1922            .expect("accumulate_stream should succeed");
1923
1924        // Verify accumulated response
1925        assert_eq!(response.id, "msg_123");
1926        assert_eq!(response.model, "claude-test");
1927        assert_eq!(response.stop_reason, Some("end_turn".into()));
1928        assert_eq!(response.usage.output_tokens, 15);
1929        assert_eq!(response.content.len(), 1);
1930        if let ApiContentBlock::Text { text, .. } = &response.content[0] {
1931            assert_eq!(text, "Hello world!");
1932        } else {
1933            panic!("expected Text content block");
1934        }
1935
1936        // Verify 3 StreamEvent messages were emitted (one per text delta)
1937        let mut stream_events = Vec::new();
1938        while let Ok(msg) = rx.try_recv() {
1939            stream_events.push(msg.unwrap());
1940        }
1941        assert_eq!(stream_events.len(), 3);
1942
1943        // Verify each is a StreamEvent with the correct text
1944        let expected_texts = ["Hello", " world", "!"];
1945        for (i, msg) in stream_events.iter().enumerate() {
1946            if let Message::StreamEvent(se) = msg {
1947                let delta = se.event.get("delta").unwrap();
1948                let text = delta.get("text").unwrap().as_str().unwrap();
1949                assert_eq!(text, expected_texts[i]);
1950                assert_eq!(se.session_id, "test-session");
1951            } else {
1952                panic!("expected StreamEvent message at index {}", i);
1953            }
1954        }
1955    }
1956
1957    #[tokio::test]
1958    async fn accumulate_stream_handles_tool_use() {
1959        use crate::client::{
1960            ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE,
1961        };
1962
1963        let events: Vec<Result<SE>> = vec![
1964            Ok(SE::MessageStart {
1965                message: MessageResponse {
1966                    id: "msg_456".into(),
1967                    role: "assistant".into(),
1968                    content: vec![],
1969                    model: "claude-test".into(),
1970                    stop_reason: None,
1971                    usage: ApiUsage::default(),
1972                },
1973            }),
1974            // Text block
1975            Ok(SE::ContentBlockStart {
1976                index: 0,
1977                content_block: ApiContentBlock::Text {
1978                    text: String::new(),
1979                    cache_control: None,
1980                },
1981            }),
1982            Ok(SE::ContentBlockDelta {
1983                index: 0,
1984                delta: ContentDelta::TextDelta {
1985                    text: "Let me check.".into(),
1986                },
1987            }),
1988            Ok(SE::ContentBlockStop { index: 0 }),
1989            // Tool use block
1990            Ok(SE::ContentBlockStart {
1991                index: 1,
1992                content_block: ApiContentBlock::ToolUse {
1993                    id: "toolu_abc".into(),
1994                    name: "Read".into(),
1995                    input: serde_json::json!({}),
1996                },
1997            }),
1998            Ok(SE::ContentBlockDelta {
1999                index: 1,
2000                delta: ContentDelta::InputJsonDelta {
2001                    partial_json: r#"{"path":"/tmp/f.txt"}"#.into(),
2002                },
2003            }),
2004            Ok(SE::ContentBlockStop { index: 1 }),
2005            Ok(SE::MessageDelta {
2006                delta: crate::client::MessageDelta {
2007                    stop_reason: Some("tool_use".into()),
2008                },
2009                usage: ApiUsage {
2010                    input_tokens: 0,
2011                    output_tokens: 20,
2012                    ..Default::default()
2013                },
2014            }),
2015            Ok(SE::MessageStop),
2016        ];
2017
2018        let stream = futures::stream::iter(events);
2019        let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
2020            Box::pin(stream);
2021
2022        let (tx, _rx) = mpsc::unbounded_channel();
2023        let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
2024            .await
2025            .expect("should succeed");
2026
2027        assert_eq!(response.content.len(), 2);
2028        if let ApiContentBlock::Text { text, .. } = &response.content[0] {
2029            assert_eq!(text, "Let me check.");
2030        } else {
2031            panic!("expected Text block at index 0");
2032        }
2033        if let ApiContentBlock::ToolUse { id, name, input } = &response.content[1] {
2034            assert_eq!(id, "toolu_abc");
2035            assert_eq!(name, "Read");
2036            assert_eq!(input["path"], "/tmp/f.txt");
2037        } else {
2038            panic!("expected ToolUse block at index 1");
2039        }
2040        assert_eq!(response.stop_reason, Some("tool_use".into()));
2041    }
2042
2043    /// OpenAI/Ollama streaming delivers the complete tool input inside
2044    /// `ContentBlockStart` (no `InputJsonDelta` follows). Verify that
2045    /// `accumulate_stream` preserves that input instead of defaulting to `{}`.
2046    #[tokio::test]
2047    async fn accumulate_stream_preserves_openai_tool_input() {
2048        use crate::client::{ApiContentBlock, ApiUsage, StreamEvent as SE};
2049
2050        let events: Vec<Result<SE>> = vec![
2051            Ok(SE::MessageStart {
2052                message: MessageResponse {
2053                    id: "msg_oai".into(),
2054                    role: "assistant".into(),
2055                    content: vec![],
2056                    model: "qwen3:8b".into(),
2057                    stop_reason: None,
2058                    usage: ApiUsage::default(),
2059                },
2060            }),
2061            // Tool use with full input in ContentBlockStart (OpenAI/Ollama pattern)
2062            Ok(SE::ContentBlockStart {
2063                index: 0,
2064                content_block: ApiContentBlock::ToolUse {
2065                    id: "call_123".into(),
2066                    name: "Bash".into(),
2067                    input: serde_json::json!({"command": "ls -la", "timeout": 5000}),
2068                },
2069            }),
2070            // No InputJsonDelta — OpenAI/Ollama doesn't send one
2071            Ok(SE::ContentBlockStop { index: 0 }),
2072            Ok(SE::MessageDelta {
2073                delta: crate::client::MessageDelta {
2074                    stop_reason: Some("tool_use".into()),
2075                },
2076                usage: ApiUsage {
2077                    input_tokens: 0,
2078                    output_tokens: 10,
2079                    ..Default::default()
2080                },
2081            }),
2082            Ok(SE::MessageStop),
2083        ];
2084
2085        let stream = futures::stream::iter(events);
2086        let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
2087            Box::pin(stream);
2088
2089        let (tx, _rx) = mpsc::unbounded_channel();
2090        let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
2091            .await
2092            .expect("should succeed");
2093
2094        assert_eq!(response.content.len(), 1);
2095        if let ApiContentBlock::ToolUse { id, name, input } = &response.content[0] {
2096            assert_eq!(id, "call_123");
2097            assert_eq!(name, "Bash");
2098            assert_eq!(input["command"], "ls -la");
2099            assert_eq!(input["timeout"], 5000);
2100        } else {
2101            panic!("expected ToolUse block");
2102        }
2103    }
2104}