syncable_cli/agent/ide/
client.rs

1//! MCP Client for IDE Communication
2//!
3//! Connects to the IDE's MCP server via HTTP SSE and provides methods
4//! for opening diffs and receiving notifications.
5
6use super::detect::{detect_ide, get_ide_process_info, IdeInfo, IdeProcessInfo};
7use super::types::*;
8use std::collections::HashMap;
9use std::env;
10use std::fs;
11use std::path::PathBuf;
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14use tokio::sync::{mpsc, oneshot};
15
16/// Result of a diff operation
17#[derive(Debug, Clone)]
18pub enum DiffResult {
19    /// User accepted the diff, possibly with edits
20    Accepted { content: String },
21    /// User rejected the diff
22    Rejected,
23}
24
25/// IDE connection state
26#[derive(Debug, Clone, PartialEq)]
27pub enum ConnectionStatus {
28    Connected,
29    Disconnected,
30    Connecting,
31}
32
33/// Errors that can occur during IDE operations
34#[derive(Debug, thiserror::Error)]
35pub enum IdeError {
36    #[error("IDE not detected")]
37    NotDetected,
38    #[error("Connection failed: {0}")]
39    ConnectionFailed(String),
40    #[error("Request failed: {0}")]
41    RequestFailed(String),
42    #[error("No response received")]
43    NoResponse,
44    #[error("Operation cancelled")]
45    Cancelled,
46    #[error("IO error: {0}")]
47    Io(#[from] std::io::Error),
48}
49
50/// MCP Client for IDE communication
51#[derive(Debug)]
52pub struct IdeClient {
53    /// HTTP client
54    http_client: reqwest::Client,
55    /// Connection state
56    status: Arc<Mutex<ConnectionStatus>>,
57    /// Detected IDE info
58    ide_info: Option<IdeInfo>,
59    /// IDE process info
60    process_info: Option<IdeProcessInfo>,
61    /// Server port
62    port: Option<u16>,
63    /// Auth token
64    auth_token: Option<String>,
65    /// Session ID for MCP
66    session_id: Arc<Mutex<Option<String>>>,
67    /// Request ID counter
68    request_id: Arc<Mutex<u64>>,
69    /// Pending diff responses
70    diff_responses: Arc<Mutex<HashMap<String, oneshot::Sender<DiffResult>>>>,
71    /// SSE event receiver
72    sse_receiver: Option<mpsc::Receiver<JsonRpcNotification>>,
73}
74
75impl IdeClient {
76    /// Create a new IDE client (does not connect automatically)
77    pub async fn new() -> Self {
78        let process_info = get_ide_process_info().await;
79        let ide_info = detect_ide(process_info.as_ref());
80
81        Self {
82            http_client: reqwest::Client::builder()
83                .timeout(Duration::from_secs(30))
84                .build()
85                .unwrap_or_default(),
86            status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
87            ide_info,
88            process_info,
89            port: None,
90            auth_token: None,
91            session_id: Arc::new(Mutex::new(None)),
92            request_id: Arc::new(Mutex::new(0)),
93            diff_responses: Arc::new(Mutex::new(HashMap::new())),
94            sse_receiver: None,
95        }
96    }
97
98    /// Check if IDE integration is available
99    pub fn is_ide_available(&self) -> bool {
100        self.ide_info.is_some()
101    }
102
103    /// Get the detected IDE name
104    pub fn ide_name(&self) -> Option<&str> {
105        self.ide_info.as_ref().map(|i| i.display_name.as_str())
106    }
107
108    /// Check if connected to IDE
109    pub fn is_connected(&self) -> bool {
110        *self.status.lock().unwrap() == ConnectionStatus::Connected
111    }
112
113    /// Get connection status
114    pub fn status(&self) -> ConnectionStatus {
115        self.status.lock().unwrap().clone()
116    }
117
118    /// Try to connect to the IDE server
119    pub async fn connect(&mut self) -> Result<(), IdeError> {
120        if self.ide_info.is_none() {
121            return Err(IdeError::NotDetected);
122        }
123
124        *self.status.lock().unwrap() = ConnectionStatus::Connecting;
125
126        // Try to read connection config from file
127        if let Some(config) = self.read_connection_config().await {
128            self.port = Some(config.port);
129            self.auth_token = config.auth_token.clone();
130
131            // Try to establish connection
132            if self.establish_connection().await.is_ok() {
133                *self.status.lock().unwrap() = ConnectionStatus::Connected;
134                return Ok(());
135            }
136        }
137
138        // Try environment variables as fallback
139        if let Ok(port_str) = env::var("SYNCABLE_CLI_IDE_SERVER_PORT") {
140            if let Ok(port) = port_str.parse::<u16>() {
141                self.port = Some(port);
142                self.auth_token = env::var("SYNCABLE_CLI_IDE_AUTH_TOKEN").ok();
143
144                if self.establish_connection().await.is_ok() {
145                    *self.status.lock().unwrap() = ConnectionStatus::Connected;
146                    return Ok(());
147                }
148            }
149        }
150
151        *self.status.lock().unwrap() = ConnectionStatus::Disconnected;
152        Err(IdeError::ConnectionFailed(
153            "Could not connect to IDE companion extension".to_string(),
154        ))
155    }
156
157    /// Read connection config from port file
158    /// Supports both Syncable and Gemini CLI companion extensions
159    async fn read_connection_config(&self) -> Option<ConnectionConfig> {
160        let temp_dir = env::temp_dir();
161
162        // Debug: show where we're looking
163        if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
164            eprintln!("[IDE Debug] Looking for port files in temp_dir: {:?}", temp_dir);
165        }
166
167        // Try Syncable extension first - scan all port files, match by workspace
168        let syncable_port_dir = temp_dir.join("syncable").join("ide");
169        if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
170            eprintln!("[IDE Debug] Checking Syncable dir: {:?} (exists: {})",
171                     syncable_port_dir, syncable_port_dir.exists());
172        }
173        if let Some(config) = self.find_port_file_by_workspace(&syncable_port_dir, "syncable-ide-server") {
174            if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
175                eprintln!("[IDE Debug] Found Syncable config: port={}", config.port);
176            }
177            return Some(config);
178        }
179
180        // Try Gemini CLI extension (for compatibility)
181        let gemini_port_dir = temp_dir.join("gemini").join("ide");
182        if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
183            eprintln!("[IDE Debug] Checking Gemini dir: {:?} (exists: {})",
184                     gemini_port_dir, gemini_port_dir.exists());
185        }
186        if let Some(config) = self.find_port_file_by_workspace(&gemini_port_dir, "gemini-ide-server") {
187            if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
188                eprintln!("[IDE Debug] Found Gemini config: port={}", config.port);
189            }
190            return Some(config);
191        }
192
193        if cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok() {
194            eprintln!("[IDE Debug] No port file found in either location");
195        }
196        None
197    }
198
199    /// Find a port file in a directory by scanning all files and matching workspace path
200    fn find_port_file_by_workspace(&self, dir: &PathBuf, prefix: &str) -> Option<ConnectionConfig> {
201        let entries = fs::read_dir(dir).ok()?;
202
203        let debug = cfg!(debug_assertions) || env::var("SYNCABLE_DEBUG").is_ok();
204
205        for entry in entries.flatten() {
206            let filename = entry.file_name().to_string_lossy().to_string();
207            // Match any file starting with the prefix and ending with .json
208            if filename.starts_with(prefix) && filename.ends_with(".json") {
209                if debug {
210                    eprintln!("[IDE Debug] Found port file: {:?}", entry.path());
211                }
212                if let Ok(content) = fs::read_to_string(entry.path()) {
213                    if let Ok(config) = serde_json::from_str::<ConnectionConfig>(&content) {
214                        if debug {
215                            eprintln!("[IDE Debug] Config workspace_path: {:?}", config.workspace_path);
216                        }
217                        if self.validate_workspace_path(&config.workspace_path) {
218                            return Some(config);
219                        } else if debug {
220                            let cwd = env::current_dir().ok();
221                            eprintln!("[IDE Debug] Workspace path did not match cwd: {:?}", cwd);
222                        }
223                    }
224                }
225            }
226        }
227        None
228    }
229
230    /// Validate that the workspace path matches our current directory
231    fn validate_workspace_path(&self, workspace_path: &Option<String>) -> bool {
232        let Some(ws_path) = workspace_path else {
233            return false;
234        };
235
236        if ws_path.is_empty() {
237            return false;
238        }
239
240        let cwd = match env::current_dir() {
241            Ok(p) => p,
242            Err(_) => return false,
243        };
244
245        // Check if cwd is within any of the workspace paths
246        for path in ws_path.split(std::path::MAIN_SEPARATOR) {
247            let ws = PathBuf::from(path);
248            if cwd.starts_with(&ws) || ws.starts_with(&cwd) {
249                return true;
250            }
251        }
252
253        false
254    }
255
256    /// Establish HTTP connection and initialize MCP session
257    async fn establish_connection(&mut self) -> Result<(), IdeError> {
258        let port = self.port.ok_or(IdeError::ConnectionFailed("No port".to_string()))?;
259        let url = format!("http://127.0.0.1:{}/mcp", port);
260
261        // Build initialize request
262        let init_request = JsonRpcRequest::new(
263            self.next_request_id(),
264            "initialize",
265            serde_json::to_value(InitializeParams {
266                protocol_version: "2024-11-05".to_string(),
267                client_info: ClientInfo {
268                    name: "syncable-cli".to_string(),
269                    version: env!("CARGO_PKG_VERSION").to_string(),
270                },
271                capabilities: ClientCapabilities {},
272            })
273            .unwrap(),
274        );
275
276        // Send initialize request
277        let mut request = self.http_client
278            .post(&url)
279            .header("Accept", "application/json, text/event-stream")
280            .json(&init_request);
281
282        if let Some(token) = &self.auth_token {
283            request = request.header("Authorization", format!("Bearer {}", token));
284        }
285
286        let response = request
287            .send()
288            .await
289            .map_err(|e| IdeError::ConnectionFailed(e.to_string()))?;
290
291        // Get session ID from response header
292        if let Some(session_id) = response.headers().get("mcp-session-id") {
293            if let Ok(id) = session_id.to_str() {
294                *self.session_id.lock().unwrap() = Some(id.to_string());
295            }
296        }
297
298        // Parse response (SSE format: "event: message\ndata: {json}")
299        let response_text = response
300            .text()
301            .await
302            .map_err(|e| IdeError::ConnectionFailed(e.to_string()))?;
303
304        let response_data: JsonRpcResponse = Self::parse_sse_response(&response_text)
305            .map_err(IdeError::ConnectionFailed)?;
306
307        if response_data.error.is_some() {
308            return Err(IdeError::ConnectionFailed(
309                response_data
310                    .error
311                    .map(|e| e.message)
312                    .unwrap_or_default(),
313            ));
314        }
315
316        Ok(())
317    }
318
319    /// Parse SSE response format to extract JSON
320    fn parse_sse_response(text: &str) -> Result<JsonRpcResponse, String> {
321        // SSE format: "event: message\ndata: {json}\n\n"
322        for line in text.lines() {
323            if let Some(json_str) = line.strip_prefix("data: ") {
324                return serde_json::from_str(json_str)
325                    .map_err(|e| format!("Failed to parse JSON: {}", e));
326            }
327        }
328        // Fallback: try parsing entire response as JSON (for non-SSE responses)
329        serde_json::from_str(text)
330            .map_err(|e| format!("Failed to parse response: {}", e))
331    }
332
333    /// Get next request ID
334    fn next_request_id(&self) -> u64 {
335        let mut id = self.request_id.lock().unwrap();
336        *id += 1;
337        *id
338    }
339
340    /// Send an MCP request
341    async fn send_request(
342        &self,
343        method: &str,
344        params: serde_json::Value,
345    ) -> Result<JsonRpcResponse, IdeError> {
346        let port = self.port.ok_or(IdeError::ConnectionFailed("Not connected".to_string()))?;
347        let url = format!("http://127.0.0.1:{}/mcp", port);
348
349        let request = JsonRpcRequest::new(self.next_request_id(), method, params);
350
351        let mut http_request = self.http_client
352            .post(&url)
353            .header("Accept", "application/json, text/event-stream")
354            .json(&request);
355
356        if let Some(token) = &self.auth_token {
357            http_request = http_request.header("Authorization", format!("Bearer {}", token));
358        }
359
360        if let Some(session_id) = &*self.session_id.lock().unwrap() {
361            http_request = http_request.header("mcp-session-id", session_id);
362        }
363
364        let response = http_request
365            .send()
366            .await
367            .map_err(|e| IdeError::RequestFailed(e.to_string()))?;
368
369        let response_text = response
370            .text()
371            .await
372            .map_err(|e| IdeError::RequestFailed(e.to_string()))?;
373
374        Self::parse_sse_response(&response_text)
375            .map_err(IdeError::RequestFailed)
376    }
377
378    /// Open a diff view in the IDE
379    ///
380    /// This sends the file path and new content to the IDE, which will show
381    /// a diff view. The method returns when the user accepts or rejects the diff.
382    pub async fn open_diff(&self, file_path: &str, new_content: &str) -> Result<DiffResult, IdeError> {
383        if !self.is_connected() {
384            return Err(IdeError::ConnectionFailed("Not connected to IDE".to_string()));
385        }
386
387        let params = serde_json::to_value(ToolCallParams {
388            name: "openDiff".to_string(),
389            arguments: serde_json::to_value(OpenDiffArgs {
390                file_path: file_path.to_string(),
391                new_content: new_content.to_string(),
392            })
393            .unwrap(),
394        })
395        .unwrap();
396
397        // Create a channel to receive the diff result
398        let (tx, rx) = oneshot::channel();
399        {
400            let mut responses = self.diff_responses.lock().unwrap();
401            responses.insert(file_path.to_string(), tx);
402        }
403
404        // Send the openDiff request
405        let response = self.send_request("tools/call", params).await;
406
407        if let Err(e) = response {
408            // Remove the pending response
409            let mut responses = self.diff_responses.lock().unwrap();
410            responses.remove(file_path);
411            return Err(e);
412        }
413
414        // Wait for the notification (with timeout)
415        match tokio::time::timeout(Duration::from_secs(300), rx).await {
416            Ok(Ok(result)) => Ok(result),
417            Ok(Err(_)) => Err(IdeError::Cancelled),
418            Err(_) => {
419                // Timeout - remove pending response
420                let mut responses = self.diff_responses.lock().unwrap();
421                responses.remove(file_path);
422                Err(IdeError::NoResponse)
423            }
424        }
425    }
426
427    /// Close a diff view in the IDE
428    pub async fn close_diff(&self, file_path: &str) -> Result<Option<String>, IdeError> {
429        if !self.is_connected() {
430            return Err(IdeError::ConnectionFailed("Not connected to IDE".to_string()));
431        }
432
433        let params = serde_json::to_value(ToolCallParams {
434            name: "closeDiff".to_string(),
435            arguments: serde_json::to_value(CloseDiffArgs {
436                file_path: file_path.to_string(),
437                suppress_notification: Some(false),
438            })
439            .unwrap(),
440        })
441        .unwrap();
442
443        let response = self.send_request("tools/call", params).await?;
444
445        // Parse the response to get content if available
446        if let Some(result) = response.result {
447            if let Ok(tool_result) = serde_json::from_value::<ToolCallResult>(result) {
448                for content in tool_result.content {
449                    if content.content_type == "text" {
450                        if let Some(text) = content.text {
451                            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
452                                if let Some(content) = parsed.get("content").and_then(|c| c.as_str())
453                                {
454                                    return Ok(Some(content.to_string()));
455                                }
456                            }
457                        }
458                    }
459                }
460            }
461        }
462
463        Ok(None)
464    }
465
466    /// Handle an incoming notification from the IDE
467    pub fn handle_notification(&self, notification: JsonRpcNotification) {
468        match notification.method.as_str() {
469            "ide/diffAccepted" => {
470                if let Ok(params) =
471                    serde_json::from_value::<IdeDiffAcceptedParams>(notification.params)
472                {
473                    let mut responses = self.diff_responses.lock().unwrap();
474                    if let Some(tx) = responses.remove(&params.file_path) {
475                        let _ = tx.send(DiffResult::Accepted {
476                            content: params.content,
477                        });
478                    }
479                }
480            }
481            "ide/diffRejected" | "ide/diffClosed" => {
482                if let Ok(params) =
483                    serde_json::from_value::<IdeDiffRejectedParams>(notification.params)
484                {
485                    let mut responses = self.diff_responses.lock().unwrap();
486                    if let Some(tx) = responses.remove(&params.file_path) {
487                        let _ = tx.send(DiffResult::Rejected);
488                    }
489                }
490            }
491            "ide/contextUpdate" => {
492                // Handle IDE context updates (e.g., open files)
493                // This could be used to show relevant context in the agent
494            }
495            _ => {
496                // Unknown notification
497            }
498        }
499    }
500
501    /// Disconnect from the IDE
502    pub async fn disconnect(&mut self) {
503        // Close any pending diffs
504        let pending: Vec<String> = {
505            let responses = self.diff_responses.lock().unwrap();
506            responses.keys().cloned().collect()
507        };
508
509        for file_path in pending {
510            let _ = self.close_diff(&file_path).await;
511        }
512
513        *self.status.lock().unwrap() = ConnectionStatus::Disconnected;
514        *self.session_id.lock().unwrap() = None;
515    }
516}
517
518impl Default for IdeClient {
519    fn default() -> Self {
520        // Create with blocking runtime for sync context
521        tokio::runtime::Handle::current().block_on(Self::new())
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    #[tokio::test]
530    async fn test_ide_client_creation() {
531        let client = IdeClient::new().await;
532        assert!(!client.is_connected());
533    }
534
535    #[test]
536    fn test_diff_result() {
537        let accepted = DiffResult::Accepted {
538            content: "test".to_string(),
539        };
540        match accepted {
541            DiffResult::Accepted { content } => assert_eq!(content, "test"),
542            _ => panic!("Expected Accepted"),
543        }
544    }
545}