1use std::collections::HashMap;
18use std::path::{Path, PathBuf};
19use std::sync::Mutex;
20
21use serde::Deserialize;
22
23use crate::allowlist::{Allowlist, Rule, RuleError};
24
25const CONFIG_FILENAME: &str = ".shell-mcp.toml";
26
27#[derive(Debug, Clone, Default, Deserialize)]
29#[serde(deny_unknown_fields)]
30pub struct Config {
31 #[serde(default)]
33 pub allow: Vec<String>,
34
35 #[serde(default = "default_true")]
39 pub include_defaults: bool,
40}
41
42fn default_true() -> bool {
43 true
44}
45
46impl Config {
47 pub fn parse(text: &str) -> Result<Self, ConfigError> {
48 toml::from_str(text).map_err(|e| ConfigError::Parse {
49 reason: e.to_string(),
50 })
51 }
52
53 pub fn load(path: &Path) -> Result<Self, ConfigError> {
54 let text = std::fs::read_to_string(path).map_err(|e| ConfigError::Io {
55 path: path.to_path_buf(),
56 source: e,
57 })?;
58 Self::parse(&text).map_err(|e| match e {
59 ConfigError::Parse { reason } => ConfigError::ParseAt {
60 path: path.to_path_buf(),
61 reason,
62 },
63 other => other,
64 })
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct LoadedConfig {
71 pub allowlist: Allowlist,
73 pub sources: Vec<PathBuf>,
77 pub defaults_included: bool,
79 pub root: PathBuf,
81 pub cwd: PathBuf,
83}
84
85fn discover_config_files(cwd: &Path) -> Vec<PathBuf> {
89 let mut found_in_tree: Vec<PathBuf> = Vec::new();
90 let mut cursor = Some(cwd.to_path_buf());
91 while let Some(dir) = cursor {
92 let candidate = dir.join(CONFIG_FILENAME);
93 if candidate.is_file() {
94 found_in_tree.push(candidate);
95 }
96 cursor = dir.parent().map(|p| p.to_path_buf());
97 }
98 found_in_tree.reverse();
100
101 let mut all = Vec::new();
102 if let Some(home) = dirs::home_dir() {
103 let global = home.join(CONFIG_FILENAME);
104 if global.is_file() {
105 all.push(global);
106 }
107 }
108 all.extend(found_in_tree);
109 all
110}
111
112pub fn resolve(root: &Path, cwd: &Path) -> Result<LoadedConfig, ConfigError> {
118 let sources = discover_config_files(cwd);
119 let mut allowlist = Allowlist::new();
120 let mut include_defaults = true;
121 let mut configs: Vec<(PathBuf, Config)> = Vec::with_capacity(sources.len());
122 for path in &sources {
123 let cfg = Config::load(path)?;
124 configs.push((path.clone(), cfg));
125 }
126 if let Some((_, last)) = configs.last() {
128 include_defaults = last.include_defaults;
129 }
130 if include_defaults {
131 allowlist.extend(crate::allowlist::platform_defaults());
132 }
133 for (path, cfg) in &configs {
134 for raw in &cfg.allow {
135 let rule = Rule::parse(raw.clone(), path.display().to_string()).map_err(|e| {
136 ConfigError::Rule {
137 path: path.clone(),
138 source: e,
139 }
140 })?;
141 allowlist.push(rule);
142 }
143 }
144 Ok(LoadedConfig {
145 allowlist,
146 sources,
147 defaults_included: include_defaults,
148 root: root.to_path_buf(),
149 cwd: cwd.to_path_buf(),
150 })
151}
152
153#[derive(Default)]
155pub struct ConfigCache {
156 inner: Mutex<HashMap<(PathBuf, PathBuf), LoadedConfig>>,
157}
158
159impl ConfigCache {
160 pub fn new() -> Self {
161 Self::default()
162 }
163
164 pub fn get_or_load(&self, root: &Path, cwd: &Path) -> Result<LoadedConfig, ConfigError> {
165 let key = (root.to_path_buf(), cwd.to_path_buf());
166 {
167 let guard = self.inner.lock().expect("config cache poisoned");
168 if let Some(hit) = guard.get(&key) {
169 return Ok(hit.clone());
170 }
171 }
172 let loaded = resolve(root, cwd)?;
173 let mut guard = self.inner.lock().expect("config cache poisoned");
174 guard.insert(key, loaded.clone());
175 Ok(loaded)
176 }
177
178 pub fn clear(&self) {
179 self.inner.lock().expect("config cache poisoned").clear();
180 }
181}
182
183#[derive(Debug, thiserror::Error)]
184pub enum ConfigError {
185 #[error("could not read config at {path}: {source}")]
186 Io {
187 path: PathBuf,
188 #[source]
189 source: std::io::Error,
190 },
191
192 #[error("could not parse config: {reason}")]
193 Parse { reason: String },
194
195 #[error("could not parse config at {path}: {reason}")]
196 ParseAt { path: PathBuf, reason: String },
197
198 #[error("invalid rule in {path}: {source}")]
199 Rule {
200 path: PathBuf,
201 #[source]
202 source: RuleError,
203 },
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use tempfile::tempdir;
210
211 fn write(path: &Path, contents: &str) {
212 std::fs::write(path, contents).unwrap();
213 }
214
215 fn tokens(s: &str) -> Vec<String> {
216 shlex::split(s).unwrap()
217 }
218
219 #[test]
220 fn empty_config_yields_defaults_only() {
221 let dir = tempdir().unwrap();
222 let loaded = resolve(dir.path(), dir.path()).unwrap();
223 assert!(loaded.defaults_included);
224 assert!(
225 loaded.sources.is_empty() || loaded.sources.iter().all(|p| !p.starts_with(dir.path()))
226 );
227 }
228
229 #[test]
230 fn walks_up_and_merges_inner_over_outer() {
231 let outer = tempdir().unwrap();
232 let inner = outer.path().join("project").join("sub");
233 std::fs::create_dir_all(&inner).unwrap();
234
235 write(
236 &outer.path().join(".shell-mcp.toml"),
237 r#"allow = ["outer-cmd **"]"#,
238 );
239 write(
240 &outer.path().join("project").join(".shell-mcp.toml"),
241 r#"allow = ["mid-cmd **"]"#,
242 );
243 write(
244 &inner.join(".shell-mcp.toml"),
245 r#"allow = ["inner-cmd **"]"#,
246 );
247
248 let loaded = resolve(outer.path(), &inner).unwrap();
249
250 let in_tree: Vec<_> = loaded
253 .sources
254 .iter()
255 .filter(|p| p.starts_with(outer.path()))
256 .collect();
257 assert_eq!(in_tree.len(), 3);
258 assert!(
259 in_tree[0].ends_with("launch/.shell-mcp.toml") || in_tree[0].starts_with(outer.path())
260 );
261 assert!(in_tree[0].parent().unwrap() == outer.path());
263 assert!(in_tree[2].parent().unwrap() == inner);
264
265 assert!(loaded
266 .allowlist
267 .find_match(&tokens("outer-cmd a"))
268 .is_some());
269 assert!(loaded.allowlist.find_match(&tokens("mid-cmd a")).is_some());
270 assert!(loaded
271 .allowlist
272 .find_match(&tokens("inner-cmd a"))
273 .is_some());
274 }
275
276 #[test]
277 fn include_defaults_false_disables_platform_defaults() {
278 let dir = tempdir().unwrap();
279 write(
280 &dir.path().join(".shell-mcp.toml"),
281 r#"
282include_defaults = false
283allow = ["only-this"]
284"#,
285 );
286 let loaded = resolve(dir.path(), dir.path()).unwrap();
287 assert!(!loaded.defaults_included);
288 assert!(loaded.allowlist.find_match(&tokens("only-this")).is_some());
289 assert!(loaded.allowlist.find_match(&tokens("pwd")).is_none());
291 }
292
293 #[test]
294 fn cache_returns_stable_result() {
295 let dir = tempdir().unwrap();
296 let cache = ConfigCache::new();
297 let a = cache.get_or_load(dir.path(), dir.path()).unwrap();
298 let b = cache.get_or_load(dir.path(), dir.path()).unwrap();
299 assert_eq!(a.sources, b.sources);
300 }
301}