1use std::collections::BTreeMap;
2use std::fmt::{Display, Formatter};
3
4use crate::compact::{
5 compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
6};
7use crate::config::RuntimeFeatureConfig;
8use crate::hooks::{HookRunResult, HookRunner};
9use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
10use crate::session::{ContentBlock, ConversationMessage, Session};
11use crate::usage::{TokenUsage, UsageTracker};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct ApiRequest {
15 pub system_prompt: Vec<String>,
16 pub messages: Vec<ConversationMessage>,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum AssistantEvent {
21 TextDelta(String),
22 ToolUse {
23 id: String,
24 name: String,
25 input: String,
26 },
27 Usage(TokenUsage),
28 MessageStop,
29}
30
31pub trait ApiClient {
32 fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
33}
34
35pub trait ToolExecutor {
36 fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct ToolError {
41 message: String,
42}
43
44impl ToolError {
45 #[must_use]
46 pub fn new(message: impl Into<String>) -> Self {
47 Self {
48 message: message.into(),
49 }
50 }
51}
52
53impl Display for ToolError {
54 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55 write!(f, "{}", self.message)
56 }
57}
58
59impl std::error::Error for ToolError {}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct RuntimeError {
63 message: String,
64}
65
66impl RuntimeError {
67 #[must_use]
68 pub fn new(message: impl Into<String>) -> Self {
69 Self {
70 message: message.into(),
71 }
72 }
73}
74
75impl Display for RuntimeError {
76 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
77 write!(f, "{}", self.message)
78 }
79}
80
81impl std::error::Error for RuntimeError {}
82
83#[derive(Debug, Clone, PartialEq, Eq)]
84pub struct TurnSummary {
85 pub assistant_messages: Vec<ConversationMessage>,
86 pub tool_results: Vec<ConversationMessage>,
87 pub iterations: usize,
88 pub usage: TokenUsage,
89}
90
91pub struct ConversationRuntime<C, T> {
92 session: Session,
93 api_client: C,
94 tool_executor: T,
95 permission_policy: PermissionPolicy,
96 system_prompt: Vec<String>,
97 max_iterations: usize,
98 usage_tracker: UsageTracker,
99 hook_runner: HookRunner,
100}
101
102impl<C, T> ConversationRuntime<C, T>
103where
104 C: ApiClient,
105 T: ToolExecutor,
106{
107 #[must_use]
108 pub fn new(
109 session: Session,
110 api_client: C,
111 tool_executor: T,
112 permission_policy: PermissionPolicy,
113 system_prompt: Vec<String>,
114 ) -> Self {
115 let default_config = RuntimeFeatureConfig::default();
116 Self::new_with_features(
117 session,
118 api_client,
119 tool_executor,
120 permission_policy,
121 system_prompt,
122 &default_config,
123 )
124 }
125
126 #[must_use]
127 pub fn new_with_features(
128 session: Session,
129 api_client: C,
130 tool_executor: T,
131 permission_policy: PermissionPolicy,
132 system_prompt: Vec<String>,
133 feature_config: &RuntimeFeatureConfig,
134 ) -> Self {
135 let usage_tracker = UsageTracker::from_session(&session);
136 Self {
137 session,
138 api_client,
139 tool_executor,
140 permission_policy,
141 system_prompt,
142 max_iterations: usize::MAX,
143 usage_tracker,
144 hook_runner: HookRunner::from_feature_config(feature_config),
145 }
146 }
147
148 #[must_use]
149 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
150 self.max_iterations = max_iterations;
151 self
152 }
153
154 pub fn run_turn(
155 &mut self,
156 user_input: impl Into<String>,
157 mut prompter: Option<&mut dyn PermissionPrompter>,
158 ) -> Result<TurnSummary, RuntimeError> {
159 self.session
160 .messages
161 .push(ConversationMessage::user_text(user_input.into()));
162
163 let mut assistant_messages = Vec::new();
164 let mut tool_results = Vec::new();
165 let mut iterations = 0;
166
167 loop {
168 iterations += 1;
169 if iterations > self.max_iterations {
170 return Err(RuntimeError::new(
171 "conversation loop exceeded the maximum number of iterations",
172 ));
173 }
174
175 let request = ApiRequest {
176 system_prompt: self.system_prompt.clone(),
177 messages: self.session.messages.clone(),
178 };
179 let events = self.api_client.stream(request)?;
180 let (assistant_message, usage) = build_assistant_message(events)?;
181 if let Some(usage) = usage {
182 self.usage_tracker.record(usage);
183 }
184 let pending_tool_uses = assistant_message
185 .blocks
186 .iter()
187 .filter_map(|block| match block {
188 ContentBlock::ToolUse { id, name, input } => {
189 Some((id.clone(), name.clone(), input.clone()))
190 }
191 _ => None,
192 })
193 .collect::<Vec<_>>();
194
195 self.session.messages.push(assistant_message.clone());
196 assistant_messages.push(assistant_message);
197
198 if pending_tool_uses.is_empty() {
199 break;
200 }
201
202 for (tool_use_id, tool_name, input) in pending_tool_uses {
203 let permission_outcome = if let Some(prompt) = prompter.as_mut() {
204 self.permission_policy
205 .authorize(&tool_name, &input, Some(*prompt))
206 } else {
207 self.permission_policy.authorize(&tool_name, &input, None)
208 };
209
210 let result_message = match permission_outcome {
211 PermissionOutcome::Allow => {
212 let pre_hook_result = self.hook_runner.run_pre_tool_use(&tool_name, &input);
213 if pre_hook_result.is_denied() {
214 let deny_message = format!("PreToolUse hook denied tool `{tool_name}`");
215 ConversationMessage::tool_result(
216 tool_use_id,
217 tool_name,
218 format_hook_message(&pre_hook_result, &deny_message),
219 true,
220 )
221 } else {
222 let (mut output, mut is_error) =
223 match self.tool_executor.execute(&tool_name, &input) {
224 Ok(output) => (output, false),
225 Err(error) => (error.to_string(), true),
226 };
227 output = merge_hook_feedback(pre_hook_result.messages(), output, false);
228
229 let post_hook_result = self
230 .hook_runner
231 .run_post_tool_use(&tool_name, &input, &output, is_error);
232 if post_hook_result.is_denied() {
233 is_error = true;
234 }
235 output = merge_hook_feedback(
236 post_hook_result.messages(),
237 output,
238 post_hook_result.is_denied(),
239 );
240
241 ConversationMessage::tool_result(
242 tool_use_id,
243 tool_name,
244 output,
245 is_error,
246 )
247 }
248 }
249 PermissionOutcome::Deny { reason } => {
250 ConversationMessage::tool_result(tool_use_id, tool_name, reason, true)
251 }
252 };
253 self.session.messages.push(result_message.clone());
254 tool_results.push(result_message);
255 }
256 }
257
258 Ok(TurnSummary {
259 assistant_messages,
260 tool_results,
261 iterations,
262 usage: self.usage_tracker.cumulative_usage(),
263 })
264 }
265
266 #[must_use]
267 pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
268 compact_session(&self.session, config)
269 }
270
271 #[must_use]
272 pub fn estimated_tokens(&self) -> usize {
273 estimate_session_tokens(&self.session)
274 }
275
276 #[must_use]
277 pub fn usage(&self) -> &UsageTracker {
278 &self.usage_tracker
279 }
280
281 #[must_use]
282 pub fn session(&self) -> &Session {
283 &self.session
284 }
285
286 #[must_use]
287 pub fn into_session(self) -> Session {
288 self.session
289 }
290}
291
292fn build_assistant_message(
293 events: Vec<AssistantEvent>,
294) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
295 let mut text = String::new();
296 let mut blocks = Vec::new();
297 let mut finished = false;
298 let mut usage = None;
299
300 for event in events {
301 match event {
302 AssistantEvent::TextDelta(delta) => text.push_str(&delta),
303 AssistantEvent::ToolUse { id, name, input } => {
304 flush_text_block(&mut text, &mut blocks);
305 blocks.push(ContentBlock::ToolUse { id, name, input });
306 }
307 AssistantEvent::Usage(value) => usage = Some(value),
308 AssistantEvent::MessageStop => {
309 finished = true;
310 }
311 }
312 }
313
314 flush_text_block(&mut text, &mut blocks);
315
316 if !finished {
317 return Err(RuntimeError::new(
318 "assistant stream ended without a message stop event",
319 ));
320 }
321 if blocks.is_empty() {
322 return Err(RuntimeError::new("assistant stream produced no content"));
323 }
324
325 Ok((
326 ConversationMessage::assistant_with_usage(blocks, usage),
327 usage,
328 ))
329}
330
331fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
332 if !text.is_empty() {
333 blocks.push(ContentBlock::Text {
334 text: std::mem::take(text),
335 });
336 }
337}
338
339fn format_hook_message(result: &HookRunResult, fallback: &str) -> String {
340 if result.messages().is_empty() {
341 fallback.to_string()
342 } else {
343 result.messages().join("\n")
344 }
345}
346
347fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String {
348 if messages.is_empty() {
349 return output;
350 }
351
352 let mut sections = Vec::new();
353 if !output.trim().is_empty() {
354 sections.push(output);
355 }
356 let label = if denied {
357 "Hook feedback (denied)"
358 } else {
359 "Hook feedback"
360 };
361 sections.push(format!("{label}:\n{}", messages.join("\n")));
362 sections.join("\n\n")
363}
364
365type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
366
367#[derive(Default)]
368pub struct StaticToolExecutor {
369 handlers: BTreeMap<String, ToolHandler>,
370}
371
372impl StaticToolExecutor {
373 #[must_use]
374 pub fn new() -> Self {
375 Self::default()
376 }
377
378 #[must_use]
379 pub fn register(
380 mut self,
381 tool_name: impl Into<String>,
382 handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
383 ) -> Self {
384 self.handlers.insert(tool_name.into(), Box::new(handler));
385 self
386 }
387}
388
389impl ToolExecutor for StaticToolExecutor {
390 fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
391 self.handlers
392 .get_mut(tool_name)
393 .ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::{
400 ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError,
401 StaticToolExecutor,
402 };
403 use crate::compact::CompactionConfig;
404 use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
405 use crate::permissions::{
406 PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
407 PermissionRequest,
408 };
409 use crate::prompt::{ProjectContext, SystemPromptBuilder};
410 use crate::session::{ContentBlock, MessageRole, Session};
411 use crate::usage::TokenUsage;
412 use std::path::PathBuf;
413
414 struct ScriptedApiClient {
415 call_count: usize,
416 }
417
418 impl ApiClient for ScriptedApiClient {
419 fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
420 self.call_count += 1;
421 match self.call_count {
422 1 => {
423 assert!(request
424 .messages
425 .iter()
426 .any(|message| message.role == MessageRole::User));
427 Ok(vec![
428 AssistantEvent::TextDelta("Let me calculate that.".to_string()),
429 AssistantEvent::ToolUse {
430 id: "tool-1".to_string(),
431 name: "add".to_string(),
432 input: "2,2".to_string(),
433 },
434 AssistantEvent::Usage(TokenUsage {
435 input_tokens: 20,
436 output_tokens: 6,
437 cache_creation_input_tokens: 1,
438 cache_read_input_tokens: 2,
439 }),
440 AssistantEvent::MessageStop,
441 ])
442 }
443 2 => {
444 let last_message = request
445 .messages
446 .last()
447 .expect("tool result should be present");
448 assert_eq!(last_message.role, MessageRole::Tool);
449 Ok(vec![
450 AssistantEvent::TextDelta("The answer is 4.".to_string()),
451 AssistantEvent::Usage(TokenUsage {
452 input_tokens: 24,
453 output_tokens: 4,
454 cache_creation_input_tokens: 1,
455 cache_read_input_tokens: 3,
456 }),
457 AssistantEvent::MessageStop,
458 ])
459 }
460 _ => Err(RuntimeError::new("unexpected extra API call")),
461 }
462 }
463 }
464
465 struct PromptAllowOnce;
466
467 impl PermissionPrompter for PromptAllowOnce {
468 fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
469 assert_eq!(request.tool_name, "add");
470 PermissionPromptDecision::Allow
471 }
472 }
473
474 #[test]
475 fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
476 let api_client = ScriptedApiClient { call_count: 0 };
477 let tool_executor = StaticToolExecutor::new().register("add", |input| {
478 let total = input
479 .split(',')
480 .map(|part| part.parse::<i32>().expect("input must be valid integer"))
481 .sum::<i32>();
482 Ok(total.to_string())
483 });
484 let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
485 let system_prompt = SystemPromptBuilder::new()
486 .with_project_context(ProjectContext {
487 cwd: PathBuf::from("/tmp/project"),
488 current_date: "2026-03-31".to_string(),
489 git_status: None,
490 git_diff: None,
491 instruction_files: Vec::new(),
492 })
493 .with_os("linux", "6.8")
494 .build();
495 let mut runtime = ConversationRuntime::new(
496 Session::new(),
497 api_client,
498 tool_executor,
499 permission_policy,
500 system_prompt,
501 );
502
503 let summary = runtime
504 .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
505 .expect("conversation loop should succeed");
506
507 assert_eq!(summary.iterations, 2);
508 assert_eq!(summary.assistant_messages.len(), 2);
509 assert_eq!(summary.tool_results.len(), 1);
510 assert_eq!(runtime.session().messages.len(), 4);
511 assert_eq!(summary.usage.output_tokens, 10);
512 assert!(matches!(
513 runtime.session().messages[1].blocks[1],
514 ContentBlock::ToolUse { .. }
515 ));
516 assert!(matches!(
517 runtime.session().messages[2].blocks[0],
518 ContentBlock::ToolResult {
519 is_error: false,
520 ..
521 }
522 ));
523 }
524
525 #[test]
526 fn records_denied_tool_results_when_prompt_rejects() {
527 struct RejectPrompter;
528 impl PermissionPrompter for RejectPrompter {
529 fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
530 PermissionPromptDecision::Deny {
531 reason: "not now".to_string(),
532 }
533 }
534 }
535
536 struct SingleCallApiClient;
537 impl ApiClient for SingleCallApiClient {
538 fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
539 if request
540 .messages
541 .iter()
542 .any(|message| message.role == MessageRole::Tool)
543 {
544 return Ok(vec![
545 AssistantEvent::TextDelta("I could not use the tool.".to_string()),
546 AssistantEvent::MessageStop,
547 ]);
548 }
549 Ok(vec![
550 AssistantEvent::ToolUse {
551 id: "tool-1".to_string(),
552 name: "blocked".to_string(),
553 input: "secret".to_string(),
554 },
555 AssistantEvent::MessageStop,
556 ])
557 }
558 }
559
560 let mut runtime = ConversationRuntime::new(
561 Session::new(),
562 SingleCallApiClient,
563 StaticToolExecutor::new(),
564 PermissionPolicy::new(PermissionMode::WorkspaceWrite),
565 vec!["system".to_string()],
566 );
567
568 let summary = runtime
569 .run_turn("use the tool", Some(&mut RejectPrompter))
570 .expect("conversation should continue after denied tool");
571
572 assert_eq!(summary.tool_results.len(), 1);
573 assert!(matches!(
574 &summary.tool_results[0].blocks[0],
575 ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
576 ));
577 }
578
579 #[test]
580 fn denies_tool_use_when_pre_tool_hook_blocks() {
581 struct SingleCallApiClient;
582 impl ApiClient for SingleCallApiClient {
583 fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
584 if request
585 .messages
586 .iter()
587 .any(|message| message.role == MessageRole::Tool)
588 {
589 return Ok(vec![
590 AssistantEvent::TextDelta("blocked".to_string()),
591 AssistantEvent::MessageStop,
592 ]);
593 }
594 Ok(vec![
595 AssistantEvent::ToolUse {
596 id: "tool-1".to_string(),
597 name: "blocked".to_string(),
598 input: r#"{"path":"secret.txt"}"#.to_string(),
599 },
600 AssistantEvent::MessageStop,
601 ])
602 }
603 }
604
605 let feature_config = RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
606 vec![shell_snippet("printf 'blocked by hook'; exit 2")],
607 Vec::new(),
608 ));
609
610 let mut runtime = ConversationRuntime::new_with_features(
611 Session::new(),
612 SingleCallApiClient,
613 StaticToolExecutor::new().register("blocked", |_input| {
614 panic!("tool should not execute when hook denies")
615 }),
616 PermissionPolicy::new(PermissionMode::DangerFullAccess),
617 vec!["system".to_string()],
618 &feature_config,
619 );
620
621 let summary = runtime
622 .run_turn("use the tool", None)
623 .expect("conversation should continue after hook denial");
624
625 assert_eq!(summary.tool_results.len(), 1);
626 let ContentBlock::ToolResult {
627 is_error, output, ..
628 } = &summary.tool_results[0].blocks[0]
629 else {
630 panic!("expected tool result block");
631 };
632 assert!(
633 *is_error,
634 "hook denial should produce an error result: {output}"
635 );
636 assert!(
637 output.contains("denied tool") || output.contains("blocked by hook"),
638 "unexpected hook denial output: {output:?}"
639 );
640 }
641
642 #[test]
643 fn appends_post_tool_hook_feedback_to_tool_result() {
644 struct TwoCallApiClient {
645 calls: usize,
646 }
647
648 impl ApiClient for TwoCallApiClient {
649 fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
650 self.calls += 1;
651 match self.calls {
652 1 => Ok(vec![
653 AssistantEvent::ToolUse {
654 id: "tool-1".to_string(),
655 name: "add".to_string(),
656 input: r#"{"lhs":2,"rhs":2}"#.to_string(),
657 },
658 AssistantEvent::MessageStop,
659 ]),
660 2 => {
661 assert!(request
662 .messages
663 .iter()
664 .any(|message| message.role == MessageRole::Tool));
665 Ok(vec![
666 AssistantEvent::TextDelta("done".to_string()),
667 AssistantEvent::MessageStop,
668 ])
669 }
670 _ => Err(RuntimeError::new("unexpected extra API call")),
671 }
672 }
673 }
674
675 let feature_config = RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
676 vec![shell_snippet("printf 'pre hook ran'")],
677 vec![shell_snippet("printf 'post hook ran'")],
678 ));
679
680 let mut runtime = ConversationRuntime::new_with_features(
681 Session::new(),
682 TwoCallApiClient { calls: 0 },
683 StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
684 PermissionPolicy::new(PermissionMode::DangerFullAccess),
685 vec!["system".to_string()],
686 &feature_config,
687 );
688
689 let summary = runtime
690 .run_turn("use add", None)
691 .expect("tool loop succeeds");
692
693 assert_eq!(summary.tool_results.len(), 1);
694 let ContentBlock::ToolResult {
695 is_error, output, ..
696 } = &summary.tool_results[0].blocks[0]
697 else {
698 panic!("expected tool result block");
699 };
700 assert!(
701 !*is_error,
702 "post hook should preserve non-error result: {output:?}"
703 );
704 assert!(
705 output.contains('4'),
706 "tool output missing value: {output:?}"
707 );
708 assert!(
709 output.contains("pre hook ran"),
710 "tool output missing pre hook feedback: {output:?}"
711 );
712 assert!(
713 output.contains("post hook ran"),
714 "tool output missing post hook feedback: {output:?}"
715 );
716 }
717
718 #[test]
719 fn reconstructs_usage_tracker_from_restored_session() {
720 struct SimpleApi;
721 impl ApiClient for SimpleApi {
722 fn stream(
723 &mut self,
724 _request: ApiRequest,
725 ) -> Result<Vec<AssistantEvent>, RuntimeError> {
726 Ok(vec![
727 AssistantEvent::TextDelta("done".to_string()),
728 AssistantEvent::MessageStop,
729 ])
730 }
731 }
732
733 let mut session = Session::new();
734 session
735 .messages
736 .push(crate::session::ConversationMessage::assistant_with_usage(
737 vec![ContentBlock::Text {
738 text: "earlier".to_string(),
739 }],
740 Some(TokenUsage {
741 input_tokens: 11,
742 output_tokens: 7,
743 cache_creation_input_tokens: 2,
744 cache_read_input_tokens: 1,
745 }),
746 ));
747
748 let runtime = ConversationRuntime::new(
749 session,
750 SimpleApi,
751 StaticToolExecutor::new(),
752 PermissionPolicy::new(PermissionMode::DangerFullAccess),
753 vec!["system".to_string()],
754 );
755
756 assert_eq!(runtime.usage().turns(), 1);
757 assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
758 }
759
760 #[test]
761 fn compacts_session_after_turns() {
762 struct SimpleApi;
763 impl ApiClient for SimpleApi {
764 fn stream(
765 &mut self,
766 _request: ApiRequest,
767 ) -> Result<Vec<AssistantEvent>, RuntimeError> {
768 Ok(vec![
769 AssistantEvent::TextDelta("done".to_string()),
770 AssistantEvent::MessageStop,
771 ])
772 }
773 }
774
775 let mut runtime = ConversationRuntime::new(
776 Session::new(),
777 SimpleApi,
778 StaticToolExecutor::new(),
779 PermissionPolicy::new(PermissionMode::DangerFullAccess),
780 vec!["system".to_string()],
781 );
782 runtime.run_turn("a", None).expect("turn a");
783 runtime.run_turn("b", None).expect("turn b");
784 runtime.run_turn("c", None).expect("turn c");
785
786 let result = runtime.compact(CompactionConfig {
787 preserve_recent_messages: 2,
788 max_estimated_tokens: 1,
789 });
790 assert!(result.summary.contains("Conversation summary"));
791 assert_eq!(
792 result.compacted_session.messages[0].role,
793 MessageRole::System
794 );
795 }
796
797 #[cfg(windows)]
798 fn shell_snippet(script: &str) -> String {
799 script.replace('\'', "\"")
800 }
801
802 #[cfg(not(windows))]
803 fn shell_snippet(script: &str) -> String {
804 script.to_string()
805 }
806}