Skip to main content

sgr_agent/
prompt_loader.rs

1//! Load, cache, and merge prompt files from disk.
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::RwLock;
6
7/// Loads and caches prompt files from a directory.
8pub struct PromptLoader {
9    base_dir: PathBuf,
10    cache: RwLock<HashMap<String, String>>,
11}
12
13impl PromptLoader {
14    /// Create a new prompt loader rooted at the given directory.
15    pub fn new(base_dir: impl Into<PathBuf>) -> Self {
16        Self {
17            base_dir: base_dir.into(),
18            cache: RwLock::new(HashMap::new()),
19        }
20    }
21
22    /// Load a prompt file by relative path (e.g., "system.md", "roles/explorer.md").
23    /// Returns cached version if already loaded.
24    pub fn load(&self, relative_path: &str) -> Result<String, PromptError> {
25        // Check cache first
26        if let Ok(cache) = self.cache.read()
27            && let Some(cached) = cache.get(relative_path)
28        {
29            return Ok(cached.clone());
30        }
31
32        let full_path = self.base_dir.join(relative_path);
33        let content = std::fs::read_to_string(&full_path)
34            .map_err(|e| PromptError::Io(full_path.clone(), e))?;
35
36        // Process includes: {{include:path/to/file.md}}
37        let processed = self.process_includes(&content, 0)?;
38
39        if let Ok(mut cache) = self.cache.write() {
40            cache.insert(relative_path.to_string(), processed.clone());
41        }
42
43        Ok(processed)
44    }
45
46    /// Load and merge multiple prompt files, separated by newlines.
47    pub fn load_merged(&self, paths: &[&str]) -> Result<String, PromptError> {
48        let mut parts = Vec::new();
49        for path in paths {
50            parts.push(self.load(path)?);
51        }
52        Ok(parts.join("\n\n"))
53    }
54
55    /// Load a prompt with variable substitution.
56    /// Variables are {{key}} patterns in the template.
57    pub fn load_with_vars(
58        &self,
59        path: &str,
60        vars: &HashMap<String, String>,
61    ) -> Result<String, PromptError> {
62        let mut content = self.load(path)?;
63        for (key, value) in vars {
64            content = content.replace(&format!("{{{{{}}}}}", key), value);
65        }
66        Ok(content)
67    }
68
69    /// Clear the cache, forcing reload on next access.
70    pub fn clear_cache(&self) {
71        if let Ok(mut cache) = self.cache.write() {
72            cache.clear();
73        }
74    }
75
76    /// Return the base directory this loader reads from.
77    pub fn base_dir(&self) -> &Path {
78        &self.base_dir
79    }
80
81    /// Process {{include:path}} directives recursively (max depth 5).
82    /// Include paths are canonicalized to prevent directory traversal attacks.
83    fn process_includes(&self, content: &str, depth: usize) -> Result<String, PromptError> {
84        if depth > 5 {
85            return Err(PromptError::MaxIncludeDepth);
86        }
87
88        let mut result = String::with_capacity(content.len());
89        let mut remaining = content;
90
91        while let Some(start) = remaining.find("{{include:") {
92            result.push_str(&remaining[..start]);
93            let after_tag = &remaining[start + 10..];
94            if let Some(end) = after_tag.find("}}") {
95                let include_path = &after_tag[..end];
96                let full_path = self.base_dir.join(include_path);
97                // Canonicalize to prevent path traversal (../ and symlinks)
98                let canonical = std::fs::canonicalize(&full_path)
99                    .map_err(|e| PromptError::Io(full_path.clone(), e))?;
100                let canonical_base = std::fs::canonicalize(&self.base_dir)
101                    .map_err(|e| PromptError::Io(self.base_dir.clone(), e))?;
102                if !canonical.starts_with(&canonical_base) {
103                    return Err(PromptError::PathTraversal(include_path.to_string()));
104                }
105                let included = std::fs::read_to_string(&canonical)
106                    .map_err(|e| PromptError::Io(full_path, e))?;
107                let processed = self.process_includes(&included, depth + 1)?;
108                result.push_str(&processed);
109                remaining = &after_tag[end + 2..];
110            } else {
111                result.push_str("{{include:");
112                remaining = after_tag;
113            }
114        }
115        result.push_str(remaining);
116
117        Ok(result)
118    }
119}
120
121/// Errors from prompt loading.
122#[derive(Debug)]
123pub enum PromptError {
124    /// File I/O error.
125    Io(PathBuf, std::io::Error),
126    /// Too many nested includes.
127    MaxIncludeDepth,
128    /// Include path escapes base directory.
129    PathTraversal(String),
130}
131
132impl std::fmt::Display for PromptError {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        match self {
135            Self::Io(path, e) => write!(f, "Failed to load prompt '{}': {}", path.display(), e),
136            Self::MaxIncludeDepth => write!(f, "Maximum include depth (5) exceeded"),
137            Self::PathTraversal(path) => {
138                write!(f, "Include path '{}' escapes base directory", path)
139            }
140        }
141    }
142}
143
144impl std::error::Error for PromptError {}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use std::fs;
150
151    fn setup_test_dir() -> tempfile::TempDir {
152        let dir = tempfile::tempdir().unwrap();
153        fs::write(
154            dir.path().join("system.md"),
155            "You are an agent.\n\nBe helpful.",
156        )
157        .unwrap();
158        fs::write(dir.path().join("mode.md"), "Mode: execute").unwrap();
159
160        fs::create_dir_all(dir.path().join("roles")).unwrap();
161        fs::write(
162            dir.path().join("roles/explorer.md"),
163            "You are an explorer. Read-only.",
164        )
165        .unwrap();
166
167        // Include test
168        fs::write(
169            dir.path().join("with_include.md"),
170            "Header\n{{include:roles/explorer.md}}\nFooter",
171        )
172        .unwrap();
173
174        dir
175    }
176
177    #[test]
178    fn load_basic() {
179        let dir = setup_test_dir();
180        let loader = PromptLoader::new(dir.path());
181        let content = loader.load("system.md").unwrap();
182        assert!(content.contains("You are an agent"));
183    }
184
185    #[test]
186    fn load_cached() {
187        let dir = setup_test_dir();
188        let loader = PromptLoader::new(dir.path());
189        let _ = loader.load("system.md").unwrap();
190        // Second load should use cache
191        let content = loader.load("system.md").unwrap();
192        assert!(content.contains("You are an agent"));
193    }
194
195    #[test]
196    fn load_merged() {
197        let dir = setup_test_dir();
198        let loader = PromptLoader::new(dir.path());
199        let content = loader.load_merged(&["system.md", "mode.md"]).unwrap();
200        assert!(content.contains("You are an agent"));
201        assert!(content.contains("Mode: execute"));
202    }
203
204    #[test]
205    fn load_with_vars() {
206        let dir = setup_test_dir();
207        fs::write(
208            dir.path().join("template.md"),
209            "Hello {{name}}, you are {{role}}.",
210        )
211        .unwrap();
212        let loader = PromptLoader::new(dir.path());
213        let mut vars = HashMap::new();
214        vars.insert("name".to_string(), "Agent-1".to_string());
215        vars.insert("role".to_string(), "explorer".to_string());
216        let content = loader.load_with_vars("template.md", &vars).unwrap();
217        assert_eq!(content, "Hello Agent-1, you are explorer.");
218    }
219
220    #[test]
221    fn load_with_includes() {
222        let dir = setup_test_dir();
223        let loader = PromptLoader::new(dir.path());
224        let content = loader.load("with_include.md").unwrap();
225        assert!(content.contains("Header"));
226        assert!(content.contains("You are an explorer"));
227        assert!(content.contains("Footer"));
228    }
229
230    #[test]
231    fn load_missing_file() {
232        let dir = setup_test_dir();
233        let loader = PromptLoader::new(dir.path());
234        assert!(loader.load("nonexistent.md").is_err());
235    }
236
237    #[test]
238    fn include_path_traversal_blocked() {
239        let dir = setup_test_dir();
240        // Create a file that tries to include outside base_dir
241        fs::write(
242            dir.path().join("evil.md"),
243            "Before\n{{include:../../../etc/hostname}}\nAfter",
244        )
245        .unwrap();
246        let loader = PromptLoader::new(dir.path());
247        let result = loader.load("evil.md");
248        assert!(result.is_err());
249        let err = result.unwrap_err().to_string();
250        assert!(
251            err.contains("escapes base directory") || err.contains("Failed to load"),
252            "unexpected error: {}",
253            err
254        );
255    }
256
257    #[test]
258    fn clear_cache_works() {
259        let dir = setup_test_dir();
260        let loader = PromptLoader::new(dir.path());
261        let _ = loader.load("system.md").unwrap();
262        loader.clear_cache();
263        // Should reload from disk
264        let content = loader.load("system.md").unwrap();
265        assert!(content.contains("You are an agent"));
266    }
267}