1use futures::channel::mpsc;
2use futures::{FutureExt, future::BoxFuture};
3use futures::{SinkExt, StreamExt};
4use fxhash::FxHashMap;
5use rmcp::ServiceExt;
6use sacp::schema::NewSessionRequest;
7use sacp::{
8 Handled, JrConnection, JrConnectionCx, JrHandler, JrMessage, JrRequestCx, MessageAndCx,
9 UntypedMessage,
10};
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
14
15use crate::{
16 JrCxExt, McpConnectRequest, McpConnectResponse, McpDisconnectNotification,
17 McpOverAcpNotification, McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
18};
19
20#[derive(Clone, Default, Debug)]
31pub struct McpServiceRegistry {
32 data: Arc<Mutex<McpServiceRegistryData>>,
33}
34
35#[derive(Default, Debug)]
36struct McpServiceRegistryData {
37 registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
38 registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
39 connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
40}
41
42impl McpServiceRegistry {
43 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn with_rmcp_server<S>(
56 self,
57 name: impl ToString,
58 make_service: impl Fn() -> S + 'static + Send + Sync,
59 ) -> Result<Self, sacp::Error>
60 where
61 S: rmcp::Service<rmcp::RoleServer>,
62 {
63 self.add_rmcp_server(name, make_service)?;
64 Ok(self)
65 }
66
67 pub fn add_rmcp_server<S>(
74 &self,
75 name: impl ToString,
76 make_service: impl Fn() -> S + 'static + Send + Sync,
77 ) -> Result<(), sacp::Error>
78 where
79 S: rmcp::Service<rmcp::RoleServer>,
80 {
81 struct SpawnRmcpService<F> {
82 make_service: F,
83 }
84
85 impl<F, S> DynSpawnMcpServer for SpawnRmcpService<F>
86 where
87 F: Fn() -> S + Send + Sync + 'static,
88 S: rmcp::Service<rmcp::RoleServer>,
89 {
90 fn spawn(
91 &self,
92 outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
93 incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
94 ) -> BoxFuture<'static, Result<(), sacp::Error>> {
95 let server = (self.make_service)();
96 async move {
97 let running_server = server
98 .serve((incoming_bytes, outgoing_bytes))
99 .await
100 .map_err(sacp::Error::into_internal_error)?;
101
102 running_server
104 .waiting()
105 .await
106 .map(|_quit_reason| ())
107 .map_err(sacp::Error::into_internal_error)
108 }
109 .boxed()
110 }
111 }
112
113 let name = name.to_string();
114 self.add_mcp_service(name, Arc::new(SpawnRmcpService { make_service }))
115 }
116
117 fn add_mcp_service(
119 &self,
120 name: String,
121 spawn: Arc<dyn DynSpawnMcpServer>,
122 ) -> Result<(), sacp::Error> {
123 let name = name.to_string();
124 if let Some(_) = self.get_registered_server_by_name(&name) {
125 return Err(sacp::util::internal_error(format!(
126 "Server with name '{}' already exists",
127 name
128 )));
129 }
130
131 let uuid = uuid::Uuid::new_v4().to_string();
132 let service = Arc::new(RegisteredMcpServer {
133 name,
134 url: format!("acp:{uuid}"),
135 spawn,
136 });
137 self.insert_registered_server(service);
138 Ok(())
139 }
140
141 fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
142 let mut data = self.data.lock().expect("not poisoned");
143 data.registered_by_name
144 .insert(service.name.clone(), service.clone());
145 data.registered_by_url
146 .insert(service.url.clone(), service.clone());
147 }
148
149 fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
150 self.data
151 .lock()
152 .expect("not poisoned")
153 .registered_by_name
154 .get(name)
155 .cloned()
156 }
157
158 fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
159 self.data
160 .lock()
161 .expect("not poisoned")
162 .registered_by_url
163 .get(url)
164 .cloned()
165 }
166
167 fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
168 self.data
169 .lock()
170 .expect("not poisoned")
171 .connections
172 .insert(connection_id.to_string(), tx);
173 }
174
175 fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
176 self.data
177 .lock()
178 .expect("not poisoned")
179 .connections
180 .get(connection_id)
181 .cloned()
182 }
183
184 fn remove_connection(&self, connection_id: &str) -> bool {
185 self.data
186 .lock()
187 .expect("not poisoned")
188 .connections
189 .remove(connection_id)
190 .is_some()
191 }
192
193 async fn handle_connect_request(
194 &self,
195 result: Result<SuccessorRequest<McpConnectRequest>, sacp::Error>,
196 request_cx: JrRequestCx<serde_json::Value>,
197 ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
198 let SuccessorRequest { request } = match result {
200 Ok(request) => request,
201 Err(err) => {
202 request_cx.respond_with_error(err)?;
203 return Ok(Handled::Yes);
204 }
205 };
206
207 let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
209 return Ok(Handled::No(request_cx));
210 };
211
212 let request_cx = request_cx.cast::<McpConnectResponse>();
213
214 let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
216 let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
217 self.insert_connection(&connection_id, mcp_server_tx);
218
219 let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
221 let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
222 let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
223
224 let spawn_results = request_cx
232 .spawn(
233 JrConnection::new(mcp_client_write.compat_write(), mcp_client_read.compat())
234 .on_receive_message({
235 let connection_id = connection_id.clone();
236 let outer_cx = request_cx.connection_cx();
237 async move |message: sacp::MessageAndCx| {
238 let wrapped = message.map(
240 |request, request_cx| {
241 (
242 McpOverAcpRequest {
243 connection_id: connection_id.clone(),
244 request,
245 },
246 request_cx,
247 )
248 },
249 |notification, cx| {
250 (
251 McpOverAcpNotification {
252 connection_id: connection_id.clone(),
253 notification,
254 },
255 cx,
256 )
257 },
258 );
259 outer_cx.send_proxied_message(wrapped)
260 }
261 })
262 .with_client({
263 async move |mcp_cx| {
264 while let Some(msg) = mcp_server_rx.next().await {
265 mcp_cx.send_proxied_message(msg)?;
266 }
267 Ok(())
268 }
269 }),
270 )
271 .and_then(|()| {
272 request_cx.spawn(async move {
274 registered_server
275 .spawn
276 .spawn(Box::pin(mcp_server_write), Box::pin(mcp_server_read))
277 .await
278 })
279 });
280
281 match spawn_results {
282 Ok(()) => {
283 request_cx.respond(McpConnectResponse { connection_id })?;
284 Ok(Handled::Yes)
285 }
286
287 Err(err) => {
288 request_cx.respond_with_error(err)?;
289 Ok(Handled::Yes)
290 }
291 }
292 }
293
294 async fn handle_mcp_over_acp_request(
295 &self,
296 result: Result<SuccessorRequest<McpOverAcpRequest<UntypedMessage>>, sacp::Error>,
297 request_cx: JrRequestCx<serde_json::Value>,
298 ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
299 let SuccessorRequest { request } = match result {
301 Ok(request) => request,
302 Err(err) => {
303 request_cx.respond_with_error(err)?;
304 return Ok(Handled::Yes);
305 }
306 };
307
308 let Some(mut mcp_server_tx) = self.get_connection(&request.connection_id) else {
310 return Ok(Handled::No(request_cx));
311 };
312
313 mcp_server_tx
314 .send(MessageAndCx::Request(request.request, request_cx))
315 .await
316 .map_err(sacp::Error::into_internal_error)?;
317
318 Ok(Handled::Yes)
319 }
320
321 async fn handle_mcp_over_acp_notification(
322 &self,
323 result: Result<SuccessorNotification<McpOverAcpNotification<UntypedMessage>>, sacp::Error>,
324 notification_cx: JrConnectionCx,
325 ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
326 let SuccessorNotification { notification } = match result {
328 Ok(request) => request,
329 Err(err) => {
330 notification_cx.send_error_notification(err)?;
331 return Ok(Handled::Yes);
332 }
333 };
334
335 let Some(mut mcp_server_tx) = self.get_connection(¬ification.connection_id) else {
337 return Ok(Handled::No(notification_cx));
338 };
339
340 mcp_server_tx
341 .send(MessageAndCx::Notification(
342 notification.notification,
343 notification_cx.clone(),
344 ))
345 .await
346 .map_err(sacp::Error::into_internal_error)?;
347
348 Ok(Handled::Yes)
349 }
350
351 async fn handle_mcp_disconnect_notification(
352 &self,
353 result: Result<SuccessorNotification<McpDisconnectNotification>, sacp::Error>,
354 notification_cx: JrConnectionCx,
355 ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
356 let SuccessorNotification { notification } = match result {
358 Ok(request) => request,
359 Err(err) => {
360 notification_cx.send_error_notification(err)?;
361 return Ok(Handled::Yes);
362 }
363 };
364
365 if self.remove_connection(¬ification.connection_id) {
367 Ok(Handled::Yes)
368 } else {
369 Ok(Handled::No(notification_cx))
370 }
371 }
372
373 async fn handle_new_session_request(
374 &self,
375 result: Result<NewSessionRequest, sacp::Error>,
376 request_cx: JrRequestCx<serde_json::Value>,
377 ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
378 let mut request = match result {
380 Ok(request) => request,
381 Err(err) => {
382 request_cx.send_error_notification(err)?;
383 return Ok(Handled::Yes);
384 }
385 };
386
387 {
391 let data = self.data.lock().expect("not poisoned");
392 for server in data.registered_by_url.values() {
393 request.mcp_servers.push(server.acp_mcp_server());
394 }
395 }
396
397 request_cx
399 .send_request_to_successor(request)
400 .forward_to_request_cx(request_cx.cast())?;
401
402 Ok(Handled::Yes)
403 }
404}
405
406impl JrHandler for McpServiceRegistry {
407 fn describe_chain(&self) -> impl std::fmt::Debug {
408 "McpServiceRegistry"
409 }
410
411 async fn handle_message(
412 &mut self,
413 message: sacp::MessageAndCx,
414 ) -> Result<sacp::Handled<sacp::MessageAndCx>, sacp::Error> {
415 match message {
416 sacp::MessageAndCx::Request(msg, mut cx) => {
417 let params = msg.params();
418
419 if let Some(result) =
420 <SuccessorRequest<McpConnectRequest>>::parse_request(cx.method(), params)
421 {
422 cx = match self.handle_connect_request(result, cx).await? {
423 Handled::Yes => return Ok(Handled::Yes),
424 Handled::No(cx) => cx,
425 };
426 }
427
428 if let Some(result) =
429 <SuccessorRequest<McpOverAcpRequest<UntypedMessage>>>::parse_request(
430 cx.method(),
431 params,
432 )
433 {
434 cx = match self.handle_mcp_over_acp_request(result, cx).await? {
435 Handled::Yes => return Ok(Handled::Yes),
436 Handled::No(cx) => cx,
437 };
438 }
439
440 if let Some(result) = <NewSessionRequest>::parse_request(cx.method(), params) {
441 cx = match self.handle_new_session_request(result, cx).await? {
442 Handled::Yes => return Ok(Handled::Yes),
443 Handled::No(cx) => cx,
444 };
445 }
446
447 Ok(Handled::No(sacp::MessageAndCx::Request(msg, cx)))
448 }
449 sacp::MessageAndCx::Notification(msg, mut cx) => {
450 let params = msg.params();
451
452 if let Some(result) =
453 <SuccessorNotification<McpOverAcpNotification<UntypedMessage>>>::parse_notification(
454 msg.method(),
455 params,
456 )
457 {
458 cx = match self.handle_mcp_over_acp_notification(result, cx).await? {
459 Handled::Yes => return Ok(Handled::Yes),
460 Handled::No(cx) => cx,
461 };
462 }
463
464 if let Some(result) =
465 <SuccessorNotification<McpDisconnectNotification>>::parse_notification(
466 msg.method(),
467 params,
468 )
469 {
470 cx = match self.handle_mcp_disconnect_notification(result, cx).await? {
471 Handled::Yes => return Ok(Handled::Yes),
472 Handled::No(cx) => cx,
473 };
474 }
475
476 Ok(sacp::Handled::No(sacp::MessageAndCx::Notification(msg, cx)))
477 }
478 }
479 }
480}
481
482#[derive(Clone)]
483struct RegisteredMcpServer {
484 name: String,
485 url: String,
486 spawn: Arc<dyn DynSpawnMcpServer>,
487}
488
489impl RegisteredMcpServer {
490 fn acp_mcp_server(&self) -> sacp::schema::McpServer {
491 sacp::schema::McpServer::Http {
492 name: self.name.clone(),
493 url: self.url.clone(),
494 headers: vec![],
495 }
496 }
497}
498
499impl std::fmt::Debug for RegisteredMcpServer {
500 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501 f.debug_struct("RegisteredMcpServer")
502 .field("name", &self.name)
503 .field("url", &self.url)
504 .finish()
505 }
506}
507
508trait DynSpawnMcpServer: 'static + Send + Sync {
509 fn spawn(
510 &self,
511 outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
512 incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
513 ) -> BoxFuture<'static, Result<(), sacp::Error>>;
514}
515
516impl AsRef<McpServiceRegistry> for McpServiceRegistry {
517 fn as_ref(&self) -> &McpServiceRegistry {
518 self
519 }
520}