1use serde::Deserialize;
2use std::collections::HashMap;
3use std::path::{Path, PathBuf};
4
5use crate::error::{Result, UbtError};
6
7pub const BUILTIN_COMMANDS: &[&str] = &[
10 "dep.install",
11 "dep.remove",
12 "dep.update",
13 "dep.outdated",
14 "dep.list",
15 "dep.audit",
16 "dep.lock",
17 "dep.why",
18 "build",
19 "start",
20 "run",
21 "fmt",
22 "run-file",
23 "exec",
24 "test",
25 "lint",
26 "check",
27 "db.migrate",
28 "db.rollback",
29 "db.seed",
30 "db.create",
31 "db.drop",
32 "db.reset",
33 "db.status",
34 "init",
35 "clean",
36 "release",
37 "publish",
38 "tool.info",
39 "tool.doctor",
40 "tool.list",
41 "tool.docs",
42 "config.show",
43 "info",
44 "completions",
45];
46
47pub const BUILTIN_GROUPS: &[&str] = &["dep", "db", "tool", "config"];
48
49#[derive(Debug, Clone, Deserialize, Default)]
52pub struct ProjectConfig {
53 pub tool: Option<String>,
54}
55
56#[derive(Debug, Clone, Deserialize, Default)]
57pub struct UbtConfig {
58 #[serde(default)]
59 pub project: Option<ProjectConfig>,
60 #[serde(default)]
61 pub commands: HashMap<String, String>,
62 #[serde(default)]
63 pub aliases: HashMap<String, String>,
64}
65
66pub fn parse_config(content: &str) -> Result<UbtConfig> {
70 toml::from_str(content).map_err(|e| {
71 let line = e.span().map(|s| {
72 content
73 .bytes()
74 .take(s.start)
75 .filter(|&b| b == b'\n')
76 .count()
77 + 1
78 });
79 UbtError::config_error(line, e.message())
80 })
81}
82
83pub fn validate_aliases(config: &UbtConfig) -> Result<()> {
87 for alias in config.aliases.keys() {
88 if let Some(&cmd) = BUILTIN_COMMANDS.iter().find(|&&c| c == alias.as_str()) {
89 return Err(UbtError::AliasConflict {
90 alias: alias.clone(),
91 command: cmd.to_string(),
92 });
93 }
94 if let Some(&group) = BUILTIN_GROUPS.iter().find(|&&g| g == alias.as_str()) {
95 return Err(UbtError::AliasConflict {
96 alias: alias.clone(),
97 command: group.to_string(),
98 });
99 }
100 }
101 Ok(())
102}
103
104pub fn find_config(start_dir: &Path) -> Result<Option<(UbtConfig, PathBuf)>> {
112 if let Ok(config_path) = std::env::var("UBT_CONFIG") {
114 let path = PathBuf::from(&config_path);
115 let content = std::fs::read_to_string(&path)?;
116 let config = parse_config(&content)?;
117 let project_root = path.parent().unwrap_or(Path::new(".")).to_path_buf();
118 return Ok(Some((config, project_root)));
119 }
120
121 let mut current = start_dir.to_path_buf();
123 loop {
124 let candidate = current.join("ubt.toml");
125 if candidate.is_file() {
126 let content = std::fs::read_to_string(&candidate)?;
127 let config = parse_config(&content)?;
128 return Ok(Some((config, current)));
129 }
130 if !current.pop() {
131 break;
132 }
133 }
134 Ok(None)
135}
136
137pub fn load_config(start_dir: &Path) -> Result<Option<(UbtConfig, PathBuf)>> {
139 match find_config(start_dir)? {
140 Some((config, root)) => {
141 validate_aliases(&config)?;
142 Ok(Some((config, root)))
143 }
144 None => Ok(None),
145 }
146}
147
148#[cfg(test)]
151mod tests {
152 use super::*;
153 use std::path::Path;
154 use std::sync::Mutex;
155 use tempfile::TempDir;
156
157 static ENV_MUTEX: Mutex<()> = Mutex::new(());
159
160 #[test]
161 fn parse_rails_example() {
162 let input = r#"
163[project]
164tool = "bundler"
165
166[commands]
167start = "bin/rails server"
168test = "bin/rails test"
169lint = "bundle exec rubocop"
170fmt = "bundle exec rubocop -a"
171"db.migrate" = "bin/rails db:migrate"
172"db.rollback" = "bin/rails db:rollback STEP={{args}}"
173"db.seed" = "bin/rails db:seed"
174"db.create" = "bin/rails db:create"
175"db.drop" = "bin/rails db:drop"
176"db.reset" = "bin/rails db:reset"
177"db.status" = "bin/rails db:migrate:status"
178run = "bin/rails {{args}}"
179
180[aliases]
181console = "bin/rails console"
182routes = "bin/rails routes"
183generate = "bin/rails generate"
184"#;
185 let config = parse_config(input).unwrap();
186 assert_eq!(config.project.unwrap().tool.unwrap(), "bundler");
187 assert_eq!(config.commands.len(), 12);
188 assert_eq!(config.aliases.len(), 3);
189 }
190
191 #[test]
192 fn parse_node_prisma_example() {
193 let input = r#"
194[project]
195tool = "pnpm"
196
197[commands]
198start = "pnpm run dev"
199build = "pnpm run build"
200test = "pnpm exec vitest"
201lint = "pnpm exec eslint ."
202fmt = "pnpm exec prettier --write ."
203"fmt.check" = "pnpm exec prettier --check ."
204"db.migrate" = "pnpm exec prisma migrate deploy"
205"db.seed" = "pnpm exec prisma db seed"
206"db.status" = "pnpm exec prisma migrate status"
207"db.reset" = "pnpm exec prisma migrate reset"
208
209[aliases]
210studio = "pnpm exec prisma studio"
211generate = "pnpm exec prisma generate"
212typecheck = "pnpm exec tsc --noEmit"
213"#;
214 let config = parse_config(input).unwrap();
215 assert_eq!(config.project.unwrap().tool.unwrap(), "pnpm");
216 assert_eq!(config.commands.len(), 10);
217 assert_eq!(config.aliases.len(), 3);
218 }
219
220 #[test]
221 fn parse_minimal_config() {
222 let input = "[project]\ntool = \"go\"";
223 let config = parse_config(input).unwrap();
224 assert_eq!(config.project.unwrap().tool.unwrap(), "go");
225 assert_eq!(config.commands.len(), 0);
226 assert_eq!(config.aliases.len(), 0);
227 }
228
229 #[test]
230 fn parse_empty_config() {
231 let config = parse_config("").unwrap();
232 assert!(config.project.is_none());
233 assert_eq!(config.commands.len(), 0);
234 assert_eq!(config.aliases.len(), 0);
235 }
236
237 #[test]
238 fn parse_invalid_toml_returns_config_error() {
239 let result = parse_config("[invalid");
240 assert!(result.is_err());
241 let err = result.unwrap_err();
242 assert!(matches!(err, UbtError::ConfigError { .. }));
243 }
244
245 #[test]
246 fn validate_alias_conflicting_with_command() {
247 let mut aliases = HashMap::new();
248 aliases.insert("test".to_string(), "something".to_string());
249 let config = UbtConfig {
250 project: None,
251 commands: HashMap::new(),
252 aliases,
253 };
254 let err = validate_aliases(&config).unwrap_err();
255 match err {
256 UbtError::AliasConflict { alias, command } => {
257 assert_eq!(alias, "test");
258 assert_eq!(command, "test");
259 }
260 other => panic!("expected AliasConflict, got: {other:?}"),
261 }
262 }
263
264 #[test]
265 fn validate_alias_conflicting_with_group() {
266 let mut aliases = HashMap::new();
267 aliases.insert("dep".to_string(), "something".to_string());
268 let config = UbtConfig {
269 project: None,
270 commands: HashMap::new(),
271 aliases,
272 };
273 let err = validate_aliases(&config).unwrap_err();
274 match err {
275 UbtError::AliasConflict { alias, command } => {
276 assert_eq!(alias, "dep");
277 assert_eq!(command, "dep");
278 }
279 other => panic!("expected AliasConflict, got: {other:?}"),
280 }
281 }
282
283 #[test]
284 fn find_config_walks_upward() {
285 let _lock = ENV_MUTEX.lock().unwrap();
286 let prev = std::env::var("UBT_CONFIG").ok();
288 unsafe {
289 std::env::remove_var("UBT_CONFIG");
290 }
291
292 let dir = TempDir::new().unwrap();
293 std::fs::write(dir.path().join("ubt.toml"), "[project]\ntool = \"go\"").unwrap();
294 let nested = dir.path().join("a").join("b").join("c");
295 std::fs::create_dir_all(&nested).unwrap();
296
297 let result = find_config(&nested).unwrap().unwrap();
298
299 if let Some(v) = prev {
301 unsafe {
302 std::env::set_var("UBT_CONFIG", v);
303 }
304 }
305
306 assert_eq!(result.0.project.unwrap().tool.unwrap(), "go");
307 assert_eq!(result.1, dir.path());
308 }
309
310 #[test]
311 fn find_config_returns_none_when_absent() {
312 let _lock = ENV_MUTEX.lock().unwrap();
313 let prev = std::env::var("UBT_CONFIG").ok();
314 unsafe {
315 std::env::remove_var("UBT_CONFIG");
316 }
317
318 let dir = TempDir::new().unwrap();
319 let result = find_config(dir.path()).unwrap();
320
321 if let Some(v) = prev {
322 unsafe {
323 std::env::set_var("UBT_CONFIG", v);
324 }
325 }
326
327 assert!(result.is_none());
328 }
329
330 #[test]
331 fn find_config_respects_ubt_config_env() {
332 let _lock = ENV_MUTEX.lock().unwrap();
333
334 let dir = TempDir::new().unwrap();
335 let config_path = dir.path().join("custom.toml");
336 std::fs::write(&config_path, "[project]\ntool = \"custom\"").unwrap();
337
338 let prev = std::env::var("UBT_CONFIG").ok();
340 unsafe {
341 std::env::set_var("UBT_CONFIG", &config_path);
342 }
343
344 let result = find_config(Path::new("/tmp"));
345
346 match prev {
348 Some(v) => unsafe {
349 std::env::set_var("UBT_CONFIG", v);
350 },
351 None => unsafe {
352 std::env::remove_var("UBT_CONFIG");
353 },
354 }
355
356 let (config, root) = result.unwrap().unwrap();
357 assert_eq!(config.project.unwrap().tool.unwrap(), "custom");
358 assert_eq!(root, dir.path());
359 }
360}