spec_ai_core/tools/builtin/
file_extract.rs

1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::fs;
8use std::path::PathBuf;
9
10#[cfg(not(target_os = "macos"))]
11use extractous::Extractor;
12
13/// Arguments accepted by the file_extract tool
14#[derive(Debug, Deserialize)]
15struct FileExtractArgs {
16    path: String,
17    #[serde(default)]
18    include_metadata: bool,
19    #[serde(default)]
20    xml_output: bool,
21    #[serde(default)]
22    max_chars: Option<i32>,
23}
24
25/// Output payload returned by the file_extract tool
26#[derive(Debug, Serialize, Deserialize)]
27struct FileExtractOutput {
28    path: String,
29    content: String,
30    metadata: Option<HashMap<String, Vec<String>>>,
31}
32
33/// Tool that extracts text from files.
34/// On macOS: Uses native Vision framework for OCR and PDFKit for PDFs
35/// On other platforms: Uses Extractous (Tika-based)
36pub struct FileExtractTool;
37
38impl Default for FileExtractTool {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl FileExtractTool {
45    pub fn new() -> Self {
46        Self
47    }
48
49    fn normalize_path(&self, input: &str) -> Result<PathBuf> {
50        let trimmed = input.trim();
51        if trimmed.is_empty() {
52            return Err(anyhow!("file_extract requires a valid path"));
53        }
54        Ok(PathBuf::from(trimmed))
55    }
56}
57
58// macOS implementation using native Vision/PDFKit
59#[cfg(target_os = "macos")]
60mod macos_extract {
61    use super::*;
62    use std::process::Command;
63    use tokio::task;
64
65    /// Swift script that uses Vision framework for OCR and PDFKit for PDF text extraction
66    const SWIFT_EXTRACTOR: &str = r#"
67import Foundation
68import Vision
69import PDFKit
70import UniformTypeIdentifiers
71
72struct ExtractionResult: Codable {
73    let content: String
74    let metadata: [String: [String]]?
75    let error: String?
76}
77
78func extractText(from path: String, includeMetadata: Bool, maxChars: Int?) -> ExtractionResult {
79    let url = URL(fileURLWithPath: path)
80    let pathExtension = url.pathExtension.lowercased()
81
82    // Determine file type
83    var uti: UTType?
84    if let typeIdentifier = try? url.resourceValues(forKeys: [.typeIdentifierKey]).typeIdentifier {
85        uti = UTType(typeIdentifier)
86    }
87
88    // PDF handling
89    if pathExtension == "pdf" || uti?.conforms(to: .pdf) == true {
90        return extractFromPDF(url: url, includeMetadata: includeMetadata, maxChars: maxChars)
91    }
92
93    // Image handling (use Vision OCR)
94    let imageExtensions = ["png", "jpg", "jpeg", "tiff", "tif", "gif", "bmp", "heic", "webp"]
95    if imageExtensions.contains(pathExtension) || uti?.conforms(to: .image) == true {
96        return extractFromImage(url: url, includeMetadata: includeMetadata, maxChars: maxChars)
97    }
98
99    // Plain text and other text-based files
100    let textExtensions = ["txt", "md", "json", "xml", "html", "htm", "css", "js", "ts", "py", "rs", "go", "java", "c", "cpp", "h", "hpp", "swift", "rb", "php", "yaml", "yml", "toml", "ini", "cfg", "conf", "sh", "bash", "zsh", "csv", "log"]
101    if textExtensions.contains(pathExtension) || uti?.conforms(to: .text) == true || uti?.conforms(to: .sourceCode) == true {
102        return extractFromText(url: url, includeMetadata: includeMetadata, maxChars: maxChars)
103    }
104
105    // Try as text file as fallback
106    return extractFromText(url: url, includeMetadata: includeMetadata, maxChars: maxChars)
107}
108
109func extractFromPDF(url: URL, includeMetadata: Bool, maxChars: Int?) -> ExtractionResult {
110    guard let document = PDFDocument(url: url) else {
111        return ExtractionResult(content: "", metadata: nil, error: "Failed to open PDF document")
112    }
113
114    var text = ""
115    for i in 0..<document.pageCount {
116        if let page = document.page(at: i), let pageText = page.string {
117            text += pageText
118            if i < document.pageCount - 1 {
119                text += "\n\n"
120            }
121        }
122    }
123
124    // If PDF has no extractable text, try OCR on each page
125    if text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
126        text = ocrPDFPages(document: document)
127    }
128
129    if let maxChars = maxChars, text.count > maxChars {
130        text = String(text.prefix(maxChars))
131    }
132
133    var metadata: [String: [String]]? = nil
134    if includeMetadata, let attrs = document.documentAttributes {
135        var meta: [String: [String]] = [:]
136        if let title = attrs[PDFDocumentAttribute.titleAttribute] as? String {
137            meta["title"] = [title]
138        }
139        if let author = attrs[PDFDocumentAttribute.authorAttribute] as? String {
140            meta["author"] = [author]
141        }
142        if let subject = attrs[PDFDocumentAttribute.subjectAttribute] as? String {
143            meta["subject"] = [subject]
144        }
145        if let creator = attrs[PDFDocumentAttribute.creatorAttribute] as? String {
146            meta["creator"] = [creator]
147        }
148        meta["pageCount"] = [String(document.pageCount)]
149        metadata = meta.isEmpty ? nil : meta
150    }
151
152    return ExtractionResult(content: text, metadata: metadata, error: nil)
153}
154
155func ocrPDFPages(document: PDFDocument) -> String {
156    var allText = ""
157    let semaphore = DispatchSemaphore(value: 0)
158
159    for i in 0..<document.pageCount {
160        guard let page = document.page(at: i) else { continue }
161        let bounds = page.bounds(for: .mediaBox)
162
163        // Render page to image
164        let image = NSImage(size: bounds.size)
165        image.lockFocus()
166        if let context = NSGraphicsContext.current?.cgContext {
167            context.setFillColor(NSColor.white.cgColor)
168            context.fill(bounds)
169            page.draw(with: .mediaBox, to: context)
170        }
171        image.unlockFocus()
172
173        guard let cgImage = image.cgImage(forProposedRect: nil, context: nil, hints: nil) else { continue }
174
175        let request = VNRecognizeTextRequest { request, error in
176            defer { semaphore.signal() }
177            guard let observations = request.results as? [VNRecognizedTextObservation] else { return }
178            let pageText = observations.compactMap { $0.topCandidates(1).first?.string }.joined(separator: "\n")
179            if !pageText.isEmpty {
180                if !allText.isEmpty { allText += "\n\n" }
181                allText += pageText
182            }
183        }
184        request.recognitionLevel = .accurate
185        request.usesLanguageCorrection = true
186
187        let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
188        try? handler.perform([request])
189        semaphore.wait()
190    }
191
192    return allText
193}
194
195func extractFromImage(url: URL, includeMetadata: Bool, maxChars: Int?) -> ExtractionResult {
196    guard let image = NSImage(contentsOf: url),
197          let cgImage = image.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
198        return ExtractionResult(content: "", metadata: nil, error: "Failed to load image")
199    }
200
201    var recognizedText = ""
202    let semaphore = DispatchSemaphore(value: 0)
203
204    let request = VNRecognizeTextRequest { request, error in
205        defer { semaphore.signal() }
206        if let error = error {
207            return
208        }
209        guard let observations = request.results as? [VNRecognizedTextObservation] else { return }
210        recognizedText = observations.compactMap { $0.topCandidates(1).first?.string }.joined(separator: "\n")
211    }
212    request.recognitionLevel = .accurate
213    request.usesLanguageCorrection = true
214
215    let handler = VNImageRequestHandler(cgImage: cgImage, options: [:])
216    do {
217        try handler.perform([request])
218        semaphore.wait()
219    } catch {
220        return ExtractionResult(content: "", metadata: nil, error: "OCR failed: \(error.localizedDescription)")
221    }
222
223    if let maxChars = maxChars, recognizedText.count > maxChars {
224        recognizedText = String(recognizedText.prefix(maxChars))
225    }
226
227    var metadata: [String: [String]]? = nil
228    if includeMetadata {
229        var meta: [String: [String]] = [:]
230        meta["width"] = [String(Int(image.size.width))]
231        meta["height"] = [String(Int(image.size.height))]
232        metadata = meta
233    }
234
235    return ExtractionResult(content: recognizedText, metadata: metadata, error: nil)
236}
237
238func extractFromText(url: URL, includeMetadata: Bool, maxChars: Int?) -> ExtractionResult {
239    do {
240        var content = try String(contentsOf: url, encoding: .utf8)
241        if let maxChars = maxChars, content.count > maxChars {
242            content = String(content.prefix(maxChars))
243        }
244
245        var metadata: [String: [String]]? = nil
246        if includeMetadata {
247            let attrs = try FileManager.default.attributesOfItem(atPath: url.path)
248            var meta: [String: [String]] = [:]
249            if let size = attrs[.size] as? Int {
250                meta["size"] = [String(size)]
251            }
252            if let modified = attrs[.modificationDate] as? Date {
253                meta["modified"] = [ISO8601DateFormatter().string(from: modified)]
254            }
255            metadata = meta.isEmpty ? nil : meta
256        }
257
258        return ExtractionResult(content: content, metadata: metadata, error: nil)
259    } catch {
260        return ExtractionResult(content: "", metadata: nil, error: "Failed to read file: \(error.localizedDescription)")
261    }
262}
263
264// Main
265let args = CommandLine.arguments
266guard args.count >= 2 else {
267    let result = ExtractionResult(content: "", metadata: nil, error: "Usage: swift extract.swift <path> [includeMetadata] [maxChars]")
268    print(String(data: try! JSONEncoder().encode(result), encoding: .utf8)!)
269    exit(1)
270}
271
272let path = args[1]
273let includeMetadata = args.count > 2 && args[2] == "true"
274let maxChars: Int? = args.count > 3 ? Int(args[3]) : nil
275
276let result = extractText(from: path, includeMetadata: includeMetadata, maxChars: maxChars)
277let encoder = JSONEncoder()
278encoder.outputFormatting = .sortedKeys
279if let json = try? encoder.encode(result), let jsonString = String(data: json, encoding: .utf8) {
280    print(jsonString)
281} else {
282    print("{\"content\":\"\",\"error\":\"JSON encoding failed\"}")
283}
284"#;
285
286    #[derive(Debug, Deserialize)]
287    struct SwiftResult {
288        content: String,
289        metadata: Option<HashMap<String, Vec<String>>>,
290        error: Option<String>,
291    }
292
293    pub async fn extract_file(
294        path: &str,
295        include_metadata: bool,
296        max_chars: Option<i32>,
297    ) -> Result<(String, Option<HashMap<String, Vec<String>>>)> {
298        let path = path.to_string();
299        let script = SWIFT_EXTRACTOR.to_string();
300
301        task::spawn_blocking(move || {
302            // Write Swift script to temp file
303            let temp_dir = std::env::temp_dir();
304            let script_path = temp_dir.join("spec_ai_extractor.swift");
305            fs::write(&script_path, &script).context("Failed to write Swift extractor script")?;
306
307            // Build arguments
308            let mut args = vec![script_path.to_string_lossy().to_string(), path.clone()];
309            args.push(include_metadata.to_string());
310            if let Some(max) = max_chars {
311                args.push(max.to_string());
312            }
313
314            // Execute Swift script
315            let output = Command::new("swift")
316                .args(&args)
317                .output()
318                .context("Failed to execute Swift script. Ensure Xcode/Swift is installed.")?;
319
320            if !output.status.success() {
321                let stderr = String::from_utf8_lossy(&output.stderr);
322                return Err(anyhow!("Swift extraction failed: {}", stderr));
323            }
324
325            let stdout = String::from_utf8_lossy(&output.stdout);
326            let result: SwiftResult =
327                serde_json::from_str(&stdout).context("Failed to parse Swift output")?;
328
329            if let Some(error) = result.error {
330                return Err(anyhow!("Extraction error: {}", error));
331            }
332
333            Ok((result.content, result.metadata))
334        })
335        .await
336        .context("Task join error")?
337    }
338}
339
340#[async_trait]
341impl Tool for FileExtractTool {
342    fn name(&self) -> &str {
343        "file_extract"
344    }
345
346    fn description(&self) -> &str {
347        "Extracts text and metadata from files regardless of format (PDF, Office, HTML, images with OCR, etc.)"
348    }
349
350    fn parameters(&self) -> Value {
351        serde_json::json!({
352            "type": "object",
353            "properties": {
354                "path": {
355                    "type": "string",
356                    "description": "Relative or absolute path to the file that should be extracted"
357                },
358                "include_metadata": {
359                    "type": "boolean",
360                    "description": "Include metadata from the file",
361                    "default": false
362                },
363                "xml_output": {
364                    "type": "boolean",
365                    "description": "Request XML formatted result instead of plain text (non-macOS only)",
366                    "default": false
367                },
368                "max_chars": {
369                    "type": "integer",
370                    "description": "Limit the number of characters returned (must be > 0 if provided)",
371                    "minimum": 1
372                }
373            },
374            "required": ["path"]
375        })
376    }
377
378    async fn execute(&self, args: Value) -> Result<ToolResult> {
379        let args: FileExtractArgs =
380            serde_json::from_value(args).context("Failed to parse file_extract arguments")?;
381
382        let path = self.normalize_path(&args.path)?;
383        let metadata =
384            fs::metadata(&path).with_context(|| format!("File not found: {}", path.display()))?;
385
386        if !metadata.is_file() {
387            return Ok(ToolResult::failure(format!(
388                "{} is not a regular file",
389                path.display()
390            )));
391        }
392
393        if let Some(max_chars) = args.max_chars {
394            if max_chars <= 0 {
395                return Ok(ToolResult::failure(
396                    "max_chars must be greater than zero".to_string(),
397                ));
398            }
399        }
400
401        let display_path = path.to_string_lossy().into_owned();
402
403        // Platform-specific extraction
404        #[cfg(target_os = "macos")]
405        let (content, extracted_metadata) = {
406            macos_extract::extract_file(&display_path, args.include_metadata, args.max_chars)
407                .await
408                .map_err(|e| anyhow!("macOS extraction failed: {}", e))?
409        };
410
411        #[cfg(not(target_os = "macos"))]
412        let (content, extracted_metadata) = {
413            let mut extractor = Extractor::new();
414            if let Some(max_chars) = args.max_chars {
415                extractor = extractor.set_extract_string_max_length(max_chars);
416            }
417            if args.xml_output {
418                extractor = extractor.set_xml_output(true);
419            }
420            extractor
421                .extract_file_to_string(&display_path)
422                .map_err(|err| anyhow!("Failed to extract {}: {}", display_path, err))?
423        };
424
425        let metadata = if args.include_metadata {
426            extracted_metadata
427        } else {
428            None
429        };
430
431        let output = FileExtractOutput {
432            path: display_path,
433            content,
434            metadata,
435        };
436
437        Ok(ToolResult::success(
438            serde_json::to_string(&output).context("Failed to serialize file_extract output")?,
439        ))
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use tempfile::NamedTempFile;
447
448    #[tokio::test]
449    async fn name_and_description() {
450        let tool = FileExtractTool::new();
451        assert_eq!(tool.name(), "file_extract");
452        assert!(tool.description().contains("Extracts text"));
453    }
454
455    #[tokio::test]
456    async fn parameters_require_path() {
457        let tool = FileExtractTool::new();
458        let params = tool.parameters();
459        let required = params["required"].as_array().unwrap();
460        assert!(required.iter().any(|value| value == "path"));
461    }
462
463    #[tokio::test]
464    async fn invalid_max_chars_returns_failure() {
465        let tool = FileExtractTool::new();
466        let tmp = NamedTempFile::new().unwrap();
467        let args = serde_json::json!({
468            "path": tmp.path().to_string_lossy(),
469            "max_chars": 0
470        });
471
472        let result = tool.execute(args).await.unwrap();
473        assert!(!result.success);
474        assert_eq!(result.error.unwrap(), "max_chars must be greater than zero");
475    }
476
477    #[tokio::test]
478    async fn extract_plain_text_file() {
479        let tool = FileExtractTool::new();
480        let tmp = NamedTempFile::new().unwrap();
481        std::fs::write(tmp.path(), "Hello, World!").unwrap();
482
483        let args = serde_json::json!({
484            "path": tmp.path().to_string_lossy()
485        });
486
487        let result = tool.execute(args).await.unwrap();
488        assert!(result.success);
489        let output: FileExtractOutput = serde_json::from_str(&result.output).unwrap();
490        assert!(output.content.contains("Hello, World!"));
491    }
492}