sacp_proxy/
mcp_server.rs

1use futures::channel::mpsc;
2use futures::{FutureExt, future::BoxFuture};
3use futures::{SinkExt, StreamExt};
4use fxhash::FxHashMap;
5use rmcp::ServiceExt;
6use sacp::NewSessionRequest;
7use sacp::{
8    Handled, JrConnection, JrConnectionCx, JrHandler, JrMessage, JrRequestCx, MessageAndCx,
9    UntypedMessage,
10};
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
14
15use crate::{
16    JrCxExt, McpConnectRequest, McpConnectResponse, McpDisconnectNotification,
17    McpOverAcpNotification, McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
18};
19
20/// Manages MCP services offered to successor proxies and agents.
21///
22/// Use the [`Self::add_rmcp_server`] method to register MCP servers implemented using the [`rmcp`] crate.
23///
24/// This struct is a handle to the underlying registry. Cloning the struct produces a second handle to the same registry.
25///
26/// # Handling requests
27///
28/// You must add the registery (or a clone of it) to the [`JrConnection`] so that it can intercept MCP requests.
29/// Typically you do this by providing it as an argument to the [`]
30#[derive(Clone, Default, Debug)]
31pub struct McpServiceRegistry {
32    data: Arc<Mutex<McpServiceRegistryData>>,
33}
34
35#[derive(Default, Debug)]
36struct McpServiceRegistryData {
37    registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
38    registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
39    connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
40}
41
42impl McpServiceRegistry {
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Add the MCP server to the registry and return `self`. Useful for chaining.
48    /// Equivalent to [`Self::add_rmcp_server`].
49    ///
50    /// # Parameters
51    ///
52    /// - `name`: The name of the server.
53    /// - `make_service`: A function that creates the service (e.g., `YourService::new`).
54    pub fn with_rmcp_server<S>(
55        self,
56        name: impl ToString,
57        make_service: impl Fn() -> S + 'static + Send + Sync,
58    ) -> Result<Self, sacp::Error>
59    where
60        S: rmcp::Service<rmcp::RoleServer>,
61    {
62        self.add_rmcp_server(name, make_service)?;
63        Ok(self)
64    }
65
66    /// Add an MCP server implemented using the rmcp crate.
67    ///
68    /// # Parameters
69    ///
70    /// - `name`: The name of the server.
71    /// - `make_service`: A function that creates the service (e.g., `YourService::new`).
72    pub fn add_rmcp_server<S>(
73        &self,
74        name: impl ToString,
75        make_service: impl Fn() -> S + 'static + Send + Sync,
76    ) -> Result<(), sacp::Error>
77    where
78        S: rmcp::Service<rmcp::RoleServer>,
79    {
80        struct SpawnRmcpService<F> {
81            make_service: F,
82        }
83
84        impl<F, S> DynSpawnMcpServer for SpawnRmcpService<F>
85        where
86            F: Fn() -> S + Send + Sync + 'static,
87            S: rmcp::Service<rmcp::RoleServer>,
88        {
89            fn spawn(
90                &self,
91                outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
92                incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
93            ) -> BoxFuture<'static, Result<(), sacp::Error>> {
94                let server = (self.make_service)();
95                async move {
96                    let running_server = server
97                        .serve((incoming_bytes, outgoing_bytes))
98                        .await
99                        .map_err(sacp::Error::into_internal_error)?;
100
101                    // Keep the server alive by waiting for it to finish
102                    running_server
103                        .waiting()
104                        .await
105                        .map(|_quit_reason| ())
106                        .map_err(sacp::Error::into_internal_error)
107                }
108                .boxed()
109            }
110        }
111
112        let name = name.to_string();
113        self.add_mcp_service(name, Arc::new(SpawnRmcpService { make_service }))
114    }
115
116    /// Internal helper for adding services, independent of how they are implemented
117    fn add_mcp_service(
118        &self,
119        name: String,
120        spawn: Arc<dyn DynSpawnMcpServer>,
121    ) -> Result<(), sacp::Error> {
122        let name = name.to_string();
123        if let Some(_) = self.get_registered_server_by_name(&name) {
124            return Err(sacp::util::internal_error(format!(
125                "Server with name '{}' already exists",
126                name
127            )));
128        }
129
130        let uuid = uuid::Uuid::new_v4().to_string();
131        let service = Arc::new(RegisteredMcpServer {
132            name,
133            url: format!("acp:{uuid}"),
134            spawn,
135        });
136        self.insert_registered_server(service);
137        Ok(())
138    }
139
140    fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
141        let mut data = self.data.lock().expect("not poisoned");
142        data.registered_by_name
143            .insert(service.name.clone(), service.clone());
144        data.registered_by_url
145            .insert(service.url.clone(), service.clone());
146    }
147
148    fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
149        self.data
150            .lock()
151            .expect("not poisoned")
152            .registered_by_name
153            .get(name)
154            .cloned()
155    }
156
157    fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
158        self.data
159            .lock()
160            .expect("not poisoned")
161            .registered_by_url
162            .get(url)
163            .cloned()
164    }
165
166    fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
167        self.data
168            .lock()
169            .expect("not poisoned")
170            .connections
171            .insert(connection_id.to_string(), tx);
172    }
173
174    fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
175        self.data
176            .lock()
177            .expect("not poisoned")
178            .connections
179            .get(connection_id)
180            .cloned()
181    }
182
183    fn remove_connection(&self, connection_id: &str) -> bool {
184        self.data
185            .lock()
186            .expect("not poisoned")
187            .connections
188            .remove(connection_id)
189            .is_some()
190    }
191
192    async fn handle_connect_request(
193        &self,
194        result: Result<SuccessorRequest<McpConnectRequest>, sacp::Error>,
195        request_cx: JrRequestCx<serde_json::Value>,
196    ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
197        // Check if we parsed this message successfully.
198        let SuccessorRequest { request } = match result {
199            Ok(request) => request,
200            Err(err) => {
201                request_cx.respond_with_error(err)?;
202                return Ok(Handled::Yes);
203            }
204        };
205
206        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
207        let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
208            return Ok(Handled::No(request_cx));
209        };
210
211        let request_cx = request_cx.cast::<McpConnectResponse>();
212
213        // Create a unique connection ID and a channel for future communication
214        let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
215        let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
216        self.insert_connection(&connection_id, mcp_server_tx);
217
218        // Generate streams
219        let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
220        let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
221        let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
222
223        // Create JrConnection for communicating with the server.
224        //
225        // Every request/notification that the server sends up, we will package up
226        // as an McpOverAcpRequest/McpOverAcpNotification and send to our agent.
227        //
228        // Every request/notification that is sent over `mcp_server_tx` we will
229        // send to the MCP server.
230        let spawn_results = request_cx
231            .spawn(
232                JrConnection::new(mcp_client_write.compat_write(), mcp_client_read.compat())
233                    .on_receive_message({
234                        let connection_id = connection_id.clone();
235                        let outer_cx = request_cx.connection_cx();
236                        async move |message: sacp::MessageAndCx| {
237                            // Wrap the message in McpOverAcp{Request,Notification} and forward to successor
238                            let wrapped = message.map(
239                                |request, request_cx| {
240                                    (
241                                        McpOverAcpRequest {
242                                            connection_id: connection_id.clone(),
243                                            request,
244                                        },
245                                        request_cx,
246                                    )
247                                },
248                                |notification, cx| {
249                                    (
250                                        McpOverAcpNotification {
251                                            connection_id: connection_id.clone(),
252                                            notification,
253                                        },
254                                        cx,
255                                    )
256                                },
257                            );
258                            outer_cx.send_proxied_message(wrapped)
259                        }
260                    })
261                    .with_client({
262                        async move |mcp_cx| {
263                            while let Some(msg) = mcp_server_rx.next().await {
264                                mcp_cx.send_proxied_message(msg)?;
265                            }
266                            Ok(())
267                        }
268                    }),
269            )
270            .and_then(|()| {
271                // Spawn MCP server task
272                request_cx.spawn(async move {
273                    registered_server
274                        .spawn
275                        .spawn(Box::pin(mcp_server_write), Box::pin(mcp_server_read))
276                        .await
277                })
278            });
279
280        match spawn_results {
281            Ok(()) => {
282                request_cx.respond(McpConnectResponse { connection_id })?;
283                Ok(Handled::Yes)
284            }
285
286            Err(err) => {
287                request_cx.respond_with_error(err)?;
288                Ok(Handled::Yes)
289            }
290        }
291    }
292
293    async fn handle_mcp_over_acp_request(
294        &self,
295        result: Result<SuccessorRequest<McpOverAcpRequest<UntypedMessage>>, sacp::Error>,
296        request_cx: JrRequestCx<serde_json::Value>,
297    ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
298        // Check if we parsed this message successfully.
299        let SuccessorRequest { request } = match result {
300            Ok(request) => request,
301            Err(err) => {
302                request_cx.respond_with_error(err)?;
303                return Ok(Handled::Yes);
304            }
305        };
306
307        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
308        let Some(mut mcp_server_tx) = self.get_connection(&request.connection_id) else {
309            return Ok(Handled::No(request_cx));
310        };
311
312        mcp_server_tx
313            .send(MessageAndCx::Request(request.request, request_cx))
314            .await
315            .map_err(sacp::Error::into_internal_error)?;
316
317        Ok(Handled::Yes)
318    }
319
320    async fn handle_mcp_over_acp_notification(
321        &self,
322        result: Result<SuccessorNotification<McpOverAcpNotification<UntypedMessage>>, sacp::Error>,
323        notification_cx: JrConnectionCx,
324    ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
325        // Check if we parsed this message successfully.
326        let SuccessorNotification { notification } = match result {
327            Ok(request) => request,
328            Err(err) => {
329                notification_cx.send_error_notification(err)?;
330                return Ok(Handled::Yes);
331            }
332        };
333
334        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
335        let Some(mut mcp_server_tx) = self.get_connection(&notification.connection_id) else {
336            return Ok(Handled::No(notification_cx));
337        };
338
339        mcp_server_tx
340            .send(MessageAndCx::Notification(
341                notification.notification,
342                notification_cx.clone(),
343            ))
344            .await
345            .map_err(sacp::Error::into_internal_error)?;
346
347        Ok(Handled::Yes)
348    }
349
350    async fn handle_mcp_disconnect_notification(
351        &self,
352        result: Result<SuccessorNotification<McpDisconnectNotification>, sacp::Error>,
353        notification_cx: JrConnectionCx,
354    ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
355        // Check if we parsed this message successfully.
356        let SuccessorNotification { notification } = match result {
357            Ok(request) => request,
358            Err(err) => {
359                notification_cx.send_error_notification(err)?;
360                return Ok(Handled::Yes);
361            }
362        };
363
364        // Remove connection if we have it. Otherwise, do not handle the notification.
365        if self.remove_connection(&notification.connection_id) {
366            Ok(Handled::Yes)
367        } else {
368            Ok(Handled::No(notification_cx))
369        }
370    }
371
372    async fn handle_new_session_request(
373        &self,
374        result: Result<NewSessionRequest, sacp::Error>,
375        request_cx: JrRequestCx<serde_json::Value>,
376    ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
377        // Check if we parsed this message successfully.
378        let mut request = match result {
379            Ok(request) => request,
380            Err(err) => {
381                request_cx.send_error_notification(err)?;
382                return Ok(Handled::Yes);
383            }
384        };
385
386        // Add the MCP servers into the session/new request.
387        //
388        // Q: Do we care if there are already servers with that name?
389        {
390            let data = self.data.lock().expect("not poisoned");
391            for server in data.registered_by_url.values() {
392                request.mcp_servers.push(server.acp_mcp_server());
393            }
394        }
395
396        // Forward it to the successor.
397        request_cx
398            .send_request_to_successor(request)
399            .forward_to_request_cx(request_cx.cast())?;
400
401        Ok(Handled::Yes)
402    }
403}
404
405impl JrHandler for McpServiceRegistry {
406    fn describe_chain(&self) -> impl std::fmt::Debug {
407        "McpServiceRegistry"
408    }
409
410    async fn handle_message(
411        &mut self,
412        message: sacp::MessageAndCx,
413    ) -> Result<sacp::Handled<sacp::MessageAndCx>, sacp::Error> {
414        match message {
415            sacp::MessageAndCx::Request(msg, mut cx) => {
416                let params = msg.params();
417
418                if let Some(result) =
419                    <SuccessorRequest<McpConnectRequest>>::parse_request(cx.method(), params)
420                {
421                    cx = match self.handle_connect_request(result, cx).await? {
422                        Handled::Yes => return Ok(Handled::Yes),
423                        Handled::No(cx) => cx,
424                    };
425                }
426
427                if let Some(result) =
428                    <SuccessorRequest<McpOverAcpRequest<UntypedMessage>>>::parse_request(
429                        cx.method(),
430                        params,
431                    )
432                {
433                    cx = match self.handle_mcp_over_acp_request(result, cx).await? {
434                        Handled::Yes => return Ok(Handled::Yes),
435                        Handled::No(cx) => cx,
436                    };
437                }
438
439                if let Some(result) = <NewSessionRequest>::parse_request(cx.method(), params) {
440                    cx = match self.handle_new_session_request(result, cx).await? {
441                        Handled::Yes => return Ok(Handled::Yes),
442                        Handled::No(cx) => cx,
443                    };
444                }
445
446                Ok(Handled::No(sacp::MessageAndCx::Request(msg, cx)))
447            }
448            sacp::MessageAndCx::Notification(msg, mut cx) => {
449                let params = msg.params();
450
451                if let Some(result) =
452                    <SuccessorNotification<McpOverAcpNotification<UntypedMessage>>>::parse_notification(
453                        msg.method(),
454                        params,
455                    )
456                {
457                    cx = match self.handle_mcp_over_acp_notification(result, cx).await? {
458                        Handled::Yes => return Ok(Handled::Yes),
459                        Handled::No(cx) => cx,
460                    };
461                }
462
463                if let Some(result) =
464                    <SuccessorNotification<McpDisconnectNotification>>::parse_notification(
465                        msg.method(),
466                        params,
467                    )
468                {
469                    cx = match self.handle_mcp_disconnect_notification(result, cx).await? {
470                        Handled::Yes => return Ok(Handled::Yes),
471                        Handled::No(cx) => cx,
472                    };
473                }
474
475                Ok(sacp::Handled::No(sacp::MessageAndCx::Notification(msg, cx)))
476            }
477        }
478    }
479}
480
481#[derive(Clone)]
482struct RegisteredMcpServer {
483    name: String,
484    url: String,
485    spawn: Arc<dyn DynSpawnMcpServer>,
486}
487
488impl RegisteredMcpServer {
489    fn acp_mcp_server(&self) -> sacp::McpServer {
490        sacp::McpServer::Http {
491            name: self.name.clone(),
492            url: self.url.clone(),
493            headers: vec![],
494        }
495    }
496}
497
498impl std::fmt::Debug for RegisteredMcpServer {
499    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500        f.debug_struct("RegisteredMcpServer")
501            .field("name", &self.name)
502            .field("url", &self.url)
503            .finish()
504    }
505}
506
507trait DynSpawnMcpServer: 'static + Send + Sync {
508    fn spawn(
509        &self,
510        outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
511        incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
512    ) -> BoxFuture<'static, Result<(), sacp::Error>>;
513}
514
515impl AsRef<McpServiceRegistry> for McpServiceRegistry {
516    fn as_ref(&self) -> &McpServiceRegistry {
517        self
518    }
519}