spec_ai_core/tools/builtin/
file_write.rs

1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use base64::{engine::general_purpose, Engine as _};
5use serde::Deserialize;
6use serde_json::Value;
7use std::fs::{self, OpenOptions};
8use std::io::Write;
9use std::path::{Path, PathBuf};
10use tempfile::NamedTempFile;
11
12const DEFAULT_MAX_BYTES: usize = 1_048_576; // 1 MiB
13
14#[derive(Debug, Deserialize)]
15#[serde(rename_all = "snake_case")]
16#[derive(Default)]
17enum WriteMode {
18    #[default]
19    Overwrite,
20    Append,
21}
22
23#[derive(Debug, Deserialize)]
24#[serde(rename_all = "snake_case")]
25#[derive(Default)]
26enum ContentEncoding {
27    #[default]
28    Text,
29    Base64,
30}
31
32#[derive(Debug, Deserialize)]
33struct FileWriteArgs {
34    path: String,
35    content: String,
36    #[serde(default)]
37    mode: WriteMode,
38    #[serde(default)]
39    encoding: ContentEncoding,
40    #[serde(default = "FileWriteArgs::default_create_dirs")]
41    create_dirs: bool,
42}
43
44impl FileWriteArgs {
45    fn default_create_dirs() -> bool {
46        true
47    }
48}
49
50#[derive(Debug, serde::Serialize)]
51struct FileWriteOutput {
52    path: String,
53    mode: &'static str,
54    bytes_written: usize,
55    existed: bool,
56    message: String,
57}
58
59/// Tool for writing files to disk with safeguards
60pub struct FileWriteTool {
61    max_bytes: usize,
62}
63
64impl FileWriteTool {
65    pub fn new() -> Self {
66        Self {
67            max_bytes: DEFAULT_MAX_BYTES,
68        }
69    }
70
71    pub fn with_max_bytes(mut self, max_bytes: usize) -> Self {
72        self.max_bytes = max_bytes;
73        self
74    }
75
76    fn resolve_path(&self, path: &str) -> Result<PathBuf> {
77        if path.trim().is_empty() {
78            return Err(anyhow!("file_write requires a valid path"));
79        }
80        Ok(PathBuf::from(path))
81    }
82
83    fn ensure_parent(&self, path: &Path, create_dirs: bool) -> Result<()> {
84        if let Some(parent) = path.parent() {
85            if parent.exists() {
86                return Ok(());
87            }
88            if create_dirs {
89                fs::create_dir_all(parent).with_context(|| {
90                    format!("Failed to create parent directories for {}", path.display())
91                })?;
92            } else {
93                return Err(anyhow!(
94                    "Parent directory does not exist for {} (set create_dirs=true to create it)",
95                    path.display()
96                ));
97            }
98            return Ok(());
99        }
100        Ok(())
101    }
102
103    fn decode_content(&self, args: &FileWriteArgs) -> Result<Vec<u8>> {
104        let bytes = match args.encoding {
105            ContentEncoding::Text => args.content.clone().into_bytes(),
106            ContentEncoding::Base64 => general_purpose::STANDARD
107                .decode(&args.content)
108                .context("Failed to decode base64 content for file_write")?,
109        };
110
111        if bytes.len() > self.max_bytes {
112            return Err(anyhow!(
113                "Content exceeds maximum allowed size of {} bytes",
114                self.max_bytes
115            ));
116        }
117
118        Ok(bytes)
119    }
120
121    fn write_overwrite(&self, path: &Path, bytes: &[u8]) -> Result<()> {
122        if let Some(parent) = path.parent() {
123            if !parent.exists() {
124                return Err(anyhow!(
125                    "Parent directory {} must exist before writing",
126                    parent.display()
127                ));
128            }
129        }
130
131        let parent = path.parent().unwrap_or_else(|| Path::new("."));
132        let mut tmp = NamedTempFile::new_in(parent)
133            .with_context(|| format!("Failed to create temporary file near {}", path.display()))?;
134        tmp.write_all(bytes)
135            .with_context(|| format!("Failed to write temporary file for {}", path.display()))?;
136        tmp.flush()?;
137        tmp.as_file().sync_all().ok();
138
139        if path.exists() {
140            fs::remove_file(path)
141                .with_context(|| format!("Failed to remove existing file {}", path.display()))?;
142        }
143
144        let tmp_path = tmp.into_temp_path();
145        tmp_path
146            .persist(path)
147            .map_err(|err| anyhow!("Failed to persist file {}: {}", path.display(), err))?;
148        Ok(())
149    }
150
151    fn write_append(&self, path: &Path, bytes: &[u8]) -> Result<()> {
152        let mut file = OpenOptions::new()
153            .create(true)
154            .append(true)
155            .open(path)
156            .with_context(|| format!("Failed to open {} for appending", path.display()))?;
157
158        file.write_all(bytes)
159            .with_context(|| format!("Failed to append to {}", path.display()))?;
160        file.flush().ok();
161        file.sync_all().ok();
162        Ok(())
163    }
164}
165
166impl Default for FileWriteTool {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172#[async_trait]
173impl Tool for FileWriteTool {
174    fn name(&self) -> &str {
175        "file_write"
176    }
177
178    fn description(&self) -> &str {
179        "Writes text or base64-decoded content to files with optional append support"
180    }
181
182    fn parameters(&self) -> Value {
183        serde_json::json!({
184            "type": "object",
185            "properties": {
186                "path": {
187                    "type": "string",
188                    "description": "Relative or absolute file path to write"
189                },
190                "content": {
191                    "type": "string",
192                    "description": "Content to write to the file"
193                },
194                "mode": {
195                    "type": "string",
196                    "enum": ["overwrite", "append"],
197                    "description": "Overwrite (default) or append to existing files",
198                    "default": "overwrite"
199                },
200                "encoding": {
201                    "type": "string",
202                    "enum": ["text", "base64"],
203                    "description": "Encoding for the provided content",
204                    "default": "text"
205                },
206                "create_dirs": {
207                    "type": "boolean",
208                    "description": "Create parent directories when needed",
209                    "default": true
210                }
211            },
212            "required": ["path", "content"]
213        })
214    }
215
216    async fn execute(&self, args: Value) -> Result<ToolResult> {
217        let args: FileWriteArgs =
218            serde_json::from_value(args).context("Failed to parse file_write arguments")?;
219
220        let path = self.resolve_path(&args.path)?;
221        self.ensure_parent(&path, args.create_dirs)?;
222        let bytes = self.decode_content(&args)?;
223
224        let existed = path.exists();
225
226        match args.mode {
227            WriteMode::Overwrite => self.write_overwrite(&path, &bytes)?,
228            WriteMode::Append => self.write_append(&path, &bytes)?,
229        };
230
231        let message = match args.mode {
232            WriteMode::Overwrite if existed => "File overwritten",
233            WriteMode::Overwrite => "File created",
234            WriteMode::Append if existed => "Content appended to existing file",
235            WriteMode::Append => "Content appended to new file",
236        }
237        .to_string();
238
239        let output = FileWriteOutput {
240            path: path.to_string_lossy().into_owned(),
241            mode: match args.mode {
242                WriteMode::Overwrite => "overwrite",
243                WriteMode::Append => "append",
244            },
245            bytes_written: bytes.len(),
246            existed,
247            message,
248        };
249
250        Ok(ToolResult::success(
251            serde_json::to_string(&output).context("Failed to serialize file_write output")?,
252        ))
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use std::fs;
260    use tempfile::tempdir;
261
262    #[tokio::test]
263    async fn test_file_write_overwrite() {
264        let dir = tempdir().unwrap();
265        let path = dir.path().join("file.txt");
266        let tool = FileWriteTool::new();
267
268        let args = serde_json::json!({
269            "path": path.to_string_lossy(),
270            "content": "hello world",
271            "create_dirs": true
272        });
273
274        let result = tool.execute(args).await.unwrap();
275        assert!(result.success);
276        assert_eq!(fs::read_to_string(&path).unwrap(), "hello world");
277    }
278
279    #[tokio::test]
280    async fn test_file_write_append() {
281        let dir = tempdir().unwrap();
282        let path = dir.path().join("append.txt");
283        fs::write(&path, "line1\n").unwrap();
284        let tool = FileWriteTool::new();
285
286        let args = serde_json::json!({
287            "path": path.to_string_lossy(),
288            "content": "line2\n",
289            "mode": "append"
290        });
291
292        let result = tool.execute(args).await.unwrap();
293        assert!(result.success);
294        let contents = fs::read_to_string(&path).unwrap();
295        assert!(contents.contains("line2"));
296    }
297
298    #[tokio::test]
299    async fn test_file_write_base64() {
300        let dir = tempdir().unwrap();
301        let path = dir.path().join("binary.bin");
302        let tool = FileWriteTool::new();
303
304        let args = serde_json::json!({
305            "path": path.to_string_lossy(),
306            "content": "AQID",
307            "encoding": "base64"
308        });
309
310        let result = tool.execute(args).await.unwrap();
311        assert!(result.success);
312        let bytes = fs::read(&path).unwrap();
313        assert_eq!(bytes, vec![1, 2, 3]);
314    }
315}