1use crate::config::constants::tools;
10use crate::types::CompactStr;
11use hashbrown::HashMap;
12use std::path::PathBuf;
13use std::sync::Arc;
14
15use super::sandboxing::{ExecToolCallOutput, ToolError};
16use super::tool_handler::{
17 DiffTracker, FileChange, PatchApplyBeginEvent, PatchApplyEndEvent, ToolCallError, ToolEvent,
18 ToolEventBegin, ToolEventFailure, ToolEventSuccess, ToolSession, TurnContext,
19};
20
21pub struct ToolEventCtx<'a> {
23 pub session: &'a dyn ToolSession,
24 pub turn: &'a TurnContext,
25 pub call_id: &'a str,
26 pub turn_diff_tracker: Option<&'a Arc<tokio::sync::Mutex<DiffTracker>>>,
27}
28
29impl<'a> ToolEventCtx<'a> {
30 pub fn new(
31 session: &'a dyn ToolSession,
32 turn: &'a TurnContext,
33 call_id: &'a str,
34 tracker: Option<&'a Arc<tokio::sync::Mutex<DiffTracker>>>,
35 ) -> Self {
36 Self {
37 session,
38 turn,
39 call_id,
40 turn_diff_tracker: tracker,
41 }
42 }
43}
44
45#[derive(Clone, Debug)]
47pub enum ToolEventStage {
48 Begin,
49 Success(ExecToolCallOutput),
50 Failure(ToolEventFailureKind),
51}
52
53#[derive(Clone, Debug)]
55pub enum ToolEventFailureKind {
56 Output(ExecToolCallOutput),
57 Message(String),
58 Error(String),
59}
60
61#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
63pub enum ExecCommandSource {
64 #[default]
65 Agent,
66 User,
67 UnifiedExecStartup,
68 UnifiedExecWriteStdin,
69}
70
71#[derive(Clone, Debug)]
73pub struct ParsedCommand {
74 pub program: String,
75 pub args: Vec<String>,
76}
77
78pub fn parse_command(command: &[String]) -> ParsedCommand {
80 let program = command.first().cloned().unwrap_or_default();
81 let args = command.get(1..).map(|s| s.to_vec()).unwrap_or_default();
82 ParsedCommand { program, args }
83}
84
85#[derive(Clone, Debug)]
87pub enum ToolEmitter {
88 Shell {
90 command: Vec<String>,
91 cwd: PathBuf,
92 source: ExecCommandSource,
93 parsed_cmd: ParsedCommand,
94 freeform: bool,
95 },
96 ApplyPatch {
98 changes: HashMap<PathBuf, FileChange>,
99 auto_approved: bool,
100 },
101 UnifiedExec {
103 command: Vec<String>,
104 cwd: PathBuf,
105 source: ExecCommandSource,
106 interaction_input: Option<String>,
107 parsed_cmd: ParsedCommand,
108 process_id: Option<String>,
109 },
110 Generic { tool_name: CompactStr },
112}
113
114impl ToolEmitter {
115 pub fn shell(
117 command: Vec<String>,
118 cwd: PathBuf,
119 source: ExecCommandSource,
120 freeform: bool,
121 ) -> Self {
122 let parsed_cmd = parse_command(&command);
123 Self::Shell {
124 command,
125 cwd,
126 source,
127 parsed_cmd,
128 freeform,
129 }
130 }
131
132 pub fn apply_patch(changes: HashMap<PathBuf, FileChange>, auto_approved: bool) -> Self {
134 Self::ApplyPatch {
135 changes,
136 auto_approved,
137 }
138 }
139
140 pub fn unified_exec(
142 command: &[String],
143 cwd: PathBuf,
144 source: ExecCommandSource,
145 process_id: Option<String>,
146 ) -> Self {
147 let parsed_cmd = parse_command(command);
148 Self::UnifiedExec {
149 command: command.to_vec(),
150 cwd,
151 source,
152 interaction_input: None,
153 parsed_cmd,
154 process_id,
155 }
156 }
157
158 pub fn generic(tool_name: impl Into<CompactStr>) -> Self {
160 Self::Generic {
161 tool_name: tool_name.into(),
162 }
163 }
164
165 pub async fn emit(&self, ctx: ToolEventCtx<'_>, stage: ToolEventStage) {
167 match (self, &stage) {
168 (
170 Self::ApplyPatch {
171 changes,
172 auto_approved,
173 },
174 ToolEventStage::Begin,
175 ) => {
176 if let Some(tracker) = ctx.turn_diff_tracker {
178 let mut guard = tracker.lock().await;
179 guard.on_patch_begin(changes);
180 }
181
182 let event = ToolEvent::PatchApplyBegin(PatchApplyBeginEvent {
183 call_id: ctx.call_id.to_string(),
184 turn_id: ctx.turn.turn_id.clone(),
185 changes: changes.clone(),
186 auto_approved: *auto_approved,
187 });
188 ctx.session.send_event(event).await;
189 }
190
191 (Self::ApplyPatch { changes: _, .. }, ToolEventStage::Success(output)) => {
193 self.emit_patch_end(ctx, output.stdout.clone(), output.stderr.clone(), true)
194 .await;
195 }
196
197 (
199 Self::ApplyPatch { .. },
200 ToolEventStage::Failure(ToolEventFailureKind::Output(output)),
201 ) => {
202 self.emit_patch_end(ctx, output.stdout.clone(), output.stderr.clone(), false)
203 .await;
204 }
205
206 (
207 Self::ApplyPatch { .. },
208 ToolEventStage::Failure(ToolEventFailureKind::Message(msg)),
209 ) => {
210 self.emit_patch_end(ctx, String::new(), msg.clone(), false)
211 .await;
212 }
213
214 (Self::Shell { .. } | Self::UnifiedExec { .. }, ToolEventStage::Begin) => {
216 let event = ToolEvent::Begin(ToolEventBegin {
217 call_id: ctx.call_id.to_string(),
218 tool_name: self.tool_name().into(),
219 turn_id: ctx.turn.turn_id.clone(),
220 });
221 ctx.session.send_event(event).await;
222 }
223
224 (Self::Shell { .. } | Self::UnifiedExec { .. }, ToolEventStage::Success(output)) => {
226 let event = ToolEvent::Success(ToolEventSuccess {
227 call_id: ctx.call_id.to_string(),
228 output: output.combined_output(),
229 });
230 ctx.session.send_event(event).await;
231 }
232
233 (Self::Shell { .. } | Self::UnifiedExec { .. }, ToolEventStage::Failure(kind)) => {
235 let error = match kind {
236 ToolEventFailureKind::Output(output) => output.combined_output(),
237 ToolEventFailureKind::Message(msg) => msg.clone(),
238 ToolEventFailureKind::Error(err) => err.clone(),
239 };
240 let event = ToolEvent::Failure(ToolEventFailure {
241 call_id: ctx.call_id.to_string(),
242 error,
243 });
244 ctx.session.send_event(event).await;
245 }
246
247 (Self::Generic { tool_name }, ToolEventStage::Begin) => {
249 let event = ToolEvent::Begin(ToolEventBegin {
250 call_id: ctx.call_id.to_string(),
251 tool_name: tool_name.to_string(),
252 turn_id: ctx.turn.turn_id.clone(),
253 });
254 ctx.session.send_event(event).await;
255 }
256
257 (Self::Generic { .. }, ToolEventStage::Success(output)) => {
258 let event = ToolEvent::Success(ToolEventSuccess {
259 call_id: ctx.call_id.to_string(),
260 output: output.combined_output(),
261 });
262 ctx.session.send_event(event).await;
263 }
264
265 (Self::Generic { .. }, ToolEventStage::Failure(kind)) => {
266 let error = match kind {
267 ToolEventFailureKind::Output(output) => output.combined_output(),
268 ToolEventFailureKind::Message(msg) => msg.clone(),
269 ToolEventFailureKind::Error(err) => err.clone(),
270 };
271 let event = ToolEvent::Failure(ToolEventFailure {
272 call_id: ctx.call_id.to_string(),
273 error,
274 });
275 ctx.session.send_event(event).await;
276 }
277
278 _ => {}
279 }
280 }
281
282 pub async fn begin(&self, ctx: ToolEventCtx<'_>) {
284 self.emit(ctx, ToolEventStage::Begin).await;
285 }
286
287 pub async fn finish(
289 &self,
290 ctx: ToolEventCtx<'_>,
291 result: Result<ExecToolCallOutput, ToolError>,
292 ) -> Result<String, ToolCallError> {
293 match result {
294 Ok(output) => {
295 self.emit(ctx, ToolEventStage::Success(output.clone()))
296 .await;
297 Ok(self.format_output_for_model(&output))
298 }
299 Err(ToolError::Rejected(msg)) => {
300 self.emit(
301 ctx,
302 ToolEventStage::Failure(ToolEventFailureKind::Message(msg.clone())),
303 )
304 .await;
305 Err(ToolCallError::Rejected(msg))
306 }
307 Err(ToolError::Timeout(ms)) => {
308 let msg = format!("Command timed out after {}ms", ms);
309 self.emit(
310 ctx,
311 ToolEventStage::Failure(ToolEventFailureKind::Message(msg.clone())),
312 )
313 .await;
314 Err(ToolCallError::Timeout(ms))
315 }
316 Err(e) => {
317 let msg = e.to_string();
318 self.emit(
319 ctx,
320 ToolEventStage::Failure(ToolEventFailureKind::Error(msg.clone())),
321 )
322 .await;
323 Err(ToolCallError::Internal(e.into()))
324 }
325 }
326 }
327
328 fn format_output_for_model(&self, output: &ExecToolCallOutput) -> String {
330 let mut result = String::new();
331
332 if !output.stdout.is_empty() {
333 result.push_str(&output.stdout);
334 }
335
336 if !output.stderr.is_empty() {
337 if !result.is_empty() {
338 result.push_str("\n\n[stderr]\n");
339 }
340 result.push_str(&output.stderr);
341 }
342
343 if output.exit_code != 0 {
344 if !result.is_empty() {
345 result.push('\n');
346 }
347 result.push_str(&format!("[exit code: {}]", output.exit_code));
348 }
349
350 if result.is_empty() {
351 result.push_str("[no output]");
352 }
353
354 result
355 }
356
357 fn tool_name(&self) -> CompactStr {
359 match self {
360 Self::Shell { .. } => CompactStr::from(tools::SHELL),
361 Self::ApplyPatch { .. } => CompactStr::from(tools::APPLY_PATCH),
362 Self::UnifiedExec { .. } => CompactStr::from(tools::UNIFIED_EXEC),
363 Self::Generic { tool_name } => tool_name.clone(),
364 }
365 }
366
367 async fn emit_patch_end(
369 &self,
370 ctx: ToolEventCtx<'_>,
371 stdout: String,
372 stderr: String,
373 success: bool,
374 ) {
375 {
377 if let Some(tracker) = ctx.turn_diff_tracker {
378 let mut guard = tracker.lock().await;
379 guard.on_patch_end(success);
380 }
381
382 let event = ToolEvent::PatchApplyEnd(PatchApplyEndEvent {
383 call_id: ctx.call_id.to_string(),
384 success,
385 stdout,
386 stderr,
387 });
388 ctx.session.send_event(event).await;
389 }
390 }
391}
392
393#[derive(Clone, Debug)]
395pub struct ExecCommandInput<'a> {
396 pub command: &'a [String],
397 pub cwd: &'a std::path::Path,
398 pub parsed_cmd: &'a ParsedCommand,
399 pub source: ExecCommandSource,
400 pub timeout_ms: Option<u64>,
401 pub justification: Option<&'a str>,
402}
403
404impl<'a> ExecCommandInput<'a> {
405 pub fn new(
406 command: &'a [String],
407 cwd: &'a std::path::Path,
408 parsed_cmd: &'a ParsedCommand,
409 source: ExecCommandSource,
410 timeout_ms: Option<u64>,
411 justification: Option<&'a str>,
412 ) -> Self {
413 Self {
414 command,
415 cwd,
416 parsed_cmd,
417 source,
418 timeout_ms,
419 justification,
420 }
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_parse_command() {
430 let cmd = vec!["ls".to_string(), "-la".to_string(), "/tmp".to_string()];
431 let parsed = parse_command(&cmd);
432
433 assert_eq!(parsed.program, "ls");
434 assert_eq!(parsed.args, vec!["-la", "/tmp"]);
435 }
436
437 #[test]
438 fn test_parse_command_empty() {
439 let cmd: Vec<String> = vec![];
440 let parsed = parse_command(&cmd);
441
442 assert_eq!(parsed.program, "");
443 assert!(parsed.args.is_empty());
444 }
445
446 #[test]
447 fn test_emitter_tool_names() {
448 let shell = ToolEmitter::shell(
449 vec!["ls".to_string()],
450 PathBuf::new(),
451 ExecCommandSource::Agent,
452 false,
453 );
454 assert_eq!(shell.tool_name(), "shell");
455
456 let patch = ToolEmitter::apply_patch(HashMap::new(), true);
457 assert_eq!(patch.tool_name(), "apply_patch");
458
459 let exec = ToolEmitter::unified_exec(
460 &["echo".to_string()],
461 PathBuf::new(),
462 ExecCommandSource::Agent,
463 None,
464 );
465 assert_eq!(exec.tool_name(), "unified_exec");
466
467 let generic = ToolEmitter::generic("custom_tool");
468 assert_eq!(generic.tool_name(), "custom_tool");
469 }
470
471 #[test]
472 fn test_format_output_for_model() {
473 let emitter = ToolEmitter::generic("test");
474
475 let output = ExecToolCallOutput {
477 stdout: "Hello, world!".to_string(),
478 stderr: String::new(),
479 exit_code: 0,
480 };
481 assert_eq!(emitter.format_output_for_model(&output), "Hello, world!");
482
483 let output = ExecToolCallOutput {
485 stdout: String::new(),
486 stderr: "Error!".to_string(),
487 exit_code: 1,
488 };
489 assert_eq!(
490 emitter.format_output_for_model(&output),
491 "Error!\n[exit code: 1]"
492 );
493
494 let output = ExecToolCallOutput::default();
496 assert_eq!(emitter.format_output_for_model(&output), "[no output]");
497 }
498}