symposium_rust_analyzer/
rust_analyzer_mcp.rs

1use anyhow::anyhow;
2use lsp_types::{Position, TextDocumentIdentifier, TextDocumentPositionParams, Uri};
3use sacp::{ProxyToConductor, mcp_server::McpServer};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::{HashMap, HashSet};
8use std::path::{Path, PathBuf};
9use std::str::FromStr;
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13use crate::failed_obligations::{
14    FailedObligationsState, handle_failed_obligations, handle_failed_obligations_goal,
15};
16use crate::lsp_client::LspClient;
17
18pub type Result<T> = std::result::Result<T, sacp::Error>;
19
20pub struct BridgeState {
21    client: Option<LspClient>,
22    opened_documents: HashSet<String>,
23    document_versions: HashMap<String, i32>,
24}
25
26impl BridgeState {
27    pub fn new() -> Self {
28        Self {
29            client: None,
30            opened_documents: HashSet::new(),
31            document_versions: HashMap::new(),
32        }
33    }
34}
35
36pub type BridgeType = Arc<Mutex<BridgeState>>;
37
38#[derive(Serialize, Deserialize, JsonSchema, Debug)]
39pub struct FilePositionInputs {
40    pub file_path: String,
41    pub line: u32,
42    pub character: u32,
43}
44
45#[derive(Serialize, Deserialize, JsonSchema)]
46struct FileOnlyInputs {
47    pub file_path: String,
48}
49
50/*
51#[derive(Serialize, Deserialize, JsonSchema)]
52struct RangeInputs {
53    pub file_path: String,
54    pub line: u32,
55    pub character: u32,
56    pub end_line: u32,
57    pub end_character: u32,
58}
59*/
60
61#[derive(Serialize, Deserialize, JsonSchema)]
62struct WorkspaceInputs {
63    pub workspace_path: String,
64}
65
66#[derive(Serialize, Deserialize, JsonSchema)]
67pub struct GoalIndexInputs {
68    pub goal_index: Value,
69}
70
71pub const SERVER_ID: &str = "rust-analyzer";
72
73pub(crate) async fn ensure_bridge(bridge: &BridgeType, workspace_path: Option<&str>) -> Result<()> {
74    let mut bridge_guard = bridge.lock().await;
75    if bridge_guard.client.is_none() || workspace_path.is_some() {
76        let workspace = workspace_path
77            .map(PathBuf::from)
78            .unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
79
80        tracing::debug!(?workspace);
81
82        let root_uri = Uri::from_str(&format!("file://{}", workspace.display()))
83            .map_err(|e| anyhow!("Invalid workspace path: {}", e))?;
84
85        tracing::debug!(?root_uri);
86
87        let client = LspClient::new("rust-analyzer", &[], root_uri)
88            .await
89            .map_err(|e| anyhow!("Failed to start rust-analyzer: {}", e))?;
90
91        wait_for_start(&client).await;
92
93        bridge_guard.client = Some(client);
94        bridge_guard.opened_documents.clear();
95        bridge_guard.document_versions.clear();
96    }
97    Ok(())
98}
99
100pub(crate) async fn with_bridge<F, R>(
101    bridge: &BridgeType,
102    workspace_path: Option<&str>,
103    f: F,
104) -> Result<R>
105where
106    F: for<'a> AsyncFnOnce(&'a LspClient) -> Result<R>,
107{
108    ensure_bridge(bridge, workspace_path).await?;
109    let bridge_guard = bridge.lock().await;
110    f(bridge_guard.client.as_ref().unwrap()).await
111}
112
113pub async fn with_bridge_and_document<F, R>(
114    bridge: &BridgeType,
115    workspace_path: Option<&str>,
116    file_path: &str,
117    f: F,
118) -> Result<R>
119where
120    F: for<'a> AsyncFnOnce(&'a LspClient, Uri) -> Result<R>,
121{
122    ensure_bridge(bridge, workspace_path).await?;
123    let mut bridge_guard = bridge.lock().await;
124    let uri = ensure_document_open(&mut bridge_guard, file_path).await?;
125    f(bridge_guard.client.as_ref().unwrap(), uri).await
126}
127
128async fn wait_for_start(lsp: &LspClient) {
129    let _ = lsp
130        .subscribe_notification::<(), _>(
131            "experimental/serverStatus".to_string(),
132            |value: serde_json::Value| {
133                Box::pin(async move {
134                    let Some(q) = value.get("quiescent").and_then(|q| q.as_bool()) else {
135                        return Err(anyhow::anyhow!("quiescent not found or invalid"));
136                    };
137                    if q { Ok(Some(())) } else { Ok(None) }
138                })
139            },
140        )
141        .await;
142}
143
144fn file_path_to_uri(file_path: &str) -> anyhow::Result<Uri> {
145    if file_path.starts_with("file://") {
146        Uri::from_str(file_path).map_err(|e| anyhow!("Invalid URI: {}", e))
147    } else {
148        Uri::from_str(&format!("file://{}", file_path))
149            .map_err(|e| anyhow!("Invalid file path: {}", e))
150    }
151}
152
153async fn ensure_document_open(bridge_state: &mut BridgeState, file_path: &str) -> Result<Uri> {
154    let file_path = Path::new(file_path);
155    let file_path =
156        std::fs::canonicalize(file_path).map_err(|e| anyhow!("Invalid file path: {}", e))?;
157    let file_path = file_path
158        .to_str()
159        .ok_or_else(|| anyhow!("Invalid file path"))?;
160    let uri = file_path_to_uri(file_path)?;
161    let uri_str = uri.to_string();
162
163    // Only open if not already opened
164    if !bridge_state.opened_documents.contains(&uri_str) {
165        if let Ok(content) = std::fs::read_to_string(file_path) {
166            if let Some(client) = &bridge_state.client {
167                let version = bridge_state
168                    .document_versions
169                    .get(&uri_str)
170                    .copied()
171                    .unwrap_or(1);
172                client
173                    .did_open(uri.clone(), "rust".to_string(), version, content)
174                    .await
175                    .map_err(|e| anyhow!("Failed to open document: {}", e))?;
176                bridge_state.opened_documents.insert(uri_str.clone());
177                bridge_state.document_versions.insert(uri_str, version);
178            }
179        }
180    }
181
182    Ok(uri)
183}
184
185pub async fn build_server(
186    workspace_path: Option<String>,
187) -> Result<McpServer<ProxyToConductor, impl sacp::JrResponder<ProxyToConductor>>> {
188    let bridge: BridgeType = Arc::new(Mutex::new(BridgeState::new()));
189    with_bridge(&bridge, workspace_path.as_deref(), async |_client| Ok(())).await?;
190
191    let failed_obligations_state = Arc::new(Mutex::new(FailedObligationsState::new()));
192    let server = McpServer::builder("rust-analyzer-mcp".to_string())
193        .instructions(indoc::indoc! {"
194            Rust analyzer LSP integration for code analysis, navigation, and diagnostics.
195        "})
196        .tool_fn_mut(
197            "rust_analyzer_hover",
198            "Get hover information for a symbol at a specific position in a Rust file",
199            {
200                let bridge = bridge.clone();
201                async move |input: FilePositionInputs, _mcp_cx| {
202                    with_bridge_and_document(
203                        &bridge,
204                        None,
205                        &input.file_path,
206                        async move |client, uri| {
207                            let position = Position::new(input.line, input.character);
208                            let result = client
209                                .hover(uri, position)
210                                .await
211                                .map_err(|e| anyhow!("Hover request failed: {}", e))?;
212                            Ok(serde_json::to_string(&result)?)
213                        },
214                    )
215                    .await
216                }
217            },
218            sacp::tool_fn_mut!(),
219        )
220        .tool_fn_mut(
221            "rust_analyzer_definition",
222            "Go to definition of a symbol at a specific position",
223            {
224                let bridge = bridge.clone();
225                async move |input: FilePositionInputs, _mcp_cx| {
226                    with_bridge_and_document(
227                        &bridge,
228                        None,
229                        &input.file_path,
230                        async move |client, uri| {
231                            let position = Position::new(input.line, input.character);
232                            let result = client
233                                .goto_definition(uri, position)
234                                .await
235                                .map_err(|e| anyhow!("Definition request failed: {}", e))?;
236                            Ok(serde_json::to_string(&result)?)
237                        },
238                    )
239                    .await
240                }
241            },
242            sacp::tool_fn_mut!(),
243        )
244        .tool_fn_mut(
245            "rust_analyzer_references",
246            "Find all references to a symbol at a specific position",
247            {
248                let bridge = bridge.clone();
249                async move |input: FilePositionInputs, _mcp_cx| {
250                    with_bridge_and_document(
251                        &bridge,
252                        None,
253                        &input.file_path,
254                        async move |client, uri| {
255                            let position = Position::new(input.line, input.character);
256                            let result = client
257                                .find_references(uri, position, true)
258                                .await
259                                .map_err(|e| anyhow!("References request failed: {}", e))?;
260                            Ok(serde_json::to_string(&result)?)
261                        },
262                    )
263                    .await
264                }
265            },
266            sacp::tool_fn_mut!(),
267        )
268        .tool_fn_mut(
269            "rust_analyzer_completion",
270            "Get code completions at a specific position",
271            {
272                let bridge = bridge.clone();
273                async move |input: FilePositionInputs, _mcp_cx| {
274                    with_bridge_and_document(
275                        &bridge,
276                        None,
277                        &input.file_path,
278                        async move |client, uri| {
279                            let position = Position::new(input.line, input.character);
280                            let result = client
281                                .completion(uri, position)
282                                .await
283                                .map_err(|e| anyhow!("Completion request failed: {}", e))?;
284                            Ok(serde_json::to_string(&result)?)
285                        },
286                    )
287                    .await
288                }
289            },
290            sacp::tool_fn_mut!(),
291        )
292        .tool_fn_mut(
293            "rust_analyzer_symbols",
294            "Get document symbols for a Rust file",
295            {
296                let bridge = bridge.clone();
297                async move |input: FileOnlyInputs, _mcp_cx| {
298                    with_bridge_and_document(
299                        &bridge,
300                        None,
301                        &input.file_path,
302                        async move |client, uri| {
303                            let result = client
304                                .document_symbols(uri)
305                                .await
306                                .map_err(|e| anyhow!("Document symbols request failed: {}", e))?;
307                            Ok(serde_json::to_string(&result)?)
308                        },
309                    )
310                    .await
311                }
312            },
313            sacp::tool_fn_mut!(),
314        )
315        /*
316        .tool_fn_mut(
317            "rust_analyzer_format",
318            "Format a Rust document",
319            {
320                let bridge = bridge.clone();
321                async move |input: FileOnlyInputs, _mcp_cx| {
322                    with_bridge_and_document(
323                        &bridge,
324                        None,
325                        &input.file_path,
326                        async move |client, uri| {
327                            let result = client
328                                .format_document(uri)
329                                .await
330                                .map_err(|e| anyhow!("Format request failed: {}", e))?;
331                            Ok(serde_json::to_string(&result)?)
332                        },
333                    )
334                    .await
335                }
336            },
337            sacp::tool_fn_mut!(),
338        )
339        .tool_fn_mut(
340            "rust_analyzer_code_actions",
341            "Get available code actions for a range in a Rust file",
342            {
343                let bridge = bridge.clone();
344                async move |input: RangeInputs, _mcp_cx| {
345                    with_bridge_and_document(
346                        &bridge,
347                        None,
348                        &input.file_path,
349                        async move |client, uri| {
350                            let range = Range::new(
351                                Position::new(input.line, input.character),
352                                Position::new(input.end_line, input.end_character),
353                            );
354                            let context = CodeActionContext {
355                                diagnostics: vec![],
356                                only: None,
357                                trigger_kind: None,
358                            };
359                            let result = client
360                                .code_actions(uri, range, context)
361                                .await
362                                .map_err(|e| anyhow!("Code actions request failed: {}", e))?;
363                            Ok(serde_json::to_string(&result)?)
364                        },
365                    )
366                    .await
367                }
368            },
369            sacp::tool_fn_mut!(),
370        )
371        */
372        .tool_fn_mut(
373            "rust_analyzer_set_workspace",
374            "Set the workspace root for rust-analyzer",
375            {
376                let bridge = bridge.clone();
377                async move |input: WorkspaceInputs, _mcp_cx| {
378                    with_bridge(&bridge, Some(&input.workspace_path), async move |_client| {
379                        Ok("Workspace set successfully".to_string())
380                    })
381                    .await
382                }
383            },
384            sacp::tool_fn_mut!(),
385        )
386        /*
387        .tool_fn_mut(
388            "rust_analyzer_diagnostics",
389            "Get diagnostics for a Rust file",
390            {
391                let bridge = bridge.clone();
392                async move |input: FileOnlyInputs, _mcp_cx| {
393                    with_bridge_and_document(
394                        &bridge,
395                        None,
396                        &input.file_path,
397                        async move |client, uri| {
398                            let result = client
399                                .diagnostics(uri)
400                                .await
401                                .map_err(|e| anyhow!("Diagnostics request failed: {}", e))?;
402                            Ok(serde_json::to_string(&result)?)
403                        },
404                    )
405                    .await
406                }
407            },
408            sacp::tool_fn_mut!(),
409        )
410        */
411        .tool_fn_mut(
412            "rust_analyzer_failed_obligations",
413            "Get failed trait obligations for debugging (rust-analyzer specific)",
414            {
415                let bridge = bridge.clone();
416                let state = failed_obligations_state.clone();
417                async move |input: FilePositionInputs, _mcp_cx| {
418                    let mut bridge_guard = bridge.lock().await;
419                    let uri = ensure_document_open(&mut bridge_guard, &input.file_path).await?;
420                    let doc = TextDocumentIdentifier { uri };
421                    let position = Position::new(input.line, input.character);
422
423                    let args = TextDocumentPositionParams {
424                        text_document: doc,
425                        position,
426                    };
427
428                    let mut state = state.lock().await;
429                    use std::ops::DerefMut;
430                    let state = state.deref_mut();
431                    let result = handle_failed_obligations(
432                        bridge_guard.client.as_ref().unwrap(),
433                        state,
434                        args,
435                    )
436                    .await?;
437
438                    Ok(serde_json::to_string(&result).map_err(|e| anyhow::Error::new(e))?)
439                }
440            },
441            sacp::tool_fn_mut!(),
442        )
443        .tool_fn_mut(
444            "rust_analyzer_failed_obligations_goal",
445            "Explore nested goals in failed trait obligations (rust-analyzer specific)",
446            {
447                let bridge = bridge.clone();
448                let state = failed_obligations_state.clone();
449                async move |input: GoalIndexInputs, _mcp_cx| {
450                    let bridge_guard = bridge.lock().await;
451                    let mut state = state.lock().await;
452                    use std::ops::DerefMut;
453                    let state = state.deref_mut();
454                    let result = handle_failed_obligations_goal(
455                        bridge_guard.client.as_ref().unwrap(),
456                        state,
457                        input,
458                    )
459                    .await?;
460
461                    Ok(serde_json::to_string(&result).map_err(|e| anyhow::Error::new(e))?)
462                }
463            },
464            sacp::tool_fn_mut!(),
465        )
466        .build();
467
468    Ok(server)
469}