1use futures::channel::mpsc;
2use futures::{FutureExt, future::BoxFuture};
3use futures::{SinkExt, StreamExt};
4use fxhash::FxHashMap;
5use rmcp::ServiceExt;
6use sacp::NewSessionRequest;
7use sacp::{
8 Handled, JrConnection, JrConnectionCx, JrHandler, JrMessage, JrRequestCx, MessageAndCx,
9 UntypedMessage,
10};
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
14
15use crate::{
16 JrCxExt, McpConnectRequest, McpConnectResponse, McpDisconnectNotification,
17 McpOverAcpNotification, McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
18};
19
20#[derive(Clone, Default, Debug)]
31pub struct McpServiceRegistry {
32 data: Arc<Mutex<McpServiceRegistryData>>,
33}
34
35#[derive(Default, Debug)]
36struct McpServiceRegistryData {
37 registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
38 registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
39 connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
40}
41
42impl McpServiceRegistry {
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn with_rmcp_server<S>(
55 self,
56 name: impl ToString,
57 make_service: impl Fn() -> S + 'static + Send + Sync,
58 ) -> Result<Self, sacp::Error>
59 where
60 S: rmcp::Service<rmcp::RoleServer>,
61 {
62 self.add_rmcp_server(name, make_service)?;
63 Ok(self)
64 }
65
66 pub fn add_rmcp_server<S>(
73 &self,
74 name: impl ToString,
75 make_service: impl Fn() -> S + 'static + Send + Sync,
76 ) -> Result<(), sacp::Error>
77 where
78 S: rmcp::Service<rmcp::RoleServer>,
79 {
80 struct SpawnRmcpService<F> {
81 make_service: F,
82 }
83
84 impl<F, S> DynSpawnMcpServer for SpawnRmcpService<F>
85 where
86 F: Fn() -> S + Send + Sync + 'static,
87 S: rmcp::Service<rmcp::RoleServer>,
88 {
89 fn spawn(
90 &self,
91 outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
92 incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
93 ) -> BoxFuture<'static, Result<(), sacp::Error>> {
94 let server = (self.make_service)();
95 async move {
96 let running_server = server
97 .serve((incoming_bytes, outgoing_bytes))
98 .await
99 .map_err(sacp::Error::into_internal_error)?;
100
101 running_server
103 .waiting()
104 .await
105 .map(|_quit_reason| ())
106 .map_err(sacp::Error::into_internal_error)
107 }
108 .boxed()
109 }
110 }
111
112 let name = name.to_string();
113 self.add_mcp_service(name, Arc::new(SpawnRmcpService { make_service }))
114 }
115
116 fn add_mcp_service(
118 &self,
119 name: String,
120 spawn: Arc<dyn DynSpawnMcpServer>,
121 ) -> Result<(), sacp::Error> {
122 let name = name.to_string();
123 if let Some(_) = self.get_registered_server_by_name(&name) {
124 return Err(sacp::util::internal_error(format!(
125 "Server with name '{}' already exists",
126 name
127 )));
128 }
129
130 let uuid = uuid::Uuid::new_v4().to_string();
131 let service = Arc::new(RegisteredMcpServer {
132 name,
133 url: format!("acp:{uuid}"),
134 spawn,
135 });
136 self.insert_registered_server(service);
137 Ok(())
138 }
139
140 fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
141 let mut data = self.data.lock().expect("not poisoned");
142 data.registered_by_name
143 .insert(service.name.clone(), service.clone());
144 data.registered_by_url
145 .insert(service.url.clone(), service.clone());
146 }
147
148 fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
149 self.data
150 .lock()
151 .expect("not poisoned")
152 .registered_by_name
153 .get(name)
154 .cloned()
155 }
156
157 fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
158 self.data
159 .lock()
160 .expect("not poisoned")
161 .registered_by_url
162 .get(url)
163 .cloned()
164 }
165
166 fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
167 self.data
168 .lock()
169 .expect("not poisoned")
170 .connections
171 .insert(connection_id.to_string(), tx);
172 }
173
174 fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
175 self.data
176 .lock()
177 .expect("not poisoned")
178 .connections
179 .get(connection_id)
180 .cloned()
181 }
182
183 fn remove_connection(&self, connection_id: &str) -> bool {
184 self.data
185 .lock()
186 .expect("not poisoned")
187 .connections
188 .remove(connection_id)
189 .is_some()
190 }
191
192 async fn handle_connect_request(
193 &self,
194 result: Result<SuccessorRequest<McpConnectRequest>, sacp::Error>,
195 request_cx: JrRequestCx<serde_json::Value>,
196 ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
197 let SuccessorRequest { request } = match result {
199 Ok(request) => request,
200 Err(err) => {
201 request_cx.respond_with_error(err)?;
202 return Ok(Handled::Yes);
203 }
204 };
205
206 let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
208 return Ok(Handled::No(request_cx));
209 };
210
211 let request_cx = request_cx.cast::<McpConnectResponse>();
212
213 let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
215 let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
216 self.insert_connection(&connection_id, mcp_server_tx);
217
218 let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
220 let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
221 let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
222
223 let spawn_results = request_cx
231 .spawn(
232 JrConnection::new(mcp_client_write.compat_write(), mcp_client_read.compat())
233 .on_receive_message({
234 let connection_id = connection_id.clone();
235 let outer_cx = request_cx.connection_cx();
236 async move |message: sacp::MessageAndCx| {
237 let wrapped = message.map(
239 |request, request_cx| {
240 (
241 McpOverAcpRequest {
242 connection_id: connection_id.clone(),
243 request,
244 },
245 request_cx,
246 )
247 },
248 |notification, cx| {
249 (
250 McpOverAcpNotification {
251 connection_id: connection_id.clone(),
252 notification,
253 },
254 cx,
255 )
256 },
257 );
258 outer_cx.send_proxied_message(wrapped)
259 }
260 })
261 .with_client({
262 async move |mcp_cx| {
263 while let Some(msg) = mcp_server_rx.next().await {
264 mcp_cx.send_proxied_message(msg)?;
265 }
266 Ok(())
267 }
268 }),
269 )
270 .and_then(|()| {
271 request_cx.spawn(async move {
273 registered_server
274 .spawn
275 .spawn(Box::pin(mcp_server_write), Box::pin(mcp_server_read))
276 .await
277 })
278 });
279
280 match spawn_results {
281 Ok(()) => {
282 request_cx.respond(McpConnectResponse { connection_id })?;
283 Ok(Handled::Yes)
284 }
285
286 Err(err) => {
287 request_cx.respond_with_error(err)?;
288 Ok(Handled::Yes)
289 }
290 }
291 }
292
293 async fn handle_mcp_over_acp_request(
294 &self,
295 result: Result<SuccessorRequest<McpOverAcpRequest<UntypedMessage>>, sacp::Error>,
296 request_cx: JrRequestCx<serde_json::Value>,
297 ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
298 let SuccessorRequest { request } = match result {
300 Ok(request) => request,
301 Err(err) => {
302 request_cx.respond_with_error(err)?;
303 return Ok(Handled::Yes);
304 }
305 };
306
307 let Some(mut mcp_server_tx) = self.get_connection(&request.connection_id) else {
309 return Ok(Handled::No(request_cx));
310 };
311
312 mcp_server_tx
313 .send(MessageAndCx::Request(request.request, request_cx))
314 .await
315 .map_err(sacp::Error::into_internal_error)?;
316
317 Ok(Handled::Yes)
318 }
319
320 async fn handle_mcp_over_acp_notification(
321 &self,
322 result: Result<SuccessorNotification<McpOverAcpNotification<UntypedMessage>>, sacp::Error>,
323 notification_cx: JrConnectionCx,
324 ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
325 let SuccessorNotification { notification } = match result {
327 Ok(request) => request,
328 Err(err) => {
329 notification_cx.send_error_notification(err)?;
330 return Ok(Handled::Yes);
331 }
332 };
333
334 let Some(mut mcp_server_tx) = self.get_connection(¬ification.connection_id) else {
336 return Ok(Handled::No(notification_cx));
337 };
338
339 mcp_server_tx
340 .send(MessageAndCx::Notification(
341 notification.notification,
342 notification_cx.clone(),
343 ))
344 .await
345 .map_err(sacp::Error::into_internal_error)?;
346
347 Ok(Handled::Yes)
348 }
349
350 async fn handle_mcp_disconnect_notification(
351 &self,
352 result: Result<SuccessorNotification<McpDisconnectNotification>, sacp::Error>,
353 notification_cx: JrConnectionCx,
354 ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
355 let SuccessorNotification { notification } = match result {
357 Ok(request) => request,
358 Err(err) => {
359 notification_cx.send_error_notification(err)?;
360 return Ok(Handled::Yes);
361 }
362 };
363
364 if self.remove_connection(¬ification.connection_id) {
366 Ok(Handled::Yes)
367 } else {
368 Ok(Handled::No(notification_cx))
369 }
370 }
371
372 async fn handle_new_session_request(
373 &self,
374 result: Result<NewSessionRequest, sacp::Error>,
375 request_cx: JrRequestCx<serde_json::Value>,
376 ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
377 let mut request = match result {
379 Ok(request) => request,
380 Err(err) => {
381 request_cx.send_error_notification(err)?;
382 return Ok(Handled::Yes);
383 }
384 };
385
386 {
390 let data = self.data.lock().expect("not poisoned");
391 for server in data.registered_by_url.values() {
392 request.mcp_servers.push(server.acp_mcp_server());
393 }
394 }
395
396 request_cx
398 .send_request_to_successor(request)
399 .forward_to_request_cx(request_cx.cast())?;
400
401 Ok(Handled::Yes)
402 }
403}
404
405impl JrHandler for McpServiceRegistry {
406 fn describe_chain(&self) -> impl std::fmt::Debug {
407 "McpServiceRegistry"
408 }
409
410 async fn handle_message(
411 &mut self,
412 message: sacp::MessageAndCx,
413 ) -> Result<sacp::Handled<sacp::MessageAndCx>, sacp::Error> {
414 match message {
415 sacp::MessageAndCx::Request(msg, mut cx) => {
416 let params = msg.params();
417
418 if let Some(result) =
419 <SuccessorRequest<McpConnectRequest>>::parse_request(cx.method(), params)
420 {
421 cx = match self.handle_connect_request(result, cx).await? {
422 Handled::Yes => return Ok(Handled::Yes),
423 Handled::No(cx) => cx,
424 };
425 }
426
427 if let Some(result) =
428 <SuccessorRequest<McpOverAcpRequest<UntypedMessage>>>::parse_request(
429 cx.method(),
430 params,
431 )
432 {
433 cx = match self.handle_mcp_over_acp_request(result, cx).await? {
434 Handled::Yes => return Ok(Handled::Yes),
435 Handled::No(cx) => cx,
436 };
437 }
438
439 if let Some(result) = <NewSessionRequest>::parse_request(cx.method(), params) {
440 cx = match self.handle_new_session_request(result, cx).await? {
441 Handled::Yes => return Ok(Handled::Yes),
442 Handled::No(cx) => cx,
443 };
444 }
445
446 Ok(Handled::No(sacp::MessageAndCx::Request(msg, cx)))
447 }
448 sacp::MessageAndCx::Notification(msg, mut cx) => {
449 let params = msg.params();
450
451 if let Some(result) =
452 <SuccessorNotification<McpOverAcpNotification<UntypedMessage>>>::parse_notification(
453 msg.method(),
454 params,
455 )
456 {
457 cx = match self.handle_mcp_over_acp_notification(result, cx).await? {
458 Handled::Yes => return Ok(Handled::Yes),
459 Handled::No(cx) => cx,
460 };
461 }
462
463 if let Some(result) =
464 <SuccessorNotification<McpDisconnectNotification>>::parse_notification(
465 msg.method(),
466 params,
467 )
468 {
469 cx = match self.handle_mcp_disconnect_notification(result, cx).await? {
470 Handled::Yes => return Ok(Handled::Yes),
471 Handled::No(cx) => cx,
472 };
473 }
474
475 Ok(sacp::Handled::No(sacp::MessageAndCx::Notification(msg, cx)))
476 }
477 }
478 }
479}
480
481#[derive(Clone)]
482struct RegisteredMcpServer {
483 name: String,
484 url: String,
485 spawn: Arc<dyn DynSpawnMcpServer>,
486}
487
488impl RegisteredMcpServer {
489 fn acp_mcp_server(&self) -> sacp::McpServer {
490 sacp::McpServer::Http {
491 name: self.name.clone(),
492 url: self.url.clone(),
493 headers: vec![],
494 }
495 }
496}
497
498impl std::fmt::Debug for RegisteredMcpServer {
499 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500 f.debug_struct("RegisteredMcpServer")
501 .field("name", &self.name)
502 .field("url", &self.url)
503 .finish()
504 }
505}
506
507trait DynSpawnMcpServer: 'static + Send + Sync {
508 fn spawn(
509 &self,
510 outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
511 incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
512 ) -> BoxFuture<'static, Result<(), sacp::Error>>;
513}
514
515impl AsRef<McpServiceRegistry> for McpServiceRegistry {
516 fn as_ref(&self) -> &McpServiceRegistry {
517 self
518 }
519}