1use std::collections::HashMap;
8use std::path::PathBuf;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Instant;
12
13use futures::stream::FuturesUnordered;
14use futures::{Stream, StreamExt as FuturesStreamExt};
15use serde_json::json;
16use tokio::sync::mpsc;
17use tokio_stream::wrappers::UnboundedReceiverStream;
18use tracing::{debug, error, warn};
19use uuid::Uuid;
20
21use crate::client::{
22 ApiContentBlock, ApiMessage, ApiUsage, CacheControl, ContentDelta, CreateMessageRequest,
23 ImageSource, MessageResponse, StreamEvent as ClientStreamEvent, SystemBlock, ThinkingParam,
24 ToolDefinition,
25};
26use crate::compact;
27use crate::error::{AgentError, Result};
28use crate::sanitize;
29use crate::hooks::HookRegistry;
30use crate::options::{Options, PermissionMode, ThinkingConfig};
31use crate::permissions::{PermissionEvaluator, PermissionVerdict};
32use crate::provider::LlmProvider;
33use crate::providers::AnthropicProvider;
34use crate::session::Session;
35use crate::tools::definitions::get_tool_definitions;
36use crate::tools::executor::{ToolExecutor, ToolResult};
37use crate::types::messages::*;
38
39const DEFAULT_MODEL: &str = "claude-haiku-4-5";
41const DEFAULT_MAX_TOKENS: u32 = 16384;
43
44pub struct Query {
48 receiver: UnboundedReceiverStream<Result<Message>>,
49 session_id: Option<String>,
50 cancel_token: tokio_util::sync::CancellationToken,
51}
52
53impl Query {
54 pub async fn interrupt(&self) -> Result<()> {
56 self.cancel_token.cancel();
57 Ok(())
58 }
59
60 pub fn session_id(&self) -> Option<&str> {
62 self.session_id.as_deref()
63 }
64
65 pub async fn set_permission_mode(&self, _mode: PermissionMode) -> Result<()> {
67 Ok(())
69 }
70
71 pub async fn set_model(&self, _model: &str) -> Result<()> {
73 Ok(())
75 }
76
77 pub fn close(&self) {
79 self.cancel_token.cancel();
80 }
81}
82
83impl Stream for Query {
84 type Item = Result<Message>;
85
86 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
87 Pin::new(&mut self.receiver).poll_next(cx)
88 }
89}
90
91pub fn query(prompt: &str, options: Options) -> Query {
126 let (tx, rx) = mpsc::unbounded_channel();
127 let cancel_token = tokio_util::sync::CancellationToken::new();
128 let cancel = cancel_token.clone();
129
130 let prompt = prompt.to_string();
131
132 tokio::spawn(async move {
133 let result = run_agent_loop(prompt, options, tx.clone(), cancel).await;
134 if let Err(e) = result {
135 let _ = tx.send(Err(e));
136 }
137 });
138
139 Query {
140 receiver: UnboundedReceiverStream::new(rx),
141 session_id: None,
142 cancel_token,
143 }
144}
145
146async fn run_agent_loop(
156 prompt: String,
157 mut options: Options,
158 tx: mpsc::UnboundedSender<Result<Message>>,
159 cancel: tokio_util::sync::CancellationToken,
160) -> Result<()> {
161 let start_time = Instant::now();
162 let mut api_time_ms: u64 = 0;
163
164 let cwd = options
166 .cwd
167 .clone()
168 .unwrap_or_else(|| {
169 std::env::current_dir()
170 .unwrap_or_else(|_| PathBuf::from("."))
171 .to_string_lossy()
172 .to_string()
173 });
174
175 let session = if let Some(ref resume_id) = options.resume {
177 Session::with_id(resume_id, &cwd)
178 } else if options.continue_session {
179 match crate::session::find_most_recent_session(Some(&cwd)).await? {
181 Some(info) => Session::with_id(&info.session_id, &cwd),
182 None => Session::new(&cwd),
183 }
184 } else {
185 match &options.session_id {
186 Some(id) => Session::with_id(id, &cwd),
187 None => Session::new(&cwd),
188 }
189 };
190
191 let session_id = session.id.clone();
192 let model = options
193 .model
194 .clone()
195 .unwrap_or_else(|| DEFAULT_MODEL.to_string());
196
197 let tool_names: Vec<String> = if options.output_format.is_some() {
200 Vec::new()
201 } else if options.allowed_tools.is_empty() {
202 vec![
204 "Read".into(), "Write".into(), "Edit".into(), "Bash".into(),
205 "Glob".into(), "Grep".into(),
206 ]
207 } else {
208 options.allowed_tools.clone()
209 };
210
211 let raw_defs: Vec<_> = get_tool_definitions(&tool_names);
212
213 let mut all_defs: Vec<ToolDefinition> = raw_defs
215 .into_iter()
216 .map(|td| ToolDefinition {
217 name: td.name.to_string(),
218 description: td.description.to_string(),
219 input_schema: td.input_schema,
220 cache_control: None,
221 })
222 .collect();
223
224 for ctd in &options.custom_tool_definitions {
226 all_defs.push(ToolDefinition {
227 name: ctd.name.clone(),
228 description: ctd.description.clone(),
229 input_schema: ctd.input_schema.clone(),
230 cache_control: None,
231 });
232 }
233
234 if let Some(last) = all_defs.last_mut() {
236 last.cache_control = Some(CacheControl::ephemeral());
237 }
238
239 let tool_defs = all_defs;
240
241 let init_msg = Message::System(SystemMessage {
243 subtype: SystemSubtype::Init,
244 uuid: Uuid::new_v4(),
245 session_id: session_id.clone(),
246 agents: if options.agents.is_empty() {
247 None
248 } else {
249 Some(options.agents.keys().cloned().collect())
250 },
251 claude_code_version: Some(env!("CARGO_PKG_VERSION").to_string()),
252 cwd: Some(cwd.clone()),
253 tools: Some(tool_names.clone()),
254 mcp_servers: if options.mcp_servers.is_empty() {
255 None
256 } else {
257 Some(
258 options
259 .mcp_servers
260 .keys()
261 .map(|name| McpServerStatus {
262 name: name.clone(),
263 status: "connected".to_string(),
264 })
265 .collect(),
266 )
267 },
268 model: Some(model.clone()),
269 permission_mode: Some(options.permission_mode.to_string()),
270 compact_metadata: None,
271 });
272
273 if options.persist_session {
275 let _ = session.append_message(&serde_json::to_value(&init_msg).unwrap_or_default()).await;
276 }
277 if tx.send(Ok(init_msg)).is_err() {
278 return Ok(());
279 }
280
281 let provider: Box<dyn LlmProvider> = match options.provider.take() {
283 Some(p) => p,
284 None => Box::new(AnthropicProvider::from_env()?),
285 };
286
287 let additional_dirs: Vec<PathBuf> = options
289 .additional_directories
290 .iter()
291 .map(PathBuf::from)
292 .collect();
293 let env_blocklist = std::mem::take(&mut options.env_blocklist);
294 let tool_executor = if additional_dirs.is_empty() {
295 ToolExecutor::new(PathBuf::from(&cwd))
296 } else {
297 ToolExecutor::with_allowed_dirs(PathBuf::from(&cwd), additional_dirs)
298 }.with_env_blocklist(env_blocklist);
299
300 let mut hook_registry = HookRegistry::from_map(std::mem::take(&mut options.hooks));
302 if !options.hook_dirs.is_empty() {
303 let dirs: Vec<&std::path::Path> = options.hook_dirs.iter().map(|p| p.as_path()).collect();
304 match crate::hooks::HookDiscovery::discover(&dirs) {
305 Ok(discovered) => hook_registry.merge(discovered),
306 Err(e) => tracing::warn!("Failed to discover hooks from dirs: {}", e),
307 }
308 }
309
310 let mut followup_rx = options.followup_rx.take();
312
313 let permission_eval = PermissionEvaluator::new(&options);
315
316 let system_prompt: Option<Vec<SystemBlock>> = {
318 let text = match &options.system_prompt {
319 Some(crate::options::SystemPrompt::Custom(s)) => s.clone(),
320 Some(crate::options::SystemPrompt::Preset { append, .. }) => {
321 let base = "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.";
322 match append {
323 Some(extra) => format!("{}\n\n{}", base, extra),
324 None => base.to_string(),
325 }
326 }
327 None => "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.".to_string(),
328 };
329 Some(vec![SystemBlock {
330 kind: "text".to_string(),
331 text,
332 cache_control: Some(CacheControl::ephemeral()),
333 }])
334 };
335
336 let mut conversation: Vec<ApiMessage> = Vec::new();
338
339 if options.resume.is_some() || options.continue_session {
341 let prev_messages = session.load_messages().await?;
342 for msg_value in prev_messages {
343 if let Some(api_msg) = value_to_api_message(&msg_value) {
344 conversation.push(api_msg);
345 }
346 }
347 }
348
349 {
351 let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
352
353 for att in &options.attachments {
355 let is_image = matches!(
356 att.mime_type.as_str(),
357 "image/png" | "image/jpeg" | "image/gif" | "image/webp"
358 );
359 if is_image {
360 content_blocks.push(ApiContentBlock::Image {
361 source: ImageSource {
362 kind: "base64".to_string(),
363 media_type: att.mime_type.clone(),
364 data: att.base64_data.clone(),
365 },
366 });
367 }
368 }
369
370 content_blocks.push(ApiContentBlock::Text {
372 text: prompt.clone(),
373 cache_control: None,
374 });
375
376 conversation.push(ApiMessage {
377 role: "user".to_string(),
378 content: content_blocks,
379 });
380 }
381
382 if options.persist_session {
384 let user_msg = json!({
385 "type": "user",
386 "uuid": Uuid::new_v4().to_string(),
387 "session_id": &session_id,
388 "content": [{"type": "text", "text": &prompt}]
389 });
390 let _ = session.append_message(&user_msg).await;
391 }
392
393 let mut num_turns: u32 = 0;
395 let mut total_usage = Usage::default();
396 let mut total_cost: f64 = 0.0;
397 let mut model_usage: HashMap<String, ModelUsage> = HashMap::new();
398 let mut permission_denials: Vec<PermissionDenial> = Vec::new();
399
400 loop {
401 if cancel.is_cancelled() {
403 return Err(AgentError::Cancelled);
404 }
405
406 if let Some(max_turns) = options.max_turns {
408 if num_turns >= max_turns {
409 let result_msg = build_result_message(
410 ResultSubtype::ErrorMaxTurns,
411 &session_id,
412 None,
413 start_time,
414 api_time_ms,
415 num_turns,
416 total_cost,
417 &total_usage,
418 &model_usage,
419 &permission_denials,
420 );
421 let _ = tx.send(Ok(result_msg));
422 return Ok(());
423 }
424 }
425
426 if let Some(max_budget) = options.max_budget_usd {
428 if total_cost >= max_budget {
429 let result_msg = build_result_message(
430 ResultSubtype::ErrorMaxBudgetUsd,
431 &session_id,
432 None,
433 start_time,
434 api_time_ms,
435 num_turns,
436 total_cost,
437 &total_usage,
438 &model_usage,
439 &permission_denials,
440 );
441 let _ = tx.send(Ok(result_msg));
442 return Ok(());
443 }
444 }
445
446 if let Some(ref mut followup_rx) = followup_rx {
450 let mut followups: Vec<String> = Vec::new();
451 while let Ok(msg) = followup_rx.try_recv() {
452 followups.push(msg);
453 }
454 if !followups.is_empty() {
455 let combined = followups.join("\n\n");
456 debug!(count = followups.len(), "Injecting followup messages into agent loop");
457
458 conversation.push(ApiMessage {
459 role: "user".to_string(),
460 content: vec![ApiContentBlock::Text {
461 text: combined.clone(),
462 cache_control: None,
463 }],
464 });
465
466 let followup_msg = Message::User(UserMessage {
468 uuid: Some(Uuid::new_v4()),
469 session_id: session_id.clone(),
470 content: vec![ContentBlock::Text { text: combined }],
471 parent_tool_use_id: None,
472 is_synthetic: false,
473 tool_use_result: None,
474 });
475
476 if options.persist_session {
477 let _ = session
478 .append_message(&serde_json::to_value(&followup_msg).unwrap_or_default())
479 .await;
480 }
481 if tx.send(Ok(followup_msg)).is_err() {
482 return Ok(());
483 }
484 }
485 }
486
487 apply_cache_breakpoint(&mut conversation);
491
492 let thinking_param = options.thinking.as_ref().map(|tc| match tc {
494 ThinkingConfig::Adaptive => ThinkingParam {
495 kind: "enabled".into(),
496 budget_tokens: Some(10240),
497 },
498 ThinkingConfig::Disabled => ThinkingParam {
499 kind: "disabled".into(),
500 budget_tokens: None,
501 },
502 ThinkingConfig::Enabled { budget_tokens } => ThinkingParam {
503 kind: "enabled".into(),
504 budget_tokens: Some(*budget_tokens),
505 },
506 });
507
508 let base_max_tokens = options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
510 let max_tokens = if let Some(ref tp) = thinking_param {
511 if let Some(budget) = tp.budget_tokens {
512 base_max_tokens.max(budget as u32 + 8192)
513 } else {
514 base_max_tokens
515 }
516 } else {
517 base_max_tokens
518 };
519
520 let use_streaming = options.include_partial_messages;
522 let request = CreateMessageRequest {
523 model: model.clone(),
524 max_tokens,
525 messages: conversation.clone(),
526 system: system_prompt.clone(),
527 tools: if tool_defs.is_empty() {
528 None
529 } else {
530 Some(tool_defs.clone())
531 },
532 stream: use_streaming,
533 metadata: None,
534 thinking: thinking_param,
535 };
536
537 let api_start = Instant::now();
539 let response = if use_streaming {
540 match provider.create_message_stream(&request).await {
542 Ok(mut event_stream) => {
543 match accumulate_stream(
544 &mut event_stream,
545 &tx,
546 &session_id,
547 ).await {
548 Ok(resp) => resp,
549 Err(e) => {
550 error!("Stream accumulation failed: {}", e);
551 let result_msg = build_error_result_message(
552 &session_id,
553 &format!("Stream error: {}", e),
554 start_time,
555 api_time_ms,
556 num_turns,
557 total_cost,
558 &total_usage,
559 &model_usage,
560 &permission_denials,
561 );
562 let _ = tx.send(Ok(result_msg));
563 return Ok(());
564 }
565 }
566 }
567 Err(e) => {
568 error!("API stream call failed: {}", e);
569 let result_msg = build_error_result_message(
570 &session_id,
571 &format!("API error: {}", e),
572 start_time,
573 api_time_ms,
574 num_turns,
575 total_cost,
576 &total_usage,
577 &model_usage,
578 &permission_denials,
579 );
580 let _ = tx.send(Ok(result_msg));
581 return Ok(());
582 }
583 }
584 } else {
585 match provider.create_message(&request).await {
587 Ok(resp) => resp,
588 Err(e) => {
589 error!("API call failed: {}", e);
590 let result_msg = build_error_result_message(
591 &session_id,
592 &format!("API error: {}", e),
593 start_time,
594 api_time_ms,
595 num_turns,
596 total_cost,
597 &total_usage,
598 &model_usage,
599 &permission_denials,
600 );
601 let _ = tx.send(Ok(result_msg));
602 return Ok(());
603 }
604 }
605 };
606 api_time_ms += api_start.elapsed().as_millis() as u64;
607
608 total_usage.input_tokens += response.usage.input_tokens;
610 total_usage.output_tokens += response.usage.output_tokens;
611 total_usage.cache_creation_input_tokens +=
612 response.usage.cache_creation_input_tokens.unwrap_or(0);
613 total_usage.cache_read_input_tokens +=
614 response.usage.cache_read_input_tokens.unwrap_or(0);
615
616 let rates = provider.cost_rates(&model);
618 let turn_cost = rates.compute_with_cache(
619 response.usage.input_tokens,
620 response.usage.output_tokens,
621 response.usage.cache_read_input_tokens.unwrap_or(0),
622 response.usage.cache_creation_input_tokens.unwrap_or(0),
623 );
624 total_cost += turn_cost;
625
626 let model_entry = model_usage
628 .entry(model.clone())
629 .or_insert_with(ModelUsage::default);
630 model_entry.input_tokens += response.usage.input_tokens;
631 model_entry.output_tokens += response.usage.output_tokens;
632 model_entry.cost_usd += turn_cost;
633
634 let content_blocks: Vec<ContentBlock> = response
636 .content
637 .iter()
638 .map(api_block_to_content_block)
639 .collect();
640
641 let assistant_msg = Message::Assistant(AssistantMessage {
643 uuid: Uuid::new_v4(),
644 session_id: session_id.clone(),
645 content: content_blocks.clone(),
646 model: response.model.clone(),
647 stop_reason: response.stop_reason.clone(),
648 parent_tool_use_id: None,
649 usage: Some(Usage {
650 input_tokens: response.usage.input_tokens,
651 output_tokens: response.usage.output_tokens,
652 cache_creation_input_tokens: response
653 .usage
654 .cache_creation_input_tokens
655 .unwrap_or(0),
656 cache_read_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
657 }),
658 error: None,
659 });
660
661 if options.persist_session {
662 let _ = session
663 .append_message(&serde_json::to_value(&assistant_msg).unwrap_or_default())
664 .await;
665 }
666 if tx.send(Ok(assistant_msg)).is_err() {
667 return Ok(());
668 }
669
670 conversation.push(ApiMessage {
672 role: "assistant".to_string(),
673 content: response.content.clone(),
674 });
675
676 let tool_uses: Vec<_> = response
678 .content
679 .iter()
680 .filter_map(|block| match block {
681 ApiContentBlock::ToolUse { id, name, input } => {
682 Some((id.clone(), name.clone(), input.clone()))
683 }
684 _ => None,
685 })
686 .collect();
687
688 if tool_uses.is_empty() {
690 let final_text = response
692 .content
693 .iter()
694 .filter_map(|block| match block {
695 ApiContentBlock::Text { text, .. } => Some(text.as_str()),
696 _ => None,
697 })
698 .collect::<Vec<_>>()
699 .join("");
700
701 let result_msg = build_result_message(
702 ResultSubtype::Success,
703 &session_id,
704 Some(final_text),
705 start_time,
706 api_time_ms,
707 num_turns,
708 total_cost,
709 &total_usage,
710 &model_usage,
711 &permission_denials,
712 );
713
714 if options.persist_session {
715 let _ = session
716 .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
717 .await;
718 }
719 let _ = tx.send(Ok(result_msg));
720 return Ok(());
721 }
722
723 num_turns += 1;
725 let mut tool_results: Vec<ApiContentBlock> = Vec::new();
726
727 struct PermittedTool {
729 tool_use_id: String,
730 tool_name: String,
731 actual_input: serde_json::Value,
732 }
733 let mut permitted_tools: Vec<PermittedTool> = Vec::new();
734
735 for (tool_use_id, tool_name, tool_input) in &tool_uses {
736 let verdict = permission_eval
737 .evaluate(tool_name, tool_input, tool_use_id, &session_id, &cwd)
738 .await?;
739
740 let actual_input = match &verdict {
741 PermissionVerdict::AllowWithUpdatedInput(new_input) => new_input.clone(),
742 _ => tool_input.clone(),
743 };
744
745 match verdict {
746 PermissionVerdict::Allow | PermissionVerdict::AllowWithUpdatedInput(_) => {
747 permitted_tools.push(PermittedTool {
748 tool_use_id: tool_use_id.clone(),
749 tool_name: tool_name.clone(),
750 actual_input,
751 });
752 }
753 PermissionVerdict::Deny { reason } => {
754 debug!(tool = %tool_name, reason = %reason, "Tool denied");
755 permission_denials.push(PermissionDenial {
756 tool_name: tool_name.clone(),
757 tool_use_id: tool_use_id.clone(),
758 tool_input: tool_input.clone(),
759 });
760
761 let api_block = ApiContentBlock::ToolResult {
762 tool_use_id: tool_use_id.clone(),
763 content: json!(format!("Permission denied: {}", reason)),
764 is_error: Some(true),
765 cache_control: None,
766 name: Some(tool_name.clone()),
767 };
768
769 let denial_msg = Message::User(UserMessage {
771 uuid: Some(Uuid::new_v4()),
772 session_id: session_id.clone(),
773 content: vec![api_block_to_content_block(&api_block)],
774 parent_tool_use_id: None,
775 is_synthetic: true,
776 tool_use_result: None,
777 });
778 if options.persist_session {
779 let _ = session
780 .append_message(&serde_json::to_value(&denial_msg).unwrap_or_default())
781 .await;
782 }
783 if tx.send(Ok(denial_msg)).is_err() {
784 return Ok(());
785 }
786
787 tool_results.push(api_block);
788 }
789 }
790 }
791
792 let mut futs: FuturesUnordered<_> = permitted_tools
794 .iter()
795 .map(|pt| {
796 let handler = &options.external_tool_handler;
797 let executor = &tool_executor;
798 let name = &pt.tool_name;
799 let input = &pt.actual_input;
800 let id = &pt.tool_use_id;
801 async move {
802 debug!(tool = %name, "Executing tool");
803
804 let tool_result = if let Some(ref handler) = handler {
805 let ext_result = handler(name.clone(), input.clone()).await;
806 if let Some(tr) = ext_result {
807 tr
808 } else {
809 match executor.execute(name, input.clone()).await {
810 Ok(tr) => tr,
811 Err(e) => ToolResult {
812 content: format!("Tool execution error: {}", e),
813 is_error: true,
814 raw_content: None,
815 },
816 }
817 }
818 } else {
819 match executor.execute(name, input.clone()).await {
820 Ok(tr) => tr,
821 Err(e) => ToolResult {
822 content: format!("Tool execution error: {}", e),
823 is_error: true,
824 raw_content: None,
825 },
826 }
827 };
828 (id.as_str(), name.as_str(), input, tool_result)
829 }
830 })
831 .collect();
832
833 while let Some((tool_use_id, tool_name, actual_input, mut tool_result)) = futs.next().await {
834 let max_result_bytes = options
836 .max_tool_result_bytes
837 .unwrap_or(sanitize::DEFAULT_MAX_TOOL_RESULT_BYTES);
838 tool_result.content =
839 sanitize::sanitize_tool_result(&tool_result.content, max_result_bytes);
840
841 hook_registry.run_post_tool_use(
843 tool_name,
844 actual_input,
845 &serde_json::to_value(&tool_result.content).unwrap_or_default(),
846 tool_use_id,
847 &session_id,
848 &cwd,
849 ).await;
850
851 let result_content = tool_result
852 .raw_content
853 .unwrap_or_else(|| json!(tool_result.content));
854
855 let api_block = ApiContentBlock::ToolResult {
856 tool_use_id: tool_use_id.to_string(),
857 content: result_content,
858 is_error: if tool_result.is_error {
859 Some(true)
860 } else {
861 None
862 },
863 cache_control: None,
864 name: Some(tool_name.to_string()),
865 };
866
867 let result_msg = Message::User(UserMessage {
869 uuid: Some(Uuid::new_v4()),
870 session_id: session_id.clone(),
871 content: vec![api_block_to_content_block(&api_block)],
872 parent_tool_use_id: None,
873 is_synthetic: true,
874 tool_use_result: None,
875 });
876 if options.persist_session {
877 let _ = session
878 .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
879 .await;
880 }
881 if tx.send(Ok(result_msg)).is_err() {
882 return Ok(());
883 }
884
885 tool_results.push(api_block);
886 }
887
888 conversation.push(ApiMessage {
890 role: "user".to_string(),
891 content: tool_results,
892 });
893
894 if let Some(context_budget) = options.context_budget {
896 let prune_pct = options
897 .prune_threshold_pct
898 .unwrap_or(compact::DEFAULT_PRUNE_THRESHOLD_PCT);
899 if compact::should_prune(response.usage.input_tokens, context_budget, prune_pct) {
900 let max_chars = options
901 .prune_tool_result_max_chars
902 .unwrap_or(compact::DEFAULT_PRUNE_TOOL_RESULT_MAX_CHARS);
903 let min_keep = options.min_keep_messages.unwrap_or(4);
904 let removed = compact::prune_tool_results(
905 &mut conversation,
906 max_chars,
907 min_keep,
908 );
909 if removed > 0 {
910 debug!(
911 chars_removed = removed,
912 input_tokens = response.usage.input_tokens,
913 "Pruned oversized tool results to free context space"
914 );
915 }
916 }
917 }
918
919 if let Some(context_budget) = options.context_budget {
921 if compact::should_compact(response.usage.input_tokens, context_budget) {
922 let min_keep = options.min_keep_messages.unwrap_or(4);
923 let split_point = compact::find_split_point(&conversation, min_keep);
924 if split_point > 0 {
925 debug!(
926 input_tokens = response.usage.input_tokens,
927 context_budget,
928 split_point,
929 "Context budget exceeded, compacting conversation"
930 );
931
932 let compaction_model = options
933 .compaction_model
934 .as_deref()
935 .unwrap_or(compact::DEFAULT_COMPACTION_MODEL);
936
937 if let Some(ref handler) = options.pre_compact_handler {
939 let msgs_to_compact = conversation[..split_point].to_vec();
940 handler(msgs_to_compact).await;
941 }
942
943 let summary_prompt =
944 compact::build_summary_prompt(&conversation[..split_point]);
945
946 let summary_max_tokens = options.summary_max_tokens.unwrap_or(4096);
947 let compact_provider: &dyn LlmProvider = match &options.compaction_provider {
948 Some(cp) => cp.as_ref(),
949 None => provider.as_ref(),
950 };
951 let fallback_provider: Option<&dyn LlmProvider> = if options.compaction_provider.is_some() {
952 Some(provider.as_ref())
953 } else {
954 None
955 };
956 match compact::call_summarizer(
957 compact_provider,
958 &summary_prompt,
959 compaction_model,
960 fallback_provider,
961 &model,
962 summary_max_tokens,
963 )
964 .await
965 {
966 Ok(summary) => {
967 let pre_tokens = response.usage.input_tokens;
968 let messages_compacted = split_point;
969
970 compact::splice_conversation(
971 &mut conversation,
972 split_point,
973 &summary,
974 );
975
976 let compact_msg = Message::System(SystemMessage {
978 subtype: SystemSubtype::CompactBoundary,
979 uuid: Uuid::new_v4(),
980 session_id: session_id.clone(),
981 agents: None,
982 claude_code_version: None,
983 cwd: None,
984 tools: None,
985 mcp_servers: None,
986 model: None,
987 permission_mode: None,
988 compact_metadata: Some(CompactMetadata {
989 trigger: CompactTrigger::Auto,
990 pre_tokens,
991 }),
992 });
993
994 if options.persist_session {
995 let _ = session
996 .append_message(
997 &serde_json::to_value(&compact_msg)
998 .unwrap_or_default(),
999 )
1000 .await;
1001 }
1002 let _ = tx.send(Ok(compact_msg));
1003
1004 debug!(
1005 pre_tokens,
1006 messages_compacted,
1007 summary_len = summary.len(),
1008 "Conversation compacted"
1009 );
1010 }
1011 Err(e) => {
1012 warn!(
1013 "Compaction failed, continuing without compaction: {}",
1014 e
1015 );
1016 }
1017 }
1018 }
1019 }
1020 }
1021 }
1022}
1023
1024async fn accumulate_stream(
1027 event_stream: &mut std::pin::Pin<Box<dyn futures::Stream<Item = Result<ClientStreamEvent>> + Send>>,
1028 tx: &mpsc::UnboundedSender<Result<Message>>,
1029 session_id: &str,
1030) -> Result<MessageResponse> {
1031 use crate::client::StreamEvent as SE;
1032
1033 let mut message_id = String::new();
1035 let mut model = String::new();
1036 let mut role = String::from("assistant");
1037 let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
1038 let mut stop_reason: Option<String> = None;
1039 let mut usage = ApiUsage::default();
1040
1041 let mut block_texts: Vec<String> = Vec::new();
1044 let mut block_types: Vec<String> = Vec::new(); let mut block_tool_ids: Vec<String> = Vec::new();
1046 let mut block_tool_names: Vec<String> = Vec::new();
1047
1048 while let Some(event_result) = FuturesStreamExt::next(event_stream).await {
1049 let event = event_result?;
1050 match event {
1051 SE::MessageStart { message } => {
1052 message_id = message.id;
1053 model = message.model;
1054 role = message.role;
1055 usage = message.usage;
1056 }
1057 SE::ContentBlockStart { index, content_block } => {
1058 while block_texts.len() <= index {
1060 block_texts.push(String::new());
1061 block_types.push(String::new());
1062 block_tool_ids.push(String::new());
1063 block_tool_names.push(String::new());
1064 }
1065 match &content_block {
1066 ApiContentBlock::Text { .. } => {
1067 block_types[index] = "text".to_string();
1068 }
1069 ApiContentBlock::ToolUse { id, name, .. } => {
1070 block_types[index] = "tool_use".to_string();
1071 block_tool_ids[index] = id.clone();
1072 block_tool_names[index] = name.clone();
1073 }
1074 ApiContentBlock::Thinking { .. } => {
1075 block_types[index] = "thinking".to_string();
1076 }
1077 _ => {}
1078 }
1079 }
1080 SE::ContentBlockDelta { index, delta } => {
1081 while block_texts.len() <= index {
1082 block_texts.push(String::new());
1083 block_types.push(String::new());
1084 block_tool_ids.push(String::new());
1085 block_tool_names.push(String::new());
1086 }
1087 match &delta {
1088 ContentDelta::TextDelta { text } => {
1089 block_texts[index].push_str(text);
1090 let stream_event = Message::StreamEvent(StreamEventMessage {
1092 event: serde_json::json!({
1093 "type": "content_block_delta",
1094 "index": index,
1095 "delta": { "type": "text_delta", "text": text }
1096 }),
1097 parent_tool_use_id: None,
1098 uuid: Uuid::new_v4(),
1099 session_id: session_id.to_string(),
1100 });
1101 if tx.send(Ok(stream_event)).is_err() {
1102 return Err(AgentError::Cancelled);
1103 }
1104 }
1105 ContentDelta::InputJsonDelta { partial_json } => {
1106 block_texts[index].push_str(partial_json);
1107 }
1108 ContentDelta::ThinkingDelta { thinking } => {
1109 block_texts[index].push_str(thinking);
1110 }
1111 }
1112 }
1113 SE::ContentBlockStop { index } => {
1114 if index < block_types.len() {
1115 let block = match block_types[index].as_str() {
1116 "text" => ApiContentBlock::Text {
1117 text: std::mem::take(&mut block_texts[index]),
1118 cache_control: None,
1119 },
1120 "tool_use" => {
1121 let input: serde_json::Value = serde_json::from_str(
1122 &block_texts[index],
1123 )
1124 .unwrap_or(serde_json::Value::Object(Default::default()));
1125 ApiContentBlock::ToolUse {
1126 id: std::mem::take(&mut block_tool_ids[index]),
1127 name: std::mem::take(&mut block_tool_names[index]),
1128 input,
1129 }
1130 }
1131 "thinking" => ApiContentBlock::Thinking {
1132 thinking: std::mem::take(&mut block_texts[index]),
1133 },
1134 _ => continue,
1135 };
1136 while content_blocks.len() <= index {
1138 content_blocks.push(ApiContentBlock::Text {
1139 text: String::new(),
1140 cache_control: None,
1141 });
1142 }
1143 content_blocks[index] = block;
1144 }
1145 }
1146 SE::MessageDelta { delta, usage: delta_usage } => {
1147 stop_reason = delta.stop_reason;
1148 usage.output_tokens = delta_usage.output_tokens;
1150 }
1151 SE::MessageStop => {
1152 break;
1153 }
1154 SE::Error { error } => {
1155 return Err(AgentError::Api(error.message));
1156 }
1157 SE::Ping => {}
1158 }
1159 }
1160
1161 Ok(MessageResponse {
1162 id: message_id,
1163 role,
1164 content: content_blocks,
1165 model,
1166 stop_reason,
1167 usage,
1168 })
1169}
1170
1171fn apply_cache_breakpoint(conversation: &mut [ApiMessage]) {
1176 for msg in conversation.iter_mut() {
1178 for block in msg.content.iter_mut() {
1179 match block {
1180 ApiContentBlock::Text { cache_control, .. }
1181 | ApiContentBlock::ToolResult { cache_control, .. } => {
1182 *cache_control = None;
1183 }
1184 ApiContentBlock::Image { .. }
1185 | ApiContentBlock::ToolUse { .. }
1186 | ApiContentBlock::Thinking { .. } => {}
1187 }
1188 }
1189 }
1190
1191 if let Some(last_user) = conversation.iter_mut().rev().find(|m| m.role == "user") {
1193 if let Some(last_block) = last_user.content.last_mut() {
1194 match last_block {
1195 ApiContentBlock::Text { cache_control, .. }
1196 | ApiContentBlock::ToolResult { cache_control, .. } => {
1197 *cache_control = Some(CacheControl::ephemeral());
1198 }
1199 ApiContentBlock::Image { .. }
1200 | ApiContentBlock::ToolUse { .. }
1201 | ApiContentBlock::Thinking { .. } => {}
1202 }
1203 }
1204 }
1205}
1206
1207fn api_block_to_content_block(block: &ApiContentBlock) -> ContentBlock {
1209 match block {
1210 ApiContentBlock::Text { text, .. } => ContentBlock::Text {
1211 text: text.clone(),
1212 },
1213 ApiContentBlock::Image { .. } => ContentBlock::Text {
1214 text: "[image]".to_string(),
1215 },
1216 ApiContentBlock::ToolUse { id, name, input } => ContentBlock::ToolUse {
1217 id: id.clone(),
1218 name: name.clone(),
1219 input: input.clone(),
1220 },
1221 ApiContentBlock::ToolResult {
1222 tool_use_id,
1223 content,
1224 is_error,
1225 ..
1226 } => ContentBlock::ToolResult {
1227 tool_use_id: tool_use_id.clone(),
1228 content: content.clone(),
1229 is_error: *is_error,
1230 },
1231 ApiContentBlock::Thinking { thinking } => ContentBlock::Thinking {
1232 thinking: thinking.clone(),
1233 },
1234 }
1235}
1236
1237fn value_to_api_message(value: &serde_json::Value) -> Option<ApiMessage> {
1239 let msg_type = value.get("type")?.as_str()?;
1240
1241 match msg_type {
1242 "assistant" => {
1243 let content = value.get("content")?;
1244 let blocks = parse_content_blocks(content)?;
1245 Some(ApiMessage {
1246 role: "assistant".to_string(),
1247 content: blocks,
1248 })
1249 }
1250 "user" => {
1251 let content = value.get("content")?;
1252 let blocks = parse_content_blocks(content)?;
1253 Some(ApiMessage {
1254 role: "user".to_string(),
1255 content: blocks,
1256 })
1257 }
1258 _ => None,
1259 }
1260}
1261
1262fn parse_content_blocks(content: &serde_json::Value) -> Option<Vec<ApiContentBlock>> {
1264 if let Some(text) = content.as_str() {
1265 return Some(vec![ApiContentBlock::Text {
1266 text: text.to_string(),
1267 cache_control: None,
1268 }]);
1269 }
1270
1271 if let Some(blocks) = content.as_array() {
1272 let parsed: Vec<ApiContentBlock> = blocks
1273 .iter()
1274 .filter_map(|b| serde_json::from_value(b.clone()).ok())
1275 .collect();
1276 if !parsed.is_empty() {
1277 return Some(parsed);
1278 }
1279 }
1280
1281 None
1282}
1283
1284fn build_result_message(
1286 subtype: ResultSubtype,
1287 session_id: &str,
1288 result_text: Option<String>,
1289 start_time: Instant,
1290 api_time_ms: u64,
1291 num_turns: u32,
1292 total_cost: f64,
1293 usage: &Usage,
1294 model_usage: &HashMap<String, ModelUsage>,
1295 permission_denials: &[PermissionDenial],
1296) -> Message {
1297 Message::Result(ResultMessage {
1298 subtype,
1299 uuid: Uuid::new_v4(),
1300 session_id: session_id.to_string(),
1301 duration_ms: start_time.elapsed().as_millis() as u64,
1302 duration_api_ms: api_time_ms,
1303 is_error: result_text.is_none(),
1304 num_turns,
1305 result: result_text,
1306 stop_reason: Some("end_turn".to_string()),
1307 total_cost_usd: total_cost,
1308 usage: Some(usage.clone()),
1309 model_usage: model_usage.clone(),
1310 permission_denials: permission_denials.to_vec(),
1311 structured_output: None,
1312 errors: Vec::new(),
1313 })
1314}
1315
1316fn build_error_result_message(
1318 session_id: &str,
1319 error_msg: &str,
1320 start_time: Instant,
1321 api_time_ms: u64,
1322 num_turns: u32,
1323 total_cost: f64,
1324 usage: &Usage,
1325 model_usage: &HashMap<String, ModelUsage>,
1326 permission_denials: &[PermissionDenial],
1327) -> Message {
1328 Message::Result(ResultMessage {
1329 subtype: ResultSubtype::ErrorDuringExecution,
1330 uuid: Uuid::new_v4(),
1331 session_id: session_id.to_string(),
1332 duration_ms: start_time.elapsed().as_millis() as u64,
1333 duration_api_ms: api_time_ms,
1334 is_error: true,
1335 num_turns,
1336 result: None,
1337 stop_reason: None,
1338 total_cost_usd: total_cost,
1339 usage: Some(usage.clone()),
1340 model_usage: model_usage.clone(),
1341 permission_denials: permission_denials.to_vec(),
1342 structured_output: None,
1343 errors: vec![error_msg.to_string()],
1344 })
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349 use super::*;
1350 use std::sync::atomic::{AtomicUsize, Ordering};
1351 use std::sync::Arc;
1352 use std::time::Duration;
1353
1354 async fn run_concurrent_tools(
1357 tools: Vec<(String, String, serde_json::Value)>,
1358 handler: impl Fn(String, serde_json::Value) -> Pin<Box<dyn futures::Future<Output = Option<ToolResult>> + Send>>,
1359 ) -> Vec<(String, String, usize)> {
1360 let order = Arc::new(AtomicUsize::new(0));
1361 let handler = Arc::new(handler);
1362
1363 struct PermittedTool {
1364 tool_use_id: String,
1365 tool_name: String,
1366 actual_input: serde_json::Value,
1367 }
1368
1369 let permitted: Vec<PermittedTool> = tools
1370 .into_iter()
1371 .map(|(id, name, input)| PermittedTool {
1372 tool_use_id: id,
1373 tool_name: name,
1374 actual_input: input,
1375 })
1376 .collect();
1377
1378 let mut futs: FuturesUnordered<_> = permitted
1379 .iter()
1380 .map(|pt| {
1381 let handler = handler.clone();
1382 let order = order.clone();
1383 let name = pt.tool_name.clone();
1384 let input = pt.actual_input.clone();
1385 let id = pt.tool_use_id.clone();
1386 async move {
1387 let result = handler(name, input).await;
1388 let seq = order.fetch_add(1, Ordering::SeqCst);
1389 (id, result, seq)
1390 }
1391 })
1392 .collect();
1393
1394 let mut results = Vec::new();
1395 while let Some((id, result, seq)) = futs.next().await {
1396 let content = result
1397 .map(|r| r.content)
1398 .unwrap_or_else(|| "no handler".into());
1399 results.push((id, content, seq));
1400 }
1401 results
1402 }
1403
1404 #[tokio::test]
1405 async fn concurrent_tools_all_complete() {
1406 let results = run_concurrent_tools(
1407 vec![
1408 ("t1".into(), "Read".into(), json!({"path": "a.txt"})),
1409 ("t2".into(), "Read".into(), json!({"path": "b.txt"})),
1410 ("t3".into(), "Read".into(), json!({"path": "c.txt"})),
1411 ],
1412 |name, input| {
1413 Box::pin(async move {
1414 let path = input["path"].as_str().unwrap_or("?");
1415 Some(ToolResult {
1416 content: format!("{}: {}", name, path),
1417 is_error: false,
1418 raw_content: None,
1419 })
1420 })
1421 },
1422 )
1423 .await;
1424
1425 assert_eq!(results.len(), 3);
1426 let ids: Vec<&str> = results.iter().map(|(id, _, _)| id.as_str()).collect();
1427 assert!(ids.contains(&"t1"));
1428 assert!(ids.contains(&"t2"));
1429 assert!(ids.contains(&"t3"));
1430 }
1431
1432 #[tokio::test]
1433 async fn slow_tool_does_not_block_fast_tools() {
1434 let start = Instant::now();
1435
1436 let results = run_concurrent_tools(
1437 vec![
1438 ("slow".into(), "Bash".into(), json!({})),
1439 ("fast1".into(), "Read".into(), json!({})),
1440 ("fast2".into(), "Read".into(), json!({})),
1441 ],
1442 |name, _input| {
1443 Box::pin(async move {
1444 if name == "Bash" {
1445 tokio::time::sleep(Duration::from_millis(200)).await;
1446 Some(ToolResult {
1447 content: "slow done".into(),
1448 is_error: false,
1449 raw_content: None,
1450 })
1451 } else {
1452 Some(ToolResult {
1454 content: "fast done".into(),
1455 is_error: false,
1456 raw_content: None,
1457 })
1458 }
1459 })
1460 },
1461 )
1462 .await;
1463
1464 let elapsed = start.elapsed();
1465
1466 assert_eq!(results.len(), 3);
1468
1469 let slow = results.iter().find(|(id, _, _)| id == "slow").unwrap();
1471 let fast1 = results.iter().find(|(id, _, _)| id == "fast1").unwrap();
1472 let fast2 = results.iter().find(|(id, _, _)| id == "fast2").unwrap();
1473
1474 assert!(fast1.2 < slow.2, "fast1 should complete before slow");
1475 assert!(fast2.2 < slow.2, "fast2 should complete before slow");
1476
1477 assert!(
1479 elapsed < Duration::from_millis(400),
1480 "elapsed {:?} should be under 400ms (concurrent execution)",
1481 elapsed
1482 );
1483 }
1484
1485 #[tokio::test]
1486 async fn results_streamed_individually_as_they_complete() {
1487 let (tx, mut rx) = mpsc::unbounded_channel::<(String, String)>();
1490
1491 let tools = vec![
1492 ("t_slow".into(), "Slow".into(), json!({})),
1493 ("t_fast".into(), "Fast".into(), json!({})),
1494 ];
1495
1496 struct PT {
1497 tool_use_id: String,
1498 tool_name: String,
1499 }
1500
1501 let permitted: Vec<PT> = tools
1502 .into_iter()
1503 .map(|(id, name, _)| PT {
1504 tool_use_id: id,
1505 tool_name: name,
1506 })
1507 .collect();
1508
1509 let mut futs: FuturesUnordered<_> = permitted
1510 .iter()
1511 .map(|pt| {
1512 let name = pt.tool_name.clone();
1513 let id = pt.tool_use_id.clone();
1514 async move {
1515 if name == "Slow" {
1516 tokio::time::sleep(Duration::from_millis(100)).await;
1517 }
1518 let result = ToolResult {
1519 content: format!("{} result", name),
1520 is_error: false,
1521 raw_content: None,
1522 };
1523 (id, result)
1524 }
1525 })
1526 .collect();
1527
1528 while let Some((id, result)) = futs.next().await {
1530 tx.send((id, result.content)).unwrap();
1531 }
1532 drop(tx);
1533
1534 let mut streamed = Vec::new();
1536 while let Some(item) = rx.recv().await {
1537 streamed.push(item);
1538 }
1539
1540 assert_eq!(streamed.len(), 2);
1541 assert_eq!(streamed[0].0, "t_fast");
1543 assert_eq!(streamed[0].1, "Fast result");
1544 assert_eq!(streamed[1].0, "t_slow");
1545 assert_eq!(streamed[1].1, "Slow result");
1546 }
1547
1548 #[tokio::test]
1549 async fn error_tool_does_not_prevent_other_tools() {
1550 let results = run_concurrent_tools(
1551 vec![
1552 ("t_ok".into(), "Read".into(), json!({})),
1553 ("t_err".into(), "Fail".into(), json!({})),
1554 ],
1555 |name, _input| {
1556 Box::pin(async move {
1557 if name == "Fail" {
1558 Some(ToolResult {
1559 content: "something went wrong".into(),
1560 is_error: true,
1561 raw_content: None,
1562 })
1563 } else {
1564 Some(ToolResult {
1565 content: "ok".into(),
1566 is_error: false,
1567 raw_content: None,
1568 })
1569 }
1570 })
1571 },
1572 )
1573 .await;
1574
1575 assert_eq!(results.len(), 2);
1576 let ok = results.iter().find(|(id, _, _)| id == "t_ok").unwrap();
1577 let err = results.iter().find(|(id, _, _)| id == "t_err").unwrap();
1578 assert_eq!(ok.1, "ok");
1579 assert_eq!(err.1, "something went wrong");
1580 }
1581
1582 #[tokio::test]
1583 async fn external_handler_none_falls_through_correctly() {
1584 let results = run_concurrent_tools(
1587 vec![
1588 ("t_custom".into(), "MyTool".into(), json!({"x": 1})),
1589 ("t_builtin".into(), "Read".into(), json!({"path": "/tmp"})),
1590 ],
1591 |name, _input| {
1592 Box::pin(async move {
1593 if name == "MyTool" {
1594 Some(ToolResult {
1595 content: "custom handled".into(),
1596 is_error: false,
1597 raw_content: None,
1598 })
1599 } else {
1600 None
1602 }
1603 })
1604 },
1605 )
1606 .await;
1607
1608 assert_eq!(results.len(), 2);
1609 let custom = results.iter().find(|(id, _, _)| id == "t_custom").unwrap();
1610 let builtin = results.iter().find(|(id, _, _)| id == "t_builtin").unwrap();
1611 assert_eq!(custom.1, "custom handled");
1612 assert_eq!(builtin.1, "no handler"); }
1614
1615 #[tokio::test]
1616 async fn single_tool_works_same_as_before() {
1617 let results = run_concurrent_tools(
1618 vec![("t1".into(), "Read".into(), json!({"path": "file.txt"}))],
1619 |_name, _input| {
1620 Box::pin(async move {
1621 Some(ToolResult {
1622 content: "file contents".into(),
1623 is_error: false,
1624 raw_content: None,
1625 })
1626 })
1627 },
1628 )
1629 .await;
1630
1631 assert_eq!(results.len(), 1);
1632 assert_eq!(results[0].0, "t1");
1633 assert_eq!(results[0].1, "file contents");
1634 assert_eq!(results[0].2, 0); }
1636
1637 #[tokio::test]
1638 async fn empty_tool_list_produces_no_results() {
1639 let results = run_concurrent_tools(vec![], |_name, _input| {
1640 Box::pin(async move { None })
1641 })
1642 .await;
1643
1644 assert_eq!(results.len(), 0);
1645 }
1646
1647 #[tokio::test]
1648 async fn tool_use_ids_preserved_through_concurrent_execution() {
1649 let results = run_concurrent_tools(
1650 vec![
1651 ("toolu_abc123".into(), "Read".into(), json!({})),
1652 ("toolu_def456".into(), "Write".into(), json!({})),
1653 ("toolu_ghi789".into(), "Bash".into(), json!({})),
1654 ],
1655 |name, _input| {
1656 Box::pin(async move {
1657 match name.as_str() {
1659 "Read" => tokio::time::sleep(Duration::from_millis(30)).await,
1660 "Write" => tokio::time::sleep(Duration::from_millis(10)).await,
1661 _ => tokio::time::sleep(Duration::from_millis(50)).await,
1662 }
1663 Some(ToolResult {
1664 content: format!("{} result", name),
1665 is_error: false,
1666 raw_content: None,
1667 })
1668 })
1669 },
1670 )
1671 .await;
1672
1673 assert_eq!(results.len(), 3);
1674
1675 for (id, content, _) in &results {
1677 match id.as_str() {
1678 "toolu_abc123" => assert_eq!(content, "Read result"),
1679 "toolu_def456" => assert_eq!(content, "Write result"),
1680 "toolu_ghi789" => assert_eq!(content, "Bash result"),
1681 other => panic!("unexpected tool_use_id: {}", other),
1682 }
1683 }
1684 }
1685
1686 #[tokio::test]
1687 async fn concurrent_execution_timing_is_parallel() {
1688 let tools: Vec<_> = (0..5)
1690 .map(|i| (format!("t{}", i), "Tool".into(), json!({})))
1691 .collect();
1692
1693 let start = Instant::now();
1694
1695 let results = run_concurrent_tools(tools, |_name, _input| {
1696 Box::pin(async move {
1697 tokio::time::sleep(Duration::from_millis(50)).await;
1698 Some(ToolResult {
1699 content: "done".into(),
1700 is_error: false,
1701 raw_content: None,
1702 })
1703 })
1704 })
1705 .await;
1706
1707 let elapsed = start.elapsed();
1708
1709 assert_eq!(results.len(), 5);
1710 assert!(
1712 elapsed < Duration::from_millis(200),
1713 "5 x 50ms tools took {:?} — should be ~50ms if concurrent",
1714 elapsed
1715 );
1716 }
1717
1718 #[tokio::test]
1719 async fn api_block_to_content_block_preserves_tool_result_fields() {
1720 let block = ApiContentBlock::ToolResult {
1721 tool_use_id: "toolu_abc".into(),
1722 content: json!("result text"),
1723 is_error: Some(true),
1724 cache_control: None,
1725 name: None,
1726 };
1727
1728 let content = api_block_to_content_block(&block);
1729 match content {
1730 ContentBlock::ToolResult {
1731 tool_use_id,
1732 content,
1733 is_error,
1734 } => {
1735 assert_eq!(tool_use_id, "toolu_abc");
1736 assert_eq!(content, json!("result text"));
1737 assert_eq!(is_error, Some(true));
1738 }
1739 _ => panic!("expected ToolResult content block"),
1740 }
1741 }
1742
1743 #[tokio::test]
1744 async fn streamed_messages_each_contain_single_tool_result() {
1745 let (tx, mut rx) = mpsc::unbounded_channel::<Result<Message>>();
1747 let session_id = "test-session".to_string();
1748
1749 let tool_ids = vec!["t1", "t2", "t3"];
1751 for id in &tool_ids {
1752 let api_block = ApiContentBlock::ToolResult {
1753 tool_use_id: id.to_string(),
1754 content: json!(format!("result for {}", id)),
1755 is_error: None,
1756 cache_control: None,
1757 name: None,
1758 };
1759
1760 let result_msg = Message::User(UserMessage {
1761 uuid: Some(Uuid::new_v4()),
1762 session_id: session_id.clone(),
1763 content: vec![api_block_to_content_block(&api_block)],
1764 parent_tool_use_id: None,
1765 is_synthetic: true,
1766 tool_use_result: None,
1767 });
1768 tx.send(Ok(result_msg)).unwrap();
1769 }
1770 drop(tx);
1771
1772 let mut messages = Vec::new();
1773 while let Some(Ok(msg)) = rx.recv().await {
1774 messages.push(msg);
1775 }
1776
1777 assert_eq!(messages.len(), 3, "should have 3 individual messages");
1778
1779 for (i, msg) in messages.iter().enumerate() {
1780 if let Message::User(user) = msg {
1781 assert_eq!(user.content.len(), 1, "each message should have exactly 1 content block");
1782 assert!(user.is_synthetic);
1783 if let ContentBlock::ToolResult { tool_use_id, .. } = &user.content[0] {
1784 assert_eq!(tool_use_id, tool_ids[i]);
1785 } else {
1786 panic!("expected ToolResult block");
1787 }
1788 } else {
1789 panic!("expected User message");
1790 }
1791 }
1792 }
1793
1794 #[tokio::test]
1795 async fn accumulate_stream_emits_text_deltas_and_builds_response() {
1796 use crate::client::{ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE};
1797
1798 let events: Vec<Result<SE>> = vec![
1800 Ok(SE::MessageStart {
1801 message: MessageResponse {
1802 id: "msg_123".into(),
1803 role: "assistant".into(),
1804 content: vec![],
1805 model: "claude-test".into(),
1806 stop_reason: None,
1807 usage: ApiUsage {
1808 input_tokens: 100,
1809 output_tokens: 0,
1810 cache_creation_input_tokens: None,
1811 cache_read_input_tokens: None,
1812 },
1813 },
1814 }),
1815 Ok(SE::ContentBlockStart {
1816 index: 0,
1817 content_block: ApiContentBlock::Text {
1818 text: String::new(),
1819 cache_control: None,
1820 },
1821 }),
1822 Ok(SE::ContentBlockDelta {
1823 index: 0,
1824 delta: ContentDelta::TextDelta { text: "Hello".into() },
1825 }),
1826 Ok(SE::ContentBlockDelta {
1827 index: 0,
1828 delta: ContentDelta::TextDelta { text: " world".into() },
1829 }),
1830 Ok(SE::ContentBlockDelta {
1831 index: 0,
1832 delta: ContentDelta::TextDelta { text: "!".into() },
1833 }),
1834 Ok(SE::ContentBlockStop { index: 0 }),
1835 Ok(SE::MessageDelta {
1836 delta: crate::client::MessageDelta {
1837 stop_reason: Some("end_turn".into()),
1838 },
1839 usage: ApiUsage {
1840 input_tokens: 0,
1841 output_tokens: 15,
1842 cache_creation_input_tokens: None,
1843 cache_read_input_tokens: None,
1844 },
1845 }),
1846 Ok(SE::MessageStop),
1847 ];
1848
1849 let stream = futures::stream::iter(events);
1850 let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
1851 Box::pin(stream);
1852
1853 let (tx, mut rx) = mpsc::unbounded_channel();
1854
1855 let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
1856 .await
1857 .expect("accumulate_stream should succeed");
1858
1859 assert_eq!(response.id, "msg_123");
1861 assert_eq!(response.model, "claude-test");
1862 assert_eq!(response.stop_reason, Some("end_turn".into()));
1863 assert_eq!(response.usage.output_tokens, 15);
1864 assert_eq!(response.content.len(), 1);
1865 if let ApiContentBlock::Text { text, .. } = &response.content[0] {
1866 assert_eq!(text, "Hello world!");
1867 } else {
1868 panic!("expected Text content block");
1869 }
1870
1871 let mut stream_events = Vec::new();
1873 while let Ok(msg) = rx.try_recv() {
1874 stream_events.push(msg.unwrap());
1875 }
1876 assert_eq!(stream_events.len(), 3);
1877
1878 let expected_texts = ["Hello", " world", "!"];
1880 for (i, msg) in stream_events.iter().enumerate() {
1881 if let Message::StreamEvent(se) = msg {
1882 let delta = se.event.get("delta").unwrap();
1883 let text = delta.get("text").unwrap().as_str().unwrap();
1884 assert_eq!(text, expected_texts[i]);
1885 assert_eq!(se.session_id, "test-session");
1886 } else {
1887 panic!("expected StreamEvent message at index {}", i);
1888 }
1889 }
1890 }
1891
1892 #[tokio::test]
1893 async fn accumulate_stream_handles_tool_use() {
1894 use crate::client::{ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE};
1895
1896 let events: Vec<Result<SE>> = vec![
1897 Ok(SE::MessageStart {
1898 message: MessageResponse {
1899 id: "msg_456".into(),
1900 role: "assistant".into(),
1901 content: vec![],
1902 model: "claude-test".into(),
1903 stop_reason: None,
1904 usage: ApiUsage::default(),
1905 },
1906 }),
1907 Ok(SE::ContentBlockStart {
1909 index: 0,
1910 content_block: ApiContentBlock::Text {
1911 text: String::new(),
1912 cache_control: None,
1913 },
1914 }),
1915 Ok(SE::ContentBlockDelta {
1916 index: 0,
1917 delta: ContentDelta::TextDelta { text: "Let me check.".into() },
1918 }),
1919 Ok(SE::ContentBlockStop { index: 0 }),
1920 Ok(SE::ContentBlockStart {
1922 index: 1,
1923 content_block: ApiContentBlock::ToolUse {
1924 id: "toolu_abc".into(),
1925 name: "Read".into(),
1926 input: serde_json::json!({}),
1927 },
1928 }),
1929 Ok(SE::ContentBlockDelta {
1930 index: 1,
1931 delta: ContentDelta::InputJsonDelta {
1932 partial_json: r#"{"path":"/tmp/f.txt"}"#.into(),
1933 },
1934 }),
1935 Ok(SE::ContentBlockStop { index: 1 }),
1936 Ok(SE::MessageDelta {
1937 delta: crate::client::MessageDelta {
1938 stop_reason: Some("tool_use".into()),
1939 },
1940 usage: ApiUsage { input_tokens: 0, output_tokens: 20, ..Default::default() },
1941 }),
1942 Ok(SE::MessageStop),
1943 ];
1944
1945 let stream = futures::stream::iter(events);
1946 let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
1947 Box::pin(stream);
1948
1949 let (tx, _rx) = mpsc::unbounded_channel();
1950 let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
1951 .await
1952 .expect("should succeed");
1953
1954 assert_eq!(response.content.len(), 2);
1955 if let ApiContentBlock::Text { text, .. } = &response.content[0] {
1956 assert_eq!(text, "Let me check.");
1957 } else {
1958 panic!("expected Text block at index 0");
1959 }
1960 if let ApiContentBlock::ToolUse { id, name, input } = &response.content[1] {
1961 assert_eq!(id, "toolu_abc");
1962 assert_eq!(name, "Read");
1963 assert_eq!(input["path"], "/tmp/f.txt");
1964 } else {
1965 panic!("expected ToolUse block at index 1");
1966 }
1967 assert_eq!(response.stop_reason, Some("tool_use".into()));
1968 }
1969}