Skip to main content

agent_sdk/
query.rs

1//! The query function and agent loop implementation.
2//!
3//! This module contains the core `query()` function that creates an async stream
4//! of messages, driving Claude through the agentic loop of prompt → response →
5//! tool calls → tool results → repeat.
6
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Instant;
12
13use futures::stream::FuturesUnordered;
14use futures::{Stream, StreamExt as FuturesStreamExt};
15use serde_json::json;
16use tokio::sync::mpsc;
17use tokio_stream::wrappers::UnboundedReceiverStream;
18use tracing::{debug, error, warn};
19use uuid::Uuid;
20
21use crate::client::{
22    ApiContentBlock, ApiMessage, ApiUsage, CacheControl, ContentDelta, CreateMessageRequest,
23    ImageSource, MessageResponse, StreamEvent as ClientStreamEvent, SystemBlock, ThinkingParam,
24    ToolDefinition,
25};
26use crate::compact;
27use crate::error::{AgentError, Result};
28use crate::sanitize;
29use crate::hooks::HookRegistry;
30use crate::options::{Options, PermissionMode, ThinkingConfig};
31use crate::permissions::{PermissionEvaluator, PermissionVerdict};
32use crate::provider::LlmProvider;
33use crate::providers::AnthropicProvider;
34use crate::session::Session;
35use crate::tools::definitions::get_tool_definitions;
36use crate::tools::executor::{ToolExecutor, ToolResult};
37use crate::types::messages::*;
38
39/// Default model to use when none is specified.
40const DEFAULT_MODEL: &str = "claude-haiku-4-5";
41/// Default max tokens for API responses.
42const DEFAULT_MAX_TOKENS: u32 = 16384;
43
44/// A handle to a running query that streams messages.
45///
46/// Implements `Stream<Item = Result<Message>>` for async iteration.
47pub struct Query {
48    receiver: UnboundedReceiverStream<Result<Message>>,
49    session_id: Option<String>,
50    cancel_token: tokio_util::sync::CancellationToken,
51}
52
53impl Query {
54    /// Interrupt the current query.
55    pub async fn interrupt(&self) -> Result<()> {
56        self.cancel_token.cancel();
57        Ok(())
58    }
59
60    /// Get the session ID (available after the init message).
61    pub fn session_id(&self) -> Option<&str> {
62        self.session_id.as_deref()
63    }
64
65    /// Change the permission mode mid-session.
66    pub async fn set_permission_mode(&self, _mode: PermissionMode) -> Result<()> {
67        // TODO: Send control message to the running agent loop
68        Ok(())
69    }
70
71    /// Change the model mid-session.
72    pub async fn set_model(&self, _model: &str) -> Result<()> {
73        // TODO: Send control message to the running agent loop
74        Ok(())
75    }
76
77    /// Close the query and terminate the underlying process.
78    pub fn close(&self) {
79        self.cancel_token.cancel();
80    }
81}
82
83impl Stream for Query {
84    type Item = Result<Message>;
85
86    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
87        Pin::new(&mut self.receiver).poll_next(cx)
88    }
89}
90
91/// Create a query that streams messages from Claude.
92///
93/// This is the primary function for interacting with the Claude Agent SDK.
94/// Returns a [`Query`] stream that yields [`Message`] items as the agent loop
95/// progresses.
96///
97/// # Arguments
98///
99/// * `prompt` - The input prompt string
100/// * `options` - Configuration options for the query
101///
102/// # Example
103///
104/// ```rust,no_run
105/// use agent_sdk::{query, Options, Message};
106/// use tokio_stream::StreamExt;
107///
108/// # async fn example() -> anyhow::Result<()> {
109/// let mut stream = query(
110///     "What files are in this directory?",
111///     Options::builder()
112///         .allowed_tools(vec!["Bash".into(), "Glob".into()])
113///         .build(),
114/// );
115///
116/// while let Some(message) = stream.next().await {
117///     let message = message?;
118///     if let Message::Result(result) = &message {
119///         println!("{}", result.result.as_deref().unwrap_or(""));
120///     }
121/// }
122/// # Ok(())
123/// # }
124/// ```
125pub fn query(prompt: &str, options: Options) -> Query {
126    let (tx, rx) = mpsc::unbounded_channel();
127    let cancel_token = tokio_util::sync::CancellationToken::new();
128    let cancel = cancel_token.clone();
129
130    let prompt = prompt.to_string();
131
132    tokio::spawn(async move {
133        let result = run_agent_loop(prompt, options, tx.clone(), cancel).await;
134        if let Err(e) = result {
135            let _ = tx.send(Err(e));
136        }
137    });
138
139    Query {
140        receiver: UnboundedReceiverStream::new(rx),
141        session_id: None,
142        cancel_token,
143    }
144}
145
146/// The main agent loop.
147///
148/// This implements the core cycle:
149/// 1. Receive prompt
150/// 2. Send to Claude
151/// 3. Process response (text + tool calls)
152/// 4. Execute tools
153/// 5. Feed results back
154/// 6. Repeat until done or limits hit
155async fn run_agent_loop(
156    prompt: String,
157    mut options: Options,
158    tx: mpsc::UnboundedSender<Result<Message>>,
159    cancel: tokio_util::sync::CancellationToken,
160) -> Result<()> {
161    let start_time = Instant::now();
162    let mut api_time_ms: u64 = 0;
163
164    // Resolve working directory
165    let cwd = options
166        .cwd
167        .clone()
168        .unwrap_or_else(|| {
169            std::env::current_dir()
170                .unwrap_or_else(|_| PathBuf::from("."))
171                .to_string_lossy()
172                .to_string()
173        });
174
175    // Create or resume session
176    let session = if let Some(ref resume_id) = options.resume {
177        Session::with_id(resume_id, &cwd)
178    } else if options.continue_session {
179        // Find most recent session
180        match crate::session::find_most_recent_session(Some(&cwd)).await? {
181            Some(info) => Session::with_id(&info.session_id, &cwd),
182            None => Session::new(&cwd),
183        }
184    } else {
185        match &options.session_id {
186            Some(id) => Session::with_id(id, &cwd),
187            None => Session::new(&cwd),
188        }
189    };
190
191    let session_id = session.id.clone();
192    let model = options
193        .model
194        .clone()
195        .unwrap_or_else(|| DEFAULT_MODEL.to_string());
196
197    // Build tool definitions (skip tools entirely when output_format is set —
198    // structured-output queries should not use tools).
199    let tool_names: Vec<String> = if options.output_format.is_some() {
200        Vec::new()
201    } else if options.allowed_tools.is_empty() {
202        // Default set of tools
203        vec![
204            "Read".into(), "Write".into(), "Edit".into(), "Bash".into(),
205            "Glob".into(), "Grep".into(),
206        ]
207    } else {
208        options.allowed_tools.clone()
209    };
210
211    let raw_defs: Vec<_> = get_tool_definitions(&tool_names);
212
213    // Combine built-in + custom tool definitions
214    let mut all_defs: Vec<ToolDefinition> = raw_defs
215        .into_iter()
216        .map(|td| ToolDefinition {
217            name: td.name.to_string(),
218            description: td.description.to_string(),
219            input_schema: td.input_schema,
220            cache_control: None,
221        })
222        .collect();
223
224    // Append custom tool definitions
225    for ctd in &options.custom_tool_definitions {
226        all_defs.push(ToolDefinition {
227            name: ctd.name.clone(),
228            description: ctd.description.clone(),
229            input_schema: ctd.input_schema.clone(),
230            cache_control: None,
231        });
232    }
233
234    // Mark the last tool with cache_control so the tools block is cached
235    if let Some(last) = all_defs.last_mut() {
236        last.cache_control = Some(CacheControl::ephemeral());
237    }
238
239    let tool_defs = all_defs;
240
241    // Emit init system message
242    let init_msg = Message::System(SystemMessage {
243        subtype: SystemSubtype::Init,
244        uuid: Uuid::new_v4(),
245        session_id: session_id.clone(),
246        agents: if options.agents.is_empty() {
247            None
248        } else {
249            Some(options.agents.keys().cloned().collect())
250        },
251        claude_code_version: Some(env!("CARGO_PKG_VERSION").to_string()),
252        cwd: Some(cwd.clone()),
253        tools: Some(tool_names.clone()),
254        mcp_servers: if options.mcp_servers.is_empty() {
255            None
256        } else {
257            Some(
258                options
259                    .mcp_servers
260                    .keys()
261                    .map(|name| McpServerStatus {
262                        name: name.clone(),
263                        status: "connected".to_string(),
264                    })
265                    .collect(),
266            )
267        },
268        model: Some(model.clone()),
269        permission_mode: Some(options.permission_mode.to_string()),
270        compact_metadata: None,
271    });
272
273    // Persist and emit init message
274    if options.persist_session {
275        let _ = session.append_message(&serde_json::to_value(&init_msg).unwrap_or_default()).await;
276    }
277    if tx.send(Ok(init_msg)).is_err() {
278        return Ok(());
279    }
280
281    // Initialize LLM provider
282    let provider: Box<dyn LlmProvider> = match options.provider.take() {
283        Some(p) => p,
284        None => Box::new(AnthropicProvider::from_env()?),
285    };
286
287    // Initialize tool executor with optional path boundary
288    let additional_dirs: Vec<PathBuf> = options
289        .additional_directories
290        .iter()
291        .map(PathBuf::from)
292        .collect();
293    let env_blocklist = std::mem::take(&mut options.env_blocklist);
294    let tool_executor = if additional_dirs.is_empty() {
295        ToolExecutor::new(PathBuf::from(&cwd))
296    } else {
297        ToolExecutor::with_allowed_dirs(PathBuf::from(&cwd), additional_dirs)
298    }.with_env_blocklist(env_blocklist);
299
300    // Build hook registry from options, merging file-discovered hooks
301    let mut hook_registry = HookRegistry::from_map(std::mem::take(&mut options.hooks));
302    if !options.hook_dirs.is_empty() {
303        let dirs: Vec<&std::path::Path> = options.hook_dirs.iter().map(|p| p.as_path()).collect();
304        match crate::hooks::HookDiscovery::discover(&dirs) {
305            Ok(discovered) => hook_registry.merge(discovered),
306            Err(e) => tracing::warn!("Failed to discover hooks from dirs: {}", e),
307        }
308    }
309
310    // Take followup_rx out of options before borrowing options immutably
311    let mut followup_rx = options.followup_rx.take();
312
313    // Initialize permission evaluator
314    let permission_eval = PermissionEvaluator::new(&options);
315
316    // Build the system prompt as SystemBlock(s) with prompt caching
317    let system_prompt: Option<Vec<SystemBlock>> = {
318        let text = match &options.system_prompt {
319            Some(crate::options::SystemPrompt::Custom(s)) => s.clone(),
320            Some(crate::options::SystemPrompt::Preset { append, .. }) => {
321                let base = "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.";
322                match append {
323                    Some(extra) => format!("{}\n\n{}", base, extra),
324                    None => base.to_string(),
325                }
326            }
327            None => "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.".to_string(),
328        };
329        Some(vec![SystemBlock {
330            kind: "text".to_string(),
331            text,
332            cache_control: Some(CacheControl::ephemeral()),
333        }])
334    };
335
336    // Build initial conversation from prompt
337    let mut conversation: Vec<ApiMessage> = Vec::new();
338
339    // Load previous messages if resuming
340    if options.resume.is_some() || options.continue_session {
341        let prev_messages = session.load_messages().await?;
342        for msg_value in prev_messages {
343            if let Some(api_msg) = value_to_api_message(&msg_value) {
344                conversation.push(api_msg);
345            }
346        }
347    }
348
349    // Add the user prompt (with optional image attachments)
350    {
351        let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
352
353        // Add image attachments as Image content blocks
354        for att in &options.attachments {
355            let is_image = matches!(
356                att.mime_type.as_str(),
357                "image/png" | "image/jpeg" | "image/gif" | "image/webp"
358            );
359            if is_image {
360                content_blocks.push(ApiContentBlock::Image {
361                    source: ImageSource {
362                        kind: "base64".to_string(),
363                        media_type: att.mime_type.clone(),
364                        data: att.base64_data.clone(),
365                    },
366                });
367            }
368        }
369
370        // Add the text prompt
371        content_blocks.push(ApiContentBlock::Text {
372            text: prompt.clone(),
373            cache_control: None,
374        });
375
376        conversation.push(ApiMessage {
377            role: "user".to_string(),
378            content: content_blocks,
379        });
380    }
381
382    // Persist user message
383    if options.persist_session {
384        let user_msg = json!({
385            "type": "user",
386            "uuid": Uuid::new_v4().to_string(),
387            "session_id": &session_id,
388            "content": [{"type": "text", "text": &prompt}]
389        });
390        let _ = session.append_message(&user_msg).await;
391    }
392
393    // Agent loop
394    let mut num_turns: u32 = 0;
395    let mut total_usage = Usage::default();
396    let mut total_cost: f64 = 0.0;
397    let mut model_usage: HashMap<String, ModelUsage> = HashMap::new();
398    let mut permission_denials: Vec<PermissionDenial> = Vec::new();
399
400    loop {
401        // Check cancellation
402        if cancel.is_cancelled() {
403            return Err(AgentError::Cancelled);
404        }
405
406        // Check turn limit
407        if let Some(max_turns) = options.max_turns {
408            if num_turns >= max_turns {
409                let result_msg = build_result_message(
410                    ResultSubtype::ErrorMaxTurns,
411                    &session_id,
412                    None,
413                    start_time,
414                    api_time_ms,
415                    num_turns,
416                    total_cost,
417                    &total_usage,
418                    &model_usage,
419                    &permission_denials,
420                );
421                let _ = tx.send(Ok(result_msg));
422                return Ok(());
423            }
424        }
425
426        // Check budget limit
427        if let Some(max_budget) = options.max_budget_usd {
428            if total_cost >= max_budget {
429                let result_msg = build_result_message(
430                    ResultSubtype::ErrorMaxBudgetUsd,
431                    &session_id,
432                    None,
433                    start_time,
434                    api_time_ms,
435                    num_turns,
436                    total_cost,
437                    &total_usage,
438                    &model_usage,
439                    &permission_denials,
440                );
441                let _ = tx.send(Ok(result_msg));
442                return Ok(());
443            }
444        }
445
446        // Drain any followup messages that arrived while we were processing.
447        // These are batched into a single user message appended to the conversation
448        // so the model sees them on the next API call.
449        if let Some(ref mut followup_rx) = followup_rx {
450            let mut followups: Vec<String> = Vec::new();
451            while let Ok(msg) = followup_rx.try_recv() {
452                followups.push(msg);
453            }
454            if !followups.is_empty() {
455                let combined = followups.join("\n\n");
456                debug!(count = followups.len(), "Injecting followup messages into agent loop");
457
458                conversation.push(ApiMessage {
459                    role: "user".to_string(),
460                    content: vec![ApiContentBlock::Text {
461                        text: combined.clone(),
462                        cache_control: None,
463                    }],
464                });
465
466                // Emit a user message so downstream consumers know about the injection
467                let followup_msg = Message::User(UserMessage {
468                    uuid: Some(Uuid::new_v4()),
469                    session_id: session_id.clone(),
470                    content: vec![ContentBlock::Text { text: combined }],
471                    parent_tool_use_id: None,
472                    is_synthetic: false,
473                    tool_use_result: None,
474                });
475
476                if options.persist_session {
477                    let _ = session
478                        .append_message(&serde_json::to_value(&followup_msg).unwrap_or_default())
479                        .await;
480                }
481                if tx.send(Ok(followup_msg)).is_err() {
482                    return Ok(());
483                }
484            }
485        }
486
487        // Set a cache breakpoint on the last content block of the last user
488        // message. This keeps the total breakpoints at 3 (system + tools + last
489        // user turn), well within the API limit of 4.
490        apply_cache_breakpoint(&mut conversation);
491
492        // Build thinking param from options
493        let thinking_param = options.thinking.as_ref().map(|tc| match tc {
494            ThinkingConfig::Adaptive => ThinkingParam {
495                kind: "enabled".into(),
496                budget_tokens: Some(10240),
497            },
498            ThinkingConfig::Disabled => ThinkingParam {
499                kind: "disabled".into(),
500                budget_tokens: None,
501            },
502            ThinkingConfig::Enabled { budget_tokens } => ThinkingParam {
503                kind: "enabled".into(),
504                budget_tokens: Some(*budget_tokens),
505            },
506        });
507
508        // Increase max_tokens when thinking is enabled
509        let base_max_tokens = options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
510        let max_tokens = if let Some(ref tp) = thinking_param {
511            if let Some(budget) = tp.budget_tokens {
512                base_max_tokens.max(budget as u32 + 8192)
513            } else {
514                base_max_tokens
515            }
516        } else {
517            base_max_tokens
518        };
519
520        // Build the API request
521        let use_streaming = options.include_partial_messages;
522        let request = CreateMessageRequest {
523            model: model.clone(),
524            max_tokens,
525            messages: conversation.clone(),
526            system: system_prompt.clone(),
527            tools: if tool_defs.is_empty() {
528                None
529            } else {
530                Some(tool_defs.clone())
531            },
532            stream: use_streaming,
533            metadata: None,
534            thinking: thinking_param,
535        };
536
537        // Call LLM provider
538        let api_start = Instant::now();
539        let response = if use_streaming {
540            // Streaming mode: consume SSE events, emit text deltas, accumulate full response
541            match provider.create_message_stream(&request).await {
542                Ok(mut event_stream) => {
543                    match accumulate_stream(
544                        &mut event_stream,
545                        &tx,
546                        &session_id,
547                    ).await {
548                        Ok(resp) => resp,
549                        Err(e) => {
550                            error!("Stream accumulation failed: {}", e);
551                            let result_msg = build_error_result_message(
552                                &session_id,
553                                &format!("Stream error: {}", e),
554                                start_time,
555                                api_time_ms,
556                                num_turns,
557                                total_cost,
558                                &total_usage,
559                                &model_usage,
560                                &permission_denials,
561                            );
562                            let _ = tx.send(Ok(result_msg));
563                            return Ok(());
564                        }
565                    }
566                }
567                Err(e) => {
568                    error!("API stream call failed: {}", e);
569                    let result_msg = build_error_result_message(
570                        &session_id,
571                        &format!("API error: {}", e),
572                        start_time,
573                        api_time_ms,
574                        num_turns,
575                        total_cost,
576                        &total_usage,
577                        &model_usage,
578                        &permission_denials,
579                    );
580                    let _ = tx.send(Ok(result_msg));
581                    return Ok(());
582                }
583            }
584        } else {
585            // Non-streaming mode: single request/response
586            match provider.create_message(&request).await {
587                Ok(resp) => resp,
588                Err(e) => {
589                    error!("API call failed: {}", e);
590                    let result_msg = build_error_result_message(
591                        &session_id,
592                        &format!("API error: {}", e),
593                        start_time,
594                        api_time_ms,
595                        num_turns,
596                        total_cost,
597                        &total_usage,
598                        &model_usage,
599                        &permission_denials,
600                    );
601                    let _ = tx.send(Ok(result_msg));
602                    return Ok(());
603                }
604            }
605        };
606        api_time_ms += api_start.elapsed().as_millis() as u64;
607
608        // Update usage
609        total_usage.input_tokens += response.usage.input_tokens;
610        total_usage.output_tokens += response.usage.output_tokens;
611        total_usage.cache_creation_input_tokens +=
612            response.usage.cache_creation_input_tokens.unwrap_or(0);
613        total_usage.cache_read_input_tokens +=
614            response.usage.cache_read_input_tokens.unwrap_or(0);
615
616        // Estimate cost using provider-specific rates (with cache-aware pricing)
617        let rates = provider.cost_rates(&model);
618        let turn_cost = rates.compute_with_cache(
619            response.usage.input_tokens,
620            response.usage.output_tokens,
621            response.usage.cache_read_input_tokens.unwrap_or(0),
622            response.usage.cache_creation_input_tokens.unwrap_or(0),
623        );
624        total_cost += turn_cost;
625
626        // Update model usage
627        let model_entry = model_usage
628            .entry(model.clone())
629            .or_insert_with(ModelUsage::default);
630        model_entry.input_tokens += response.usage.input_tokens;
631        model_entry.output_tokens += response.usage.output_tokens;
632        model_entry.cost_usd += turn_cost;
633
634        // Convert response to our message types
635        let content_blocks: Vec<ContentBlock> = response
636            .content
637            .iter()
638            .map(api_block_to_content_block)
639            .collect();
640
641        // Emit assistant message
642        let assistant_msg = Message::Assistant(AssistantMessage {
643            uuid: Uuid::new_v4(),
644            session_id: session_id.clone(),
645            content: content_blocks.clone(),
646            model: response.model.clone(),
647            stop_reason: response.stop_reason.clone(),
648            parent_tool_use_id: None,
649            usage: Some(Usage {
650                input_tokens: response.usage.input_tokens,
651                output_tokens: response.usage.output_tokens,
652                cache_creation_input_tokens: response
653                    .usage
654                    .cache_creation_input_tokens
655                    .unwrap_or(0),
656                cache_read_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
657            }),
658            error: None,
659        });
660
661        if options.persist_session {
662            let _ = session
663                .append_message(&serde_json::to_value(&assistant_msg).unwrap_or_default())
664                .await;
665        }
666        if tx.send(Ok(assistant_msg)).is_err() {
667            return Ok(());
668        }
669
670        // Add assistant response to conversation
671        conversation.push(ApiMessage {
672            role: "assistant".to_string(),
673            content: response.content.clone(),
674        });
675
676        // Check if there are tool calls
677        let tool_uses: Vec<_> = response
678            .content
679            .iter()
680            .filter_map(|block| match block {
681                ApiContentBlock::ToolUse { id, name, input } => {
682                    Some((id.clone(), name.clone(), input.clone()))
683                }
684                _ => None,
685            })
686            .collect();
687
688        // If no tool calls, we're done
689        if tool_uses.is_empty() {
690            // Extract final text
691            let final_text = response
692                .content
693                .iter()
694                .filter_map(|block| match block {
695                    ApiContentBlock::Text { text, .. } => Some(text.as_str()),
696                    _ => None,
697                })
698                .collect::<Vec<_>>()
699                .join("");
700
701            let result_msg = build_result_message(
702                ResultSubtype::Success,
703                &session_id,
704                Some(final_text),
705                start_time,
706                api_time_ms,
707                num_turns,
708                total_cost,
709                &total_usage,
710                &model_usage,
711                &permission_denials,
712            );
713
714            if options.persist_session {
715                let _ = session
716                    .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
717                    .await;
718            }
719            let _ = tx.send(Ok(result_msg));
720            return Ok(());
721        }
722
723        // Execute tool calls
724        num_turns += 1;
725        let mut tool_results: Vec<ApiContentBlock> = Vec::new();
726
727        // Phase 1: Evaluate permissions sequentially (may involve user interaction)
728        struct PermittedTool {
729            tool_use_id: String,
730            tool_name: String,
731            actual_input: serde_json::Value,
732        }
733        let mut permitted_tools: Vec<PermittedTool> = Vec::new();
734
735        for (tool_use_id, tool_name, tool_input) in &tool_uses {
736            let verdict = permission_eval
737                .evaluate(tool_name, tool_input, tool_use_id, &session_id, &cwd)
738                .await?;
739
740            let actual_input = match &verdict {
741                PermissionVerdict::AllowWithUpdatedInput(new_input) => new_input.clone(),
742                _ => tool_input.clone(),
743            };
744
745            match verdict {
746                PermissionVerdict::Allow | PermissionVerdict::AllowWithUpdatedInput(_) => {
747                    permitted_tools.push(PermittedTool {
748                        tool_use_id: tool_use_id.clone(),
749                        tool_name: tool_name.clone(),
750                        actual_input,
751                    });
752                }
753                PermissionVerdict::Deny { reason } => {
754                    debug!(tool = %tool_name, reason = %reason, "Tool denied");
755                    permission_denials.push(PermissionDenial {
756                        tool_name: tool_name.clone(),
757                        tool_use_id: tool_use_id.clone(),
758                        tool_input: tool_input.clone(),
759                    });
760
761                    let api_block = ApiContentBlock::ToolResult {
762                        tool_use_id: tool_use_id.clone(),
763                        content: json!(format!("Permission denied: {}", reason)),
764                        is_error: Some(true),
765                        cache_control: None,
766                        name: Some(tool_name.clone()),
767                    };
768
769                    // Stream denial result to frontend immediately
770                    let denial_msg = Message::User(UserMessage {
771                        uuid: Some(Uuid::new_v4()),
772                        session_id: session_id.clone(),
773                        content: vec![api_block_to_content_block(&api_block)],
774                        parent_tool_use_id: None,
775                        is_synthetic: true,
776                        tool_use_result: None,
777                    });
778                    if options.persist_session {
779                        let _ = session
780                            .append_message(&serde_json::to_value(&denial_msg).unwrap_or_default())
781                            .await;
782                    }
783                    if tx.send(Ok(denial_msg)).is_err() {
784                        return Ok(());
785                    }
786
787                    tool_results.push(api_block);
788                }
789            }
790        }
791
792        // Phase 2: Execute permitted tools concurrently, stream results as they complete
793        let mut futs: FuturesUnordered<_> = permitted_tools
794            .iter()
795            .map(|pt| {
796                let handler = &options.external_tool_handler;
797                let executor = &tool_executor;
798                let name = &pt.tool_name;
799                let input = &pt.actual_input;
800                let id = &pt.tool_use_id;
801                async move {
802                    debug!(tool = %name, "Executing tool");
803
804                    let tool_result = if let Some(ref handler) = handler {
805                        let ext_result = handler(name.clone(), input.clone()).await;
806                        if let Some(tr) = ext_result {
807                            tr
808                        } else {
809                            match executor.execute(name, input.clone()).await {
810                                Ok(tr) => tr,
811                                Err(e) => ToolResult {
812                                    content: format!("Tool execution error: {}", e),
813                                    is_error: true,
814                                    raw_content: None,
815                                },
816                            }
817                        }
818                    } else {
819                        match executor.execute(name, input.clone()).await {
820                            Ok(tr) => tr,
821                            Err(e) => ToolResult {
822                                content: format!("Tool execution error: {}", e),
823                                is_error: true,
824                                raw_content: None,
825                            },
826                        }
827                    };
828                    (id.as_str(), name.as_str(), input, tool_result)
829                }
830            })
831            .collect();
832
833        while let Some((tool_use_id, tool_name, actual_input, mut tool_result)) = futs.next().await {
834            // Sanitize tool result: strip blobs, enforce byte limit.
835            let max_result_bytes = options
836                .max_tool_result_bytes
837                .unwrap_or(sanitize::DEFAULT_MAX_TOOL_RESULT_BYTES);
838            tool_result.content =
839                sanitize::sanitize_tool_result(&tool_result.content, max_result_bytes);
840
841            // Run PostToolUse hooks
842            hook_registry.run_post_tool_use(
843                tool_name,
844                actual_input,
845                &serde_json::to_value(&tool_result.content).unwrap_or_default(),
846                tool_use_id,
847                &session_id,
848                &cwd,
849            ).await;
850
851            let result_content = tool_result
852                .raw_content
853                .unwrap_or_else(|| json!(tool_result.content));
854
855            let api_block = ApiContentBlock::ToolResult {
856                tool_use_id: tool_use_id.to_string(),
857                content: result_content,
858                is_error: if tool_result.is_error {
859                    Some(true)
860                } else {
861                    None
862                },
863                cache_control: None,
864                name: Some(tool_name.to_string()),
865            };
866
867            // Stream this individual result to the frontend immediately
868            let result_msg = Message::User(UserMessage {
869                uuid: Some(Uuid::new_v4()),
870                session_id: session_id.clone(),
871                content: vec![api_block_to_content_block(&api_block)],
872                parent_tool_use_id: None,
873                is_synthetic: true,
874                tool_use_result: None,
875            });
876            if options.persist_session {
877                let _ = session
878                    .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
879                    .await;
880            }
881            if tx.send(Ok(result_msg)).is_err() {
882                return Ok(());
883            }
884
885            tool_results.push(api_block);
886        }
887
888        // Add all tool results to conversation for the next API call
889        conversation.push(ApiMessage {
890            role: "user".to_string(),
891            content: tool_results,
892        });
893
894        // --- Lightweight pruning (between turns, before full compaction) ---
895        if let Some(context_budget) = options.context_budget {
896            let prune_pct = options
897                .prune_threshold_pct
898                .unwrap_or(compact::DEFAULT_PRUNE_THRESHOLD_PCT);
899            if compact::should_prune(response.usage.input_tokens, context_budget, prune_pct) {
900                let max_chars = options
901                    .prune_tool_result_max_chars
902                    .unwrap_or(compact::DEFAULT_PRUNE_TOOL_RESULT_MAX_CHARS);
903                let min_keep = options.min_keep_messages.unwrap_or(4);
904                let removed = compact::prune_tool_results(
905                    &mut conversation,
906                    max_chars,
907                    min_keep,
908                );
909                if removed > 0 {
910                    debug!(
911                        chars_removed = removed,
912                        input_tokens = response.usage.input_tokens,
913                        "Pruned oversized tool results to free context space"
914                    );
915                }
916            }
917        }
918
919        // --- Compaction check (between turns) ---
920        if let Some(context_budget) = options.context_budget {
921            if compact::should_compact(response.usage.input_tokens, context_budget) {
922                let min_keep = options.min_keep_messages.unwrap_or(4);
923                let split_point = compact::find_split_point(&conversation, min_keep);
924                if split_point > 0 {
925                    debug!(
926                        input_tokens = response.usage.input_tokens,
927                        context_budget,
928                        split_point,
929                        "Context budget exceeded, compacting conversation"
930                    );
931
932                    let compaction_model = options
933                        .compaction_model
934                        .as_deref()
935                        .unwrap_or(compact::DEFAULT_COMPACTION_MODEL);
936
937                    // Fire pre-compact handler so the host can persist key facts
938                    if let Some(ref handler) = options.pre_compact_handler {
939                        let msgs_to_compact = conversation[..split_point].to_vec();
940                        handler(msgs_to_compact).await;
941                    }
942
943                    let summary_prompt =
944                        compact::build_summary_prompt(&conversation[..split_point]);
945
946                    let summary_max_tokens = options.summary_max_tokens.unwrap_or(4096);
947                    let compact_provider: &dyn LlmProvider = match &options.compaction_provider {
948                        Some(cp) => cp.as_ref(),
949                        None => provider.as_ref(),
950                    };
951                    let fallback_provider: Option<&dyn LlmProvider> = if options.compaction_provider.is_some() {
952                        Some(provider.as_ref())
953                    } else {
954                        None
955                    };
956                    match compact::call_summarizer(
957                        compact_provider,
958                        &summary_prompt,
959                        compaction_model,
960                        fallback_provider,
961                        &model,
962                        summary_max_tokens,
963                    )
964                    .await
965                    {
966                        Ok(summary) => {
967                            let pre_tokens = response.usage.input_tokens;
968                            let messages_compacted = split_point;
969
970                            compact::splice_conversation(
971                                &mut conversation,
972                                split_point,
973                                &summary,
974                            );
975
976                            // Emit CompactBoundary system message
977                            let compact_msg = Message::System(SystemMessage {
978                                subtype: SystemSubtype::CompactBoundary,
979                                uuid: Uuid::new_v4(),
980                                session_id: session_id.clone(),
981                                agents: None,
982                                claude_code_version: None,
983                                cwd: None,
984                                tools: None,
985                                mcp_servers: None,
986                                model: None,
987                                permission_mode: None,
988                                compact_metadata: Some(CompactMetadata {
989                                    trigger: CompactTrigger::Auto,
990                                    pre_tokens,
991                                }),
992                            });
993
994                            if options.persist_session {
995                                let _ = session
996                                    .append_message(
997                                        &serde_json::to_value(&compact_msg)
998                                            .unwrap_or_default(),
999                                    )
1000                                    .await;
1001                            }
1002                            let _ = tx.send(Ok(compact_msg));
1003
1004                            debug!(
1005                                pre_tokens,
1006                                messages_compacted,
1007                                summary_len = summary.len(),
1008                                "Conversation compacted"
1009                            );
1010                        }
1011                        Err(e) => {
1012                            warn!(
1013                                "Compaction failed, continuing without compaction: {}",
1014                                e
1015                            );
1016                        }
1017                    }
1018                }
1019            }
1020        }
1021    }
1022}
1023
1024/// Consume a streaming response, emitting `Message::StreamEvent` for each text
1025/// delta, and accumulate the full `MessageResponse` for the agent loop.
1026async fn accumulate_stream(
1027    event_stream: &mut std::pin::Pin<Box<dyn futures::Stream<Item = Result<ClientStreamEvent>> + Send>>,
1028    tx: &mpsc::UnboundedSender<Result<Message>>,
1029    session_id: &str,
1030) -> Result<MessageResponse> {
1031    use crate::client::StreamEvent as SE;
1032
1033    // Accumulated state
1034    let mut message_id = String::new();
1035    let mut model = String::new();
1036    let mut role = String::from("assistant");
1037    let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
1038    let mut stop_reason: Option<String> = None;
1039    let mut usage = ApiUsage::default();
1040
1041    // Track in-progress content blocks by index
1042    // For text blocks: accumulate text. For tool_use: accumulate JSON string.
1043    let mut block_texts: Vec<String> = Vec::new();
1044    let mut block_types: Vec<String> = Vec::new(); // "text", "tool_use", "thinking"
1045    let mut block_tool_ids: Vec<String> = Vec::new();
1046    let mut block_tool_names: Vec<String> = Vec::new();
1047
1048    while let Some(event_result) = FuturesStreamExt::next(event_stream).await {
1049        let event = event_result?;
1050        match event {
1051            SE::MessageStart { message } => {
1052                message_id = message.id;
1053                model = message.model;
1054                role = message.role;
1055                usage = message.usage;
1056            }
1057            SE::ContentBlockStart { index, content_block } => {
1058                // Ensure vectors are large enough
1059                while block_texts.len() <= index {
1060                    block_texts.push(String::new());
1061                    block_types.push(String::new());
1062                    block_tool_ids.push(String::new());
1063                    block_tool_names.push(String::new());
1064                }
1065                match &content_block {
1066                    ApiContentBlock::Text { .. } => {
1067                        block_types[index] = "text".to_string();
1068                    }
1069                    ApiContentBlock::ToolUse { id, name, .. } => {
1070                        block_types[index] = "tool_use".to_string();
1071                        block_tool_ids[index] = id.clone();
1072                        block_tool_names[index] = name.clone();
1073                    }
1074                    ApiContentBlock::Thinking { .. } => {
1075                        block_types[index] = "thinking".to_string();
1076                    }
1077                    _ => {}
1078                }
1079            }
1080            SE::ContentBlockDelta { index, delta } => {
1081                while block_texts.len() <= index {
1082                    block_texts.push(String::new());
1083                    block_types.push(String::new());
1084                    block_tool_ids.push(String::new());
1085                    block_tool_names.push(String::new());
1086                }
1087                match &delta {
1088                    ContentDelta::TextDelta { text } => {
1089                        block_texts[index].push_str(text);
1090                        // Emit streaming event so downstream consumers get per-token updates
1091                        let stream_event = Message::StreamEvent(StreamEventMessage {
1092                            event: serde_json::json!({
1093                                "type": "content_block_delta",
1094                                "index": index,
1095                                "delta": { "type": "text_delta", "text": text }
1096                            }),
1097                            parent_tool_use_id: None,
1098                            uuid: Uuid::new_v4(),
1099                            session_id: session_id.to_string(),
1100                        });
1101                        if tx.send(Ok(stream_event)).is_err() {
1102                            return Err(AgentError::Cancelled);
1103                        }
1104                    }
1105                    ContentDelta::InputJsonDelta { partial_json } => {
1106                        block_texts[index].push_str(partial_json);
1107                    }
1108                    ContentDelta::ThinkingDelta { thinking } => {
1109                        block_texts[index].push_str(thinking);
1110                    }
1111                }
1112            }
1113            SE::ContentBlockStop { index } => {
1114                if index < block_types.len() {
1115                    let block = match block_types[index].as_str() {
1116                        "text" => ApiContentBlock::Text {
1117                            text: std::mem::take(&mut block_texts[index]),
1118                            cache_control: None,
1119                        },
1120                        "tool_use" => {
1121                            let input: serde_json::Value = serde_json::from_str(
1122                                &block_texts[index],
1123                            )
1124                            .unwrap_or(serde_json::Value::Object(Default::default()));
1125                            ApiContentBlock::ToolUse {
1126                                id: std::mem::take(&mut block_tool_ids[index]),
1127                                name: std::mem::take(&mut block_tool_names[index]),
1128                                input,
1129                            }
1130                        }
1131                        "thinking" => ApiContentBlock::Thinking {
1132                            thinking: std::mem::take(&mut block_texts[index]),
1133                        },
1134                        _ => continue,
1135                    };
1136                    // Place blocks at the correct index
1137                    while content_blocks.len() <= index {
1138                        content_blocks.push(ApiContentBlock::Text {
1139                            text: String::new(),
1140                            cache_control: None,
1141                        });
1142                    }
1143                    content_blocks[index] = block;
1144                }
1145            }
1146            SE::MessageDelta { delta, usage: delta_usage } => {
1147                stop_reason = delta.stop_reason;
1148                // MessageDelta carries output_tokens for the whole message
1149                usage.output_tokens = delta_usage.output_tokens;
1150            }
1151            SE::MessageStop => {
1152                break;
1153            }
1154            SE::Error { error } => {
1155                return Err(AgentError::Api(error.message));
1156            }
1157            SE::Ping => {}
1158        }
1159    }
1160
1161    Ok(MessageResponse {
1162        id: message_id,
1163        role,
1164        content: content_blocks,
1165        model,
1166        stop_reason,
1167        usage,
1168    })
1169}
1170
1171/// Apply a single cache breakpoint to the last content block of the last user
1172/// message in the conversation. Clears any previous breakpoints from messages
1173/// so we stay within the API limit of 4 cache_control blocks (system + tools +
1174/// this one = 3 total).
1175fn apply_cache_breakpoint(conversation: &mut [ApiMessage]) {
1176    // First, clear all existing cache_control from messages
1177    for msg in conversation.iter_mut() {
1178        for block in msg.content.iter_mut() {
1179            match block {
1180                ApiContentBlock::Text { cache_control, .. }
1181                | ApiContentBlock::ToolResult { cache_control, .. } => {
1182                    *cache_control = None;
1183                }
1184                ApiContentBlock::Image { .. }
1185                | ApiContentBlock::ToolUse { .. }
1186                | ApiContentBlock::Thinking { .. } => {}
1187            }
1188        }
1189    }
1190
1191    // Set cache_control on the last content block of the last user message
1192    if let Some(last_user) = conversation.iter_mut().rev().find(|m| m.role == "user") {
1193        if let Some(last_block) = last_user.content.last_mut() {
1194            match last_block {
1195                ApiContentBlock::Text { cache_control, .. }
1196                | ApiContentBlock::ToolResult { cache_control, .. } => {
1197                    *cache_control = Some(CacheControl::ephemeral());
1198                }
1199                ApiContentBlock::Image { .. }
1200                | ApiContentBlock::ToolUse { .. }
1201                | ApiContentBlock::Thinking { .. } => {}
1202            }
1203        }
1204    }
1205}
1206
1207/// Convert an API content block to our ContentBlock type.
1208fn api_block_to_content_block(block: &ApiContentBlock) -> ContentBlock {
1209    match block {
1210        ApiContentBlock::Text { text, .. } => ContentBlock::Text {
1211            text: text.clone(),
1212        },
1213        ApiContentBlock::Image { .. } => ContentBlock::Text {
1214            text: "[image]".to_string(),
1215        },
1216        ApiContentBlock::ToolUse { id, name, input } => ContentBlock::ToolUse {
1217            id: id.clone(),
1218            name: name.clone(),
1219            input: input.clone(),
1220        },
1221        ApiContentBlock::ToolResult {
1222            tool_use_id,
1223            content,
1224            is_error,
1225            ..
1226        } => ContentBlock::ToolResult {
1227            tool_use_id: tool_use_id.clone(),
1228            content: content.clone(),
1229            is_error: *is_error,
1230        },
1231        ApiContentBlock::Thinking { thinking } => ContentBlock::Thinking {
1232            thinking: thinking.clone(),
1233        },
1234    }
1235}
1236
1237/// Try to convert a stored JSON value to an API message.
1238fn value_to_api_message(value: &serde_json::Value) -> Option<ApiMessage> {
1239    let msg_type = value.get("type")?.as_str()?;
1240
1241    match msg_type {
1242        "assistant" => {
1243            let content = value.get("content")?;
1244            let blocks = parse_content_blocks(content)?;
1245            Some(ApiMessage {
1246                role: "assistant".to_string(),
1247                content: blocks,
1248            })
1249        }
1250        "user" => {
1251            let content = value.get("content")?;
1252            let blocks = parse_content_blocks(content)?;
1253            Some(ApiMessage {
1254                role: "user".to_string(),
1255                content: blocks,
1256            })
1257        }
1258        _ => None,
1259    }
1260}
1261
1262/// Parse content blocks from a JSON value.
1263fn parse_content_blocks(content: &serde_json::Value) -> Option<Vec<ApiContentBlock>> {
1264    if let Some(text) = content.as_str() {
1265        return Some(vec![ApiContentBlock::Text {
1266            text: text.to_string(),
1267            cache_control: None,
1268        }]);
1269    }
1270
1271    if let Some(blocks) = content.as_array() {
1272        let parsed: Vec<ApiContentBlock> = blocks
1273            .iter()
1274            .filter_map(|b| serde_json::from_value(b.clone()).ok())
1275            .collect();
1276        if !parsed.is_empty() {
1277            return Some(parsed);
1278        }
1279    }
1280
1281    None
1282}
1283
1284/// Build a ResultMessage.
1285fn build_result_message(
1286    subtype: ResultSubtype,
1287    session_id: &str,
1288    result_text: Option<String>,
1289    start_time: Instant,
1290    api_time_ms: u64,
1291    num_turns: u32,
1292    total_cost: f64,
1293    usage: &Usage,
1294    model_usage: &HashMap<String, ModelUsage>,
1295    permission_denials: &[PermissionDenial],
1296) -> Message {
1297    Message::Result(ResultMessage {
1298        subtype,
1299        uuid: Uuid::new_v4(),
1300        session_id: session_id.to_string(),
1301        duration_ms: start_time.elapsed().as_millis() as u64,
1302        duration_api_ms: api_time_ms,
1303        is_error: result_text.is_none(),
1304        num_turns,
1305        result: result_text,
1306        stop_reason: Some("end_turn".to_string()),
1307        total_cost_usd: total_cost,
1308        usage: Some(usage.clone()),
1309        model_usage: model_usage.clone(),
1310        permission_denials: permission_denials.to_vec(),
1311        structured_output: None,
1312        errors: Vec::new(),
1313    })
1314}
1315
1316/// Build an error ResultMessage.
1317fn build_error_result_message(
1318    session_id: &str,
1319    error_msg: &str,
1320    start_time: Instant,
1321    api_time_ms: u64,
1322    num_turns: u32,
1323    total_cost: f64,
1324    usage: &Usage,
1325    model_usage: &HashMap<String, ModelUsage>,
1326    permission_denials: &[PermissionDenial],
1327) -> Message {
1328    Message::Result(ResultMessage {
1329        subtype: ResultSubtype::ErrorDuringExecution,
1330        uuid: Uuid::new_v4(),
1331        session_id: session_id.to_string(),
1332        duration_ms: start_time.elapsed().as_millis() as u64,
1333        duration_api_ms: api_time_ms,
1334        is_error: true,
1335        num_turns,
1336        result: None,
1337        stop_reason: None,
1338        total_cost_usd: total_cost,
1339        usage: Some(usage.clone()),
1340        model_usage: model_usage.clone(),
1341        permission_denials: permission_denials.to_vec(),
1342        structured_output: None,
1343        errors: vec![error_msg.to_string()],
1344    })
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349    use super::*;
1350    use std::sync::atomic::{AtomicUsize, Ordering};
1351    use std::sync::Arc;
1352    use std::time::Duration;
1353
1354    /// Helper: execute tools concurrently using the same FuturesUnordered pattern
1355    /// as the production code, collecting (tool_use_id, content, completion_order).
1356    async fn run_concurrent_tools(
1357        tools: Vec<(String, String, serde_json::Value)>,
1358        handler: impl Fn(String, serde_json::Value) -> Pin<Box<dyn futures::Future<Output = Option<ToolResult>> + Send>>,
1359    ) -> Vec<(String, String, usize)> {
1360        let order = Arc::new(AtomicUsize::new(0));
1361        let handler = Arc::new(handler);
1362
1363        struct PermittedTool {
1364            tool_use_id: String,
1365            tool_name: String,
1366            actual_input: serde_json::Value,
1367        }
1368
1369        let permitted: Vec<PermittedTool> = tools
1370            .into_iter()
1371            .map(|(id, name, input)| PermittedTool {
1372                tool_use_id: id,
1373                tool_name: name,
1374                actual_input: input,
1375            })
1376            .collect();
1377
1378        let mut futs: FuturesUnordered<_> = permitted
1379            .iter()
1380            .map(|pt| {
1381                let handler = handler.clone();
1382                let order = order.clone();
1383                let name = pt.tool_name.clone();
1384                let input = pt.actual_input.clone();
1385                let id = pt.tool_use_id.clone();
1386                async move {
1387                    let result = handler(name, input).await;
1388                    let seq = order.fetch_add(1, Ordering::SeqCst);
1389                    (id, result, seq)
1390                }
1391            })
1392            .collect();
1393
1394        let mut results = Vec::new();
1395        while let Some((id, result, seq)) = futs.next().await {
1396            let content = result
1397                .map(|r| r.content)
1398                .unwrap_or_else(|| "no handler".into());
1399            results.push((id, content, seq));
1400        }
1401        results
1402    }
1403
1404    #[tokio::test]
1405    async fn concurrent_tools_all_complete() {
1406        let results = run_concurrent_tools(
1407            vec![
1408                ("t1".into(), "Read".into(), json!({"path": "a.txt"})),
1409                ("t2".into(), "Read".into(), json!({"path": "b.txt"})),
1410                ("t3".into(), "Read".into(), json!({"path": "c.txt"})),
1411            ],
1412            |name, input| {
1413                Box::pin(async move {
1414                    let path = input["path"].as_str().unwrap_or("?");
1415                    Some(ToolResult {
1416                        content: format!("{}: {}", name, path),
1417                        is_error: false,
1418                        raw_content: None,
1419                    })
1420                })
1421            },
1422        )
1423        .await;
1424
1425        assert_eq!(results.len(), 3);
1426        let ids: Vec<&str> = results.iter().map(|(id, _, _)| id.as_str()).collect();
1427        assert!(ids.contains(&"t1"));
1428        assert!(ids.contains(&"t2"));
1429        assert!(ids.contains(&"t3"));
1430    }
1431
1432    #[tokio::test]
1433    async fn slow_tool_does_not_block_fast_tools() {
1434        let start = Instant::now();
1435
1436        let results = run_concurrent_tools(
1437            vec![
1438                ("slow".into(), "Bash".into(), json!({})),
1439                ("fast1".into(), "Read".into(), json!({})),
1440                ("fast2".into(), "Read".into(), json!({})),
1441            ],
1442            |name, _input| {
1443                Box::pin(async move {
1444                    if name == "Bash" {
1445                        tokio::time::sleep(Duration::from_millis(200)).await;
1446                        Some(ToolResult {
1447                            content: "slow done".into(),
1448                            is_error: false,
1449                            raw_content: None,
1450                        })
1451                    } else {
1452                        // Fast tools complete immediately
1453                        Some(ToolResult {
1454                            content: "fast done".into(),
1455                            is_error: false,
1456                            raw_content: None,
1457                        })
1458                    }
1459                })
1460            },
1461        )
1462        .await;
1463
1464        let elapsed = start.elapsed();
1465
1466        // All three should complete
1467        assert_eq!(results.len(), 3);
1468
1469        // Fast tools should complete before the slow tool (lower order index)
1470        let slow = results.iter().find(|(id, _, _)| id == "slow").unwrap();
1471        let fast1 = results.iter().find(|(id, _, _)| id == "fast1").unwrap();
1472        let fast2 = results.iter().find(|(id, _, _)| id == "fast2").unwrap();
1473
1474        assert!(fast1.2 < slow.2, "fast1 should complete before slow");
1475        assert!(fast2.2 < slow.2, "fast2 should complete before slow");
1476
1477        // Total time should be ~200ms (concurrent), not ~400ms+ (sequential)
1478        assert!(
1479            elapsed < Duration::from_millis(400),
1480            "elapsed {:?} should be under 400ms (concurrent execution)",
1481            elapsed
1482        );
1483    }
1484
1485    #[tokio::test]
1486    async fn results_streamed_individually_as_they_complete() {
1487        // Simulate the streaming pattern from the production code:
1488        // each tool result is sent to the channel as it completes.
1489        let (tx, mut rx) = mpsc::unbounded_channel::<(String, String)>();
1490
1491        let tools = vec![
1492            ("t_slow".into(), "Slow".into(), json!({})),
1493            ("t_fast".into(), "Fast".into(), json!({})),
1494        ];
1495
1496        struct PT {
1497            tool_use_id: String,
1498            tool_name: String,
1499        }
1500
1501        let permitted: Vec<PT> = tools
1502            .into_iter()
1503            .map(|(id, name, _)| PT {
1504                tool_use_id: id,
1505                tool_name: name,
1506            })
1507            .collect();
1508
1509        let mut futs: FuturesUnordered<_> = permitted
1510            .iter()
1511            .map(|pt| {
1512                let name = pt.tool_name.clone();
1513                let id = pt.tool_use_id.clone();
1514                async move {
1515                    if name == "Slow" {
1516                        tokio::time::sleep(Duration::from_millis(100)).await;
1517                    }
1518                    let result = ToolResult {
1519                        content: format!("{} result", name),
1520                        is_error: false,
1521                        raw_content: None,
1522                    };
1523                    (id, result)
1524                }
1525            })
1526            .collect();
1527
1528        // Process results as they complete (like production code)
1529        while let Some((id, result)) = futs.next().await {
1530            tx.send((id, result.content)).unwrap();
1531        }
1532        drop(tx);
1533
1534        // Collect what was streamed
1535        let mut streamed = Vec::new();
1536        while let Some(item) = rx.recv().await {
1537            streamed.push(item);
1538        }
1539
1540        assert_eq!(streamed.len(), 2);
1541        // Fast should arrive first
1542        assert_eq!(streamed[0].0, "t_fast");
1543        assert_eq!(streamed[0].1, "Fast result");
1544        assert_eq!(streamed[1].0, "t_slow");
1545        assert_eq!(streamed[1].1, "Slow result");
1546    }
1547
1548    #[tokio::test]
1549    async fn error_tool_does_not_prevent_other_tools() {
1550        let results = run_concurrent_tools(
1551            vec![
1552                ("t_ok".into(), "Read".into(), json!({})),
1553                ("t_err".into(), "Fail".into(), json!({})),
1554            ],
1555            |name, _input| {
1556                Box::pin(async move {
1557                    if name == "Fail" {
1558                        Some(ToolResult {
1559                            content: "something went wrong".into(),
1560                            is_error: true,
1561                            raw_content: None,
1562                        })
1563                    } else {
1564                        Some(ToolResult {
1565                            content: "ok".into(),
1566                            is_error: false,
1567                            raw_content: None,
1568                        })
1569                    }
1570                })
1571            },
1572        )
1573        .await;
1574
1575        assert_eq!(results.len(), 2);
1576        let ok = results.iter().find(|(id, _, _)| id == "t_ok").unwrap();
1577        let err = results.iter().find(|(id, _, _)| id == "t_err").unwrap();
1578        assert_eq!(ok.1, "ok");
1579        assert_eq!(err.1, "something went wrong");
1580    }
1581
1582    #[tokio::test]
1583    async fn external_handler_none_falls_through_correctly() {
1584        // When handler returns None for a tool, the production code falls through
1585        // to the built-in executor. Test that the pattern works.
1586        let results = run_concurrent_tools(
1587            vec![
1588                ("t_custom".into(), "MyTool".into(), json!({"x": 1})),
1589                ("t_builtin".into(), "Read".into(), json!({"path": "/tmp"})),
1590            ],
1591            |name, _input| {
1592                Box::pin(async move {
1593                    if name == "MyTool" {
1594                        Some(ToolResult {
1595                            content: "custom handled".into(),
1596                            is_error: false,
1597                            raw_content: None,
1598                        })
1599                    } else {
1600                        // Returns None => would fall through to built-in executor
1601                        None
1602                    }
1603                })
1604            },
1605        )
1606        .await;
1607
1608        assert_eq!(results.len(), 2);
1609        let custom = results.iter().find(|(id, _, _)| id == "t_custom").unwrap();
1610        let builtin = results.iter().find(|(id, _, _)| id == "t_builtin").unwrap();
1611        assert_eq!(custom.1, "custom handled");
1612        assert_eq!(builtin.1, "no handler"); // our test helper treats None as "no handler"
1613    }
1614
1615    #[tokio::test]
1616    async fn single_tool_works_same_as_before() {
1617        let results = run_concurrent_tools(
1618            vec![("t1".into(), "Read".into(), json!({"path": "file.txt"}))],
1619            |_name, _input| {
1620                Box::pin(async move {
1621                    Some(ToolResult {
1622                        content: "file contents".into(),
1623                        is_error: false,
1624                        raw_content: None,
1625                    })
1626                })
1627            },
1628        )
1629        .await;
1630
1631        assert_eq!(results.len(), 1);
1632        assert_eq!(results[0].0, "t1");
1633        assert_eq!(results[0].1, "file contents");
1634        assert_eq!(results[0].2, 0); // first (and only) completion
1635    }
1636
1637    #[tokio::test]
1638    async fn empty_tool_list_produces_no_results() {
1639        let results = run_concurrent_tools(vec![], |_name, _input| {
1640            Box::pin(async move { None })
1641        })
1642        .await;
1643
1644        assert_eq!(results.len(), 0);
1645    }
1646
1647    #[tokio::test]
1648    async fn tool_use_ids_preserved_through_concurrent_execution() {
1649        let results = run_concurrent_tools(
1650            vec![
1651                ("toolu_abc123".into(), "Read".into(), json!({})),
1652                ("toolu_def456".into(), "Write".into(), json!({})),
1653                ("toolu_ghi789".into(), "Bash".into(), json!({})),
1654            ],
1655            |name, _input| {
1656                Box::pin(async move {
1657                    // Add varying delays to shuffle completion order
1658                    match name.as_str() {
1659                        "Read" => tokio::time::sleep(Duration::from_millis(30)).await,
1660                        "Write" => tokio::time::sleep(Duration::from_millis(10)).await,
1661                        _ => tokio::time::sleep(Duration::from_millis(50)).await,
1662                    }
1663                    Some(ToolResult {
1664                        content: format!("{} result", name),
1665                        is_error: false,
1666                        raw_content: None,
1667                    })
1668                })
1669            },
1670        )
1671        .await;
1672
1673        assert_eq!(results.len(), 3);
1674
1675        // Regardless of completion order, IDs must match their tools
1676        for (id, content, _) in &results {
1677            match id.as_str() {
1678                "toolu_abc123" => assert_eq!(content, "Read result"),
1679                "toolu_def456" => assert_eq!(content, "Write result"),
1680                "toolu_ghi789" => assert_eq!(content, "Bash result"),
1681                other => panic!("unexpected tool_use_id: {}", other),
1682            }
1683        }
1684    }
1685
1686    #[tokio::test]
1687    async fn concurrent_execution_timing_is_parallel() {
1688        // 5 tools each taking 50ms should complete in ~50ms total, not 250ms
1689        let tools: Vec<_> = (0..5)
1690            .map(|i| (format!("t{}", i), "Tool".into(), json!({})))
1691            .collect();
1692
1693        let start = Instant::now();
1694
1695        let results = run_concurrent_tools(tools, |_name, _input| {
1696            Box::pin(async move {
1697                tokio::time::sleep(Duration::from_millis(50)).await;
1698                Some(ToolResult {
1699                    content: "done".into(),
1700                    is_error: false,
1701                    raw_content: None,
1702                })
1703            })
1704        })
1705        .await;
1706
1707        let elapsed = start.elapsed();
1708
1709        assert_eq!(results.len(), 5);
1710        // Should complete in roughly 50ms, definitely under 200ms
1711        assert!(
1712            elapsed < Duration::from_millis(200),
1713            "5 x 50ms tools took {:?} — should be ~50ms if concurrent",
1714            elapsed
1715        );
1716    }
1717
1718    #[tokio::test]
1719    async fn api_block_to_content_block_preserves_tool_result_fields() {
1720        let block = ApiContentBlock::ToolResult {
1721            tool_use_id: "toolu_abc".into(),
1722            content: json!("result text"),
1723            is_error: Some(true),
1724            cache_control: None,
1725            name: None,
1726        };
1727
1728        let content = api_block_to_content_block(&block);
1729        match content {
1730            ContentBlock::ToolResult {
1731                tool_use_id,
1732                content,
1733                is_error,
1734            } => {
1735                assert_eq!(tool_use_id, "toolu_abc");
1736                assert_eq!(content, json!("result text"));
1737                assert_eq!(is_error, Some(true));
1738            }
1739            _ => panic!("expected ToolResult content block"),
1740        }
1741    }
1742
1743    #[tokio::test]
1744    async fn streamed_messages_each_contain_single_tool_result() {
1745        // Verify that the streaming pattern produces one User message per tool result
1746        let (tx, mut rx) = mpsc::unbounded_channel::<Result<Message>>();
1747        let session_id = "test-session".to_string();
1748
1749        // Simulate what the production code does
1750        let tool_ids = vec!["t1", "t2", "t3"];
1751        for id in &tool_ids {
1752            let api_block = ApiContentBlock::ToolResult {
1753                tool_use_id: id.to_string(),
1754                content: json!(format!("result for {}", id)),
1755                is_error: None,
1756                cache_control: None,
1757                name: None,
1758            };
1759
1760            let result_msg = Message::User(UserMessage {
1761                uuid: Some(Uuid::new_v4()),
1762                session_id: session_id.clone(),
1763                content: vec![api_block_to_content_block(&api_block)],
1764                parent_tool_use_id: None,
1765                is_synthetic: true,
1766                tool_use_result: None,
1767            });
1768            tx.send(Ok(result_msg)).unwrap();
1769        }
1770        drop(tx);
1771
1772        let mut messages = Vec::new();
1773        while let Some(Ok(msg)) = rx.recv().await {
1774            messages.push(msg);
1775        }
1776
1777        assert_eq!(messages.len(), 3, "should have 3 individual messages");
1778
1779        for (i, msg) in messages.iter().enumerate() {
1780            if let Message::User(user) = msg {
1781                assert_eq!(user.content.len(), 1, "each message should have exactly 1 content block");
1782                assert!(user.is_synthetic);
1783                if let ContentBlock::ToolResult { tool_use_id, .. } = &user.content[0] {
1784                    assert_eq!(tool_use_id, tool_ids[i]);
1785                } else {
1786                    panic!("expected ToolResult block");
1787                }
1788            } else {
1789                panic!("expected User message");
1790            }
1791        }
1792    }
1793
1794    #[tokio::test]
1795    async fn accumulate_stream_emits_text_deltas_and_builds_response() {
1796        use crate::client::{ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE};
1797
1798        // Build a fake stream of SSE events
1799        let events: Vec<Result<SE>> = vec![
1800            Ok(SE::MessageStart {
1801                message: MessageResponse {
1802                    id: "msg_123".into(),
1803                    role: "assistant".into(),
1804                    content: vec![],
1805                    model: "claude-test".into(),
1806                    stop_reason: None,
1807                    usage: ApiUsage {
1808                        input_tokens: 100,
1809                        output_tokens: 0,
1810                        cache_creation_input_tokens: None,
1811                        cache_read_input_tokens: None,
1812                    },
1813                },
1814            }),
1815            Ok(SE::ContentBlockStart {
1816                index: 0,
1817                content_block: ApiContentBlock::Text {
1818                    text: String::new(),
1819                    cache_control: None,
1820                },
1821            }),
1822            Ok(SE::ContentBlockDelta {
1823                index: 0,
1824                delta: ContentDelta::TextDelta { text: "Hello".into() },
1825            }),
1826            Ok(SE::ContentBlockDelta {
1827                index: 0,
1828                delta: ContentDelta::TextDelta { text: " world".into() },
1829            }),
1830            Ok(SE::ContentBlockDelta {
1831                index: 0,
1832                delta: ContentDelta::TextDelta { text: "!".into() },
1833            }),
1834            Ok(SE::ContentBlockStop { index: 0 }),
1835            Ok(SE::MessageDelta {
1836                delta: crate::client::MessageDelta {
1837                    stop_reason: Some("end_turn".into()),
1838                },
1839                usage: ApiUsage {
1840                    input_tokens: 0,
1841                    output_tokens: 15,
1842                    cache_creation_input_tokens: None,
1843                    cache_read_input_tokens: None,
1844                },
1845            }),
1846            Ok(SE::MessageStop),
1847        ];
1848
1849        let stream = futures::stream::iter(events);
1850        let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
1851            Box::pin(stream);
1852
1853        let (tx, mut rx) = mpsc::unbounded_channel();
1854
1855        let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
1856            .await
1857            .expect("accumulate_stream should succeed");
1858
1859        // Verify accumulated response
1860        assert_eq!(response.id, "msg_123");
1861        assert_eq!(response.model, "claude-test");
1862        assert_eq!(response.stop_reason, Some("end_turn".into()));
1863        assert_eq!(response.usage.output_tokens, 15);
1864        assert_eq!(response.content.len(), 1);
1865        if let ApiContentBlock::Text { text, .. } = &response.content[0] {
1866            assert_eq!(text, "Hello world!");
1867        } else {
1868            panic!("expected Text content block");
1869        }
1870
1871        // Verify 3 StreamEvent messages were emitted (one per text delta)
1872        let mut stream_events = Vec::new();
1873        while let Ok(msg) = rx.try_recv() {
1874            stream_events.push(msg.unwrap());
1875        }
1876        assert_eq!(stream_events.len(), 3);
1877
1878        // Verify each is a StreamEvent with the correct text
1879        let expected_texts = ["Hello", " world", "!"];
1880        for (i, msg) in stream_events.iter().enumerate() {
1881            if let Message::StreamEvent(se) = msg {
1882                let delta = se.event.get("delta").unwrap();
1883                let text = delta.get("text").unwrap().as_str().unwrap();
1884                assert_eq!(text, expected_texts[i]);
1885                assert_eq!(se.session_id, "test-session");
1886            } else {
1887                panic!("expected StreamEvent message at index {}", i);
1888            }
1889        }
1890    }
1891
1892    #[tokio::test]
1893    async fn accumulate_stream_handles_tool_use() {
1894        use crate::client::{ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE};
1895
1896        let events: Vec<Result<SE>> = vec![
1897            Ok(SE::MessageStart {
1898                message: MessageResponse {
1899                    id: "msg_456".into(),
1900                    role: "assistant".into(),
1901                    content: vec![],
1902                    model: "claude-test".into(),
1903                    stop_reason: None,
1904                    usage: ApiUsage::default(),
1905                },
1906            }),
1907            // Text block
1908            Ok(SE::ContentBlockStart {
1909                index: 0,
1910                content_block: ApiContentBlock::Text {
1911                    text: String::new(),
1912                    cache_control: None,
1913                },
1914            }),
1915            Ok(SE::ContentBlockDelta {
1916                index: 0,
1917                delta: ContentDelta::TextDelta { text: "Let me check.".into() },
1918            }),
1919            Ok(SE::ContentBlockStop { index: 0 }),
1920            // Tool use block
1921            Ok(SE::ContentBlockStart {
1922                index: 1,
1923                content_block: ApiContentBlock::ToolUse {
1924                    id: "toolu_abc".into(),
1925                    name: "Read".into(),
1926                    input: serde_json::json!({}),
1927                },
1928            }),
1929            Ok(SE::ContentBlockDelta {
1930                index: 1,
1931                delta: ContentDelta::InputJsonDelta {
1932                    partial_json: r#"{"path":"/tmp/f.txt"}"#.into(),
1933                },
1934            }),
1935            Ok(SE::ContentBlockStop { index: 1 }),
1936            Ok(SE::MessageDelta {
1937                delta: crate::client::MessageDelta {
1938                    stop_reason: Some("tool_use".into()),
1939                },
1940                usage: ApiUsage { input_tokens: 0, output_tokens: 20, ..Default::default() },
1941            }),
1942            Ok(SE::MessageStop),
1943        ];
1944
1945        let stream = futures::stream::iter(events);
1946        let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
1947            Box::pin(stream);
1948
1949        let (tx, _rx) = mpsc::unbounded_channel();
1950        let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
1951            .await
1952            .expect("should succeed");
1953
1954        assert_eq!(response.content.len(), 2);
1955        if let ApiContentBlock::Text { text, .. } = &response.content[0] {
1956            assert_eq!(text, "Let me check.");
1957        } else {
1958            panic!("expected Text block at index 0");
1959        }
1960        if let ApiContentBlock::ToolUse { id, name, input } = &response.content[1] {
1961            assert_eq!(id, "toolu_abc");
1962            assert_eq!(name, "Read");
1963            assert_eq!(input["path"], "/tmp/f.txt");
1964        } else {
1965            panic!("expected ToolUse block at index 1");
1966        }
1967        assert_eq!(response.stop_reason, Some("tool_use".into()));
1968    }
1969}