syncable_cli/agent/ide/
client.rs

1//! MCP Client for IDE Communication
2//!
3//! Connects to the IDE's MCP server via HTTP SSE and provides methods
4//! for opening diffs and receiving notifications.
5
6use super::detect::{detect_ide, get_ide_process_info, IdeInfo, IdeProcessInfo};
7use super::types::*;
8use std::collections::HashMap;
9use std::env;
10use std::fs;
11use std::path::PathBuf;
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14use tokio::sync::{mpsc, oneshot};
15
16/// Result of a diff operation
17#[derive(Debug, Clone)]
18pub enum DiffResult {
19    /// User accepted the diff, possibly with edits
20    Accepted { content: String },
21    /// User rejected the diff
22    Rejected,
23}
24
25/// IDE connection state
26#[derive(Debug, Clone, PartialEq)]
27pub enum ConnectionStatus {
28    Connected,
29    Disconnected,
30    Connecting,
31}
32
33/// Errors that can occur during IDE operations
34#[derive(Debug, thiserror::Error)]
35pub enum IdeError {
36    #[error("IDE not detected")]
37    NotDetected,
38    #[error("Connection failed: {0}")]
39    ConnectionFailed(String),
40    #[error("Request failed: {0}")]
41    RequestFailed(String),
42    #[error("No response received")]
43    NoResponse,
44    #[error("Operation cancelled")]
45    Cancelled,
46    #[error("IO error: {0}")]
47    Io(#[from] std::io::Error),
48}
49
50/// MCP Client for IDE communication
51#[derive(Debug)]
52pub struct IdeClient {
53    /// HTTP client
54    http_client: reqwest::Client,
55    /// Connection state
56    status: Arc<Mutex<ConnectionStatus>>,
57    /// Detected IDE info
58    ide_info: Option<IdeInfo>,
59    /// IDE process info
60    process_info: Option<IdeProcessInfo>,
61    /// Server port
62    port: Option<u16>,
63    /// Auth token
64    auth_token: Option<String>,
65    /// Session ID for MCP
66    session_id: Arc<Mutex<Option<String>>>,
67    /// Request ID counter
68    request_id: Arc<Mutex<u64>>,
69    /// Pending diff responses
70    diff_responses: Arc<Mutex<HashMap<String, oneshot::Sender<DiffResult>>>>,
71    /// SSE event receiver
72    sse_receiver: Option<mpsc::Receiver<JsonRpcNotification>>,
73}
74
75impl IdeClient {
76    /// Create a new IDE client (does not connect automatically)
77    pub async fn new() -> Self {
78        let process_info = get_ide_process_info().await;
79        let ide_info = detect_ide(process_info.as_ref());
80
81        Self {
82            http_client: reqwest::Client::builder()
83                .timeout(Duration::from_secs(30))
84                .build()
85                .unwrap_or_default(),
86            status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
87            ide_info,
88            process_info,
89            port: None,
90            auth_token: None,
91            session_id: Arc::new(Mutex::new(None)),
92            request_id: Arc::new(Mutex::new(0)),
93            diff_responses: Arc::new(Mutex::new(HashMap::new())),
94            sse_receiver: None,
95        }
96    }
97
98    /// Check if IDE integration is available
99    pub fn is_ide_available(&self) -> bool {
100        self.ide_info.is_some()
101    }
102
103    /// Get the detected IDE name
104    pub fn ide_name(&self) -> Option<&str> {
105        self.ide_info.as_ref().map(|i| i.display_name.as_str())
106    }
107
108    /// Check if connected to IDE
109    pub fn is_connected(&self) -> bool {
110        *self.status.lock().unwrap() == ConnectionStatus::Connected
111    }
112
113    /// Get connection status
114    pub fn status(&self) -> ConnectionStatus {
115        self.status.lock().unwrap().clone()
116    }
117
118    /// Try to connect to the IDE server
119    pub async fn connect(&mut self) -> Result<(), IdeError> {
120        if self.ide_info.is_none() {
121            return Err(IdeError::NotDetected);
122        }
123
124        *self.status.lock().unwrap() = ConnectionStatus::Connecting;
125
126        // Try to read connection config from file
127        if let Some(config) = self.read_connection_config().await {
128            self.port = Some(config.port);
129            self.auth_token = config.auth_token;
130
131            // Try to establish connection
132            if self.establish_connection().await.is_ok() {
133                *self.status.lock().unwrap() = ConnectionStatus::Connected;
134                return Ok(());
135            }
136        }
137
138        // Try environment variables as fallback
139        if let Ok(port_str) = env::var("SYNCABLE_CLI_IDE_SERVER_PORT") {
140            if let Ok(port) = port_str.parse::<u16>() {
141                self.port = Some(port);
142                self.auth_token = env::var("SYNCABLE_CLI_IDE_AUTH_TOKEN").ok();
143
144                if self.establish_connection().await.is_ok() {
145                    *self.status.lock().unwrap() = ConnectionStatus::Connected;
146                    return Ok(());
147                }
148            }
149        }
150
151        *self.status.lock().unwrap() = ConnectionStatus::Disconnected;
152        Err(IdeError::ConnectionFailed(
153            "Could not connect to IDE companion extension".to_string(),
154        ))
155    }
156
157    /// Read connection config from port file
158    /// Supports both Syncable and Gemini CLI companion extensions
159    async fn read_connection_config(&self) -> Option<ConnectionConfig> {
160        let process_info = self.process_info.as_ref()?;
161        let pid = process_info.pid;
162        let temp_dir = env::temp_dir();
163
164        // Try Syncable extension first
165        let syncable_port_dir = temp_dir.join("syncable").join("ide");
166        if let Some(config) = self.find_port_file(&syncable_port_dir, "syncable-ide-server", pid) {
167            return Some(config);
168        }
169
170        // Try Gemini CLI extension (for compatibility)
171        let gemini_port_dir = temp_dir.join("gemini").join("ide");
172        if let Some(config) = self.find_port_file(&gemini_port_dir, "gemini-ide-server", pid) {
173            return Some(config);
174        }
175
176        // Legacy Gemini format (single file in temp dir)
177        let legacy_gemini = temp_dir.join(format!("gemini-ide-server-{}.json", pid));
178        if let Ok(content) = fs::read_to_string(&legacy_gemini) {
179            if let Ok(config) = serde_json::from_str::<ConnectionConfig>(&content) {
180                if self.validate_workspace_path(&config.workspace_path) {
181                    return Some(config);
182                }
183            }
184        }
185
186        None
187    }
188
189    /// Find a port file in a directory matching the given prefix and PID
190    fn find_port_file(&self, dir: &PathBuf, prefix: &str, pid: u32) -> Option<ConnectionConfig> {
191        let entries = fs::read_dir(dir).ok()?;
192        let file_prefix = format!("{}-{}-", prefix, pid);
193
194        for entry in entries.flatten() {
195            let filename = entry.file_name().to_string_lossy().to_string();
196            if filename.starts_with(&file_prefix) && filename.ends_with(".json") {
197                if let Ok(content) = fs::read_to_string(entry.path()) {
198                    if let Ok(config) = serde_json::from_str::<ConnectionConfig>(&content) {
199                        if self.validate_workspace_path(&config.workspace_path) {
200                            return Some(config);
201                        }
202                    }
203                }
204            }
205        }
206        None
207    }
208
209    /// Validate that the workspace path matches our current directory
210    fn validate_workspace_path(&self, workspace_path: &Option<String>) -> bool {
211        let Some(ws_path) = workspace_path else {
212            return false;
213        };
214
215        if ws_path.is_empty() {
216            return false;
217        }
218
219        let cwd = match env::current_dir() {
220            Ok(p) => p,
221            Err(_) => return false,
222        };
223
224        // Check if cwd is within any of the workspace paths
225        for path in ws_path.split(std::path::MAIN_SEPARATOR) {
226            let ws = PathBuf::from(path);
227            if cwd.starts_with(&ws) || ws.starts_with(&cwd) {
228                return true;
229            }
230        }
231
232        false
233    }
234
235    /// Establish HTTP connection and initialize MCP session
236    async fn establish_connection(&mut self) -> Result<(), IdeError> {
237        let port = self.port.ok_or(IdeError::ConnectionFailed("No port".to_string()))?;
238        let url = format!("http://127.0.0.1:{}/mcp", port);
239
240        // Build initialize request
241        let init_request = JsonRpcRequest::new(
242            self.next_request_id(),
243            "initialize",
244            serde_json::to_value(InitializeParams {
245                protocol_version: "2024-11-05".to_string(),
246                client_info: ClientInfo {
247                    name: "syncable-cli".to_string(),
248                    version: env!("CARGO_PKG_VERSION").to_string(),
249                },
250                capabilities: ClientCapabilities {},
251            })
252            .unwrap(),
253        );
254
255        // Send initialize request
256        let mut request = self.http_client.post(&url).json(&init_request);
257
258        if let Some(token) = &self.auth_token {
259            request = request.header("Authorization", format!("Bearer {}", token));
260        }
261
262        let response = request
263            .send()
264            .await
265            .map_err(|e| IdeError::ConnectionFailed(e.to_string()))?;
266
267        // Get session ID from response header
268        if let Some(session_id) = response.headers().get("mcp-session-id") {
269            if let Ok(id) = session_id.to_str() {
270                *self.session_id.lock().unwrap() = Some(id.to_string());
271            }
272        }
273
274        // Parse response
275        let response_data: JsonRpcResponse = response
276            .json()
277            .await
278            .map_err(|e| IdeError::ConnectionFailed(e.to_string()))?;
279
280        if response_data.error.is_some() {
281            return Err(IdeError::ConnectionFailed(
282                response_data
283                    .error
284                    .map(|e| e.message)
285                    .unwrap_or_default(),
286            ));
287        }
288
289        Ok(())
290    }
291
292    /// Get next request ID
293    fn next_request_id(&self) -> u64 {
294        let mut id = self.request_id.lock().unwrap();
295        *id += 1;
296        *id
297    }
298
299    /// Send an MCP request
300    async fn send_request(
301        &self,
302        method: &str,
303        params: serde_json::Value,
304    ) -> Result<JsonRpcResponse, IdeError> {
305        let port = self.port.ok_or(IdeError::ConnectionFailed("Not connected".to_string()))?;
306        let url = format!("http://127.0.0.1:{}/mcp", port);
307
308        let request = JsonRpcRequest::new(self.next_request_id(), method, params);
309
310        let mut http_request = self.http_client.post(&url).json(&request);
311
312        if let Some(token) = &self.auth_token {
313            http_request = http_request.header("Authorization", format!("Bearer {}", token));
314        }
315
316        if let Some(session_id) = &*self.session_id.lock().unwrap() {
317            http_request = http_request.header("mcp-session-id", session_id);
318        }
319
320        let response = http_request
321            .send()
322            .await
323            .map_err(|e| IdeError::RequestFailed(e.to_string()))?;
324
325        response
326            .json()
327            .await
328            .map_err(|e| IdeError::RequestFailed(e.to_string()))
329    }
330
331    /// Open a diff view in the IDE
332    ///
333    /// This sends the file path and new content to the IDE, which will show
334    /// a diff view. The method returns when the user accepts or rejects the diff.
335    pub async fn open_diff(&self, file_path: &str, new_content: &str) -> Result<DiffResult, IdeError> {
336        if !self.is_connected() {
337            return Err(IdeError::ConnectionFailed("Not connected to IDE".to_string()));
338        }
339
340        let params = serde_json::to_value(ToolCallParams {
341            name: "openDiff".to_string(),
342            arguments: serde_json::to_value(OpenDiffArgs {
343                file_path: file_path.to_string(),
344                new_content: new_content.to_string(),
345            })
346            .unwrap(),
347        })
348        .unwrap();
349
350        // Create a channel to receive the diff result
351        let (tx, rx) = oneshot::channel();
352        {
353            let mut responses = self.diff_responses.lock().unwrap();
354            responses.insert(file_path.to_string(), tx);
355        }
356
357        // Send the openDiff request
358        let response = self.send_request("tools/call", params).await;
359
360        if let Err(e) = response {
361            // Remove the pending response
362            let mut responses = self.diff_responses.lock().unwrap();
363            responses.remove(file_path);
364            return Err(e);
365        }
366
367        // Wait for the notification (with timeout)
368        match tokio::time::timeout(Duration::from_secs(300), rx).await {
369            Ok(Ok(result)) => Ok(result),
370            Ok(Err(_)) => Err(IdeError::Cancelled),
371            Err(_) => {
372                // Timeout - remove pending response
373                let mut responses = self.diff_responses.lock().unwrap();
374                responses.remove(file_path);
375                Err(IdeError::NoResponse)
376            }
377        }
378    }
379
380    /// Close a diff view in the IDE
381    pub async fn close_diff(&self, file_path: &str) -> Result<Option<String>, IdeError> {
382        if !self.is_connected() {
383            return Err(IdeError::ConnectionFailed("Not connected to IDE".to_string()));
384        }
385
386        let params = serde_json::to_value(ToolCallParams {
387            name: "closeDiff".to_string(),
388            arguments: serde_json::to_value(CloseDiffArgs {
389                file_path: file_path.to_string(),
390                suppress_notification: Some(false),
391            })
392            .unwrap(),
393        })
394        .unwrap();
395
396        let response = self.send_request("tools/call", params).await?;
397
398        // Parse the response to get content if available
399        if let Some(result) = response.result {
400            if let Ok(tool_result) = serde_json::from_value::<ToolCallResult>(result) {
401                for content in tool_result.content {
402                    if content.content_type == "text" {
403                        if let Some(text) = content.text {
404                            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
405                                if let Some(content) = parsed.get("content").and_then(|c| c.as_str())
406                                {
407                                    return Ok(Some(content.to_string()));
408                                }
409                            }
410                        }
411                    }
412                }
413            }
414        }
415
416        Ok(None)
417    }
418
419    /// Handle an incoming notification from the IDE
420    pub fn handle_notification(&self, notification: JsonRpcNotification) {
421        match notification.method.as_str() {
422            "ide/diffAccepted" => {
423                if let Ok(params) =
424                    serde_json::from_value::<IdeDiffAcceptedParams>(notification.params)
425                {
426                    let mut responses = self.diff_responses.lock().unwrap();
427                    if let Some(tx) = responses.remove(&params.file_path) {
428                        let _ = tx.send(DiffResult::Accepted {
429                            content: params.content,
430                        });
431                    }
432                }
433            }
434            "ide/diffRejected" | "ide/diffClosed" => {
435                if let Ok(params) =
436                    serde_json::from_value::<IdeDiffRejectedParams>(notification.params)
437                {
438                    let mut responses = self.diff_responses.lock().unwrap();
439                    if let Some(tx) = responses.remove(&params.file_path) {
440                        let _ = tx.send(DiffResult::Rejected);
441                    }
442                }
443            }
444            "ide/contextUpdate" => {
445                // Handle IDE context updates (e.g., open files)
446                // This could be used to show relevant context in the agent
447            }
448            _ => {
449                // Unknown notification
450            }
451        }
452    }
453
454    /// Disconnect from the IDE
455    pub async fn disconnect(&mut self) {
456        // Close any pending diffs
457        let pending: Vec<String> = {
458            let responses = self.diff_responses.lock().unwrap();
459            responses.keys().cloned().collect()
460        };
461
462        for file_path in pending {
463            let _ = self.close_diff(&file_path).await;
464        }
465
466        *self.status.lock().unwrap() = ConnectionStatus::Disconnected;
467        *self.session_id.lock().unwrap() = None;
468    }
469}
470
471impl Default for IdeClient {
472    fn default() -> Self {
473        // Create with blocking runtime for sync context
474        tokio::runtime::Handle::current().block_on(Self::new())
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[tokio::test]
483    async fn test_ide_client_creation() {
484        let client = IdeClient::new().await;
485        assert!(!client.is_connected());
486    }
487
488    #[test]
489    fn test_diff_result() {
490        let accepted = DiffResult::Accepted {
491            content: "test".to_string(),
492        };
493        match accepted {
494            DiffResult::Accepted { content } => assert_eq!(content, "test"),
495            _ => panic!("Expected Accepted"),
496        }
497    }
498}