1use crate::config::constants::tools;
14use hashbrown::HashMap;
15use std::path::PathBuf;
16use std::sync::Arc;
17
18use anyhow::Result;
19use async_trait::async_trait;
20use serde::{Deserialize, Serialize};
21use serde_json::Value;
22pub use vtcode_utility_tool_specs::{
23 AdditionalProperties, FreeformTool, FreeformToolFormat, JsonSchema, ResponsesApiTool,
24};
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
28pub enum ToolKind {
29 Function,
31 Mcp,
33 Custom,
35}
36
37#[derive(Clone, Debug)]
39pub enum ToolPayload {
40 Function { arguments: String },
42 Custom { input: String },
44 Mcp { arguments: Option<Value> },
46 LocalShell { params: ShellToolCallParams },
48}
49
50#[derive(Clone, Debug, Deserialize, Serialize)]
52pub struct ShellToolCallParams {
53 pub command: Vec<String>,
54 pub workdir: Option<String>,
55 pub timeout_ms: Option<u64>,
56 pub sandbox_permissions: Option<SandboxPermissions>,
57 pub justification: Option<String>,
58}
59
60#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize, PartialEq, Eq)]
62#[serde(rename_all = "snake_case")]
63pub enum SandboxPermissions {
64 #[default]
65 UseDefault,
66 RequireEscalated,
67 WithAdditionalPermissions,
68}
69
70#[derive(Clone, Debug)]
72pub enum ToolOutput {
73 Function {
75 content: String,
76 content_items: Option<Vec<ContentItem>>,
77 success: Option<bool>,
78 },
79 Mcp { result: McpToolResult },
81}
82
83impl ToolOutput {
84 pub fn simple(content: impl Into<String>) -> Self {
86 Self::Function {
87 content: content.into(),
88 content_items: None,
89 success: Some(true),
90 }
91 }
92
93 pub fn with_success(content: impl Into<String>, success: bool) -> Self {
95 Self::Function {
96 content: content.into(),
97 content_items: None,
98 success: Some(success),
99 }
100 }
101
102 pub fn error(message: impl Into<String>) -> Self {
104 Self::Function {
105 content: message.into(),
106 content_items: None,
107 success: Some(false),
108 }
109 }
110
111 pub fn content(&self) -> Option<&str> {
113 match self {
114 Self::Function { content, .. } => Some(content),
115 Self::Mcp { result } => result.content.first().and_then(|c| c.as_text()),
116 }
117 }
118
119 pub fn is_success(&self) -> bool {
121 match self {
122 Self::Function { success, .. } => success.unwrap_or(true),
123 Self::Mcp { result } => !result.is_error.unwrap_or(false),
124 }
125 }
126}
127
128#[derive(Clone, Debug, Serialize, Deserialize)]
130#[serde(tag = "type", rename_all = "snake_case")]
131pub enum ContentItem {
132 Text {
133 text: String,
134 },
135 Image {
136 data: String,
137 mime_type: String,
138 },
139 Resource {
140 uri: String,
141 mime_type: Option<String>,
142 },
143}
144
145impl ContentItem {
146 pub fn as_text(&self) -> Option<&str> {
147 match self {
148 ContentItem::Text { text } => Some(text),
149 _ => None,
150 }
151 }
152}
153
154#[derive(Clone, Debug, Serialize, Deserialize)]
156pub struct McpToolResult {
157 pub content: Vec<ContentItem>,
158 pub is_error: Option<bool>,
159}
160
161pub struct ToolInvocation {
163 pub session: Arc<dyn ToolSession>,
164 pub turn: Arc<TurnContext>,
165 pub tracker: Option<SharedDiffTracker>,
166 pub call_id: String,
167 pub tool_name: String,
168 pub payload: ToolPayload,
169}
170
171pub type SharedDiffTracker = Arc<tokio::sync::Mutex<DiffTracker>>;
173
174#[derive(Clone, Debug, PartialEq, Eq)]
176pub struct Constrained<T> {
177 value: T,
178}
179
180impl<T> Constrained<T> {
181 pub fn allow_any(initial_value: T) -> Self {
182 Self {
183 value: initial_value,
184 }
185 }
186
187 pub fn get(&self) -> &T {
188 &self.value
189 }
190}
191
192impl<T> std::ops::Deref for Constrained<T> {
197 type Target = T;
198
199 fn deref(&self) -> &Self::Target {
200 &self.value
201 }
202}
203
204impl<T: Copy> Constrained<T> {
205 pub fn value(&self) -> T {
206 self.value
207 }
208}
209
210impl<T: Default> Default for Constrained<T> {
211 fn default() -> Self {
212 Self::allow_any(T::default())
213 }
214}
215
216#[async_trait]
218pub trait ToolSession: Send + Sync {
219 fn cwd(&self) -> &PathBuf;
221
222 fn workspace_root(&self) -> &PathBuf;
224
225 async fn record_warning(&self, message: String);
227
228 fn user_shell(&self) -> &str;
230
231 async fn send_event(&self, event: ToolEvent);
233}
234
235#[derive(Clone, Debug)]
237pub struct TurnContext {
238 pub cwd: PathBuf,
239 pub turn_id: String,
240 pub sub_id: Option<String>,
241 pub shell_environment_policy: ShellEnvironmentPolicy,
242 pub approval_policy: Constrained<ApprovalPolicy>,
243 pub codex_linux_sandbox_exe: Option<PathBuf>,
244 pub sandbox_policy: Constrained<super::sandboxing::SandboxPolicy>,
246}
247
248impl TurnContext {
249 pub fn resolve_path(&self, path: Option<String>) -> PathBuf {
251 self.resolve_path_ref(path.as_deref())
252 }
253
254 pub fn resolve_path_ref(&self, path: Option<&str>) -> PathBuf {
256 match path {
257 Some(p) => {
258 let path = PathBuf::from(p);
259 if path.is_absolute() {
260 path
261 } else {
262 self.cwd.join(path)
263 }
264 }
265 None => self.cwd.clone(),
266 }
267 }
268}
269
270#[derive(Clone, Debug, Default)]
272pub enum ShellEnvironmentPolicy {
273 #[default]
274 Inherit,
275 Clean,
276 Custom(HashMap<String, String>),
277}
278
279#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
281pub enum ApprovalPolicy {
282 #[default]
283 Never,
284 OnMutation,
285 Always,
286}
287
288#[derive(Default, Debug)]
290pub struct DiffTracker {
291 pub changes: HashMap<PathBuf, FileChange>,
292}
293
294impl DiffTracker {
295 pub fn on_patch_begin(&mut self, changes: &HashMap<PathBuf, FileChange>) {
296 self.changes.extend(changes.clone());
297 }
298
299 pub fn on_patch_end(&mut self, success: bool) {
300 if !success {
301 self.changes.clear();
302 }
303 }
304}
305
306#[derive(Clone, Debug, Serialize, Deserialize)]
308#[serde(tag = "type", rename_all = "snake_case")]
309pub enum FileChange {
310 Add {
311 content: String,
312 },
313 Delete,
314 Update {
315 old_content: String,
316 new_content: String,
317 },
318 Rename {
319 new_path: PathBuf,
320 content: Option<String>,
321 },
322}
323
324#[derive(Clone, Debug)]
326pub enum ToolEvent {
327 Begin(ToolEventBegin),
328 Success(ToolEventSuccess),
329 Failure(ToolEventFailure),
330 PatchApplyBegin(PatchApplyBeginEvent),
331 PatchApplyEnd(PatchApplyEndEvent),
332}
333
334#[derive(Clone, Debug)]
335pub struct ToolEventBegin {
336 pub call_id: String,
337 pub tool_name: String,
338 pub turn_id: String,
339}
340
341#[derive(Clone, Debug)]
342pub struct ToolEventSuccess {
343 pub call_id: String,
344 pub output: String,
345}
346
347#[derive(Clone, Debug)]
348pub struct ToolEventFailure {
349 pub call_id: String,
350 pub error: String,
351}
352
353#[derive(Clone, Debug)]
354pub struct PatchApplyBeginEvent {
355 pub call_id: String,
356 pub turn_id: String,
357 pub changes: HashMap<PathBuf, FileChange>,
358 pub auto_approved: bool,
359}
360
361#[derive(Clone, Debug)]
362pub struct PatchApplyEndEvent {
363 pub call_id: String,
364 pub success: bool,
365 pub stdout: String,
366 pub stderr: String,
367}
368
369#[derive(Debug, thiserror::Error)]
371pub enum ToolCallError {
372 #[error("Tool error: {0}")]
374 RespondToModel(String),
375
376 #[error("Internal error: {0}")]
378 Internal(#[from] anyhow::Error),
379
380 #[error("Tool rejected: {0}")]
382 Rejected(String),
383
384 #[error("Tool timed out after {0}ms")]
386 Timeout(u64),
387}
388
389impl ToolCallError {
390 pub fn respond(message: impl Into<String>) -> Self {
392 Self::RespondToModel(message.into())
393 }
394}
395
396impl From<super::sandboxing::ToolError> for ToolCallError {
397 fn from(err: super::sandboxing::ToolError) -> Self {
398 match err {
399 super::sandboxing::ToolError::Rejected(msg) => ToolCallError::Rejected(msg),
400 super::sandboxing::ToolError::Codex(e) => ToolCallError::Internal(e),
401 super::sandboxing::ToolError::SandboxDenied(msg) => {
402 ToolCallError::Rejected(format!("Sandbox denied: {}", msg))
403 }
404 super::sandboxing::ToolError::Timeout(ms) => ToolCallError::Timeout(ms),
405 }
406 }
407}
408
409#[async_trait]
414pub trait ToolHandler: Send + Sync {
415 fn kind(&self) -> ToolKind;
417
418 fn matches_kind(&self, payload: &ToolPayload) -> bool {
420 matches!(
421 (self.kind(), payload),
422 (ToolKind::Function, ToolPayload::Function { .. })
423 | (ToolKind::Mcp, ToolPayload::Mcp { .. })
424 | (ToolKind::Custom, ToolPayload::Custom { .. })
425 )
426 }
427
428 async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool {
432 false
433 }
434
435 async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, ToolCallError>;
437}
438
439#[derive(Clone, Debug, Serialize, Deserialize)]
441#[serde(tag = "type", rename_all = "snake_case")]
442pub enum ToolSpec {
443 Function(ResponsesApiTool),
444 Freeform(FreeformTool),
445 WebSearch {},
446 LocalShell {},
447}
448
449impl ToolSpec {
450 pub fn name(&self) -> &str {
451 match self {
452 ToolSpec::Function(tool) => &tool.name,
453 ToolSpec::Freeform(tool) => &tool.name,
454 ToolSpec::WebSearch {} => tools::WEB_SEARCH,
455 ToolSpec::LocalShell {} => "local_shell",
456 }
457 }
458}
459
460#[derive(Clone, Debug)]
462pub struct ConfiguredToolSpec {
463 pub spec: ToolSpec,
464 pub supports_parallel_tool_calls: bool,
465}
466
467impl ConfiguredToolSpec {
468 pub fn new(spec: ToolSpec, supports_parallel: bool) -> Self {
469 Self {
470 spec,
471 supports_parallel_tool_calls: supports_parallel,
472 }
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_tool_output_simple() {
482 let output = ToolOutput::simple("Hello, world!");
483 assert!(output.is_success());
484 assert_eq!(output.content(), Some("Hello, world!"));
485 }
486
487 #[test]
488 fn test_tool_output_error() {
489 let output = ToolOutput::error("Something went wrong");
490 assert!(!output.is_success());
491 assert_eq!(output.content(), Some("Something went wrong"));
492 }
493
494 #[test]
495 fn test_sandbox_permissions_default() {
496 let perms = SandboxPermissions::default();
497 assert_eq!(perms, SandboxPermissions::UseDefault);
498 }
499
500 #[test]
501 fn test_turn_context_resolve_path_absolute() {
502 let ctx = TurnContext {
503 cwd: PathBuf::from("/workspace"),
504 turn_id: "test".to_string(),
505 sub_id: None,
506 shell_environment_policy: ShellEnvironmentPolicy::default(),
507 approval_policy: Constrained::allow_any(ApprovalPolicy::default()),
508 codex_linux_sandbox_exe: None,
509 sandbox_policy: Constrained::allow_any(Default::default()),
510 };
511
512 let resolved = ctx.resolve_path(Some("/absolute/path".to_string()));
513 assert_eq!(resolved, PathBuf::from("/absolute/path"));
514 }
515
516 #[test]
517 fn test_turn_context_resolve_path_relative() {
518 let ctx = TurnContext {
519 cwd: PathBuf::from("/workspace"),
520 turn_id: "test".to_string(),
521 sub_id: None,
522 shell_environment_policy: ShellEnvironmentPolicy::default(),
523 approval_policy: Constrained::allow_any(ApprovalPolicy::default()),
524 codex_linux_sandbox_exe: None,
525 sandbox_policy: Constrained::allow_any(Default::default()),
526 };
527
528 let resolved = ctx.resolve_path(Some("relative/path".to_string()));
529 assert_eq!(resolved, PathBuf::from("/workspace/relative/path"));
530 }
531}