sacp_proxy/
mcp_server.rs

1use agent_client_protocol::{self as acp, NewSessionRequest};
2use futures::channel::mpsc;
3use futures::{FutureExt, future::BoxFuture};
4use futures::{SinkExt, StreamExt};
5use fxhash::FxHashMap;
6use rmcp::ServiceExt;
7use sacp::{
8    Handled, JsonRpcConnection, JsonRpcConnectionCx, JsonRpcHandler, JsonRpcMessage,
9    JsonRpcRequestCx, MessageAndCx, UntypedMessage,
10};
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
14
15use crate::{
16    JsonRpcCxExt, McpConnectRequest, McpConnectResponse, McpDisconnectNotification,
17    McpOverAcpNotification, McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
18};
19
20/// Manages MCP services offered to successor proxies and agents.
21///
22/// Use the [`Self::add_rmcp_server`] method to register MCP servers implemented using the [`rmcp`] crate.
23///
24/// This struct is a handle to the underlying registry. Cloning the struct produces a second handle to the same registry.
25///
26/// # Handling requests
27///
28/// You must add the registery (or a clone of it) to the [`JsonRpcConnection`] so that it can intercept MCP requests.
29/// Typically you do this by providing it as an argument to the [`]
30#[derive(Clone, Default, Debug)]
31pub struct McpServiceRegistry {
32    data: Arc<Mutex<McpServiceRegistryData>>,
33}
34
35#[derive(Default, Debug)]
36struct McpServiceRegistryData {
37    registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
38    registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
39    connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
40}
41
42impl McpServiceRegistry {
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Add the MCP server to the registry and return `self`. Useful for chaining.
48    /// Equivalent to [`Self::add_rmcp_server`].
49    ///
50    /// # Parameters
51    ///
52    /// - `name`: The name of the server.
53    /// - `make_service`: A function that creates the service (e.g., `YourService::new`).
54    pub fn with_rmcp_server<S>(
55        self,
56        name: impl ToString,
57        make_service: impl Fn() -> S + 'static + Send + Sync,
58    ) -> Result<Self, acp::Error>
59    where
60        S: rmcp::Service<rmcp::RoleServer>,
61    {
62        self.add_rmcp_server(name, make_service)?;
63        Ok(self)
64    }
65
66    /// Add an MCP server implemented using the rmcp crate.
67    ///
68    /// # Parameters
69    ///
70    /// - `name`: The name of the server.
71    /// - `make_service`: A function that creates the service (e.g., `YourService::new`).
72    pub fn add_rmcp_server<S>(
73        &self,
74        name: impl ToString,
75        make_service: impl Fn() -> S + 'static + Send + Sync,
76    ) -> Result<(), acp::Error>
77    where
78        S: rmcp::Service<rmcp::RoleServer>,
79    {
80        struct SpawnRmcpService<F> {
81            make_service: F,
82        }
83
84        impl<F, S> DynSpawnMcpServer for SpawnRmcpService<F>
85        where
86            F: Fn() -> S + Send + Sync + 'static,
87            S: rmcp::Service<rmcp::RoleServer>,
88        {
89            fn spawn(
90                &self,
91                outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
92                incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
93            ) -> BoxFuture<'static, Result<(), acp::Error>> {
94                let server = (self.make_service)();
95                async move {
96                    let running_server = server
97                        .serve((incoming_bytes, outgoing_bytes))
98                        .await
99                        .map_err(acp::Error::into_internal_error)?;
100
101                    // Keep the server alive by waiting for it to finish
102                    running_server
103                        .waiting()
104                        .await
105                        .map(|_quit_reason| ())
106                        .map_err(acp::Error::into_internal_error)
107                }
108                .boxed()
109            }
110        }
111
112        let name = name.to_string();
113        self.add_mcp_service(name, Arc::new(SpawnRmcpService { make_service }))
114    }
115
116    /// Internal helper for adding services, independent of how they are implemented
117    fn add_mcp_service(
118        &self,
119        name: String,
120        spawn: Arc<dyn DynSpawnMcpServer>,
121    ) -> Result<(), acp::Error> {
122        let name = name.to_string();
123        if let Some(_) = self.get_registered_server_by_name(&name) {
124            return Err(sacp::util::internal_error(format!(
125                "Server with name '{}' already exists",
126                name
127            )));
128        }
129
130        let uuid = uuid::Uuid::new_v4().to_string();
131        let service = Arc::new(RegisteredMcpServer {
132            name,
133            url: format!("acp:{uuid}"),
134            spawn,
135        });
136        self.insert_registered_server(service);
137        Ok(())
138    }
139
140    fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
141        let mut data = self.data.lock().expect("not poisoned");
142        data.registered_by_name
143            .insert(service.name.clone(), service.clone());
144        data.registered_by_url
145            .insert(service.url.clone(), service.clone());
146    }
147
148    fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
149        self.data
150            .lock()
151            .expect("not poisoned")
152            .registered_by_name
153            .get(name)
154            .cloned()
155    }
156
157    fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
158        self.data
159            .lock()
160            .expect("not poisoned")
161            .registered_by_url
162            .get(url)
163            .cloned()
164    }
165
166    fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
167        self.data
168            .lock()
169            .expect("not poisoned")
170            .connections
171            .insert(connection_id.to_string(), tx);
172    }
173
174    fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
175        self.data
176            .lock()
177            .expect("not poisoned")
178            .connections
179            .get(connection_id)
180            .cloned()
181    }
182
183    fn remove_connection(&self, connection_id: &str) -> bool {
184        self.data
185            .lock()
186            .expect("not poisoned")
187            .connections
188            .remove(connection_id)
189            .is_some()
190    }
191
192    async fn handle_connect_request(
193        &self,
194        result: Result<SuccessorRequest<McpConnectRequest>, agent_client_protocol::Error>,
195        request_cx: JsonRpcRequestCx<serde_json::Value>,
196    ) -> Result<Handled<JsonRpcRequestCx<serde_json::Value>>, agent_client_protocol::Error> {
197        // Check if we parsed this message successfully.
198        let SuccessorRequest { request } = match result {
199            Ok(request) => request,
200            Err(err) => {
201                request_cx.respond_with_error(err)?;
202                return Ok(Handled::Yes);
203            }
204        };
205
206        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
207        let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
208            return Ok(Handled::No(request_cx));
209        };
210
211        let request_cx = request_cx.cast::<McpConnectResponse>();
212
213        // Create a unique connection ID and a channel for future communication
214        let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
215        let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
216        self.insert_connection(&connection_id, mcp_server_tx);
217
218        // Generate streams
219        let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
220        let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
221        let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
222
223        // Create JsonRpcConnection for communicating with the server.
224        //
225        // Every request/notification that the server sends up, we will package up
226        // as an McpOverAcpRequest/McpOverAcpNotification and send to our agent.
227        //
228        // Every request/notification that is sent over `mcp_server_tx` we will
229        // send to the MCP server.
230        let spawn_results = request_cx
231            .spawn(
232                JsonRpcConnection::new(mcp_client_write.compat_write(), mcp_client_read.compat())
233                    .on_receive_message({
234                        let connection_id = connection_id.clone();
235                        let outer_cx = request_cx.connection_cx();
236                        async move |message: sacp::MessageAndCx| {
237                            // Wrap the message in McpOverAcp{Request,Notification} and forward to successor
238                            let wrapped = message.map(
239                                |request, request_cx| {
240                                    (
241                                        McpOverAcpRequest {
242                                            connection_id: connection_id.clone(),
243                                            request,
244                                        },
245                                        request_cx,
246                                    )
247                                },
248                                |notification, cx| {
249                                    (
250                                        McpOverAcpNotification {
251                                            connection_id: connection_id.clone(),
252                                            notification,
253                                        },
254                                        cx,
255                                    )
256                                },
257                            );
258                            outer_cx.send_proxied_message(wrapped)
259                        }
260                    })
261                    .with_client({
262                        async move |mcp_cx| {
263                            while let Some(msg) = mcp_server_rx.next().await {
264                                mcp_cx.send_proxied_message(msg)?;
265                            }
266                            Ok(())
267                        }
268                    }),
269            )
270            .and_then(|()| {
271                // Spawn MCP server task
272                request_cx.spawn(async move {
273                    registered_server
274                        .spawn
275                        .spawn(Box::pin(mcp_server_write), Box::pin(mcp_server_read))
276                        .await
277                })
278            });
279
280        match spawn_results {
281            Ok(()) => {
282                request_cx.respond(McpConnectResponse { connection_id })?;
283                Ok(Handled::Yes)
284            }
285
286            Err(err) => {
287                request_cx.respond_with_error(err)?;
288                Ok(Handled::Yes)
289            }
290        }
291    }
292
293    async fn handle_mcp_over_acp_request(
294        &self,
295        result: Result<
296            SuccessorRequest<McpOverAcpRequest<UntypedMessage>>,
297            agent_client_protocol::Error,
298        >,
299        request_cx: JsonRpcRequestCx<serde_json::Value>,
300    ) -> Result<Handled<JsonRpcRequestCx<serde_json::Value>>, agent_client_protocol::Error> {
301        // Check if we parsed this message successfully.
302        let SuccessorRequest { request } = match result {
303            Ok(request) => request,
304            Err(err) => {
305                request_cx.respond_with_error(err)?;
306                return Ok(Handled::Yes);
307            }
308        };
309
310        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
311        let Some(mut mcp_server_tx) = self.get_connection(&request.connection_id) else {
312            return Ok(Handled::No(request_cx));
313        };
314
315        mcp_server_tx
316            .send(MessageAndCx::Request(request.request, request_cx))
317            .await
318            .map_err(acp::Error::into_internal_error)?;
319
320        Ok(Handled::Yes)
321    }
322
323    async fn handle_mcp_over_acp_notification(
324        &self,
325        result: Result<
326            SuccessorNotification<McpOverAcpNotification<UntypedMessage>>,
327            agent_client_protocol::Error,
328        >,
329        notification_cx: JsonRpcConnectionCx,
330    ) -> Result<Handled<JsonRpcConnectionCx>, agent_client_protocol::Error> {
331        // Check if we parsed this message successfully.
332        let SuccessorNotification { notification } = match result {
333            Ok(request) => request,
334            Err(err) => {
335                notification_cx.send_error_notification(err)?;
336                return Ok(Handled::Yes);
337            }
338        };
339
340        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
341        let Some(mut mcp_server_tx) = self.get_connection(&notification.connection_id) else {
342            return Ok(Handled::No(notification_cx));
343        };
344
345        mcp_server_tx
346            .send(MessageAndCx::Notification(
347                notification.notification,
348                notification_cx.clone(),
349            ))
350            .await
351            .map_err(acp::Error::into_internal_error)?;
352
353        Ok(Handled::Yes)
354    }
355
356    async fn handle_mcp_disconnect_notification(
357        &self,
358        result: Result<
359            SuccessorNotification<McpDisconnectNotification>,
360            agent_client_protocol::Error,
361        >,
362        notification_cx: JsonRpcConnectionCx,
363    ) -> Result<Handled<JsonRpcConnectionCx>, agent_client_protocol::Error> {
364        // Check if we parsed this message successfully.
365        let SuccessorNotification { notification } = match result {
366            Ok(request) => request,
367            Err(err) => {
368                notification_cx.send_error_notification(err)?;
369                return Ok(Handled::Yes);
370            }
371        };
372
373        // Remove connection if we have it. Otherwise, do not handle the notification.
374        if self.remove_connection(&notification.connection_id) {
375            Ok(Handled::Yes)
376        } else {
377            Ok(Handled::No(notification_cx))
378        }
379    }
380
381    async fn handle_new_session_request(
382        &self,
383        result: Result<NewSessionRequest, agent_client_protocol::Error>,
384        request_cx: JsonRpcRequestCx<serde_json::Value>,
385    ) -> Result<Handled<JsonRpcRequestCx<serde_json::Value>>, agent_client_protocol::Error> {
386        // Check if we parsed this message successfully.
387        let mut request = match result {
388            Ok(request) => request,
389            Err(err) => {
390                request_cx.send_error_notification(err)?;
391                return Ok(Handled::Yes);
392            }
393        };
394
395        // Add the MCP servers into the session/new request.
396        //
397        // Q: Do we care if there are already servers with that name?
398        {
399            let data = self.data.lock().expect("not poisoned");
400            for server in data.registered_by_url.values() {
401                request.mcp_servers.push(server.acp_mcp_server());
402            }
403        }
404
405        // Forward it to the successor.
406        request_cx
407            .send_request_to_successor(request)
408            .forward_to_request_cx(request_cx.cast())?;
409
410        Ok(Handled::Yes)
411    }
412}
413
414impl JsonRpcHandler for McpServiceRegistry {
415    fn describe_chain(&self) -> impl std::fmt::Debug {
416        "McpServiceRegistry"
417    }
418
419    async fn handle_message(
420        &mut self,
421        message: sacp::MessageAndCx,
422    ) -> Result<sacp::Handled<sacp::MessageAndCx>, agent_client_protocol::Error> {
423        match message {
424            sacp::MessageAndCx::Request(msg, mut cx) => {
425                let params = msg.params();
426
427                if let Some(result) =
428                    <SuccessorRequest<McpConnectRequest>>::parse_request(cx.method(), params)
429                {
430                    cx = match self.handle_connect_request(result, cx).await? {
431                        Handled::Yes => return Ok(Handled::Yes),
432                        Handled::No(cx) => cx,
433                    };
434                }
435
436                if let Some(result) =
437                    <SuccessorRequest<McpOverAcpRequest<UntypedMessage>>>::parse_request(
438                        cx.method(),
439                        params,
440                    )
441                {
442                    cx = match self.handle_mcp_over_acp_request(result, cx).await? {
443                        Handled::Yes => return Ok(Handled::Yes),
444                        Handled::No(cx) => cx,
445                    };
446                }
447
448                if let Some(result) = <NewSessionRequest>::parse_request(cx.method(), params) {
449                    cx = match self.handle_new_session_request(result, cx).await? {
450                        Handled::Yes => return Ok(Handled::Yes),
451                        Handled::No(cx) => cx,
452                    };
453                }
454
455                Ok(Handled::No(sacp::MessageAndCx::Request(msg, cx)))
456            }
457            sacp::MessageAndCx::Notification(msg, mut cx) => {
458                let params = msg.params();
459
460                if let Some(result) =
461                    <SuccessorNotification<McpOverAcpNotification<UntypedMessage>>>::parse_notification(
462                        msg.method(),
463                        params,
464                    )
465                {
466                    cx = match self.handle_mcp_over_acp_notification(result, cx).await? {
467                        Handled::Yes => return Ok(Handled::Yes),
468                        Handled::No(cx) => cx,
469                    };
470                }
471
472                if let Some(result) =
473                    <SuccessorNotification<McpDisconnectNotification>>::parse_notification(
474                        msg.method(),
475                        params,
476                    )
477                {
478                    cx = match self.handle_mcp_disconnect_notification(result, cx).await? {
479                        Handled::Yes => return Ok(Handled::Yes),
480                        Handled::No(cx) => cx,
481                    };
482                }
483
484                Ok(sacp::Handled::No(sacp::MessageAndCx::Notification(msg, cx)))
485            }
486        }
487    }
488}
489
490#[derive(Clone)]
491struct RegisteredMcpServer {
492    name: String,
493    url: String,
494    spawn: Arc<dyn DynSpawnMcpServer>,
495}
496
497impl RegisteredMcpServer {
498    fn acp_mcp_server(&self) -> acp::McpServer {
499        acp::McpServer::Http {
500            name: self.name.clone(),
501            url: self.url.clone(),
502            headers: vec![],
503        }
504    }
505}
506
507impl std::fmt::Debug for RegisteredMcpServer {
508    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509        f.debug_struct("RegisteredMcpServer")
510            .field("name", &self.name)
511            .field("url", &self.url)
512            .finish()
513    }
514}
515
516trait DynSpawnMcpServer: 'static + Send + Sync {
517    fn spawn(
518        &self,
519        outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
520        incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
521    ) -> BoxFuture<'static, Result<(), acp::Error>>;
522}
523
524impl AsRef<McpServiceRegistry> for McpServiceRegistry {
525    fn as_ref(&self) -> &McpServiceRegistry {
526        self
527    }
528}