1use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::time::Duration;
9
10use glob::glob as glob_match;
11use regex::{Regex, RegexBuilder};
12use serde_json::Value;
13use tokio::io::AsyncReadExt;
14use tokio::process::Command;
15use tracing::{debug, warn};
16
17use crate::error::{AgentError, Result};
18use crate::types::tools::{
19 BashInput, FileEditInput, FileReadInput, FileWriteInput, GlobInput, GrepInput, GrepOutputMode,
20};
21
22#[derive(Debug, Clone)]
24pub struct ToolResult {
25 pub content: String,
27 pub is_error: bool,
29 pub raw_content: Option<serde_json::Value>,
32}
33
34impl ToolResult {
35 fn ok(content: String) -> Self {
37 Self {
38 content,
39 is_error: false,
40 raw_content: None,
41 }
42 }
43
44 fn err(content: String) -> Self {
46 Self {
47 content,
48 is_error: true,
49 raw_content: None,
50 }
51 }
52}
53
54async fn drain_pipe<R: tokio::io::AsyncRead + Unpin>(handle: Option<R>) -> Vec<u8> {
58 let mut buf = Vec::new();
59 let Some(mut reader) = handle else {
60 return buf;
61 };
62 let mut chunk = [0u8; 65536];
63 loop {
64 match tokio::time::timeout(Duration::from_millis(10), reader.read(&mut chunk)).await {
65 Ok(Ok(0)) => break, Ok(Ok(n)) => buf.extend_from_slice(&chunk[..n]),
67 Ok(Err(_)) => break, Err(_) => break, }
70 }
71 buf
72}
73
74pub struct ToolExecutor {
76 cwd: PathBuf,
78 boundary: Option<PathBoundary>,
83 env_blocklist: Vec<String>,
86}
87
88struct PathBoundary {
97 allowed: Vec<PathBuf>,
99}
100
101impl PathBoundary {
102 fn new(cwd: &Path, additional: &[PathBuf]) -> Self {
104 let mut allowed = Vec::with_capacity(1 + additional.len());
105
106 let push_canon = |dirs: &mut Vec<PathBuf>, p: &Path| {
109 dirs.push(p.canonicalize().unwrap_or_else(|_| p.to_path_buf()));
110 };
111
112 push_canon(&mut allowed, cwd);
113 for dir in additional {
114 push_canon(&mut allowed, dir);
115 }
116
117 Self { allowed }
118 }
119
120 fn check(&self, path: &Path) -> std::result::Result<(), ToolResult> {
128 let normalized = Self::normalize(path)?;
129
130 for allowed in &self.allowed {
131 if normalized.starts_with(allowed) {
132 return Ok(());
133 }
134 }
135
136 Err(ToolResult::err(format!(
137 "Access denied: {} is outside the allowed directories",
138 path.display()
139 )))
140 }
141
142 fn normalize(path: &Path) -> std::result::Result<PathBuf, ToolResult> {
149 if let Ok(canon) = path.canonicalize() {
151 return Ok(canon);
152 }
153
154 let mut remaining = Vec::new();
156 let mut ancestor = path.to_path_buf();
157
158 loop {
159 if ancestor.exists() {
160 let base = ancestor.canonicalize().map_err(|_| {
161 ToolResult::err(format!(
162 "Access denied: cannot resolve {}",
163 path.display()
164 ))
165 })?;
166
167 let mut result = base;
169 for component in remaining.iter().rev() {
170 result = result.join(component);
171 }
172 return Ok(result);
173 }
174
175 match ancestor.file_name() {
176 Some(name) => {
177 let name = name.to_os_string();
178 remaining.push(name);
179 if !ancestor.pop() {
180 break;
181 }
182 }
183 None => break,
184 }
185 }
186
187 Err(ToolResult::err(format!(
189 "Access denied: cannot resolve {}",
190 path.display()
191 )))
192 }
193}
194
195impl ToolExecutor {
196 pub fn new(cwd: PathBuf) -> Self {
198 Self {
199 cwd,
200 boundary: None,
201 env_blocklist: Vec::new(),
202 }
203 }
204
205 pub fn with_allowed_dirs(cwd: PathBuf, additional: Vec<PathBuf>) -> Self {
208 let boundary = PathBoundary::new(&cwd, &additional);
209 Self {
210 cwd,
211 boundary: Some(boundary),
212 env_blocklist: Vec::new(),
213 }
214 }
215
216 pub fn with_env_blocklist(mut self, blocklist: Vec<String>) -> Self {
218 self.env_blocklist = blocklist;
219 self
220 }
221
222 pub async fn execute(&self, tool_name: &str, input: Value) -> Result<ToolResult> {
224 debug!(tool = tool_name, "executing built-in tool");
225
226 match tool_name {
227 "Read" => {
228 let params: FileReadInput = serde_json::from_value(input)?;
229 self.execute_read(¶ms).await
230 }
231 "Write" => {
232 let params: FileWriteInput = serde_json::from_value(input)?;
233 self.execute_write(¶ms).await
234 }
235 "Edit" => {
236 let params: FileEditInput = serde_json::from_value(input)?;
237 self.execute_edit(¶ms).await
238 }
239 "Bash" => {
240 let params: BashInput = serde_json::from_value(input)?;
241 self.execute_bash(¶ms).await
242 }
243 "Glob" => {
244 let params: GlobInput = serde_json::from_value(input)?;
245 self.execute_glob(¶ms).await
246 }
247 "Grep" => {
248 let params: GrepInput = serde_json::from_value(input)?;
249 self.execute_grep(¶ms).await
250 }
251 _ => Err(AgentError::ToolExecution(format!(
252 "unsupported built-in tool: {tool_name}"
253 ))),
254 }
255 }
256
257 fn resolve_and_check(&self, path: &str) -> std::result::Result<PathBuf, ToolResult> {
262 let p = Path::new(path);
263 let resolved = if p.is_absolute() {
264 p.to_path_buf()
265 } else {
266 self.cwd.join(p)
267 };
268
269 if let Some(ref boundary) = self.boundary {
270 boundary.check(&resolved)?;
271 }
272
273 Ok(resolved)
274 }
275
276 async fn execute_read(&self, input: &FileReadInput) -> Result<ToolResult> {
283 let path = match self.resolve_and_check(&input.file_path) {
284 Ok(p) => p,
285 Err(denied) => return Ok(denied),
286 };
287
288 let ext = path
290 .extension()
291 .unwrap_or_default()
292 .to_string_lossy()
293 .to_lowercase();
294
295 let media_type = match ext.as_str() {
296 "png" => Some("image/png"),
297 "jpg" | "jpeg" => Some("image/jpeg"),
298 "gif" => Some("image/gif"),
299 "webp" => Some("image/webp"),
300 _ => None,
301 };
302
303 if let Some(media_type) = media_type {
304 return self.read_image(&path, media_type).await;
305 }
306
307 let content = match tokio::fs::read_to_string(&path).await {
308 Ok(c) => c,
309 Err(e) => {
310 return Ok(ToolResult::err(format!(
311 "Failed to read {}: {e}",
312 path.display()
313 )));
314 }
315 };
316
317 let lines: Vec<&str> = content.lines().collect();
318 let total = lines.len();
319
320 let offset = input.offset.unwrap_or(0) as usize;
321 let limit = input.limit.unwrap_or(total as u64) as usize;
322
323 if offset >= total {
324 return Ok(ToolResult::ok(String::new()));
325 }
326
327 let end = (offset + limit).min(total);
328 let selected = &lines[offset..end];
329
330 let width = format!("{}", end).len();
332 let mut output = String::new();
333 for (i, line) in selected.iter().enumerate() {
334 let line_no = offset + i + 1; output.push_str(&format!("{line_no:>width$}\t{line}\n", width = width));
336 }
337
338 Ok(ToolResult::ok(output))
339 }
340
341 async fn read_image(&self, path: &Path, media_type: &str) -> Result<ToolResult> {
343 let bytes = match tokio::fs::read(path).await {
344 Ok(b) => b,
345 Err(e) => {
346 return Ok(ToolResult::err(format!(
347 "Failed to read {}: {e}",
348 path.display()
349 )));
350 }
351 };
352
353 if bytes.len() > 20 * 1024 * 1024 {
355 return Ok(ToolResult::err(format!(
356 "Image too large ({:.1} MB, max 20 MB): {}",
357 bytes.len() as f64 / (1024.0 * 1024.0),
358 path.display()
359 )));
360 }
361
362 use base64::Engine;
363 let b64 = base64::engine::general_purpose::STANDARD.encode(&bytes);
364
365 Ok(ToolResult {
366 content: format!("Image: {}", path.display()),
367 is_error: false,
368 raw_content: Some(serde_json::json!([
369 {
370 "type": "image",
371 "source": {
372 "type": "base64",
373 "media_type": media_type,
374 "data": b64,
375 }
376 }
377 ])),
378 })
379 }
380
381 async fn execute_write(&self, input: &FileWriteInput) -> Result<ToolResult> {
385 let path = match self.resolve_and_check(&input.file_path) {
386 Ok(p) => p,
387 Err(denied) => return Ok(denied),
388 };
389
390 if let Some(parent) = path.parent() {
392 if let Err(e) = tokio::fs::create_dir_all(parent).await {
393 return Ok(ToolResult::err(format!(
394 "Failed to create directories for {}: {e}",
395 path.display()
396 )));
397 }
398 }
399
400 match tokio::fs::write(&path, &input.content).await {
401 Ok(()) => Ok(ToolResult::ok(format!(
402 "Successfully wrote to {}",
403 path.display()
404 ))),
405 Err(e) => Ok(ToolResult::err(format!(
406 "Failed to write {}: {e}",
407 path.display()
408 ))),
409 }
410 }
411
412 async fn execute_edit(&self, input: &FileEditInput) -> Result<ToolResult> {
420 let path = match self.resolve_and_check(&input.file_path) {
421 Ok(p) => p,
422 Err(denied) => return Ok(denied),
423 };
424
425 let content = match tokio::fs::read_to_string(&path).await {
426 Ok(c) => c,
427 Err(e) => {
428 return Ok(ToolResult::err(format!(
429 "Failed to read {}: {e}",
430 path.display()
431 )));
432 }
433 };
434
435 let replace_all = input.replace_all.unwrap_or(false);
436 let count = content.matches(&input.old_string).count();
437
438 if count == 0 {
439 return Ok(ToolResult::err(format!(
440 "old_string not found in {}. Make sure it matches exactly, including whitespace and indentation.",
441 path.display()
442 )));
443 }
444
445 if count > 1 && !replace_all {
446 return Ok(ToolResult::err(format!(
447 "old_string found {count} times in {}. Provide more surrounding context to make it unique, or set replace_all to true.",
448 path.display()
449 )));
450 }
451
452 let new_content = if replace_all {
453 content.replace(&input.old_string, &input.new_string)
454 } else {
455 content.replacen(&input.old_string, &input.new_string, 1)
457 };
458
459 match tokio::fs::write(&path, &new_content).await {
460 Ok(()) => {
461 let replacements = if replace_all {
462 format!("{count} replacement(s)")
463 } else {
464 "1 replacement".to_string()
465 };
466 Ok(ToolResult::ok(format!(
467 "Successfully edited {} ({replacements})",
468 path.display()
469 )))
470 }
471 Err(e) => Ok(ToolResult::err(format!(
472 "Failed to write {}: {e}",
473 path.display()
474 ))),
475 }
476 }
477
478 async fn execute_bash(&self, input: &BashInput) -> Result<ToolResult> {
483 if input.run_in_background == Some(true) {
485 let mut cmd = Command::new("/bin/bash");
486 cmd.arg("-c")
487 .arg(&input.command)
488 .current_dir(&self.cwd)
489 .env("HOME", &self.cwd)
490 .stdout(std::process::Stdio::null())
491 .stderr(std::process::Stdio::null());
492 for key in &self.env_blocklist {
493 cmd.env_remove(key);
494 }
495 let child = cmd.spawn();
496
497 return match child {
498 Ok(child) => {
499 let pid = child.id().unwrap_or(0);
500 Ok(ToolResult {
501 content: format!("Process started in background (pid: {pid})"),
502 is_error: false,
503 raw_content: None,
504 })
505 }
506 Err(e) => Ok(ToolResult::err(format!("Failed to spawn process: {e}"))),
507 };
508 }
509
510 let timeout_ms = input.timeout.unwrap_or(120_000);
511 let timeout_dur = Duration::from_millis(timeout_ms);
512
513 let mut cmd = Command::new("/bin/bash");
514 cmd.arg("-c")
515 .arg(&input.command)
516 .current_dir(&self.cwd)
517 .env("HOME", &self.cwd)
518 .stdout(std::process::Stdio::piped())
519 .stderr(std::process::Stdio::piped());
520 for key in &self.env_blocklist {
521 cmd.env_remove(key);
522 }
523 let child = cmd.spawn();
524
525 let mut child = match child {
526 Ok(c) => c,
527 Err(e) => {
528 return Ok(ToolResult::err(format!("Failed to spawn process: {e}")));
529 }
530 };
531
532 let stdout_handle = child.stdout.take();
534 let stderr_handle = child.stderr.take();
535
536 let wait_result = tokio::time::timeout(timeout_dur, child.wait()).await;
538
539 match wait_result {
540 Ok(Ok(status)) => {
541 let (stdout_bytes, stderr_bytes) = tokio::join!(
545 drain_pipe(stdout_handle),
546 drain_pipe(stderr_handle),
547 );
548
549 let stdout = String::from_utf8_lossy(&stdout_bytes);
550 let stderr = String::from_utf8_lossy(&stderr_bytes);
551
552 let mut combined = String::new();
553 if !stdout.is_empty() {
554 combined.push_str(&stdout);
555 }
556 if !stderr.is_empty() {
557 if !combined.is_empty() {
558 combined.push('\n');
559 }
560 combined.push_str(&stderr);
561 }
562
563 let is_error = !status.success();
564 if is_error && combined.is_empty() {
565 combined = format!(
566 "Process exited with code {}",
567 status.code().unwrap_or(-1)
568 );
569 }
570
571 Ok(ToolResult {
572 content: combined,
573 is_error,
574 raw_content: None,
575 })
576 }
577 Ok(Err(e)) => Ok(ToolResult::err(format!("Process IO error: {e}"))),
578 Err(_) => {
579 let _ = child.kill().await;
581 Ok(ToolResult::err(format!(
582 "Command timed out after {timeout_ms}ms"
583 )))
584 }
585 }
586 }
587
588 async fn execute_glob(&self, input: &GlobInput) -> Result<ToolResult> {
592 let base = match &input.path {
593 Some(p) => match self.resolve_and_check(p) {
594 Ok(resolved) => resolved,
595 Err(denied) => return Ok(denied),
596 },
597 None => self.cwd.clone(),
598 };
599
600 let full_pattern = base.join(&input.pattern);
601 let pattern_str = full_pattern.to_string_lossy().to_string();
602
603 let result = tokio::task::spawn_blocking(move || -> std::result::Result<Vec<String>, String> {
605 let entries = glob_match(&pattern_str).map_err(|e| format!("Invalid glob pattern: {e}"))?;
606
607 let mut paths: Vec<String> = Vec::new();
608 for entry in entries {
609 match entry {
610 Ok(p) => paths.push(p.to_string_lossy().to_string()),
611 Err(e) => {
612 warn!("glob entry error: {e}");
613 }
614 }
615 }
616 paths.sort();
617 Ok(paths)
618 })
619 .await
620 .map_err(|e| AgentError::ToolExecution(format!("glob task panicked: {e}")))?;
621
622 match result {
623 Ok(paths) => {
624 if paths.is_empty() {
625 Ok(ToolResult::ok("No files matched the pattern.".to_string()))
626 } else {
627 Ok(ToolResult::ok(paths.join("\n")))
628 }
629 }
630 Err(e) => Ok(ToolResult::err(e)),
631 }
632 }
633
634 async fn execute_grep(&self, input: &GrepInput) -> Result<ToolResult> {
639 if let Some(ref p) = input.path {
640 if let Err(denied) = self.resolve_and_check(p) {
641 return Ok(denied);
642 }
643 }
644
645 let input = input.clone();
646 let cwd = self.cwd.clone();
647
648 let result = tokio::task::spawn_blocking(move || grep_sync(&input, &cwd))
650 .await
651 .map_err(|e| AgentError::ToolExecution(format!("grep task panicked: {e}")))?;
652
653 result
654 }
655}
656
657fn extensions_for_type(file_type: &str) -> Option<Vec<&'static str>> {
661 let map: HashMap<&str, Vec<&str>> = HashMap::from([
662 ("rust", vec!["rs"]),
663 ("rs", vec!["rs"]),
664 ("py", vec!["py", "pyi"]),
665 ("python", vec!["py", "pyi"]),
666 ("js", vec!["js", "mjs", "cjs"]),
667 ("ts", vec!["ts", "tsx", "mts", "cts"]),
668 ("go", vec!["go"]),
669 ("java", vec!["java"]),
670 ("c", vec!["c", "h"]),
671 ("cpp", vec!["cpp", "cxx", "cc", "hpp", "hxx", "hh", "h"]),
672 ("rb", vec!["rb"]),
673 ("ruby", vec!["rb"]),
674 ("html", vec!["html", "htm"]),
675 ("css", vec!["css"]),
676 ("json", vec!["json"]),
677 ("yaml", vec!["yaml", "yml"]),
678 ("toml", vec!["toml"]),
679 ("md", vec!["md", "markdown"]),
680 ("sh", vec!["sh", "bash", "zsh"]),
681 ("sql", vec!["sql"]),
682 ("xml", vec!["xml"]),
683 ("swift", vec!["swift"]),
684 ("kt", vec!["kt", "kts"]),
685 ("scala", vec!["scala"]),
686 ]);
687 map.get(file_type).cloned()
688}
689
690fn matches_file_filter(path: &Path, glob_filter: &Option<glob::Pattern>, type_exts: &Option<Vec<&str>>) -> bool {
692 if let Some(pat) = glob_filter {
693 let name = path.file_name().unwrap_or_default().to_string_lossy();
694 if !pat.matches(&name) {
695 return false;
696 }
697 }
698 if let Some(exts) = type_exts {
699 let ext = path
700 .extension()
701 .unwrap_or_default()
702 .to_string_lossy()
703 .to_lowercase();
704 if !exts.contains(&ext.as_str()) {
705 return false;
706 }
707 }
708 true
709}
710
711fn walk_files(dir: &Path) -> Vec<PathBuf> {
713 let mut files = Vec::new();
714 walk_files_recursive(dir, &mut files);
715 files.sort();
716 files
717}
718
719fn walk_files_recursive(dir: &Path, out: &mut Vec<PathBuf>) {
720 let entries = match std::fs::read_dir(dir) {
721 Ok(e) => e,
722 Err(_) => return,
723 };
724 for entry in entries.flatten() {
725 let path = entry.path();
726 let name = entry.file_name();
727 let name_str = name.to_string_lossy();
728
729 if name_str.starts_with('.') || name_str == "node_modules" || name_str == "target" {
731 continue;
732 }
733
734 if path.is_dir() {
735 walk_files_recursive(&path, out);
736 } else if path.is_file() {
737 out.push(path);
738 }
739 }
740}
741
742fn grep_sync(input: &GrepInput, cwd: &Path) -> Result<ToolResult> {
744 let output_mode = input
745 .output_mode
746 .clone()
747 .unwrap_or(GrepOutputMode::FilesWithMatches);
748 let case_insensitive = input.case_insensitive.unwrap_or(false);
749 let show_line_numbers = input.line_numbers.unwrap_or(true);
750 let multiline = input.multiline.unwrap_or(false);
751
752 let context_lines = input.context.or(input.context_alias);
754 let before_context = input.before_context.or(context_lines).unwrap_or(0) as usize;
755 let after_context = input.after_context.or(context_lines).unwrap_or(0) as usize;
756
757 let head_limit = input.head_limit.unwrap_or(0) as usize;
758 let offset = input.offset.unwrap_or(0) as usize;
759
760 let re = RegexBuilder::new(&input.pattern)
762 .case_insensitive(case_insensitive)
763 .multi_line(multiline)
764 .dot_matches_new_line(multiline)
765 .build()?;
766
767 let search_path = match &input.path {
769 Some(p) => {
770 let resolved = if Path::new(p).is_absolute() {
771 PathBuf::from(p)
772 } else {
773 cwd.join(p)
774 };
775 resolved
776 }
777 None => cwd.to_path_buf(),
778 };
779
780 let glob_filter = input.glob.as_ref().map(|g| {
782 glob::Pattern::new(g).unwrap_or_else(|_| glob::Pattern::new("*").unwrap())
783 });
784 let type_exts = input.file_type.as_ref().and_then(|t| {
785 extensions_for_type(t).map(|v| v.into_iter().collect::<Vec<_>>())
786 });
787
788 let files = if search_path.is_file() {
790 vec![search_path.clone()]
791 } else {
792 walk_files(&search_path)
793 };
794
795 let files: Vec<PathBuf> = files
797 .into_iter()
798 .filter(|f| matches_file_filter(f, &glob_filter, &type_exts))
799 .collect();
800
801 match output_mode {
802 GrepOutputMode::FilesWithMatches => {
803 grep_files_with_matches(&re, &files, offset, head_limit)
804 }
805 GrepOutputMode::Count => grep_count(&re, &files, offset, head_limit),
806 GrepOutputMode::Content => grep_content(
807 &re,
808 &files,
809 before_context,
810 after_context,
811 show_line_numbers,
812 offset,
813 head_limit,
814 ),
815 }
816}
817
818fn grep_files_with_matches(
819 re: &Regex,
820 files: &[PathBuf],
821 offset: usize,
822 head_limit: usize,
823) -> Result<ToolResult> {
824 let mut matched: Vec<String> = Vec::new();
825 for file in files {
826 if let Ok(content) = std::fs::read_to_string(file) {
827 if re.is_match(&content) {
828 matched.push(file.to_string_lossy().to_string());
829 }
830 }
831 }
832
833 let result = apply_offset_limit(matched, offset, head_limit);
834 if result.is_empty() {
835 Ok(ToolResult::ok("No matches found.".to_string()))
836 } else {
837 Ok(ToolResult::ok(result.join("\n")))
838 }
839}
840
841fn grep_count(
842 re: &Regex,
843 files: &[PathBuf],
844 offset: usize,
845 head_limit: usize,
846) -> Result<ToolResult> {
847 let mut entries: Vec<String> = Vec::new();
848 for file in files {
849 if let Ok(content) = std::fs::read_to_string(file) {
850 let count = re.find_iter(&content).count();
851 if count > 0 {
852 entries.push(format!("{}:{count}", file.to_string_lossy()));
853 }
854 }
855 }
856
857 let result = apply_offset_limit(entries, offset, head_limit);
858 if result.is_empty() {
859 Ok(ToolResult::ok("No matches found.".to_string()))
860 } else {
861 Ok(ToolResult::ok(result.join("\n")))
862 }
863}
864
865fn grep_content(
866 re: &Regex,
867 files: &[PathBuf],
868 before_context: usize,
869 after_context: usize,
870 show_line_numbers: bool,
871 offset: usize,
872 head_limit: usize,
873) -> Result<ToolResult> {
874 let mut output_lines: Vec<String> = Vec::new();
875
876 for file in files {
877 let content = match std::fs::read_to_string(file) {
878 Ok(c) => c,
879 Err(_) => continue,
880 };
881
882 let lines: Vec<&str> = content.lines().collect();
883 let file_display = file.to_string_lossy();
884
885 let mut matching_line_indices: Vec<usize> = Vec::new();
887 for (i, line) in lines.iter().enumerate() {
888 if re.is_match(line) {
889 matching_line_indices.push(i);
890 }
891 }
892
893 if matching_line_indices.is_empty() {
894 continue;
895 }
896
897 let mut display_set = Vec::new();
899 for &idx in &matching_line_indices {
900 let start = idx.saturating_sub(before_context);
901 let end = (idx + after_context + 1).min(lines.len());
902 for i in start..end {
903 display_set.push(i);
904 }
905 }
906 display_set.sort();
907 display_set.dedup();
908
909 let mut prev: Option<usize> = None;
911 for &line_idx in &display_set {
912 if let Some(p) = prev {
913 if line_idx > p + 1 {
914 output_lines.push("--".to_string());
915 }
916 }
917
918 let line_content = lines[line_idx];
919 if show_line_numbers {
920 let sep = if matching_line_indices.contains(&line_idx) {
921 ':'
922 } else {
923 '-'
924 };
925 output_lines.push(format!(
926 "{file_display}{sep}{}{sep}{line_content}",
927 line_idx + 1
928 ));
929 } else {
930 output_lines.push(format!("{file_display}:{line_content}"));
931 }
932
933 prev = Some(line_idx);
934 }
935 }
936
937 let result = apply_offset_limit(output_lines, offset, head_limit);
938 if result.is_empty() {
939 Ok(ToolResult::ok("No matches found.".to_string()))
940 } else {
941 Ok(ToolResult::ok(result.join("\n")))
942 }
943}
944
945fn apply_offset_limit(items: Vec<String>, offset: usize, head_limit: usize) -> Vec<String> {
947 let after_offset: Vec<String> = items.into_iter().skip(offset).collect();
948 if head_limit > 0 {
949 after_offset.into_iter().take(head_limit).collect()
950 } else {
951 after_offset
952 }
953}
954
955#[cfg(test)]
956mod tests {
957 use super::*;
958 use serde_json::json;
959 use tempfile::TempDir;
960
961 fn setup() -> (TempDir, ToolExecutor) {
962 let tmp = TempDir::new().unwrap();
963 let executor = ToolExecutor::new(tmp.path().to_path_buf());
964 (tmp, executor)
965 }
966
967 #[tokio::test]
970 async fn read_text_file() {
971 let (tmp, executor) = setup();
972 let file = tmp.path().join("hello.txt");
973 std::fs::write(&file, "line one\nline two\nline three\n").unwrap();
974
975 let result = executor
976 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
977 .await
978 .unwrap();
979
980 assert!(!result.is_error);
981 assert!(result.raw_content.is_none());
982 assert!(result.content.contains("line one"));
983 assert!(result.content.contains("line three"));
984 }
985
986 #[tokio::test]
987 async fn read_text_file_with_offset_and_limit() {
988 let (tmp, executor) = setup();
989 let file = tmp.path().join("lines.txt");
990 std::fs::write(&file, "a\nb\nc\nd\ne\n").unwrap();
991
992 let result = executor
993 .execute(
994 "Read",
995 json!({ "file_path": file.to_str().unwrap(), "offset": 1, "limit": 2 }),
996 )
997 .await
998 .unwrap();
999
1000 assert!(!result.is_error);
1001 assert!(result.content.contains("b"));
1003 assert!(result.content.contains("c"));
1004 assert!(!result.content.contains("a"));
1005 assert!(!result.content.contains("d"));
1006 }
1007
1008 #[tokio::test]
1009 async fn read_missing_file_returns_error() {
1010 let (tmp, executor) = setup();
1011 let file = tmp.path().join("nope.txt");
1012
1013 let result = executor
1014 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1015 .await
1016 .unwrap();
1017
1018 assert!(result.is_error);
1019 assert!(result.content.contains("Failed to read"));
1020 }
1021
1022 #[tokio::test]
1025 async fn read_png_returns_image_content_block() {
1026 let (tmp, executor) = setup();
1027 let file = tmp.path().join("test.png");
1028 let png_bytes = b"\x89PNG\r\n\x1a\nfake-png-payload";
1029 std::fs::write(&file, png_bytes).unwrap();
1030
1031 let result = executor
1032 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1033 .await
1034 .unwrap();
1035
1036 assert!(!result.is_error);
1037 assert!(result.raw_content.is_some(), "image should set raw_content");
1038
1039 let blocks = result.raw_content.unwrap();
1040 let block = blocks.as_array().unwrap().first().unwrap();
1041 assert_eq!(block["type"], "image");
1042 assert_eq!(block["source"]["type"], "base64");
1043 assert_eq!(block["source"]["media_type"], "image/png");
1044 let data = block["source"]["data"].as_str().unwrap();
1046 assert!(!data.is_empty());
1047 use base64::Engine;
1048 let decoded = base64::engine::general_purpose::STANDARD.decode(data).unwrap();
1049 assert_eq!(decoded, png_bytes);
1050 }
1051
1052 #[tokio::test]
1053 async fn read_jpeg_returns_image_content_block() {
1054 let (tmp, executor) = setup();
1055 let file = tmp.path().join("photo.jpg");
1058 let fake_jpeg = b"\xFF\xD8\xFF\xE0fake-jpeg-data";
1059 std::fs::write(&file, fake_jpeg).unwrap();
1060
1061 let result = executor
1062 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1063 .await
1064 .unwrap();
1065
1066 assert!(!result.is_error);
1067 let blocks = result.raw_content.unwrap();
1068 let block = blocks.as_array().unwrap().first().unwrap();
1069 assert_eq!(block["source"]["media_type"], "image/jpeg");
1070 }
1071
1072 #[tokio::test]
1073 async fn read_jpeg_extension_detected() {
1074 let (tmp, executor) = setup();
1075 let file = tmp.path().join("photo.jpeg");
1076 std::fs::write(&file, b"data").unwrap();
1077
1078 let result = executor
1079 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1080 .await
1081 .unwrap();
1082
1083 assert!(!result.is_error);
1084 let blocks = result.raw_content.unwrap();
1085 assert_eq!(blocks[0]["source"]["media_type"], "image/jpeg");
1086 }
1087
1088 #[tokio::test]
1089 async fn read_gif_returns_image_content_block() {
1090 let (tmp, executor) = setup();
1091 let file = tmp.path().join("anim.gif");
1092 std::fs::write(&file, b"GIF89adata").unwrap();
1093
1094 let result = executor
1095 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1096 .await
1097 .unwrap();
1098
1099 assert!(!result.is_error);
1100 let blocks = result.raw_content.unwrap();
1101 assert_eq!(blocks[0]["source"]["media_type"], "image/gif");
1102 }
1103
1104 #[tokio::test]
1105 async fn read_webp_returns_image_content_block() {
1106 let (tmp, executor) = setup();
1107 let file = tmp.path().join("img.webp");
1108 std::fs::write(&file, b"RIFF\x00\x00\x00\x00WEBP").unwrap();
1109
1110 let result = executor
1111 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1112 .await
1113 .unwrap();
1114
1115 assert!(!result.is_error);
1116 let blocks = result.raw_content.unwrap();
1117 assert_eq!(blocks[0]["source"]["media_type"], "image/webp");
1118 }
1119
1120 #[tokio::test]
1121 async fn read_missing_image_returns_error() {
1122 let (tmp, executor) = setup();
1123 let file = tmp.path().join("nope.png");
1124
1125 let result = executor
1126 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1127 .await
1128 .unwrap();
1129
1130 assert!(result.is_error);
1131 assert!(result.content.contains("Failed to read"));
1132 assert!(result.raw_content.is_none());
1133 }
1134
1135 #[tokio::test]
1136 async fn read_non_image_extension_returns_text() {
1137 let (tmp, executor) = setup();
1138 let file = tmp.path().join("data.csv");
1139 std::fs::write(&file, "a,b,c\n1,2,3\n").unwrap();
1140
1141 let result = executor
1142 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1143 .await
1144 .unwrap();
1145
1146 assert!(!result.is_error);
1147 assert!(result.raw_content.is_none(), "csv should not be treated as image");
1148 assert!(result.content.contains("a,b,c"));
1149 }
1150
1151 #[test]
1154 fn tool_result_ok_has_no_raw_content() {
1155 let r = ToolResult::ok("hello".into());
1156 assert!(!r.is_error);
1157 assert!(r.raw_content.is_none());
1158 }
1159
1160 #[test]
1161 fn tool_result_err_has_no_raw_content() {
1162 let r = ToolResult::err("boom".into());
1163 assert!(r.is_error);
1164 assert!(r.raw_content.is_none());
1165 }
1166
1167 fn setup_sandboxed() -> (TempDir, TempDir, ToolExecutor) {
1170 let project = TempDir::new().unwrap();
1171 let data = TempDir::new().unwrap();
1172 let executor = ToolExecutor::with_allowed_dirs(
1173 project.path().to_path_buf(),
1174 vec![data.path().to_path_buf()],
1175 );
1176 (project, data, executor)
1177 }
1178
1179 #[tokio::test]
1180 async fn sandbox_allows_read_inside_cwd() {
1181 let (project, _data, executor) = setup_sandboxed();
1182 let file = project.path().join("hello.txt");
1183 std::fs::write(&file, "ok").unwrap();
1184
1185 let result = executor
1186 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1187 .await
1188 .unwrap();
1189 assert!(!result.is_error);
1190 }
1191
1192 #[tokio::test]
1193 async fn sandbox_allows_read_inside_additional_dir() {
1194 let (_project, data, executor) = setup_sandboxed();
1195 let file = data.path().join("MEMORY.md");
1196 std::fs::write(&file, "# Memory").unwrap();
1197
1198 let result = executor
1199 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1200 .await
1201 .unwrap();
1202 assert!(!result.is_error);
1203 assert!(result.content.contains("Memory"));
1204 }
1205
1206 #[tokio::test]
1207 async fn sandbox_denies_read_outside_boundaries() {
1208 let (_project, _data, executor) = setup_sandboxed();
1209 let outside = TempDir::new().unwrap();
1210 let file = outside.path().join("secret.txt");
1211 std::fs::write(&file, "secret data").unwrap();
1212
1213 let result = executor
1214 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1215 .await
1216 .unwrap();
1217 assert!(result.is_error);
1218 assert!(result.content.contains("Access denied"));
1219 }
1220
1221 #[tokio::test]
1222 async fn sandbox_denies_write_outside_boundaries() {
1223 let (_project, _data, executor) = setup_sandboxed();
1224 let outside = TempDir::new().unwrap();
1225 let file = outside.path().join("hack.txt");
1226
1227 let result = executor
1228 .execute(
1229 "Write",
1230 json!({ "file_path": file.to_str().unwrap(), "content": "pwned" }),
1231 )
1232 .await
1233 .unwrap();
1234 assert!(result.is_error);
1235 assert!(result.content.contains("Access denied"));
1236 assert!(!file.exists());
1237 }
1238
1239 #[tokio::test]
1240 async fn sandbox_denies_edit_outside_boundaries() {
1241 let (_project, _data, executor) = setup_sandboxed();
1242 let outside = TempDir::new().unwrap();
1243 let file = outside.path().join("target.txt");
1244 std::fs::write(&file, "original").unwrap();
1245
1246 let result = executor
1247 .execute(
1248 "Edit",
1249 json!({
1250 "file_path": file.to_str().unwrap(),
1251 "old_string": "original",
1252 "new_string": "modified"
1253 }),
1254 )
1255 .await
1256 .unwrap();
1257 assert!(result.is_error);
1258 assert!(result.content.contains("Access denied"));
1259 }
1260
1261 #[tokio::test]
1262 async fn no_sandbox_when_allowed_dirs_empty() {
1263 let outside = TempDir::new().unwrap();
1265 let file = outside.path().join("free.txt");
1266 std::fs::write(&file, "open access").unwrap();
1267
1268 let executor = ToolExecutor::new(TempDir::new().unwrap().path().to_path_buf());
1269 let result = executor
1270 .execute("Read", json!({ "file_path": file.to_str().unwrap() }))
1271 .await
1272 .unwrap();
1273 assert!(!result.is_error);
1274 }
1275
1276 #[tokio::test]
1277 async fn sandbox_denies_dotdot_traversal() {
1278 let (project, _data, executor) = setup_sandboxed();
1279 let outside = TempDir::new().unwrap();
1281 let secret = outside.path().join("secret.txt");
1282 std::fs::write(&secret, "sensitive").unwrap();
1283
1284 let traversal = project
1286 .path()
1287 .join("..")
1288 .join("..")
1289 .join(outside.path().strip_prefix("/").unwrap())
1290 .join("secret.txt");
1291
1292 let result = executor
1293 .execute("Read", json!({ "file_path": traversal.to_str().unwrap() }))
1294 .await
1295 .unwrap();
1296 assert!(result.is_error);
1297 assert!(result.content.contains("Access denied"));
1298 }
1299
1300 #[tokio::test]
1301 async fn sandbox_allows_write_new_file_inside_cwd() {
1302 let (project, _data, executor) = setup_sandboxed();
1303 let file = project.path().join("subdir").join("new.txt");
1304
1305 let result = executor
1306 .execute(
1307 "Write",
1308 json!({ "file_path": file.to_str().unwrap(), "content": "hello" }),
1309 )
1310 .await
1311 .unwrap();
1312 assert!(!result.is_error);
1313 assert!(file.exists());
1314 }
1315
1316 #[tokio::test]
1317 async fn sandbox_denies_symlink_escape() {
1318 let (project, _data, executor) = setup_sandboxed();
1319 let outside = TempDir::new().unwrap();
1320 let secret = outside.path().join("secret.txt");
1321 std::fs::write(&secret, "sensitive").unwrap();
1322
1323 let link = project.path().join("escape");
1325 std::os::unix::fs::symlink(outside.path(), &link).unwrap();
1326
1327 let via_link = link.join("secret.txt");
1328 let result = executor
1329 .execute("Read", json!({ "file_path": via_link.to_str().unwrap() }))
1330 .await
1331 .unwrap();
1332 assert!(result.is_error);
1333 assert!(result.content.contains("Access denied"));
1334 }
1335
1336 #[tokio::test]
1337 async fn bash_sets_home_to_cwd() {
1338 let (tmp, executor) = setup();
1339
1340 let result = executor
1341 .execute("Bash", json!({ "command": "echo $HOME" }))
1342 .await
1343 .unwrap();
1344 assert!(!result.is_error, "Bash should succeed");
1345
1346 let expected = tmp.path().to_string_lossy().to_string();
1347 assert!(
1348 result.content.trim().contains(&expected),
1349 "HOME should be set to cwd ({}), got: {}",
1350 expected,
1351 result.content.trim()
1352 );
1353 }
1354
1355 #[tokio::test]
1356 async fn bash_tilde_resolves_to_cwd() {
1357 let (tmp, executor) = setup();
1358
1359 std::fs::write(tmp.path().join("marker.txt"), "found").unwrap();
1361
1362 let result = executor
1363 .execute("Bash", json!({ "command": "cat ~/marker.txt" }))
1364 .await
1365 .unwrap();
1366 assert!(!result.is_error, "Should read file via ~: {}", result.content);
1367 assert!(result.content.contains("found"), "~ should resolve to cwd");
1368 }
1369
1370 #[tokio::test]
1371 async fn bash_env_blocklist_strips_vars() {
1372 let tmp = TempDir::new().unwrap();
1373 unsafe { std::env::set_var("STARPOD_TEST_SECRET", "leaked"); }
1375
1376 let executor = ToolExecutor::new(tmp.path().to_path_buf())
1377 .with_env_blocklist(vec!["STARPOD_TEST_SECRET".to_string()]);
1378
1379 let result = executor
1380 .execute("Bash", json!({ "command": "echo \"val=${STARPOD_TEST_SECRET}\"" }))
1381 .await
1382 .unwrap();
1383 assert!(!result.is_error);
1384 assert_eq!(result.content.trim(), "val=", "Blocked env var should not be visible to child process");
1385
1386 std::env::remove_var("STARPOD_TEST_SECRET");
1388 }
1389
1390 #[tokio::test]
1391 async fn bash_env_blocklist_does_not_affect_other_vars() {
1392 let tmp = TempDir::new().unwrap();
1393 unsafe {
1394 std::env::set_var("STARPOD_TEST_ALLOWED", "visible");
1395 std::env::set_var("STARPOD_TEST_BLOCKED", "hidden");
1396 }
1397
1398 let executor = ToolExecutor::new(tmp.path().to_path_buf())
1399 .with_env_blocklist(vec!["STARPOD_TEST_BLOCKED".to_string()]);
1400
1401 let result = executor
1402 .execute("Bash", json!({ "command": "echo $STARPOD_TEST_ALLOWED" }))
1403 .await
1404 .unwrap();
1405 assert!(result.content.contains("visible"), "Non-blocked vars should still be inherited");
1406
1407 std::env::remove_var("STARPOD_TEST_ALLOWED");
1409 std::env::remove_var("STARPOD_TEST_BLOCKED");
1410 }
1411}