1use futures::channel::mpsc;
2use futures::{SinkExt, StreamExt};
3use fxhash::FxHashMap;
4
5use sacp::schema::{NewSessionRequest, NewSessionResponse};
6use sacp::util::MatchMessage;
7use sacp::{
8 Channel, Component, DynComponent, Handled, JrConnectionCx, JrHandlerChain, JrMessageHandler,
9 JrRequestCx, MessageAndCx, UntypedMessage,
10};
11use std::sync::{Arc, Mutex};
12
13use crate::{
14 McpConnectRequest, McpConnectResponse, McpDisconnectNotification, McpOverAcpNotification,
15 McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
16};
17
18#[derive(Clone, Default, Debug)]
32pub struct McpServiceRegistry {
33 data: Arc<Mutex<McpServiceRegistryData>>,
34}
35
36#[derive(Default, Debug)]
37struct McpServiceRegistryData {
38 registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
39 registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
40 connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
41}
42
43impl McpServiceRegistry {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn add_mcp_server<C: Component>(
62 &self,
63 name: impl ToString,
64 new_fn: impl Fn() -> C + Send + Sync + 'static,
65 ) -> Result<(), sacp::Error> {
66 struct FnSpawner<F> {
67 new_fn: F,
68 }
69
70 impl<C, F> SpawnMcpServer for FnSpawner<F>
71 where
72 F: Fn() -> C + Send + Sync + 'static,
73 C: Component,
74 {
75 fn spawn(&self) -> DynComponent {
76 let component = (self.new_fn)();
77 DynComponent::new(component)
78 }
79 }
80
81 let name = name.to_string();
82 if let Some(_) = self.get_registered_server_by_name(&name) {
83 return Err(sacp::util::internal_error(format!(
84 "Server with name '{}' already exists",
85 name
86 )));
87 }
88
89 let uuid = uuid::Uuid::new_v4().to_string();
90 let service = Arc::new(RegisteredMcpServer {
91 name,
92 url: format!("acp:{uuid}"),
93 spawn: Arc::new(FnSpawner { new_fn }),
94 });
95 self.insert_registered_server(service);
96 Ok(())
97 }
98
99 fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
100 let mut data = self.data.lock().expect("not poisoned");
101 data.registered_by_name
102 .insert(service.name.clone(), service.clone());
103 data.registered_by_url
104 .insert(service.url.clone(), service.clone());
105 }
106
107 fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
108 self.data
109 .lock()
110 .expect("not poisoned")
111 .registered_by_name
112 .get(name)
113 .cloned()
114 }
115
116 fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
117 self.data
118 .lock()
119 .expect("not poisoned")
120 .registered_by_url
121 .get(url)
122 .cloned()
123 }
124
125 fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
126 self.data
127 .lock()
128 .expect("not poisoned")
129 .connections
130 .insert(connection_id.to_string(), tx);
131 }
132
133 fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
134 self.data
135 .lock()
136 .expect("not poisoned")
137 .connections
138 .get(connection_id)
139 .cloned()
140 }
141
142 fn remove_connection(&self, connection_id: &str) -> bool {
143 self.data
144 .lock()
145 .expect("not poisoned")
146 .connections
147 .remove(connection_id)
148 .is_some()
149 }
150
151 pub fn add_registered_mcp_servers_to(&self, request: &mut NewSessionRequest) {
174 let data = self.data.lock().expect("not poisoned");
175 for server in data.registered_by_url.values() {
176 request.mcp_servers.push(server.acp_mcp_server());
177 }
178 }
179
180 async fn handle_connect_request(
181 &self,
182 successor_request: SuccessorRequest<McpConnectRequest>,
183 request_cx: JrRequestCx<McpConnectResponse>,
184 ) -> Result<
185 Handled<(
186 SuccessorRequest<McpConnectRequest>,
187 JrRequestCx<McpConnectResponse>,
188 )>,
189 sacp::Error,
190 > {
191 let SuccessorRequest { request } = &successor_request;
192
193 let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
195 return Ok(Handled::No((successor_request, request_cx)));
196 };
197
198 let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
200 let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
201 self.insert_connection(&connection_id, mcp_server_tx);
202
203 let (client_channel, server_channel) = Channel::duplex();
205
206 let client_component = {
208 let connection_id = connection_id.clone();
209 let outer_cx = request_cx.connection_cx();
210
211 JrHandlerChain::new()
212 .on_receive_message(async move |message: sacp::MessageAndCx| {
213 let wrapped = message.map(
215 |request, request_cx| {
216 (
217 McpOverAcpRequest {
218 connection_id: connection_id.clone(),
219 request,
220 },
221 request_cx,
222 )
223 },
224 |notification, cx| {
225 (
226 McpOverAcpNotification {
227 connection_id: connection_id.clone(),
228 notification,
229 },
230 cx,
231 )
232 },
233 );
234 outer_cx.send_proxied_message(wrapped)
235 })
236 .with_spawned(move |mcp_cx| async move {
237 while let Some(msg) = mcp_server_rx.next().await {
238 mcp_cx.send_proxied_message(msg)?;
239 }
240 Ok(())
241 })
242 };
243
244 let mcp_server = registered_server.spawn.spawn();
246
247 let spawn_results = request_cx
249 .connection_cx()
250 .spawn(async move { client_component.serve(client_channel).await })
251 .and_then(|()| {
252 request_cx
254 .connection_cx()
255 .spawn(async move { mcp_server.serve(server_channel).await })
256 });
257
258 match spawn_results {
259 Ok(()) => {
260 request_cx.respond(McpConnectResponse { connection_id })?;
261 Ok(Handled::Yes)
262 }
263
264 Err(err) => {
265 request_cx.respond_with_error(err)?;
266 Ok(Handled::Yes)
267 }
268 }
269 }
270
271 async fn handle_mcp_over_acp_request(
272 &self,
273 successor_request: SuccessorRequest<McpOverAcpRequest<UntypedMessage>>,
274 request_cx: JrRequestCx<serde_json::Value>,
275 ) -> Result<
276 Handled<(
277 SuccessorRequest<McpOverAcpRequest<UntypedMessage>>,
278 JrRequestCx<serde_json::Value>,
279 )>,
280 sacp::Error,
281 > {
282 let Some(mut mcp_server_tx) = self.get_connection(&successor_request.request.connection_id)
284 else {
285 return Ok(Handled::No((successor_request, request_cx)));
286 };
287
288 let SuccessorRequest { request } = successor_request;
289
290 mcp_server_tx
291 .send(MessageAndCx::Request(request.request, request_cx))
292 .await
293 .map_err(sacp::Error::into_internal_error)?;
294
295 Ok(Handled::Yes)
296 }
297
298 async fn handle_mcp_over_acp_notification(
299 &self,
300 successor_notification: SuccessorNotification<McpOverAcpNotification<UntypedMessage>>,
301 notification_cx: JrConnectionCx,
302 ) -> Result<
303 Handled<(
304 SuccessorNotification<McpOverAcpNotification<UntypedMessage>>,
305 JrConnectionCx,
306 )>,
307 sacp::Error,
308 > {
309 let Some(mut mcp_server_tx) =
311 self.get_connection(&successor_notification.notification.connection_id)
312 else {
313 return Ok(Handled::No((successor_notification, notification_cx)));
314 };
315
316 let SuccessorNotification { notification } = successor_notification;
317
318 mcp_server_tx
319 .send(MessageAndCx::Notification(
320 notification.notification,
321 notification_cx.clone(),
322 ))
323 .await
324 .map_err(sacp::Error::into_internal_error)?;
325
326 Ok(Handled::Yes)
327 }
328
329 async fn handle_mcp_disconnect_notification(
330 &self,
331 successor_notification: SuccessorNotification<McpDisconnectNotification>,
332 notification_cx: JrConnectionCx,
333 ) -> Result<
334 Handled<(
335 SuccessorNotification<McpDisconnectNotification>,
336 JrConnectionCx,
337 )>,
338 sacp::Error,
339 > {
340 let SuccessorNotification { notification } = &successor_notification;
341
342 if self.remove_connection(¬ification.connection_id) {
344 Ok(Handled::Yes)
345 } else {
346 Ok(Handled::No((successor_notification, notification_cx)))
347 }
348 }
349
350 async fn handle_new_session_request(
351 &self,
352 mut request: NewSessionRequest,
353 request_cx: JrRequestCx<NewSessionResponse>,
354 ) -> Result<Handled<(NewSessionRequest, JrRequestCx<NewSessionResponse>)>, sacp::Error> {
355 self.add_registered_mcp_servers_to(&mut request);
359
360 Ok(Handled::No((request, request_cx)))
362 }
363}
364
365impl JrMessageHandler for McpServiceRegistry {
366 fn describe_chain(&self) -> impl std::fmt::Debug {
367 "McpServiceRegistry"
368 }
369
370 async fn handle_message(
371 &mut self,
372 message: sacp::MessageAndCx,
373 ) -> Result<sacp::Handled<sacp::MessageAndCx>, sacp::Error> {
374 MatchMessage::new(message)
375 .if_request(|request, request_cx| self.handle_connect_request(request, request_cx))
376 .await
377 .if_request(|request, request_cx| self.handle_mcp_over_acp_request(request, request_cx))
378 .await
379 .if_request(|request, request_cx| self.handle_new_session_request(request, request_cx))
380 .await
381 .if_notification(|notification, notification_cx| {
382 self.handle_mcp_over_acp_notification(notification, notification_cx)
383 })
384 .await
385 .if_notification(|notification, notification_cx| {
386 self.handle_mcp_disconnect_notification(notification, notification_cx)
387 })
388 .await
389 .done()
390 }
391}
392
393#[derive(Clone)]
394struct RegisteredMcpServer {
395 name: String,
396 url: String,
397 spawn: Arc<dyn SpawnMcpServer>,
398}
399
400impl RegisteredMcpServer {
401 fn acp_mcp_server(&self) -> sacp::schema::McpServer {
402 sacp::schema::McpServer::Http {
403 name: self.name.clone(),
404 url: self.url.clone(),
405 headers: vec![],
406 }
407 }
408}
409
410impl std::fmt::Debug for RegisteredMcpServer {
411 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412 f.debug_struct("RegisteredMcpServer")
413 .field("name", &self.name)
414 .field("url", &self.url)
415 .finish()
416 }
417}
418
419trait SpawnMcpServer: Send + Sync + 'static {
423 fn spawn(&self) -> sacp::DynComponent;
427}
428
429impl AsRef<McpServiceRegistry> for McpServiceRegistry {
430 fn as_ref(&self) -> &McpServiceRegistry {
431 self
432 }
433}