sacp_proxy/
mcp_server_registry.rs

1use futures::channel::mpsc;
2use futures::{SinkExt, StreamExt};
3use fxhash::FxHashMap;
4
5use sacp::schema::{NewSessionRequest, NewSessionResponse, SessionId};
6use sacp::util::MatchMessage;
7use sacp::{
8    Channel, Component, DynComponent, Handled, JrConnectionCx, JrHandlerChain, JrMessageHandler,
9    JrRequestCx, MessageAndCx, UntypedMessage,
10};
11use std::sync::{Arc, Mutex};
12
13use crate::mcp_server_builder::McpServer;
14use crate::{
15    McpConnectRequest, McpConnectResponse, McpDisconnectNotification, McpOverAcpNotification,
16    McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
17};
18
19/// Manages MCP services offered to successor proxies and agents.
20///
21/// Use the [`Self::add_mcp_server`] method to register MCP servers. For rmcp-based servers,
22/// use the `sacp-rmcp` crate which provides convenient extension methods.
23///
24/// This struct is a handle to the underlying registry. Cloning the struct produces a second handle to the same registry.
25///
26/// # Handling requests
27///
28/// You must add the registry (or a clone of it) to the [`JrHandlerChain`] so that it can intercept MCP requests.
29/// Typically you do this by providing it as an argument to the handler chain methods.
30///
31/// [`JrHandlerChain`]: sacp::JrHandlerChain
32#[derive(Clone, Default, Debug)]
33pub struct McpServiceRegistry {
34    data: Arc<Mutex<McpServiceRegistryData>>,
35}
36
37#[derive(Default, Debug)]
38struct McpServiceRegistryData {
39    registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
40    registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
41    connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
42}
43
44impl McpServiceRegistry {
45    /// Creates a new empty MCP service registry
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Add an [`McpServer`] to the registry.
51    ///
52    /// This server will be added to all new sessions where this registry is in the handler chain.
53    ///
54    /// See the [`McpServer`] documentation for more information.
55    pub fn with_mcp_server(
56        self,
57        name: impl ToString,
58        server: McpServer,
59    ) -> Result<Self, sacp::Error> {
60        self.add_mcp_server_with_context(name, move |mcp_cx| server.new_connection(mcp_cx))?;
61        Ok(self)
62    }
63
64    /// Add an MCP server to the registry using a custom constructor function.
65    ///
66    /// This server will be added to all new sessions where this registry is in the handler chain.
67    ///
68    /// This method is for independent MCP servers that do not make use of ACP.
69    /// You may wish to use the `sacp-rmcp` crate which provides convenient
70    /// extension methods for working with MCP servers implemented using the `rmcp` crate.
71    ///
72    /// # Parameters
73    ///
74    /// - `name`: The name of the server.
75    /// - `new_fn`: Constructor function that creates the MCP server and returns a [`Component`] for connecting to it.
76    pub fn add_mcp_server<C: Component>(
77        &self,
78        name: impl ToString,
79        new_fn: impl Fn() -> C + Send + Sync + 'static,
80    ) -> Result<(), sacp::Error> {
81        struct FnSpawner<F> {
82            new_fn: F,
83        }
84
85        impl<C, F> SpawnMcpServer for FnSpawner<F>
86        where
87            F: Fn() -> C + Send + Sync + 'static,
88            C: Component,
89        {
90            fn spawn(&self, _cx: McpContext) -> DynComponent {
91                let component = (self.new_fn)();
92                DynComponent::new(component)
93            }
94        }
95
96        self.add_mcp_server_internal(name, FnSpawner { new_fn })
97    }
98
99    /// Add an MCP server to the registry that wishes to receive a [`McpContext`] when created.
100    ///
101    /// This server will be added to all new sessions where this registry is in the handler chain.
102    ///
103    /// This method is for MCP servers that require information about the ACP connection and/or
104    /// the ability to make ACP requests.
105    ///
106    /// # Parameters
107    ///
108    /// - `name`: The name of the server.
109    /// - `new_fn`: Constructor function that creates the MCP server and returns a [`Component`] for connecting to it.
110    pub fn add_mcp_server_with_context<C: Component>(
111        &self,
112        name: impl ToString,
113        new_fn: impl Fn(McpContext) -> C + Send + Sync + 'static,
114    ) -> Result<(), sacp::Error> {
115        struct FnSpawner<F> {
116            new_fn: F,
117        }
118
119        impl<C, F> SpawnMcpServer for FnSpawner<F>
120        where
121            F: Fn(McpContext) -> C + Send + Sync + 'static,
122            C: Component,
123        {
124            fn spawn(&self, cx: McpContext) -> DynComponent {
125                let component = (self.new_fn)(cx);
126                DynComponent::new(component)
127            }
128        }
129
130        self.add_mcp_server_internal(name, FnSpawner { new_fn })
131    }
132    fn add_mcp_server_internal(
133        &self,
134        name: impl ToString,
135        spawner: impl SpawnMcpServer,
136    ) -> Result<(), sacp::Error> {
137        let name = name.to_string();
138        if let Some(_) = self.get_registered_server_by_name(&name) {
139            return Err(sacp::util::internal_error(format!(
140                "Server with name '{}' already exists",
141                name
142            )));
143        }
144
145        let uuid = uuid::Uuid::new_v4().to_string();
146        let service = Arc::new(RegisteredMcpServer {
147            name,
148            url: format!("acp:{uuid}"),
149            spawn: Arc::new(spawner),
150        });
151        self.insert_registered_server(service);
152        Ok(())
153    }
154
155    fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
156        let mut data = self.data.lock().expect("not poisoned");
157        data.registered_by_name
158            .insert(service.name.clone(), service.clone());
159        data.registered_by_url
160            .insert(service.url.clone(), service.clone());
161    }
162
163    fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
164        self.data
165            .lock()
166            .expect("not poisoned")
167            .registered_by_name
168            .get(name)
169            .cloned()
170    }
171
172    fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
173        self.data
174            .lock()
175            .expect("not poisoned")
176            .registered_by_url
177            .get(url)
178            .cloned()
179    }
180
181    fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
182        self.data
183            .lock()
184            .expect("not poisoned")
185            .connections
186            .insert(connection_id.to_string(), tx);
187    }
188
189    fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
190        self.data
191            .lock()
192            .expect("not poisoned")
193            .connections
194            .get(connection_id)
195            .cloned()
196    }
197
198    fn remove_connection(&self, connection_id: &str) -> bool {
199        self.data
200            .lock()
201            .expect("not poisoned")
202            .connections
203            .remove(connection_id)
204            .is_some()
205    }
206
207    /// Adds all registered MCP servers to the given `NewSessionRequest`.
208    ///
209    /// This method appends the MCP server configurations for all servers registered
210    /// with this registry to the `mcp_servers` field of the request. This is useful
211    /// when you want to manually populate a request with MCP servers outside of the
212    /// automatic handler chain processing.
213    ///
214    /// # Example
215    ///
216    /// ```rust,ignore
217    /// let registry = McpServiceRegistry::new();
218    /// registry.add_mcp_server("my-server", || MyMcpServer)?;
219    ///
220    /// let mut request = NewSessionRequest {
221
222    /// registry.add_registered_mcp_servers_to(&mut request);
223    /// // request.mcp_servers now contains "my-server"
224    /// ```
225    pub fn add_registered_mcp_servers_to(&self, request: &mut NewSessionRequest) {
226        let data = self.data.lock().expect("not poisoned");
227        for server in data.registered_by_url.values() {
228            request.mcp_servers.push(server.acp_mcp_server());
229        }
230    }
231
232    async fn handle_connect_request(
233        &self,
234        successor_request: SuccessorRequest<McpConnectRequest>,
235        request_cx: JrRequestCx<McpConnectResponse>,
236    ) -> Result<
237        Handled<(
238            SuccessorRequest<McpConnectRequest>,
239            JrRequestCx<McpConnectResponse>,
240        )>,
241        sacp::Error,
242    > {
243        let SuccessorRequest { request, .. } = &successor_request;
244        let outer_cx = request_cx.connection_cx();
245
246        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
247        let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
248            return Ok(Handled::No((successor_request, request_cx)));
249        };
250
251        // Create a unique connection ID and a channel for future communication
252        let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
253        let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
254        self.insert_connection(&connection_id, mcp_server_tx);
255
256        // Create connected channel pair for client-server communication
257        let (client_channel, server_channel) = Channel::duplex();
258
259        // Create client-side handler that wraps messages and forwards to successor
260        let client_component = {
261            let connection_id = connection_id.clone();
262            let outer_cx = outer_cx.clone();
263
264            JrHandlerChain::new()
265                .on_receive_message(async move |message: sacp::MessageAndCx| {
266                    // Wrap the message in McpOverAcp{Request,Notification} and forward to successor
267                    let wrapped = message.map(
268                        |request, request_cx| {
269                            (
270                                McpOverAcpRequest {
271                                    connection_id: connection_id.clone(),
272                                    request,
273                                    meta: None,
274                                },
275                                request_cx,
276                            )
277                        },
278                        |notification, cx| {
279                            (
280                                McpOverAcpNotification {
281                                    connection_id: connection_id.clone(),
282                                    notification,
283                                    meta: None,
284                                },
285                                cx,
286                            )
287                        },
288                    );
289                    outer_cx.send_proxied_message(wrapped)
290                })
291                .with_spawned(move |mcp_cx| async move {
292                    while let Some(msg) = mcp_server_rx.next().await {
293                        mcp_cx.send_proxied_message(msg)?;
294                    }
295                    Ok(())
296                })
297        };
298
299        // Get the MCP server component
300        let mcp_server = registered_server.spawn.spawn(McpContext {
301            session_id: request.session_id.clone(),
302            connection_cx: outer_cx.clone(),
303        });
304
305        // Spawn both sides of the connection
306        let spawn_results = outer_cx
307            .spawn(async move { client_component.serve(client_channel).await })
308            .and_then(|()| {
309                // Spawn the MCP server serving the server channel
310                request_cx
311                    .connection_cx()
312                    .spawn(async move { mcp_server.serve(server_channel).await })
313            });
314
315        match spawn_results {
316            Ok(()) => {
317                request_cx.respond(McpConnectResponse {
318                    connection_id,
319                    meta: None,
320                })?;
321                Ok(Handled::Yes)
322            }
323
324            Err(err) => {
325                request_cx.respond_with_error(err)?;
326                Ok(Handled::Yes)
327            }
328        }
329    }
330
331    async fn handle_mcp_over_acp_request(
332        &self,
333        successor_request: SuccessorRequest<McpOverAcpRequest<UntypedMessage>>,
334        request_cx: JrRequestCx<serde_json::Value>,
335    ) -> Result<
336        Handled<(
337            SuccessorRequest<McpOverAcpRequest<UntypedMessage>>,
338            JrRequestCx<serde_json::Value>,
339        )>,
340        sacp::Error,
341    > {
342        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
343        let Some(mut mcp_server_tx) = self.get_connection(&successor_request.request.connection_id)
344        else {
345            return Ok(Handled::No((successor_request, request_cx)));
346        };
347
348        let SuccessorRequest { request, .. } = successor_request;
349
350        mcp_server_tx
351            .send(MessageAndCx::Request(request.request, request_cx))
352            .await
353            .map_err(sacp::Error::into_internal_error)?;
354
355        Ok(Handled::Yes)
356    }
357
358    async fn handle_mcp_over_acp_notification(
359        &self,
360        successor_notification: SuccessorNotification<McpOverAcpNotification<UntypedMessage>>,
361        notification_cx: JrConnectionCx,
362    ) -> Result<
363        Handled<(
364            SuccessorNotification<McpOverAcpNotification<UntypedMessage>>,
365            JrConnectionCx,
366        )>,
367        sacp::Error,
368    > {
369        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
370        let Some(mut mcp_server_tx) =
371            self.get_connection(&successor_notification.notification.connection_id)
372        else {
373            return Ok(Handled::No((successor_notification, notification_cx)));
374        };
375
376        let SuccessorNotification { notification, .. } = successor_notification;
377
378        mcp_server_tx
379            .send(MessageAndCx::Notification(
380                notification.notification,
381                notification_cx.clone(),
382            ))
383            .await
384            .map_err(sacp::Error::into_internal_error)?;
385
386        Ok(Handled::Yes)
387    }
388
389    async fn handle_mcp_disconnect_notification(
390        &self,
391        successor_notification: SuccessorNotification<McpDisconnectNotification>,
392        notification_cx: JrConnectionCx,
393    ) -> Result<
394        Handled<(
395            SuccessorNotification<McpDisconnectNotification>,
396            JrConnectionCx,
397        )>,
398        sacp::Error,
399    > {
400        let SuccessorNotification { notification, .. } = &successor_notification;
401
402        // Remove connection if we have it. Otherwise, do not handle the notification.
403        if self.remove_connection(&notification.connection_id) {
404            Ok(Handled::Yes)
405        } else {
406            Ok(Handled::No((successor_notification, notification_cx)))
407        }
408    }
409
410    async fn handle_new_session_request(
411        &self,
412        mut request: NewSessionRequest,
413        request_cx: JrRequestCx<NewSessionResponse>,
414    ) -> Result<Handled<(NewSessionRequest, JrRequestCx<NewSessionResponse>)>, sacp::Error> {
415        // Add the MCP servers into the session/new request.
416        //
417        // Q: Do we care if there are already servers with that name?
418        self.add_registered_mcp_servers_to(&mut request);
419
420        // Return the modified request so subsequent handlers can see the MCP servers we added.
421        Ok(Handled::No((request, request_cx)))
422    }
423}
424
425impl JrMessageHandler for McpServiceRegistry {
426    fn describe_chain(&self) -> impl std::fmt::Debug {
427        "McpServiceRegistry"
428    }
429
430    async fn handle_message(
431        &mut self,
432        message: sacp::MessageAndCx,
433    ) -> Result<sacp::Handled<sacp::MessageAndCx>, sacp::Error> {
434        MatchMessage::new(message)
435            .if_request(|request, request_cx| self.handle_connect_request(request, request_cx))
436            .await
437            .if_request(|request, request_cx| self.handle_mcp_over_acp_request(request, request_cx))
438            .await
439            .if_request(|request, request_cx| self.handle_new_session_request(request, request_cx))
440            .await
441            .if_notification(|notification, notification_cx| {
442                self.handle_mcp_over_acp_notification(notification, notification_cx)
443            })
444            .await
445            .if_notification(|notification, notification_cx| {
446                self.handle_mcp_disconnect_notification(notification, notification_cx)
447            })
448            .await
449            .done()
450    }
451}
452
453/// A "registeed" MCP server can be launched when a connection is established.
454#[derive(Clone)]
455struct RegisteredMcpServer {
456    name: String,
457    url: String,
458    spawn: Arc<dyn SpawnMcpServer>,
459}
460
461impl RegisteredMcpServer {
462    fn acp_mcp_server(&self) -> sacp::schema::McpServer {
463        sacp::schema::McpServer::Http {
464            name: self.name.clone(),
465            url: self.url.clone(),
466            headers: vec![],
467        }
468    }
469}
470
471impl std::fmt::Debug for RegisteredMcpServer {
472    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
473        f.debug_struct("RegisteredMcpServer")
474            .field("name", &self.name)
475            .field("url", &self.url)
476            .finish()
477    }
478}
479
480/// Trait for spawning MCP server components.
481///
482/// This trait allows creating MCP server instances that implement the `Component` trait.
483trait SpawnMcpServer: Send + Sync + 'static {
484    /// Create a new MCP server component.
485    ///
486    /// Returns a `DynComponent` that can be used with the Component API.
487    fn spawn(&self, cx: McpContext) -> sacp::DynComponent;
488}
489
490impl AsRef<McpServiceRegistry> for McpServiceRegistry {
491    fn as_ref(&self) -> &McpServiceRegistry {
492        self
493    }
494}
495
496/// Context about the ACP and MCP connection available to an MCP server.
497#[derive(Clone)]
498pub struct McpContext {
499    session_id: SessionId,
500    connection_cx: JrConnectionCx,
501}
502
503impl McpContext {
504    /// The session-id of the session connected to this MCP server.
505    pub fn session_id(&self) -> &SessionId {
506        &self.session_id
507    }
508
509    /// The ACP connection context, which can be used to send ACP requests and notifications
510    /// to your successor.
511    pub fn connection_cx(&self) -> JrConnectionCx {
512        self.connection_cx.clone()
513    }
514}