Skip to main content

rust_doctor/mcp/
tools.rs

1use crate::diagnostics::ScanResult;
2use crate::discovery::ProjectInfo;
3use crate::{config, scan};
4use rmcp::handler::server::wrapper::{Json, Parameters};
5use rmcp::model::{
6    CallToolResult, Content, GetPromptResult, LoggingLevel, LoggingMessageNotificationParam,
7    PromptMessage, PromptMessageRole,
8};
9use rmcp::{ErrorData as McpError, RoleServer, prompt, prompt_router, tool, tool_router};
10use std::sync::Arc;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::time::Duration;
13
14use super::RustDoctorServer;
15use super::helpers::{discover_and_resolve, format_scan_report, group_diagnostics};
16use super::rules::{get_all_rules_listing, get_rule_explanation};
17use super::types::{
18    DeepAuditArgs, ExplainRuleInput, HealthCheckArgs, ScanInput, ScoreInput, ScoreOutput,
19};
20
21/// MCP timeout for a single scan/score call. On expiry the work is cancelled
22/// cooperatively, not detached.
23const MCP_SCAN_TIMEOUT_SECS: u64 = 300;
24
25/// Run a scan on a blocking thread under a 5-minute absolute timeout.
26///
27/// On timeout the shared cancel flag is set so the (now-detached) blocking scan
28/// stops launching new passes instead of running to completion in the background
29/// and exhausting the blocking pool (US-007). The client-facing timeout message is
30/// unchanged.
31async fn run_scan_with_timeout(
32    project_info: ProjectInfo,
33    resolved: config::ResolvedConfig,
34    offline: bool,
35    tool: &str,
36) -> Result<ScanResult, McpError> {
37    let cancel = Arc::new(AtomicBool::new(false));
38    let cancel_task = Arc::clone(&cancel);
39    let scan_future = tokio::task::spawn_blocking(move || {
40        scan::scan_project_cancellable(&project_info, &resolved, offline, &[], true, &cancel_task)
41    });
42
43    match tokio::time::timeout(Duration::from_secs(MCP_SCAN_TIMEOUT_SECS), scan_future).await {
44        Ok(join_result) => join_result
45            .map_err(|e| McpError::internal_error(format!("scan task failed: {e}"), None))?
46            .map_err(|e| {
47                eprintln!("MCP {tool} error: {e}");
48                McpError::internal_error(
49                    "scan failed — check project compiles with `cargo check`",
50                    None,
51                )
52            }),
53        Err(_elapsed) => {
54            // Signal the detached blocking task to stop; do not leave it running.
55            cancel.store(true, Ordering::Relaxed);
56            Err(McpError::internal_error(
57                "scan timed out after 5 minutes — project may be too large or a subprocess is hanging",
58                None,
59            ))
60        }
61    }
62}
63
64// ---------------------------------------------------------------------------
65// Tool and prompt implementations
66// ---------------------------------------------------------------------------
67
68#[tool_router(vis = "pub(super)")]
69#[prompt_router(vis = "pub(super)")]
70impl RustDoctorServer {
71    pub(super) fn new() -> Self {
72        Self {
73            tool_router: Self::tool_router(),
74            prompt_router: Self::prompt_router(),
75        }
76    }
77
78    #[tool(
79        name = "scan",
80        description = "Run a full Rust code health analysis on a project directory. \
81Use this tool when you need detailed diagnostics — it returns all findings with file:line precision. \
82Takes 5-30 seconds depending on project size. \
83Returns JSON with: diagnostics array (each has rule, severity, message, file_path, line, column, help), \
84score (0-100), score_label, source_file_count, elapsed_secs, error_count, warning_count, info_count, skipped_passes. \
85Severity levels: error (bugs/security), warning (code smells), info (suggestions). \
86Runs 4 passes in parallel: clippy (55+ lints), 19 custom AST rules, cargo-audit (CVEs), cargo-machete (unused deps). \
87Set 'diff' to a branch name to only scan changed files. \
88After scanning, use explain_rule on any rule ID to get fix guidance.",
89        annotations(
90            title = "Scan Project",
91            read_only_hint = true,
92            destructive_hint = false,
93            idempotent_hint = true,
94            open_world_hint = false,
95        )
96    )]
97    async fn scan(
98        &self,
99        meta: rmcp::model::Meta,
100        client: rmcp::Peer<RoleServer>,
101        params: Parameters<ScanInput>,
102    ) -> Result<CallToolResult, McpError> {
103        let input = params.0;
104        let progress_token = meta.get_progress_token();
105
106        // Send start progress if client supports it
107        if let Some(ref token) = progress_token {
108            let _ = client
109                .notify_progress(rmcp::model::ProgressNotificationParam {
110                    progress_token: token.clone(),
111                    progress: 0.0,
112                    total: Some(2.0),
113                    message: Some("Bootstrapping project...".to_string()),
114                })
115                .await;
116        }
117        let _ = client
118            .notify_logging_message(LoggingMessageNotificationParam {
119                level: LoggingLevel::Info,
120                logger: Some("rust-doctor".into()),
121                data: serde_json::json!("Bootstrapping project..."),
122            })
123            .await;
124
125        let (_dir, project_info, mut resolved) =
126            discover_and_resolve(&input.directory, input.ignore_project_config)?;
127
128        if let Some(diff_base) = input.diff {
129            resolved.diff = Some(diff_base);
130        }
131
132        // Send scanning progress
133        if let Some(ref token) = progress_token {
134            let _ = client
135                .notify_progress(rmcp::model::ProgressNotificationParam {
136                    progress_token: token.clone(),
137                    progress: 1.0,
138                    total: Some(2.0),
139                    message: Some(
140                        "Running analysis passes (clippy, rules, audit, machete)...".to_string(),
141                    ),
142                })
143                .await;
144        }
145        let _ = client
146            .notify_logging_message(LoggingMessageNotificationParam {
147                level: LoggingLevel::Info,
148                logger: Some("rust-doctor".into()),
149                data: serde_json::json!(
150                    "Running 4 analysis passes (clippy, AST rules, cargo-audit, cargo-machete)..."
151                ),
152            })
153            .await;
154
155        // Run the CPU-bound scan on a blocking thread with a 5-minute absolute timeout
156        let offline = input.offline;
157        let result = run_scan_with_timeout(project_info, resolved, offline, "scan").await?;
158
159        // Send completion progress
160        if let Some(ref token) = progress_token {
161            let _ = client
162                .notify_progress(rmcp::model::ProgressNotificationParam {
163                    progress_token: token.clone(),
164                    progress: 2.0,
165                    total: Some(2.0),
166                    message: Some(format!(
167                        "Scan complete: score {}/100, {} findings",
168                        result.score,
169                        result.diagnostics.len()
170                    )),
171                })
172                .await;
173        }
174        let _ = client
175            .notify_logging_message(LoggingMessageNotificationParam {
176                level: LoggingLevel::Info,
177                logger: Some("rust-doctor".into()),
178                data: serde_json::Value::String(format!(
179                    "Scan complete: {}/100 ({}) — {} errors, {} warnings, {} info in {:.1}s",
180                    result.score,
181                    result.score_label,
182                    result.error_count,
183                    result.warning_count,
184                    result.info_count,
185                    result.elapsed.as_secs_f64()
186                )),
187            })
188            .await;
189
190        let grouped = group_diagnostics(&result.diagnostics);
191        let report = format_scan_report(&result, &grouped);
192
193        Ok(CallToolResult::success(vec![Content::text(report)]))
194    }
195
196    #[tool(
197        name = "score",
198        description = "Get just the health score of a Rust project (0-100 integer). \
199Use this tool for a quick pass/fail check without full diagnostics. \
200IMPORTANT: runs the same full analysis as scan internally, so takes the same 5-30 seconds. \
201Score thresholds: >=75 'Great', >=50 'Needs work', <50 'Critical'. \
202Scoring: each unique error-severity rule violated costs 1.5 points, each warning costs 0.75 points. \
203If you also need the diagnostics, use scan instead — it includes the score too.",
204        annotations(
205            title = "Score Project",
206            read_only_hint = true,
207            destructive_hint = false,
208            idempotent_hint = true,
209            open_world_hint = false,
210        )
211    )]
212    async fn score(
213        &self,
214        meta: rmcp::model::Meta,
215        client: rmcp::Peer<RoleServer>,
216        params: Parameters<ScoreInput>,
217    ) -> Result<Json<ScoreOutput>, McpError> {
218        let input = params.0;
219        let progress_token = meta.get_progress_token();
220
221        if let Some(ref token) = progress_token {
222            let _ = client
223                .notify_progress(rmcp::model::ProgressNotificationParam {
224                    progress_token: token.clone(),
225                    progress: 0.0,
226                    total: Some(1.0),
227                    message: Some("Scoring project...".to_string()),
228                })
229                .await;
230        }
231        let _ = client
232            .notify_logging_message(LoggingMessageNotificationParam {
233                level: LoggingLevel::Info,
234                logger: Some("rust-doctor".into()),
235                data: serde_json::json!("Scoring project..."),
236            })
237            .await;
238
239        let (_dir, project_info, resolved) =
240            discover_and_resolve(&input.directory, input.ignore_project_config)?;
241
242        // Run the CPU-bound scan on a blocking thread with a 5-minute absolute timeout
243        let offline = input.offline;
244        let result = run_scan_with_timeout(project_info, resolved, offline, "score").await?;
245
246        if let Some(ref token) = progress_token {
247            let _ = client
248                .notify_progress(rmcp::model::ProgressNotificationParam {
249                    progress_token: token.clone(),
250                    progress: 1.0,
251                    total: Some(1.0),
252                    message: Some(format!(
253                        "Score: {}/100 ({})",
254                        result.score, result.score_label
255                    )),
256                })
257                .await;
258        }
259        let _ = client
260            .notify_logging_message(LoggingMessageNotificationParam {
261                level: LoggingLevel::Info,
262                logger: Some("rust-doctor".into()),
263                data: serde_json::Value::String(format!(
264                    "Score: {}/100 ({})",
265                    result.score, result.score_label
266                )),
267            })
268            .await;
269
270        Ok(Json(ScoreOutput {
271            score: result.score,
272            score_label: result.score_label,
273        }))
274    }
275
276    #[tool(
277        name = "explain_rule",
278        description = "Get a detailed markdown explanation of a specific rust-doctor rule. \
279Use this after scan to understand what a rule detects and how to fix violations. \
280Returns: rule name, category, severity, description, and fix guidance. \
281Accepts custom rule IDs (e.g. 'unwrap-in-production') and clippy lint names (e.g. 'clippy::expect_used'). \
282Instant response — no project scanning required. \
283For unknown rules, returns guidance to use list_rules.",
284        annotations(
285            title = "Explain Rule",
286            read_only_hint = true,
287            destructive_hint = false,
288            idempotent_hint = true,
289            open_world_hint = false,
290        )
291    )]
292    async fn explain_rule(
293        &self,
294        params: Parameters<ExplainRuleInput>,
295    ) -> Result<CallToolResult, McpError> {
296        let explanation = get_rule_explanation(&params.0.rule);
297        Ok(CallToolResult::success(vec![Content::text(explanation)]))
298    }
299
300    #[tool(
301        name = "list_rules",
302        description = "List all available rust-doctor rules as formatted markdown. \
303Use this to discover which checks exist before scanning, or to find a rule ID for explain_rule. \
304Instant response — no project scanning required. \
305Returns: 19 custom AST rules (grouped by Error Handling, Performance, Architecture, Security, Async, Framework), \
30655+ clippy lints with custom severity overrides, and 2 external tools (cargo-audit, cargo-machete). \
307Each entry shows rule ID, severity, and one-line summary.",
308        annotations(
309            title = "List Rules",
310            read_only_hint = true,
311            destructive_hint = false,
312            idempotent_hint = true,
313            open_world_hint = false,
314        )
315    )]
316    async fn list_rules(&self) -> Result<CallToolResult, McpError> {
317        let listing = get_all_rules_listing();
318        Ok(CallToolResult::success(vec![Content::text(listing)]))
319    }
320
321    // -- Prompts --------------------------------------------------------------
322
323    #[prompt(
324        name = "deep-audit",
325        description = "Comprehensive Rust code audit: explores codebase architecture, runs rust-doctor \
326analysis, performs deep code review against production best practices, researches current Rust patterns \
327on the web, cross-references findings, and generates a full remediation report. Ends with a choice: \
328implement all fixes, generate a PRD, or manual prompt. Use this for thorough, expert-level code audits \
329that go far beyond linting."
330    )]
331    pub(super) async fn deep_audit(&self, params: Parameters<DeepAuditArgs>) -> GetPromptResult {
332        GetPromptResult::new(vec![PromptMessage::new_text(
333            PromptMessageRole::User,
334            super::prompts::deep_audit_prompt(&params.0.directory),
335        )])
336        .with_description(
337            "Expert-level Rust audit: codebase exploration + static analysis + deep code review \
338             + best practices research + synthesis report + actionable remediation choices",
339        )
340    }
341
342    #[prompt(
343        name = "health-check",
344        description = "Run a full health check on a Rust project: scan, generate a prioritized \
345remediation plan, and optionally apply fixes. Combines scan + plan + fix into one structured workflow."
346    )]
347    pub(super) async fn health_check(
348        &self,
349        params: Parameters<HealthCheckArgs>,
350    ) -> GetPromptResult {
351        GetPromptResult::new(vec![PromptMessage::new_text(
352            PromptMessageRole::User,
353            super::prompts::health_check_prompt(&params.0.directory),
354        )])
355        .with_description(
356            "Full health audit with prioritized remediation plan and structured fix workflow",
357        )
358    }
359}