sgr_agent/
prompt_loader.rs1use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::RwLock;
6
7pub struct PromptLoader {
9 base_dir: PathBuf,
10 cache: RwLock<HashMap<String, String>>,
11}
12
13impl PromptLoader {
14 pub fn new(base_dir: impl Into<PathBuf>) -> Self {
16 Self {
17 base_dir: base_dir.into(),
18 cache: RwLock::new(HashMap::new()),
19 }
20 }
21
22 pub fn load(&self, relative_path: &str) -> Result<String, PromptError> {
25 if let Ok(cache) = self.cache.read()
27 && let Some(cached) = cache.get(relative_path)
28 {
29 return Ok(cached.clone());
30 }
31
32 let full_path = self.base_dir.join(relative_path);
33 let content = std::fs::read_to_string(&full_path)
34 .map_err(|e| PromptError::Io(full_path.clone(), e))?;
35
36 let processed = self.process_includes(&content, 0)?;
38
39 if let Ok(mut cache) = self.cache.write() {
40 cache.insert(relative_path.to_string(), processed.clone());
41 }
42
43 Ok(processed)
44 }
45
46 pub fn load_merged(&self, paths: &[&str]) -> Result<String, PromptError> {
48 let mut parts = Vec::new();
49 for path in paths {
50 parts.push(self.load(path)?);
51 }
52 Ok(parts.join("\n\n"))
53 }
54
55 pub fn load_with_vars(
58 &self,
59 path: &str,
60 vars: &HashMap<String, String>,
61 ) -> Result<String, PromptError> {
62 let mut content = self.load(path)?;
63 for (key, value) in vars {
64 content = content.replace(&format!("{{{{{}}}}}", key), value);
65 }
66 Ok(content)
67 }
68
69 pub fn clear_cache(&self) {
71 if let Ok(mut cache) = self.cache.write() {
72 cache.clear();
73 }
74 }
75
76 pub fn base_dir(&self) -> &Path {
78 &self.base_dir
79 }
80
81 fn process_includes(&self, content: &str, depth: usize) -> Result<String, PromptError> {
84 if depth > 5 {
85 return Err(PromptError::MaxIncludeDepth);
86 }
87
88 let mut result = String::with_capacity(content.len());
89 let mut remaining = content;
90
91 while let Some(start) = remaining.find("{{include:") {
92 result.push_str(&remaining[..start]);
93 let after_tag = &remaining[start + 10..];
94 if let Some(end) = after_tag.find("}}") {
95 let include_path = &after_tag[..end];
96 let full_path = self.base_dir.join(include_path);
97 let canonical = std::fs::canonicalize(&full_path)
99 .map_err(|e| PromptError::Io(full_path.clone(), e))?;
100 let canonical_base = std::fs::canonicalize(&self.base_dir)
101 .map_err(|e| PromptError::Io(self.base_dir.clone(), e))?;
102 if !canonical.starts_with(&canonical_base) {
103 return Err(PromptError::PathTraversal(include_path.to_string()));
104 }
105 let included = std::fs::read_to_string(&canonical)
106 .map_err(|e| PromptError::Io(full_path, e))?;
107 let processed = self.process_includes(&included, depth + 1)?;
108 result.push_str(&processed);
109 remaining = &after_tag[end + 2..];
110 } else {
111 result.push_str("{{include:");
112 remaining = after_tag;
113 }
114 }
115 result.push_str(remaining);
116
117 Ok(result)
118 }
119}
120
121#[derive(Debug)]
123pub enum PromptError {
124 Io(PathBuf, std::io::Error),
126 MaxIncludeDepth,
128 PathTraversal(String),
130}
131
132impl std::fmt::Display for PromptError {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 match self {
135 Self::Io(path, e) => write!(f, "Failed to load prompt '{}': {}", path.display(), e),
136 Self::MaxIncludeDepth => write!(f, "Maximum include depth (5) exceeded"),
137 Self::PathTraversal(path) => {
138 write!(f, "Include path '{}' escapes base directory", path)
139 }
140 }
141 }
142}
143
144impl std::error::Error for PromptError {}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use std::fs;
150
151 fn setup_test_dir() -> tempfile::TempDir {
152 let dir = tempfile::tempdir().unwrap();
153 fs::write(
154 dir.path().join("system.md"),
155 "You are an agent.\n\nBe helpful.",
156 )
157 .unwrap();
158 fs::write(dir.path().join("mode.md"), "Mode: execute").unwrap();
159
160 fs::create_dir_all(dir.path().join("roles")).unwrap();
161 fs::write(
162 dir.path().join("roles/explorer.md"),
163 "You are an explorer. Read-only.",
164 )
165 .unwrap();
166
167 fs::write(
169 dir.path().join("with_include.md"),
170 "Header\n{{include:roles/explorer.md}}\nFooter",
171 )
172 .unwrap();
173
174 dir
175 }
176
177 #[test]
178 fn load_basic() {
179 let dir = setup_test_dir();
180 let loader = PromptLoader::new(dir.path());
181 let content = loader.load("system.md").unwrap();
182 assert!(content.contains("You are an agent"));
183 }
184
185 #[test]
186 fn load_cached() {
187 let dir = setup_test_dir();
188 let loader = PromptLoader::new(dir.path());
189 let _ = loader.load("system.md").unwrap();
190 let content = loader.load("system.md").unwrap();
192 assert!(content.contains("You are an agent"));
193 }
194
195 #[test]
196 fn load_merged() {
197 let dir = setup_test_dir();
198 let loader = PromptLoader::new(dir.path());
199 let content = loader.load_merged(&["system.md", "mode.md"]).unwrap();
200 assert!(content.contains("You are an agent"));
201 assert!(content.contains("Mode: execute"));
202 }
203
204 #[test]
205 fn load_with_vars() {
206 let dir = setup_test_dir();
207 fs::write(
208 dir.path().join("template.md"),
209 "Hello {{name}}, you are {{role}}.",
210 )
211 .unwrap();
212 let loader = PromptLoader::new(dir.path());
213 let mut vars = HashMap::new();
214 vars.insert("name".to_string(), "Agent-1".to_string());
215 vars.insert("role".to_string(), "explorer".to_string());
216 let content = loader.load_with_vars("template.md", &vars).unwrap();
217 assert_eq!(content, "Hello Agent-1, you are explorer.");
218 }
219
220 #[test]
221 fn load_with_includes() {
222 let dir = setup_test_dir();
223 let loader = PromptLoader::new(dir.path());
224 let content = loader.load("with_include.md").unwrap();
225 assert!(content.contains("Header"));
226 assert!(content.contains("You are an explorer"));
227 assert!(content.contains("Footer"));
228 }
229
230 #[test]
231 fn load_missing_file() {
232 let dir = setup_test_dir();
233 let loader = PromptLoader::new(dir.path());
234 assert!(loader.load("nonexistent.md").is_err());
235 }
236
237 #[test]
238 fn include_path_traversal_blocked() {
239 let dir = setup_test_dir();
240 fs::write(
242 dir.path().join("evil.md"),
243 "Before\n{{include:../../../etc/hostname}}\nAfter",
244 )
245 .unwrap();
246 let loader = PromptLoader::new(dir.path());
247 let result = loader.load("evil.md");
248 assert!(result.is_err());
249 let err = result.unwrap_err().to_string();
250 assert!(
251 err.contains("escapes base directory") || err.contains("Failed to load"),
252 "unexpected error: {}",
253 err
254 );
255 }
256
257 #[test]
258 fn clear_cache_works() {
259 let dir = setup_test_dir();
260 let loader = PromptLoader::new(dir.path());
261 let _ = loader.load("system.md").unwrap();
262 loader.clear_cache();
263 let content = loader.load("system.md").unwrap();
265 assert!(content.contains("You are an agent"));
266 }
267}