Skip to main content

tracevault_cli/commands/
init.rs

1use crate::api_client::ApiClient;
2use crate::config::TracevaultConfig;
3use std::fs;
4use std::io;
5use std::path::Path;
6
7pub fn git_remote_url(project_root: &Path) -> Option<String> {
8    std::process::Command::new("git")
9        .args(["remote", "get-url", "origin"])
10        .current_dir(project_root)
11        .output()
12        .ok()
13        .filter(|o| o.status.success())
14        .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
15        .filter(|s| !s.is_empty())
16}
17
18fn parse_github_org(remote_url: &str) -> Option<String> {
19    // SSH: git@github.com:softwaremill/tracevault.git
20    if let Some(path) = remote_url.strip_prefix("git@github.com:") {
21        return path.split('/').next().map(String::from);
22    }
23    // HTTPS: https://github.com/softwaremill/tracevault.git
24    if let Some(path) = remote_url
25        .strip_prefix("https://github.com/")
26        .or_else(|| remote_url.strip_prefix("http://github.com/"))
27    {
28        return path.split('/').next().map(String::from);
29    }
30    None
31}
32
33pub async fn init_in_directory(
34    project_root: &Path,
35    server_url: Option<&str>,
36) -> Result<(), io::Error> {
37    // Check for git repository
38    if !project_root.join(".git").exists() {
39        return Err(io::Error::new(
40            io::ErrorKind::NotFound,
41            "Not a git repository. Run 'git init' first.",
42        ));
43    }
44
45    // Create .tracevault/ directory
46    let config_dir = TracevaultConfig::config_dir(project_root);
47    fs::create_dir_all(&config_dir)?;
48    fs::create_dir_all(config_dir.join("sessions"))?;
49    fs::create_dir_all(config_dir.join("cache"))?;
50
51    // Register repo on server if authenticated, server URL known, and git remote available
52    let remote_url = git_remote_url(project_root);
53    if remote_url.is_none() {
54        eprintln!("Warning: no git remote 'origin' configured. Skipping server registration.");
55        eprintln!("Run 'git remote add origin <url>' then 'tracevault sync' to register.");
56    }
57
58    // Extract org slug from GitHub remote URL
59    let org_slug = remote_url.as_deref().and_then(parse_github_org);
60
61    // Write config (include server_url and org_slug if available)
62    let mut config = TracevaultConfig::default();
63    if let Some(url) = server_url {
64        config.server_url = Some(url.to_string());
65    }
66    config.org_slug = org_slug.clone();
67    fs::write(
68        TracevaultConfig::config_path(project_root),
69        config.to_toml(),
70    )?;
71
72    // Create .tracevault/.gitignore
73    fs::write(
74        config_dir.join(".gitignore"),
75        "sessions/\ncache/\n*.local.toml\n",
76    )?;
77
78    // Install Claude Code hooks into .claude/settings.json
79    install_claude_hooks(project_root)?;
80
81    // Install git hooks
82    install_git_hook(project_root)?;
83    install_post_commit_hook(project_root)?;
84
85    // Detect AI tools in the project
86    let detected = crate::hooks::detect_tools(project_root);
87    for tool in &detected {
88        println!("  Detected: {}", tool.name());
89    }
90
91    let (resolved_url, resolved_token) = crate::api_client::resolve_credentials(project_root);
92    let effective_url = server_url.map(String::from).or(resolved_url);
93
94    if resolved_token.is_none() {
95        eprintln!("Not logged in. Run 'tracevault login' to register this repo with the server.");
96    } else if let (Some(url), Some(remote), Some(slug)) = (effective_url, remote_url, org_slug) {
97        let client = ApiClient::new(&url, resolved_token.as_deref());
98        let repo_name = git_repo_name(project_root);
99
100        match client
101            .register_repo(
102                &slug,
103                crate::api_client::RegisterRepoRequest {
104                    repo_name,
105                    github_url: Some(remote),
106                },
107            )
108            .await
109        {
110            Ok(resp) => {
111                println!("Repo registered on server (id: {})", resp.repo_id);
112                // Save repo_id to config
113                if let Some(mut cfg) = TracevaultConfig::load(project_root) {
114                    cfg.repo_id = Some(resp.repo_id.to_string());
115                    let _ = fs::write(TracevaultConfig::config_path(project_root), cfg.to_toml());
116                }
117            }
118            Err(e) => {
119                let msg = e.to_string();
120                if msg.contains("404") {
121                    eprintln!("Warning: organization '{}' not found on the server.", slug);
122                    eprintln!(
123                        "Create it first at your TraceVault instance, then run 'tracevault sync'."
124                    );
125                } else if msg.contains("403") {
126                    eprintln!("Warning: you are not a member of organization '{}'.", slug);
127                } else {
128                    eprintln!("Warning: could not register repo on server: {e}");
129                }
130            }
131        }
132    }
133
134    Ok(())
135}
136
137fn git_repo_name(project_root: &Path) -> String {
138    std::process::Command::new("git")
139        .args(["rev-parse", "--show-toplevel"])
140        .current_dir(project_root)
141        .output()
142        .ok()
143        .filter(|o| o.status.success())
144        .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
145        .as_deref()
146        .and_then(|p| p.rsplit('/').next())
147        .map(String::from)
148        .unwrap_or_else(|| "unknown".into())
149}
150
151const HOOK_MARKER: &str = "# tracevault:enforce";
152const OLD_HOOK_MARKER: &str = "# tracevault:auto-push";
153
154fn install_git_hook(project_root: &Path) -> Result<(), io::Error> {
155    let hooks_dir = project_root.join(".git/hooks");
156    fs::create_dir_all(&hooks_dir)?;
157
158    let hook_path = hooks_dir.join("pre-push");
159    let tracevault_block = format!(
160        "{HOOK_MARKER}\ntracevault sync 2>/dev/null || true\ntracevault check || {{ echo \"tracevault: policy check failed\"; exit 1; }}\n"
161    );
162
163    if hook_path.exists() {
164        let existing = fs::read_to_string(&hook_path)?;
165
166        // Already has new-style hook
167        if existing.contains(HOOK_MARKER) {
168            return Ok(());
169        }
170
171        // Replace old-style hook block if present
172        if existing.contains(OLD_HOOK_MARKER) {
173            let mut new_content = String::new();
174            let mut skip = false;
175            for line in existing.lines() {
176                if line.contains(OLD_HOOK_MARKER) {
177                    skip = true;
178                    continue;
179                }
180                if skip {
181                    // Skip old tracevault lines (they start with "tracevault " or are empty continuations)
182                    if line.starts_with("tracevault ") {
183                        continue;
184                    }
185                    skip = false;
186                }
187                new_content.push_str(line);
188                new_content.push('\n');
189            }
190            if !new_content.ends_with('\n') {
191                new_content.push('\n');
192            }
193            new_content.push_str(&tracevault_block);
194            fs::write(&hook_path, new_content)?;
195        } else {
196            // Append to existing hook
197            let mut content = existing;
198            if !content.ends_with('\n') {
199                content.push('\n');
200            }
201            content.push_str(&tracevault_block);
202            fs::write(&hook_path, content)?;
203        }
204    } else {
205        let content = format!("#!/bin/sh\n{tracevault_block}");
206        fs::write(&hook_path, content)?;
207    }
208
209    // Make executable
210    #[cfg(unix)]
211    {
212        use std::os::unix::fs::PermissionsExt;
213        let mut perms = fs::metadata(&hook_path)?.permissions();
214        perms.set_mode(0o755);
215        fs::set_permissions(&hook_path, perms)?;
216    }
217
218    Ok(())
219}
220
221const POST_COMMIT_MARKER: &str = "# tracevault:post-commit";
222
223fn install_post_commit_hook(project_root: &Path) -> Result<(), io::Error> {
224    let hooks_dir = project_root.join(".git/hooks");
225    fs::create_dir_all(&hooks_dir)?;
226
227    let hook_path = hooks_dir.join("post-commit");
228    let tracevault_block = format!("{POST_COMMIT_MARKER}\ntracevault commit-push 2>/dev/null &\n");
229
230    if hook_path.exists() {
231        let existing = fs::read_to_string(&hook_path)?;
232
233        if existing.contains(POST_COMMIT_MARKER) {
234            return Ok(());
235        }
236
237        let mut content = existing;
238        if !content.ends_with('\n') {
239            content.push('\n');
240        }
241        content.push_str(&tracevault_block);
242        fs::write(&hook_path, content)?;
243    } else {
244        let content = format!("#!/bin/sh\n{tracevault_block}");
245        fs::write(&hook_path, content)?;
246    }
247
248    #[cfg(unix)]
249    {
250        use std::os::unix::fs::PermissionsExt;
251        let mut perms = fs::metadata(&hook_path)?.permissions();
252        perms.set_mode(0o755);
253        fs::set_permissions(&hook_path, perms)?;
254    }
255
256    Ok(())
257}
258
259fn install_claude_hooks(project_root: &Path) -> Result<(), io::Error> {
260    let claude_dir = project_root.join(".claude");
261    fs::create_dir_all(&claude_dir)?;
262
263    let settings_path = claude_dir.join("settings.json");
264    let mut settings: serde_json::Value = if settings_path.exists() {
265        let content = fs::read_to_string(&settings_path)?;
266        serde_json::from_str(&content).map_err(|e| {
267            io::Error::new(
268                io::ErrorKind::InvalidData,
269                format!("Failed to parse .claude/settings.json: {e}"),
270            )
271        })?
272    } else {
273        serde_json::json!({})
274    };
275
276    let hooks = tracevault_hooks();
277
278    // Merge hooks into existing settings
279    let settings_obj = settings.as_object_mut().ok_or_else(|| {
280        io::Error::new(
281            io::ErrorKind::InvalidData,
282            ".claude/settings.json is not a JSON object",
283        )
284    })?;
285
286    settings_obj.insert("hooks".to_string(), hooks);
287
288    let formatted = serde_json::to_string_pretty(&settings)
289        .map_err(|e| io::Error::other(format!("Failed to serialize settings: {e}")))?;
290    fs::write(&settings_path, formatted)?;
291
292    Ok(())
293}
294
295pub fn tracevault_hooks() -> serde_json::Value {
296    serde_json::json!({
297        "PreToolUse": [{
298            "matcher": "Write|Edit|Bash",
299            "hooks": [{
300                "type": "command",
301                "command": "tracevault stream --event pre-tool-use",
302                "timeout": 10,
303                "statusMessage": "TraceVault: streaming pre-tool event"
304            }]
305        }],
306        "PostToolUse": [{
307            "matcher": "",
308            "hooks": [{
309                "type": "command",
310                "command": "tracevault stream --event post-tool-use",
311                "timeout": 10,
312                "statusMessage": "TraceVault: streaming post-tool event"
313            }]
314        }],
315        "Notification": [{
316            "matcher": "",
317            "hooks": [{
318                "type": "command",
319                "command": "tracevault stream --event notification",
320                "timeout": 10,
321                "statusMessage": "TraceVault: streaming notification"
322            }]
323        }],
324        "Stop": [{
325            "matcher": "",
326            "hooks": [{
327                "type": "command",
328                "command": "tracevault stream --event stop",
329                "timeout": 10,
330                "statusMessage": "TraceVault: finalizing session"
331            }]
332        }]
333    })
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn parse_github_org_ssh() {
342        assert_eq!(
343            parse_github_org("git@github.com:myorg/myrepo.git"),
344            Some("myorg".into())
345        );
346    }
347
348    #[test]
349    fn parse_github_org_https() {
350        assert_eq!(
351            parse_github_org("https://github.com/myorg/myrepo"),
352            Some("myorg".into())
353        );
354    }
355
356    #[test]
357    fn parse_github_org_non_github_returns_none() {
358        assert_eq!(parse_github_org("https://gitlab.com/org/repo"), None);
359    }
360
361    #[test]
362    fn parse_github_org_invalid() {
363        assert_eq!(parse_github_org("not-a-url"), None);
364    }
365}