syncable_cli/agent/tools/
security.rs

1//! Security and vulnerability scanning tools using Rig's Tool trait
2
3use rig::completion::ToolDefinition;
4use rig::tool::Tool;
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7use std::path::PathBuf;
8
9use super::compression::{CompressionConfig, compress_tool_output};
10use crate::analyzer::security::turbo::{ScanMode, TurboConfig, TurboSecurityAnalyzer};
11
12// ============================================================================
13// Security Scan Tool
14// ============================================================================
15
16#[derive(Debug, Deserialize)]
17pub struct SecurityScanArgs {
18    pub mode: Option<String>,
19    pub path: Option<String>,
20}
21
22#[derive(Debug, thiserror::Error)]
23#[error("Security scan error: {0}")]
24pub struct SecurityScanError(String);
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct SecurityScanTool {
28    project_path: PathBuf,
29}
30
31impl SecurityScanTool {
32    pub fn new(project_path: PathBuf) -> Self {
33        Self { project_path }
34    }
35}
36
37impl Tool for SecurityScanTool {
38    const NAME: &'static str = "security_scan";
39
40    type Error = SecurityScanError;
41    type Args = SecurityScanArgs;
42    type Output = String;
43
44    async fn definition(&self, _prompt: String) -> ToolDefinition {
45        ToolDefinition {
46            name: Self::NAME.to_string(),
47            description: "Perform a security scan to detect potential secrets, API keys, passwords, and sensitive data that might be accidentally committed.".to_string(),
48            parameters: json!({
49                "type": "object",
50                "properties": {
51                    "mode": {
52                        "type": "string",
53                        "enum": ["lightning", "fast", "balanced", "thorough", "paranoid"],
54                        "description": "Scan mode: lightning (fast), balanced (recommended), thorough, or paranoid"
55                    },
56                    "path": {
57                        "type": "string",
58                        "description": "Optional subdirectory path to scan"
59                    }
60                }
61            }),
62        }
63    }
64
65    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
66        let path = match args.path {
67            Some(subpath) => self.project_path.join(subpath),
68            None => self.project_path.clone(),
69        };
70
71        let scan_mode = match args.mode.as_deref() {
72            Some("lightning") => ScanMode::Lightning,
73            Some("fast") => ScanMode::Fast,
74            Some("thorough") => ScanMode::Thorough,
75            Some("paranoid") => ScanMode::Paranoid,
76            _ => ScanMode::Balanced,
77        };
78
79        let config = TurboConfig {
80            scan_mode,
81            ..TurboConfig::default()
82        };
83
84        let scanner = TurboSecurityAnalyzer::new(config)
85            .map_err(|e| SecurityScanError(format!("Failed to create scanner: {}", e)))?;
86
87        let report = scanner
88            .analyze_project(&path)
89            .map_err(|e| SecurityScanError(format!("Scan failed: {}", e)))?;
90
91        // Build full result with all findings (compression will handle size)
92        let result = json!({
93            "total_findings": report.total_findings,
94            "overall_score": report.overall_score,
95            "risk_level": format!("{:?}", report.risk_level),
96            "files_scanned": report.files_scanned,
97            "findings": report.findings.iter().map(|f| {
98                json!({
99                    "title": f.title,
100                    "description": f.description,
101                    "severity": format!("{:?}", f.severity),
102                    "category": format!("{:?}", f.category),
103                    "file": f.file_path.as_ref().map(|p| p.display().to_string()),
104                    "line": f.line_number,
105                    "evidence": f.evidence.as_ref().map(|e| e.chars().take(100).collect::<String>()),
106                })
107            }).collect::<Vec<_>>(),
108            "recommendations": report.recommendations.clone(),
109            "scan_mode": args.mode.as_deref().unwrap_or("balanced"),
110        });
111
112        // Use compression - stores full data for RAG retrieval if output is large
113        let config = CompressionConfig::default();
114        Ok(compress_tool_output(&result, "security_scan", &config))
115    }
116}
117
118// ============================================================================
119// Vulnerabilities Tool
120// ============================================================================
121
122#[derive(Debug, Deserialize)]
123pub struct VulnerabilitiesArgs {
124    pub path: Option<String>,
125}
126
127#[derive(Debug, thiserror::Error)]
128#[error("Vulnerability check error: {0}")]
129pub struct VulnerabilitiesError(String);
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct VulnerabilitiesTool {
133    project_path: PathBuf,
134}
135
136impl VulnerabilitiesTool {
137    pub fn new(project_path: PathBuf) -> Self {
138        Self { project_path }
139    }
140}
141
142impl Tool for VulnerabilitiesTool {
143    const NAME: &'static str = "check_vulnerabilities";
144
145    type Error = VulnerabilitiesError;
146    type Args = VulnerabilitiesArgs;
147    type Output = String;
148
149    async fn definition(&self, _prompt: String) -> ToolDefinition {
150        ToolDefinition {
151            name: Self::NAME.to_string(),
152            description:
153                "Check the project's dependencies for known security vulnerabilities (CVEs)."
154                    .to_string(),
155            parameters: json!({
156                "type": "object",
157                "properties": {
158                    "path": {
159                        "type": "string",
160                        "description": "Optional subdirectory path to check"
161                    }
162                }
163            }),
164        }
165    }
166
167    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
168        let path = match args.path {
169            Some(subpath) => self.project_path.join(subpath),
170            None => self.project_path.clone(),
171        };
172
173        let parser = crate::analyzer::dependency_parser::DependencyParser::new();
174        let dependencies = parser
175            .parse_all_dependencies(&path)
176            .map_err(|e| VulnerabilitiesError(format!("Failed to parse dependencies: {}", e)))?;
177
178        if dependencies.is_empty() {
179            return Ok(json!({
180                "message": "No dependencies found in project",
181                "total_vulnerabilities": 0
182            })
183            .to_string());
184        }
185
186        let checker = crate::analyzer::vulnerability::VulnerabilityChecker::new();
187        let report = checker
188            .check_all_dependencies(&dependencies, &path)
189            .await
190            .map_err(|e| VulnerabilitiesError(format!("Vulnerability check failed: {}", e)))?;
191
192        // Build findings array for compression (each vuln as a separate issue)
193        let mut findings = Vec::new();
194        for dep in &report.vulnerable_dependencies {
195            for v in &dep.vulnerabilities {
196                findings.push(json!({
197                    "code": v.id.clone(),
198                    "severity": format!("{:?}", v.severity),
199                    "title": v.title.clone(),
200                    "message": format!("{} {} has vulnerability: {}", dep.name, dep.version, v.title),
201                    "dependency": dep.name.clone(),
202                    "version": dep.version.clone(),
203                    "language": dep.language.as_str(),
204                    "cve": v.cve.clone(),
205                    "patched_versions": v.patched_versions.clone(),
206                }));
207            }
208        }
209
210        let result = json!({
211            "total_vulnerabilities": report.total_vulnerabilities,
212            "critical_count": report.critical_count,
213            "high_count": report.high_count,
214            "medium_count": report.medium_count,
215            "low_count": report.low_count,
216            "issues": findings,  // Use "issues" so compression can find it
217        });
218
219        // Use compression - stores full data for RAG retrieval if output is large
220        let config = CompressionConfig::default();
221        Ok(compress_tool_output(
222            &result,
223            "check_vulnerabilities",
224            &config,
225        ))
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use tempfile::tempdir;
233
234    #[tokio::test]
235    async fn test_security_scan_empty_project() {
236        let dir = tempdir().unwrap();
237        // Create minimal project structure
238        std::fs::write(dir.path().join("main.rs"), "fn main() {}").unwrap();
239
240        let tool = SecurityScanTool::new(dir.path().to_path_buf());
241        let args = SecurityScanArgs {
242            mode: None,
243            path: None,
244        };
245
246        let result = tool.call(args).await.unwrap();
247        // Should return valid JSON (could be success with counts or error)
248        let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
249        assert!(parsed.is_object());
250    }
251
252    #[tokio::test]
253    async fn test_security_scan_with_path() {
254        let dir = tempdir().unwrap();
255        let subdir = dir.path().join("src");
256        std::fs::create_dir(&subdir).unwrap();
257        std::fs::write(subdir.join("lib.rs"), "pub fn foo() {}").unwrap();
258
259        let tool = SecurityScanTool::new(dir.path().to_path_buf());
260        let args = SecurityScanArgs {
261            mode: None,
262            path: Some("src".to_string()),
263        };
264
265        let result = tool.call(args).await.unwrap();
266        // Should return valid JSON
267        let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
268        assert!(parsed.is_object());
269    }
270}