Skip to main content

tracevault_cli/commands/
init.rs

1use crate::api_client::ApiClient;
2use crate::config::TracevaultConfig;
3use std::fs;
4use std::io;
5use std::path::Path;
6
7pub fn git_remote_url(project_root: &Path) -> Option<String> {
8    std::process::Command::new("git")
9        .args(["remote", "get-url", "origin"])
10        .current_dir(project_root)
11        .output()
12        .ok()
13        .filter(|o| o.status.success())
14        .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
15        .filter(|s| !s.is_empty())
16}
17
18fn parse_github_org(remote_url: &str) -> Option<String> {
19    // SSH: git@github.com:softwaremill/tracevault.git
20    if let Some(path) = remote_url.strip_prefix("git@github.com:") {
21        return path.split('/').next().map(String::from);
22    }
23    // HTTPS: https://github.com/softwaremill/tracevault.git
24    if let Some(path) = remote_url
25        .strip_prefix("https://github.com/")
26        .or_else(|| remote_url.strip_prefix("http://github.com/"))
27    {
28        return path.split('/').next().map(String::from);
29    }
30    None
31}
32
33pub async fn init_in_directory(
34    project_root: &Path,
35    server_url: Option<&str>,
36) -> Result<(), io::Error> {
37    // Check for git repository
38    if !project_root.join(".git").exists() {
39        return Err(io::Error::new(
40            io::ErrorKind::NotFound,
41            "Not a git repository. Run 'git init' first.",
42        ));
43    }
44
45    // Create .tracevault/ directory
46    let config_dir = TracevaultConfig::config_dir(project_root);
47    fs::create_dir_all(&config_dir)?;
48    fs::create_dir_all(config_dir.join("sessions"))?;
49    fs::create_dir_all(config_dir.join("cache"))?;
50
51    // Register repo on server if authenticated, server URL known, and git remote available
52    let remote_url = git_remote_url(project_root);
53    if remote_url.is_none() {
54        eprintln!("Warning: no git remote 'origin' configured. Skipping server registration.");
55        eprintln!("Run 'git remote add origin <url>' then 'tracevault sync' to register.");
56    }
57
58    // Extract org slug from GitHub remote URL
59    let org_slug = remote_url.as_deref().and_then(parse_github_org);
60
61    // Write config (include server_url and org_slug if available)
62    let mut config = TracevaultConfig::default();
63    if let Some(url) = server_url {
64        config.server_url = Some(url.to_string());
65    }
66    config.org_slug = org_slug.clone();
67    fs::write(
68        TracevaultConfig::config_path(project_root),
69        config.to_toml(),
70    )?;
71
72    // Create .tracevault/.gitignore
73    fs::write(
74        config_dir.join(".gitignore"),
75        "sessions/\ncache/\n*.local.toml\n",
76    )?;
77
78    // Install Claude Code hooks into .claude/settings.json
79    install_claude_hooks(project_root)?;
80
81    // Install git hooks
82    install_git_hook(project_root)?;
83    install_post_commit_hook(project_root)?;
84
85    let (resolved_url, resolved_token) = crate::api_client::resolve_credentials(project_root);
86    let effective_url = server_url.map(String::from).or(resolved_url);
87
88    if resolved_token.is_none() {
89        eprintln!("Not logged in. Run 'tracevault login' to register this repo with the server.");
90    } else if let (Some(url), Some(remote), Some(slug)) = (effective_url, remote_url, org_slug) {
91        let client = ApiClient::new(&url, resolved_token.as_deref());
92        let repo_name = git_repo_name(project_root);
93
94        match client
95            .register_repo(
96                &slug,
97                crate::api_client::RegisterRepoRequest {
98                    repo_name,
99                    github_url: Some(remote),
100                },
101            )
102            .await
103        {
104            Ok(resp) => {
105                println!("Repo registered on server (id: {})", resp.repo_id);
106                // Save repo_id to config
107                if let Some(mut cfg) = TracevaultConfig::load(project_root) {
108                    cfg.repo_id = Some(resp.repo_id.to_string());
109                    let _ = fs::write(TracevaultConfig::config_path(project_root), cfg.to_toml());
110                }
111            }
112            Err(e) => {
113                let msg = e.to_string();
114                if msg.contains("404") {
115                    eprintln!("Warning: organization '{}' not found on the server.", slug);
116                    eprintln!(
117                        "Create it first at your TraceVault instance, then run 'tracevault sync'."
118                    );
119                } else if msg.contains("403") {
120                    eprintln!("Warning: you are not a member of organization '{}'.", slug);
121                } else {
122                    eprintln!("Warning: could not register repo on server: {e}");
123                }
124            }
125        }
126    }
127
128    Ok(())
129}
130
131fn git_repo_name(project_root: &Path) -> String {
132    std::process::Command::new("git")
133        .args(["rev-parse", "--show-toplevel"])
134        .current_dir(project_root)
135        .output()
136        .ok()
137        .filter(|o| o.status.success())
138        .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
139        .as_deref()
140        .and_then(|p| p.rsplit('/').next())
141        .map(String::from)
142        .unwrap_or_else(|| "unknown".into())
143}
144
145const HOOK_MARKER: &str = "# tracevault:enforce";
146const OLD_HOOK_MARKER: &str = "# tracevault:auto-push";
147
148fn install_git_hook(project_root: &Path) -> Result<(), io::Error> {
149    let hooks_dir = project_root.join(".git/hooks");
150    fs::create_dir_all(&hooks_dir)?;
151
152    let hook_path = hooks_dir.join("pre-push");
153    let tracevault_block = format!(
154        "{HOOK_MARKER}\ntracevault sync 2>/dev/null || true\ntracevault check || {{ echo \"tracevault: policy check failed\"; exit 1; }}\n"
155    );
156
157    if hook_path.exists() {
158        let existing = fs::read_to_string(&hook_path)?;
159
160        // Already has new-style hook
161        if existing.contains(HOOK_MARKER) {
162            return Ok(());
163        }
164
165        // Replace old-style hook block if present
166        if existing.contains(OLD_HOOK_MARKER) {
167            let mut new_content = String::new();
168            let mut skip = false;
169            for line in existing.lines() {
170                if line.contains(OLD_HOOK_MARKER) {
171                    skip = true;
172                    continue;
173                }
174                if skip {
175                    // Skip old tracevault lines (they start with "tracevault " or are empty continuations)
176                    if line.starts_with("tracevault ") {
177                        continue;
178                    }
179                    skip = false;
180                }
181                new_content.push_str(line);
182                new_content.push('\n');
183            }
184            if !new_content.ends_with('\n') {
185                new_content.push('\n');
186            }
187            new_content.push_str(&tracevault_block);
188            fs::write(&hook_path, new_content)?;
189        } else {
190            // Append to existing hook
191            let mut content = existing;
192            if !content.ends_with('\n') {
193                content.push('\n');
194            }
195            content.push_str(&tracevault_block);
196            fs::write(&hook_path, content)?;
197        }
198    } else {
199        let content = format!("#!/bin/sh\n{tracevault_block}");
200        fs::write(&hook_path, content)?;
201    }
202
203    // Make executable
204    #[cfg(unix)]
205    {
206        use std::os::unix::fs::PermissionsExt;
207        let mut perms = fs::metadata(&hook_path)?.permissions();
208        perms.set_mode(0o755);
209        fs::set_permissions(&hook_path, perms)?;
210    }
211
212    Ok(())
213}
214
215const POST_COMMIT_MARKER: &str = "# tracevault:post-commit";
216
217fn install_post_commit_hook(project_root: &Path) -> Result<(), io::Error> {
218    let hooks_dir = project_root.join(".git/hooks");
219    fs::create_dir_all(&hooks_dir)?;
220
221    let hook_path = hooks_dir.join("post-commit");
222    let tracevault_block = format!("{POST_COMMIT_MARKER}\ntracevault commit-push 2>/dev/null &\n");
223
224    if hook_path.exists() {
225        let existing = fs::read_to_string(&hook_path)?;
226
227        if existing.contains(POST_COMMIT_MARKER) {
228            return Ok(());
229        }
230
231        let mut content = existing;
232        if !content.ends_with('\n') {
233            content.push('\n');
234        }
235        content.push_str(&tracevault_block);
236        fs::write(&hook_path, content)?;
237    } else {
238        let content = format!("#!/bin/sh\n{tracevault_block}");
239        fs::write(&hook_path, content)?;
240    }
241
242    #[cfg(unix)]
243    {
244        use std::os::unix::fs::PermissionsExt;
245        let mut perms = fs::metadata(&hook_path)?.permissions();
246        perms.set_mode(0o755);
247        fs::set_permissions(&hook_path, perms)?;
248    }
249
250    Ok(())
251}
252
253fn install_claude_hooks(project_root: &Path) -> Result<(), io::Error> {
254    let claude_dir = project_root.join(".claude");
255    fs::create_dir_all(&claude_dir)?;
256
257    let settings_path = claude_dir.join("settings.json");
258    let mut settings: serde_json::Value = if settings_path.exists() {
259        let content = fs::read_to_string(&settings_path)?;
260        serde_json::from_str(&content).map_err(|e| {
261            io::Error::new(
262                io::ErrorKind::InvalidData,
263                format!("Failed to parse .claude/settings.json: {e}"),
264            )
265        })?
266    } else {
267        serde_json::json!({})
268    };
269
270    let hooks = tracevault_hooks();
271
272    // Merge hooks into existing settings
273    let settings_obj = settings.as_object_mut().ok_or_else(|| {
274        io::Error::new(
275            io::ErrorKind::InvalidData,
276            ".claude/settings.json is not a JSON object",
277        )
278    })?;
279
280    settings_obj.insert("hooks".to_string(), hooks);
281
282    let formatted = serde_json::to_string_pretty(&settings)
283        .map_err(|e| io::Error::other(format!("Failed to serialize settings: {e}")))?;
284    fs::write(&settings_path, formatted)?;
285
286    Ok(())
287}
288
289pub fn tracevault_hooks() -> serde_json::Value {
290    serde_json::json!({
291        "PreToolUse": [{
292            "matcher": "Write|Edit|Bash",
293            "hooks": [{
294                "type": "command",
295                "command": "tracevault stream --event pre-tool-use",
296                "timeout": 10,
297                "statusMessage": "TraceVault: streaming pre-tool event"
298            }]
299        }],
300        "PostToolUse": [{
301            "matcher": "",
302            "hooks": [{
303                "type": "command",
304                "command": "tracevault stream --event post-tool-use",
305                "timeout": 10,
306                "statusMessage": "TraceVault: streaming post-tool event"
307            }]
308        }],
309        "Notification": [{
310            "matcher": "",
311            "hooks": [{
312                "type": "command",
313                "command": "tracevault stream --event notification",
314                "timeout": 10,
315                "statusMessage": "TraceVault: streaming notification"
316            }]
317        }],
318        "Stop": [{
319            "matcher": "",
320            "hooks": [{
321                "type": "command",
322                "command": "tracevault stream --event stop",
323                "timeout": 10,
324                "statusMessage": "TraceVault: finalizing session"
325            }]
326        }]
327    })
328}