1use anyhow::anyhow;
2use lsp_types::{Position, TextDocumentIdentifier, TextDocumentPositionParams, Uri};
3use sacp::{ProxyToConductor, mcp_server::McpServer};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::{HashMap, HashSet};
8use std::path::{Path, PathBuf};
9use std::str::FromStr;
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13use crate::failed_obligations::{
14 FailedObligationsState, handle_failed_obligations, handle_failed_obligations_goal,
15};
16use crate::lsp_client::LspClient;
17
18pub type Result<T> = std::result::Result<T, sacp::Error>;
19
20pub struct BridgeState {
21 client: Option<LspClient>,
22 opened_documents: HashSet<String>,
23 document_versions: HashMap<String, i32>,
24}
25
26impl BridgeState {
27 pub fn new() -> Self {
28 Self {
29 client: None,
30 opened_documents: HashSet::new(),
31 document_versions: HashMap::new(),
32 }
33 }
34}
35
36pub type BridgeType = Arc<Mutex<BridgeState>>;
37
38#[derive(Serialize, Deserialize, JsonSchema, Debug)]
39pub struct FilePositionInputs {
40 pub file_path: String,
41 pub line: u32,
42 pub character: u32,
43}
44
45#[derive(Serialize, Deserialize, JsonSchema)]
46struct FileOnlyInputs {
47 pub file_path: String,
48}
49
50#[derive(Serialize, Deserialize, JsonSchema)]
62struct WorkspaceInputs {
63 pub workspace_path: String,
64}
65
66#[derive(Serialize, Deserialize, JsonSchema)]
67pub struct GoalIndexInputs {
68 pub goal_index: Value,
69}
70
71pub const SERVER_ID: &str = "rust-analyzer";
72
73pub(crate) async fn ensure_bridge(bridge: &BridgeType, workspace_path: Option<&str>) -> Result<()> {
74 let mut bridge_guard = bridge.lock().await;
75 if bridge_guard.client.is_none() || workspace_path.is_some() {
76 let workspace = workspace_path
77 .map(PathBuf::from)
78 .unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
79
80 tracing::debug!(?workspace);
81
82 let root_uri = Uri::from_str(&format!("file://{}", workspace.display()))
83 .map_err(|e| anyhow!("Invalid workspace path: {}", e))?;
84
85 tracing::debug!(?root_uri);
86
87 let client = LspClient::new("rust-analyzer", &[], root_uri)
88 .await
89 .map_err(|e| anyhow!("Failed to start rust-analyzer: {}", e))?;
90
91 wait_for_start(&client).await;
92
93 bridge_guard.client = Some(client);
94 bridge_guard.opened_documents.clear();
95 bridge_guard.document_versions.clear();
96 }
97 Ok(())
98}
99
100pub(crate) async fn with_bridge<F, R>(
101 bridge: &BridgeType,
102 workspace_path: Option<&str>,
103 f: F,
104) -> Result<R>
105where
106 F: for<'a> AsyncFnOnce(&'a LspClient) -> Result<R>,
107{
108 ensure_bridge(bridge, workspace_path).await?;
109 let bridge_guard = bridge.lock().await;
110 f(bridge_guard.client.as_ref().unwrap()).await
111}
112
113pub async fn with_bridge_and_document<F, R>(
114 bridge: &BridgeType,
115 workspace_path: Option<&str>,
116 file_path: &str,
117 f: F,
118) -> Result<R>
119where
120 F: for<'a> AsyncFnOnce(&'a LspClient, Uri) -> Result<R>,
121{
122 ensure_bridge(bridge, workspace_path).await?;
123 let mut bridge_guard = bridge.lock().await;
124 let uri = ensure_document_open(&mut bridge_guard, file_path).await?;
125 f(bridge_guard.client.as_ref().unwrap(), uri).await
126}
127
128async fn wait_for_start(lsp: &LspClient) {
129 let _ = lsp
130 .subscribe_notification::<(), _>(
131 "experimental/serverStatus".to_string(),
132 |value: serde_json::Value| {
133 Box::pin(async move {
134 let Some(q) = value.get("quiescent").and_then(|q| q.as_bool()) else {
135 return Err(anyhow::anyhow!("quiescent not found or invalid"));
136 };
137 if q { Ok(Some(())) } else { Ok(None) }
138 })
139 },
140 )
141 .await;
142}
143
144fn file_path_to_uri(file_path: &str) -> anyhow::Result<Uri> {
145 if file_path.starts_with("file://") {
146 Uri::from_str(file_path).map_err(|e| anyhow!("Invalid URI: {}", e))
147 } else {
148 Uri::from_str(&format!("file://{}", file_path))
149 .map_err(|e| anyhow!("Invalid file path: {}", e))
150 }
151}
152
153async fn ensure_document_open(bridge_state: &mut BridgeState, file_path: &str) -> Result<Uri> {
154 let file_path = Path::new(file_path);
155 let file_path =
156 std::fs::canonicalize(file_path).map_err(|e| anyhow!("Invalid file path: {}", e))?;
157 let file_path = file_path
158 .to_str()
159 .ok_or_else(|| anyhow!("Invalid file path"))?;
160 let uri = file_path_to_uri(file_path)?;
161 let uri_str = uri.to_string();
162
163 if !bridge_state.opened_documents.contains(&uri_str) {
165 if let Ok(content) = std::fs::read_to_string(file_path) {
166 if let Some(client) = &bridge_state.client {
167 let version = bridge_state
168 .document_versions
169 .get(&uri_str)
170 .copied()
171 .unwrap_or(1);
172 client
173 .did_open(uri.clone(), "rust".to_string(), version, content)
174 .await
175 .map_err(|e| anyhow!("Failed to open document: {}", e))?;
176 bridge_state.opened_documents.insert(uri_str.clone());
177 bridge_state.document_versions.insert(uri_str, version);
178 }
179 }
180 }
181
182 Ok(uri)
183}
184
185pub async fn build_server(
186 workspace_path: Option<String>,
187) -> Result<McpServer<ProxyToConductor, impl sacp::JrResponder<ProxyToConductor>>> {
188 let bridge: BridgeType = Arc::new(Mutex::new(BridgeState::new()));
189 with_bridge(&bridge, workspace_path.as_deref(), async |_client| Ok(())).await?;
190
191 let failed_obligations_state = Arc::new(Mutex::new(FailedObligationsState::new()));
192 let server = McpServer::builder("rust-analyzer-mcp".to_string())
193 .instructions(indoc::indoc! {"
194 Rust analyzer LSP integration for code analysis, navigation, and diagnostics.
195 "})
196 .tool_fn_mut(
197 "rust_analyzer_hover",
198 "Get hover information for a symbol at a specific position in a Rust file",
199 {
200 let bridge = bridge.clone();
201 async move |input: FilePositionInputs, _mcp_cx| {
202 with_bridge_and_document(
203 &bridge,
204 None,
205 &input.file_path,
206 async move |client, uri| {
207 let position = Position::new(input.line, input.character);
208 let result = client
209 .hover(uri, position)
210 .await
211 .map_err(|e| anyhow!("Hover request failed: {}", e))?;
212 Ok(serde_json::to_string(&result)?)
213 },
214 )
215 .await
216 }
217 },
218 sacp::tool_fn_mut!(),
219 )
220 .tool_fn_mut(
221 "rust_analyzer_definition",
222 "Go to definition of a symbol at a specific position",
223 {
224 let bridge = bridge.clone();
225 async move |input: FilePositionInputs, _mcp_cx| {
226 with_bridge_and_document(
227 &bridge,
228 None,
229 &input.file_path,
230 async move |client, uri| {
231 let position = Position::new(input.line, input.character);
232 let result = client
233 .goto_definition(uri, position)
234 .await
235 .map_err(|e| anyhow!("Definition request failed: {}", e))?;
236 Ok(serde_json::to_string(&result)?)
237 },
238 )
239 .await
240 }
241 },
242 sacp::tool_fn_mut!(),
243 )
244 .tool_fn_mut(
245 "rust_analyzer_references",
246 "Find all references to a symbol at a specific position",
247 {
248 let bridge = bridge.clone();
249 async move |input: FilePositionInputs, _mcp_cx| {
250 with_bridge_and_document(
251 &bridge,
252 None,
253 &input.file_path,
254 async move |client, uri| {
255 let position = Position::new(input.line, input.character);
256 let result = client
257 .find_references(uri, position, true)
258 .await
259 .map_err(|e| anyhow!("References request failed: {}", e))?;
260 Ok(serde_json::to_string(&result)?)
261 },
262 )
263 .await
264 }
265 },
266 sacp::tool_fn_mut!(),
267 )
268 .tool_fn_mut(
269 "rust_analyzer_completion",
270 "Get code completions at a specific position",
271 {
272 let bridge = bridge.clone();
273 async move |input: FilePositionInputs, _mcp_cx| {
274 with_bridge_and_document(
275 &bridge,
276 None,
277 &input.file_path,
278 async move |client, uri| {
279 let position = Position::new(input.line, input.character);
280 let result = client
281 .completion(uri, position)
282 .await
283 .map_err(|e| anyhow!("Completion request failed: {}", e))?;
284 Ok(serde_json::to_string(&result)?)
285 },
286 )
287 .await
288 }
289 },
290 sacp::tool_fn_mut!(),
291 )
292 .tool_fn_mut(
293 "rust_analyzer_symbols",
294 "Get document symbols for a Rust file",
295 {
296 let bridge = bridge.clone();
297 async move |input: FileOnlyInputs, _mcp_cx| {
298 with_bridge_and_document(
299 &bridge,
300 None,
301 &input.file_path,
302 async move |client, uri| {
303 let result = client
304 .document_symbols(uri)
305 .await
306 .map_err(|e| anyhow!("Document symbols request failed: {}", e))?;
307 Ok(serde_json::to_string(&result)?)
308 },
309 )
310 .await
311 }
312 },
313 sacp::tool_fn_mut!(),
314 )
315 .tool_fn_mut(
373 "rust_analyzer_set_workspace",
374 "Set the workspace root for rust-analyzer",
375 {
376 let bridge = bridge.clone();
377 async move |input: WorkspaceInputs, _mcp_cx| {
378 with_bridge(&bridge, Some(&input.workspace_path), async move |_client| {
379 Ok("Workspace set successfully".to_string())
380 })
381 .await
382 }
383 },
384 sacp::tool_fn_mut!(),
385 )
386 .tool_fn_mut(
412 "rust_analyzer_failed_obligations",
413 "Get failed trait obligations for debugging (rust-analyzer specific)",
414 {
415 let bridge = bridge.clone();
416 let state = failed_obligations_state.clone();
417 async move |input: FilePositionInputs, _mcp_cx| {
418 let mut bridge_guard = bridge.lock().await;
419 let uri = ensure_document_open(&mut bridge_guard, &input.file_path).await?;
420 let doc = TextDocumentIdentifier { uri };
421 let position = Position::new(input.line, input.character);
422
423 let args = TextDocumentPositionParams {
424 text_document: doc,
425 position,
426 };
427
428 let mut state = state.lock().await;
429 use std::ops::DerefMut;
430 let state = state.deref_mut();
431 let result = handle_failed_obligations(
432 bridge_guard.client.as_ref().unwrap(),
433 state,
434 args,
435 )
436 .await?;
437
438 Ok(serde_json::to_string(&result).map_err(|e| anyhow::Error::new(e))?)
439 }
440 },
441 sacp::tool_fn_mut!(),
442 )
443 .tool_fn_mut(
444 "rust_analyzer_failed_obligations_goal",
445 "Explore nested goals in failed trait obligations (rust-analyzer specific)",
446 {
447 let bridge = bridge.clone();
448 let state = failed_obligations_state.clone();
449 async move |input: GoalIndexInputs, _mcp_cx| {
450 let bridge_guard = bridge.lock().await;
451 let mut state = state.lock().await;
452 use std::ops::DerefMut;
453 let state = state.deref_mut();
454 let result = handle_failed_obligations_goal(
455 bridge_guard.client.as_ref().unwrap(),
456 state,
457 input,
458 )
459 .await?;
460
461 Ok(serde_json::to_string(&result).map_err(|e| anyhow::Error::new(e))?)
462 }
463 },
464 sacp::tool_fn_mut!(),
465 )
466 .build();
467
468 Ok(server)
469}