sacp_proxy/
mcp_server.rs

1use futures::channel::mpsc;
2use futures::{SinkExt, StreamExt};
3use fxhash::FxHashMap;
4
5use sacp::schema::{NewSessionRequest, NewSessionResponse};
6use sacp::util::MatchMessage;
7use sacp::{
8    Channel, Component, DynComponent, Handled, JrConnectionCx, JrHandlerChain, JrMessageHandler,
9    JrRequestCx, MessageAndCx, UntypedMessage,
10};
11use std::sync::{Arc, Mutex};
12
13use crate::{
14    McpConnectRequest, McpConnectResponse, McpDisconnectNotification, McpOverAcpNotification,
15    McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
16};
17
18/// Manages MCP services offered to successor proxies and agents.
19///
20/// Use the [`Self::add_mcp_server`] method to register MCP servers. For rmcp-based servers,
21/// use the `sacp-rmcp` crate which provides convenient extension methods.
22///
23/// This struct is a handle to the underlying registry. Cloning the struct produces a second handle to the same registry.
24///
25/// # Handling requests
26///
27/// You must add the registry (or a clone of it) to the [`JrHandlerChain`] so that it can intercept MCP requests.
28/// Typically you do this by providing it as an argument to the handler chain methods.
29///
30/// [`JrHandlerChain`]: sacp::JrHandlerChain
31#[derive(Clone, Default, Debug)]
32pub struct McpServiceRegistry {
33    data: Arc<Mutex<McpServiceRegistryData>>,
34}
35
36#[derive(Default, Debug)]
37struct McpServiceRegistryData {
38    registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
39    registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
40    connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
41}
42
43impl McpServiceRegistry {
44    /// Creates a new empty MCP service registry
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Add an MCP server to the registry using a custom spawner.
50    ///
51    /// This is the base method for adding MCP servers. Use this if you have a custom
52    /// way to create Component instances for your MCP server.
53    ///
54    /// For rmcp-based servers, use the `sacp-rmcp` crate which provides convenient
55    /// extension methods.
56    ///
57    /// # Parameters
58    ///
59    /// - `name`: The name of the server.
60    /// - `spawner`: A trait object that can create Component instances.
61    pub fn add_mcp_server<C: Component>(
62        &self,
63        name: impl ToString,
64        new_fn: impl Fn() -> C + Send + Sync + 'static,
65    ) -> Result<(), sacp::Error> {
66        struct FnSpawner<F> {
67            new_fn: F,
68        }
69
70        impl<C, F> SpawnMcpServer for FnSpawner<F>
71        where
72            F: Fn() -> C + Send + Sync + 'static,
73            C: Component,
74        {
75            fn spawn(&self) -> DynComponent {
76                let component = (self.new_fn)();
77                DynComponent::new(component)
78            }
79        }
80
81        let name = name.to_string();
82        if let Some(_) = self.get_registered_server_by_name(&name) {
83            return Err(sacp::util::internal_error(format!(
84                "Server with name '{}' already exists",
85                name
86            )));
87        }
88
89        let uuid = uuid::Uuid::new_v4().to_string();
90        let service = Arc::new(RegisteredMcpServer {
91            name,
92            url: format!("acp:{uuid}"),
93            spawn: Arc::new(FnSpawner { new_fn }),
94        });
95        self.insert_registered_server(service);
96        Ok(())
97    }
98
99    fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
100        let mut data = self.data.lock().expect("not poisoned");
101        data.registered_by_name
102            .insert(service.name.clone(), service.clone());
103        data.registered_by_url
104            .insert(service.url.clone(), service.clone());
105    }
106
107    fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
108        self.data
109            .lock()
110            .expect("not poisoned")
111            .registered_by_name
112            .get(name)
113            .cloned()
114    }
115
116    fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
117        self.data
118            .lock()
119            .expect("not poisoned")
120            .registered_by_url
121            .get(url)
122            .cloned()
123    }
124
125    fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
126        self.data
127            .lock()
128            .expect("not poisoned")
129            .connections
130            .insert(connection_id.to_string(), tx);
131    }
132
133    fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
134        self.data
135            .lock()
136            .expect("not poisoned")
137            .connections
138            .get(connection_id)
139            .cloned()
140    }
141
142    fn remove_connection(&self, connection_id: &str) -> bool {
143        self.data
144            .lock()
145            .expect("not poisoned")
146            .connections
147            .remove(connection_id)
148            .is_some()
149    }
150
151    /// Adds all registered MCP servers to the given `NewSessionRequest`.
152    ///
153    /// This method appends the MCP server configurations for all servers registered
154    /// with this registry to the `mcp_servers` field of the request. This is useful
155    /// when you want to manually populate a request with MCP servers outside of the
156    /// automatic handler chain processing.
157    ///
158    /// # Example
159    ///
160    /// ```rust,ignore
161    /// let registry = McpServiceRegistry::new();
162    /// registry.add_mcp_server("my-server", || MyMcpServer)?;
163    ///
164    /// let mut request = NewSessionRequest {
165    ///     mcp_servers: vec![],
166    ///     cwd: std::env::current_dir()?,
167    ///     meta: None,
168    /// };
169    ///
170    /// registry.add_registered_mcp_servers_to(&mut request);
171    /// // request.mcp_servers now contains "my-server"
172    /// ```
173    pub fn add_registered_mcp_servers_to(&self, request: &mut NewSessionRequest) {
174        let data = self.data.lock().expect("not poisoned");
175        for server in data.registered_by_url.values() {
176            request.mcp_servers.push(server.acp_mcp_server());
177        }
178    }
179
180    async fn handle_connect_request(
181        &self,
182        successor_request: SuccessorRequest<McpConnectRequest>,
183        request_cx: JrRequestCx<McpConnectResponse>,
184    ) -> Result<
185        Handled<(
186            SuccessorRequest<McpConnectRequest>,
187            JrRequestCx<McpConnectResponse>,
188        )>,
189        sacp::Error,
190    > {
191        let SuccessorRequest { request } = &successor_request;
192
193        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
194        let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
195            return Ok(Handled::No((successor_request, request_cx)));
196        };
197
198        // Create a unique connection ID and a channel for future communication
199        let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
200        let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
201        self.insert_connection(&connection_id, mcp_server_tx);
202
203        // Create connected channel pair for client-server communication
204        let (client_channel, server_channel) = Channel::duplex();
205
206        // Create client-side handler that wraps messages and forwards to successor
207        let client_component = {
208            let connection_id = connection_id.clone();
209            let outer_cx = request_cx.connection_cx();
210
211            JrHandlerChain::new()
212                .on_receive_message(async move |message: sacp::MessageAndCx| {
213                    // Wrap the message in McpOverAcp{Request,Notification} and forward to successor
214                    let wrapped = message.map(
215                        |request, request_cx| {
216                            (
217                                McpOverAcpRequest {
218                                    connection_id: connection_id.clone(),
219                                    request,
220                                },
221                                request_cx,
222                            )
223                        },
224                        |notification, cx| {
225                            (
226                                McpOverAcpNotification {
227                                    connection_id: connection_id.clone(),
228                                    notification,
229                                },
230                                cx,
231                            )
232                        },
233                    );
234                    outer_cx.send_proxied_message(wrapped)
235                })
236                .with_spawned(move |mcp_cx| async move {
237                    while let Some(msg) = mcp_server_rx.next().await {
238                        mcp_cx.send_proxied_message(msg)?;
239                    }
240                    Ok(())
241                })
242        };
243
244        // Get the MCP server component
245        let mcp_server = registered_server.spawn.spawn();
246
247        // Spawn both sides of the connection
248        let spawn_results = request_cx
249            .connection_cx()
250            .spawn(async move { client_component.serve(client_channel).await })
251            .and_then(|()| {
252                // Spawn the MCP server serving the server channel
253                request_cx
254                    .connection_cx()
255                    .spawn(async move { mcp_server.serve(server_channel).await })
256            });
257
258        match spawn_results {
259            Ok(()) => {
260                request_cx.respond(McpConnectResponse { connection_id })?;
261                Ok(Handled::Yes)
262            }
263
264            Err(err) => {
265                request_cx.respond_with_error(err)?;
266                Ok(Handled::Yes)
267            }
268        }
269    }
270
271    async fn handle_mcp_over_acp_request(
272        &self,
273        successor_request: SuccessorRequest<McpOverAcpRequest<UntypedMessage>>,
274        request_cx: JrRequestCx<serde_json::Value>,
275    ) -> Result<
276        Handled<(
277            SuccessorRequest<McpOverAcpRequest<UntypedMessage>>,
278            JrRequestCx<serde_json::Value>,
279        )>,
280        sacp::Error,
281    > {
282        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
283        let Some(mut mcp_server_tx) = self.get_connection(&successor_request.request.connection_id)
284        else {
285            return Ok(Handled::No((successor_request, request_cx)));
286        };
287
288        let SuccessorRequest { request } = successor_request;
289
290        mcp_server_tx
291            .send(MessageAndCx::Request(request.request, request_cx))
292            .await
293            .map_err(sacp::Error::into_internal_error)?;
294
295        Ok(Handled::Yes)
296    }
297
298    async fn handle_mcp_over_acp_notification(
299        &self,
300        successor_notification: SuccessorNotification<McpOverAcpNotification<UntypedMessage>>,
301        notification_cx: JrConnectionCx,
302    ) -> Result<
303        Handled<(
304            SuccessorNotification<McpOverAcpNotification<UntypedMessage>>,
305            JrConnectionCx,
306        )>,
307        sacp::Error,
308    > {
309        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
310        let Some(mut mcp_server_tx) =
311            self.get_connection(&successor_notification.notification.connection_id)
312        else {
313            return Ok(Handled::No((successor_notification, notification_cx)));
314        };
315
316        let SuccessorNotification { notification } = successor_notification;
317
318        mcp_server_tx
319            .send(MessageAndCx::Notification(
320                notification.notification,
321                notification_cx.clone(),
322            ))
323            .await
324            .map_err(sacp::Error::into_internal_error)?;
325
326        Ok(Handled::Yes)
327    }
328
329    async fn handle_mcp_disconnect_notification(
330        &self,
331        successor_notification: SuccessorNotification<McpDisconnectNotification>,
332        notification_cx: JrConnectionCx,
333    ) -> Result<
334        Handled<(
335            SuccessorNotification<McpDisconnectNotification>,
336            JrConnectionCx,
337        )>,
338        sacp::Error,
339    > {
340        let SuccessorNotification { notification } = &successor_notification;
341
342        // Remove connection if we have it. Otherwise, do not handle the notification.
343        if self.remove_connection(&notification.connection_id) {
344            Ok(Handled::Yes)
345        } else {
346            Ok(Handled::No((successor_notification, notification_cx)))
347        }
348    }
349
350    async fn handle_new_session_request(
351        &self,
352        mut request: NewSessionRequest,
353        request_cx: JrRequestCx<NewSessionResponse>,
354    ) -> Result<Handled<(NewSessionRequest, JrRequestCx<NewSessionResponse>)>, sacp::Error> {
355        // Add the MCP servers into the session/new request.
356        //
357        // Q: Do we care if there are already servers with that name?
358        self.add_registered_mcp_servers_to(&mut request);
359
360        // Return the modified request so subsequent handlers can see the MCP servers we added.
361        Ok(Handled::No((request, request_cx)))
362    }
363}
364
365impl JrMessageHandler for McpServiceRegistry {
366    fn describe_chain(&self) -> impl std::fmt::Debug {
367        "McpServiceRegistry"
368    }
369
370    async fn handle_message(
371        &mut self,
372        message: sacp::MessageAndCx,
373    ) -> Result<sacp::Handled<sacp::MessageAndCx>, sacp::Error> {
374        MatchMessage::new(message)
375            .if_request(|request, request_cx| self.handle_connect_request(request, request_cx))
376            .await
377            .if_request(|request, request_cx| self.handle_mcp_over_acp_request(request, request_cx))
378            .await
379            .if_request(|request, request_cx| self.handle_new_session_request(request, request_cx))
380            .await
381            .if_notification(|notification, notification_cx| {
382                self.handle_mcp_over_acp_notification(notification, notification_cx)
383            })
384            .await
385            .if_notification(|notification, notification_cx| {
386                self.handle_mcp_disconnect_notification(notification, notification_cx)
387            })
388            .await
389            .done()
390    }
391}
392
393#[derive(Clone)]
394struct RegisteredMcpServer {
395    name: String,
396    url: String,
397    spawn: Arc<dyn SpawnMcpServer>,
398}
399
400impl RegisteredMcpServer {
401    fn acp_mcp_server(&self) -> sacp::schema::McpServer {
402        sacp::schema::McpServer::Http {
403            name: self.name.clone(),
404            url: self.url.clone(),
405            headers: vec![],
406        }
407    }
408}
409
410impl std::fmt::Debug for RegisteredMcpServer {
411    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412        f.debug_struct("RegisteredMcpServer")
413            .field("name", &self.name)
414            .field("url", &self.url)
415            .finish()
416    }
417}
418
419/// Trait for spawning MCP server components.
420///
421/// This trait allows creating MCP server instances that implement the `Component` trait.
422trait SpawnMcpServer: Send + Sync + 'static {
423    /// Create a new MCP server component.
424    ///
425    /// Returns a `DynComponent` that can be used with the Component API.
426    fn spawn(&self) -> sacp::DynComponent;
427}
428
429impl AsRef<McpServiceRegistry> for McpServiceRegistry {
430    fn as_ref(&self) -> &McpServiceRegistry {
431        self
432    }
433}