sacp_proxy/
mcp_server.rs

1use futures::channel::mpsc;
2use futures::{FutureExt, future::BoxFuture};
3use futures::{SinkExt, StreamExt};
4use fxhash::FxHashMap;
5use rmcp::ServiceExt;
6use sacp::schema::NewSessionRequest;
7use sacp::{
8    Handled, JrConnection, JrConnectionCx, JrHandler, JrMessage, JrRequestCx, MessageAndCx,
9    UntypedMessage,
10};
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
14
15use crate::{
16    JrCxExt, McpConnectRequest, McpConnectResponse, McpDisconnectNotification,
17    McpOverAcpNotification, McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
18};
19
20/// Manages MCP services offered to successor proxies and agents.
21///
22/// Use the [`Self::add_rmcp_server`] method to register MCP servers implemented using the [`rmcp`] crate.
23///
24/// This struct is a handle to the underlying registry. Cloning the struct produces a second handle to the same registry.
25///
26/// # Handling requests
27///
28/// You must add the registery (or a clone of it) to the [`JrConnection`] so that it can intercept MCP requests.
29/// Typically you do this by providing it as an argument to the [`]
30#[derive(Clone, Default, Debug)]
31pub struct McpServiceRegistry {
32    data: Arc<Mutex<McpServiceRegistryData>>,
33}
34
35#[derive(Default, Debug)]
36struct McpServiceRegistryData {
37    registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
38    registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
39    connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
40}
41
42impl McpServiceRegistry {
43    /// Creates a new empty MCP service registry
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    /// Add the MCP server to the registry and return `self`. Useful for chaining.
49    /// Equivalent to [`Self::add_rmcp_server`].
50    ///
51    /// # Parameters
52    ///
53    /// - `name`: The name of the server.
54    /// - `make_service`: A function that creates the service (e.g., `YourService::new`).
55    pub fn with_rmcp_server<S>(
56        self,
57        name: impl ToString,
58        make_service: impl Fn() -> S + 'static + Send + Sync,
59    ) -> Result<Self, sacp::Error>
60    where
61        S: rmcp::Service<rmcp::RoleServer>,
62    {
63        self.add_rmcp_server(name, make_service)?;
64        Ok(self)
65    }
66
67    /// Add an MCP server implemented using the rmcp crate.
68    ///
69    /// # Parameters
70    ///
71    /// - `name`: The name of the server.
72    /// - `make_service`: A function that creates the service (e.g., `YourService::new`).
73    pub fn add_rmcp_server<S>(
74        &self,
75        name: impl ToString,
76        make_service: impl Fn() -> S + 'static + Send + Sync,
77    ) -> Result<(), sacp::Error>
78    where
79        S: rmcp::Service<rmcp::RoleServer>,
80    {
81        struct SpawnRmcpService<F> {
82            make_service: F,
83        }
84
85        impl<F, S> DynSpawnMcpServer for SpawnRmcpService<F>
86        where
87            F: Fn() -> S + Send + Sync + 'static,
88            S: rmcp::Service<rmcp::RoleServer>,
89        {
90            fn spawn(
91                &self,
92                outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
93                incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
94            ) -> BoxFuture<'static, Result<(), sacp::Error>> {
95                let server = (self.make_service)();
96                async move {
97                    let running_server = server
98                        .serve((incoming_bytes, outgoing_bytes))
99                        .await
100                        .map_err(sacp::Error::into_internal_error)?;
101
102                    // Keep the server alive by waiting for it to finish
103                    running_server
104                        .waiting()
105                        .await
106                        .map(|_quit_reason| ())
107                        .map_err(sacp::Error::into_internal_error)
108                }
109                .boxed()
110            }
111        }
112
113        let name = name.to_string();
114        self.add_mcp_service(name, Arc::new(SpawnRmcpService { make_service }))
115    }
116
117    /// Internal helper for adding services, independent of how they are implemented
118    fn add_mcp_service(
119        &self,
120        name: String,
121        spawn: Arc<dyn DynSpawnMcpServer>,
122    ) -> Result<(), sacp::Error> {
123        let name = name.to_string();
124        if let Some(_) = self.get_registered_server_by_name(&name) {
125            return Err(sacp::util::internal_error(format!(
126                "Server with name '{}' already exists",
127                name
128            )));
129        }
130
131        let uuid = uuid::Uuid::new_v4().to_string();
132        let service = Arc::new(RegisteredMcpServer {
133            name,
134            url: format!("acp:{uuid}"),
135            spawn,
136        });
137        self.insert_registered_server(service);
138        Ok(())
139    }
140
141    fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
142        let mut data = self.data.lock().expect("not poisoned");
143        data.registered_by_name
144            .insert(service.name.clone(), service.clone());
145        data.registered_by_url
146            .insert(service.url.clone(), service.clone());
147    }
148
149    fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
150        self.data
151            .lock()
152            .expect("not poisoned")
153            .registered_by_name
154            .get(name)
155            .cloned()
156    }
157
158    fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
159        self.data
160            .lock()
161            .expect("not poisoned")
162            .registered_by_url
163            .get(url)
164            .cloned()
165    }
166
167    fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
168        self.data
169            .lock()
170            .expect("not poisoned")
171            .connections
172            .insert(connection_id.to_string(), tx);
173    }
174
175    fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
176        self.data
177            .lock()
178            .expect("not poisoned")
179            .connections
180            .get(connection_id)
181            .cloned()
182    }
183
184    fn remove_connection(&self, connection_id: &str) -> bool {
185        self.data
186            .lock()
187            .expect("not poisoned")
188            .connections
189            .remove(connection_id)
190            .is_some()
191    }
192
193    async fn handle_connect_request(
194        &self,
195        result: Result<SuccessorRequest<McpConnectRequest>, sacp::Error>,
196        request_cx: JrRequestCx<serde_json::Value>,
197    ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
198        // Check if we parsed this message successfully.
199        let SuccessorRequest { request } = match result {
200            Ok(request) => request,
201            Err(err) => {
202                request_cx.respond_with_error(err)?;
203                return Ok(Handled::Yes);
204            }
205        };
206
207        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
208        let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
209            return Ok(Handled::No(request_cx));
210        };
211
212        let request_cx = request_cx.cast::<McpConnectResponse>();
213
214        // Create a unique connection ID and a channel for future communication
215        let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
216        let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
217        self.insert_connection(&connection_id, mcp_server_tx);
218
219        // Generate streams
220        let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
221        let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
222        let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
223
224        // Create JrConnection for communicating with the server.
225        //
226        // Every request/notification that the server sends up, we will package up
227        // as an McpOverAcpRequest/McpOverAcpNotification and send to our agent.
228        //
229        // Every request/notification that is sent over `mcp_server_tx` we will
230        // send to the MCP server.
231        let spawn_results = request_cx
232            .spawn(
233                JrConnection::new(mcp_client_write.compat_write(), mcp_client_read.compat())
234                    .on_receive_message({
235                        let connection_id = connection_id.clone();
236                        let outer_cx = request_cx.connection_cx();
237                        async move |message: sacp::MessageAndCx| {
238                            // Wrap the message in McpOverAcp{Request,Notification} and forward to successor
239                            let wrapped = message.map(
240                                |request, request_cx| {
241                                    (
242                                        McpOverAcpRequest {
243                                            connection_id: connection_id.clone(),
244                                            request,
245                                        },
246                                        request_cx,
247                                    )
248                                },
249                                |notification, cx| {
250                                    (
251                                        McpOverAcpNotification {
252                                            connection_id: connection_id.clone(),
253                                            notification,
254                                        },
255                                        cx,
256                                    )
257                                },
258                            );
259                            outer_cx.send_proxied_message(wrapped)
260                        }
261                    })
262                    .with_client({
263                        async move |mcp_cx| {
264                            while let Some(msg) = mcp_server_rx.next().await {
265                                mcp_cx.send_proxied_message(msg)?;
266                            }
267                            Ok(())
268                        }
269                    }),
270            )
271            .and_then(|()| {
272                // Spawn MCP server task
273                request_cx.spawn(async move {
274                    registered_server
275                        .spawn
276                        .spawn(Box::pin(mcp_server_write), Box::pin(mcp_server_read))
277                        .await
278                })
279            });
280
281        match spawn_results {
282            Ok(()) => {
283                request_cx.respond(McpConnectResponse { connection_id })?;
284                Ok(Handled::Yes)
285            }
286
287            Err(err) => {
288                request_cx.respond_with_error(err)?;
289                Ok(Handled::Yes)
290            }
291        }
292    }
293
294    async fn handle_mcp_over_acp_request(
295        &self,
296        result: Result<SuccessorRequest<McpOverAcpRequest<UntypedMessage>>, sacp::Error>,
297        request_cx: JrRequestCx<serde_json::Value>,
298    ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
299        // Check if we parsed this message successfully.
300        let SuccessorRequest { request } = match result {
301            Ok(request) => request,
302            Err(err) => {
303                request_cx.respond_with_error(err)?;
304                return Ok(Handled::Yes);
305            }
306        };
307
308        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
309        let Some(mut mcp_server_tx) = self.get_connection(&request.connection_id) else {
310            return Ok(Handled::No(request_cx));
311        };
312
313        mcp_server_tx
314            .send(MessageAndCx::Request(request.request, request_cx))
315            .await
316            .map_err(sacp::Error::into_internal_error)?;
317
318        Ok(Handled::Yes)
319    }
320
321    async fn handle_mcp_over_acp_notification(
322        &self,
323        result: Result<SuccessorNotification<McpOverAcpNotification<UntypedMessage>>, sacp::Error>,
324        notification_cx: JrConnectionCx,
325    ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
326        // Check if we parsed this message successfully.
327        let SuccessorNotification { notification } = match result {
328            Ok(request) => request,
329            Err(err) => {
330                notification_cx.send_error_notification(err)?;
331                return Ok(Handled::Yes);
332            }
333        };
334
335        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
336        let Some(mut mcp_server_tx) = self.get_connection(&notification.connection_id) else {
337            return Ok(Handled::No(notification_cx));
338        };
339
340        mcp_server_tx
341            .send(MessageAndCx::Notification(
342                notification.notification,
343                notification_cx.clone(),
344            ))
345            .await
346            .map_err(sacp::Error::into_internal_error)?;
347
348        Ok(Handled::Yes)
349    }
350
351    async fn handle_mcp_disconnect_notification(
352        &self,
353        result: Result<SuccessorNotification<McpDisconnectNotification>, sacp::Error>,
354        notification_cx: JrConnectionCx,
355    ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
356        // Check if we parsed this message successfully.
357        let SuccessorNotification { notification } = match result {
358            Ok(request) => request,
359            Err(err) => {
360                notification_cx.send_error_notification(err)?;
361                return Ok(Handled::Yes);
362            }
363        };
364
365        // Remove connection if we have it. Otherwise, do not handle the notification.
366        if self.remove_connection(&notification.connection_id) {
367            Ok(Handled::Yes)
368        } else {
369            Ok(Handled::No(notification_cx))
370        }
371    }
372
373    async fn handle_new_session_request(
374        &self,
375        result: Result<NewSessionRequest, sacp::Error>,
376        request_cx: JrRequestCx<serde_json::Value>,
377    ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
378        // Check if we parsed this message successfully.
379        let mut request = match result {
380            Ok(request) => request,
381            Err(err) => {
382                request_cx.send_error_notification(err)?;
383                return Ok(Handled::Yes);
384            }
385        };
386
387        // Add the MCP servers into the session/new request.
388        //
389        // Q: Do we care if there are already servers with that name?
390        {
391            let data = self.data.lock().expect("not poisoned");
392            for server in data.registered_by_url.values() {
393                request.mcp_servers.push(server.acp_mcp_server());
394            }
395        }
396
397        // Forward it to the successor.
398        request_cx
399            .send_request_to_successor(request)
400            .forward_to_request_cx(request_cx.cast())?;
401
402        Ok(Handled::Yes)
403    }
404}
405
406impl JrHandler for McpServiceRegistry {
407    fn describe_chain(&self) -> impl std::fmt::Debug {
408        "McpServiceRegistry"
409    }
410
411    async fn handle_message(
412        &mut self,
413        message: sacp::MessageAndCx,
414    ) -> Result<sacp::Handled<sacp::MessageAndCx>, sacp::Error> {
415        match message {
416            sacp::MessageAndCx::Request(msg, mut cx) => {
417                let params = msg.params();
418
419                if let Some(result) =
420                    <SuccessorRequest<McpConnectRequest>>::parse_request(cx.method(), params)
421                {
422                    cx = match self.handle_connect_request(result, cx).await? {
423                        Handled::Yes => return Ok(Handled::Yes),
424                        Handled::No(cx) => cx,
425                    };
426                }
427
428                if let Some(result) =
429                    <SuccessorRequest<McpOverAcpRequest<UntypedMessage>>>::parse_request(
430                        cx.method(),
431                        params,
432                    )
433                {
434                    cx = match self.handle_mcp_over_acp_request(result, cx).await? {
435                        Handled::Yes => return Ok(Handled::Yes),
436                        Handled::No(cx) => cx,
437                    };
438                }
439
440                if let Some(result) = <NewSessionRequest>::parse_request(cx.method(), params) {
441                    cx = match self.handle_new_session_request(result, cx).await? {
442                        Handled::Yes => return Ok(Handled::Yes),
443                        Handled::No(cx) => cx,
444                    };
445                }
446
447                Ok(Handled::No(sacp::MessageAndCx::Request(msg, cx)))
448            }
449            sacp::MessageAndCx::Notification(msg, mut cx) => {
450                let params = msg.params();
451
452                if let Some(result) =
453                    <SuccessorNotification<McpOverAcpNotification<UntypedMessage>>>::parse_notification(
454                        msg.method(),
455                        params,
456                    )
457                {
458                    cx = match self.handle_mcp_over_acp_notification(result, cx).await? {
459                        Handled::Yes => return Ok(Handled::Yes),
460                        Handled::No(cx) => cx,
461                    };
462                }
463
464                if let Some(result) =
465                    <SuccessorNotification<McpDisconnectNotification>>::parse_notification(
466                        msg.method(),
467                        params,
468                    )
469                {
470                    cx = match self.handle_mcp_disconnect_notification(result, cx).await? {
471                        Handled::Yes => return Ok(Handled::Yes),
472                        Handled::No(cx) => cx,
473                    };
474                }
475
476                Ok(sacp::Handled::No(sacp::MessageAndCx::Notification(msg, cx)))
477            }
478        }
479    }
480}
481
482#[derive(Clone)]
483struct RegisteredMcpServer {
484    name: String,
485    url: String,
486    spawn: Arc<dyn DynSpawnMcpServer>,
487}
488
489impl RegisteredMcpServer {
490    fn acp_mcp_server(&self) -> sacp::schema::McpServer {
491        sacp::schema::McpServer::Http {
492            name: self.name.clone(),
493            url: self.url.clone(),
494            headers: vec![],
495        }
496    }
497}
498
499impl std::fmt::Debug for RegisteredMcpServer {
500    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501        f.debug_struct("RegisteredMcpServer")
502            .field("name", &self.name)
503            .field("url", &self.url)
504            .finish()
505    }
506}
507
508trait DynSpawnMcpServer: 'static + Send + Sync {
509    fn spawn(
510        &self,
511        outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
512        incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
513    ) -> BoxFuture<'static, Result<(), sacp::Error>>;
514}
515
516impl AsRef<McpServiceRegistry> for McpServiceRegistry {
517    fn as_ref(&self) -> &McpServiceRegistry {
518        self
519    }
520}