1use futures::channel::mpsc;
2use futures::{SinkExt, StreamExt};
3use fxhash::FxHashMap;
4
5use sacp::schema::NewSessionRequest;
6use sacp::{
7 Channel, Component, DynComponent, Handled, JrConnectionCx, JrHandlerChain, JrMessage,
8 JrMessageHandler, JrRequestCx, MessageAndCx, UntypedMessage,
9};
10use std::sync::{Arc, Mutex};
11
12use crate::{
13 JrCxExt, McpConnectRequest, McpConnectResponse, McpDisconnectNotification,
14 McpOverAcpNotification, McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
15};
16
17#[derive(Clone, Default, Debug)]
31pub struct McpServiceRegistry {
32 data: Arc<Mutex<McpServiceRegistryData>>,
33}
34
35#[derive(Default, Debug)]
36struct McpServiceRegistryData {
37 registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
38 registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
39 connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
40}
41
42impl McpServiceRegistry {
43 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn add_mcp_server<C: Component>(
61 &self,
62 name: impl ToString,
63 new_fn: impl Fn() -> C + Send + Sync + 'static,
64 ) -> Result<(), sacp::Error> {
65 struct FnSpawner<F> {
66 new_fn: F,
67 }
68
69 impl<C, F> SpawnMcpServer for FnSpawner<F>
70 where
71 F: Fn() -> C + Send + Sync + 'static,
72 C: Component,
73 {
74 fn spawn(&self) -> DynComponent {
75 let component = (self.new_fn)();
76 DynComponent::new(component)
77 }
78 }
79
80 let name = name.to_string();
81 if let Some(_) = self.get_registered_server_by_name(&name) {
82 return Err(sacp::util::internal_error(format!(
83 "Server with name '{}' already exists",
84 name
85 )));
86 }
87
88 let uuid = uuid::Uuid::new_v4().to_string();
89 let service = Arc::new(RegisteredMcpServer {
90 name,
91 url: format!("acp:{uuid}"),
92 spawn: Arc::new(FnSpawner { new_fn }),
93 });
94 self.insert_registered_server(service);
95 Ok(())
96 }
97
98 fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
99 let mut data = self.data.lock().expect("not poisoned");
100 data.registered_by_name
101 .insert(service.name.clone(), service.clone());
102 data.registered_by_url
103 .insert(service.url.clone(), service.clone());
104 }
105
106 fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
107 self.data
108 .lock()
109 .expect("not poisoned")
110 .registered_by_name
111 .get(name)
112 .cloned()
113 }
114
115 fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
116 self.data
117 .lock()
118 .expect("not poisoned")
119 .registered_by_url
120 .get(url)
121 .cloned()
122 }
123
124 fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
125 self.data
126 .lock()
127 .expect("not poisoned")
128 .connections
129 .insert(connection_id.to_string(), tx);
130 }
131
132 fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
133 self.data
134 .lock()
135 .expect("not poisoned")
136 .connections
137 .get(connection_id)
138 .cloned()
139 }
140
141 fn remove_connection(&self, connection_id: &str) -> bool {
142 self.data
143 .lock()
144 .expect("not poisoned")
145 .connections
146 .remove(connection_id)
147 .is_some()
148 }
149
150 async fn handle_connect_request(
151 &self,
152 result: Result<SuccessorRequest<McpConnectRequest>, sacp::Error>,
153 request_cx: JrRequestCx<serde_json::Value>,
154 ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
155 let SuccessorRequest { request } = match result {
157 Ok(request) => request,
158 Err(err) => {
159 request_cx.respond_with_error(err)?;
160 return Ok(Handled::Yes);
161 }
162 };
163
164 let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
166 return Ok(Handled::No(request_cx));
167 };
168
169 let request_cx = request_cx.cast::<McpConnectResponse>();
170
171 let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
173 let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
174 self.insert_connection(&connection_id, mcp_server_tx);
175
176 let (client_channel, server_channel) = Channel::duplex();
178
179 let client_component = {
181 let connection_id = connection_id.clone();
182 let outer_cx = request_cx.connection_cx();
183
184 JrHandlerChain::new()
185 .on_receive_message(async move |message: sacp::MessageAndCx| {
186 let wrapped = message.map(
188 |request, request_cx| {
189 (
190 McpOverAcpRequest {
191 connection_id: connection_id.clone(),
192 request,
193 },
194 request_cx,
195 )
196 },
197 |notification, cx| {
198 (
199 McpOverAcpNotification {
200 connection_id: connection_id.clone(),
201 notification,
202 },
203 cx,
204 )
205 },
206 );
207 outer_cx.send_proxied_message(wrapped)
208 })
209 .with_spawned(move |mcp_cx| async move {
210 while let Some(msg) = mcp_server_rx.next().await {
211 mcp_cx.send_proxied_message(msg)?;
212 }
213 Ok(())
214 })
215 };
216
217 let mcp_server = registered_server.spawn.spawn();
219
220 let spawn_results = request_cx
222 .connection_cx()
223 .spawn(async move { client_component.serve(client_channel).await })
224 .and_then(|()| {
225 request_cx
227 .connection_cx()
228 .spawn(async move { mcp_server.serve(server_channel).await })
229 });
230
231 match spawn_results {
232 Ok(()) => {
233 request_cx.respond(McpConnectResponse { connection_id })?;
234 Ok(Handled::Yes)
235 }
236
237 Err(err) => {
238 request_cx.respond_with_error(err)?;
239 Ok(Handled::Yes)
240 }
241 }
242 }
243
244 async fn handle_mcp_over_acp_request(
245 &self,
246 result: Result<SuccessorRequest<McpOverAcpRequest<UntypedMessage>>, sacp::Error>,
247 request_cx: JrRequestCx<serde_json::Value>,
248 ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
249 let SuccessorRequest { request } = match result {
251 Ok(request) => request,
252 Err(err) => {
253 request_cx.respond_with_error(err)?;
254 return Ok(Handled::Yes);
255 }
256 };
257
258 let Some(mut mcp_server_tx) = self.get_connection(&request.connection_id) else {
260 return Ok(Handled::No(request_cx));
261 };
262
263 mcp_server_tx
264 .send(MessageAndCx::Request(request.request, request_cx))
265 .await
266 .map_err(sacp::Error::into_internal_error)?;
267
268 Ok(Handled::Yes)
269 }
270
271 async fn handle_mcp_over_acp_notification(
272 &self,
273 result: Result<SuccessorNotification<McpOverAcpNotification<UntypedMessage>>, sacp::Error>,
274 notification_cx: JrConnectionCx,
275 ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
276 let SuccessorNotification { notification } = match result {
278 Ok(request) => request,
279 Err(err) => {
280 notification_cx.send_error_notification(err)?;
281 return Ok(Handled::Yes);
282 }
283 };
284
285 let Some(mut mcp_server_tx) = self.get_connection(¬ification.connection_id) else {
287 return Ok(Handled::No(notification_cx));
288 };
289
290 mcp_server_tx
291 .send(MessageAndCx::Notification(
292 notification.notification,
293 notification_cx.clone(),
294 ))
295 .await
296 .map_err(sacp::Error::into_internal_error)?;
297
298 Ok(Handled::Yes)
299 }
300
301 async fn handle_mcp_disconnect_notification(
302 &self,
303 result: Result<SuccessorNotification<McpDisconnectNotification>, sacp::Error>,
304 notification_cx: JrConnectionCx,
305 ) -> Result<Handled<JrConnectionCx>, sacp::Error> {
306 let SuccessorNotification { notification } = match result {
308 Ok(request) => request,
309 Err(err) => {
310 notification_cx.send_error_notification(err)?;
311 return Ok(Handled::Yes);
312 }
313 };
314
315 if self.remove_connection(¬ification.connection_id) {
317 Ok(Handled::Yes)
318 } else {
319 Ok(Handled::No(notification_cx))
320 }
321 }
322
323 async fn handle_new_session_request(
324 &self,
325 result: Result<NewSessionRequest, sacp::Error>,
326 request_cx: JrRequestCx<serde_json::Value>,
327 ) -> Result<Handled<JrRequestCx<serde_json::Value>>, sacp::Error> {
328 let mut request = match result {
330 Ok(request) => request,
331 Err(err) => {
332 request_cx.connection_cx().send_error_notification(err)?;
333 return Ok(Handled::Yes);
334 }
335 };
336
337 {
341 let data = self.data.lock().expect("not poisoned");
342 for server in data.registered_by_url.values() {
343 request.mcp_servers.push(server.acp_mcp_server());
344 }
345 }
346
347 request_cx
349 .connection_cx()
350 .send_request_to_successor(request)
351 .forward_to_request_cx(request_cx.cast())?;
352
353 Ok(Handled::Yes)
354 }
355}
356
357impl JrMessageHandler for McpServiceRegistry {
358 fn describe_chain(&self) -> impl std::fmt::Debug {
359 "McpServiceRegistry"
360 }
361
362 async fn handle_message(
363 &mut self,
364 message: sacp::MessageAndCx,
365 ) -> Result<sacp::Handled<sacp::MessageAndCx>, sacp::Error> {
366 match message {
367 sacp::MessageAndCx::Request(msg, mut cx) => {
368 let params = msg.params();
369
370 if let Some(result) =
371 <SuccessorRequest<McpConnectRequest>>::parse_request(cx.method(), params)
372 {
373 cx = match self.handle_connect_request(result, cx).await? {
374 Handled::Yes => return Ok(Handled::Yes),
375 Handled::No(cx) => cx,
376 };
377 }
378
379 if let Some(result) =
380 <SuccessorRequest<McpOverAcpRequest<UntypedMessage>>>::parse_request(
381 cx.method(),
382 params,
383 )
384 {
385 cx = match self.handle_mcp_over_acp_request(result, cx).await? {
386 Handled::Yes => return Ok(Handled::Yes),
387 Handled::No(cx) => cx,
388 };
389 }
390
391 if let Some(result) = <NewSessionRequest>::parse_request(cx.method(), params) {
392 cx = match self.handle_new_session_request(result, cx).await? {
393 Handled::Yes => return Ok(Handled::Yes),
394 Handled::No(cx) => cx,
395 };
396 }
397
398 Ok(Handled::No(sacp::MessageAndCx::Request(msg, cx)))
399 }
400 sacp::MessageAndCx::Notification(msg, mut cx) => {
401 let params = msg.params();
402
403 if let Some(result) =
404 <SuccessorNotification<McpOverAcpNotification<UntypedMessage>>>::parse_notification(
405 msg.method(),
406 params,
407 )
408 {
409 cx = match self.handle_mcp_over_acp_notification(result, cx).await? {
410 Handled::Yes => return Ok(Handled::Yes),
411 Handled::No(cx) => cx,
412 };
413 }
414
415 if let Some(result) =
416 <SuccessorNotification<McpDisconnectNotification>>::parse_notification(
417 msg.method(),
418 params,
419 )
420 {
421 cx = match self.handle_mcp_disconnect_notification(result, cx).await? {
422 Handled::Yes => return Ok(Handled::Yes),
423 Handled::No(cx) => cx,
424 };
425 }
426
427 Ok(sacp::Handled::No(sacp::MessageAndCx::Notification(msg, cx)))
428 }
429 }
430 }
431}
432
433#[derive(Clone)]
434struct RegisteredMcpServer {
435 name: String,
436 url: String,
437 spawn: Arc<dyn SpawnMcpServer>,
438}
439
440impl RegisteredMcpServer {
441 fn acp_mcp_server(&self) -> sacp::schema::McpServer {
442 sacp::schema::McpServer::Http {
443 name: self.name.clone(),
444 url: self.url.clone(),
445 headers: vec![],
446 }
447 }
448}
449
450impl std::fmt::Debug for RegisteredMcpServer {
451 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452 f.debug_struct("RegisteredMcpServer")
453 .field("name", &self.name)
454 .field("url", &self.url)
455 .finish()
456 }
457}
458
459trait SpawnMcpServer: Send + Sync + 'static {
463 fn spawn(&self) -> sacp::DynComponent;
467}
468
469impl AsRef<McpServiceRegistry> for McpServiceRegistry {
470 fn as_ref(&self) -> &McpServiceRegistry {
471 self
472 }
473}