Skip to main content

tracevault_cli/commands/
init.rs

1use crate::api_client::ApiClient;
2use crate::config::TracevaultConfig;
3use std::fs;
4use std::io;
5use std::path::Path;
6
7pub fn git_remote_url(project_root: &Path) -> Option<String> {
8    std::process::Command::new("git")
9        .args(["remote", "get-url", "origin"])
10        .current_dir(project_root)
11        .output()
12        .ok()
13        .filter(|o| o.status.success())
14        .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
15        .filter(|s| !s.is_empty())
16}
17
18fn parse_github_org(remote_url: &str) -> Option<String> {
19    // SSH: git@github.com:softwaremill/tracevault.git
20    if let Some(path) = remote_url.strip_prefix("git@github.com:") {
21        return path.split('/').next().map(String::from);
22    }
23    // HTTPS: https://github.com/softwaremill/tracevault.git
24    if let Some(path) = remote_url
25        .strip_prefix("https://github.com/")
26        .or_else(|| remote_url.strip_prefix("http://github.com/"))
27    {
28        return path.split('/').next().map(String::from);
29    }
30    None
31}
32
33pub async fn init_in_directory(
34    project_root: &Path,
35    server_url: Option<&str>,
36) -> Result<(), io::Error> {
37    // Check for git repository
38    if !project_root.join(".git").exists() {
39        return Err(io::Error::new(
40            io::ErrorKind::NotFound,
41            "Not a git repository. Run 'git init' first.",
42        ));
43    }
44
45    // Create .tracevault/ directory
46    let config_dir = TracevaultConfig::config_dir(project_root);
47    fs::create_dir_all(&config_dir)?;
48    fs::create_dir_all(config_dir.join("sessions"))?;
49    fs::create_dir_all(config_dir.join("cache"))?;
50
51    // Register repo on server if authenticated, server URL known, and git remote available
52    let remote_url = git_remote_url(project_root);
53    if remote_url.is_none() {
54        eprintln!("Warning: no git remote 'origin' configured. Skipping server registration.");
55        eprintln!("Run 'git remote add origin <url>' then 'tracevault sync' to register.");
56    }
57
58    // Extract org slug from GitHub remote URL
59    let org_slug = remote_url.as_deref().and_then(parse_github_org);
60
61    // Write config (include server_url and org_slug if available)
62    let mut config = TracevaultConfig::default();
63    if let Some(url) = server_url {
64        config.server_url = Some(url.to_string());
65    }
66    config.org_slug = org_slug.clone();
67    fs::write(
68        TracevaultConfig::config_path(project_root),
69        config.to_toml(),
70    )?;
71
72    // Create .tracevault/.gitignore
73    fs::write(
74        config_dir.join(".gitignore"),
75        "sessions/\ncache/\n*.local.toml\n",
76    )?;
77
78    // Install Claude Code hooks into .claude/settings.json
79    install_claude_hooks(project_root)?;
80
81    // Install git pre-push hook
82    install_git_hook(project_root)?;
83
84    let (resolved_url, resolved_token) = crate::api_client::resolve_credentials(project_root);
85    let effective_url = server_url.map(String::from).or(resolved_url);
86
87    if resolved_token.is_none() {
88        eprintln!("Not logged in. Run 'tracevault login' to register this repo with the server.");
89    } else if let (Some(url), Some(remote), Some(slug)) = (effective_url, remote_url, org_slug) {
90        let client = ApiClient::new(&url, resolved_token.as_deref());
91        let repo_name = git_repo_name(project_root);
92
93        match client
94            .register_repo(
95                &slug,
96                crate::api_client::RegisterRepoRequest {
97                    repo_name,
98                    github_url: Some(remote),
99                },
100            )
101            .await
102        {
103            Ok(resp) => {
104                println!("Repo registered on server (id: {})", resp.repo_id);
105            }
106            Err(e) => {
107                let msg = e.to_string();
108                if msg.contains("404") {
109                    eprintln!("Warning: organization '{}' not found on the server.", slug);
110                    eprintln!(
111                        "Create it first at your TraceVault instance, then run 'tracevault sync'."
112                    );
113                } else if msg.contains("403") {
114                    eprintln!("Warning: you are not a member of organization '{}'.", slug);
115                } else {
116                    eprintln!("Warning: could not register repo on server: {e}");
117                }
118            }
119        }
120    }
121
122    Ok(())
123}
124
125fn git_repo_name(project_root: &Path) -> String {
126    std::process::Command::new("git")
127        .args(["rev-parse", "--show-toplevel"])
128        .current_dir(project_root)
129        .output()
130        .ok()
131        .filter(|o| o.status.success())
132        .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
133        .as_deref()
134        .and_then(|p| p.rsplit('/').next())
135        .map(String::from)
136        .unwrap_or_else(|| "unknown".into())
137}
138
139const HOOK_MARKER: &str = "# tracevault:enforce";
140const OLD_HOOK_MARKER: &str = "# tracevault:auto-push";
141
142fn install_git_hook(project_root: &Path) -> Result<(), io::Error> {
143    let hooks_dir = project_root.join(".git/hooks");
144    fs::create_dir_all(&hooks_dir)?;
145
146    let hook_path = hooks_dir.join("pre-push");
147    let tracevault_block = format!(
148        "{HOOK_MARKER}\ntracevault sync 2>/dev/null || true\ntracevault check || {{ echo \"tracevault: policy check failed, push blocked.\"; exit 1; }}\ntracevault push || {{ echo \"tracevault: push failed, git push blocked.\"; exit 1; }}\n"
149    );
150
151    if hook_path.exists() {
152        let existing = fs::read_to_string(&hook_path)?;
153
154        // Already has new-style hook
155        if existing.contains(HOOK_MARKER) {
156            return Ok(());
157        }
158
159        // Replace old-style hook block if present
160        if existing.contains(OLD_HOOK_MARKER) {
161            let mut new_content = String::new();
162            let mut skip = false;
163            for line in existing.lines() {
164                if line.contains(OLD_HOOK_MARKER) {
165                    skip = true;
166                    continue;
167                }
168                if skip {
169                    // Skip old tracevault lines (they start with "tracevault " or are empty continuations)
170                    if line.starts_with("tracevault ") {
171                        continue;
172                    }
173                    skip = false;
174                }
175                new_content.push_str(line);
176                new_content.push('\n');
177            }
178            if !new_content.ends_with('\n') {
179                new_content.push('\n');
180            }
181            new_content.push_str(&tracevault_block);
182            fs::write(&hook_path, new_content)?;
183        } else {
184            // Append to existing hook
185            let mut content = existing;
186            if !content.ends_with('\n') {
187                content.push('\n');
188            }
189            content.push_str(&tracevault_block);
190            fs::write(&hook_path, content)?;
191        }
192    } else {
193        let content = format!("#!/bin/sh\n{tracevault_block}");
194        fs::write(&hook_path, content)?;
195    }
196
197    // Make executable
198    #[cfg(unix)]
199    {
200        use std::os::unix::fs::PermissionsExt;
201        let mut perms = fs::metadata(&hook_path)?.permissions();
202        perms.set_mode(0o755);
203        fs::set_permissions(&hook_path, perms)?;
204    }
205
206    Ok(())
207}
208
209fn install_claude_hooks(project_root: &Path) -> Result<(), io::Error> {
210    let claude_dir = project_root.join(".claude");
211    fs::create_dir_all(&claude_dir)?;
212
213    let settings_path = claude_dir.join("settings.json");
214    let mut settings: serde_json::Value = if settings_path.exists() {
215        let content = fs::read_to_string(&settings_path)?;
216        serde_json::from_str(&content).map_err(|e| {
217            io::Error::new(
218                io::ErrorKind::InvalidData,
219                format!("Failed to parse .claude/settings.json: {e}"),
220            )
221        })?
222    } else {
223        serde_json::json!({})
224    };
225
226    let hooks = tracevault_hooks();
227
228    // Merge hooks into existing settings
229    let settings_obj = settings.as_object_mut().ok_or_else(|| {
230        io::Error::new(
231            io::ErrorKind::InvalidData,
232            ".claude/settings.json is not a JSON object",
233        )
234    })?;
235
236    settings_obj.insert("hooks".to_string(), hooks);
237
238    let formatted = serde_json::to_string_pretty(&settings)
239        .map_err(|e| io::Error::other(format!("Failed to serialize settings: {e}")))?;
240    fs::write(&settings_path, formatted)?;
241
242    Ok(())
243}
244
245pub fn tracevault_hooks() -> serde_json::Value {
246    serde_json::json!({
247        "PreToolUse": [{
248            "matcher": "Write|Edit",
249            "hooks": [{
250                "type": "command",
251                "command": "tracevault hook --event pre-tool-use",
252                "timeout": 5,
253                "statusMessage": "TraceVault: capturing pre-edit state"
254            }]
255        }],
256        "PostToolUse": [{
257            "matcher": "Write|Edit|Bash",
258            "hooks": [{
259                "type": "command",
260                "command": "tracevault hook --event post-tool-use",
261                "timeout": 5,
262                "statusMessage": "TraceVault: recording change"
263            }]
264        }]
265    })
266}