1use std::path::Path;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum ProjectType {
12 Rust,
13 Node,
14 Python,
15 Go,
16 Java,
17 Ruby,
18 CSharp,
19 Cpp,
20 Mixed(Vec<ProjectType>),
22 Unknown,
24}
25
26impl std::fmt::Display for ProjectType {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 ProjectType::Rust => write!(f, "Rust"),
30 ProjectType::Node => write!(f, "Node.js"),
31 ProjectType::Python => write!(f, "Python"),
32 ProjectType::Go => write!(f, "Go"),
33 ProjectType::Java => write!(f, "Java"),
34 ProjectType::Ruby => write!(f, "Ruby"),
35 ProjectType::CSharp => write!(f, "C#"),
36 ProjectType::Cpp => write!(f, "C/C++"),
37 ProjectType::Mixed(types) => {
38 let names: Vec<String> = types.iter().map(|t| t.to_string()).collect();
39 write!(f, "Mixed ({})", names.join(", "))
40 }
41 ProjectType::Unknown => write!(f, "Unknown"),
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct ProjectInfo {
49 pub project_type: ProjectType,
51 pub has_git: bool,
53 pub git_clean: bool,
55 pub build_commands: Vec<String>,
57 pub test_commands: Vec<String>,
59 pub package_manager: Option<String>,
61 pub source_dirs: Vec<String>,
63 pub has_ci: bool,
65 pub framework: Option<String>,
67}
68
69pub fn detect_project(workspace: &Path) -> ProjectInfo {
71 let mut types = Vec::new();
72 let mut build_commands = Vec::new();
73 let mut test_commands = Vec::new();
74 let mut package_manager = None;
75 let mut source_dirs = Vec::new();
76 let mut framework = None;
77
78 if workspace.join("Cargo.toml").exists() {
80 types.push(ProjectType::Rust);
81 build_commands.push("cargo build".to_string());
82 test_commands.push("cargo test".to_string());
83 package_manager = Some("cargo".to_string());
84 if workspace.join("src").exists() {
85 source_dirs.push("src".to_string());
86 }
87 if let Ok(content) = std::fs::read_to_string(workspace.join("Cargo.toml")) {
89 if content.contains("actix-web") {
90 framework = Some("Actix Web".to_string());
91 } else if content.contains("axum") {
92 framework = Some("Axum".to_string());
93 } else if content.contains("rocket") {
94 framework = Some("Rocket".to_string());
95 } else if content.contains("tauri") {
96 framework = Some("Tauri".to_string());
97 }
98 }
99 }
100
101 if workspace.join("package.json").exists() {
103 types.push(ProjectType::Node);
104 if workspace.join("pnpm-lock.yaml").exists() {
105 build_commands.push("pnpm build".to_string());
106 test_commands.push("pnpm test".to_string());
107 package_manager = Some("pnpm".to_string());
108 } else if workspace.join("yarn.lock").exists() {
109 build_commands.push("yarn build".to_string());
110 test_commands.push("yarn test".to_string());
111 package_manager = Some("yarn".to_string());
112 } else if workspace.join("bun.lockb").exists() {
113 build_commands.push("bun build".to_string());
114 test_commands.push("bun test".to_string());
115 package_manager = Some("bun".to_string());
116 } else {
117 build_commands.push("npm run build".to_string());
118 test_commands.push("npm test".to_string());
119 package_manager = Some("npm".to_string());
120 }
121 if workspace.join("src").exists() {
122 source_dirs.push("src".to_string());
123 }
124 if let Ok(content) = std::fs::read_to_string(workspace.join("package.json")) {
126 if content.contains("\"next\"") {
127 framework = Some("Next.js".to_string());
128 } else if content.contains("\"react\"") {
129 framework = Some("React".to_string());
130 } else if content.contains("\"vue\"") {
131 framework = Some("Vue.js".to_string());
132 } else if content.contains("\"svelte\"") {
133 framework = Some("Svelte".to_string());
134 } else if content.contains("\"express\"") {
135 framework = Some("Express".to_string());
136 } else if content.contains("\"nestjs\"") || content.contains("\"@nestjs/core\"") {
137 framework = Some("NestJS".to_string());
138 }
139 }
140 }
141
142 if workspace.join("pyproject.toml").exists()
144 || workspace.join("setup.py").exists()
145 || workspace.join("requirements.txt").exists()
146 || workspace.join("Pipfile").exists()
147 {
148 types.push(ProjectType::Python);
149 if workspace.join("pyproject.toml").exists() {
150 if workspace.join("poetry.lock").exists() {
151 build_commands.push("poetry build".to_string());
152 test_commands.push("poetry run pytest".to_string());
153 package_manager = Some("poetry".to_string());
154 } else if workspace.join("uv.lock").exists() {
155 test_commands.push("uv run pytest".to_string());
156 package_manager = Some("uv".to_string());
157 } else {
158 test_commands.push("python -m pytest".to_string());
159 package_manager = Some("pip".to_string());
160 }
161 } else if workspace.join("Pipfile").exists() {
162 test_commands.push("pipenv run pytest".to_string());
163 package_manager = Some("pipenv".to_string());
164 } else {
165 test_commands.push("python -m pytest".to_string());
166 package_manager = Some("pip".to_string());
167 }
168 let py_files = ["pyproject.toml", "requirements.txt", "setup.py"];
170 for file in &py_files {
171 if let Ok(content) = std::fs::read_to_string(workspace.join(file)) {
172 if content.contains("django") {
173 framework = Some("Django".to_string());
174 break;
175 } else if content.contains("fastapi") {
176 framework = Some("FastAPI".to_string());
177 break;
178 } else if content.contains("flask") {
179 framework = Some("Flask".to_string());
180 break;
181 }
182 }
183 }
184 }
185
186 if workspace.join("go.mod").exists() {
188 types.push(ProjectType::Go);
189 build_commands.push("go build ./...".to_string());
190 test_commands.push("go test ./...".to_string());
191 package_manager = Some("go".to_string());
192 }
193
194 if workspace.join("pom.xml").exists() {
196 types.push(ProjectType::Java);
197 build_commands.push("mvn compile".to_string());
198 test_commands.push("mvn test".to_string());
199 package_manager = Some("maven".to_string());
200 } else if workspace.join("build.gradle").exists() || workspace.join("build.gradle.kts").exists()
201 {
202 types.push(ProjectType::Java);
203 build_commands.push("./gradlew build".to_string());
204 test_commands.push("./gradlew test".to_string());
205 package_manager = Some("gradle".to_string());
206 }
207
208 if workspace.join("Gemfile").exists() {
210 types.push(ProjectType::Ruby);
211 test_commands.push("bundle exec rspec".to_string());
212 package_manager = Some("bundler".to_string());
213 if workspace.join("config").join("routes.rb").exists() {
214 framework = Some("Rails".to_string());
215 }
216 }
217
218 let has_csharp = workspace.join("*.csproj").exists()
220 || std::fs::read_dir(workspace)
221 .map(|entries| {
222 entries
223 .filter_map(|e| e.ok())
224 .any(|e| e.path().extension().is_some_and(|ext| ext == "csproj"))
225 })
226 .unwrap_or(false);
227 if has_csharp || workspace.join("*.sln").exists() {
228 types.push(ProjectType::CSharp);
229 build_commands.push("dotnet build".to_string());
230 test_commands.push("dotnet test".to_string());
231 package_manager = Some("nuget".to_string());
232 }
233
234 if workspace.join("CMakeLists.txt").exists() || workspace.join("Makefile").exists() {
236 types.push(ProjectType::Cpp);
237 if workspace.join("CMakeLists.txt").exists() {
238 build_commands.push("cmake --build build".to_string());
239 package_manager = Some("cmake".to_string());
240 } else {
241 build_commands.push("make".to_string());
242 }
243 }
244
245 let has_git = workspace.join(".git").exists();
247 let git_clean = if has_git {
248 std::process::Command::new("git")
249 .args(["status", "--porcelain"])
250 .current_dir(workspace)
251 .output()
252 .map(|o| o.stdout.is_empty())
253 .unwrap_or(false)
254 } else {
255 false
256 };
257
258 let has_ci = workspace.join(".github").join("workflows").exists()
260 || workspace.join(".gitlab-ci.yml").exists()
261 || workspace.join(".circleci").exists()
262 || workspace.join("Jenkinsfile").exists();
263
264 let project_type = match types.len() {
266 0 => ProjectType::Unknown,
267 1 => types.into_iter().next().unwrap(),
268 _ => ProjectType::Mixed(types),
269 };
270
271 ProjectInfo {
272 project_type,
273 has_git,
274 git_clean,
275 build_commands,
276 test_commands,
277 package_manager,
278 source_dirs,
279 has_ci,
280 framework,
281 }
282}
283
284pub fn recommended_allowed_commands(info: &ProjectInfo) -> Vec<String> {
286 let mut commands = vec!["git".to_string(), "echo".to_string(), "cat".to_string()];
287
288 match &info.project_type {
289 ProjectType::Rust => {
290 commands.extend([
291 "cargo".to_string(),
292 "rustfmt".to_string(),
293 "clippy-driver".to_string(),
294 ]);
295 }
296 ProjectType::Node => {
297 commands.extend(["node".to_string(), "npx".to_string()]);
298 if let Some(pm) = &info.package_manager {
299 commands.push(pm.clone());
300 }
301 }
302 ProjectType::Python => {
303 commands.extend([
304 "python".to_string(),
305 "python3".to_string(),
306 "pytest".to_string(),
307 ]);
308 if let Some(pm) = &info.package_manager {
309 commands.push(pm.clone());
310 }
311 }
312 ProjectType::Go => {
313 commands.extend(["go".to_string(), "gofmt".to_string()]);
314 }
315 ProjectType::Java => {
316 if let Some(pm) = &info.package_manager {
317 match pm.as_str() {
318 "maven" => commands.push("mvn".to_string()),
319 "gradle" => commands.push("./gradlew".to_string()),
320 _ => {}
321 }
322 }
323 }
324 ProjectType::Ruby => {
325 commands.extend(["ruby".to_string(), "bundle".to_string(), "rake".to_string()]);
326 }
327 ProjectType::CSharp => {
328 commands.push("dotnet".to_string());
329 }
330 ProjectType::Cpp => {
331 commands.extend(["cmake".to_string(), "make".to_string()]);
332 }
333 ProjectType::Mixed(types) => {
334 for t in types {
335 let sub_info = ProjectInfo {
336 project_type: t.clone(),
337 has_git: false,
338 git_clean: false,
339 build_commands: vec![],
340 test_commands: vec![],
341 package_manager: info.package_manager.clone(),
342 source_dirs: vec![],
343 has_ci: false,
344 framework: None,
345 };
346 commands.extend(recommended_allowed_commands(&sub_info));
347 }
348 commands.sort();
349 commands.dedup();
350 }
351 ProjectType::Unknown => {}
352 }
353
354 commands
355}
356
357pub fn example_tasks(info: &ProjectInfo) -> Vec<String> {
359 let mut tasks = Vec::new();
360
361 match &info.project_type {
362 ProjectType::Rust => {
363 tasks.push("\"Fix the compiler warnings in src/main.rs\"".to_string());
364 tasks.push("\"Add error handling to the database module\"".to_string());
365 tasks.push("\"Write tests for the authentication logic\"".to_string());
366 }
367 ProjectType::Node => {
368 tasks.push("\"Add input validation to the API endpoints\"".to_string());
369 tasks.push("\"Fix the failing test in auth.test.ts\"".to_string());
370 tasks.push("\"Refactor the user service to use async/await\"".to_string());
371 }
372 ProjectType::Python => {
373 tasks.push("\"Add type hints to the data processing module\"".to_string());
374 tasks.push("\"Write unit tests for the API handlers\"".to_string());
375 tasks.push("\"Fix the race condition in the worker pool\"".to_string());
376 }
377 ProjectType::Go => {
378 tasks.push("\"Add error wrapping to the HTTP handlers\"".to_string());
379 tasks.push("\"Write table-driven tests for the parser\"".to_string());
380 tasks.push("\"Implement graceful shutdown for the server\"".to_string());
381 }
382 ProjectType::Java => {
383 tasks.push("\"Add null safety checks to the service layer\"".to_string());
384 tasks.push("\"Write integration tests for the REST controllers\"".to_string());
385 tasks.push("\"Refactor the DAO layer to use the repository pattern\"".to_string());
386 }
387 _ => {
388 tasks.push("\"Find and fix bugs in the codebase\"".to_string());
389 tasks.push("\"Add tests for the main module\"".to_string());
390 tasks.push("\"Explain the architecture of this project\"".to_string());
391 }
392 }
393
394 tasks
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use tempfile::TempDir;
401
402 #[test]
403 fn test_detect_rust_project() {
404 let dir = TempDir::new().unwrap();
405 std::fs::write(
406 dir.path().join("Cargo.toml"),
407 "[package]\nname = \"test\"\nversion = \"0.1.0\"",
408 )
409 .unwrap();
410 std::fs::create_dir(dir.path().join("src")).unwrap();
411
412 let info = detect_project(dir.path());
413 assert_eq!(info.project_type, ProjectType::Rust);
414 assert!(info.build_commands.contains(&"cargo build".to_string()));
415 assert!(info.test_commands.contains(&"cargo test".to_string()));
416 assert_eq!(info.package_manager, Some("cargo".to_string()));
417 assert!(info.source_dirs.contains(&"src".to_string()));
418 }
419
420 #[test]
421 fn test_detect_node_project_npm() {
422 let dir = TempDir::new().unwrap();
423 std::fs::write(
424 dir.path().join("package.json"),
425 r#"{"name": "test", "version": "1.0.0"}"#,
426 )
427 .unwrap();
428
429 let info = detect_project(dir.path());
430 assert_eq!(info.project_type, ProjectType::Node);
431 assert_eq!(info.package_manager, Some("npm".to_string()));
432 }
433
434 #[test]
435 fn test_detect_node_project_pnpm() {
436 let dir = TempDir::new().unwrap();
437 std::fs::write(
438 dir.path().join("package.json"),
439 r#"{"name": "test", "version": "1.0.0"}"#,
440 )
441 .unwrap();
442 std::fs::write(dir.path().join("pnpm-lock.yaml"), "lockfileVersion: 9\n").unwrap();
443
444 let info = detect_project(dir.path());
445 assert_eq!(info.project_type, ProjectType::Node);
446 assert_eq!(info.package_manager, Some("pnpm".to_string()));
447 }
448
449 #[test]
450 fn test_detect_python_project() {
451 let dir = TempDir::new().unwrap();
452 std::fs::write(
453 dir.path().join("pyproject.toml"),
454 "[project]\nname = \"test\"",
455 )
456 .unwrap();
457
458 let info = detect_project(dir.path());
459 assert_eq!(info.project_type, ProjectType::Python);
460 }
461
462 #[test]
463 fn test_detect_go_project() {
464 let dir = TempDir::new().unwrap();
465 std::fs::write(
466 dir.path().join("go.mod"),
467 "module example.com/test\n\ngo 1.21\n",
468 )
469 .unwrap();
470
471 let info = detect_project(dir.path());
472 assert_eq!(info.project_type, ProjectType::Go);
473 assert!(info.build_commands.contains(&"go build ./...".to_string()));
474 }
475
476 #[test]
477 fn test_detect_mixed_project() {
478 let dir = TempDir::new().unwrap();
479 std::fs::write(dir.path().join("Cargo.toml"), "[package]\nname = \"test\"").unwrap();
480 std::fs::write(dir.path().join("package.json"), r#"{"name": "test"}"#).unwrap();
481
482 let info = detect_project(dir.path());
483 match &info.project_type {
484 ProjectType::Mixed(types) => {
485 assert!(types.contains(&ProjectType::Rust));
486 assert!(types.contains(&ProjectType::Node));
487 }
488 _ => panic!("Expected Mixed project type"),
489 }
490 }
491
492 #[test]
493 fn test_detect_unknown_project() {
494 let dir = TempDir::new().unwrap();
495 let info = detect_project(dir.path());
496 assert_eq!(info.project_type, ProjectType::Unknown);
497 }
498
499 #[test]
500 fn test_detect_git_status() {
501 let dir = TempDir::new().unwrap();
502 let info = detect_project(dir.path());
504 assert!(!info.has_git);
505 }
506
507 #[test]
508 fn test_detect_ci() {
509 let dir = TempDir::new().unwrap();
510 let gh_dir = dir.path().join(".github").join("workflows");
511 std::fs::create_dir_all(&gh_dir).unwrap();
512 std::fs::write(gh_dir.join("ci.yml"), "name: CI").unwrap();
513
514 let info = detect_project(dir.path());
515 assert!(info.has_ci);
516 }
517
518 #[test]
519 fn test_recommended_commands_rust() {
520 let info = ProjectInfo {
521 project_type: ProjectType::Rust,
522 has_git: true,
523 git_clean: true,
524 build_commands: vec!["cargo build".to_string()],
525 test_commands: vec!["cargo test".to_string()],
526 package_manager: Some("cargo".to_string()),
527 source_dirs: vec!["src".to_string()],
528 has_ci: false,
529 framework: None,
530 };
531 let cmds = recommended_allowed_commands(&info);
532 assert!(cmds.contains(&"cargo".to_string()));
533 assert!(cmds.contains(&"git".to_string()));
534 }
535
536 #[test]
537 fn test_example_tasks_rust() {
538 let info = ProjectInfo {
539 project_type: ProjectType::Rust,
540 has_git: true,
541 git_clean: true,
542 build_commands: vec![],
543 test_commands: vec![],
544 package_manager: None,
545 source_dirs: vec![],
546 has_ci: false,
547 framework: None,
548 };
549 let tasks = example_tasks(&info);
550 assert!(!tasks.is_empty());
551 }
552
553 #[test]
554 fn test_detect_rust_framework_axum() {
555 let dir = TempDir::new().unwrap();
556 std::fs::write(
557 dir.path().join("Cargo.toml"),
558 "[dependencies]\naxum = \"0.7\"",
559 )
560 .unwrap();
561
562 let info = detect_project(dir.path());
563 assert_eq!(info.framework, Some("Axum".to_string()));
564 }
565
566 #[test]
567 fn test_project_type_display() {
568 assert_eq!(ProjectType::Rust.to_string(), "Rust");
569 assert_eq!(ProjectType::Node.to_string(), "Node.js");
570 assert_eq!(ProjectType::Unknown.to_string(), "Unknown");
571 assert_eq!(
572 ProjectType::Mixed(vec![ProjectType::Rust, ProjectType::Node]).to_string(),
573 "Mixed (Rust, Node.js)"
574 );
575 }
576}