spec_ai_core/tools/builtin/
file_write.rs1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use base64::{engine::general_purpose, Engine as _};
5use serde::Deserialize;
6use serde_json::Value;
7use std::fs::{self, OpenOptions};
8use std::io::Write;
9use std::path::{Path, PathBuf};
10use tempfile::NamedTempFile;
11
12const DEFAULT_MAX_BYTES: usize = 1_048_576; #[derive(Debug, Deserialize)]
15#[serde(rename_all = "snake_case")]
16#[derive(Default)]
17enum WriteMode {
18 #[default]
19 Overwrite,
20 Append,
21}
22
23#[derive(Debug, Deserialize)]
24#[serde(rename_all = "snake_case")]
25#[derive(Default)]
26enum ContentEncoding {
27 #[default]
28 Text,
29 Base64,
30}
31
32#[derive(Debug, Deserialize)]
33struct FileWriteArgs {
34 path: String,
35 content: String,
36 #[serde(default)]
37 mode: WriteMode,
38 #[serde(default)]
39 encoding: ContentEncoding,
40 #[serde(default = "FileWriteArgs::default_create_dirs")]
41 create_dirs: bool,
42}
43
44impl FileWriteArgs {
45 fn default_create_dirs() -> bool {
46 true
47 }
48}
49
50#[derive(Debug, serde::Serialize)]
51struct FileWriteOutput {
52 path: String,
53 mode: &'static str,
54 bytes_written: usize,
55 existed: bool,
56 message: String,
57}
58
59pub struct FileWriteTool {
61 max_bytes: usize,
62}
63
64impl FileWriteTool {
65 pub fn new() -> Self {
66 Self {
67 max_bytes: DEFAULT_MAX_BYTES,
68 }
69 }
70
71 pub fn with_max_bytes(mut self, max_bytes: usize) -> Self {
72 self.max_bytes = max_bytes;
73 self
74 }
75
76 fn resolve_path(&self, path: &str) -> Result<PathBuf> {
77 if path.trim().is_empty() {
78 return Err(anyhow!("file_write requires a valid path"));
79 }
80 Ok(PathBuf::from(path))
81 }
82
83 fn ensure_parent(&self, path: &Path, create_dirs: bool) -> Result<()> {
84 if let Some(parent) = path.parent() {
85 if parent.exists() {
86 return Ok(());
87 }
88 if create_dirs {
89 fs::create_dir_all(parent).with_context(|| {
90 format!("Failed to create parent directories for {}", path.display())
91 })?;
92 } else {
93 return Err(anyhow!(
94 "Parent directory does not exist for {} (set create_dirs=true to create it)",
95 path.display()
96 ));
97 }
98 return Ok(());
99 }
100 Ok(())
101 }
102
103 fn decode_content(&self, args: &FileWriteArgs) -> Result<Vec<u8>> {
104 let bytes = match args.encoding {
105 ContentEncoding::Text => args.content.clone().into_bytes(),
106 ContentEncoding::Base64 => general_purpose::STANDARD
107 .decode(&args.content)
108 .context("Failed to decode base64 content for file_write")?,
109 };
110
111 if bytes.len() > self.max_bytes {
112 return Err(anyhow!(
113 "Content exceeds maximum allowed size of {} bytes",
114 self.max_bytes
115 ));
116 }
117
118 Ok(bytes)
119 }
120
121 fn write_overwrite(&self, path: &Path, bytes: &[u8]) -> Result<()> {
122 if let Some(parent) = path.parent() {
123 if !parent.exists() {
124 return Err(anyhow!(
125 "Parent directory {} must exist before writing",
126 parent.display()
127 ));
128 }
129 }
130
131 let parent = path.parent().unwrap_or_else(|| Path::new("."));
132 let mut tmp = NamedTempFile::new_in(parent)
133 .with_context(|| format!("Failed to create temporary file near {}", path.display()))?;
134 tmp.write_all(bytes)
135 .with_context(|| format!("Failed to write temporary file for {}", path.display()))?;
136 tmp.flush()?;
137 tmp.as_file().sync_all().ok();
138
139 if path.exists() {
140 fs::remove_file(path)
141 .with_context(|| format!("Failed to remove existing file {}", path.display()))?;
142 }
143
144 let tmp_path = tmp.into_temp_path();
145 tmp_path
146 .persist(path)
147 .map_err(|err| anyhow!("Failed to persist file {}: {}", path.display(), err))?;
148 Ok(())
149 }
150
151 fn write_append(&self, path: &Path, bytes: &[u8]) -> Result<()> {
152 let mut file = OpenOptions::new()
153 .create(true)
154 .append(true)
155 .open(path)
156 .with_context(|| format!("Failed to open {} for appending", path.display()))?;
157
158 file.write_all(bytes)
159 .with_context(|| format!("Failed to append to {}", path.display()))?;
160 file.flush().ok();
161 file.sync_all().ok();
162 Ok(())
163 }
164}
165
166impl Default for FileWriteTool {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[async_trait]
173impl Tool for FileWriteTool {
174 fn name(&self) -> &str {
175 "file_write"
176 }
177
178 fn description(&self) -> &str {
179 "Writes text or base64-decoded content to files with optional append support"
180 }
181
182 fn parameters(&self) -> Value {
183 serde_json::json!({
184 "type": "object",
185 "properties": {
186 "path": {
187 "type": "string",
188 "description": "Relative or absolute file path to write"
189 },
190 "content": {
191 "type": "string",
192 "description": "Content to write to the file"
193 },
194 "mode": {
195 "type": "string",
196 "enum": ["overwrite", "append"],
197 "description": "Overwrite (default) or append to existing files",
198 "default": "overwrite"
199 },
200 "encoding": {
201 "type": "string",
202 "enum": ["text", "base64"],
203 "description": "Encoding for the provided content",
204 "default": "text"
205 },
206 "create_dirs": {
207 "type": "boolean",
208 "description": "Create parent directories when needed",
209 "default": true
210 }
211 },
212 "required": ["path", "content"]
213 })
214 }
215
216 async fn execute(&self, args: Value) -> Result<ToolResult> {
217 let args: FileWriteArgs =
218 serde_json::from_value(args).context("Failed to parse file_write arguments")?;
219
220 let path = self.resolve_path(&args.path)?;
221 self.ensure_parent(&path, args.create_dirs)?;
222 let bytes = self.decode_content(&args)?;
223
224 let existed = path.exists();
225
226 match args.mode {
227 WriteMode::Overwrite => self.write_overwrite(&path, &bytes)?,
228 WriteMode::Append => self.write_append(&path, &bytes)?,
229 };
230
231 let message = match args.mode {
232 WriteMode::Overwrite if existed => "File overwritten",
233 WriteMode::Overwrite => "File created",
234 WriteMode::Append if existed => "Content appended to existing file",
235 WriteMode::Append => "Content appended to new file",
236 }
237 .to_string();
238
239 let output = FileWriteOutput {
240 path: path.to_string_lossy().into_owned(),
241 mode: match args.mode {
242 WriteMode::Overwrite => "overwrite",
243 WriteMode::Append => "append",
244 },
245 bytes_written: bytes.len(),
246 existed,
247 message,
248 };
249
250 Ok(ToolResult::success(
251 serde_json::to_string(&output).context("Failed to serialize file_write output")?,
252 ))
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use std::fs;
260 use tempfile::tempdir;
261
262 #[tokio::test]
263 async fn test_file_write_overwrite() {
264 let dir = tempdir().unwrap();
265 let path = dir.path().join("file.txt");
266 let tool = FileWriteTool::new();
267
268 let args = serde_json::json!({
269 "path": path.to_string_lossy(),
270 "content": "hello world",
271 "create_dirs": true
272 });
273
274 let result = tool.execute(args).await.unwrap();
275 assert!(result.success);
276 assert_eq!(fs::read_to_string(&path).unwrap(), "hello world");
277 }
278
279 #[tokio::test]
280 async fn test_file_write_append() {
281 let dir = tempdir().unwrap();
282 let path = dir.path().join("append.txt");
283 fs::write(&path, "line1\n").unwrap();
284 let tool = FileWriteTool::new();
285
286 let args = serde_json::json!({
287 "path": path.to_string_lossy(),
288 "content": "line2\n",
289 "mode": "append"
290 });
291
292 let result = tool.execute(args).await.unwrap();
293 assert!(result.success);
294 let contents = fs::read_to_string(&path).unwrap();
295 assert!(contents.contains("line2"));
296 }
297
298 #[tokio::test]
299 async fn test_file_write_base64() {
300 let dir = tempdir().unwrap();
301 let path = dir.path().join("binary.bin");
302 let tool = FileWriteTool::new();
303
304 let args = serde_json::json!({
305 "path": path.to_string_lossy(),
306 "content": "AQID",
307 "encoding": "base64"
308 });
309
310 let result = tool.execute(args).await.unwrap();
311 assert!(result.success);
312 let bytes = fs::read(&path).unwrap();
313 assert_eq!(bytes, vec![1, 2, 3]);
314 }
315}