1use anyhow::{bail, Context, Result};
2use serde::Deserialize;
3use std::fs;
4use std::path::{Path, PathBuf};
5use std::process::Command;
6
7use crate::git;
8
9#[derive(Debug, Deserialize, Default)]
10struct ProjectConfigFile {
11 init: Option<InitConfig>,
12}
13
14#[derive(Debug, Deserialize, Default)]
15struct InitConfig {
16 include: Option<Vec<String>>,
17 commands: Option<Vec<String>>,
18}
19
20#[derive(Debug, Default)]
22pub struct ProjectInit {
23 pub include: Vec<String>,
25 pub commands: Vec<String>,
27}
28
29pub fn load_project_config(repo_root: &Path) -> Result<ProjectInit> {
31 let primary = repo_root.join(".worktree-init.toml");
32 let fallback = repo_root.join(".yawn.toml");
33
34 let config_path = if primary.exists() {
35 primary
36 } else if fallback.exists() {
37 eprintln!("note: using .yawn.toml — consider renaming to .worktree-init.toml");
38 fallback
39 } else {
40 return Ok(ProjectInit::default());
41 };
42
43 let content = fs::read_to_string(&config_path)
44 .with_context(|| format!("failed to read {}", config_path.display()))?;
45 let file: ProjectConfigFile = toml::from_str(&content)
46 .with_context(|| format!("failed to parse {}", config_path.display()))?;
47
48 let init = file.init.unwrap_or_default();
49 Ok(ProjectInit {
50 include: init.include.unwrap_or_default(),
51 commands: init.commands.unwrap_or_default(),
52 })
53}
54
55fn expand_pattern(base: &Path, pattern: &str) -> Vec<PathBuf> {
58 if pattern.contains('*') || pattern.contains('?') || pattern.contains('[') {
59 let glob = match globset::Glob::new(pattern) {
60 Ok(g) => g.compile_matcher(),
61 Err(_) => return Vec::new(),
62 };
63 collect_files(base, base, &glob).unwrap_or_default()
64 } else {
65 let path = base.join(pattern);
66 if path.is_dir() {
67 collect_dir_files(base, &path).unwrap_or_default()
68 } else if path.exists() {
69 vec![PathBuf::from(pattern)]
70 } else {
71 Vec::new()
72 }
73 }
74}
75
76fn collect_dir_files(base: &Path, dir: &Path) -> Result<Vec<PathBuf>> {
78 let mut results = Vec::new();
79 let entries = match fs::read_dir(dir) {
80 Ok(entries) => entries,
81 Err(_) => return Ok(results),
82 };
83 for entry in entries {
84 let entry = entry?;
85 let path = entry.path();
86 if path.is_dir() {
87 results.extend(collect_dir_files(base, &path)?);
88 } else {
89 let rel = path.strip_prefix(base).unwrap_or(&path);
90 results.push(rel.to_path_buf());
91 }
92 }
93 Ok(results)
94}
95
96fn collect_files(base: &Path, dir: &Path, glob: &globset::GlobMatcher) -> Result<Vec<PathBuf>> {
98 let mut results = Vec::new();
99 let entries = match fs::read_dir(dir) {
100 Ok(entries) => entries,
101 Err(_) => return Ok(results),
102 };
103 for entry in entries {
104 let entry = entry?;
105 let path = entry.path();
106 let rel = path.strip_prefix(base).unwrap_or(&path);
107 if path.is_dir() {
108 results.extend(collect_files(base, &path, glob)?);
109 } else if glob.is_match(rel) {
110 results.push(rel.to_path_buf());
111 }
112 }
113 Ok(results)
114}
115
116fn copy_include_files(source: &Path, target: &Path, patterns: &[String]) -> Result<()> {
118 for pattern in patterns {
119 let files = expand_pattern(source, pattern);
120 for file in &files {
121 let src = source.join(file);
122 let dst = target.join(file);
123 if let Some(parent) = dst.parent() {
124 fs::create_dir_all(parent)?;
125 }
126 fs::copy(&src, &dst).with_context(|| format!("failed to copy {}", file.display()))?;
127 eprintln!("copied {}", file.display());
128 }
129 }
130 Ok(())
131}
132
133fn run_commands(target: &Path, commands: &[String]) -> Result<()> {
135 for cmd in commands {
136 eprintln!("running: {}", cmd);
137 let status = Command::new("sh")
138 .arg("-c")
139 .arg(cmd)
140 .current_dir(target)
141 .status()
142 .with_context(|| format!("failed to run: {}", cmd))?;
143 if !status.success() {
144 bail!(
145 "command failed (exit {}): {}",
146 status.code().unwrap_or(-1),
147 cmd
148 );
149 }
150 }
151 Ok(())
152}
153
154pub fn run(target: &Path) -> Result<()> {
160 let toplevel = git::toplevel(target).context("not inside a git repository")?;
161 let repo_root = git::repo_root(target).context("not inside a git repository")?;
162 let config = load_project_config(&repo_root)?;
163
164 if config.include.is_empty() && config.commands.is_empty() {
165 eprintln!("nothing to do: no [init] config in .worktree-init.toml");
166 return Ok(());
167 }
168
169 if !config.include.is_empty() {
171 let toplevel_canonical = toplevel.canonicalize().unwrap_or_else(|_| toplevel.clone());
172 let root_canonical = repo_root
173 .canonicalize()
174 .unwrap_or_else(|_| repo_root.clone());
175 if toplevel_canonical != root_canonical {
176 copy_include_files(&repo_root, &toplevel, &config.include)?;
177 }
178 }
179
180 if !config.commands.is_empty() {
181 run_commands(&toplevel, &config.commands)?;
182 }
183
184 Ok(())
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use std::process::Command as StdCommand;
191 use tempfile::TempDir;
192
193 fn git_cmd(dir: &Path, args: &[&str]) -> String {
194 let output = StdCommand::new("git")
195 .args(args)
196 .current_dir(dir)
197 .output()
198 .unwrap();
199 String::from_utf8_lossy(&output.stdout).trim().to_string()
200 }
201
202 fn init_repo(path: &Path) {
203 fs::create_dir_all(path).unwrap();
204 git_cmd(path, &["init"]);
205 git_cmd(path, &["config", "user.email", "test@test.com"]);
206 git_cmd(path, &["config", "user.name", "Test"]);
207 fs::write(path.join("README.md"), "# test").unwrap();
208 git_cmd(path, &["add", "."]);
209 git_cmd(path, &["commit", "-m", "init"]);
210 let _ = git_cmd(path, &["branch", "-M", "main"]);
211 }
212
213 #[test]
216 fn test_load_no_config_file() {
217 let tmp = TempDir::new().unwrap();
218 let config = load_project_config(tmp.path()).unwrap();
219 assert!(config.include.is_empty());
220 assert!(config.commands.is_empty());
221 }
222
223 #[test]
224 fn test_load_empty_config() {
225 let tmp = TempDir::new().unwrap();
226 fs::write(tmp.path().join(".worktree-init.toml"), "").unwrap();
227 let config = load_project_config(tmp.path()).unwrap();
228 assert!(config.include.is_empty());
229 assert!(config.commands.is_empty());
230 }
231
232 #[test]
233 fn test_load_include_only() {
234 let tmp = TempDir::new().unwrap();
235 fs::write(
236 tmp.path().join(".worktree-init.toml"),
237 "[init]\ninclude = [\".env\", \"config/*.toml\"]\n",
238 )
239 .unwrap();
240 let config = load_project_config(tmp.path()).unwrap();
241 assert_eq!(config.include, vec![".env", "config/*.toml"]);
242 assert!(config.commands.is_empty());
243 }
244
245 #[test]
246 fn test_load_commands_only() {
247 let tmp = TempDir::new().unwrap();
248 fs::write(
249 tmp.path().join(".worktree-init.toml"),
250 "[init]\ncommands = [\"npm install\"]\n",
251 )
252 .unwrap();
253 let config = load_project_config(tmp.path()).unwrap();
254 assert!(config.include.is_empty());
255 assert_eq!(config.commands, vec!["npm install"]);
256 }
257
258 #[test]
259 fn test_load_full_config() {
260 let tmp = TempDir::new().unwrap();
261 fs::write(
262 tmp.path().join(".worktree-init.toml"),
263 "[init]\ninclude = [\".env\"]\ncommands = [\"npm install\", \"cargo build\"]\n",
264 )
265 .unwrap();
266 let config = load_project_config(tmp.path()).unwrap();
267 assert_eq!(config.include, vec![".env"]);
268 assert_eq!(config.commands, vec!["npm install", "cargo build"]);
269 }
270
271 #[test]
272 fn test_load_invalid_toml() {
273 let tmp = TempDir::new().unwrap();
274 fs::write(tmp.path().join(".worktree-init.toml"), "{{invalid").unwrap();
275 assert!(load_project_config(tmp.path()).is_err());
276 }
277
278 #[test]
279 fn test_load_fallback_to_yawn_toml() {
280 let tmp = TempDir::new().unwrap();
281 fs::write(
282 tmp.path().join(".yawn.toml"),
283 "[init]\ninclude = [\".env\"]\n",
284 )
285 .unwrap();
286 let config = load_project_config(tmp.path()).unwrap();
287 assert_eq!(config.include, vec![".env"]);
288 }
289
290 #[test]
291 fn test_load_prefers_worktree_init_toml() {
292 let tmp = TempDir::new().unwrap();
293 fs::write(
294 tmp.path().join(".worktree-init.toml"),
295 "[init]\ninclude = [\"from-new\"]\n",
296 )
297 .unwrap();
298 fs::write(
299 tmp.path().join(".yawn.toml"),
300 "[init]\ninclude = [\"from-old\"]\n",
301 )
302 .unwrap();
303 let config = load_project_config(tmp.path()).unwrap();
304 assert_eq!(config.include, vec!["from-new"]);
305 }
306
307 #[test]
310 fn test_copy_include_literal_files() {
311 let tmp = TempDir::new().unwrap();
312 let source = tmp.path().join("source");
313 let target = tmp.path().join("target");
314 fs::create_dir_all(&source).unwrap();
315 fs::create_dir_all(&target).unwrap();
316
317 fs::write(source.join(".env"), "SECRET=123").unwrap();
318 fs::create_dir_all(source.join("config")).unwrap();
319 fs::write(source.join("config/local.toml"), "[db]\nhost=localhost").unwrap();
320
321 copy_include_files(
322 &source,
323 &target,
324 &[".env".into(), "config/local.toml".into()],
325 )
326 .unwrap();
327
328 assert_eq!(
329 fs::read_to_string(target.join(".env")).unwrap(),
330 "SECRET=123"
331 );
332 assert_eq!(
333 fs::read_to_string(target.join("config/local.toml")).unwrap(),
334 "[db]\nhost=localhost"
335 );
336 }
337
338 #[test]
339 fn test_copy_include_glob_pattern() {
340 let tmp = TempDir::new().unwrap();
341 let source = tmp.path().join("source");
342 let target = tmp.path().join("target");
343 fs::create_dir_all(&source).unwrap();
344 fs::create_dir_all(&target).unwrap();
345
346 fs::write(source.join("data_users.csv"), "id,name").unwrap();
347 fs::write(source.join("data_orders.csv"), "id,total").unwrap();
348 fs::write(source.join("other.csv"), "should not copy").unwrap();
349
350 copy_include_files(&source, &target, &["data_*.csv".into()]).unwrap();
351
352 assert!(target.join("data_users.csv").exists());
353 assert!(target.join("data_orders.csv").exists());
354 assert!(!target.join("other.csv").exists());
355 }
356
357 #[test]
358 fn test_copy_include_glob_in_subdir() {
359 let tmp = TempDir::new().unwrap();
360 let source = tmp.path().join("source");
361 let target = tmp.path().join("target");
362 fs::create_dir_all(&source).unwrap();
363 fs::create_dir_all(&target).unwrap();
364
365 fs::create_dir_all(source.join("config")).unwrap();
366 fs::write(source.join("config/dev.toml"), "dev").unwrap();
367 fs::write(source.join("config/test.toml"), "test").unwrap();
368 fs::write(source.join("config/keep.json"), "not matched").unwrap();
369
370 copy_include_files(&source, &target, &["config/*.toml".into()]).unwrap();
371
372 assert!(target.join("config/dev.toml").exists());
373 assert!(target.join("config/test.toml").exists());
374 assert!(!target.join("config/keep.json").exists());
375 }
376
377 #[test]
378 fn test_copy_include_directory() {
379 let tmp = TempDir::new().unwrap();
380 let source = tmp.path().join("source");
381 let target = tmp.path().join("target");
382 fs::create_dir_all(&source).unwrap();
383 fs::create_dir_all(&target).unwrap();
384
385 fs::create_dir_all(source.join(".cache/sub")).unwrap();
386 fs::write(source.join(".cache/a.txt"), "aaa").unwrap();
387 fs::write(source.join(".cache/sub/b.txt"), "bbb").unwrap();
388
389 copy_include_files(&source, &target, &[".cache".into()]).unwrap();
390
391 assert_eq!(
392 fs::read_to_string(target.join(".cache/a.txt")).unwrap(),
393 "aaa"
394 );
395 assert_eq!(
396 fs::read_to_string(target.join(".cache/sub/b.txt")).unwrap(),
397 "bbb"
398 );
399 }
400
401 #[test]
402 fn test_copy_include_missing_source() {
403 let tmp = TempDir::new().unwrap();
404 let source = tmp.path().join("source");
405 let target = tmp.path().join("target");
406 fs::create_dir_all(&source).unwrap();
407 fs::create_dir_all(&target).unwrap();
408
409 fs::write(source.join(".env"), "SECRET=123").unwrap();
410
411 copy_include_files(&source, &target, &[".env".into(), "missing-file".into()]).unwrap();
412 assert!(target.join(".env").exists());
413 assert!(!target.join("missing-file").exists());
414 }
415
416 #[test]
417 fn test_copy_include_no_glob_matches() {
418 let tmp = TempDir::new().unwrap();
419 let source = tmp.path().join("source");
420 let target = tmp.path().join("target");
421 fs::create_dir_all(&source).unwrap();
422 fs::create_dir_all(&target).unwrap();
423
424 copy_include_files(&source, &target, &["*.xyz".into()]).unwrap();
425 }
426
427 #[test]
430 fn test_run_commands_success() {
431 let tmp = TempDir::new().unwrap();
432 run_commands(tmp.path(), &["echo hello > out.txt".into()]).unwrap();
433 assert_eq!(
434 fs::read_to_string(tmp.path().join("out.txt"))
435 .unwrap()
436 .trim(),
437 "hello"
438 );
439 }
440
441 #[test]
442 fn test_run_commands_failure() {
443 let tmp = TempDir::new().unwrap();
444 let result = run_commands(tmp.path(), &["false".into()]);
445 assert!(result.is_err());
446 assert!(result.unwrap_err().to_string().contains("command failed"));
447 }
448
449 #[test]
450 fn test_run_commands_sequential() {
451 let tmp = TempDir::new().unwrap();
452 run_commands(
453 tmp.path(),
454 &[
455 "echo first > first.txt".into(),
456 "echo second > second.txt".into(),
457 ],
458 )
459 .unwrap();
460 assert!(tmp.path().join("first.txt").exists());
461 assert!(tmp.path().join("second.txt").exists());
462 }
463
464 #[test]
465 fn test_run_commands_stops_on_failure() {
466 let tmp = TempDir::new().unwrap();
467 let result = run_commands(
468 tmp.path(),
469 &[
470 "echo first > first.txt".into(),
471 "false".into(),
472 "echo third > third.txt".into(),
473 ],
474 );
475 assert!(result.is_err());
476 assert!(tmp.path().join("first.txt").exists());
477 assert!(!tmp.path().join("third.txt").exists());
478 }
479
480 #[test]
483 fn test_run_on_worktree_copies_files_and_runs_commands() {
484 let tmp = TempDir::new().unwrap();
485 let repo = tmp.path().join("myproject");
486 init_repo(&repo);
487
488 fs::write(
490 repo.join(".worktree-init.toml"),
491 "[init]\ninclude = [\".env\"]\ncommands = [\"echo done > .init_marker\"]\n",
492 )
493 .unwrap();
494 fs::write(repo.join(".env"), "DB_HOST=localhost").unwrap();
495 git_cmd(&repo, &["add", "."]);
496 git_cmd(&repo, &["commit", "-m", "add config"]);
497
498 let wt_path = tmp.path().join("myproject--feature");
500 git_cmd(
501 &repo,
502 &[
503 "worktree",
504 "add",
505 "-b",
506 "feature",
507 &wt_path.to_string_lossy(),
508 ],
509 );
510
511 run(&wt_path).unwrap();
513
514 assert_eq!(
515 fs::read_to_string(wt_path.join(".env")).unwrap(),
516 "DB_HOST=localhost"
517 );
518 assert!(wt_path.join(".init_marker").exists());
519 }
520
521 #[test]
522 fn test_run_on_main_repo_skips_copy_runs_commands() {
523 let tmp = TempDir::new().unwrap();
524 let repo = tmp.path().join("myproject");
525 init_repo(&repo);
526
527 fs::write(
528 repo.join(".worktree-init.toml"),
529 "[init]\ninclude = [\".env\"]\ncommands = [\"echo done > .init_marker\"]\n",
530 )
531 .unwrap();
532 fs::write(repo.join(".env"), "DB_HOST=localhost").unwrap();
533 git_cmd(&repo, &["add", "."]);
534 git_cmd(&repo, &["commit", "-m", "add config"]);
535
536 run(&repo).unwrap();
538
539 assert!(repo.join(".init_marker").exists());
541 }
542
543 #[test]
544 fn test_run_no_config() {
545 let tmp = TempDir::new().unwrap();
546 let repo = tmp.path().join("myproject");
547 init_repo(&repo);
548
549 run(&repo).unwrap();
551 }
552}