1use agent_client_protocol::{self as acp, NewSessionRequest};
2use futures::channel::mpsc;
3use futures::{FutureExt, future::BoxFuture};
4use futures::{SinkExt, StreamExt};
5use fxhash::FxHashMap;
6use rmcp::ServiceExt;
7use sacp::{
8 Handled, JsonRpcConnection, JsonRpcConnectionCx, JsonRpcHandler, JsonRpcMessage,
9 JsonRpcRequestCx, MessageAndCx, UntypedMessage,
10};
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
14
15use crate::{
16 JsonRpcCxExt, McpConnectRequest, McpConnectResponse, McpDisconnectNotification,
17 McpOverAcpNotification, McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
18};
19
20#[derive(Clone, Default, Debug)]
31pub struct McpServiceRegistry {
32 data: Arc<Mutex<McpServiceRegistryData>>,
33}
34
35#[derive(Default, Debug)]
36struct McpServiceRegistryData {
37 registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
38 registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
39 connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
40}
41
42impl McpServiceRegistry {
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn with_rmcp_server<S>(
55 self,
56 name: impl ToString,
57 make_service: impl Fn() -> S + 'static + Send + Sync,
58 ) -> Result<Self, acp::Error>
59 where
60 S: rmcp::Service<rmcp::RoleServer>,
61 {
62 self.add_rmcp_server(name, make_service)?;
63 Ok(self)
64 }
65
66 pub fn add_rmcp_server<S>(
73 &self,
74 name: impl ToString,
75 make_service: impl Fn() -> S + 'static + Send + Sync,
76 ) -> Result<(), acp::Error>
77 where
78 S: rmcp::Service<rmcp::RoleServer>,
79 {
80 struct SpawnRmcpService<F> {
81 make_service: F,
82 }
83
84 impl<F, S> DynSpawnMcpServer for SpawnRmcpService<F>
85 where
86 F: Fn() -> S + Send + Sync + 'static,
87 S: rmcp::Service<rmcp::RoleServer>,
88 {
89 fn spawn(
90 &self,
91 outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
92 incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
93 ) -> BoxFuture<'static, Result<(), acp::Error>> {
94 let server = (self.make_service)();
95 async move {
96 let running_server = server
97 .serve((incoming_bytes, outgoing_bytes))
98 .await
99 .map_err(acp::Error::into_internal_error)?;
100
101 running_server
103 .waiting()
104 .await
105 .map(|_quit_reason| ())
106 .map_err(acp::Error::into_internal_error)
107 }
108 .boxed()
109 }
110 }
111
112 let name = name.to_string();
113 self.add_mcp_service(name, Arc::new(SpawnRmcpService { make_service }))
114 }
115
116 fn add_mcp_service(
118 &self,
119 name: String,
120 spawn: Arc<dyn DynSpawnMcpServer>,
121 ) -> Result<(), acp::Error> {
122 let name = name.to_string();
123 if let Some(_) = self.get_registered_server_by_name(&name) {
124 return Err(sacp::util::internal_error(format!(
125 "Server with name '{}' already exists",
126 name
127 )));
128 }
129
130 let uuid = uuid::Uuid::new_v4().to_string();
131 let service = Arc::new(RegisteredMcpServer {
132 name,
133 url: format!("acp:{uuid}"),
134 spawn,
135 });
136 self.insert_registered_server(service);
137 Ok(())
138 }
139
140 fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
141 let mut data = self.data.lock().expect("not poisoned");
142 data.registered_by_name
143 .insert(service.name.clone(), service.clone());
144 data.registered_by_url
145 .insert(service.url.clone(), service.clone());
146 }
147
148 fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
149 self.data
150 .lock()
151 .expect("not poisoned")
152 .registered_by_name
153 .get(name)
154 .cloned()
155 }
156
157 fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
158 self.data
159 .lock()
160 .expect("not poisoned")
161 .registered_by_url
162 .get(url)
163 .cloned()
164 }
165
166 fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
167 self.data
168 .lock()
169 .expect("not poisoned")
170 .connections
171 .insert(connection_id.to_string(), tx);
172 }
173
174 fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
175 self.data
176 .lock()
177 .expect("not poisoned")
178 .connections
179 .get(connection_id)
180 .cloned()
181 }
182
183 fn remove_connection(&self, connection_id: &str) -> bool {
184 self.data
185 .lock()
186 .expect("not poisoned")
187 .connections
188 .remove(connection_id)
189 .is_some()
190 }
191
192 async fn handle_connect_request(
193 &self,
194 result: Result<SuccessorRequest<McpConnectRequest>, agent_client_protocol::Error>,
195 request_cx: JsonRpcRequestCx<serde_json::Value>,
196 ) -> Result<Handled<JsonRpcRequestCx<serde_json::Value>>, agent_client_protocol::Error> {
197 let SuccessorRequest { request } = match result {
199 Ok(request) => request,
200 Err(err) => {
201 request_cx.respond_with_error(err)?;
202 return Ok(Handled::Yes);
203 }
204 };
205
206 let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
208 return Ok(Handled::No(request_cx));
209 };
210
211 let request_cx = request_cx.cast::<McpConnectResponse>();
212
213 let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
215 let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
216 self.insert_connection(&connection_id, mcp_server_tx);
217
218 let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
220 let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
221 let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
222
223 let spawn_results = request_cx
231 .spawn(
232 JsonRpcConnection::new(mcp_client_write.compat_write(), mcp_client_read.compat())
233 .on_receive_message({
234 let connection_id = connection_id.clone();
235 let outer_cx = request_cx.connection_cx();
236 async move |message: sacp::MessageAndCx| {
237 let wrapped = message.map(
239 |request, request_cx| {
240 (
241 McpOverAcpRequest {
242 connection_id: connection_id.clone(),
243 request,
244 },
245 request_cx,
246 )
247 },
248 |notification, cx| {
249 (
250 McpOverAcpNotification {
251 connection_id: connection_id.clone(),
252 notification,
253 },
254 cx,
255 )
256 },
257 );
258 outer_cx.send_proxied_message(wrapped)
259 }
260 })
261 .with_client({
262 async move |mcp_cx| {
263 while let Some(msg) = mcp_server_rx.next().await {
264 mcp_cx.send_proxied_message(msg)?;
265 }
266 Ok(())
267 }
268 }),
269 )
270 .and_then(|()| {
271 request_cx.spawn(async move {
273 registered_server
274 .spawn
275 .spawn(Box::pin(mcp_server_write), Box::pin(mcp_server_read))
276 .await
277 })
278 });
279
280 match spawn_results {
281 Ok(()) => {
282 request_cx.respond(McpConnectResponse { connection_id })?;
283 Ok(Handled::Yes)
284 }
285
286 Err(err) => {
287 request_cx.respond_with_error(err)?;
288 Ok(Handled::Yes)
289 }
290 }
291 }
292
293 async fn handle_mcp_over_acp_request(
294 &self,
295 result: Result<
296 SuccessorRequest<McpOverAcpRequest<UntypedMessage>>,
297 agent_client_protocol::Error,
298 >,
299 request_cx: JsonRpcRequestCx<serde_json::Value>,
300 ) -> Result<Handled<JsonRpcRequestCx<serde_json::Value>>, agent_client_protocol::Error> {
301 let SuccessorRequest { request } = match result {
303 Ok(request) => request,
304 Err(err) => {
305 request_cx.respond_with_error(err)?;
306 return Ok(Handled::Yes);
307 }
308 };
309
310 let Some(mut mcp_server_tx) = self.get_connection(&request.connection_id) else {
312 return Ok(Handled::No(request_cx));
313 };
314
315 mcp_server_tx
316 .send(MessageAndCx::Request(request.request, request_cx))
317 .await
318 .map_err(acp::Error::into_internal_error)?;
319
320 Ok(Handled::Yes)
321 }
322
323 async fn handle_mcp_over_acp_notification(
324 &self,
325 result: Result<
326 SuccessorNotification<McpOverAcpNotification<UntypedMessage>>,
327 agent_client_protocol::Error,
328 >,
329 notification_cx: JsonRpcConnectionCx,
330 ) -> Result<Handled<JsonRpcConnectionCx>, agent_client_protocol::Error> {
331 let SuccessorNotification { notification } = match result {
333 Ok(request) => request,
334 Err(err) => {
335 notification_cx.send_error_notification(err)?;
336 return Ok(Handled::Yes);
337 }
338 };
339
340 let Some(mut mcp_server_tx) = self.get_connection(¬ification.connection_id) else {
342 return Ok(Handled::No(notification_cx));
343 };
344
345 mcp_server_tx
346 .send(MessageAndCx::Notification(
347 notification.notification,
348 notification_cx.clone(),
349 ))
350 .await
351 .map_err(acp::Error::into_internal_error)?;
352
353 Ok(Handled::Yes)
354 }
355
356 async fn handle_mcp_disconnect_notification(
357 &self,
358 result: Result<
359 SuccessorNotification<McpDisconnectNotification>,
360 agent_client_protocol::Error,
361 >,
362 notification_cx: JsonRpcConnectionCx,
363 ) -> Result<Handled<JsonRpcConnectionCx>, agent_client_protocol::Error> {
364 let SuccessorNotification { notification } = match result {
366 Ok(request) => request,
367 Err(err) => {
368 notification_cx.send_error_notification(err)?;
369 return Ok(Handled::Yes);
370 }
371 };
372
373 if self.remove_connection(¬ification.connection_id) {
375 Ok(Handled::Yes)
376 } else {
377 Ok(Handled::No(notification_cx))
378 }
379 }
380
381 async fn handle_new_session_request(
382 &self,
383 result: Result<NewSessionRequest, agent_client_protocol::Error>,
384 request_cx: JsonRpcRequestCx<serde_json::Value>,
385 ) -> Result<Handled<JsonRpcRequestCx<serde_json::Value>>, agent_client_protocol::Error> {
386 let mut request = match result {
388 Ok(request) => request,
389 Err(err) => {
390 request_cx.send_error_notification(err)?;
391 return Ok(Handled::Yes);
392 }
393 };
394
395 {
399 let data = self.data.lock().expect("not poisoned");
400 for server in data.registered_by_url.values() {
401 request.mcp_servers.push(server.acp_mcp_server());
402 }
403 }
404
405 request_cx
407 .send_request_to_successor(request)
408 .forward_to_request_cx(request_cx.cast())?;
409
410 Ok(Handled::Yes)
411 }
412}
413
414impl JsonRpcHandler for McpServiceRegistry {
415 fn describe_chain(&self) -> impl std::fmt::Debug {
416 "McpServiceRegistry"
417 }
418
419 async fn handle_message(
420 &mut self,
421 message: sacp::MessageAndCx,
422 ) -> Result<sacp::Handled<sacp::MessageAndCx>, agent_client_protocol::Error> {
423 match message {
424 sacp::MessageAndCx::Request(msg, mut cx) => {
425 let params = msg.params();
426
427 if let Some(result) =
428 <SuccessorRequest<McpConnectRequest>>::parse_request(cx.method(), params)
429 {
430 cx = match self.handle_connect_request(result, cx).await? {
431 Handled::Yes => return Ok(Handled::Yes),
432 Handled::No(cx) => cx,
433 };
434 }
435
436 if let Some(result) =
437 <SuccessorRequest<McpOverAcpRequest<UntypedMessage>>>::parse_request(
438 cx.method(),
439 params,
440 )
441 {
442 cx = match self.handle_mcp_over_acp_request(result, cx).await? {
443 Handled::Yes => return Ok(Handled::Yes),
444 Handled::No(cx) => cx,
445 };
446 }
447
448 if let Some(result) = <NewSessionRequest>::parse_request(cx.method(), params) {
449 cx = match self.handle_new_session_request(result, cx).await? {
450 Handled::Yes => return Ok(Handled::Yes),
451 Handled::No(cx) => cx,
452 };
453 }
454
455 Ok(Handled::No(sacp::MessageAndCx::Request(msg, cx)))
456 }
457 sacp::MessageAndCx::Notification(msg, mut cx) => {
458 let params = msg.params();
459
460 if let Some(result) =
461 <SuccessorNotification<McpOverAcpNotification<UntypedMessage>>>::parse_notification(
462 msg.method(),
463 params,
464 )
465 {
466 cx = match self.handle_mcp_over_acp_notification(result, cx).await? {
467 Handled::Yes => return Ok(Handled::Yes),
468 Handled::No(cx) => cx,
469 };
470 }
471
472 if let Some(result) =
473 <SuccessorNotification<McpDisconnectNotification>>::parse_notification(
474 msg.method(),
475 params,
476 )
477 {
478 cx = match self.handle_mcp_disconnect_notification(result, cx).await? {
479 Handled::Yes => return Ok(Handled::Yes),
480 Handled::No(cx) => cx,
481 };
482 }
483
484 Ok(sacp::Handled::No(sacp::MessageAndCx::Notification(msg, cx)))
485 }
486 }
487 }
488}
489
490#[derive(Clone)]
491struct RegisteredMcpServer {
492 name: String,
493 url: String,
494 spawn: Arc<dyn DynSpawnMcpServer>,
495}
496
497impl RegisteredMcpServer {
498 fn acp_mcp_server(&self) -> acp::McpServer {
499 acp::McpServer::Http {
500 name: self.name.clone(),
501 url: self.url.clone(),
502 headers: vec![],
503 }
504 }
505}
506
507impl std::fmt::Debug for RegisteredMcpServer {
508 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
509 f.debug_struct("RegisteredMcpServer")
510 .field("name", &self.name)
511 .field("url", &self.url)
512 .finish()
513 }
514}
515
516trait DynSpawnMcpServer: 'static + Send + Sync {
517 fn spawn(
518 &self,
519 outgoing_bytes: Pin<Box<dyn tokio::io::AsyncWrite + Send>>,
520 incoming_bytes: Pin<Box<dyn tokio::io::AsyncRead + Send>>,
521 ) -> BoxFuture<'static, Result<(), acp::Error>>;
522}
523
524impl AsRef<McpServiceRegistry> for McpServiceRegistry {
525 fn as_ref(&self) -> &McpServiceRegistry {
526 self
527 }
528}