syncable_cli/agent/ide/
client.rs

1//! MCP Client for IDE Communication
2//!
3//! Connects to the IDE's MCP server via HTTP SSE and provides methods
4//! for opening diffs and receiving notifications.
5
6use super::detect::{detect_ide, get_ide_process_info, IdeInfo, IdeProcessInfo};
7use super::types::*;
8use std::collections::HashMap;
9use std::env;
10use std::fs;
11use std::path::PathBuf;
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14use tokio::sync::{mpsc, oneshot};
15
16/// Result of a diff operation
17#[derive(Debug, Clone)]
18pub enum DiffResult {
19    /// User accepted the diff, possibly with edits
20    Accepted { content: String },
21    /// User rejected the diff
22    Rejected,
23}
24
25/// IDE connection state
26#[derive(Debug, Clone, PartialEq)]
27pub enum ConnectionStatus {
28    Connected,
29    Disconnected,
30    Connecting,
31}
32
33/// Errors that can occur during IDE operations
34#[derive(Debug, thiserror::Error)]
35pub enum IdeError {
36    #[error("IDE not detected")]
37    NotDetected,
38    #[error("Connection failed: {0}")]
39    ConnectionFailed(String),
40    #[error("Request failed: {0}")]
41    RequestFailed(String),
42    #[error("No response received")]
43    NoResponse,
44    #[error("Operation cancelled")]
45    Cancelled,
46    #[error("IO error: {0}")]
47    Io(#[from] std::io::Error),
48}
49
50/// MCP Client for IDE communication
51#[derive(Debug)]
52pub struct IdeClient {
53    /// HTTP client
54    http_client: reqwest::Client,
55    /// Connection state
56    status: Arc<Mutex<ConnectionStatus>>,
57    /// Detected IDE info
58    ide_info: Option<IdeInfo>,
59    /// IDE process info
60    process_info: Option<IdeProcessInfo>,
61    /// Server port
62    port: Option<u16>,
63    /// Auth token
64    auth_token: Option<String>,
65    /// Session ID for MCP
66    session_id: Arc<Mutex<Option<String>>>,
67    /// Request ID counter
68    request_id: Arc<Mutex<u64>>,
69    /// Pending diff responses
70    diff_responses: Arc<Mutex<HashMap<String, oneshot::Sender<DiffResult>>>>,
71    /// SSE event receiver
72    sse_receiver: Option<mpsc::Receiver<JsonRpcNotification>>,
73}
74
75impl IdeClient {
76    /// Create a new IDE client (does not connect automatically)
77    pub async fn new() -> Self {
78        let process_info = get_ide_process_info().await;
79        let ide_info = detect_ide(process_info.as_ref());
80
81        Self {
82            http_client: reqwest::Client::builder()
83                .timeout(Duration::from_secs(30))
84                .build()
85                .unwrap_or_default(),
86            status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
87            ide_info,
88            process_info,
89            port: None,
90            auth_token: None,
91            session_id: Arc::new(Mutex::new(None)),
92            request_id: Arc::new(Mutex::new(0)),
93            diff_responses: Arc::new(Mutex::new(HashMap::new())),
94            sse_receiver: None,
95        }
96    }
97
98    /// Check if IDE integration is available
99    pub fn is_ide_available(&self) -> bool {
100        self.ide_info.is_some()
101    }
102
103    /// Get the detected IDE name
104    pub fn ide_name(&self) -> Option<&str> {
105        self.ide_info.as_ref().map(|i| i.display_name.as_str())
106    }
107
108    /// Check if connected to IDE
109    pub fn is_connected(&self) -> bool {
110        *self.status.lock().unwrap() == ConnectionStatus::Connected
111    }
112
113    /// Get connection status
114    pub fn status(&self) -> ConnectionStatus {
115        self.status.lock().unwrap().clone()
116    }
117
118    /// Try to connect to the IDE server
119    pub async fn connect(&mut self) -> Result<(), IdeError> {
120        if self.ide_info.is_none() {
121            return Err(IdeError::NotDetected);
122        }
123
124        *self.status.lock().unwrap() = ConnectionStatus::Connecting;
125
126        // Try to read connection config from file
127        if let Some(config) = self.read_connection_config().await {
128            self.port = Some(config.port);
129            self.auth_token = config.auth_token.clone();
130
131            // Try to establish connection
132            if self.establish_connection().await.is_ok() {
133                *self.status.lock().unwrap() = ConnectionStatus::Connected;
134                return Ok(());
135            }
136        }
137
138        // Try environment variables as fallback
139        if let Ok(port_str) = env::var("SYNCABLE_CLI_IDE_SERVER_PORT") {
140            if let Ok(port) = port_str.parse::<u16>() {
141                self.port = Some(port);
142                self.auth_token = env::var("SYNCABLE_CLI_IDE_AUTH_TOKEN").ok();
143
144                if self.establish_connection().await.is_ok() {
145                    *self.status.lock().unwrap() = ConnectionStatus::Connected;
146                    return Ok(());
147                }
148            }
149        }
150
151        *self.status.lock().unwrap() = ConnectionStatus::Disconnected;
152        Err(IdeError::ConnectionFailed(
153            "Could not connect to IDE companion extension".to_string(),
154        ))
155    }
156
157    /// Read connection config from port file
158    /// Supports both Syncable and Gemini CLI companion extensions
159    async fn read_connection_config(&self) -> Option<ConnectionConfig> {
160        let temp_dir = env::temp_dir();
161
162        // Try Syncable extension first - scan all port files, match by workspace
163        let syncable_port_dir = temp_dir.join("syncable").join("ide");
164        if let Some(config) = self.find_port_file_by_workspace(&syncable_port_dir, "syncable-ide-server") {
165            return Some(config);
166        }
167
168        // Try Gemini CLI extension (for compatibility)
169        let gemini_port_dir = temp_dir.join("gemini").join("ide");
170        if let Some(config) = self.find_port_file_by_workspace(&gemini_port_dir, "gemini-ide-server") {
171            return Some(config);
172        }
173
174        None
175    }
176
177    /// Find a port file in a directory by scanning all files and matching workspace path
178    fn find_port_file_by_workspace(&self, dir: &PathBuf, prefix: &str) -> Option<ConnectionConfig> {
179        let entries = fs::read_dir(dir).ok()?;
180
181        for entry in entries.flatten() {
182            let filename = entry.file_name().to_string_lossy().to_string();
183            // Match any file starting with the prefix and ending with .json
184            if filename.starts_with(prefix) && filename.ends_with(".json") {
185                if let Ok(content) = fs::read_to_string(entry.path()) {
186                    if let Ok(config) = serde_json::from_str::<ConnectionConfig>(&content) {
187                        if self.validate_workspace_path(&config.workspace_path) {
188                            return Some(config);
189                        }
190                    }
191                }
192            }
193        }
194        None
195    }
196
197    /// Validate that the workspace path matches our current directory
198    fn validate_workspace_path(&self, workspace_path: &Option<String>) -> bool {
199        let Some(ws_path) = workspace_path else {
200            return false;
201        };
202
203        if ws_path.is_empty() {
204            return false;
205        }
206
207        let cwd = match env::current_dir() {
208            Ok(p) => p,
209            Err(_) => return false,
210        };
211
212        // Check if cwd is within any of the workspace paths
213        for path in ws_path.split(std::path::MAIN_SEPARATOR) {
214            let ws = PathBuf::from(path);
215            if cwd.starts_with(&ws) || ws.starts_with(&cwd) {
216                return true;
217            }
218        }
219
220        false
221    }
222
223    /// Establish HTTP connection and initialize MCP session
224    async fn establish_connection(&mut self) -> Result<(), IdeError> {
225        let port = self.port.ok_or(IdeError::ConnectionFailed("No port".to_string()))?;
226        let url = format!("http://127.0.0.1:{}/mcp", port);
227
228        // Build initialize request
229        let init_request = JsonRpcRequest::new(
230            self.next_request_id(),
231            "initialize",
232            serde_json::to_value(InitializeParams {
233                protocol_version: "2024-11-05".to_string(),
234                client_info: ClientInfo {
235                    name: "syncable-cli".to_string(),
236                    version: env!("CARGO_PKG_VERSION").to_string(),
237                },
238                capabilities: ClientCapabilities {},
239            })
240            .unwrap(),
241        );
242
243        // Send initialize request
244        let mut request = self.http_client
245            .post(&url)
246            .header("Accept", "application/json, text/event-stream")
247            .json(&init_request);
248
249        if let Some(token) = &self.auth_token {
250            request = request.header("Authorization", format!("Bearer {}", token));
251        }
252
253        let response = request
254            .send()
255            .await
256            .map_err(|e| IdeError::ConnectionFailed(e.to_string()))?;
257
258        // Get session ID from response header
259        if let Some(session_id) = response.headers().get("mcp-session-id") {
260            if let Ok(id) = session_id.to_str() {
261                *self.session_id.lock().unwrap() = Some(id.to_string());
262            }
263        }
264
265        // Parse response (SSE format: "event: message\ndata: {json}")
266        let response_text = response
267            .text()
268            .await
269            .map_err(|e| IdeError::ConnectionFailed(e.to_string()))?;
270
271        let response_data: JsonRpcResponse = Self::parse_sse_response(&response_text)
272            .map_err(IdeError::ConnectionFailed)?;
273
274        if response_data.error.is_some() {
275            return Err(IdeError::ConnectionFailed(
276                response_data
277                    .error
278                    .map(|e| e.message)
279                    .unwrap_or_default(),
280            ));
281        }
282
283        Ok(())
284    }
285
286    /// Parse SSE response format to extract JSON
287    fn parse_sse_response(text: &str) -> Result<JsonRpcResponse, String> {
288        // SSE format: "event: message\ndata: {json}\n\n"
289        for line in text.lines() {
290            if let Some(json_str) = line.strip_prefix("data: ") {
291                return serde_json::from_str(json_str)
292                    .map_err(|e| format!("Failed to parse JSON: {}", e));
293            }
294        }
295        // Fallback: try parsing entire response as JSON (for non-SSE responses)
296        serde_json::from_str(text)
297            .map_err(|e| format!("Failed to parse response: {}", e))
298    }
299
300    /// Get next request ID
301    fn next_request_id(&self) -> u64 {
302        let mut id = self.request_id.lock().unwrap();
303        *id += 1;
304        *id
305    }
306
307    /// Send an MCP request
308    async fn send_request(
309        &self,
310        method: &str,
311        params: serde_json::Value,
312    ) -> Result<JsonRpcResponse, IdeError> {
313        let port = self.port.ok_or(IdeError::ConnectionFailed("Not connected".to_string()))?;
314        let url = format!("http://127.0.0.1:{}/mcp", port);
315
316        let request = JsonRpcRequest::new(self.next_request_id(), method, params);
317
318        let mut http_request = self.http_client
319            .post(&url)
320            .header("Accept", "application/json, text/event-stream")
321            .json(&request);
322
323        if let Some(token) = &self.auth_token {
324            http_request = http_request.header("Authorization", format!("Bearer {}", token));
325        }
326
327        if let Some(session_id) = &*self.session_id.lock().unwrap() {
328            http_request = http_request.header("mcp-session-id", session_id);
329        }
330
331        let response = http_request
332            .send()
333            .await
334            .map_err(|e| IdeError::RequestFailed(e.to_string()))?;
335
336        let response_text = response
337            .text()
338            .await
339            .map_err(|e| IdeError::RequestFailed(e.to_string()))?;
340
341        Self::parse_sse_response(&response_text)
342            .map_err(IdeError::RequestFailed)
343    }
344
345    /// Open a diff view in the IDE
346    ///
347    /// This sends the file path and new content to the IDE, which will show
348    /// a diff view. The method returns when the user accepts or rejects the diff.
349    pub async fn open_diff(&self, file_path: &str, new_content: &str) -> Result<DiffResult, IdeError> {
350        if !self.is_connected() {
351            return Err(IdeError::ConnectionFailed("Not connected to IDE".to_string()));
352        }
353
354        let params = serde_json::to_value(ToolCallParams {
355            name: "openDiff".to_string(),
356            arguments: serde_json::to_value(OpenDiffArgs {
357                file_path: file_path.to_string(),
358                new_content: new_content.to_string(),
359            })
360            .unwrap(),
361        })
362        .unwrap();
363
364        // Create a channel to receive the diff result
365        let (tx, rx) = oneshot::channel();
366        {
367            let mut responses = self.diff_responses.lock().unwrap();
368            responses.insert(file_path.to_string(), tx);
369        }
370
371        // Send the openDiff request
372        let response = self.send_request("tools/call", params).await;
373
374        if let Err(e) = response {
375            // Remove the pending response
376            let mut responses = self.diff_responses.lock().unwrap();
377            responses.remove(file_path);
378            return Err(e);
379        }
380
381        // Wait for the notification (with timeout)
382        match tokio::time::timeout(Duration::from_secs(300), rx).await {
383            Ok(Ok(result)) => Ok(result),
384            Ok(Err(_)) => Err(IdeError::Cancelled),
385            Err(_) => {
386                // Timeout - remove pending response
387                let mut responses = self.diff_responses.lock().unwrap();
388                responses.remove(file_path);
389                Err(IdeError::NoResponse)
390            }
391        }
392    }
393
394    /// Close a diff view in the IDE
395    pub async fn close_diff(&self, file_path: &str) -> Result<Option<String>, IdeError> {
396        if !self.is_connected() {
397            return Err(IdeError::ConnectionFailed("Not connected to IDE".to_string()));
398        }
399
400        let params = serde_json::to_value(ToolCallParams {
401            name: "closeDiff".to_string(),
402            arguments: serde_json::to_value(CloseDiffArgs {
403                file_path: file_path.to_string(),
404                suppress_notification: Some(false),
405            })
406            .unwrap(),
407        })
408        .unwrap();
409
410        let response = self.send_request("tools/call", params).await?;
411
412        // Parse the response to get content if available
413        if let Some(result) = response.result {
414            if let Ok(tool_result) = serde_json::from_value::<ToolCallResult>(result) {
415                for content in tool_result.content {
416                    if content.content_type == "text" {
417                        if let Some(text) = content.text {
418                            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
419                                if let Some(content) = parsed.get("content").and_then(|c| c.as_str())
420                                {
421                                    return Ok(Some(content.to_string()));
422                                }
423                            }
424                        }
425                    }
426                }
427            }
428        }
429
430        Ok(None)
431    }
432
433    /// Handle an incoming notification from the IDE
434    pub fn handle_notification(&self, notification: JsonRpcNotification) {
435        match notification.method.as_str() {
436            "ide/diffAccepted" => {
437                if let Ok(params) =
438                    serde_json::from_value::<IdeDiffAcceptedParams>(notification.params)
439                {
440                    let mut responses = self.diff_responses.lock().unwrap();
441                    if let Some(tx) = responses.remove(&params.file_path) {
442                        let _ = tx.send(DiffResult::Accepted {
443                            content: params.content,
444                        });
445                    }
446                }
447            }
448            "ide/diffRejected" | "ide/diffClosed" => {
449                if let Ok(params) =
450                    serde_json::from_value::<IdeDiffRejectedParams>(notification.params)
451                {
452                    let mut responses = self.diff_responses.lock().unwrap();
453                    if let Some(tx) = responses.remove(&params.file_path) {
454                        let _ = tx.send(DiffResult::Rejected);
455                    }
456                }
457            }
458            "ide/contextUpdate" => {
459                // Handle IDE context updates (e.g., open files)
460                // This could be used to show relevant context in the agent
461            }
462            _ => {
463                // Unknown notification
464            }
465        }
466    }
467
468    /// Disconnect from the IDE
469    pub async fn disconnect(&mut self) {
470        // Close any pending diffs
471        let pending: Vec<String> = {
472            let responses = self.diff_responses.lock().unwrap();
473            responses.keys().cloned().collect()
474        };
475
476        for file_path in pending {
477            let _ = self.close_diff(&file_path).await;
478        }
479
480        *self.status.lock().unwrap() = ConnectionStatus::Disconnected;
481        *self.session_id.lock().unwrap() = None;
482    }
483}
484
485impl Default for IdeClient {
486    fn default() -> Self {
487        // Create with blocking runtime for sync context
488        tokio::runtime::Handle::current().block_on(Self::new())
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[tokio::test]
497    async fn test_ide_client_creation() {
498        let client = IdeClient::new().await;
499        assert!(!client.is_connected());
500    }
501
502    #[test]
503    fn test_diff_result() {
504        let accepted = DiffResult::Accepted {
505            content: "test".to_string(),
506        };
507        match accepted {
508            DiffResult::Accepted { content } => assert_eq!(content, "test"),
509            _ => panic!("Expected Accepted"),
510        }
511    }
512}