Skip to main content

rustant_core/
project_detect.rs

1//! Project type detection for zero-config initialization.
2//!
3//! Scans a workspace directory to identify the project's language, framework,
4//! and build system. Used by `rustant init` to generate optimal default
5//! configurations without requiring manual setup.
6
7use std::path::Path;
8
9/// Detected project type based on workspace analysis.
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum ProjectType {
12    Rust,
13    Node,
14    Python,
15    Go,
16    Java,
17    Ruby,
18    CSharp,
19    Cpp,
20    /// Multiple languages detected.
21    Mixed(Vec<ProjectType>),
22    /// No recognized project markers found.
23    Unknown,
24}
25
26impl std::fmt::Display for ProjectType {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            ProjectType::Rust => write!(f, "Rust"),
30            ProjectType::Node => write!(f, "Node.js"),
31            ProjectType::Python => write!(f, "Python"),
32            ProjectType::Go => write!(f, "Go"),
33            ProjectType::Java => write!(f, "Java"),
34            ProjectType::Ruby => write!(f, "Ruby"),
35            ProjectType::CSharp => write!(f, "C#"),
36            ProjectType::Cpp => write!(f, "C/C++"),
37            ProjectType::Mixed(types) => {
38                let names: Vec<String> = types.iter().map(|t| t.to_string()).collect();
39                write!(f, "Mixed ({})", names.join(", "))
40            }
41            ProjectType::Unknown => write!(f, "Unknown"),
42        }
43    }
44}
45
46/// Result of project detection with rich metadata.
47#[derive(Debug, Clone)]
48pub struct ProjectInfo {
49    /// Primary detected project type.
50    pub project_type: ProjectType,
51    /// Whether the project has a git repository.
52    pub has_git: bool,
53    /// Whether the git working tree is clean.
54    pub git_clean: bool,
55    /// Detected build tool commands (e.g., "cargo", "npm", "make").
56    pub build_commands: Vec<String>,
57    /// Detected test commands (e.g., "cargo test", "npm test").
58    pub test_commands: Vec<String>,
59    /// Detected package manager.
60    pub package_manager: Option<String>,
61    /// Key source directories found.
62    pub source_dirs: Vec<String>,
63    /// Whether a CI configuration was found.
64    pub has_ci: bool,
65    /// Detected framework (e.g., "React", "Django", "Actix").
66    pub framework: Option<String>,
67}
68
69/// Detect the project type and metadata from a workspace directory.
70pub fn detect_project(workspace: &Path) -> ProjectInfo {
71    let mut types = Vec::new();
72    let mut build_commands = Vec::new();
73    let mut test_commands = Vec::new();
74    let mut package_manager = None;
75    let mut source_dirs = Vec::new();
76    let mut framework = None;
77
78    // Rust detection
79    if workspace.join("Cargo.toml").exists() {
80        types.push(ProjectType::Rust);
81        build_commands.push("cargo build".to_string());
82        test_commands.push("cargo test".to_string());
83        package_manager = Some("cargo".to_string());
84        if workspace.join("src").exists() {
85            source_dirs.push("src".to_string());
86        }
87        // Detect Rust frameworks
88        if let Ok(content) = std::fs::read_to_string(workspace.join("Cargo.toml")) {
89            if content.contains("actix-web") {
90                framework = Some("Actix Web".to_string());
91            } else if content.contains("axum") {
92                framework = Some("Axum".to_string());
93            } else if content.contains("rocket") {
94                framework = Some("Rocket".to_string());
95            } else if content.contains("tauri") {
96                framework = Some("Tauri".to_string());
97            }
98        }
99    }
100
101    // Node.js detection
102    if workspace.join("package.json").exists() {
103        types.push(ProjectType::Node);
104        if workspace.join("pnpm-lock.yaml").exists() {
105            build_commands.push("pnpm build".to_string());
106            test_commands.push("pnpm test".to_string());
107            package_manager = Some("pnpm".to_string());
108        } else if workspace.join("yarn.lock").exists() {
109            build_commands.push("yarn build".to_string());
110            test_commands.push("yarn test".to_string());
111            package_manager = Some("yarn".to_string());
112        } else if workspace.join("bun.lockb").exists() {
113            build_commands.push("bun build".to_string());
114            test_commands.push("bun test".to_string());
115            package_manager = Some("bun".to_string());
116        } else {
117            build_commands.push("npm run build".to_string());
118            test_commands.push("npm test".to_string());
119            package_manager = Some("npm".to_string());
120        }
121        if workspace.join("src").exists() {
122            source_dirs.push("src".to_string());
123        }
124        // Detect Node frameworks
125        if let Ok(content) = std::fs::read_to_string(workspace.join("package.json")) {
126            if content.contains("\"next\"") {
127                framework = Some("Next.js".to_string());
128            } else if content.contains("\"react\"") {
129                framework = Some("React".to_string());
130            } else if content.contains("\"vue\"") {
131                framework = Some("Vue.js".to_string());
132            } else if content.contains("\"svelte\"") {
133                framework = Some("Svelte".to_string());
134            } else if content.contains("\"express\"") {
135                framework = Some("Express".to_string());
136            } else if content.contains("\"nestjs\"") || content.contains("\"@nestjs/core\"") {
137                framework = Some("NestJS".to_string());
138            }
139        }
140    }
141
142    // Python detection
143    if workspace.join("pyproject.toml").exists()
144        || workspace.join("setup.py").exists()
145        || workspace.join("requirements.txt").exists()
146        || workspace.join("Pipfile").exists()
147    {
148        types.push(ProjectType::Python);
149        if workspace.join("pyproject.toml").exists() {
150            if workspace.join("poetry.lock").exists() {
151                build_commands.push("poetry build".to_string());
152                test_commands.push("poetry run pytest".to_string());
153                package_manager = Some("poetry".to_string());
154            } else if workspace.join("uv.lock").exists() {
155                test_commands.push("uv run pytest".to_string());
156                package_manager = Some("uv".to_string());
157            } else {
158                test_commands.push("python -m pytest".to_string());
159                package_manager = Some("pip".to_string());
160            }
161        } else if workspace.join("Pipfile").exists() {
162            test_commands.push("pipenv run pytest".to_string());
163            package_manager = Some("pipenv".to_string());
164        } else {
165            test_commands.push("python -m pytest".to_string());
166            package_manager = Some("pip".to_string());
167        }
168        // Detect Python frameworks
169        let py_files = ["pyproject.toml", "requirements.txt", "setup.py"];
170        for file in &py_files {
171            if let Ok(content) = std::fs::read_to_string(workspace.join(file)) {
172                if content.contains("django") {
173                    framework = Some("Django".to_string());
174                    break;
175                } else if content.contains("fastapi") {
176                    framework = Some("FastAPI".to_string());
177                    break;
178                } else if content.contains("flask") {
179                    framework = Some("Flask".to_string());
180                    break;
181                }
182            }
183        }
184    }
185
186    // Go detection
187    if workspace.join("go.mod").exists() {
188        types.push(ProjectType::Go);
189        build_commands.push("go build ./...".to_string());
190        test_commands.push("go test ./...".to_string());
191        package_manager = Some("go".to_string());
192    }
193
194    // Java detection
195    if workspace.join("pom.xml").exists() {
196        types.push(ProjectType::Java);
197        build_commands.push("mvn compile".to_string());
198        test_commands.push("mvn test".to_string());
199        package_manager = Some("maven".to_string());
200    } else if workspace.join("build.gradle").exists() || workspace.join("build.gradle.kts").exists()
201    {
202        types.push(ProjectType::Java);
203        build_commands.push("./gradlew build".to_string());
204        test_commands.push("./gradlew test".to_string());
205        package_manager = Some("gradle".to_string());
206    }
207
208    // Ruby detection
209    if workspace.join("Gemfile").exists() {
210        types.push(ProjectType::Ruby);
211        test_commands.push("bundle exec rspec".to_string());
212        package_manager = Some("bundler".to_string());
213        if workspace.join("config").join("routes.rb").exists() {
214            framework = Some("Rails".to_string());
215        }
216    }
217
218    // C# detection
219    let has_csharp = workspace.join("*.csproj").exists()
220        || std::fs::read_dir(workspace)
221            .map(|entries| {
222                entries
223                    .filter_map(|e| e.ok())
224                    .any(|e| e.path().extension().is_some_and(|ext| ext == "csproj"))
225            })
226            .unwrap_or(false);
227    if has_csharp || workspace.join("*.sln").exists() {
228        types.push(ProjectType::CSharp);
229        build_commands.push("dotnet build".to_string());
230        test_commands.push("dotnet test".to_string());
231        package_manager = Some("nuget".to_string());
232    }
233
234    // C/C++ detection
235    if workspace.join("CMakeLists.txt").exists() || workspace.join("Makefile").exists() {
236        types.push(ProjectType::Cpp);
237        if workspace.join("CMakeLists.txt").exists() {
238            build_commands.push("cmake --build build".to_string());
239            package_manager = Some("cmake".to_string());
240        } else {
241            build_commands.push("make".to_string());
242        }
243    }
244
245    // Git detection
246    let has_git = workspace.join(".git").exists();
247    let git_clean = if has_git {
248        std::process::Command::new("git")
249            .args(["status", "--porcelain"])
250            .current_dir(workspace)
251            .output()
252            .map(|o| o.stdout.is_empty())
253            .unwrap_or(false)
254    } else {
255        false
256    };
257
258    // CI detection
259    let has_ci = workspace.join(".github").join("workflows").exists()
260        || workspace.join(".gitlab-ci.yml").exists()
261        || workspace.join(".circleci").exists()
262        || workspace.join("Jenkinsfile").exists();
263
264    // Determine primary project type
265    let project_type = match types.len() {
266        0 => ProjectType::Unknown,
267        1 => types.into_iter().next().unwrap(),
268        _ => ProjectType::Mixed(types),
269    };
270
271    ProjectInfo {
272        project_type,
273        has_git,
274        git_clean,
275        build_commands,
276        test_commands,
277        package_manager,
278        source_dirs,
279        has_ci,
280        framework,
281    }
282}
283
284/// Generate recommended safety allowed_commands based on project type.
285pub fn recommended_allowed_commands(info: &ProjectInfo) -> Vec<String> {
286    let mut commands = vec!["git".to_string(), "echo".to_string(), "cat".to_string()];
287
288    match &info.project_type {
289        ProjectType::Rust => {
290            commands.extend([
291                "cargo".to_string(),
292                "rustfmt".to_string(),
293                "clippy-driver".to_string(),
294            ]);
295        }
296        ProjectType::Node => {
297            commands.extend(["node".to_string(), "npx".to_string()]);
298            if let Some(pm) = &info.package_manager {
299                commands.push(pm.clone());
300            }
301        }
302        ProjectType::Python => {
303            commands.extend([
304                "python".to_string(),
305                "python3".to_string(),
306                "pytest".to_string(),
307            ]);
308            if let Some(pm) = &info.package_manager {
309                commands.push(pm.clone());
310            }
311        }
312        ProjectType::Go => {
313            commands.extend(["go".to_string(), "gofmt".to_string()]);
314        }
315        ProjectType::Java => {
316            if let Some(pm) = &info.package_manager {
317                match pm.as_str() {
318                    "maven" => commands.push("mvn".to_string()),
319                    "gradle" => commands.push("./gradlew".to_string()),
320                    _ => {}
321                }
322            }
323        }
324        ProjectType::Ruby => {
325            commands.extend(["ruby".to_string(), "bundle".to_string(), "rake".to_string()]);
326        }
327        ProjectType::CSharp => {
328            commands.push("dotnet".to_string());
329        }
330        ProjectType::Cpp => {
331            commands.extend(["cmake".to_string(), "make".to_string()]);
332        }
333        ProjectType::Mixed(types) => {
334            for t in types {
335                let sub_info = ProjectInfo {
336                    project_type: t.clone(),
337                    has_git: false,
338                    git_clean: false,
339                    build_commands: vec![],
340                    test_commands: vec![],
341                    package_manager: info.package_manager.clone(),
342                    source_dirs: vec![],
343                    has_ci: false,
344                    framework: None,
345                };
346                commands.extend(recommended_allowed_commands(&sub_info));
347            }
348            commands.sort();
349            commands.dedup();
350        }
351        ProjectType::Unknown => {}
352    }
353
354    commands
355}
356
357/// Generate example tasks tailored to the detected project type.
358pub fn example_tasks(info: &ProjectInfo) -> Vec<String> {
359    let mut tasks = Vec::new();
360
361    match &info.project_type {
362        ProjectType::Rust => {
363            tasks.push("\"Fix the compiler warnings in src/main.rs\"".to_string());
364            tasks.push("\"Add error handling to the database module\"".to_string());
365            tasks.push("\"Write tests for the authentication logic\"".to_string());
366        }
367        ProjectType::Node => {
368            tasks.push("\"Add input validation to the API endpoints\"".to_string());
369            tasks.push("\"Fix the failing test in auth.test.ts\"".to_string());
370            tasks.push("\"Refactor the user service to use async/await\"".to_string());
371        }
372        ProjectType::Python => {
373            tasks.push("\"Add type hints to the data processing module\"".to_string());
374            tasks.push("\"Write unit tests for the API handlers\"".to_string());
375            tasks.push("\"Fix the race condition in the worker pool\"".to_string());
376        }
377        ProjectType::Go => {
378            tasks.push("\"Add error wrapping to the HTTP handlers\"".to_string());
379            tasks.push("\"Write table-driven tests for the parser\"".to_string());
380            tasks.push("\"Implement graceful shutdown for the server\"".to_string());
381        }
382        ProjectType::Java => {
383            tasks.push("\"Add null safety checks to the service layer\"".to_string());
384            tasks.push("\"Write integration tests for the REST controllers\"".to_string());
385            tasks.push("\"Refactor the DAO layer to use the repository pattern\"".to_string());
386        }
387        _ => {
388            tasks.push("\"Find and fix bugs in the codebase\"".to_string());
389            tasks.push("\"Add tests for the main module\"".to_string());
390            tasks.push("\"Explain the architecture of this project\"".to_string());
391        }
392    }
393
394    tasks
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use tempfile::TempDir;
401
402    #[test]
403    fn test_detect_rust_project() {
404        let dir = TempDir::new().unwrap();
405        std::fs::write(
406            dir.path().join("Cargo.toml"),
407            "[package]\nname = \"test\"\nversion = \"0.1.0\"",
408        )
409        .unwrap();
410        std::fs::create_dir(dir.path().join("src")).unwrap();
411
412        let info = detect_project(dir.path());
413        assert_eq!(info.project_type, ProjectType::Rust);
414        assert!(info.build_commands.contains(&"cargo build".to_string()));
415        assert!(info.test_commands.contains(&"cargo test".to_string()));
416        assert_eq!(info.package_manager, Some("cargo".to_string()));
417        assert!(info.source_dirs.contains(&"src".to_string()));
418    }
419
420    #[test]
421    fn test_detect_node_project_npm() {
422        let dir = TempDir::new().unwrap();
423        std::fs::write(
424            dir.path().join("package.json"),
425            r#"{"name": "test", "version": "1.0.0"}"#,
426        )
427        .unwrap();
428
429        let info = detect_project(dir.path());
430        assert_eq!(info.project_type, ProjectType::Node);
431        assert_eq!(info.package_manager, Some("npm".to_string()));
432    }
433
434    #[test]
435    fn test_detect_node_project_pnpm() {
436        let dir = TempDir::new().unwrap();
437        std::fs::write(
438            dir.path().join("package.json"),
439            r#"{"name": "test", "version": "1.0.0"}"#,
440        )
441        .unwrap();
442        std::fs::write(dir.path().join("pnpm-lock.yaml"), "lockfileVersion: 9\n").unwrap();
443
444        let info = detect_project(dir.path());
445        assert_eq!(info.project_type, ProjectType::Node);
446        assert_eq!(info.package_manager, Some("pnpm".to_string()));
447    }
448
449    #[test]
450    fn test_detect_python_project() {
451        let dir = TempDir::new().unwrap();
452        std::fs::write(
453            dir.path().join("pyproject.toml"),
454            "[project]\nname = \"test\"",
455        )
456        .unwrap();
457
458        let info = detect_project(dir.path());
459        assert_eq!(info.project_type, ProjectType::Python);
460    }
461
462    #[test]
463    fn test_detect_go_project() {
464        let dir = TempDir::new().unwrap();
465        std::fs::write(
466            dir.path().join("go.mod"),
467            "module example.com/test\n\ngo 1.21\n",
468        )
469        .unwrap();
470
471        let info = detect_project(dir.path());
472        assert_eq!(info.project_type, ProjectType::Go);
473        assert!(info.build_commands.contains(&"go build ./...".to_string()));
474    }
475
476    #[test]
477    fn test_detect_mixed_project() {
478        let dir = TempDir::new().unwrap();
479        std::fs::write(dir.path().join("Cargo.toml"), "[package]\nname = \"test\"").unwrap();
480        std::fs::write(dir.path().join("package.json"), r#"{"name": "test"}"#).unwrap();
481
482        let info = detect_project(dir.path());
483        match &info.project_type {
484            ProjectType::Mixed(types) => {
485                assert!(types.contains(&ProjectType::Rust));
486                assert!(types.contains(&ProjectType::Node));
487            }
488            _ => panic!("Expected Mixed project type"),
489        }
490    }
491
492    #[test]
493    fn test_detect_unknown_project() {
494        let dir = TempDir::new().unwrap();
495        let info = detect_project(dir.path());
496        assert_eq!(info.project_type, ProjectType::Unknown);
497    }
498
499    #[test]
500    fn test_detect_git_status() {
501        let dir = TempDir::new().unwrap();
502        // No .git directory
503        let info = detect_project(dir.path());
504        assert!(!info.has_git);
505    }
506
507    #[test]
508    fn test_detect_ci() {
509        let dir = TempDir::new().unwrap();
510        let gh_dir = dir.path().join(".github").join("workflows");
511        std::fs::create_dir_all(&gh_dir).unwrap();
512        std::fs::write(gh_dir.join("ci.yml"), "name: CI").unwrap();
513
514        let info = detect_project(dir.path());
515        assert!(info.has_ci);
516    }
517
518    #[test]
519    fn test_recommended_commands_rust() {
520        let info = ProjectInfo {
521            project_type: ProjectType::Rust,
522            has_git: true,
523            git_clean: true,
524            build_commands: vec!["cargo build".to_string()],
525            test_commands: vec!["cargo test".to_string()],
526            package_manager: Some("cargo".to_string()),
527            source_dirs: vec!["src".to_string()],
528            has_ci: false,
529            framework: None,
530        };
531        let cmds = recommended_allowed_commands(&info);
532        assert!(cmds.contains(&"cargo".to_string()));
533        assert!(cmds.contains(&"git".to_string()));
534    }
535
536    #[test]
537    fn test_example_tasks_rust() {
538        let info = ProjectInfo {
539            project_type: ProjectType::Rust,
540            has_git: true,
541            git_clean: true,
542            build_commands: vec![],
543            test_commands: vec![],
544            package_manager: None,
545            source_dirs: vec![],
546            has_ci: false,
547            framework: None,
548        };
549        let tasks = example_tasks(&info);
550        assert!(!tasks.is_empty());
551    }
552
553    #[test]
554    fn test_detect_rust_framework_axum() {
555        let dir = TempDir::new().unwrap();
556        std::fs::write(
557            dir.path().join("Cargo.toml"),
558            "[dependencies]\naxum = \"0.7\"",
559        )
560        .unwrap();
561
562        let info = detect_project(dir.path());
563        assert_eq!(info.framework, Some("Axum".to_string()));
564    }
565
566    #[test]
567    fn test_project_type_display() {
568        assert_eq!(ProjectType::Rust.to_string(), "Rust");
569        assert_eq!(ProjectType::Node.to_string(), "Node.js");
570        assert_eq!(ProjectType::Unknown.to_string(), "Unknown");
571        assert_eq!(
572            ProjectType::Mixed(vec![ProjectType::Rust, ProjectType::Node]).to_string(),
573            "Mixed (Rust, Node.js)"
574        );
575    }
576}