Skip to main content

zap/
gateway.rs

1//! ZAP MCP Gateway - Full Implementation
2//!
3//! A gateway that bridges multiple MCP servers, providing:
4//! - Multi-transport support (stdio, HTTP/SSE, WebSocket)
5//! - Tool/resource/prompt aggregation across servers
6//! - Request routing to correct backend servers
7//! - Health checking and automatic reconnection
8//! - Server lifecycle management
9//!
10//! This module implements MCP (Model Context Protocol) gateway functionality
11//! allowing ZAP to act as a unified interface to multiple MCP servers.
12
13use crate::{Config, Result, Error, config::{ServerConfig, Transport, Auth}};
14use serde::{Deserialize, Serialize};
15use serde_json::{json, Value};
16use std::collections::HashMap;
17use std::process::Stdio;
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
21use tokio::process::{Child, Command};
22use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
23use tokio::time::{interval, timeout};
24
25// ============================================================================
26// MCP Protocol Types
27// ============================================================================
28
29/// JSON-RPC 2.0 request
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct JsonRpcRequest {
32    pub jsonrpc: String,
33    pub id: Value,
34    pub method: String,
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub params: Option<Value>,
37}
38
39/// JSON-RPC 2.0 response
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct JsonRpcResponse {
42    pub jsonrpc: String,
43    pub id: Value,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub result: Option<Value>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub error: Option<JsonRpcError>,
48}
49
50/// JSON-RPC 2.0 notification (no id)
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct JsonRpcNotification {
53    pub jsonrpc: String,
54    pub method: String,
55    #[serde(default, skip_serializing_if = "Option::is_none")]
56    pub params: Option<Value>,
57}
58
59/// JSON-RPC 2.0 error
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct JsonRpcError {
62    pub code: i64,
63    pub message: String,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub data: Option<Value>,
66}
67
68/// MCP Tool definition
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct McpTool {
71    pub name: String,
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub title: Option<String>,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub description: Option<String>,
76    #[serde(rename = "inputSchema")]
77    pub input_schema: Value,
78}
79
80/// MCP Resource definition
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct McpResource {
83    pub uri: String,
84    pub name: String,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub description: Option<String>,
87    #[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")]
88    pub mime_type: Option<String>,
89}
90
91/// MCP Prompt definition
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct McpPrompt {
94    pub name: String,
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub description: Option<String>,
97    #[serde(default)]
98    pub arguments: Vec<McpPromptArgument>,
99}
100
101/// MCP Prompt argument
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct McpPromptArgument {
104    pub name: String,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub description: Option<String>,
107    #[serde(default)]
108    pub required: bool,
109}
110
111/// MCP Server capabilities
112#[derive(Debug, Clone, Default, Serialize, Deserialize)]
113pub struct McpCapabilities {
114    #[serde(default)]
115    pub tools: Option<ToolsCapability>,
116    #[serde(default)]
117    pub resources: Option<ResourcesCapability>,
118    #[serde(default)]
119    pub prompts: Option<PromptsCapability>,
120    #[serde(default)]
121    pub logging: Option<Value>,
122}
123
124#[derive(Debug, Clone, Default, Serialize, Deserialize)]
125pub struct ToolsCapability {
126    #[serde(rename = "listChanged", default)]
127    pub list_changed: bool,
128}
129
130#[derive(Debug, Clone, Default, Serialize, Deserialize)]
131pub struct ResourcesCapability {
132    #[serde(rename = "listChanged", default)]
133    pub list_changed: bool,
134    #[serde(default)]
135    pub subscribe: bool,
136}
137
138#[derive(Debug, Clone, Default, Serialize, Deserialize)]
139pub struct PromptsCapability {
140    #[serde(rename = "listChanged", default)]
141    pub list_changed: bool,
142}
143
144/// MCP Server info from initialization
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct McpServerInfo {
147    pub name: String,
148    #[serde(default)]
149    pub version: String,
150}
151
152// ============================================================================
153// Server Connection Status
154// ============================================================================
155
156/// Server connection status
157#[derive(Debug, Clone, Copy, PartialEq, Eq)]
158pub enum ServerStatus {
159    Connecting,
160    Connected,
161    Disconnected,
162    Error,
163    Reconnecting,
164}
165
166impl std::fmt::Display for ServerStatus {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        match self {
169            ServerStatus::Connecting => write!(f, "connecting"),
170            ServerStatus::Connected => write!(f, "connected"),
171            ServerStatus::Disconnected => write!(f, "disconnected"),
172            ServerStatus::Error => write!(f, "error"),
173            ServerStatus::Reconnecting => write!(f, "reconnecting"),
174        }
175    }
176}
177
178// ============================================================================
179// Stdio Transport
180// ============================================================================
181
182/// Stdio transport for subprocess MCP servers
183pub struct StdioTransport {
184    stdin: Arc<Mutex<tokio::process::ChildStdin>>,
185    pending: Arc<RwLock<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
186    connected: Arc<std::sync::atomic::AtomicBool>,
187    _child: Arc<Mutex<Child>>,
188}
189
190impl StdioTransport {
191    /// Spawn a subprocess and connect via stdio
192    pub async fn spawn(command: &str, args: &[String], env: Option<&HashMap<String, String>>) -> Result<Self> {
193        let mut cmd = Command::new(command);
194        cmd.args(args)
195            .stdin(Stdio::piped())
196            .stdout(Stdio::piped())
197            .stderr(Stdio::piped());
198
199        if let Some(env_vars) = env {
200            for (k, v) in env_vars {
201                cmd.env(k, v);
202            }
203        }
204
205        let mut child = cmd.spawn()
206            .map_err(|e| Error::Transport(format!("failed to spawn {}: {}", command, e)))?;
207
208        let stdin = child.stdin.take()
209            .ok_or_else(|| Error::Transport("failed to get stdin".into()))?;
210        let stdout = child.stdout.take()
211            .ok_or_else(|| Error::Transport("failed to get stdout".into()))?;
212
213        let pending: Arc<RwLock<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> =
214            Arc::new(RwLock::new(HashMap::new()));
215        let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
216
217        // Spawn reader task
218        let pending_clone = pending.clone();
219        let connected_clone = connected.clone();
220        tokio::spawn(async move {
221            let mut reader = BufReader::new(stdout).lines();
222            while let Ok(Some(line)) = reader.next_line().await {
223                if line.is_empty() {
224                    continue;
225                }
226
227                match serde_json::from_str::<JsonRpcResponse>(&line) {
228                    Ok(response) => {
229                        let id_str = match &response.id {
230                            Value::Number(n) => n.to_string(),
231                            Value::String(s) => s.clone(),
232                            _ => continue,
233                        };
234
235                        let mut pending = pending_clone.write().await;
236                        if let Some(tx) = pending.remove(&id_str) {
237                            let _ = tx.send(response);
238                        }
239                    }
240                    Err(e) => {
241                        tracing::debug!("Failed to parse response: {} - line: {}", e, line);
242                    }
243                }
244            }
245            connected_clone.store(false, std::sync::atomic::Ordering::SeqCst);
246        });
247
248        Ok(Self {
249            stdin: Arc::new(Mutex::new(stdin)),
250            pending,
251            connected,
252            _child: Arc::new(Mutex::new(child)),
253        })
254    }
255
256    pub async fn request(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse> {
257        let id_str = match &req.id {
258            Value::Number(n) => n.to_string(),
259            Value::String(s) => s.clone(),
260            _ => return Err(Error::Protocol("invalid request id".into())),
261        };
262
263        let (tx, rx) = oneshot::channel();
264
265        {
266            let mut pending = self.pending.write().await;
267            pending.insert(id_str.clone(), tx);
268        }
269
270        let line = serde_json::to_string(&req)? + "\n";
271        {
272            let mut stdin = self.stdin.lock().await;
273            stdin.write_all(line.as_bytes()).await
274                .map_err(|e| Error::Transport(format!("write failed: {}", e)))?;
275            stdin.flush().await
276                .map_err(|e| Error::Transport(format!("flush failed: {}", e)))?;
277        }
278
279        match timeout(Duration::from_secs(30), rx).await {
280            Ok(Ok(response)) => Ok(response),
281            Ok(Err(_)) => Err(Error::Transport("response channel closed".into())),
282            Err(_) => {
283                let mut pending = self.pending.write().await;
284                pending.remove(&id_str);
285                Err(Error::Transport("request timeout".into()))
286            }
287        }
288    }
289
290    pub async fn notify(&self, notif: JsonRpcNotification) -> Result<()> {
291        let line = serde_json::to_string(&notif)? + "\n";
292        let mut stdin = self.stdin.lock().await;
293        stdin.write_all(line.as_bytes()).await
294            .map_err(|e| Error::Transport(format!("write failed: {}", e)))?;
295        stdin.flush().await
296            .map_err(|e| Error::Transport(format!("flush failed: {}", e)))?;
297        Ok(())
298    }
299
300    pub async fn close(&self) -> Result<()> {
301        let mut child = self._child.lock().await;
302        let _ = child.kill().await;
303        self.connected.store(false, std::sync::atomic::Ordering::SeqCst);
304        Ok(())
305    }
306
307    pub fn is_connected(&self) -> bool {
308        self.connected.load(std::sync::atomic::Ordering::SeqCst)
309    }
310}
311
312// ============================================================================
313// HTTP Transport (using hyper directly)
314// ============================================================================
315
316/// HTTP transport with optional SSE support
317pub struct HttpTransport {
318    endpoint: String,
319    session_id: Arc<RwLock<Option<String>>>,
320    auth: Option<Auth>,
321    connected: Arc<std::sync::atomic::AtomicBool>,
322}
323
324impl HttpTransport {
325    pub fn new(endpoint: &str, auth: Option<Auth>) -> Result<Self> {
326        Ok(Self {
327            endpoint: endpoint.to_string(),
328            session_id: Arc::new(RwLock::new(None)),
329            auth,
330            connected: Arc::new(std::sync::atomic::AtomicBool::new(true)),
331        })
332    }
333
334    pub async fn request(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse> {
335        use http_body_util::{BodyExt, Full};
336        use hyper::body::Bytes;
337        use hyper::Request;
338        use hyper_util::client::legacy::Client;
339        use hyper_util::rt::TokioExecutor;
340
341        let body_json = serde_json::to_string(&req)?;
342
343        let uri: hyper::Uri = self.endpoint.parse()
344            .map_err(|e| Error::Transport(format!("invalid URI: {}", e)))?;
345
346        let mut request_builder = Request::builder()
347            .method("POST")
348            .uri(&uri)
349            .header("Content-Type", "application/json")
350            .header("Accept", "application/json, text/event-stream");
351
352        if let Some(ref auth) = self.auth {
353            match auth {
354                Auth::Bearer { token } => {
355                    request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
356                }
357                Auth::Basic { username, password } => {
358                    let credentials = format!("{}:{}", username, password);
359                    let encoded = hex::encode(credentials.as_bytes());
360                    request_builder = request_builder.header("Authorization", format!("Basic {}", encoded));
361                }
362            }
363        }
364
365        if let Some(ref sid) = *self.session_id.read().await {
366            request_builder = request_builder.header("Mcp-Session-Id", sid.as_str());
367        }
368
369        let request = request_builder
370            .body(Full::new(Bytes::from(body_json)))
371            .map_err(|e| Error::Transport(format!("failed to build request: {}", e)))?;
372
373        let https = hyper_util::client::legacy::connect::HttpConnector::new();
374        let client: Client<_, Full<Bytes>> = Client::builder(TokioExecutor::new()).build(https);
375
376        let response = client.request(request).await
377            .map_err(|e| Error::Transport(format!("HTTP request failed: {}", e)))?;
378
379        if let Some(sid) = response.headers().get("Mcp-Session-Id") {
380            if let Ok(sid_str) = sid.to_str() {
381                *self.session_id.write().await = Some(sid_str.to_string());
382            }
383        }
384
385        let status = response.status();
386        if !status.is_success() {
387            self.connected.store(false, std::sync::atomic::Ordering::SeqCst);
388            return Err(Error::Transport(format!("HTTP error: {}", status)));
389        }
390
391        let body_bytes = response.into_body().collect().await
392            .map_err(|e| Error::Transport(format!("failed to read response: {}", e)))?
393            .to_bytes();
394
395        let body = String::from_utf8_lossy(&body_bytes);
396
397        let json_str = if body.starts_with("data:") {
398            body.lines()
399                .filter(|l| l.starts_with("data:"))
400                .last()
401                .map(|l| l.trim_start_matches("data:").trim())
402                .unwrap_or(&body)
403        } else {
404            &body
405        };
406
407        serde_json::from_str(json_str)
408            .map_err(|e| Error::Protocol(format!("invalid JSON response: {}", e)))
409    }
410
411    pub async fn notify(&self, notif: JsonRpcNotification) -> Result<()> {
412        use http_body_util::Full;
413        use hyper::body::Bytes;
414        use hyper::Request;
415        use hyper_util::client::legacy::Client;
416        use hyper_util::rt::TokioExecutor;
417
418        let body_json = serde_json::to_string(&notif)?;
419        let uri: hyper::Uri = self.endpoint.parse()
420            .map_err(|e| Error::Transport(format!("invalid URI: {}", e)))?;
421
422        let mut request_builder = Request::builder()
423            .method("POST")
424            .uri(&uri)
425            .header("Content-Type", "application/json");
426
427        if let Some(ref auth) = self.auth {
428            match auth {
429                Auth::Bearer { token } => {
430                    request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
431                }
432                Auth::Basic { username, password } => {
433                    let credentials = format!("{}:{}", username, password);
434                    let encoded = hex::encode(credentials.as_bytes());
435                    request_builder = request_builder.header("Authorization", format!("Basic {}", encoded));
436                }
437            }
438        }
439
440        if let Some(ref sid) = *self.session_id.read().await {
441            request_builder = request_builder.header("Mcp-Session-Id", sid.as_str());
442        }
443
444        let request = request_builder
445            .body(Full::new(Bytes::from(body_json)))
446            .map_err(|e| Error::Transport(format!("failed to build request: {}", e)))?;
447
448        let https = hyper_util::client::legacy::connect::HttpConnector::new();
449        let client: Client<_, Full<Bytes>> = Client::builder(TokioExecutor::new()).build(https);
450
451        let response = client.request(request).await
452            .map_err(|e| Error::Transport(format!("HTTP request failed: {}", e)))?;
453
454        let status = response.status();
455        if status != hyper::StatusCode::ACCEPTED && !status.is_success() {
456            return Err(Error::Transport(format!("unexpected status: {}", status)));
457        }
458
459        Ok(())
460    }
461
462    pub async fn close(&self) -> Result<()> {
463        self.connected.store(false, std::sync::atomic::Ordering::SeqCst);
464        Ok(())
465    }
466
467    pub fn is_connected(&self) -> bool {
468        self.connected.load(std::sync::atomic::Ordering::SeqCst)
469    }
470}
471
472// ============================================================================
473// WebSocket Transport
474// ============================================================================
475
476/// WebSocket transport
477pub struct WebSocketTransport {
478    write: Arc<Mutex<futures::stream::SplitSink<
479        tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
480        tokio_tungstenite::tungstenite::Message
481    >>>,
482    pending: Arc<RwLock<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
483    connected: Arc<std::sync::atomic::AtomicBool>,
484}
485
486impl WebSocketTransport {
487    pub async fn connect(url: &str) -> Result<Self> {
488        use futures::StreamExt;
489        use tokio_tungstenite::connect_async;
490
491        let (ws_stream, _) = connect_async(url).await
492            .map_err(|e| Error::Transport(format!("WebSocket connect failed: {}", e)))?;
493
494        let (write, mut read) = ws_stream.split();
495        let pending: Arc<RwLock<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> =
496            Arc::new(RwLock::new(HashMap::new()));
497        let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
498
499        let pending_clone = pending.clone();
500        let connected_clone = connected.clone();
501        tokio::spawn(async move {
502            while let Some(msg) = read.next().await {
503                match msg {
504                    Ok(tokio_tungstenite::tungstenite::Message::Text(text)) => {
505                        if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&text) {
506                            let id_str = match &response.id {
507                                Value::Number(n) => n.to_string(),
508                                Value::String(s) => s.clone(),
509                                _ => continue,
510                            };
511
512                            let mut pending = pending_clone.write().await;
513                            if let Some(tx) = pending.remove(&id_str) {
514                                let _ = tx.send(response);
515                            }
516                        }
517                    }
518                    Ok(tokio_tungstenite::tungstenite::Message::Close(_)) => break,
519                    Err(_) => break,
520                    _ => {}
521                }
522            }
523            connected_clone.store(false, std::sync::atomic::Ordering::SeqCst);
524        });
525
526        Ok(Self { write: Arc::new(Mutex::new(write)), pending, connected })
527    }
528
529    pub async fn request(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse> {
530        use futures::SinkExt;
531        use tokio_tungstenite::tungstenite::Message;
532
533        let id_str = match &req.id {
534            Value::Number(n) => n.to_string(),
535            Value::String(s) => s.clone(),
536            _ => return Err(Error::Protocol("invalid request id".into())),
537        };
538
539        let (tx, rx) = oneshot::channel();
540        { self.pending.write().await.insert(id_str.clone(), tx); }
541
542        let json = serde_json::to_string(&req)?;
543        { self.write.lock().await.send(Message::Text(json.into())).await
544            .map_err(|e| Error::Transport(format!("WebSocket send failed: {}", e)))?; }
545
546        match timeout(Duration::from_secs(30), rx).await {
547            Ok(Ok(response)) => Ok(response),
548            Ok(Err(_)) => Err(Error::Transport("response channel closed".into())),
549            Err(_) => { self.pending.write().await.remove(&id_str); Err(Error::Transport("request timeout".into())) }
550        }
551    }
552
553    pub async fn notify(&self, notif: JsonRpcNotification) -> Result<()> {
554        use futures::SinkExt;
555        use tokio_tungstenite::tungstenite::Message;
556
557        let json = serde_json::to_string(&notif)?;
558        self.write.lock().await.send(Message::Text(json.into())).await
559            .map_err(|e| Error::Transport(format!("WebSocket send failed: {}", e)))
560    }
561
562    pub async fn close(&self) -> Result<()> {
563        use futures::SinkExt;
564        use tokio_tungstenite::tungstenite::Message;
565        let _ = self.write.lock().await.send(Message::Close(None)).await;
566        self.connected.store(false, std::sync::atomic::Ordering::SeqCst);
567        Ok(())
568    }
569
570    pub fn is_connected(&self) -> bool {
571        self.connected.load(std::sync::atomic::Ordering::SeqCst)
572    }
573}
574
575// ============================================================================
576// MCP Client
577// ============================================================================
578
579enum McpClientTransport {
580    Stdio(StdioTransport),
581    Http(HttpTransport),
582    WebSocket(WebSocketTransport),
583}
584
585/// MCP client for a single server connection
586pub struct McpClient {
587    transport: McpClientTransport,
588    server_info: RwLock<Option<McpServerInfo>>,
589    capabilities: RwLock<McpCapabilities>,
590    tools: RwLock<Vec<McpTool>>,
591    resources: RwLock<Vec<McpResource>>,
592    prompts: RwLock<Vec<McpPrompt>>,
593    request_id: std::sync::atomic::AtomicU64,
594}
595
596impl McpClient {
597    fn next_id(&self) -> Value {
598        Value::Number(self.request_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst).into())
599    }
600
601    async fn send_request(&self, method: &str, params: Option<Value>) -> Result<JsonRpcResponse> {
602        let req = JsonRpcRequest {
603            jsonrpc: "2.0".to_string(),
604            id: self.next_id(),
605            method: method.to_string(),
606            params,
607        };
608        match &self.transport {
609            McpClientTransport::Stdio(t) => t.request(req).await,
610            McpClientTransport::Http(t) => t.request(req).await,
611            McpClientTransport::WebSocket(t) => t.request(req).await,
612        }
613    }
614
615    async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<()> {
616        let notif = JsonRpcNotification {
617            jsonrpc: "2.0".to_string(),
618            method: method.to_string(),
619            params,
620        };
621        match &self.transport {
622            McpClientTransport::Stdio(t) => t.notify(notif).await,
623            McpClientTransport::Http(t) => t.notify(notif).await,
624            McpClientTransport::WebSocket(t) => t.notify(notif).await,
625        }
626    }
627
628    pub async fn connect_stdio(command: &str, args: &[String], env: Option<&HashMap<String, String>>) -> Result<Self> {
629        let transport = StdioTransport::spawn(command, args, env).await?;
630        let client = Self {
631            transport: McpClientTransport::Stdio(transport),
632            server_info: RwLock::new(None),
633            capabilities: RwLock::new(McpCapabilities::default()),
634            tools: RwLock::new(Vec::new()),
635            resources: RwLock::new(Vec::new()),
636            prompts: RwLock::new(Vec::new()),
637            request_id: std::sync::atomic::AtomicU64::new(1),
638        };
639        client.initialize().await?;
640        Ok(client)
641    }
642
643    pub async fn connect_http(endpoint: &str, auth: Option<Auth>) -> Result<Self> {
644        let transport = HttpTransport::new(endpoint, auth)?;
645        let client = Self {
646            transport: McpClientTransport::Http(transport),
647            server_info: RwLock::new(None),
648            capabilities: RwLock::new(McpCapabilities::default()),
649            tools: RwLock::new(Vec::new()),
650            resources: RwLock::new(Vec::new()),
651            prompts: RwLock::new(Vec::new()),
652            request_id: std::sync::atomic::AtomicU64::new(1),
653        };
654        client.initialize().await?;
655        Ok(client)
656    }
657
658    pub async fn connect_websocket(url: &str) -> Result<Self> {
659        let transport = WebSocketTransport::connect(url).await?;
660        let client = Self {
661            transport: McpClientTransport::WebSocket(transport),
662            server_info: RwLock::new(None),
663            capabilities: RwLock::new(McpCapabilities::default()),
664            tools: RwLock::new(Vec::new()),
665            resources: RwLock::new(Vec::new()),
666            prompts: RwLock::new(Vec::new()),
667            request_id: std::sync::atomic::AtomicU64::new(1),
668        };
669        client.initialize().await?;
670        Ok(client)
671    }
672
673    async fn initialize(&self) -> Result<()> {
674        let params = json!({
675            "protocolVersion": "2024-11-05",
676            "capabilities": { "roots": { "listChanged": true }, "sampling": {} },
677            "clientInfo": { "name": "zap-gateway", "version": env!("CARGO_PKG_VERSION") }
678        });
679
680        let response = self.send_request("initialize", Some(params)).await?;
681        if let Some(error) = response.error {
682            return Err(Error::Protocol(format!("initialize failed: {}", error.message)));
683        }
684
685        if let Some(result) = response.result {
686            if let Some(server_info) = result.get("serverInfo") {
687                *self.server_info.write().await = serde_json::from_value(server_info.clone()).ok();
688            }
689            if let Some(caps) = result.get("capabilities") {
690                *self.capabilities.write().await = serde_json::from_value(caps.clone()).unwrap_or_default();
691            }
692        }
693
694        self.send_notification("notifications/initialized", None).await?;
695        self.refresh_all().await?;
696        Ok(())
697    }
698
699    pub async fn refresh_all(&self) -> Result<()> {
700        let caps = self.capabilities.read().await.clone();
701        if caps.tools.is_some() { let _ = self.refresh_tools().await; }
702        if caps.resources.is_some() { let _ = self.refresh_resources().await; }
703        if caps.prompts.is_some() { let _ = self.refresh_prompts().await; }
704        Ok(())
705    }
706
707    pub async fn refresh_tools(&self) -> Result<()> {
708        let response = self.send_request("tools/list", None).await?;
709        if let Some(result) = response.result {
710            if let Some(tools_val) = result.get("tools") {
711                *self.tools.write().await = serde_json::from_value(tools_val.clone()).unwrap_or_default();
712            }
713        }
714        Ok(())
715    }
716
717    pub async fn refresh_resources(&self) -> Result<()> {
718        let response = self.send_request("resources/list", None).await?;
719        if let Some(result) = response.result {
720            if let Some(resources_val) = result.get("resources") {
721                *self.resources.write().await = serde_json::from_value(resources_val.clone()).unwrap_or_default();
722            }
723        }
724        Ok(())
725    }
726
727    pub async fn refresh_prompts(&self) -> Result<()> {
728        let response = self.send_request("prompts/list", None).await?;
729        if let Some(result) = response.result {
730            if let Some(prompts_val) = result.get("prompts") {
731                *self.prompts.write().await = serde_json::from_value(prompts_val.clone()).unwrap_or_default();
732            }
733        }
734        Ok(())
735    }
736
737    pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
738        let params = json!({ "name": name, "arguments": arguments });
739        let response = self.send_request("tools/call", Some(params)).await?;
740        if let Some(error) = response.error {
741            return Err(Error::ToolCallFailed(format!("{}: {}", name, error.message)));
742        }
743        response.result.ok_or_else(|| Error::Protocol("empty tool result".into()))
744    }
745
746    pub async fn read_resource(&self, uri: &str) -> Result<Value> {
747        let params = json!({ "uri": uri });
748        let response = self.send_request("resources/read", Some(params)).await?;
749        if let Some(error) = response.error {
750            return Err(Error::ResourceNotFound(format!("{}: {}", uri, error.message)));
751        }
752        response.result.ok_or_else(|| Error::Protocol("empty resource result".into()))
753    }
754
755    pub async fn get_prompt(&self, name: &str, arguments: Option<Value>) -> Result<Value> {
756        let params = json!({ "name": name, "arguments": arguments.unwrap_or(json!({})) });
757        let response = self.send_request("prompts/get", Some(params)).await?;
758        if let Some(error) = response.error {
759            return Err(Error::Protocol(format!("prompt {} failed: {}", name, error.message)));
760        }
761        response.result.ok_or_else(|| Error::Protocol("empty prompt result".into()))
762    }
763
764    pub async fn tools(&self) -> Vec<McpTool> { self.tools.read().await.clone() }
765    pub async fn resources(&self) -> Vec<McpResource> { self.resources.read().await.clone() }
766    pub async fn prompts(&self) -> Vec<McpPrompt> { self.prompts.read().await.clone() }
767    pub async fn server_info(&self) -> Option<McpServerInfo> { self.server_info.read().await.clone() }
768
769    pub fn is_connected(&self) -> bool {
770        match &self.transport {
771            McpClientTransport::Stdio(t) => t.is_connected(),
772            McpClientTransport::Http(t) => t.is_connected(),
773            McpClientTransport::WebSocket(t) => t.is_connected(),
774        }
775    }
776
777    pub async fn close(&self) -> Result<()> {
778        match &self.transport {
779            McpClientTransport::Stdio(t) => t.close().await,
780            McpClientTransport::Http(t) => t.close().await,
781            McpClientTransport::WebSocket(t) => t.close().await,
782        }
783    }
784}
785
786// ============================================================================
787// Connected Server State
788// ============================================================================
789
790struct ConnectedServer {
791    id: String,
792    name: String,
793    config: ServerConfig,
794    client: Option<Arc<McpClient>>,
795    status: ServerStatus,
796    last_error: Option<String>,
797    #[allow(dead_code)]
798    last_health_check: Option<Instant>,
799    reconnect_attempts: u32,
800}
801
802impl ConnectedServer {
803    fn new(id: String, name: String, config: ServerConfig) -> Self {
804        Self { id, name, config, client: None, status: ServerStatus::Disconnected,
805               last_error: None, last_health_check: None, reconnect_attempts: 0 }
806    }
807}
808
809// ============================================================================
810// Gateway Implementation
811// ============================================================================
812
813/// ZAP MCP Gateway - aggregates multiple MCP servers
814pub struct Gateway {
815    config: Config,
816    servers: Arc<RwLock<HashMap<String, ConnectedServer>>>,
817    tool_routing: Arc<RwLock<HashMap<String, String>>>,
818    resource_routing: Arc<RwLock<HashMap<String, String>>>,
819    prompt_routing: Arc<RwLock<HashMap<String, String>>>,
820    shutdown_tx: Option<mpsc::Sender<()>>,
821}
822
823/// Server info returned by list_servers
824#[derive(Debug, Clone)]
825pub struct ServerInfo {
826    pub id: String,
827    pub name: String,
828    pub url: String,
829    pub status: ServerStatus,
830    pub tools_count: usize,
831    pub resources_count: usize,
832    pub prompts_count: usize,
833    pub last_error: Option<String>,
834}
835
836impl Gateway {
837    pub fn new(config: Config) -> Self {
838        Self {
839            config,
840            servers: Arc::new(RwLock::new(HashMap::new())),
841            tool_routing: Arc::new(RwLock::new(HashMap::new())),
842            resource_routing: Arc::new(RwLock::new(HashMap::new())),
843            prompt_routing: Arc::new(RwLock::new(HashMap::new())),
844            shutdown_tx: None,
845        }
846    }
847
848    fn generate_id() -> String {
849        use std::time::{SystemTime, UNIX_EPOCH};
850        format!("{:x}", SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_nanos())
851    }
852
853    pub async fn add_server(&self, name: &str, config: ServerConfig) -> Result<String> {
854        let id = Self::generate_id();
855        let server = ConnectedServer::new(id.clone(), name.to_string(), config);
856        self.servers.write().await.insert(id.clone(), server);
857
858        let servers = self.servers.clone();
859        let tool_routing = self.tool_routing.clone();
860        let resource_routing = self.resource_routing.clone();
861        let prompt_routing = self.prompt_routing.clone();
862        let server_id = id.clone();
863
864        tokio::spawn(async move {
865            if let Err(e) = Self::connect_server(&servers, &tool_routing, &resource_routing, &prompt_routing, &server_id).await {
866                tracing::error!("Failed to connect to server {}: {}", server_id, e);
867            }
868        });
869
870        Ok(id)
871    }
872
873    async fn connect_server(
874        servers: &Arc<RwLock<HashMap<String, ConnectedServer>>>,
875        tool_routing: &Arc<RwLock<HashMap<String, String>>>,
876        resource_routing: &Arc<RwLock<HashMap<String, String>>>,
877        prompt_routing: &Arc<RwLock<HashMap<String, String>>>,
878        server_id: &str,
879    ) -> Result<()> {
880        let config = {
881            let mut servers = servers.write().await;
882            let server = servers.get_mut(server_id).ok_or_else(|| Error::Server(format!("server {} not found", server_id)))?;
883            server.status = ServerStatus::Connecting;
884            server.config.clone()
885        };
886
887        let client_result = match config.transport {
888            Transport::Stdio => {
889                let url = url::Url::parse(&config.url).map_err(|e| Error::Config(format!("invalid URL: {}", e)))?;
890                let command = url.path();
891                let args: Vec<String> = url.query_pairs().filter(|(k, _)| k == "arg").map(|(_, v)| v.to_string()).collect();
892                McpClient::connect_stdio(command, &args, None).await
893            }
894            Transport::Http => McpClient::connect_http(&config.url, config.auth.clone()).await,
895            Transport::WebSocket => McpClient::connect_websocket(&config.url).await,
896            Transport::Zap => return Err(Error::Transport("ZAP transport not yet implemented".into())),
897            Transport::Unix => return Err(Error::Transport("Unix transport not yet implemented".into())),
898        };
899
900        match client_result {
901            Ok(client) => {
902                let client = Arc::new(client);
903
904                { let tools = client.tools().await; let mut routing = tool_routing.write().await;
905                  for tool in &tools { routing.insert(tool.name.clone(), server_id.to_string()); } }
906
907                { let resources = client.resources().await; let mut routing = resource_routing.write().await;
908                  for resource in &resources {
909                      if let Some(scheme) = resource.uri.split(':').next() { routing.insert(format!("{}:", scheme), server_id.to_string()); }
910                      routing.insert(resource.uri.clone(), server_id.to_string());
911                  } }
912
913                { let prompts = client.prompts().await; let mut routing = prompt_routing.write().await;
914                  for prompt in &prompts { routing.insert(prompt.name.clone(), server_id.to_string()); } }
915
916                { let mut servers = servers.write().await;
917                  if let Some(server) = servers.get_mut(server_id) {
918                      server.client = Some(client);
919                      server.status = ServerStatus::Connected;
920                      server.last_error = None;
921                      server.reconnect_attempts = 0;
922                      server.last_health_check = Some(Instant::now());
923                  } }
924
925                tracing::info!("Connected to MCP server: {}", server_id);
926                Ok(())
927            }
928            Err(e) => {
929                let mut servers = servers.write().await;
930                if let Some(server) = servers.get_mut(server_id) {
931                    server.status = ServerStatus::Error;
932                    server.last_error = Some(e.to_string());
933                    server.reconnect_attempts += 1;
934                }
935                Err(e)
936            }
937        }
938    }
939
940    pub async fn remove_server(&self, id: &str) -> Result<()> {
941        let server = self.servers.write().await.remove(id);
942        if let Some(server) = server {
943            self.tool_routing.write().await.retain(|_, v| v != id);
944            self.resource_routing.write().await.retain(|_, v| v != id);
945            self.prompt_routing.write().await.retain(|_, v| v != id);
946            if let Some(client) = &server.client { let _ = client.close().await; }
947        }
948        Ok(())
949    }
950
951    pub async fn list_servers(&self) -> Vec<ServerInfo> {
952        let servers = self.servers.read().await;
953        let mut result = Vec::new();
954        for server in servers.values() {
955            let (tools_count, resources_count, prompts_count) = if let Some(client) = &server.client {
956                (client.tools().await.len(), client.resources().await.len(), client.prompts().await.len())
957            } else { (0, 0, 0) };
958            result.push(ServerInfo {
959                id: server.id.clone(), name: server.name.clone(), url: server.config.url.clone(),
960                status: server.status, tools_count, resources_count, prompts_count, last_error: server.last_error.clone(),
961            });
962        }
963        result
964    }
965
966    pub async fn server_status(&self, id: &str) -> Option<ServerStatus> {
967        self.servers.read().await.get(id).map(|s| s.status)
968    }
969
970    pub async fn list_tools(&self) -> Vec<McpTool> {
971        let servers = self.servers.read().await;
972        let mut tools = Vec::new();
973        for server in servers.values() {
974            if let Some(client) = &server.client {
975                if server.status == ServerStatus::Connected { tools.extend(client.tools().await); }
976            }
977        }
978        tools
979    }
980
981    pub async fn list_resources(&self) -> Vec<McpResource> {
982        let servers = self.servers.read().await;
983        let mut resources = Vec::new();
984        for server in servers.values() {
985            if let Some(client) = &server.client {
986                if server.status == ServerStatus::Connected { resources.extend(client.resources().await); }
987            }
988        }
989        resources
990    }
991
992    pub async fn list_prompts(&self) -> Vec<McpPrompt> {
993        let servers = self.servers.read().await;
994        let mut prompts = Vec::new();
995        for server in servers.values() {
996            if let Some(client) = &server.client {
997                if server.status == ServerStatus::Connected { prompts.extend(client.prompts().await); }
998            }
999        }
1000        prompts
1001    }
1002
1003    pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
1004        let server_id = self.tool_routing.read().await.get(name).cloned()
1005            .ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
1006        let client = self.servers.read().await.get(&server_id).and_then(|s| s.client.clone())
1007            .ok_or_else(|| Error::Server(format!("server {} not connected", server_id)))?;
1008        client.call_tool(name, arguments).await
1009    }
1010
1011    pub async fn read_resource(&self, uri: &str) -> Result<Value> {
1012        let server_id = {
1013            let routing = self.resource_routing.read().await;
1014            routing.get(uri).cloned().or_else(|| routing.iter().find(|(prefix, _)| uri.starts_with(prefix.as_str())).map(|(_, id)| id.clone()))
1015        }.ok_or_else(|| Error::ResourceNotFound(uri.to_string()))?;
1016        let client = self.servers.read().await.get(&server_id).and_then(|s| s.client.clone())
1017            .ok_or_else(|| Error::Server(format!("server {} not connected", server_id)))?;
1018        client.read_resource(uri).await
1019    }
1020
1021    pub async fn get_prompt(&self, name: &str, arguments: Option<Value>) -> Result<Value> {
1022        let server_id = self.prompt_routing.read().await.get(name).cloned()
1023            .ok_or_else(|| Error::Protocol(format!("prompt {} not found", name)))?;
1024        let client = self.servers.read().await.get(&server_id).and_then(|s| s.client.clone())
1025            .ok_or_else(|| Error::Server(format!("server {} not connected", server_id)))?;
1026        client.get_prompt(name, arguments).await
1027    }
1028
1029    pub async fn run(&mut self) -> Result<()> {
1030        let addr = format!("{}:{}", self.config.listen, self.config.port);
1031        tracing::info!("ZAP gateway starting on {}", addr);
1032
1033        for server_config in self.config.servers.clone() {
1034            let name = server_config.name.clone();
1035            match self.add_server(&name, server_config).await {
1036                Ok(id) => tracing::info!("Added server {} with id {}", name, id),
1037                Err(e) => tracing::error!("Failed to add server {}: {}", name, e),
1038            }
1039        }
1040
1041        let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
1042        self.shutdown_tx = Some(shutdown_tx);
1043
1044        let servers = self.servers.clone();
1045        let tool_routing = self.tool_routing.clone();
1046        let resource_routing = self.resource_routing.clone();
1047        let prompt_routing = self.prompt_routing.clone();
1048
1049        let health_task = tokio::spawn(async move {
1050            let mut check_interval = interval(Duration::from_secs(30));
1051            loop {
1052                check_interval.tick().await;
1053                let server_ids: Vec<String> = servers.read().await.keys().cloned().collect();
1054                for server_id in server_ids {
1055                    let (needs_reconnect, client) = {
1056                        let servers = servers.read().await;
1057                        if let Some(server) = servers.get(&server_id) {
1058                            let needs_reconnect = match server.status {
1059                                ServerStatus::Error | ServerStatus::Disconnected => true,
1060                                ServerStatus::Connected => server.client.as_ref().map(|c| !c.is_connected()).unwrap_or(true),
1061                                _ => false,
1062                            };
1063                            (needs_reconnect, server.client.clone())
1064                        } else { (false, None) }
1065                    };
1066
1067                    if needs_reconnect {
1068                        tracing::info!("Health check: reconnecting {}", server_id);
1069                        { servers.write().await.get_mut(&server_id).map(|s| s.status = ServerStatus::Reconnecting); }
1070                        let _ = Self::connect_server(&servers, &tool_routing, &resource_routing, &prompt_routing, &server_id).await;
1071                    } else if let Some(client) = client {
1072                        let _ = client.refresh_all().await;
1073                    }
1074                }
1075            }
1076        });
1077
1078        tokio::select! {
1079            _ = shutdown_rx.recv() => { tracing::info!("Shutdown signal received"); }
1080            _ = tokio::signal::ctrl_c() => { tracing::info!("Ctrl+C received"); }
1081        }
1082
1083        health_task.abort();
1084        for id in self.servers.read().await.keys().cloned().collect::<Vec<_>>() { let _ = self.remove_server(&id).await; }
1085        tracing::info!("Gateway shutdown complete");
1086        Ok(())
1087    }
1088
1089    pub async fn shutdown(&self) -> Result<()> {
1090        if let Some(tx) = &self.shutdown_tx { let _ = tx.send(()).await; }
1091        Ok(())
1092    }
1093}
1094
1095// ============================================================================
1096// Tests
1097// ============================================================================
1098
1099#[cfg(test)]
1100mod tests {
1101    use super::*;
1102
1103    #[test]
1104    fn test_json_rpc_request_serialize() {
1105        let req = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: json!(1), method: "tools/list".to_string(), params: None };
1106        let json = serde_json::to_string(&req).unwrap();
1107        assert!(json.contains("\"jsonrpc\":\"2.0\""));
1108        assert!(json.contains("\"method\":\"tools/list\""));
1109    }
1110
1111    #[test]
1112    fn test_json_rpc_response_deserialize() {
1113        let json = r#"{"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}"#;
1114        let resp: JsonRpcResponse = serde_json::from_str(json).unwrap();
1115        assert_eq!(resp.jsonrpc, "2.0");
1116        assert!(resp.result.is_some());
1117    }
1118
1119    #[test]
1120    fn test_mcp_tool_deserialize() {
1121        let json = r#"{"name": "calculator", "description": "Perform calculations", "inputSchema": {"type": "object"}}"#;
1122        let tool: McpTool = serde_json::from_str(json).unwrap();
1123        assert_eq!(tool.name, "calculator");
1124    }
1125
1126    #[tokio::test]
1127    async fn test_gateway_create() {
1128        let config = Config::default();
1129        let gateway = Gateway::new(config);
1130        assert!(gateway.list_servers().await.is_empty());
1131    }
1132
1133    #[tokio::test]
1134    async fn test_gateway_add_remove_server() {
1135        let config = Config::default();
1136        let gateway = Gateway::new(config);
1137        let server_config = ServerConfig { name: "test".to_string(), url: "http://localhost:8080".to_string(),
1138                                           transport: Transport::Http, timeout: 30000, auth: None };
1139        let id = gateway.add_server("test", server_config).await.unwrap();
1140        assert!(!id.is_empty());
1141        tokio::time::sleep(Duration::from_millis(10)).await;
1142        assert_eq!(gateway.list_servers().await.len(), 1);
1143        gateway.remove_server(&id).await.unwrap();
1144        assert!(gateway.list_servers().await.is_empty());
1145    }
1146
1147    #[test]
1148    fn test_server_status_display() {
1149        assert_eq!(ServerStatus::Connecting.to_string(), "connecting");
1150        assert_eq!(ServerStatus::Connected.to_string(), "connected");
1151    }
1152}