Skip to main content

tam_worktree/
init.rs

1use anyhow::{bail, Context, Result};
2use serde::Deserialize;
3use std::fs;
4use std::path::{Path, PathBuf};
5use std::process::Command;
6
7use crate::git;
8
9#[derive(Debug, Deserialize, Default)]
10struct ProjectConfigFile {
11    init: Option<InitConfig>,
12}
13
14#[derive(Debug, Deserialize, Default)]
15struct InitConfig {
16    include: Option<Vec<String>>,
17    commands: Option<Vec<String>>,
18}
19
20/// Parsed `.tam.toml` configuration for initializing new worktrees.
21#[derive(Debug, Default)]
22pub struct ProjectInit {
23    /// File globs to copy from the main checkout (e.g. `[".env", ".claude/**"]`).
24    pub include: Vec<String>,
25    /// Shell commands to run in the new worktree (e.g. `["npm install"]`).
26    pub commands: Vec<String>,
27}
28
29/// Load project init config from `.tam.toml`.
30pub fn load_project_config(repo_root: &Path) -> Result<ProjectInit> {
31    let config_path = repo_root.join(".tam.toml");
32
33    if !config_path.exists() {
34        return Ok(ProjectInit::default());
35    }
36
37    let content = fs::read_to_string(&config_path)
38        .with_context(|| format!("failed to read {}", config_path.display()))?;
39    let file: ProjectConfigFile = toml::from_str(&content)
40        .with_context(|| format!("failed to parse {}", config_path.display()))?;
41
42    let init = file.init.unwrap_or_default();
43    Ok(ProjectInit {
44        include: init.include.unwrap_or_default(),
45        commands: init.commands.unwrap_or_default(),
46    })
47}
48
49/// Expand a pattern relative to a directory, supporting globs.
50/// If the pattern contains glob characters, expand it; otherwise treat it as a literal path.
51fn expand_pattern(base: &Path, pattern: &str) -> Vec<PathBuf> {
52    if pattern.contains('*') || pattern.contains('?') || pattern.contains('[') {
53        let glob = match globset::Glob::new(pattern) {
54            Ok(g) => g.compile_matcher(),
55            Err(_) => return Vec::new(),
56        };
57        collect_files(base, base, &glob).unwrap_or_default()
58    } else {
59        let path = base.join(pattern);
60        if path.is_dir() {
61            collect_dir_files(base, &path).unwrap_or_default()
62        } else if path.exists() {
63            vec![PathBuf::from(pattern)]
64        } else {
65            Vec::new()
66        }
67    }
68}
69
70/// Recursively collect all files under `dir`, returning paths relative to `base`.
71fn collect_dir_files(base: &Path, dir: &Path) -> Result<Vec<PathBuf>> {
72    let mut results = Vec::new();
73    let entries = match fs::read_dir(dir) {
74        Ok(entries) => entries,
75        Err(_) => return Ok(results),
76    };
77    for entry in entries {
78        let entry = entry?;
79        let path = entry.path();
80        if path.is_dir() {
81            results.extend(collect_dir_files(base, &path)?);
82        } else {
83            let rel = path.strip_prefix(base).unwrap_or(&path);
84            results.push(rel.to_path_buf());
85        }
86    }
87    Ok(results)
88}
89
90/// Recursively collect files under `dir` that match `glob`, returning paths relative to `base`.
91fn collect_files(base: &Path, dir: &Path, glob: &globset::GlobMatcher) -> Result<Vec<PathBuf>> {
92    let mut results = Vec::new();
93    let entries = match fs::read_dir(dir) {
94        Ok(entries) => entries,
95        Err(_) => return Ok(results),
96    };
97    for entry in entries {
98        let entry = entry?;
99        let path = entry.path();
100        let rel = path.strip_prefix(base).unwrap_or(&path);
101        if path.is_dir() {
102            results.extend(collect_files(base, &path, glob)?);
103        } else if glob.is_match(rel) {
104            results.push(rel.to_path_buf());
105        }
106    }
107    Ok(results)
108}
109
110/// Copy include files from source to target directory.
111fn copy_include_files(source: &Path, target: &Path, patterns: &[String]) -> Result<()> {
112    for pattern in patterns {
113        let files = expand_pattern(source, pattern);
114        for file in &files {
115            let src = source.join(file);
116            let dst = target.join(file);
117            if let Some(parent) = dst.parent() {
118                fs::create_dir_all(parent)?;
119            }
120            fs::copy(&src, &dst).with_context(|| format!("failed to copy {}", file.display()))?;
121            eprintln!("copied {}", file.display());
122        }
123    }
124    Ok(())
125}
126
127/// Run commands sequentially in the target directory. Stops on first failure.
128fn run_commands(target: &Path, commands: &[String]) -> Result<()> {
129    for cmd in commands {
130        eprintln!("running: {}", cmd);
131        let status = Command::new("sh")
132            .arg("-c")
133            .arg(cmd)
134            .current_dir(target)
135            .status()
136            .with_context(|| format!("failed to run: {}", cmd))?;
137        if !status.success() {
138            bail!(
139                "command failed (exit {}): {}",
140                status.code().unwrap_or(-1),
141                cmd
142            );
143        }
144    }
145    Ok(())
146}
147
148/// Run init for the given target directory.
149///
150/// Resolves the working tree toplevel, then reads `.tam.toml` (or legacy fallbacks)
151/// from the main repo root. If the toplevel is a worktree (different from the main repo),
152/// copies include files. Then runs commands.
153pub fn run(target: &Path) -> Result<()> {
154    let toplevel = git::toplevel(target).context("not inside a git repository")?;
155    let repo_root = git::repo_root(target).context("not inside a git repository")?;
156    let config = load_project_config(&repo_root)?;
157
158    if config.include.is_empty() && config.commands.is_empty() {
159        eprintln!("nothing to do: no [init] config in .tam.toml");
160        return Ok(());
161    }
162
163    // Copy include files only when toplevel differs from repo root (i.e. we're in a worktree)
164    if !config.include.is_empty() {
165        let toplevel_canonical = toplevel.canonicalize().unwrap_or_else(|_| toplevel.clone());
166        let root_canonical = repo_root
167            .canonicalize()
168            .unwrap_or_else(|_| repo_root.clone());
169        if toplevel_canonical != root_canonical {
170            copy_include_files(&repo_root, &toplevel, &config.include)?;
171        }
172    }
173
174    if !config.commands.is_empty() {
175        run_commands(&toplevel, &config.commands)?;
176    }
177
178    Ok(())
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use std::process::Command as StdCommand;
185    use tempfile::TempDir;
186
187    fn git_cmd(dir: &Path, args: &[&str]) -> String {
188        let output = StdCommand::new("git")
189            .args(args)
190            .current_dir(dir)
191            .output()
192            .unwrap();
193        String::from_utf8_lossy(&output.stdout).trim().to_string()
194    }
195
196    fn init_repo(path: &Path) {
197        fs::create_dir_all(path).unwrap();
198        git_cmd(path, &["init"]);
199        git_cmd(path, &["config", "user.email", "test@test.com"]);
200        git_cmd(path, &["config", "user.name", "Test"]);
201        fs::write(path.join("README.md"), "# test").unwrap();
202        git_cmd(path, &["add", "."]);
203        git_cmd(path, &["commit", "-m", "init"]);
204        let _ = git_cmd(path, &["branch", "-M", "main"]);
205    }
206
207    // --- load_project_config tests ---
208
209    #[test]
210    fn test_load_no_config_file() {
211        let tmp = TempDir::new().unwrap();
212        let config = load_project_config(tmp.path()).unwrap();
213        assert!(config.include.is_empty());
214        assert!(config.commands.is_empty());
215    }
216
217    #[test]
218    fn test_load_empty_config() {
219        let tmp = TempDir::new().unwrap();
220        fs::write(tmp.path().join(".tam.toml"), "").unwrap();
221        let config = load_project_config(tmp.path()).unwrap();
222        assert!(config.include.is_empty());
223        assert!(config.commands.is_empty());
224    }
225
226    #[test]
227    fn test_load_include_only() {
228        let tmp = TempDir::new().unwrap();
229        fs::write(
230            tmp.path().join(".tam.toml"),
231            "[init]\ninclude = [\".env\", \"config/*.toml\"]\n",
232        )
233        .unwrap();
234        let config = load_project_config(tmp.path()).unwrap();
235        assert_eq!(config.include, vec![".env", "config/*.toml"]);
236        assert!(config.commands.is_empty());
237    }
238
239    #[test]
240    fn test_load_commands_only() {
241        let tmp = TempDir::new().unwrap();
242        fs::write(
243            tmp.path().join(".tam.toml"),
244            "[init]\ncommands = [\"npm install\"]\n",
245        )
246        .unwrap();
247        let config = load_project_config(tmp.path()).unwrap();
248        assert!(config.include.is_empty());
249        assert_eq!(config.commands, vec!["npm install"]);
250    }
251
252    #[test]
253    fn test_load_full_config() {
254        let tmp = TempDir::new().unwrap();
255        fs::write(
256            tmp.path().join(".tam.toml"),
257            "[init]\ninclude = [\".env\"]\ncommands = [\"npm install\", \"cargo build\"]\n",
258        )
259        .unwrap();
260        let config = load_project_config(tmp.path()).unwrap();
261        assert_eq!(config.include, vec![".env"]);
262        assert_eq!(config.commands, vec!["npm install", "cargo build"]);
263    }
264
265    #[test]
266    fn test_load_invalid_toml() {
267        let tmp = TempDir::new().unwrap();
268        fs::write(tmp.path().join(".tam.toml"), "{{invalid").unwrap();
269        assert!(load_project_config(tmp.path()).is_err());
270    }
271
272    #[test]
273    fn test_load_ignores_legacy_files() {
274        let tmp = TempDir::new().unwrap();
275        fs::write(
276            tmp.path().join(".worktree-init.toml"),
277            "[init]\ninclude = [\".env\"]\n",
278        )
279        .unwrap();
280        fs::write(
281            tmp.path().join(".yawn.toml"),
282            "[init]\ninclude = [\".env\"]\n",
283        )
284        .unwrap();
285        let config = load_project_config(tmp.path()).unwrap();
286        assert!(config.include.is_empty());
287    }
288
289    // --- copy_include_files tests ---
290
291    #[test]
292    fn test_copy_include_literal_files() {
293        let tmp = TempDir::new().unwrap();
294        let source = tmp.path().join("source");
295        let target = tmp.path().join("target");
296        fs::create_dir_all(&source).unwrap();
297        fs::create_dir_all(&target).unwrap();
298
299        fs::write(source.join(".env"), "SECRET=123").unwrap();
300        fs::create_dir_all(source.join("config")).unwrap();
301        fs::write(source.join("config/local.toml"), "[db]\nhost=localhost").unwrap();
302
303        copy_include_files(
304            &source,
305            &target,
306            &[".env".into(), "config/local.toml".into()],
307        )
308        .unwrap();
309
310        assert_eq!(
311            fs::read_to_string(target.join(".env")).unwrap(),
312            "SECRET=123"
313        );
314        assert_eq!(
315            fs::read_to_string(target.join("config/local.toml")).unwrap(),
316            "[db]\nhost=localhost"
317        );
318    }
319
320    #[test]
321    fn test_copy_include_glob_pattern() {
322        let tmp = TempDir::new().unwrap();
323        let source = tmp.path().join("source");
324        let target = tmp.path().join("target");
325        fs::create_dir_all(&source).unwrap();
326        fs::create_dir_all(&target).unwrap();
327
328        fs::write(source.join("data_users.csv"), "id,name").unwrap();
329        fs::write(source.join("data_orders.csv"), "id,total").unwrap();
330        fs::write(source.join("other.csv"), "should not copy").unwrap();
331
332        copy_include_files(&source, &target, &["data_*.csv".into()]).unwrap();
333
334        assert!(target.join("data_users.csv").exists());
335        assert!(target.join("data_orders.csv").exists());
336        assert!(!target.join("other.csv").exists());
337    }
338
339    #[test]
340    fn test_copy_include_glob_in_subdir() {
341        let tmp = TempDir::new().unwrap();
342        let source = tmp.path().join("source");
343        let target = tmp.path().join("target");
344        fs::create_dir_all(&source).unwrap();
345        fs::create_dir_all(&target).unwrap();
346
347        fs::create_dir_all(source.join("config")).unwrap();
348        fs::write(source.join("config/dev.toml"), "dev").unwrap();
349        fs::write(source.join("config/test.toml"), "test").unwrap();
350        fs::write(source.join("config/keep.json"), "not matched").unwrap();
351
352        copy_include_files(&source, &target, &["config/*.toml".into()]).unwrap();
353
354        assert!(target.join("config/dev.toml").exists());
355        assert!(target.join("config/test.toml").exists());
356        assert!(!target.join("config/keep.json").exists());
357    }
358
359    #[test]
360    fn test_copy_include_directory() {
361        let tmp = TempDir::new().unwrap();
362        let source = tmp.path().join("source");
363        let target = tmp.path().join("target");
364        fs::create_dir_all(&source).unwrap();
365        fs::create_dir_all(&target).unwrap();
366
367        fs::create_dir_all(source.join(".cache/sub")).unwrap();
368        fs::write(source.join(".cache/a.txt"), "aaa").unwrap();
369        fs::write(source.join(".cache/sub/b.txt"), "bbb").unwrap();
370
371        copy_include_files(&source, &target, &[".cache".into()]).unwrap();
372
373        assert_eq!(
374            fs::read_to_string(target.join(".cache/a.txt")).unwrap(),
375            "aaa"
376        );
377        assert_eq!(
378            fs::read_to_string(target.join(".cache/sub/b.txt")).unwrap(),
379            "bbb"
380        );
381    }
382
383    #[test]
384    fn test_copy_include_missing_source() {
385        let tmp = TempDir::new().unwrap();
386        let source = tmp.path().join("source");
387        let target = tmp.path().join("target");
388        fs::create_dir_all(&source).unwrap();
389        fs::create_dir_all(&target).unwrap();
390
391        fs::write(source.join(".env"), "SECRET=123").unwrap();
392
393        copy_include_files(&source, &target, &[".env".into(), "missing-file".into()]).unwrap();
394        assert!(target.join(".env").exists());
395        assert!(!target.join("missing-file").exists());
396    }
397
398    #[test]
399    fn test_copy_include_no_glob_matches() {
400        let tmp = TempDir::new().unwrap();
401        let source = tmp.path().join("source");
402        let target = tmp.path().join("target");
403        fs::create_dir_all(&source).unwrap();
404        fs::create_dir_all(&target).unwrap();
405
406        copy_include_files(&source, &target, &["*.xyz".into()]).unwrap();
407    }
408
409    // --- run_commands tests ---
410
411    #[test]
412    fn test_run_commands_success() {
413        let tmp = TempDir::new().unwrap();
414        run_commands(tmp.path(), &["echo hello > out.txt".into()]).unwrap();
415        assert_eq!(
416            fs::read_to_string(tmp.path().join("out.txt"))
417                .unwrap()
418                .trim(),
419            "hello"
420        );
421    }
422
423    #[test]
424    fn test_run_commands_failure() {
425        let tmp = TempDir::new().unwrap();
426        let result = run_commands(tmp.path(), &["false".into()]);
427        assert!(result.is_err());
428        assert!(result.unwrap_err().to_string().contains("command failed"));
429    }
430
431    #[test]
432    fn test_run_commands_sequential() {
433        let tmp = TempDir::new().unwrap();
434        run_commands(
435            tmp.path(),
436            &[
437                "echo first > first.txt".into(),
438                "echo second > second.txt".into(),
439            ],
440        )
441        .unwrap();
442        assert!(tmp.path().join("first.txt").exists());
443        assert!(tmp.path().join("second.txt").exists());
444    }
445
446    #[test]
447    fn test_run_commands_stops_on_failure() {
448        let tmp = TempDir::new().unwrap();
449        let result = run_commands(
450            tmp.path(),
451            &[
452                "echo first > first.txt".into(),
453                "false".into(),
454                "echo third > third.txt".into(),
455            ],
456        );
457        assert!(result.is_err());
458        assert!(tmp.path().join("first.txt").exists());
459        assert!(!tmp.path().join("third.txt").exists());
460    }
461
462    // --- integration tests ---
463
464    #[test]
465    fn test_run_on_worktree_copies_files_and_runs_commands() {
466        let tmp = TempDir::new().unwrap();
467        let repo = tmp.path().join("myproject");
468        init_repo(&repo);
469
470        // Set up .tam.toml and include files
471        fs::write(
472            repo.join(".tam.toml"),
473            "[init]\ninclude = [\".env\"]\ncommands = [\"echo done > .init_marker\"]\n",
474        )
475        .unwrap();
476        fs::write(repo.join(".env"), "DB_HOST=localhost").unwrap();
477        git_cmd(&repo, &["add", "."]);
478        git_cmd(&repo, &["commit", "-m", "add config"]);
479
480        // Create a worktree
481        let wt_path = tmp.path().join("myproject--feature");
482        git_cmd(
483            &repo,
484            &[
485                "worktree",
486                "add",
487                "-b",
488                "feature",
489                &wt_path.to_string_lossy(),
490            ],
491        );
492
493        // Run init on the worktree
494        run(&wt_path).unwrap();
495
496        assert_eq!(
497            fs::read_to_string(wt_path.join(".env")).unwrap(),
498            "DB_HOST=localhost"
499        );
500        assert!(wt_path.join(".init_marker").exists());
501    }
502
503    #[test]
504    fn test_run_on_main_repo_skips_copy_runs_commands() {
505        let tmp = TempDir::new().unwrap();
506        let repo = tmp.path().join("myproject");
507        init_repo(&repo);
508
509        fs::write(
510            repo.join(".tam.toml"),
511            "[init]\ninclude = [\".env\"]\ncommands = [\"echo done > .init_marker\"]\n",
512        )
513        .unwrap();
514        fs::write(repo.join(".env"), "DB_HOST=localhost").unwrap();
515        git_cmd(&repo, &["add", "."]);
516        git_cmd(&repo, &["commit", "-m", "add config"]);
517
518        // Run init on the main repo itself
519        run(&repo).unwrap();
520
521        // Commands should have run
522        assert!(repo.join(".init_marker").exists());
523    }
524
525    #[test]
526    fn test_run_no_config() {
527        let tmp = TempDir::new().unwrap();
528        let repo = tmp.path().join("myproject");
529        init_repo(&repo);
530
531        // Should silently succeed with no config file
532        run(&repo).unwrap();
533    }
534}