1use super::detect::{detect_ide, get_ide_process_info, IdeInfo, IdeProcessInfo};
7use super::types::*;
8use std::collections::HashMap;
9use std::env;
10use std::fs;
11use std::path::PathBuf;
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14use tokio::sync::{mpsc, oneshot};
15
16#[derive(Debug, Clone)]
18pub enum DiffResult {
19 Accepted { content: String },
21 Rejected,
23}
24
25#[derive(Debug, Clone, PartialEq)]
27pub enum ConnectionStatus {
28 Connected,
29 Disconnected,
30 Connecting,
31}
32
33#[derive(Debug, thiserror::Error)]
35pub enum IdeError {
36 #[error("IDE not detected")]
37 NotDetected,
38 #[error("Connection failed: {0}")]
39 ConnectionFailed(String),
40 #[error("Request failed: {0}")]
41 RequestFailed(String),
42 #[error("No response received")]
43 NoResponse,
44 #[error("Operation cancelled")]
45 Cancelled,
46 #[error("IO error: {0}")]
47 Io(#[from] std::io::Error),
48}
49
50#[derive(Debug)]
52pub struct IdeClient {
53 http_client: reqwest::Client,
55 status: Arc<Mutex<ConnectionStatus>>,
57 ide_info: Option<IdeInfo>,
59 process_info: Option<IdeProcessInfo>,
61 port: Option<u16>,
63 auth_token: Option<String>,
65 session_id: Arc<Mutex<Option<String>>>,
67 request_id: Arc<Mutex<u64>>,
69 diff_responses: Arc<Mutex<HashMap<String, oneshot::Sender<DiffResult>>>>,
71 sse_receiver: Option<mpsc::Receiver<JsonRpcNotification>>,
73}
74
75impl IdeClient {
76 pub async fn new() -> Self {
78 let process_info = get_ide_process_info().await;
79 let ide_info = detect_ide(process_info.as_ref());
80
81 Self {
82 http_client: reqwest::Client::builder()
83 .timeout(Duration::from_secs(30))
84 .build()
85 .unwrap_or_default(),
86 status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
87 ide_info,
88 process_info,
89 port: None,
90 auth_token: None,
91 session_id: Arc::new(Mutex::new(None)),
92 request_id: Arc::new(Mutex::new(0)),
93 diff_responses: Arc::new(Mutex::new(HashMap::new())),
94 sse_receiver: None,
95 }
96 }
97
98 pub fn is_ide_available(&self) -> bool {
100 self.ide_info.is_some()
101 }
102
103 pub fn ide_name(&self) -> Option<&str> {
105 self.ide_info.as_ref().map(|i| i.display_name.as_str())
106 }
107
108 pub fn is_connected(&self) -> bool {
110 *self.status.lock().unwrap() == ConnectionStatus::Connected
111 }
112
113 pub fn status(&self) -> ConnectionStatus {
115 self.status.lock().unwrap().clone()
116 }
117
118 pub async fn connect(&mut self) -> Result<(), IdeError> {
120 if self.ide_info.is_none() {
121 return Err(IdeError::NotDetected);
122 }
123
124 *self.status.lock().unwrap() = ConnectionStatus::Connecting;
125
126 if let Some(config) = self.read_connection_config().await {
128 self.port = Some(config.port);
129 self.auth_token = config.auth_token.clone();
130
131 if self.establish_connection().await.is_ok() {
133 *self.status.lock().unwrap() = ConnectionStatus::Connected;
134 return Ok(());
135 }
136 }
137
138 if let Ok(port_str) = env::var("SYNCABLE_CLI_IDE_SERVER_PORT") {
140 if let Ok(port) = port_str.parse::<u16>() {
141 self.port = Some(port);
142 self.auth_token = env::var("SYNCABLE_CLI_IDE_AUTH_TOKEN").ok();
143
144 if self.establish_connection().await.is_ok() {
145 *self.status.lock().unwrap() = ConnectionStatus::Connected;
146 return Ok(());
147 }
148 }
149 }
150
151 *self.status.lock().unwrap() = ConnectionStatus::Disconnected;
152 Err(IdeError::ConnectionFailed(
153 "Could not connect to IDE companion extension".to_string(),
154 ))
155 }
156
157 async fn read_connection_config(&self) -> Option<ConnectionConfig> {
160 let temp_dir = env::temp_dir();
161
162 let syncable_port_dir = temp_dir.join("syncable").join("ide");
164 if let Some(config) = self.find_port_file_by_workspace(&syncable_port_dir, "syncable-ide-server") {
165 return Some(config);
166 }
167
168 let gemini_port_dir = temp_dir.join("gemini").join("ide");
170 if let Some(config) = self.find_port_file_by_workspace(&gemini_port_dir, "gemini-ide-server") {
171 return Some(config);
172 }
173
174 None
175 }
176
177 fn find_port_file_by_workspace(&self, dir: &PathBuf, prefix: &str) -> Option<ConnectionConfig> {
179 let entries = fs::read_dir(dir).ok()?;
180
181 for entry in entries.flatten() {
182 let filename = entry.file_name().to_string_lossy().to_string();
183 if filename.starts_with(prefix) && filename.ends_with(".json") {
185 if let Ok(content) = fs::read_to_string(entry.path()) {
186 if let Ok(config) = serde_json::from_str::<ConnectionConfig>(&content) {
187 if self.validate_workspace_path(&config.workspace_path) {
188 return Some(config);
189 }
190 }
191 }
192 }
193 }
194 None
195 }
196
197 fn validate_workspace_path(&self, workspace_path: &Option<String>) -> bool {
199 let Some(ws_path) = workspace_path else {
200 return false;
201 };
202
203 if ws_path.is_empty() {
204 return false;
205 }
206
207 let cwd = match env::current_dir() {
208 Ok(p) => p,
209 Err(_) => return false,
210 };
211
212 for path in ws_path.split(std::path::MAIN_SEPARATOR) {
214 let ws = PathBuf::from(path);
215 if cwd.starts_with(&ws) || ws.starts_with(&cwd) {
216 return true;
217 }
218 }
219
220 false
221 }
222
223 async fn establish_connection(&mut self) -> Result<(), IdeError> {
225 let port = self.port.ok_or(IdeError::ConnectionFailed("No port".to_string()))?;
226 let url = format!("http://127.0.0.1:{}/mcp", port);
227
228 let init_request = JsonRpcRequest::new(
230 self.next_request_id(),
231 "initialize",
232 serde_json::to_value(InitializeParams {
233 protocol_version: "2024-11-05".to_string(),
234 client_info: ClientInfo {
235 name: "syncable-cli".to_string(),
236 version: env!("CARGO_PKG_VERSION").to_string(),
237 },
238 capabilities: ClientCapabilities {},
239 })
240 .unwrap(),
241 );
242
243 let mut request = self.http_client
245 .post(&url)
246 .header("Accept", "application/json, text/event-stream")
247 .json(&init_request);
248
249 if let Some(token) = &self.auth_token {
250 request = request.header("Authorization", format!("Bearer {}", token));
251 }
252
253 let response = request
254 .send()
255 .await
256 .map_err(|e| IdeError::ConnectionFailed(e.to_string()))?;
257
258 if let Some(session_id) = response.headers().get("mcp-session-id") {
260 if let Ok(id) = session_id.to_str() {
261 *self.session_id.lock().unwrap() = Some(id.to_string());
262 }
263 }
264
265 let response_text = response
267 .text()
268 .await
269 .map_err(|e| IdeError::ConnectionFailed(e.to_string()))?;
270
271 let response_data: JsonRpcResponse = Self::parse_sse_response(&response_text)
272 .map_err(IdeError::ConnectionFailed)?;
273
274 if response_data.error.is_some() {
275 return Err(IdeError::ConnectionFailed(
276 response_data
277 .error
278 .map(|e| e.message)
279 .unwrap_or_default(),
280 ));
281 }
282
283 Ok(())
284 }
285
286 fn parse_sse_response(text: &str) -> Result<JsonRpcResponse, String> {
288 for line in text.lines() {
290 if let Some(json_str) = line.strip_prefix("data: ") {
291 return serde_json::from_str(json_str)
292 .map_err(|e| format!("Failed to parse JSON: {}", e));
293 }
294 }
295 serde_json::from_str(text)
297 .map_err(|e| format!("Failed to parse response: {}", e))
298 }
299
300 fn next_request_id(&self) -> u64 {
302 let mut id = self.request_id.lock().unwrap();
303 *id += 1;
304 *id
305 }
306
307 async fn send_request(
309 &self,
310 method: &str,
311 params: serde_json::Value,
312 ) -> Result<JsonRpcResponse, IdeError> {
313 let port = self.port.ok_or(IdeError::ConnectionFailed("Not connected".to_string()))?;
314 let url = format!("http://127.0.0.1:{}/mcp", port);
315
316 let request = JsonRpcRequest::new(self.next_request_id(), method, params);
317
318 let mut http_request = self.http_client
319 .post(&url)
320 .header("Accept", "application/json, text/event-stream")
321 .json(&request);
322
323 if let Some(token) = &self.auth_token {
324 http_request = http_request.header("Authorization", format!("Bearer {}", token));
325 }
326
327 if let Some(session_id) = &*self.session_id.lock().unwrap() {
328 http_request = http_request.header("mcp-session-id", session_id);
329 }
330
331 let response = http_request
332 .send()
333 .await
334 .map_err(|e| IdeError::RequestFailed(e.to_string()))?;
335
336 let response_text = response
337 .text()
338 .await
339 .map_err(|e| IdeError::RequestFailed(e.to_string()))?;
340
341 Self::parse_sse_response(&response_text)
342 .map_err(IdeError::RequestFailed)
343 }
344
345 pub async fn open_diff(&self, file_path: &str, new_content: &str) -> Result<DiffResult, IdeError> {
350 if !self.is_connected() {
351 return Err(IdeError::ConnectionFailed("Not connected to IDE".to_string()));
352 }
353
354 let params = serde_json::to_value(ToolCallParams {
355 name: "openDiff".to_string(),
356 arguments: serde_json::to_value(OpenDiffArgs {
357 file_path: file_path.to_string(),
358 new_content: new_content.to_string(),
359 })
360 .unwrap(),
361 })
362 .unwrap();
363
364 let (tx, rx) = oneshot::channel();
366 {
367 let mut responses = self.diff_responses.lock().unwrap();
368 responses.insert(file_path.to_string(), tx);
369 }
370
371 let response = self.send_request("tools/call", params).await;
373
374 if let Err(e) = response {
375 let mut responses = self.diff_responses.lock().unwrap();
377 responses.remove(file_path);
378 return Err(e);
379 }
380
381 match tokio::time::timeout(Duration::from_secs(300), rx).await {
383 Ok(Ok(result)) => Ok(result),
384 Ok(Err(_)) => Err(IdeError::Cancelled),
385 Err(_) => {
386 let mut responses = self.diff_responses.lock().unwrap();
388 responses.remove(file_path);
389 Err(IdeError::NoResponse)
390 }
391 }
392 }
393
394 pub async fn close_diff(&self, file_path: &str) -> Result<Option<String>, IdeError> {
396 if !self.is_connected() {
397 return Err(IdeError::ConnectionFailed("Not connected to IDE".to_string()));
398 }
399
400 let params = serde_json::to_value(ToolCallParams {
401 name: "closeDiff".to_string(),
402 arguments: serde_json::to_value(CloseDiffArgs {
403 file_path: file_path.to_string(),
404 suppress_notification: Some(false),
405 })
406 .unwrap(),
407 })
408 .unwrap();
409
410 let response = self.send_request("tools/call", params).await?;
411
412 if let Some(result) = response.result {
414 if let Ok(tool_result) = serde_json::from_value::<ToolCallResult>(result) {
415 for content in tool_result.content {
416 if content.content_type == "text" {
417 if let Some(text) = content.text {
418 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
419 if let Some(content) = parsed.get("content").and_then(|c| c.as_str())
420 {
421 return Ok(Some(content.to_string()));
422 }
423 }
424 }
425 }
426 }
427 }
428 }
429
430 Ok(None)
431 }
432
433 pub fn handle_notification(&self, notification: JsonRpcNotification) {
435 match notification.method.as_str() {
436 "ide/diffAccepted" => {
437 if let Ok(params) =
438 serde_json::from_value::<IdeDiffAcceptedParams>(notification.params)
439 {
440 let mut responses = self.diff_responses.lock().unwrap();
441 if let Some(tx) = responses.remove(¶ms.file_path) {
442 let _ = tx.send(DiffResult::Accepted {
443 content: params.content,
444 });
445 }
446 }
447 }
448 "ide/diffRejected" | "ide/diffClosed" => {
449 if let Ok(params) =
450 serde_json::from_value::<IdeDiffRejectedParams>(notification.params)
451 {
452 let mut responses = self.diff_responses.lock().unwrap();
453 if let Some(tx) = responses.remove(¶ms.file_path) {
454 let _ = tx.send(DiffResult::Rejected);
455 }
456 }
457 }
458 "ide/contextUpdate" => {
459 }
462 _ => {
463 }
465 }
466 }
467
468 pub async fn disconnect(&mut self) {
470 let pending: Vec<String> = {
472 let responses = self.diff_responses.lock().unwrap();
473 responses.keys().cloned().collect()
474 };
475
476 for file_path in pending {
477 let _ = self.close_diff(&file_path).await;
478 }
479
480 *self.status.lock().unwrap() = ConnectionStatus::Disconnected;
481 *self.session_id.lock().unwrap() = None;
482 }
483}
484
485impl Default for IdeClient {
486 fn default() -> Self {
487 tokio::runtime::Handle::current().block_on(Self::new())
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[tokio::test]
497 async fn test_ide_client_creation() {
498 let client = IdeClient::new().await;
499 assert!(!client.is_connected());
500 }
501
502 #[test]
503 fn test_diff_result() {
504 let accepted = DiffResult::Accepted {
505 content: "test".to_string(),
506 };
507 match accepted {
508 DiffResult::Accepted { content } => assert_eq!(content, "test"),
509 _ => panic!("Expected Accepted"),
510 }
511 }
512}