syncable_cli/agent/ide/
client.rs

1//! MCP Client for IDE Communication
2//!
3//! Connects to the IDE's MCP server via HTTP SSE and provides methods
4//! for opening diffs and receiving notifications.
5
6use super::detect::{IdeInfo, IdeProcessInfo, detect_ide, get_ide_process_info};
7use super::types::*;
8use std::collections::HashMap;
9use std::env;
10use std::fs;
11use std::path::PathBuf;
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14use tokio::sync::{mpsc, oneshot};
15
16/// Result of a diff operation
17#[derive(Debug, Clone)]
18pub enum DiffResult {
19    /// User accepted the diff, possibly with edits
20    Accepted { content: String },
21    /// User rejected the diff
22    Rejected,
23}
24
25/// IDE connection state
26#[derive(Debug, Clone, PartialEq)]
27pub enum ConnectionStatus {
28    Connected,
29    Disconnected,
30    Connecting,
31}
32
33/// Errors that can occur during IDE operations
34#[derive(Debug, thiserror::Error)]
35pub enum IdeError {
36    #[error("IDE not detected")]
37    NotDetected,
38    #[error("Connection failed: {0}")]
39    ConnectionFailed(String),
40    #[error("Request failed: {0}")]
41    RequestFailed(String),
42    #[error("No response received")]
43    NoResponse,
44    #[error("Operation cancelled")]
45    Cancelled,
46    #[error("IO error: {0}")]
47    Io(#[from] std::io::Error),
48}
49
50/// MCP Client for IDE communication
51#[derive(Debug)]
52pub struct IdeClient {
53    /// HTTP client
54    http_client: reqwest::Client,
55    /// Connection state
56    status: Arc<Mutex<ConnectionStatus>>,
57    /// Detected IDE info
58    ide_info: Option<IdeInfo>,
59    /// IDE process info
60    process_info: Option<IdeProcessInfo>,
61    /// Server port
62    port: Option<u16>,
63    /// Auth token
64    auth_token: Option<String>,
65    /// Session ID for MCP
66    session_id: Arc<Mutex<Option<String>>>,
67    /// Request ID counter
68    request_id: Arc<Mutex<u64>>,
69    /// Pending diff responses
70    diff_responses: Arc<Mutex<HashMap<String, oneshot::Sender<DiffResult>>>>,
71    /// SSE event receiver
72    sse_receiver: Option<mpsc::Receiver<JsonRpcNotification>>,
73}
74
75impl IdeClient {
76    /// Create a new IDE client (does not connect automatically)
77    pub async fn new() -> Self {
78        let process_info = get_ide_process_info().await;
79        let ide_info = detect_ide(process_info.as_ref());
80
81        Self {
82            http_client: reqwest::Client::builder()
83                .timeout(Duration::from_secs(30))
84                .build()
85                .unwrap_or_default(),
86            status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
87            ide_info,
88            process_info,
89            port: None,
90            auth_token: None,
91            session_id: Arc::new(Mutex::new(None)),
92            request_id: Arc::new(Mutex::new(0)),
93            diff_responses: Arc::new(Mutex::new(HashMap::new())),
94            sse_receiver: None,
95        }
96    }
97
98    /// Check if IDE integration is available
99    pub fn is_ide_available(&self) -> bool {
100        self.ide_info.is_some()
101    }
102
103    /// Get the detected IDE name
104    pub fn ide_name(&self) -> Option<&str> {
105        self.ide_info.as_ref().map(|i| i.display_name.as_str())
106    }
107
108    /// Check if connected to IDE
109    pub fn is_connected(&self) -> bool {
110        *self.status.lock().unwrap() == ConnectionStatus::Connected
111    }
112
113    /// Get connection status
114    pub fn status(&self) -> ConnectionStatus {
115        self.status.lock().unwrap().clone()
116    }
117
118    /// Try to connect to the IDE server
119    pub async fn connect(&mut self) -> Result<(), IdeError> {
120        if self.ide_info.is_none() {
121            return Err(IdeError::NotDetected);
122        }
123
124        *self.status.lock().unwrap() = ConnectionStatus::Connecting;
125
126        // Try to read connection config from file
127        if let Some(config) = self.read_connection_config().await {
128            self.port = Some(config.port);
129            self.auth_token = config.auth_token.clone();
130
131            // Try to establish connection
132            if self.establish_connection().await.is_ok() {
133                *self.status.lock().unwrap() = ConnectionStatus::Connected;
134                return Ok(());
135            }
136        }
137
138        // Try environment variables as fallback
139        if let Ok(port_str) = env::var("SYNCABLE_CLI_IDE_SERVER_PORT") {
140            if let Ok(port) = port_str.parse::<u16>() {
141                self.port = Some(port);
142                self.auth_token = env::var("SYNCABLE_CLI_IDE_AUTH_TOKEN").ok();
143
144                if self.establish_connection().await.is_ok() {
145                    *self.status.lock().unwrap() = ConnectionStatus::Connected;
146                    return Ok(());
147                }
148            }
149        }
150
151        *self.status.lock().unwrap() = ConnectionStatus::Disconnected;
152        Err(IdeError::ConnectionFailed(
153            "Could not connect to IDE companion extension".to_string(),
154        ))
155    }
156
157    /// Read connection config from port file
158    /// Supports both Syncable and Gemini CLI companion extensions
159    async fn read_connection_config(&self) -> Option<ConnectionConfig> {
160        let temp_dir = env::temp_dir();
161
162        // Debug: show where we're looking
163        if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
164            eprintln!(
165                "[IDE Debug] Looking for port files in temp_dir: {:?}",
166                temp_dir
167            );
168        }
169
170        // Try Syncable extension first - scan all port files, match by workspace
171        let syncable_port_dir = temp_dir.join("syncable").join("ide");
172        if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
173            eprintln!(
174                "[IDE Debug] Checking Syncable dir: {:?} (exists: {})",
175                syncable_port_dir,
176                syncable_port_dir.exists()
177            );
178        }
179        if let Some(config) =
180            self.find_port_file_by_workspace(&syncable_port_dir, "syncable-ide-server")
181        {
182            if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
183                eprintln!("[IDE Debug] Found Syncable config: port={}", config.port);
184            }
185            return Some(config);
186        }
187
188        // Try Gemini CLI extension (for compatibility)
189        let gemini_port_dir = temp_dir.join("gemini").join("ide");
190        if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
191            eprintln!(
192                "[IDE Debug] Checking Gemini dir: {:?} (exists: {})",
193                gemini_port_dir,
194                gemini_port_dir.exists()
195            );
196        }
197        if let Some(config) =
198            self.find_port_file_by_workspace(&gemini_port_dir, "gemini-ide-server")
199        {
200            if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
201                eprintln!("[IDE Debug] Found Gemini config: port={}", config.port);
202            }
203            return Some(config);
204        }
205
206        if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
207            eprintln!("[IDE Debug] No port file found in either location");
208        }
209        None
210    }
211
212    /// Find a port file in a directory by scanning all files and matching workspace path
213    fn find_port_file_by_workspace(&self, dir: &PathBuf, prefix: &str) -> Option<ConnectionConfig> {
214        let entries = fs::read_dir(dir).ok()?;
215
216        let debug = cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok();
217
218        for entry in entries.flatten() {
219            let filename = entry.file_name().to_string_lossy().to_string();
220            // Match any file starting with the prefix and ending with .json
221            if filename.starts_with(prefix) && filename.ends_with(".json") {
222                if debug {
223                    eprintln!("[IDE Debug] Found port file: {:?}", entry.path());
224                }
225                if let Ok(content) = fs::read_to_string(entry.path()) {
226                    if let Ok(config) = serde_json::from_str::<ConnectionConfig>(&content) {
227                        if debug {
228                            eprintln!(
229                                "[IDE Debug] Config workspace_path: {:?}",
230                                config.workspace_path
231                            );
232                        }
233                        if self.validate_workspace_path(&config.workspace_path) {
234                            return Some(config);
235                        } else if debug {
236                            let cwd = env::current_dir().ok();
237                            eprintln!("[IDE Debug] Workspace path did not match cwd: {:?}", cwd);
238                        }
239                    }
240                }
241            }
242        }
243        None
244    }
245
246    /// Validate that the workspace path matches our current directory
247    fn validate_workspace_path(&self, workspace_path: &Option<String>) -> bool {
248        let Some(ws_path) = workspace_path else {
249            return false;
250        };
251
252        if ws_path.is_empty() {
253            return false;
254        }
255
256        let cwd = match env::current_dir() {
257            Ok(p) => p,
258            Err(_) => return false,
259        };
260
261        // Check if cwd is within any of the workspace paths
262        for path in ws_path.split(std::path::MAIN_SEPARATOR) {
263            let ws = PathBuf::from(path);
264            if cwd.starts_with(&ws) || ws.starts_with(&cwd) {
265                return true;
266            }
267        }
268
269        false
270    }
271
272    /// Establish HTTP connection and initialize MCP session
273    async fn establish_connection(&mut self) -> Result<(), IdeError> {
274        let port = self
275            .port
276            .ok_or(IdeError::ConnectionFailed("No port".to_string()))?;
277        let url = format!("http://127.0.0.1:{}/mcp", port);
278
279        // Build initialize request
280        let init_request = JsonRpcRequest::new(
281            self.next_request_id(),
282            "initialize",
283            serde_json::to_value(InitializeParams {
284                protocol_version: "2024-11-05".to_string(),
285                client_info: ClientInfo {
286                    name: "syncable-cli".to_string(),
287                    version: env!("CARGO_PKG_VERSION").to_string(),
288                },
289                capabilities: ClientCapabilities {},
290            })
291            .unwrap(),
292        );
293
294        // Send initialize request
295        let mut request = self
296            .http_client
297            .post(&url)
298            .header("Accept", "application/json, text/event-stream")
299            .json(&init_request);
300
301        if let Some(token) = &self.auth_token {
302            request = request.header("Authorization", format!("Bearer {}", token));
303        }
304
305        let response = request
306            .send()
307            .await
308            .map_err(|e| IdeError::ConnectionFailed(e.to_string()))?;
309
310        // Get session ID from response header
311        if let Some(session_id) = response.headers().get("mcp-session-id") {
312            if let Ok(id) = session_id.to_str() {
313                *self.session_id.lock().unwrap() = Some(id.to_string());
314            }
315        }
316
317        // Parse response (SSE format: "event: message\ndata: {json}")
318        let response_text = response
319            .text()
320            .await
321            .map_err(|e| IdeError::ConnectionFailed(e.to_string()))?;
322
323        let response_data: JsonRpcResponse =
324            Self::parse_sse_response(&response_text).map_err(IdeError::ConnectionFailed)?;
325
326        if response_data.error.is_some() {
327            return Err(IdeError::ConnectionFailed(
328                response_data.error.map(|e| e.message).unwrap_or_default(),
329            ));
330        }
331
332        Ok(())
333    }
334
335    /// Parse SSE response format to extract JSON
336    fn parse_sse_response(text: &str) -> Result<JsonRpcResponse, String> {
337        // SSE format: "event: message\ndata: {json}\n\n"
338        for line in text.lines() {
339            if let Some(json_str) = line.strip_prefix("data: ") {
340                return serde_json::from_str(json_str)
341                    .map_err(|e| format!("Failed to parse JSON: {}", e));
342            }
343        }
344        // Fallback: try parsing entire response as JSON (for non-SSE responses)
345        serde_json::from_str(text).map_err(|e| format!("Failed to parse response: {}", e))
346    }
347
348    /// Get next request ID
349    fn next_request_id(&self) -> u64 {
350        let mut id = self.request_id.lock().unwrap();
351        *id += 1;
352        *id
353    }
354
355    /// Send an MCP request
356    async fn send_request(
357        &self,
358        method: &str,
359        params: serde_json::Value,
360    ) -> Result<JsonRpcResponse, IdeError> {
361        let port = self
362            .port
363            .ok_or(IdeError::ConnectionFailed("Not connected".to_string()))?;
364        let url = format!("http://127.0.0.1:{}/mcp", port);
365
366        let request = JsonRpcRequest::new(self.next_request_id(), method, params);
367
368        let mut http_request = self
369            .http_client
370            .post(&url)
371            .header("Accept", "application/json, text/event-stream")
372            .json(&request);
373
374        if let Some(token) = &self.auth_token {
375            http_request = http_request.header("Authorization", format!("Bearer {}", token));
376        }
377
378        if let Some(session_id) = &*self.session_id.lock().unwrap() {
379            http_request = http_request.header("mcp-session-id", session_id);
380        }
381
382        let response = http_request
383            .send()
384            .await
385            .map_err(|e| IdeError::RequestFailed(e.to_string()))?;
386
387        let response_text = response
388            .text()
389            .await
390            .map_err(|e| IdeError::RequestFailed(e.to_string()))?;
391
392        Self::parse_sse_response(&response_text).map_err(IdeError::RequestFailed)
393    }
394
395    /// Open a diff view in the IDE
396    ///
397    /// This sends the file path and new content to the IDE, which will show
398    /// a diff view. The method returns when the user accepts or rejects the diff.
399    pub async fn open_diff(
400        &self,
401        file_path: &str,
402        new_content: &str,
403    ) -> Result<DiffResult, IdeError> {
404        if !self.is_connected() {
405            return Err(IdeError::ConnectionFailed(
406                "Not connected to IDE".to_string(),
407            ));
408        }
409
410        let params = serde_json::to_value(ToolCallParams {
411            name: "openDiff".to_string(),
412            arguments: serde_json::to_value(OpenDiffArgs {
413                file_path: file_path.to_string(),
414                new_content: new_content.to_string(),
415            })
416            .unwrap(),
417        })
418        .unwrap();
419
420        // Create a channel to receive the diff result
421        let (tx, rx) = oneshot::channel();
422        {
423            let mut responses = self.diff_responses.lock().unwrap();
424            responses.insert(file_path.to_string(), tx);
425        }
426
427        // Send the openDiff request
428        let response = self.send_request("tools/call", params).await;
429
430        if let Err(e) = response {
431            // Remove the pending response
432            let mut responses = self.diff_responses.lock().unwrap();
433            responses.remove(file_path);
434            return Err(e);
435        }
436
437        // Wait for the notification (with timeout)
438        match tokio::time::timeout(Duration::from_secs(300), rx).await {
439            Ok(Ok(result)) => Ok(result),
440            Ok(Err(_)) => Err(IdeError::Cancelled),
441            Err(_) => {
442                // Timeout - remove pending response
443                let mut responses = self.diff_responses.lock().unwrap();
444                responses.remove(file_path);
445                Err(IdeError::NoResponse)
446            }
447        }
448    }
449
450    /// Close a diff view in the IDE
451    pub async fn close_diff(&self, file_path: &str) -> Result<Option<String>, IdeError> {
452        if !self.is_connected() {
453            return Err(IdeError::ConnectionFailed(
454                "Not connected to IDE".to_string(),
455            ));
456        }
457
458        let params = serde_json::to_value(ToolCallParams {
459            name: "closeDiff".to_string(),
460            arguments: serde_json::to_value(CloseDiffArgs {
461                file_path: file_path.to_string(),
462                suppress_notification: Some(false),
463            })
464            .unwrap(),
465        })
466        .unwrap();
467
468        let response = self.send_request("tools/call", params).await?;
469
470        // Parse the response to get content if available
471        if let Some(result) = response.result {
472            if let Ok(tool_result) = serde_json::from_value::<ToolCallResult>(result) {
473                for content in tool_result.content {
474                    if content.content_type == "text" {
475                        if let Some(text) = content.text {
476                            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
477                                if let Some(content) =
478                                    parsed.get("content").and_then(|c| c.as_str())
479                                {
480                                    return Ok(Some(content.to_string()));
481                                }
482                            }
483                        }
484                    }
485                }
486            }
487        }
488
489        Ok(None)
490    }
491
492    /// Handle an incoming notification from the IDE
493    pub fn handle_notification(&self, notification: JsonRpcNotification) {
494        match notification.method.as_str() {
495            "ide/diffAccepted" => {
496                if let Ok(params) =
497                    serde_json::from_value::<IdeDiffAcceptedParams>(notification.params)
498                {
499                    let mut responses = self.diff_responses.lock().unwrap();
500                    if let Some(tx) = responses.remove(&params.file_path) {
501                        let _ = tx.send(DiffResult::Accepted {
502                            content: params.content,
503                        });
504                    }
505                }
506            }
507            "ide/diffRejected" | "ide/diffClosed" => {
508                if let Ok(params) =
509                    serde_json::from_value::<IdeDiffRejectedParams>(notification.params)
510                {
511                    let mut responses = self.diff_responses.lock().unwrap();
512                    if let Some(tx) = responses.remove(&params.file_path) {
513                        let _ = tx.send(DiffResult::Rejected);
514                    }
515                }
516            }
517            "ide/contextUpdate" => {
518                // Handle IDE context updates (e.g., open files)
519                // This could be used to show relevant context in the agent
520            }
521            _ => {
522                // Unknown notification
523            }
524        }
525    }
526
527    /// Get diagnostics from the IDE's language servers
528    ///
529    /// This queries the IDE for all diagnostic messages (errors, warnings, etc.)
530    /// from the active language servers (rust-analyzer, ESLint, TypeScript, etc.)
531    ///
532    /// If `file_path` is provided, returns diagnostics only for that file.
533    /// Otherwise returns all diagnostics across the workspace.
534    pub async fn get_diagnostics(
535        &self,
536        file_path: Option<&str>,
537    ) -> Result<DiagnosticsResponse, IdeError> {
538        if !self.is_connected() {
539            return Err(IdeError::ConnectionFailed(
540                "Not connected to IDE".to_string(),
541            ));
542        }
543
544        let params = serde_json::to_value(ToolCallParams {
545            name: "getDiagnostics".to_string(),
546            arguments: serde_json::to_value(GetDiagnosticsArgs {
547                uri: file_path.map(|p| format!("file://{}", p)),
548            })
549            .unwrap(),
550        })
551        .unwrap();
552
553        let response = self.send_request("tools/call", params).await?;
554
555        // Parse the response
556        if let Some(result) = response.result {
557            if let Ok(tool_result) = serde_json::from_value::<ToolCallResult>(result) {
558                // Look for the text content with diagnostics
559                for content in tool_result.content {
560                    if content.content_type == "text" {
561                        if let Some(text) = content.text {
562                            // Try to parse as DiagnosticsResponse
563                            if let Ok(diag_response) =
564                                serde_json::from_str::<DiagnosticsResponse>(&text)
565                            {
566                                return Ok(diag_response);
567                            }
568                            // Try parsing as raw array of diagnostics
569                            if let Ok(diagnostics) = serde_json::from_str::<Vec<Diagnostic>>(&text)
570                            {
571                                let total_errors = diagnostics
572                                    .iter()
573                                    .filter(|d| d.severity == DiagnosticSeverity::Error)
574                                    .count()
575                                    as u32;
576                                let total_warnings = diagnostics
577                                    .iter()
578                                    .filter(|d| d.severity == DiagnosticSeverity::Warning)
579                                    .count()
580                                    as u32;
581                                return Ok(DiagnosticsResponse {
582                                    diagnostics,
583                                    total_errors,
584                                    total_warnings,
585                                });
586                            }
587                        }
588                    }
589                }
590            }
591        }
592
593        // No diagnostics found - return empty response
594        Ok(DiagnosticsResponse {
595            diagnostics: Vec::new(),
596            total_errors: 0,
597            total_warnings: 0,
598        })
599    }
600
601    /// Disconnect from the IDE
602    pub async fn disconnect(&mut self) {
603        // Close any pending diffs
604        let pending: Vec<String> = {
605            let responses = self.diff_responses.lock().unwrap();
606            responses.keys().cloned().collect()
607        };
608
609        for file_path in pending {
610            let _ = self.close_diff(&file_path).await;
611        }
612
613        *self.status.lock().unwrap() = ConnectionStatus::Disconnected;
614        *self.session_id.lock().unwrap() = None;
615    }
616}
617
618impl Default for IdeClient {
619    fn default() -> Self {
620        // Create with blocking runtime for sync context
621        tokio::runtime::Handle::current().block_on(Self::new())
622    }
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628
629    #[tokio::test]
630    async fn test_ide_client_creation() {
631        let client = IdeClient::new().await;
632        assert!(!client.is_connected());
633    }
634
635    #[test]
636    fn test_diff_result() {
637        let accepted = DiffResult::Accepted {
638            content: "test".to_string(),
639        };
640        match accepted {
641            DiffResult::Accepted { content } => assert_eq!(content, "test"),
642            _ => panic!("Expected Accepted"),
643        }
644    }
645}