sacp_proxy/
mcp_server.rs

1use futures::channel::mpsc;
2use futures::{SinkExt, StreamExt};
3use fxhash::FxHashMap;
4
5use sacp::schema::NewSessionRequest;
6use sacp::{
7    Channel, Component, DynComponent, Handled, JrConnectionCx, JrHandlerChain, JrMessage,
8    JrMessageHandler, JrRequestCx, MessageAndCx, UntypedMessage,
9};
10use std::sync::{Arc, Mutex};
11
12use crate::{
13    JrCxExt, McpConnectRequest, McpConnectResponse, McpDisconnectNotification,
14    McpOverAcpNotification, McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
15};
16
17/// Manages MCP services offered to successor proxies and agents.
18///
19/// Use the [`Self::add_mcp_server`] method to register MCP servers. For rmcp-based servers,
20/// use the `sacp-rmcp` crate which provides convenient extension methods.
21///
22/// This struct is a handle to the underlying registry. Cloning the struct produces a second handle to the same registry.
23///
24/// # Handling requests
25///
26/// You must add the registry (or a clone of it) to the [`JrHandlerChain`] so that it can intercept MCP requests.
27/// Typically you do this by providing it as an argument to the handler chain methods.
28///
29/// [`JrHandlerChain`]: sacp::JrHandlerChain
30#[derive(Clone, Default, Debug)]
31pub struct McpServiceRegistry {
32    data: Arc<Mutex<McpServiceRegistryData>>,
33}
34
35#[derive(Default, Debug)]
36struct McpServiceRegistryData {
37    registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
38    registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
39    connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
40}
41
42impl McpServiceRegistry {
43    /// Creates a new empty MCP service registry
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    /// Add an MCP server to the registry using a custom spawner.
49    ///
50    /// This is the base method for adding MCP servers. Use this if you have a custom
51    /// way to create Component instances for your MCP server.
52    ///
53    /// For rmcp-based servers, use the `sacp-rmcp` crate which provides convenient
54    /// extension methods.
55    ///
56    /// # Parameters
57    ///
58    /// - `name`: The name of the server.
59    /// - `spawner`: A trait object that can create Component instances.
60    pub fn add_mcp_server<C: Component>(
61        &self,
62        name: impl ToString,
63        new_fn: impl Fn() -> C + Send + Sync + 'static,
64    ) -> Result<(), sacp::Error> {
65        struct FnSpawner<F> {
66            new_fn: F,
67        }
68
69        impl<C, F> SpawnMcpServer for FnSpawner<F>
70        where
71            F: Fn() -> C + Send + Sync + 'static,
72            C: Component,
73        {
74            fn spawn(&self) -> DynComponent {
75                let component = (self.new_fn)();
76                DynComponent::new(component)
77            }
78        }
79
80        let name = name.to_string();
81        if let Some(_) = self.get_registered_server_by_name(&name) {
82            return Err(sacp::util::internal_error(format!(
83                "Server with name '{}' already exists",
84                name
85            )));
86        }
87
88        let uuid = uuid::Uuid::new_v4().to_string();
89        let service = Arc::new(RegisteredMcpServer {
90            name,
91            url: format!("acp:{uuid}"),
92            spawn: Arc::new(FnSpawner { new_fn }),
93        });
94        self.insert_registered_server(service);
95        Ok(())
96    }
97
98    fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
99        let mut data = self.data.lock().expect("not poisoned");
100        data.registered_by_name
101            .insert(service.name.clone(), service.clone());
102        data.registered_by_url
103            .insert(service.url.clone(), service.clone());
104    }
105
106    fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
107        self.data
108            .lock()
109            .expect("not poisoned")
110            .registered_by_name
111            .get(name)
112            .cloned()
113    }
114
115    fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
116        self.data
117            .lock()
118            .expect("not poisoned")
119            .registered_by_url
120            .get(url)
121            .cloned()
122    }
123
124    fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
125        self.data
126            .lock()
127            .expect("not poisoned")
128            .connections
129            .insert(connection_id.to_string(), tx);
130    }
131
132    fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
133        self.data
134            .lock()
135            .expect("not poisoned")
136            .connections
137            .get(connection_id)
138            .cloned()
139    }
140
141    fn remove_connection(&self, connection_id: &str) -> bool {
142        self.data
143            .lock()
144            .expect("not poisoned")
145            .connections
146            .remove(connection_id)
147            .is_some()
148    }
149
150    async fn handle_connect_request(
151        &self,
152        result: Result<SuccessorRequest<McpConnectRequest>, sacp::Error>,
153        request_cx: JrRequestCx<serde_json::Value>,
154    ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
155        // Check if we parsed this message successfully.
156        let SuccessorRequest { request } = match result {
157            Ok(request) => request,
158            Err(err) => {
159                request_cx.respond_with_error(err)?;
160                return Ok(Handled::Yes);
161            }
162        };
163
164        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
165        let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
166            return Ok(Handled::No(request_cx));
167        };
168
169        let request_cx = request_cx.cast::<McpConnectResponse>();
170
171        // Create a unique connection ID and a channel for future communication
172        let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
173        let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
174        self.insert_connection(&connection_id, mcp_server_tx);
175
176        // Create connected channel pair for client-server communication
177        let (client_channel, server_channel) = Channel::duplex();
178
179        // Create client-side handler that wraps messages and forwards to successor
180        let client_component = {
181            let connection_id = connection_id.clone();
182            let outer_cx = request_cx.connection_cx();
183
184            JrHandlerChain::new()
185                .on_receive_message(async move |message: sacp::MessageAndCx| {
186                    // Wrap the message in McpOverAcp{Request,Notification} and forward to successor
187                    let wrapped = message.map(
188                        |request, request_cx| {
189                            (
190                                McpOverAcpRequest {
191                                    connection_id: connection_id.clone(),
192                                    request,
193                                },
194                                request_cx,
195                            )
196                        },
197                        |notification, cx| {
198                            (
199                                McpOverAcpNotification {
200                                    connection_id: connection_id.clone(),
201                                    notification,
202                                },
203                                cx,
204                            )
205                        },
206                    );
207                    outer_cx.send_proxied_message(wrapped)
208                })
209                .with_spawned(move |mcp_cx| async move {
210                    while let Some(msg) = mcp_server_rx.next().await {
211                        mcp_cx.send_proxied_message(msg)?;
212                    }
213                    Ok(())
214                })
215        };
216
217        // Get the MCP server component
218        let mcp_server = registered_server.spawn.spawn();
219
220        // Spawn both sides of the connection
221        let spawn_results = request_cx
222            .connection_cx()
223            .spawn(async move { client_component.serve(client_channel).await })
224            .and_then(|()| {
225                // Spawn the MCP server serving the server channel
226                request_cx
227                    .connection_cx()
228                    .spawn(async move { mcp_server.serve(server_channel).await })
229            });
230
231        match spawn_results {
232            Ok(()) => {
233                request_cx.respond(McpConnectResponse { connection_id })?;
234                Ok(Handled::Yes)
235            }
236
237            Err(err) => {
238                request_cx.respond_with_error(err)?;
239                Ok(Handled::Yes)
240            }
241        }
242    }
243
244    async fn handle_mcp_over_acp_request(
245        &self,
246        result: Result<SuccessorRequest<McpOverAcpRequest<UntypedMessage>>, sacp::Error>,
247        request_cx: JrRequestCx<serde_json::Value>,
248    ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
249        // Check if we parsed this message successfully.
250        let SuccessorRequest { request } = match result {
251            Ok(request) => request,
252            Err(err) => {
253                request_cx.respond_with_error(err)?;
254                return Ok(Handled::Yes);
255            }
256        };
257
258        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
259        let Some(mut mcp_server_tx) = self.get_connection(&request.connection_id) else {
260            return Ok(Handled::No(request_cx));
261        };
262
263        mcp_server_tx
264            .send(MessageAndCx::Request(request.request, request_cx))
265            .await
266            .map_err(sacp::Error::into_internal_error)?;
267
268        Ok(Handled::Yes)
269    }
270
271    async fn handle_mcp_over_acp_notification(
272        &self,
273        result: Result<SuccessorNotification<McpOverAcpNotification<UntypedMessage>>, sacp::Error>,
274        notification_cx: JrConnectionCx,
275    ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
276        // Check if we parsed this message successfully.
277        let SuccessorNotification { notification } = match result {
278            Ok(request) => request,
279            Err(err) => {
280                notification_cx.send_error_notification(err)?;
281                return Ok(Handled::Yes);
282            }
283        };
284
285        // Check if we have a registered server with the given URL. If not, don't try to handle the request.
286        let Some(mut mcp_server_tx) = self.get_connection(&notification.connection_id) else {
287            return Ok(Handled::No(notification_cx));
288        };
289
290        mcp_server_tx
291            .send(MessageAndCx::Notification(
292                notification.notification,
293                notification_cx.clone(),
294            ))
295            .await
296            .map_err(sacp::Error::into_internal_error)?;
297
298        Ok(Handled::Yes)
299    }
300
301    async fn handle_mcp_disconnect_notification(
302        &self,
303        result: Result<SuccessorNotification<McpDisconnectNotification>, sacp::Error>,
304        notification_cx: JrConnectionCx,
305    ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
306        // Check if we parsed this message successfully.
307        let SuccessorNotification { notification } = match result {
308            Ok(request) => request,
309            Err(err) => {
310                notification_cx.send_error_notification(err)?;
311                return Ok(Handled::Yes);
312            }
313        };
314
315        // Remove connection if we have it. Otherwise, do not handle the notification.
316        if self.remove_connection(&notification.connection_id) {
317            Ok(Handled::Yes)
318        } else {
319            Ok(Handled::No(notification_cx))
320        }
321    }
322
323    async fn handle_new_session_request(
324        &self,
325        result: Result<NewSessionRequest, sacp::Error>,
326        request_cx: JrRequestCx<serde_json::Value>,
327    ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
328        // Check if we parsed this message successfully.
329        let mut request = match result {
330            Ok(request) => request,
331            Err(err) => {
332                request_cx.connection_cx().send_error_notification(err)?;
333                return Ok(Handled::Yes);
334            }
335        };
336
337        // Add the MCP servers into the session/new request.
338        //
339        // Q: Do we care if there are already servers with that name?
340        {
341            let data = self.data.lock().expect("not poisoned");
342            for server in data.registered_by_url.values() {
343                request.mcp_servers.push(server.acp_mcp_server());
344            }
345        }
346
347        // Forward it to the successor.
348        request_cx
349            .connection_cx()
350            .send_request_to_successor(request)
351            .forward_to_request_cx(request_cx.cast())?;
352
353        Ok(Handled::Yes)
354    }
355}
356
357impl JrMessageHandler for McpServiceRegistry {
358    fn describe_chain(&self) -> impl std::fmt::Debug {
359        "McpServiceRegistry"
360    }
361
362    async fn handle_message(
363        &mut self,
364        message: sacp::MessageAndCx,
365    ) -> Result<sacp::Handled<sacp::MessageAndCx>, sacp::Error> {
366        match message {
367            sacp::MessageAndCx::Request(msg, mut cx) => {
368                let params = msg.params();
369
370                if let Some(result) =
371                    <SuccessorRequest<McpConnectRequest>>::parse_request(cx.method(), params)
372                {
373                    cx = match self.handle_connect_request(result, cx).await? {
374                        Handled::Yes => return Ok(Handled::Yes),
375                        Handled::No(cx) => cx,
376                    };
377                }
378
379                if let Some(result) =
380                    <SuccessorRequest<McpOverAcpRequest<UntypedMessage>>>::parse_request(
381                        cx.method(),
382                        params,
383                    )
384                {
385                    cx = match self.handle_mcp_over_acp_request(result, cx).await? {
386                        Handled::Yes => return Ok(Handled::Yes),
387                        Handled::No(cx) => cx,
388                    };
389                }
390
391                if let Some(result) = <NewSessionRequest>::parse_request(cx.method(), params) {
392                    cx = match self.handle_new_session_request(result, cx).await? {
393                        Handled::Yes => return Ok(Handled::Yes),
394                        Handled::No(cx) => cx,
395                    };
396                }
397
398                Ok(Handled::No(sacp::MessageAndCx::Request(msg, cx)))
399            }
400            sacp::MessageAndCx::Notification(msg, mut cx) => {
401                let params = msg.params();
402
403                if let Some(result) =
404                    <SuccessorNotification<McpOverAcpNotification<UntypedMessage>>>::parse_notification(
405                        msg.method(),
406                        params,
407                    )
408                {
409                    cx = match self.handle_mcp_over_acp_notification(result, cx).await? {
410                        Handled::Yes => return Ok(Handled::Yes),
411                        Handled::No(cx) => cx,
412                    };
413                }
414
415                if let Some(result) =
416                    <SuccessorNotification<McpDisconnectNotification>>::parse_notification(
417                        msg.method(),
418                        params,
419                    )
420                {
421                    cx = match self.handle_mcp_disconnect_notification(result, cx).await? {
422                        Handled::Yes => return Ok(Handled::Yes),
423                        Handled::No(cx) => cx,
424                    };
425                }
426
427                Ok(sacp::Handled::No(sacp::MessageAndCx::Notification(msg, cx)))
428            }
429        }
430    }
431}
432
433#[derive(Clone)]
434struct RegisteredMcpServer {
435    name: String,
436    url: String,
437    spawn: Arc<dyn SpawnMcpServer>,
438}
439
440impl RegisteredMcpServer {
441    fn acp_mcp_server(&self) -> sacp::schema::McpServer {
442        sacp::schema::McpServer::Http {
443            name: self.name.clone(),
444            url: self.url.clone(),
445            headers: vec![],
446        }
447    }
448}
449
450impl std::fmt::Debug for RegisteredMcpServer {
451    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452        f.debug_struct("RegisteredMcpServer")
453            .field("name", &self.name)
454            .field("url", &self.url)
455            .finish()
456    }
457}
458
459/// Trait for spawning MCP server components.
460///
461/// This trait allows creating MCP server instances that implement the `Component` trait.
462trait SpawnMcpServer: Send + Sync + 'static {
463    /// Create a new MCP server component.
464    ///
465    /// Returns a `DynComponent` that can be used with the Component API.
466    fn spawn(&self) -> sacp::DynComponent;
467}
468
469impl AsRef<McpServiceRegistry> for McpServiceRegistry {
470    fn as_ref(&self) -> &McpServiceRegistry {
471        self
472    }
473}