sacp/mcp_server/
registry.rs

1//! MCP service registry for managing MCP servers.
2
3use futures::channel::mpsc;
4use futures::{SinkExt, StreamExt};
5use fxhash::FxHashMap;
6
7use crate::mcp::{McpClientToServer, McpServerEnd};
8use crate::schema::{
9    McpConnectRequest, McpConnectResponse, McpDisconnectNotification, McpOverAcpMessage,
10    NewSessionRequest, NewSessionResponse,
11};
12use crate::util::MatchMessage;
13use crate::{
14    Agent, Channel, Component, DynComponent, Handled, HasEndpoint, JrConnectionCx,
15    JrMessageHandlerSend, JrRequestCx, JrRole, MessageCx, UntypedMessage,
16};
17use std::sync::{Arc, Mutex};
18
19use super::server::McpServer;
20
21/// Manages MCP services offered to successor proxies and agents.
22///
23/// Use the [`Self::with_mcp_server`] method to register MCP servers.
24///
25/// This struct is a handle to the underlying registry. Cloning the struct produces a second handle to the same registry.
26///
27/// # Handling requests
28///
29/// You must add the registry (or a clone of it) to the [`JrConnectionBuilder`] so that it can intercept MCP requests.
30/// Typically you do this by providing it as an argument to the handler chain methods.
31///
32/// [`JrConnectionBuilder`]: crate::JrConnectionBuilder
33#[derive(Clone, Default, Debug)]
34pub struct McpServiceRegistry<Role: JrRole> {
35    data: Arc<Mutex<McpServiceRegistryData<Role>>>,
36}
37
38#[derive(Default, Debug)]
39struct McpServiceRegistryData<Role: JrRole> {
40    registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer<Role>>>,
41    registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer<Role>>>,
42    connections: FxHashMap<String, mpsc::Sender<MessageCx>>,
43}
44
45impl<Role: JrRole> McpServiceRegistry<Role>
46where
47    Role: HasEndpoint<Agent>,
48{
49    /// Creates a new empty MCP service registry
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Add an [`McpServer`] to the registry.
55    ///
56    /// This server will be added to all new sessions where this registry is in the handler chain.
57    ///
58    /// See the [`McpServer`] documentation for more information.
59    pub fn with_mcp_server(
60        self,
61        name: impl ToString,
62        server: McpServer<Role>,
63    ) -> Result<Self, crate::Error> {
64        self.add_mcp_server_with_context(name, move |mcp_cx| server.new_connection(mcp_cx))?;
65        Ok(self)
66    }
67
68    /// Add an MCP server to the registry using a custom constructor function.
69    ///
70    /// This server will be added to all new sessions where this registry is in the handler chain.
71    ///
72    /// This method is for independent MCP servers that do not make use of ACP.
73    /// You may wish to use the `sacp-rmcp` crate which provides convenient
74    /// extension methods for working with MCP servers implemented using the `rmcp` crate.
75    ///
76    /// # Parameters
77    ///
78    /// - `name`: The name of the server.
79    /// - `new_fn`: Constructor function that creates the MCP server and returns a [`Component`] for connecting to it.
80    pub fn add_mcp_server<C: Component>(
81        &self,
82        name: impl ToString,
83        new_fn: impl Fn() -> C + Send + Sync + 'static,
84    ) -> Result<(), crate::Error> {
85        struct FnSpawner<F> {
86            new_fn: F,
87        }
88
89        impl<Role, C, F> SpawnMcpServer<Role> for FnSpawner<F>
90        where
91            Role: JrRole,
92            F: Fn() -> C + Send + Sync + 'static,
93            C: Component,
94        {
95            fn spawn(&self, _cx: McpContext<Role>) -> DynComponent {
96                let component = (self.new_fn)();
97                DynComponent::new(component)
98            }
99        }
100
101        self.add_mcp_server_internal(name, FnSpawner { new_fn })
102    }
103
104    /// Add an MCP server to the registry that wishes to receive a [`McpContext`] when created.
105    ///
106    /// This server will be added to all new sessions where this registry is in the handler chain.
107    ///
108    /// This method is for MCP servers that require information about the ACP connection and/or
109    /// the ability to make ACP requests.
110    ///
111    /// # Parameters
112    ///
113    /// - `name`: The name of the server.
114    /// - `new_fn`: Constructor function that creates the MCP server and returns a [`Component`] for connecting to it.
115    pub fn add_mcp_server_with_context<C: Component>(
116        &self,
117        name: impl ToString,
118        new_fn: impl Fn(McpContext<Role>) -> C + Send + Sync + 'static,
119    ) -> Result<(), crate::Error> {
120        struct FnSpawner<F> {
121            new_fn: F,
122        }
123
124        impl<Role, C, F> SpawnMcpServer<Role> for FnSpawner<F>
125        where
126            Role: JrRole,
127            F: Fn(McpContext<Role>) -> C + Send + Sync + 'static,
128            C: Component,
129        {
130            fn spawn(&self, cx: McpContext<Role>) -> DynComponent {
131                let component = (self.new_fn)(cx);
132                DynComponent::new(component)
133            }
134        }
135
136        self.add_mcp_server_internal(name, FnSpawner { new_fn })
137    }
138
139    fn add_mcp_server_internal(
140        &self,
141        name: impl ToString,
142        spawner: impl SpawnMcpServer<Role>,
143    ) -> Result<(), crate::Error> {
144        let name = name.to_string();
145        if self.get_registered_server_by_name(&name).is_some() {
146            return Err(crate::util::internal_error(format!(
147                "Server with name '{}' already exists",
148                name
149            )));
150        }
151
152        let uuid = uuid::Uuid::new_v4().to_string();
153        let service = Arc::new(RegisteredMcpServer {
154            name,
155            url: format!("acp:{uuid}"),
156            spawn: Arc::new(spawner),
157        });
158        self.insert_registered_server(service);
159        Ok(())
160    }
161
162    fn insert_registered_server(&self, service: Arc<RegisteredMcpServer<Role>>) {
163        let mut data = self.data.lock().expect("not poisoned");
164        data.registered_by_name
165            .insert(service.name.clone(), service.clone());
166        data.registered_by_url
167            .insert(service.url.clone(), service.clone());
168    }
169
170    fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer<Role>>> {
171        self.data
172            .lock()
173            .expect("not poisoned")
174            .registered_by_name
175            .get(name)
176            .cloned()
177    }
178
179    fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer<Role>>> {
180        self.data
181            .lock()
182            .expect("not poisoned")
183            .registered_by_url
184            .get(url)
185            .cloned()
186    }
187
188    fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<MessageCx>) {
189        self.data
190            .lock()
191            .expect("not poisoned")
192            .connections
193            .insert(connection_id.to_string(), tx);
194    }
195
196    fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<MessageCx>> {
197        self.data
198            .lock()
199            .expect("not poisoned")
200            .connections
201            .get(connection_id)
202            .cloned()
203    }
204
205    fn remove_connection(&self, connection_id: &str) -> bool {
206        self.data
207            .lock()
208            .expect("not poisoned")
209            .connections
210            .remove(connection_id)
211            .is_some()
212    }
213
214    /// Adds all registered MCP servers to the given `NewSessionRequest`.
215    ///
216    /// This method appends the MCP server configurations for all servers registered
217    /// with this registry to the `mcp_servers` field of the request. This is useful
218    /// when you want to manually populate a request with MCP servers outside of the
219    /// automatic handler chain processing.
220    pub fn add_registered_mcp_servers_to(&self, request: &mut NewSessionRequest) {
221        let data = self.data.lock().expect("not poisoned");
222        for server in data.registered_by_url.values() {
223            request.mcp_servers.push(server.acp_mcp_server());
224        }
225    }
226
227    async fn handle_connect_request(
228        &self,
229        request: McpConnectRequest,
230        request_cx: JrRequestCx<McpConnectResponse>,
231        outer_cx: JrConnectionCx<Role>,
232    ) -> Result<Handled<(McpConnectRequest, JrRequestCx<McpConnectResponse>)>, crate::Error> {
233        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
234        let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
235            return Ok(Handled::No((request, request_cx)));
236        };
237
238        // Create a unique connection ID and a channel for future communication
239        let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
240        let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
241        self.insert_connection(&connection_id, mcp_server_tx);
242
243        // Create connected channel pair for client-server communication
244        let (client_channel, server_channel) = Channel::duplex();
245
246        // Create client-side handler that wraps messages and forwards to successor
247        let client_component = {
248            let connection_id = connection_id.clone();
249            let outer_cx = outer_cx.clone();
250
251            McpClientToServer::builder()
252                .on_receive_message(async move |message: MessageCx, _mcp_cx| {
253                    // Wrap the message in McpOverAcp{Request,Notification} and forward to successor
254                    let wrapped = message.map(
255                        |request, request_cx| {
256                            (
257                                McpOverAcpMessage {
258                                    connection_id: connection_id.clone(),
259                                    message: request,
260                                    meta: None,
261                                },
262                                request_cx,
263                            )
264                        },
265                        |notification| McpOverAcpMessage {
266                            connection_id: connection_id.clone(),
267                            message: notification,
268                            meta: None,
269                        },
270                    );
271                    outer_cx.send_proxied_message_to(Agent, wrapped)
272                })
273                .with_spawned(move |mcp_cx| async move {
274                    // Messages we pull off this channel were sent from the agent.
275                    // Forward them back to the MCP server.
276                    while let Some(msg) = mcp_server_rx.next().await {
277                        mcp_cx.send_proxied_message_to(McpServerEnd, msg)?;
278                    }
279                    Ok(())
280                })
281        };
282
283        // Get the MCP server component
284        let mcp_server = registered_server.spawn.spawn(McpContext {
285            acp_url: request.acp_url.clone(),
286            connection_cx: outer_cx.clone(),
287        });
288
289        // Spawn both sides of the connection
290        let spawn_results = outer_cx
291            .spawn(async move { client_component.serve(client_channel).await })
292            .and_then(|()| {
293                // Spawn the MCP server serving the server channel
294                outer_cx.spawn(async move { mcp_server.serve(server_channel).await })
295            });
296
297        match spawn_results {
298            Ok(()) => {
299                request_cx.respond(McpConnectResponse {
300                    connection_id,
301                    meta: None,
302                })?;
303                Ok(Handled::Yes)
304            }
305
306            Err(err) => {
307                request_cx.respond_with_error(err)?;
308                Ok(Handled::Yes)
309            }
310        }
311    }
312
313    async fn handle_mcp_over_acp_request(
314        &self,
315        request: McpOverAcpMessage<UntypedMessage>,
316        request_cx: JrRequestCx<serde_json::Value>,
317    ) -> Result<
318        Handled<(
319            McpOverAcpMessage<UntypedMessage>,
320            JrRequestCx<serde_json::Value>,
321        )>,
322        crate::Error,
323    > {
324        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
325        let Some(mut mcp_server_tx) = self.get_connection(&request.connection_id) else {
326            return Ok(Handled::No((request, request_cx)));
327        };
328
329        mcp_server_tx
330            .send(MessageCx::Request(request.message, request_cx))
331            .await
332            .map_err(crate::Error::into_internal_error)?;
333
334        Ok(Handled::Yes)
335    }
336
337    async fn handle_mcp_over_acp_notification(
338        &self,
339        notification: McpOverAcpMessage<UntypedMessage>,
340    ) -> Result<Handled<McpOverAcpMessage<UntypedMessage>>, crate::Error> {
341        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
342        let Some(mut mcp_server_tx) = self.get_connection(&notification.connection_id) else {
343            return Ok(Handled::No(notification));
344        };
345
346        mcp_server_tx
347            .send(MessageCx::Notification(notification.message))
348            .await
349            .map_err(crate::Error::into_internal_error)?;
350
351        Ok(Handled::Yes)
352    }
353
354    async fn handle_mcp_disconnect_notification(
355        &self,
356        successor_notification: McpDisconnectNotification,
357    ) -> Result<Handled<McpDisconnectNotification>, crate::Error> {
358        // Remove connection if we have it. Otherwise, do not handle the notification.
359        if self.remove_connection(&successor_notification.connection_id) {
360            Ok(Handled::Yes)
361        } else {
362            Ok(Handled::No(successor_notification))
363        }
364    }
365
366    async fn handle_new_session_request(
367        &self,
368        mut request: NewSessionRequest,
369        request_cx: JrRequestCx<NewSessionResponse>,
370    ) -> Result<Handled<(NewSessionRequest, JrRequestCx<NewSessionResponse>)>, crate::Error> {
371        // Add the MCP servers into the session/new request.
372        //
373        // Q: Do we care if there are already servers with that name?
374        self.add_registered_mcp_servers_to(&mut request);
375
376        // Return the modified request so subsequent handlers can see the MCP servers we added.
377        Ok(Handled::No((request, request_cx)))
378    }
379}
380
381impl<Role: JrRole> JrMessageHandlerSend for McpServiceRegistry<Role>
382where
383    Role: HasEndpoint<Agent>,
384{
385    type Role = Role;
386
387    fn describe_chain(&self) -> impl std::fmt::Debug {
388        "McpServiceRegistry"
389    }
390
391    async fn handle_message(
392        &mut self,
393        message: MessageCx,
394        connection_cx: JrConnectionCx<Role>,
395    ) -> Result<Handled<MessageCx>, crate::Error> {
396        // Hmm, this is a bit wacky:
397        //
398        // * In a proxy, we expect to receive MCP over ACP notifications wrapped as a "FromSuccessorNotification"
399        //   and we don't expect to receive them unwrapped (that would be the client sending it to us, not our agent,
400        //   and that's weird);
401        // * But in a *client*, we expect to receive incoming messages unwrapped (i.e., from our successor),
402        //   and not wrapped (we don't expect *anything* wrapped).
403        //
404        // So we just accept them in either direction for now. The whole thing feels a bit inelegant,
405        // but I guess it works.
406
407        MatchMessage::new(message)
408            // MCP connect requests come from the Agent direction (wrapped in SuccessorMessage)
409            .if_request_from(Agent, connection_cx.clone(), |request, request_cx, cx| {
410                self.handle_connect_request(request, request_cx, cx)
411            })
412            .await
413            // MCP over ACP requests come from the Agent direction
414            .if_request_from(Agent, connection_cx.clone(), |request, request_cx, _cx| {
415                self.handle_mcp_over_acp_request(request, request_cx)
416            })
417            .await
418            // session/new requests come from the Client direction (not wrapped)
419            .if_request(|request, request_cx| self.handle_new_session_request(request, request_cx))
420            .await
421            // MCP over ACP notifications come from the Agent direction
422            .if_notification_from(Agent, connection_cx.clone(), |notification, _cx| {
423                self.handle_mcp_over_acp_notification(notification)
424            })
425            .await
426            // MCP disconnect notifications come from the Agent direction
427            .if_notification_from(Agent, connection_cx, |notification, _cx| {
428                self.handle_mcp_disconnect_notification(notification)
429            })
430            .await
431            .done()
432    }
433}
434
435/// A "registered" MCP server can be launched when a connection is established.
436#[derive(Clone)]
437struct RegisteredMcpServer<Role: JrRole> {
438    name: String,
439    url: String,
440    spawn: Arc<dyn SpawnMcpServer<Role>>,
441}
442
443impl<Role: JrRole> RegisteredMcpServer<Role> {
444    fn acp_mcp_server(&self) -> crate::schema::McpServer {
445        crate::schema::McpServer::Http {
446            name: self.name.clone(),
447            url: self.url.clone(),
448            headers: vec![],
449        }
450    }
451}
452
453impl<Role: JrRole> std::fmt::Debug for RegisteredMcpServer<Role> {
454    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455        f.debug_struct("RegisteredMcpServer")
456            .field("name", &self.name)
457            .field("url", &self.url)
458            .finish()
459    }
460}
461
462/// Trait for spawning MCP server components.
463///
464/// This trait allows creating MCP server instances that implement the `Component` trait.
465trait SpawnMcpServer<Role: JrRole>: Send + Sync + 'static {
466    /// Create a new MCP server component.
467    ///
468    /// Returns a `DynComponent` that can be used with the Component API.
469    fn spawn(&self, cx: McpContext<Role>) -> DynComponent;
470}
471
472impl<Role: JrRole> AsRef<McpServiceRegistry<Role>> for McpServiceRegistry<Role>
473where
474    Role: HasEndpoint<Agent>,
475{
476    fn as_ref(&self) -> &McpServiceRegistry<Role> {
477        self
478    }
479}
480
481/// Context about the ACP and MCP connection available to an MCP server.
482#[derive(Clone)]
483pub struct McpContext<Role: JrRole> {
484    acp_url: String,
485    connection_cx: JrConnectionCx<Role>,
486}
487
488impl<Role: JrRole> McpContext<Role> {
489    /// The `acp:UUID` that was given.
490    pub fn acp_url(&self) -> String {
491        self.acp_url.clone()
492    }
493
494    /// The ACP connection context, which can be used to send ACP requests and notifications
495    /// to your successor.
496    pub fn connection_cx(&self) -> JrConnectionCx<Role> {
497        self.connection_cx.clone()
498    }
499}