Skip to main content

tam_worktree/
init.rs

1use anyhow::{bail, Context, Result};
2use serde::Deserialize;
3use std::fs;
4use std::path::{Path, PathBuf};
5use std::process::Command;
6
7use crate::git;
8
9#[derive(Debug, Deserialize, Default)]
10struct ProjectConfigFile {
11    init: Option<InitConfig>,
12}
13
14#[derive(Debug, Deserialize, Default)]
15struct InitConfig {
16    include: Option<Vec<String>>,
17    commands: Option<Vec<String>>,
18}
19
20/// Parsed `.worktree-init.toml` configuration for initializing new worktrees.
21#[derive(Debug, Default)]
22pub struct ProjectInit {
23    /// File globs to copy from the main checkout (e.g. `[".env", ".claude/**"]`).
24    pub include: Vec<String>,
25    /// Shell commands to run in the new worktree (e.g. `["npm install"]`).
26    pub commands: Vec<String>,
27}
28
29/// Load project init config from `.worktree-init.toml` (preferred) or `.yawn.toml` (fallback).
30pub fn load_project_config(repo_root: &Path) -> Result<ProjectInit> {
31    let primary = repo_root.join(".worktree-init.toml");
32    let fallback = repo_root.join(".yawn.toml");
33
34    let config_path = if primary.exists() {
35        primary
36    } else if fallback.exists() {
37        eprintln!("note: using .yawn.toml — consider renaming to .worktree-init.toml");
38        fallback
39    } else {
40        return Ok(ProjectInit::default());
41    };
42
43    let content = fs::read_to_string(&config_path)
44        .with_context(|| format!("failed to read {}", config_path.display()))?;
45    let file: ProjectConfigFile = toml::from_str(&content)
46        .with_context(|| format!("failed to parse {}", config_path.display()))?;
47
48    let init = file.init.unwrap_or_default();
49    Ok(ProjectInit {
50        include: init.include.unwrap_or_default(),
51        commands: init.commands.unwrap_or_default(),
52    })
53}
54
55/// Expand a pattern relative to a directory, supporting globs.
56/// If the pattern contains glob characters, expand it; otherwise treat it as a literal path.
57fn expand_pattern(base: &Path, pattern: &str) -> Vec<PathBuf> {
58    if pattern.contains('*') || pattern.contains('?') || pattern.contains('[') {
59        let glob = match globset::Glob::new(pattern) {
60            Ok(g) => g.compile_matcher(),
61            Err(_) => return Vec::new(),
62        };
63        collect_files(base, base, &glob).unwrap_or_default()
64    } else {
65        let path = base.join(pattern);
66        if path.is_dir() {
67            collect_dir_files(base, &path).unwrap_or_default()
68        } else if path.exists() {
69            vec![PathBuf::from(pattern)]
70        } else {
71            Vec::new()
72        }
73    }
74}
75
76/// Recursively collect all files under `dir`, returning paths relative to `base`.
77fn collect_dir_files(base: &Path, dir: &Path) -> Result<Vec<PathBuf>> {
78    let mut results = Vec::new();
79    let entries = match fs::read_dir(dir) {
80        Ok(entries) => entries,
81        Err(_) => return Ok(results),
82    };
83    for entry in entries {
84        let entry = entry?;
85        let path = entry.path();
86        if path.is_dir() {
87            results.extend(collect_dir_files(base, &path)?);
88        } else {
89            let rel = path.strip_prefix(base).unwrap_or(&path);
90            results.push(rel.to_path_buf());
91        }
92    }
93    Ok(results)
94}
95
96/// Recursively collect files under `dir` that match `glob`, returning paths relative to `base`.
97fn collect_files(base: &Path, dir: &Path, glob: &globset::GlobMatcher) -> Result<Vec<PathBuf>> {
98    let mut results = Vec::new();
99    let entries = match fs::read_dir(dir) {
100        Ok(entries) => entries,
101        Err(_) => return Ok(results),
102    };
103    for entry in entries {
104        let entry = entry?;
105        let path = entry.path();
106        let rel = path.strip_prefix(base).unwrap_or(&path);
107        if path.is_dir() {
108            results.extend(collect_files(base, &path, glob)?);
109        } else if glob.is_match(rel) {
110            results.push(rel.to_path_buf());
111        }
112    }
113    Ok(results)
114}
115
116/// Copy include files from source to target directory.
117fn copy_include_files(source: &Path, target: &Path, patterns: &[String]) -> Result<()> {
118    for pattern in patterns {
119        let files = expand_pattern(source, pattern);
120        for file in &files {
121            let src = source.join(file);
122            let dst = target.join(file);
123            if let Some(parent) = dst.parent() {
124                fs::create_dir_all(parent)?;
125            }
126            fs::copy(&src, &dst).with_context(|| format!("failed to copy {}", file.display()))?;
127            eprintln!("copied {}", file.display());
128        }
129    }
130    Ok(())
131}
132
133/// Run commands sequentially in the target directory. Stops on first failure.
134fn run_commands(target: &Path, commands: &[String]) -> Result<()> {
135    for cmd in commands {
136        eprintln!("running: {}", cmd);
137        let status = Command::new("sh")
138            .arg("-c")
139            .arg(cmd)
140            .current_dir(target)
141            .status()
142            .with_context(|| format!("failed to run: {}", cmd))?;
143        if !status.success() {
144            bail!(
145                "command failed (exit {}): {}",
146                status.code().unwrap_or(-1),
147                cmd
148            );
149        }
150    }
151    Ok(())
152}
153
154/// Run init for the given target directory.
155///
156/// Resolves the working tree toplevel, then reads `.worktree-init.toml` (or `.yawn.toml`)
157/// from the main repo root. If the toplevel is a worktree (different from the main repo),
158/// copies include files. Then runs commands.
159pub fn run(target: &Path) -> Result<()> {
160    let toplevel = git::toplevel(target).context("not inside a git repository")?;
161    let repo_root = git::repo_root(target).context("not inside a git repository")?;
162    let config = load_project_config(&repo_root)?;
163
164    if config.include.is_empty() && config.commands.is_empty() {
165        eprintln!("nothing to do: no [init] config in .worktree-init.toml");
166        return Ok(());
167    }
168
169    // Copy include files only when toplevel differs from repo root (i.e. we're in a worktree)
170    if !config.include.is_empty() {
171        let toplevel_canonical = toplevel.canonicalize().unwrap_or_else(|_| toplevel.clone());
172        let root_canonical = repo_root
173            .canonicalize()
174            .unwrap_or_else(|_| repo_root.clone());
175        if toplevel_canonical != root_canonical {
176            copy_include_files(&repo_root, &toplevel, &config.include)?;
177        }
178    }
179
180    if !config.commands.is_empty() {
181        run_commands(&toplevel, &config.commands)?;
182    }
183
184    Ok(())
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use std::process::Command as StdCommand;
191    use tempfile::TempDir;
192
193    fn git_cmd(dir: &Path, args: &[&str]) -> String {
194        let output = StdCommand::new("git")
195            .args(args)
196            .current_dir(dir)
197            .output()
198            .unwrap();
199        String::from_utf8_lossy(&output.stdout).trim().to_string()
200    }
201
202    fn init_repo(path: &Path) {
203        fs::create_dir_all(path).unwrap();
204        git_cmd(path, &["init"]);
205        git_cmd(path, &["config", "user.email", "test@test.com"]);
206        git_cmd(path, &["config", "user.name", "Test"]);
207        fs::write(path.join("README.md"), "# test").unwrap();
208        git_cmd(path, &["add", "."]);
209        git_cmd(path, &["commit", "-m", "init"]);
210        let _ = git_cmd(path, &["branch", "-M", "main"]);
211    }
212
213    // --- load_project_config tests ---
214
215    #[test]
216    fn test_load_no_config_file() {
217        let tmp = TempDir::new().unwrap();
218        let config = load_project_config(tmp.path()).unwrap();
219        assert!(config.include.is_empty());
220        assert!(config.commands.is_empty());
221    }
222
223    #[test]
224    fn test_load_empty_config() {
225        let tmp = TempDir::new().unwrap();
226        fs::write(tmp.path().join(".worktree-init.toml"), "").unwrap();
227        let config = load_project_config(tmp.path()).unwrap();
228        assert!(config.include.is_empty());
229        assert!(config.commands.is_empty());
230    }
231
232    #[test]
233    fn test_load_include_only() {
234        let tmp = TempDir::new().unwrap();
235        fs::write(
236            tmp.path().join(".worktree-init.toml"),
237            "[init]\ninclude = [\".env\", \"config/*.toml\"]\n",
238        )
239        .unwrap();
240        let config = load_project_config(tmp.path()).unwrap();
241        assert_eq!(config.include, vec![".env", "config/*.toml"]);
242        assert!(config.commands.is_empty());
243    }
244
245    #[test]
246    fn test_load_commands_only() {
247        let tmp = TempDir::new().unwrap();
248        fs::write(
249            tmp.path().join(".worktree-init.toml"),
250            "[init]\ncommands = [\"npm install\"]\n",
251        )
252        .unwrap();
253        let config = load_project_config(tmp.path()).unwrap();
254        assert!(config.include.is_empty());
255        assert_eq!(config.commands, vec!["npm install"]);
256    }
257
258    #[test]
259    fn test_load_full_config() {
260        let tmp = TempDir::new().unwrap();
261        fs::write(
262            tmp.path().join(".worktree-init.toml"),
263            "[init]\ninclude = [\".env\"]\ncommands = [\"npm install\", \"cargo build\"]\n",
264        )
265        .unwrap();
266        let config = load_project_config(tmp.path()).unwrap();
267        assert_eq!(config.include, vec![".env"]);
268        assert_eq!(config.commands, vec!["npm install", "cargo build"]);
269    }
270
271    #[test]
272    fn test_load_invalid_toml() {
273        let tmp = TempDir::new().unwrap();
274        fs::write(tmp.path().join(".worktree-init.toml"), "{{invalid").unwrap();
275        assert!(load_project_config(tmp.path()).is_err());
276    }
277
278    #[test]
279    fn test_load_fallback_to_yawn_toml() {
280        let tmp = TempDir::new().unwrap();
281        fs::write(
282            tmp.path().join(".yawn.toml"),
283            "[init]\ninclude = [\".env\"]\n",
284        )
285        .unwrap();
286        let config = load_project_config(tmp.path()).unwrap();
287        assert_eq!(config.include, vec![".env"]);
288    }
289
290    #[test]
291    fn test_load_prefers_worktree_init_toml() {
292        let tmp = TempDir::new().unwrap();
293        fs::write(
294            tmp.path().join(".worktree-init.toml"),
295            "[init]\ninclude = [\"from-new\"]\n",
296        )
297        .unwrap();
298        fs::write(
299            tmp.path().join(".yawn.toml"),
300            "[init]\ninclude = [\"from-old\"]\n",
301        )
302        .unwrap();
303        let config = load_project_config(tmp.path()).unwrap();
304        assert_eq!(config.include, vec!["from-new"]);
305    }
306
307    // --- copy_include_files tests ---
308
309    #[test]
310    fn test_copy_include_literal_files() {
311        let tmp = TempDir::new().unwrap();
312        let source = tmp.path().join("source");
313        let target = tmp.path().join("target");
314        fs::create_dir_all(&source).unwrap();
315        fs::create_dir_all(&target).unwrap();
316
317        fs::write(source.join(".env"), "SECRET=123").unwrap();
318        fs::create_dir_all(source.join("config")).unwrap();
319        fs::write(source.join("config/local.toml"), "[db]\nhost=localhost").unwrap();
320
321        copy_include_files(
322            &source,
323            &target,
324            &[".env".into(), "config/local.toml".into()],
325        )
326        .unwrap();
327
328        assert_eq!(
329            fs::read_to_string(target.join(".env")).unwrap(),
330            "SECRET=123"
331        );
332        assert_eq!(
333            fs::read_to_string(target.join("config/local.toml")).unwrap(),
334            "[db]\nhost=localhost"
335        );
336    }
337
338    #[test]
339    fn test_copy_include_glob_pattern() {
340        let tmp = TempDir::new().unwrap();
341        let source = tmp.path().join("source");
342        let target = tmp.path().join("target");
343        fs::create_dir_all(&source).unwrap();
344        fs::create_dir_all(&target).unwrap();
345
346        fs::write(source.join("data_users.csv"), "id,name").unwrap();
347        fs::write(source.join("data_orders.csv"), "id,total").unwrap();
348        fs::write(source.join("other.csv"), "should not copy").unwrap();
349
350        copy_include_files(&source, &target, &["data_*.csv".into()]).unwrap();
351
352        assert!(target.join("data_users.csv").exists());
353        assert!(target.join("data_orders.csv").exists());
354        assert!(!target.join("other.csv").exists());
355    }
356
357    #[test]
358    fn test_copy_include_glob_in_subdir() {
359        let tmp = TempDir::new().unwrap();
360        let source = tmp.path().join("source");
361        let target = tmp.path().join("target");
362        fs::create_dir_all(&source).unwrap();
363        fs::create_dir_all(&target).unwrap();
364
365        fs::create_dir_all(source.join("config")).unwrap();
366        fs::write(source.join("config/dev.toml"), "dev").unwrap();
367        fs::write(source.join("config/test.toml"), "test").unwrap();
368        fs::write(source.join("config/keep.json"), "not matched").unwrap();
369
370        copy_include_files(&source, &target, &["config/*.toml".into()]).unwrap();
371
372        assert!(target.join("config/dev.toml").exists());
373        assert!(target.join("config/test.toml").exists());
374        assert!(!target.join("config/keep.json").exists());
375    }
376
377    #[test]
378    fn test_copy_include_directory() {
379        let tmp = TempDir::new().unwrap();
380        let source = tmp.path().join("source");
381        let target = tmp.path().join("target");
382        fs::create_dir_all(&source).unwrap();
383        fs::create_dir_all(&target).unwrap();
384
385        fs::create_dir_all(source.join(".cache/sub")).unwrap();
386        fs::write(source.join(".cache/a.txt"), "aaa").unwrap();
387        fs::write(source.join(".cache/sub/b.txt"), "bbb").unwrap();
388
389        copy_include_files(&source, &target, &[".cache".into()]).unwrap();
390
391        assert_eq!(
392            fs::read_to_string(target.join(".cache/a.txt")).unwrap(),
393            "aaa"
394        );
395        assert_eq!(
396            fs::read_to_string(target.join(".cache/sub/b.txt")).unwrap(),
397            "bbb"
398        );
399    }
400
401    #[test]
402    fn test_copy_include_missing_source() {
403        let tmp = TempDir::new().unwrap();
404        let source = tmp.path().join("source");
405        let target = tmp.path().join("target");
406        fs::create_dir_all(&source).unwrap();
407        fs::create_dir_all(&target).unwrap();
408
409        fs::write(source.join(".env"), "SECRET=123").unwrap();
410
411        copy_include_files(&source, &target, &[".env".into(), "missing-file".into()]).unwrap();
412        assert!(target.join(".env").exists());
413        assert!(!target.join("missing-file").exists());
414    }
415
416    #[test]
417    fn test_copy_include_no_glob_matches() {
418        let tmp = TempDir::new().unwrap();
419        let source = tmp.path().join("source");
420        let target = tmp.path().join("target");
421        fs::create_dir_all(&source).unwrap();
422        fs::create_dir_all(&target).unwrap();
423
424        copy_include_files(&source, &target, &["*.xyz".into()]).unwrap();
425    }
426
427    // --- run_commands tests ---
428
429    #[test]
430    fn test_run_commands_success() {
431        let tmp = TempDir::new().unwrap();
432        run_commands(tmp.path(), &["echo hello > out.txt".into()]).unwrap();
433        assert_eq!(
434            fs::read_to_string(tmp.path().join("out.txt"))
435                .unwrap()
436                .trim(),
437            "hello"
438        );
439    }
440
441    #[test]
442    fn test_run_commands_failure() {
443        let tmp = TempDir::new().unwrap();
444        let result = run_commands(tmp.path(), &["false".into()]);
445        assert!(result.is_err());
446        assert!(result.unwrap_err().to_string().contains("command failed"));
447    }
448
449    #[test]
450    fn test_run_commands_sequential() {
451        let tmp = TempDir::new().unwrap();
452        run_commands(
453            tmp.path(),
454            &[
455                "echo first > first.txt".into(),
456                "echo second > second.txt".into(),
457            ],
458        )
459        .unwrap();
460        assert!(tmp.path().join("first.txt").exists());
461        assert!(tmp.path().join("second.txt").exists());
462    }
463
464    #[test]
465    fn test_run_commands_stops_on_failure() {
466        let tmp = TempDir::new().unwrap();
467        let result = run_commands(
468            tmp.path(),
469            &[
470                "echo first > first.txt".into(),
471                "false".into(),
472                "echo third > third.txt".into(),
473            ],
474        );
475        assert!(result.is_err());
476        assert!(tmp.path().join("first.txt").exists());
477        assert!(!tmp.path().join("third.txt").exists());
478    }
479
480    // --- integration tests ---
481
482    #[test]
483    fn test_run_on_worktree_copies_files_and_runs_commands() {
484        let tmp = TempDir::new().unwrap();
485        let repo = tmp.path().join("myproject");
486        init_repo(&repo);
487
488        // Set up .worktree-init.toml and include files
489        fs::write(
490            repo.join(".worktree-init.toml"),
491            "[init]\ninclude = [\".env\"]\ncommands = [\"echo done > .init_marker\"]\n",
492        )
493        .unwrap();
494        fs::write(repo.join(".env"), "DB_HOST=localhost").unwrap();
495        git_cmd(&repo, &["add", "."]);
496        git_cmd(&repo, &["commit", "-m", "add config"]);
497
498        // Create a worktree
499        let wt_path = tmp.path().join("myproject--feature");
500        git_cmd(
501            &repo,
502            &[
503                "worktree",
504                "add",
505                "-b",
506                "feature",
507                &wt_path.to_string_lossy(),
508            ],
509        );
510
511        // Run init on the worktree
512        run(&wt_path).unwrap();
513
514        assert_eq!(
515            fs::read_to_string(wt_path.join(".env")).unwrap(),
516            "DB_HOST=localhost"
517        );
518        assert!(wt_path.join(".init_marker").exists());
519    }
520
521    #[test]
522    fn test_run_on_main_repo_skips_copy_runs_commands() {
523        let tmp = TempDir::new().unwrap();
524        let repo = tmp.path().join("myproject");
525        init_repo(&repo);
526
527        fs::write(
528            repo.join(".worktree-init.toml"),
529            "[init]\ninclude = [\".env\"]\ncommands = [\"echo done > .init_marker\"]\n",
530        )
531        .unwrap();
532        fs::write(repo.join(".env"), "DB_HOST=localhost").unwrap();
533        git_cmd(&repo, &["add", "."]);
534        git_cmd(&repo, &["commit", "-m", "add config"]);
535
536        // Run init on the main repo itself
537        run(&repo).unwrap();
538
539        // Commands should have run
540        assert!(repo.join(".init_marker").exists());
541    }
542
543    #[test]
544    fn test_run_no_config() {
545        let tmp = TempDir::new().unwrap();
546        let repo = tmp.path().join("myproject");
547        init_repo(&repo);
548
549        // Should silently succeed with no config file
550        run(&repo).unwrap();
551    }
552}