1use std::collections::HashMap;
8use std::path::PathBuf;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Instant;
12
13use futures::stream::FuturesUnordered;
14use futures::{Stream, StreamExt as FuturesStreamExt};
15use serde_json::json;
16use tokio::sync::mpsc;
17use tokio_stream::wrappers::UnboundedReceiverStream;
18use tracing::{debug, error, warn};
19use uuid::Uuid;
20
21use crate::client::{
22 ApiContentBlock, ApiMessage, ApiUsage, CacheControl, ContentDelta, CreateMessageRequest,
23 ImageSource, MessageResponse, StreamEvent as ClientStreamEvent, SystemBlock, ThinkingParam,
24 ToolDefinition,
25};
26use crate::compact;
27use crate::error::{AgentError, Result};
28use crate::hooks::HookRegistry;
29use crate::options::{Options, PermissionMode, ThinkingConfig};
30use crate::permissions::{PermissionEvaluator, PermissionVerdict};
31use crate::provider::LlmProvider;
32use crate::providers::AnthropicProvider;
33use crate::sanitize;
34use crate::session::Session;
35use crate::tools::definitions::get_tool_definitions;
36use crate::tools::executor::{ToolExecutor, ToolResult};
37use crate::types::messages::*;
38
39const DEFAULT_MODEL: &str = "claude-haiku-4-5";
41const DEFAULT_MAX_TOKENS: u32 = 16384;
43
44pub struct Query {
48 receiver: UnboundedReceiverStream<Result<Message>>,
49 session_id: Option<String>,
50 cancel_token: tokio_util::sync::CancellationToken,
51}
52
53impl Query {
54 pub async fn interrupt(&self) -> Result<()> {
56 self.cancel_token.cancel();
57 Ok(())
58 }
59
60 pub fn session_id(&self) -> Option<&str> {
62 self.session_id.as_deref()
63 }
64
65 pub async fn set_permission_mode(&self, _mode: PermissionMode) -> Result<()> {
67 Ok(())
69 }
70
71 pub async fn set_model(&self, _model: &str) -> Result<()> {
73 Ok(())
75 }
76
77 pub fn close(&self) {
79 self.cancel_token.cancel();
80 }
81}
82
83impl Stream for Query {
84 type Item = Result<Message>;
85
86 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
87 Pin::new(&mut self.receiver).poll_next(cx)
88 }
89}
90
91pub fn query(prompt: &str, options: Options) -> Query {
126 let (tx, rx) = mpsc::unbounded_channel();
127 let cancel_token = tokio_util::sync::CancellationToken::new();
128 let cancel = cancel_token.clone();
129
130 let prompt = prompt.to_string();
131
132 tokio::spawn(async move {
133 let result = run_agent_loop(prompt, options, tx.clone(), cancel).await;
134 if let Err(e) = result {
135 let _ = tx.send(Err(e));
136 }
137 });
138
139 Query {
140 receiver: UnboundedReceiverStream::new(rx),
141 session_id: None,
142 cancel_token,
143 }
144}
145
146async fn run_agent_loop(
156 prompt: String,
157 mut options: Options,
158 tx: mpsc::UnboundedSender<Result<Message>>,
159 cancel: tokio_util::sync::CancellationToken,
160) -> Result<()> {
161 let start_time = Instant::now();
162 let mut api_time_ms: u64 = 0;
163
164 let cwd = options.cwd.clone().unwrap_or_else(|| {
166 std::env::current_dir()
167 .unwrap_or_else(|_| PathBuf::from("."))
168 .to_string_lossy()
169 .to_string()
170 });
171
172 let session = if let Some(ref resume_id) = options.resume {
174 Session::with_id(resume_id, &cwd)
175 } else if options.continue_session {
176 match crate::session::find_most_recent_session(Some(&cwd)).await? {
178 Some(info) => Session::with_id(&info.session_id, &cwd),
179 None => Session::new(&cwd),
180 }
181 } else {
182 match &options.session_id {
183 Some(id) => Session::with_id(id, &cwd),
184 None => Session::new(&cwd),
185 }
186 };
187
188 let session_id = session.id.clone();
189 let model = options
190 .model
191 .clone()
192 .unwrap_or_else(|| DEFAULT_MODEL.to_string());
193
194 let tool_names: Vec<String> = if options.output_format.is_some() {
197 Vec::new()
198 } else if options.allowed_tools.is_empty() {
199 vec![
201 "Read".into(),
202 "Write".into(),
203 "Edit".into(),
204 "Bash".into(),
205 "Glob".into(),
206 "Grep".into(),
207 ]
208 } else {
209 options.allowed_tools.clone()
210 };
211
212 let raw_defs: Vec<_> = get_tool_definitions(&tool_names);
213
214 let mut all_defs: Vec<ToolDefinition> = raw_defs
216 .into_iter()
217 .map(|td| ToolDefinition {
218 name: td.name.to_string(),
219 description: td.description.to_string(),
220 input_schema: td.input_schema,
221 cache_control: None,
222 })
223 .collect();
224
225 for ctd in &options.custom_tool_definitions {
227 all_defs.push(ToolDefinition {
228 name: ctd.name.clone(),
229 description: ctd.description.clone(),
230 input_schema: ctd.input_schema.clone(),
231 cache_control: None,
232 });
233 }
234
235 if let Some(last) = all_defs.last_mut() {
237 last.cache_control = Some(CacheControl::ephemeral());
238 }
239
240 let tool_defs = all_defs;
241
242 let init_msg = Message::System(SystemMessage {
244 subtype: SystemSubtype::Init,
245 uuid: Uuid::new_v4(),
246 session_id: session_id.clone(),
247 agents: if options.agents.is_empty() {
248 None
249 } else {
250 Some(options.agents.keys().cloned().collect())
251 },
252 claude_code_version: Some(env!("CARGO_PKG_VERSION").to_string()),
253 cwd: Some(cwd.clone()),
254 tools: Some(tool_names.clone()),
255 mcp_servers: if options.mcp_servers.is_empty() {
256 None
257 } else {
258 Some(
259 options
260 .mcp_servers
261 .keys()
262 .map(|name| McpServerStatus {
263 name: name.clone(),
264 status: "connected".to_string(),
265 })
266 .collect(),
267 )
268 },
269 model: Some(model.clone()),
270 permission_mode: Some(options.permission_mode.to_string()),
271 compact_metadata: None,
272 });
273
274 if options.persist_session {
276 let _ = session
277 .append_message(&serde_json::to_value(&init_msg).unwrap_or_default())
278 .await;
279 }
280 if tx.send(Ok(init_msg)).is_err() {
281 return Ok(());
282 }
283
284 let provider: Box<dyn LlmProvider> = match options.provider.take() {
286 Some(p) => p,
287 None => Box::new(AnthropicProvider::from_env()?),
288 };
289
290 let additional_dirs: Vec<PathBuf> = options
292 .additional_directories
293 .iter()
294 .map(PathBuf::from)
295 .collect();
296 let env_blocklist = std::mem::take(&mut options.env_blocklist);
297 let tool_executor = if additional_dirs.is_empty() {
298 ToolExecutor::new(PathBuf::from(&cwd))
299 } else {
300 ToolExecutor::with_allowed_dirs(PathBuf::from(&cwd), additional_dirs)
301 }
302 .with_env_blocklist(env_blocklist);
303
304 let mut hook_registry = HookRegistry::from_map(std::mem::take(&mut options.hooks));
306 if !options.hook_dirs.is_empty() {
307 let dirs: Vec<&std::path::Path> = options.hook_dirs.iter().map(|p| p.as_path()).collect();
308 match crate::hooks::HookDiscovery::discover(&dirs) {
309 Ok(discovered) => hook_registry.merge(discovered),
310 Err(e) => tracing::warn!("Failed to discover hooks from dirs: {}", e),
311 }
312 }
313
314 let mut followup_rx = options.followup_rx.take();
316
317 let permission_eval = PermissionEvaluator::new(&options);
319
320 let system_prompt: Option<Vec<SystemBlock>> = {
322 let text = match &options.system_prompt {
323 Some(crate::options::SystemPrompt::Custom(s)) => s.clone(),
324 Some(crate::options::SystemPrompt::Preset { append, .. }) => {
325 let base = "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.";
326 match append {
327 Some(extra) => format!("{}\n\n{}", base, extra),
328 None => base.to_string(),
329 }
330 }
331 None => "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.".to_string(),
332 };
333 Some(vec![SystemBlock {
334 kind: "text".to_string(),
335 text,
336 cache_control: Some(CacheControl::ephemeral()),
337 }])
338 };
339
340 let mut conversation: Vec<ApiMessage> = Vec::new();
342
343 if options.resume.is_some() || options.continue_session {
345 let prev_messages = session.load_messages().await?;
346 for msg_value in prev_messages {
347 if let Some(api_msg) = value_to_api_message(&msg_value) {
348 conversation.push(api_msg);
349 }
350 }
351 }
352
353 {
355 let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
356
357 for att in &options.attachments {
359 let is_image = matches!(
360 att.mime_type.as_str(),
361 "image/png" | "image/jpeg" | "image/gif" | "image/webp"
362 );
363 if is_image {
364 content_blocks.push(ApiContentBlock::Image {
365 source: ImageSource {
366 kind: "base64".to_string(),
367 media_type: att.mime_type.clone(),
368 data: att.base64_data.clone(),
369 },
370 });
371 }
372 }
373
374 content_blocks.push(ApiContentBlock::Text {
376 text: prompt.clone(),
377 cache_control: None,
378 });
379
380 conversation.push(ApiMessage {
381 role: "user".to_string(),
382 content: content_blocks,
383 });
384 }
385
386 if options.persist_session {
388 let user_msg = json!({
389 "type": "user",
390 "uuid": Uuid::new_v4().to_string(),
391 "session_id": &session_id,
392 "content": [{"type": "text", "text": &prompt}]
393 });
394 let _ = session.append_message(&user_msg).await;
395 }
396
397 let mut num_turns: u32 = 0;
399 let mut total_usage = Usage::default();
400 let mut total_cost: f64 = 0.0;
401 let mut model_usage: HashMap<String, ModelUsage> = HashMap::new();
402 let mut permission_denials: Vec<PermissionDenial> = Vec::new();
403
404 loop {
405 if cancel.is_cancelled() {
407 return Err(AgentError::Cancelled);
408 }
409
410 if let Some(max_turns) = options.max_turns {
412 if num_turns >= max_turns {
413 let result_msg = build_result_message(
414 ResultSubtype::ErrorMaxTurns,
415 &session_id,
416 None,
417 start_time,
418 api_time_ms,
419 num_turns,
420 total_cost,
421 &total_usage,
422 &model_usage,
423 &permission_denials,
424 );
425 let _ = tx.send(Ok(result_msg));
426 return Ok(());
427 }
428 }
429
430 if let Some(max_budget) = options.max_budget_usd {
432 if total_cost >= max_budget {
433 let result_msg = build_result_message(
434 ResultSubtype::ErrorMaxBudgetUsd,
435 &session_id,
436 None,
437 start_time,
438 api_time_ms,
439 num_turns,
440 total_cost,
441 &total_usage,
442 &model_usage,
443 &permission_denials,
444 );
445 let _ = tx.send(Ok(result_msg));
446 return Ok(());
447 }
448 }
449
450 if let Some(ref mut followup_rx) = followup_rx {
454 let mut followups: Vec<String> = Vec::new();
455 while let Ok(msg) = followup_rx.try_recv() {
456 followups.push(msg);
457 }
458 if !followups.is_empty() {
459 let combined = followups.join("\n\n");
460 debug!(
461 count = followups.len(),
462 "Injecting followup messages into agent loop"
463 );
464
465 conversation.push(ApiMessage {
466 role: "user".to_string(),
467 content: vec![ApiContentBlock::Text {
468 text: combined.clone(),
469 cache_control: None,
470 }],
471 });
472
473 let followup_msg = Message::User(UserMessage {
475 uuid: Some(Uuid::new_v4()),
476 session_id: session_id.clone(),
477 content: vec![ContentBlock::Text { text: combined }],
478 parent_tool_use_id: None,
479 is_synthetic: false,
480 tool_use_result: None,
481 });
482
483 if options.persist_session {
484 let _ = session
485 .append_message(&serde_json::to_value(&followup_msg).unwrap_or_default())
486 .await;
487 }
488 if tx.send(Ok(followup_msg)).is_err() {
489 return Ok(());
490 }
491 }
492 }
493
494 apply_cache_breakpoint(&mut conversation);
498
499 let thinking_param = options.thinking.as_ref().map(|tc| match tc {
501 ThinkingConfig::Adaptive => ThinkingParam {
502 kind: "enabled".into(),
503 budget_tokens: Some(10240),
504 },
505 ThinkingConfig::Disabled => ThinkingParam {
506 kind: "disabled".into(),
507 budget_tokens: None,
508 },
509 ThinkingConfig::Enabled { budget_tokens } => ThinkingParam {
510 kind: "enabled".into(),
511 budget_tokens: Some(*budget_tokens),
512 },
513 });
514
515 let base_max_tokens = options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
517 let max_tokens = if let Some(ref tp) = thinking_param {
518 if let Some(budget) = tp.budget_tokens {
519 base_max_tokens.max(budget as u32 + 8192)
520 } else {
521 base_max_tokens
522 }
523 } else {
524 base_max_tokens
525 };
526
527 let use_streaming = options.include_partial_messages;
529 let request = CreateMessageRequest {
530 model: model.clone(),
531 max_tokens,
532 messages: conversation.clone(),
533 system: system_prompt.clone(),
534 tools: if tool_defs.is_empty() {
535 None
536 } else {
537 Some(tool_defs.clone())
538 },
539 stream: use_streaming,
540 metadata: None,
541 thinking: thinking_param,
542 };
543
544 let api_start = Instant::now();
546 let response = if use_streaming {
547 match provider.create_message_stream(&request).await {
549 Ok(mut event_stream) => {
550 match accumulate_stream(&mut event_stream, &tx, &session_id).await {
551 Ok(resp) => resp,
552 Err(e) => {
553 error!("Stream accumulation failed: {}", e);
554 let result_msg = build_error_result_message(
555 &session_id,
556 &format!("Stream error: {}", e),
557 start_time,
558 api_time_ms,
559 num_turns,
560 total_cost,
561 &total_usage,
562 &model_usage,
563 &permission_denials,
564 );
565 let _ = tx.send(Ok(result_msg));
566 return Ok(());
567 }
568 }
569 }
570 Err(e) => {
571 error!("API stream call failed: {}", e);
572 let result_msg = build_error_result_message(
573 &session_id,
574 &format!("API error: {}", e),
575 start_time,
576 api_time_ms,
577 num_turns,
578 total_cost,
579 &total_usage,
580 &model_usage,
581 &permission_denials,
582 );
583 let _ = tx.send(Ok(result_msg));
584 return Ok(());
585 }
586 }
587 } else {
588 match provider.create_message(&request).await {
590 Ok(resp) => resp,
591 Err(e) => {
592 error!("API call failed: {}", e);
593 let result_msg = build_error_result_message(
594 &session_id,
595 &format!("API error: {}", e),
596 start_time,
597 api_time_ms,
598 num_turns,
599 total_cost,
600 &total_usage,
601 &model_usage,
602 &permission_denials,
603 );
604 let _ = tx.send(Ok(result_msg));
605 return Ok(());
606 }
607 }
608 };
609 api_time_ms += api_start.elapsed().as_millis() as u64;
610
611 total_usage.input_tokens += response.usage.input_tokens;
613 total_usage.output_tokens += response.usage.output_tokens;
614 total_usage.cache_creation_input_tokens +=
615 response.usage.cache_creation_input_tokens.unwrap_or(0);
616 total_usage.cache_read_input_tokens += response.usage.cache_read_input_tokens.unwrap_or(0);
617
618 let rates = provider.cost_rates(&model);
620 let turn_cost = rates.compute_with_cache(
621 response.usage.input_tokens,
622 response.usage.output_tokens,
623 response.usage.cache_read_input_tokens.unwrap_or(0),
624 response.usage.cache_creation_input_tokens.unwrap_or(0),
625 );
626 total_cost += turn_cost;
627
628 let model_entry = model_usage.entry(model.clone()).or_default();
630 model_entry.input_tokens += response.usage.input_tokens;
631 model_entry.output_tokens += response.usage.output_tokens;
632 model_entry.cost_usd += turn_cost;
633
634 let content_blocks: Vec<ContentBlock> = response
636 .content
637 .iter()
638 .map(api_block_to_content_block)
639 .collect();
640
641 let assistant_msg = Message::Assistant(AssistantMessage {
643 uuid: Uuid::new_v4(),
644 session_id: session_id.clone(),
645 content: content_blocks.clone(),
646 model: response.model.clone(),
647 stop_reason: response.stop_reason.clone(),
648 parent_tool_use_id: None,
649 usage: Some(Usage {
650 input_tokens: response.usage.input_tokens,
651 output_tokens: response.usage.output_tokens,
652 cache_creation_input_tokens: response
653 .usage
654 .cache_creation_input_tokens
655 .unwrap_or(0),
656 cache_read_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
657 }),
658 error: None,
659 });
660
661 if options.persist_session {
662 let _ = session
663 .append_message(&serde_json::to_value(&assistant_msg).unwrap_or_default())
664 .await;
665 }
666 if tx.send(Ok(assistant_msg)).is_err() {
667 return Ok(());
668 }
669
670 conversation.push(ApiMessage {
672 role: "assistant".to_string(),
673 content: response.content.clone(),
674 });
675
676 let tool_uses: Vec<_> = response
678 .content
679 .iter()
680 .filter_map(|block| match block {
681 ApiContentBlock::ToolUse { id, name, input } => {
682 Some((id.clone(), name.clone(), input.clone()))
683 }
684 _ => None,
685 })
686 .collect();
687
688 if tool_uses.is_empty() {
690 let final_text = response
692 .content
693 .iter()
694 .filter_map(|block| match block {
695 ApiContentBlock::Text { text, .. } => Some(text.as_str()),
696 _ => None,
697 })
698 .collect::<Vec<_>>()
699 .join("");
700
701 let result_msg = build_result_message(
702 ResultSubtype::Success,
703 &session_id,
704 Some(final_text),
705 start_time,
706 api_time_ms,
707 num_turns,
708 total_cost,
709 &total_usage,
710 &model_usage,
711 &permission_denials,
712 );
713
714 if options.persist_session {
715 let _ = session
716 .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
717 .await;
718 }
719 let _ = tx.send(Ok(result_msg));
720 return Ok(());
721 }
722
723 num_turns += 1;
725 let mut tool_results: Vec<ApiContentBlock> = Vec::new();
726
727 let known_tool_names: std::collections::HashSet<&str> =
730 tool_defs.iter().map(|td| td.name.as_str()).collect();
731
732 let mut valid_tool_uses: Vec<&(String, String, serde_json::Value)> = Vec::new();
733 for tu in &tool_uses {
734 let (tool_use_id, tool_name, _tool_input) = tu;
735 if known_tool_names.contains(tool_name.as_str()) {
736 valid_tool_uses.push(tu);
737 } else {
738 warn!(tool = %tool_name, "model invoked unknown tool, returning error");
739 let available: Vec<&str> = tool_defs.iter().map(|td| td.name.as_str()).collect();
740 let error_msg = format!(
741 "Error: '{}' is not a valid tool. You MUST use one of the following tools: {}",
742 tool_name,
743 available.join(", ")
744 );
745 let api_block = ApiContentBlock::ToolResult {
746 tool_use_id: tool_use_id.clone(),
747 content: json!(error_msg),
748 is_error: Some(true),
749 cache_control: None,
750 name: Some(tool_name.clone()),
751 };
752
753 let result_msg = Message::User(UserMessage {
755 uuid: Some(Uuid::new_v4()),
756 session_id: session_id.clone(),
757 content: vec![api_block_to_content_block(&api_block)],
758 parent_tool_use_id: None,
759 is_synthetic: true,
760 tool_use_result: None,
761 });
762 if options.persist_session {
763 let _ = session
764 .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
765 .await;
766 }
767 if tx.send(Ok(result_msg)).is_err() {
768 return Ok(());
769 }
770
771 tool_results.push(api_block);
772 }
773 }
774
775 struct PermittedTool {
777 tool_use_id: String,
778 tool_name: String,
779 actual_input: serde_json::Value,
780 }
781 let mut permitted_tools: Vec<PermittedTool> = Vec::new();
782
783 for (tool_use_id, tool_name, tool_input) in valid_tool_uses.iter().map(|t| &**t) {
784 let verdict = permission_eval
785 .evaluate(tool_name, tool_input, tool_use_id, &session_id, &cwd)
786 .await?;
787
788 let actual_input = match &verdict {
789 PermissionVerdict::AllowWithUpdatedInput(new_input) => new_input.clone(),
790 _ => tool_input.clone(),
791 };
792
793 match verdict {
794 PermissionVerdict::Allow | PermissionVerdict::AllowWithUpdatedInput(_) => {
795 permitted_tools.push(PermittedTool {
796 tool_use_id: tool_use_id.clone(),
797 tool_name: tool_name.clone(),
798 actual_input,
799 });
800 }
801 PermissionVerdict::Deny { reason } => {
802 debug!(tool = %tool_name, reason = %reason, "Tool denied");
803 permission_denials.push(PermissionDenial {
804 tool_name: tool_name.clone(),
805 tool_use_id: tool_use_id.clone(),
806 tool_input: tool_input.clone(),
807 });
808
809 let api_block = ApiContentBlock::ToolResult {
810 tool_use_id: tool_use_id.clone(),
811 content: json!(format!("Permission denied: {}", reason)),
812 is_error: Some(true),
813 cache_control: None,
814 name: Some(tool_name.clone()),
815 };
816
817 let denial_msg = Message::User(UserMessage {
819 uuid: Some(Uuid::new_v4()),
820 session_id: session_id.clone(),
821 content: vec![api_block_to_content_block(&api_block)],
822 parent_tool_use_id: None,
823 is_synthetic: true,
824 tool_use_result: None,
825 });
826 if options.persist_session {
827 let _ = session
828 .append_message(&serde_json::to_value(&denial_msg).unwrap_or_default())
829 .await;
830 }
831 if tx.send(Ok(denial_msg)).is_err() {
832 return Ok(());
833 }
834
835 tool_results.push(api_block);
836 }
837 }
838 }
839
840 let mut futs: FuturesUnordered<_> = permitted_tools
842 .iter()
843 .map(|pt| {
844 let handler = &options.external_tool_handler;
845 let executor = &tool_executor;
846 let name = &pt.tool_name;
847 let input = &pt.actual_input;
848 let id = &pt.tool_use_id;
849 async move {
850 debug!(tool = %name, "Executing tool");
851
852 let tool_result = if let Some(ref handler) = handler {
853 let ext_result = handler(name.clone(), input.clone()).await;
854 if let Some(tr) = ext_result {
855 tr
856 } else {
857 match executor.execute(name, input.clone()).await {
858 Ok(tr) => tr,
859 Err(e) => ToolResult {
860 content: format!("{}", e),
861 is_error: true,
862 raw_content: None,
863 },
864 }
865 }
866 } else {
867 match executor.execute(name, input.clone()).await {
868 Ok(tr) => tr,
869 Err(e) => ToolResult {
870 content: format!("{}", e),
871 is_error: true,
872 raw_content: None,
873 },
874 }
875 };
876 (id.as_str(), name.as_str(), input, tool_result)
877 }
878 })
879 .collect();
880
881 while let Some((tool_use_id, tool_name, actual_input, mut tool_result)) = futs.next().await
882 {
883 let max_result_bytes = options
885 .max_tool_result_bytes
886 .unwrap_or(sanitize::DEFAULT_MAX_TOOL_RESULT_BYTES);
887 tool_result.content =
888 sanitize::sanitize_tool_result(&tool_result.content, max_result_bytes);
889
890 hook_registry
892 .run_post_tool_use(
893 tool_name,
894 actual_input,
895 &serde_json::to_value(&tool_result.content).unwrap_or_default(),
896 tool_use_id,
897 &session_id,
898 &cwd,
899 )
900 .await;
901
902 let result_content = tool_result
903 .raw_content
904 .unwrap_or_else(|| json!(tool_result.content));
905
906 let api_block = ApiContentBlock::ToolResult {
907 tool_use_id: tool_use_id.to_string(),
908 content: result_content,
909 is_error: if tool_result.is_error {
910 Some(true)
911 } else {
912 None
913 },
914 cache_control: None,
915 name: Some(tool_name.to_string()),
916 };
917
918 let result_msg = Message::User(UserMessage {
920 uuid: Some(Uuid::new_v4()),
921 session_id: session_id.clone(),
922 content: vec![api_block_to_content_block(&api_block)],
923 parent_tool_use_id: None,
924 is_synthetic: true,
925 tool_use_result: None,
926 });
927 if options.persist_session {
928 let _ = session
929 .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
930 .await;
931 }
932 if tx.send(Ok(result_msg)).is_err() {
933 return Ok(());
934 }
935
936 tool_results.push(api_block);
937 }
938
939 conversation.push(ApiMessage {
941 role: "user".to_string(),
942 content: tool_results,
943 });
944
945 if let Some(context_budget) = options.context_budget {
947 let prune_pct = options
948 .prune_threshold_pct
949 .unwrap_or(compact::DEFAULT_PRUNE_THRESHOLD_PCT);
950 if compact::should_prune(response.usage.input_tokens, context_budget, prune_pct) {
951 let max_chars = options
952 .prune_tool_result_max_chars
953 .unwrap_or(compact::DEFAULT_PRUNE_TOOL_RESULT_MAX_CHARS);
954 let min_keep = options.min_keep_messages.unwrap_or(4);
955 let removed = compact::prune_tool_results(&mut conversation, max_chars, min_keep);
956 if removed > 0 {
957 debug!(
958 chars_removed = removed,
959 input_tokens = response.usage.input_tokens,
960 "Pruned oversized tool results to free context space"
961 );
962 }
963 }
964 }
965
966 if let Some(context_budget) = options.context_budget {
968 if compact::should_compact(response.usage.input_tokens, context_budget) {
969 let min_keep = options.min_keep_messages.unwrap_or(4);
970 let split_point = compact::find_split_point(&conversation, min_keep);
971 if split_point > 0 {
972 debug!(
973 input_tokens = response.usage.input_tokens,
974 context_budget,
975 split_point,
976 "Context budget exceeded, compacting conversation"
977 );
978
979 let compaction_model = options
980 .compaction_model
981 .as_deref()
982 .unwrap_or(compact::DEFAULT_COMPACTION_MODEL);
983
984 if let Some(ref handler) = options.pre_compact_handler {
986 let msgs_to_compact = conversation[..split_point].to_vec();
987 handler(msgs_to_compact).await;
988 }
989
990 let summary_prompt =
991 compact::build_summary_prompt(&conversation[..split_point]);
992
993 let summary_max_tokens = options.summary_max_tokens.unwrap_or(4096);
994 let compact_provider: &dyn LlmProvider = match &options.compaction_provider {
995 Some(cp) => cp.as_ref(),
996 None => provider.as_ref(),
997 };
998 let fallback_provider: Option<&dyn LlmProvider> =
999 if options.compaction_provider.is_some() {
1000 Some(provider.as_ref())
1001 } else {
1002 None
1003 };
1004 match compact::call_summarizer(
1005 compact_provider,
1006 &summary_prompt,
1007 compaction_model,
1008 fallback_provider,
1009 &model,
1010 summary_max_tokens,
1011 )
1012 .await
1013 {
1014 Ok(summary) => {
1015 let pre_tokens = response.usage.input_tokens;
1016 let messages_compacted = split_point;
1017
1018 compact::splice_conversation(&mut conversation, split_point, &summary);
1019
1020 let compact_msg = Message::System(SystemMessage {
1022 subtype: SystemSubtype::CompactBoundary,
1023 uuid: Uuid::new_v4(),
1024 session_id: session_id.clone(),
1025 agents: None,
1026 claude_code_version: None,
1027 cwd: None,
1028 tools: None,
1029 mcp_servers: None,
1030 model: None,
1031 permission_mode: None,
1032 compact_metadata: Some(CompactMetadata {
1033 trigger: CompactTrigger::Auto,
1034 pre_tokens,
1035 }),
1036 });
1037
1038 if options.persist_session {
1039 let _ = session
1040 .append_message(
1041 &serde_json::to_value(&compact_msg).unwrap_or_default(),
1042 )
1043 .await;
1044 }
1045 let _ = tx.send(Ok(compact_msg));
1046
1047 debug!(
1048 pre_tokens,
1049 messages_compacted,
1050 summary_len = summary.len(),
1051 "Conversation compacted"
1052 );
1053 }
1054 Err(e) => {
1055 warn!("Compaction failed, continuing without compaction: {}", e);
1056 }
1057 }
1058 }
1059 }
1060 }
1061 }
1062}
1063
1064async fn accumulate_stream(
1067 event_stream: &mut std::pin::Pin<
1068 Box<dyn futures::Stream<Item = Result<ClientStreamEvent>> + Send>,
1069 >,
1070 tx: &mpsc::UnboundedSender<Result<Message>>,
1071 session_id: &str,
1072) -> Result<MessageResponse> {
1073 use crate::client::StreamEvent as SE;
1074
1075 let mut message_id = String::new();
1077 let mut model = String::new();
1078 let mut role = String::from("assistant");
1079 let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
1080 let mut stop_reason: Option<String> = None;
1081 let mut usage = ApiUsage::default();
1082
1083 let mut block_texts: Vec<String> = Vec::new();
1086 let mut block_types: Vec<String> = Vec::new(); let mut block_tool_ids: Vec<String> = Vec::new();
1088 let mut block_tool_names: Vec<String> = Vec::new();
1089
1090 while let Some(event_result) = FuturesStreamExt::next(event_stream).await {
1091 let event = event_result?;
1092 match event {
1093 SE::MessageStart { message } => {
1094 message_id = message.id;
1095 model = message.model;
1096 role = message.role;
1097 usage = message.usage;
1098 }
1099 SE::ContentBlockStart {
1100 index,
1101 content_block,
1102 } => {
1103 while block_texts.len() <= index {
1105 block_texts.push(String::new());
1106 block_types.push(String::new());
1107 block_tool_ids.push(String::new());
1108 block_tool_names.push(String::new());
1109 }
1110 match &content_block {
1111 ApiContentBlock::Text { .. } => {
1112 block_types[index] = "text".to_string();
1113 }
1114 ApiContentBlock::ToolUse { id, name, input } => {
1115 block_types[index] = "tool_use".to_string();
1116 block_tool_ids[index] = id.clone();
1117 block_tool_names[index] = name.clone();
1118 let input_str = input.to_string();
1122 if input_str != "{}" {
1123 block_texts[index] = input_str;
1124 }
1125 }
1126 ApiContentBlock::Thinking { .. } => {
1127 block_types[index] = "thinking".to_string();
1128 }
1129 _ => {}
1130 }
1131 }
1132 SE::ContentBlockDelta { index, delta } => {
1133 while block_texts.len() <= index {
1134 block_texts.push(String::new());
1135 block_types.push(String::new());
1136 block_tool_ids.push(String::new());
1137 block_tool_names.push(String::new());
1138 }
1139 match &delta {
1140 ContentDelta::TextDelta { text } => {
1141 block_texts[index].push_str(text);
1142 let stream_event = Message::StreamEvent(StreamEventMessage {
1144 event: serde_json::json!({
1145 "type": "content_block_delta",
1146 "index": index,
1147 "delta": { "type": "text_delta", "text": text }
1148 }),
1149 parent_tool_use_id: None,
1150 uuid: Uuid::new_v4(),
1151 session_id: session_id.to_string(),
1152 });
1153 if tx.send(Ok(stream_event)).is_err() {
1154 return Err(AgentError::Cancelled);
1155 }
1156 }
1157 ContentDelta::InputJsonDelta { partial_json } => {
1158 block_texts[index].push_str(partial_json);
1159 }
1160 ContentDelta::ThinkingDelta { thinking } => {
1161 block_texts[index].push_str(thinking);
1162 }
1163 }
1164 }
1165 SE::ContentBlockStop { index } => {
1166 if index < block_types.len() {
1167 let block = match block_types[index].as_str() {
1168 "text" => ApiContentBlock::Text {
1169 text: std::mem::take(&mut block_texts[index]),
1170 cache_control: None,
1171 },
1172 "tool_use" => {
1173 let input: serde_json::Value =
1174 serde_json::from_str(&block_texts[index])
1175 .unwrap_or(serde_json::Value::Object(Default::default()));
1176 ApiContentBlock::ToolUse {
1177 id: std::mem::take(&mut block_tool_ids[index]),
1178 name: std::mem::take(&mut block_tool_names[index]),
1179 input,
1180 }
1181 }
1182 "thinking" => ApiContentBlock::Thinking {
1183 thinking: std::mem::take(&mut block_texts[index]),
1184 },
1185 _ => continue,
1186 };
1187 while content_blocks.len() <= index {
1189 content_blocks.push(ApiContentBlock::Text {
1190 text: String::new(),
1191 cache_control: None,
1192 });
1193 }
1194 content_blocks[index] = block;
1195 }
1196 }
1197 SE::MessageDelta {
1198 delta,
1199 usage: delta_usage,
1200 } => {
1201 stop_reason = delta.stop_reason;
1202 usage.output_tokens = delta_usage.output_tokens;
1204 }
1205 SE::MessageStop => {
1206 break;
1207 }
1208 SE::Error { error } => {
1209 return Err(AgentError::Api(error.message));
1210 }
1211 SE::Ping => {}
1212 }
1213 }
1214
1215 Ok(MessageResponse {
1216 id: message_id,
1217 role,
1218 content: content_blocks,
1219 model,
1220 stop_reason,
1221 usage,
1222 })
1223}
1224
1225fn apply_cache_breakpoint(conversation: &mut [ApiMessage]) {
1230 for msg in conversation.iter_mut() {
1232 for block in msg.content.iter_mut() {
1233 match block {
1234 ApiContentBlock::Text { cache_control, .. }
1235 | ApiContentBlock::ToolResult { cache_control, .. } => {
1236 *cache_control = None;
1237 }
1238 ApiContentBlock::Image { .. }
1239 | ApiContentBlock::ToolUse { .. }
1240 | ApiContentBlock::Thinking { .. } => {}
1241 }
1242 }
1243 }
1244
1245 if let Some(last_user) = conversation.iter_mut().rev().find(|m| m.role == "user") {
1247 if let Some(last_block) = last_user.content.last_mut() {
1248 match last_block {
1249 ApiContentBlock::Text { cache_control, .. }
1250 | ApiContentBlock::ToolResult { cache_control, .. } => {
1251 *cache_control = Some(CacheControl::ephemeral());
1252 }
1253 ApiContentBlock::Image { .. }
1254 | ApiContentBlock::ToolUse { .. }
1255 | ApiContentBlock::Thinking { .. } => {}
1256 }
1257 }
1258 }
1259}
1260
1261fn api_block_to_content_block(block: &ApiContentBlock) -> ContentBlock {
1263 match block {
1264 ApiContentBlock::Text { text, .. } => ContentBlock::Text { text: text.clone() },
1265 ApiContentBlock::Image { .. } => ContentBlock::Text {
1266 text: "[image]".to_string(),
1267 },
1268 ApiContentBlock::ToolUse { id, name, input } => ContentBlock::ToolUse {
1269 id: id.clone(),
1270 name: name.clone(),
1271 input: input.clone(),
1272 },
1273 ApiContentBlock::ToolResult {
1274 tool_use_id,
1275 content,
1276 is_error,
1277 ..
1278 } => ContentBlock::ToolResult {
1279 tool_use_id: tool_use_id.clone(),
1280 content: content.clone(),
1281 is_error: *is_error,
1282 },
1283 ApiContentBlock::Thinking { thinking } => ContentBlock::Thinking {
1284 thinking: thinking.clone(),
1285 },
1286 }
1287}
1288
1289fn value_to_api_message(value: &serde_json::Value) -> Option<ApiMessage> {
1291 let msg_type = value.get("type")?.as_str()?;
1292
1293 match msg_type {
1294 "assistant" => {
1295 let content = value.get("content")?;
1296 let blocks = parse_content_blocks(content)?;
1297 Some(ApiMessage {
1298 role: "assistant".to_string(),
1299 content: blocks,
1300 })
1301 }
1302 "user" => {
1303 let content = value.get("content")?;
1304 let blocks = parse_content_blocks(content)?;
1305 Some(ApiMessage {
1306 role: "user".to_string(),
1307 content: blocks,
1308 })
1309 }
1310 _ => None,
1311 }
1312}
1313
1314fn parse_content_blocks(content: &serde_json::Value) -> Option<Vec<ApiContentBlock>> {
1316 if let Some(text) = content.as_str() {
1317 return Some(vec![ApiContentBlock::Text {
1318 text: text.to_string(),
1319 cache_control: None,
1320 }]);
1321 }
1322
1323 if let Some(blocks) = content.as_array() {
1324 let parsed: Vec<ApiContentBlock> = blocks
1325 .iter()
1326 .filter_map(|b| serde_json::from_value(b.clone()).ok())
1327 .collect();
1328 if !parsed.is_empty() {
1329 return Some(parsed);
1330 }
1331 }
1332
1333 None
1334}
1335
1336#[allow(clippy::too_many_arguments)]
1338fn build_result_message(
1339 subtype: ResultSubtype,
1340 session_id: &str,
1341 result_text: Option<String>,
1342 start_time: Instant,
1343 api_time_ms: u64,
1344 num_turns: u32,
1345 total_cost: f64,
1346 usage: &Usage,
1347 model_usage: &HashMap<String, ModelUsage>,
1348 permission_denials: &[PermissionDenial],
1349) -> Message {
1350 Message::Result(ResultMessage {
1351 subtype,
1352 uuid: Uuid::new_v4(),
1353 session_id: session_id.to_string(),
1354 duration_ms: start_time.elapsed().as_millis() as u64,
1355 duration_api_ms: api_time_ms,
1356 is_error: result_text.is_none(),
1357 num_turns,
1358 result: result_text,
1359 stop_reason: Some("end_turn".to_string()),
1360 total_cost_usd: total_cost,
1361 usage: Some(usage.clone()),
1362 model_usage: model_usage.clone(),
1363 permission_denials: permission_denials.to_vec(),
1364 structured_output: None,
1365 errors: Vec::new(),
1366 })
1367}
1368
1369#[allow(clippy::too_many_arguments)]
1371fn build_error_result_message(
1372 session_id: &str,
1373 error_msg: &str,
1374 start_time: Instant,
1375 api_time_ms: u64,
1376 num_turns: u32,
1377 total_cost: f64,
1378 usage: &Usage,
1379 model_usage: &HashMap<String, ModelUsage>,
1380 permission_denials: &[PermissionDenial],
1381) -> Message {
1382 Message::Result(ResultMessage {
1383 subtype: ResultSubtype::ErrorDuringExecution,
1384 uuid: Uuid::new_v4(),
1385 session_id: session_id.to_string(),
1386 duration_ms: start_time.elapsed().as_millis() as u64,
1387 duration_api_ms: api_time_ms,
1388 is_error: true,
1389 num_turns,
1390 result: None,
1391 stop_reason: None,
1392 total_cost_usd: total_cost,
1393 usage: Some(usage.clone()),
1394 model_usage: model_usage.clone(),
1395 permission_denials: permission_denials.to_vec(),
1396 structured_output: None,
1397 errors: vec![error_msg.to_string()],
1398 })
1399}
1400
1401#[cfg(test)]
1402mod tests {
1403 use super::*;
1404 use std::sync::atomic::{AtomicUsize, Ordering};
1405 use std::sync::Arc;
1406 use std::time::Duration;
1407
1408 async fn run_concurrent_tools(
1411 tools: Vec<(String, String, serde_json::Value)>,
1412 handler: impl Fn(
1413 String,
1414 serde_json::Value,
1415 ) -> Pin<Box<dyn futures::Future<Output = Option<ToolResult>> + Send>>,
1416 ) -> Vec<(String, String, usize)> {
1417 let order = Arc::new(AtomicUsize::new(0));
1418 let handler = Arc::new(handler);
1419
1420 struct PermittedTool {
1421 tool_use_id: String,
1422 tool_name: String,
1423 actual_input: serde_json::Value,
1424 }
1425
1426 let permitted: Vec<PermittedTool> = tools
1427 .into_iter()
1428 .map(|(id, name, input)| PermittedTool {
1429 tool_use_id: id,
1430 tool_name: name,
1431 actual_input: input,
1432 })
1433 .collect();
1434
1435 let mut futs: FuturesUnordered<_> = permitted
1436 .iter()
1437 .map(|pt| {
1438 let handler = handler.clone();
1439 let order = order.clone();
1440 let name = pt.tool_name.clone();
1441 let input = pt.actual_input.clone();
1442 let id = pt.tool_use_id.clone();
1443 async move {
1444 let result = handler(name, input).await;
1445 let seq = order.fetch_add(1, Ordering::SeqCst);
1446 (id, result, seq)
1447 }
1448 })
1449 .collect();
1450
1451 let mut results = Vec::new();
1452 while let Some((id, result, seq)) = futs.next().await {
1453 let content = result
1454 .map(|r| r.content)
1455 .unwrap_or_else(|| "no handler".into());
1456 results.push((id, content, seq));
1457 }
1458 results
1459 }
1460
1461 #[tokio::test]
1462 async fn concurrent_tools_all_complete() {
1463 let results = run_concurrent_tools(
1464 vec![
1465 ("t1".into(), "Read".into(), json!({"path": "a.txt"})),
1466 ("t2".into(), "Read".into(), json!({"path": "b.txt"})),
1467 ("t3".into(), "Read".into(), json!({"path": "c.txt"})),
1468 ],
1469 |name, input| {
1470 Box::pin(async move {
1471 let path = input["path"].as_str().unwrap_or("?");
1472 Some(ToolResult {
1473 content: format!("{}: {}", name, path),
1474 is_error: false,
1475 raw_content: None,
1476 })
1477 })
1478 },
1479 )
1480 .await;
1481
1482 assert_eq!(results.len(), 3);
1483 let ids: Vec<&str> = results.iter().map(|(id, _, _)| id.as_str()).collect();
1484 assert!(ids.contains(&"t1"));
1485 assert!(ids.contains(&"t2"));
1486 assert!(ids.contains(&"t3"));
1487 }
1488
1489 #[tokio::test]
1490 async fn slow_tool_does_not_block_fast_tools() {
1491 let start = Instant::now();
1492
1493 let results = run_concurrent_tools(
1494 vec![
1495 ("slow".into(), "Bash".into(), json!({})),
1496 ("fast1".into(), "Read".into(), json!({})),
1497 ("fast2".into(), "Read".into(), json!({})),
1498 ],
1499 |name, _input| {
1500 Box::pin(async move {
1501 if name == "Bash" {
1502 tokio::time::sleep(Duration::from_millis(200)).await;
1503 Some(ToolResult {
1504 content: "slow done".into(),
1505 is_error: false,
1506 raw_content: None,
1507 })
1508 } else {
1509 Some(ToolResult {
1511 content: "fast done".into(),
1512 is_error: false,
1513 raw_content: None,
1514 })
1515 }
1516 })
1517 },
1518 )
1519 .await;
1520
1521 let elapsed = start.elapsed();
1522
1523 assert_eq!(results.len(), 3);
1525
1526 let slow = results.iter().find(|(id, _, _)| id == "slow").unwrap();
1528 let fast1 = results.iter().find(|(id, _, _)| id == "fast1").unwrap();
1529 let fast2 = results.iter().find(|(id, _, _)| id == "fast2").unwrap();
1530
1531 assert!(fast1.2 < slow.2, "fast1 should complete before slow");
1532 assert!(fast2.2 < slow.2, "fast2 should complete before slow");
1533
1534 assert!(
1536 elapsed < Duration::from_millis(400),
1537 "elapsed {:?} should be under 400ms (concurrent execution)",
1538 elapsed
1539 );
1540 }
1541
1542 #[tokio::test]
1543 async fn results_streamed_individually_as_they_complete() {
1544 let (tx, mut rx) = mpsc::unbounded_channel::<(String, String)>();
1547
1548 let tools = vec![
1549 ("t_slow".into(), "Slow".into(), json!({})),
1550 ("t_fast".into(), "Fast".into(), json!({})),
1551 ];
1552
1553 struct PT {
1554 tool_use_id: String,
1555 tool_name: String,
1556 }
1557
1558 let permitted: Vec<PT> = tools
1559 .into_iter()
1560 .map(|(id, name, _)| PT {
1561 tool_use_id: id,
1562 tool_name: name,
1563 })
1564 .collect();
1565
1566 let mut futs: FuturesUnordered<_> = permitted
1567 .iter()
1568 .map(|pt| {
1569 let name = pt.tool_name.clone();
1570 let id = pt.tool_use_id.clone();
1571 async move {
1572 if name == "Slow" {
1573 tokio::time::sleep(Duration::from_millis(100)).await;
1574 }
1575 let result = ToolResult {
1576 content: format!("{} result", name),
1577 is_error: false,
1578 raw_content: None,
1579 };
1580 (id, result)
1581 }
1582 })
1583 .collect();
1584
1585 while let Some((id, result)) = futs.next().await {
1587 tx.send((id, result.content)).unwrap();
1588 }
1589 drop(tx);
1590
1591 let mut streamed = Vec::new();
1593 while let Some(item) = rx.recv().await {
1594 streamed.push(item);
1595 }
1596
1597 assert_eq!(streamed.len(), 2);
1598 assert_eq!(streamed[0].0, "t_fast");
1600 assert_eq!(streamed[0].1, "Fast result");
1601 assert_eq!(streamed[1].0, "t_slow");
1602 assert_eq!(streamed[1].1, "Slow result");
1603 }
1604
1605 #[tokio::test]
1606 async fn error_tool_does_not_prevent_other_tools() {
1607 let results = run_concurrent_tools(
1608 vec![
1609 ("t_ok".into(), "Read".into(), json!({})),
1610 ("t_err".into(), "Fail".into(), json!({})),
1611 ],
1612 |name, _input| {
1613 Box::pin(async move {
1614 if name == "Fail" {
1615 Some(ToolResult {
1616 content: "something went wrong".into(),
1617 is_error: true,
1618 raw_content: None,
1619 })
1620 } else {
1621 Some(ToolResult {
1622 content: "ok".into(),
1623 is_error: false,
1624 raw_content: None,
1625 })
1626 }
1627 })
1628 },
1629 )
1630 .await;
1631
1632 assert_eq!(results.len(), 2);
1633 let ok = results.iter().find(|(id, _, _)| id == "t_ok").unwrap();
1634 let err = results.iter().find(|(id, _, _)| id == "t_err").unwrap();
1635 assert_eq!(ok.1, "ok");
1636 assert_eq!(err.1, "something went wrong");
1637 }
1638
1639 #[tokio::test]
1640 async fn external_handler_none_falls_through_correctly() {
1641 let results = run_concurrent_tools(
1644 vec![
1645 ("t_custom".into(), "MyTool".into(), json!({"x": 1})),
1646 ("t_builtin".into(), "Read".into(), json!({"path": "/tmp"})),
1647 ],
1648 |name, _input| {
1649 Box::pin(async move {
1650 if name == "MyTool" {
1651 Some(ToolResult {
1652 content: "custom handled".into(),
1653 is_error: false,
1654 raw_content: None,
1655 })
1656 } else {
1657 None
1659 }
1660 })
1661 },
1662 )
1663 .await;
1664
1665 assert_eq!(results.len(), 2);
1666 let custom = results.iter().find(|(id, _, _)| id == "t_custom").unwrap();
1667 let builtin = results.iter().find(|(id, _, _)| id == "t_builtin").unwrap();
1668 assert_eq!(custom.1, "custom handled");
1669 assert_eq!(builtin.1, "no handler"); }
1671
1672 #[tokio::test]
1673 async fn single_tool_works_same_as_before() {
1674 let results = run_concurrent_tools(
1675 vec![("t1".into(), "Read".into(), json!({"path": "file.txt"}))],
1676 |_name, _input| {
1677 Box::pin(async move {
1678 Some(ToolResult {
1679 content: "file contents".into(),
1680 is_error: false,
1681 raw_content: None,
1682 })
1683 })
1684 },
1685 )
1686 .await;
1687
1688 assert_eq!(results.len(), 1);
1689 assert_eq!(results[0].0, "t1");
1690 assert_eq!(results[0].1, "file contents");
1691 assert_eq!(results[0].2, 0); }
1693
1694 #[tokio::test]
1695 async fn empty_tool_list_produces_no_results() {
1696 let results =
1697 run_concurrent_tools(vec![], |_name, _input| Box::pin(async move { None })).await;
1698
1699 assert_eq!(results.len(), 0);
1700 }
1701
1702 #[tokio::test]
1703 async fn tool_use_ids_preserved_through_concurrent_execution() {
1704 let results = run_concurrent_tools(
1705 vec![
1706 ("toolu_abc123".into(), "Read".into(), json!({})),
1707 ("toolu_def456".into(), "Write".into(), json!({})),
1708 ("toolu_ghi789".into(), "Bash".into(), json!({})),
1709 ],
1710 |name, _input| {
1711 Box::pin(async move {
1712 match name.as_str() {
1714 "Read" => tokio::time::sleep(Duration::from_millis(30)).await,
1715 "Write" => tokio::time::sleep(Duration::from_millis(10)).await,
1716 _ => tokio::time::sleep(Duration::from_millis(50)).await,
1717 }
1718 Some(ToolResult {
1719 content: format!("{} result", name),
1720 is_error: false,
1721 raw_content: None,
1722 })
1723 })
1724 },
1725 )
1726 .await;
1727
1728 assert_eq!(results.len(), 3);
1729
1730 for (id, content, _) in &results {
1732 match id.as_str() {
1733 "toolu_abc123" => assert_eq!(content, "Read result"),
1734 "toolu_def456" => assert_eq!(content, "Write result"),
1735 "toolu_ghi789" => assert_eq!(content, "Bash result"),
1736 other => panic!("unexpected tool_use_id: {}", other),
1737 }
1738 }
1739 }
1740
1741 #[tokio::test]
1742 async fn concurrent_execution_timing_is_parallel() {
1743 let tools: Vec<_> = (0..5)
1745 .map(|i| (format!("t{}", i), "Tool".into(), json!({})))
1746 .collect();
1747
1748 let start = Instant::now();
1749
1750 let results = run_concurrent_tools(tools, |_name, _input| {
1751 Box::pin(async move {
1752 tokio::time::sleep(Duration::from_millis(50)).await;
1753 Some(ToolResult {
1754 content: "done".into(),
1755 is_error: false,
1756 raw_content: None,
1757 })
1758 })
1759 })
1760 .await;
1761
1762 let elapsed = start.elapsed();
1763
1764 assert_eq!(results.len(), 5);
1765 assert!(
1767 elapsed < Duration::from_millis(200),
1768 "5 x 50ms tools took {:?} — should be ~50ms if concurrent",
1769 elapsed
1770 );
1771 }
1772
1773 #[tokio::test]
1774 async fn api_block_to_content_block_preserves_tool_result_fields() {
1775 let block = ApiContentBlock::ToolResult {
1776 tool_use_id: "toolu_abc".into(),
1777 content: json!("result text"),
1778 is_error: Some(true),
1779 cache_control: None,
1780 name: None,
1781 };
1782
1783 let content = api_block_to_content_block(&block);
1784 match content {
1785 ContentBlock::ToolResult {
1786 tool_use_id,
1787 content,
1788 is_error,
1789 } => {
1790 assert_eq!(tool_use_id, "toolu_abc");
1791 assert_eq!(content, json!("result text"));
1792 assert_eq!(is_error, Some(true));
1793 }
1794 _ => panic!("expected ToolResult content block"),
1795 }
1796 }
1797
1798 #[tokio::test]
1799 async fn streamed_messages_each_contain_single_tool_result() {
1800 let (tx, mut rx) = mpsc::unbounded_channel::<Result<Message>>();
1802 let session_id = "test-session".to_string();
1803
1804 let tool_ids = vec!["t1", "t2", "t3"];
1806 for id in &tool_ids {
1807 let api_block = ApiContentBlock::ToolResult {
1808 tool_use_id: id.to_string(),
1809 content: json!(format!("result for {}", id)),
1810 is_error: None,
1811 cache_control: None,
1812 name: None,
1813 };
1814
1815 let result_msg = Message::User(UserMessage {
1816 uuid: Some(Uuid::new_v4()),
1817 session_id: session_id.clone(),
1818 content: vec![api_block_to_content_block(&api_block)],
1819 parent_tool_use_id: None,
1820 is_synthetic: true,
1821 tool_use_result: None,
1822 });
1823 tx.send(Ok(result_msg)).unwrap();
1824 }
1825 drop(tx);
1826
1827 let mut messages = Vec::new();
1828 while let Some(Ok(msg)) = rx.recv().await {
1829 messages.push(msg);
1830 }
1831
1832 assert_eq!(messages.len(), 3, "should have 3 individual messages");
1833
1834 for (i, msg) in messages.iter().enumerate() {
1835 if let Message::User(user) = msg {
1836 assert_eq!(
1837 user.content.len(),
1838 1,
1839 "each message should have exactly 1 content block"
1840 );
1841 assert!(user.is_synthetic);
1842 if let ContentBlock::ToolResult { tool_use_id, .. } = &user.content[0] {
1843 assert_eq!(tool_use_id, tool_ids[i]);
1844 } else {
1845 panic!("expected ToolResult block");
1846 }
1847 } else {
1848 panic!("expected User message");
1849 }
1850 }
1851 }
1852
1853 #[tokio::test]
1854 async fn accumulate_stream_emits_text_deltas_and_builds_response() {
1855 use crate::client::{
1856 ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE,
1857 };
1858
1859 let events: Vec<Result<SE>> = vec![
1861 Ok(SE::MessageStart {
1862 message: MessageResponse {
1863 id: "msg_123".into(),
1864 role: "assistant".into(),
1865 content: vec![],
1866 model: "claude-test".into(),
1867 stop_reason: None,
1868 usage: ApiUsage {
1869 input_tokens: 100,
1870 output_tokens: 0,
1871 cache_creation_input_tokens: None,
1872 cache_read_input_tokens: None,
1873 },
1874 },
1875 }),
1876 Ok(SE::ContentBlockStart {
1877 index: 0,
1878 content_block: ApiContentBlock::Text {
1879 text: String::new(),
1880 cache_control: None,
1881 },
1882 }),
1883 Ok(SE::ContentBlockDelta {
1884 index: 0,
1885 delta: ContentDelta::TextDelta {
1886 text: "Hello".into(),
1887 },
1888 }),
1889 Ok(SE::ContentBlockDelta {
1890 index: 0,
1891 delta: ContentDelta::TextDelta {
1892 text: " world".into(),
1893 },
1894 }),
1895 Ok(SE::ContentBlockDelta {
1896 index: 0,
1897 delta: ContentDelta::TextDelta { text: "!".into() },
1898 }),
1899 Ok(SE::ContentBlockStop { index: 0 }),
1900 Ok(SE::MessageDelta {
1901 delta: crate::client::MessageDelta {
1902 stop_reason: Some("end_turn".into()),
1903 },
1904 usage: ApiUsage {
1905 input_tokens: 0,
1906 output_tokens: 15,
1907 cache_creation_input_tokens: None,
1908 cache_read_input_tokens: None,
1909 },
1910 }),
1911 Ok(SE::MessageStop),
1912 ];
1913
1914 let stream = futures::stream::iter(events);
1915 let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
1916 Box::pin(stream);
1917
1918 let (tx, mut rx) = mpsc::unbounded_channel();
1919
1920 let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
1921 .await
1922 .expect("accumulate_stream should succeed");
1923
1924 assert_eq!(response.id, "msg_123");
1926 assert_eq!(response.model, "claude-test");
1927 assert_eq!(response.stop_reason, Some("end_turn".into()));
1928 assert_eq!(response.usage.output_tokens, 15);
1929 assert_eq!(response.content.len(), 1);
1930 if let ApiContentBlock::Text { text, .. } = &response.content[0] {
1931 assert_eq!(text, "Hello world!");
1932 } else {
1933 panic!("expected Text content block");
1934 }
1935
1936 let mut stream_events = Vec::new();
1938 while let Ok(msg) = rx.try_recv() {
1939 stream_events.push(msg.unwrap());
1940 }
1941 assert_eq!(stream_events.len(), 3);
1942
1943 let expected_texts = ["Hello", " world", "!"];
1945 for (i, msg) in stream_events.iter().enumerate() {
1946 if let Message::StreamEvent(se) = msg {
1947 let delta = se.event.get("delta").unwrap();
1948 let text = delta.get("text").unwrap().as_str().unwrap();
1949 assert_eq!(text, expected_texts[i]);
1950 assert_eq!(se.session_id, "test-session");
1951 } else {
1952 panic!("expected StreamEvent message at index {}", i);
1953 }
1954 }
1955 }
1956
1957 #[tokio::test]
1958 async fn accumulate_stream_handles_tool_use() {
1959 use crate::client::{
1960 ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE,
1961 };
1962
1963 let events: Vec<Result<SE>> = vec![
1964 Ok(SE::MessageStart {
1965 message: MessageResponse {
1966 id: "msg_456".into(),
1967 role: "assistant".into(),
1968 content: vec![],
1969 model: "claude-test".into(),
1970 stop_reason: None,
1971 usage: ApiUsage::default(),
1972 },
1973 }),
1974 Ok(SE::ContentBlockStart {
1976 index: 0,
1977 content_block: ApiContentBlock::Text {
1978 text: String::new(),
1979 cache_control: None,
1980 },
1981 }),
1982 Ok(SE::ContentBlockDelta {
1983 index: 0,
1984 delta: ContentDelta::TextDelta {
1985 text: "Let me check.".into(),
1986 },
1987 }),
1988 Ok(SE::ContentBlockStop { index: 0 }),
1989 Ok(SE::ContentBlockStart {
1991 index: 1,
1992 content_block: ApiContentBlock::ToolUse {
1993 id: "toolu_abc".into(),
1994 name: "Read".into(),
1995 input: serde_json::json!({}),
1996 },
1997 }),
1998 Ok(SE::ContentBlockDelta {
1999 index: 1,
2000 delta: ContentDelta::InputJsonDelta {
2001 partial_json: r#"{"path":"/tmp/f.txt"}"#.into(),
2002 },
2003 }),
2004 Ok(SE::ContentBlockStop { index: 1 }),
2005 Ok(SE::MessageDelta {
2006 delta: crate::client::MessageDelta {
2007 stop_reason: Some("tool_use".into()),
2008 },
2009 usage: ApiUsage {
2010 input_tokens: 0,
2011 output_tokens: 20,
2012 ..Default::default()
2013 },
2014 }),
2015 Ok(SE::MessageStop),
2016 ];
2017
2018 let stream = futures::stream::iter(events);
2019 let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
2020 Box::pin(stream);
2021
2022 let (tx, _rx) = mpsc::unbounded_channel();
2023 let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
2024 .await
2025 .expect("should succeed");
2026
2027 assert_eq!(response.content.len(), 2);
2028 if let ApiContentBlock::Text { text, .. } = &response.content[0] {
2029 assert_eq!(text, "Let me check.");
2030 } else {
2031 panic!("expected Text block at index 0");
2032 }
2033 if let ApiContentBlock::ToolUse { id, name, input } = &response.content[1] {
2034 assert_eq!(id, "toolu_abc");
2035 assert_eq!(name, "Read");
2036 assert_eq!(input["path"], "/tmp/f.txt");
2037 } else {
2038 panic!("expected ToolUse block at index 1");
2039 }
2040 assert_eq!(response.stop_reason, Some("tool_use".into()));
2041 }
2042
2043 #[tokio::test]
2047 async fn accumulate_stream_preserves_openai_tool_input() {
2048 use crate::client::{ApiContentBlock, ApiUsage, StreamEvent as SE};
2049
2050 let events: Vec<Result<SE>> = vec![
2051 Ok(SE::MessageStart {
2052 message: MessageResponse {
2053 id: "msg_oai".into(),
2054 role: "assistant".into(),
2055 content: vec![],
2056 model: "qwen3:8b".into(),
2057 stop_reason: None,
2058 usage: ApiUsage::default(),
2059 },
2060 }),
2061 Ok(SE::ContentBlockStart {
2063 index: 0,
2064 content_block: ApiContentBlock::ToolUse {
2065 id: "call_123".into(),
2066 name: "Bash".into(),
2067 input: serde_json::json!({"command": "ls -la", "timeout": 5000}),
2068 },
2069 }),
2070 Ok(SE::ContentBlockStop { index: 0 }),
2072 Ok(SE::MessageDelta {
2073 delta: crate::client::MessageDelta {
2074 stop_reason: Some("tool_use".into()),
2075 },
2076 usage: ApiUsage {
2077 input_tokens: 0,
2078 output_tokens: 10,
2079 ..Default::default()
2080 },
2081 }),
2082 Ok(SE::MessageStop),
2083 ];
2084
2085 let stream = futures::stream::iter(events);
2086 let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
2087 Box::pin(stream);
2088
2089 let (tx, _rx) = mpsc::unbounded_channel();
2090 let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
2091 .await
2092 .expect("should succeed");
2093
2094 assert_eq!(response.content.len(), 1);
2095 if let ApiContentBlock::ToolUse { id, name, input } = &response.content[0] {
2096 assert_eq!(id, "call_123");
2097 assert_eq!(name, "Bash");
2098 assert_eq!(input["command"], "ls -la");
2099 assert_eq!(input["timeout"], 5000);
2100 } else {
2101 panic!("expected ToolUse block");
2102 }
2103 }
2104}