1use futures::channel::mpsc;
2use futures::{SinkExt, StreamExt};
3use fxhash::FxHashMap;
4
5use sacp::schema::{NewSessionRequest, NewSessionResponse};
6use sacp::util::MatchMessage;
7use sacp::{
8 Channel, Component, DynComponent, Handled, JrConnectionCx, JrHandlerChain, JrMessageHandler,
9 JrRequestCx, MessageAndCx, UntypedMessage,
10};
11use std::sync::{Arc, Mutex};
12
13use crate::mcp_server_builder::McpServer;
14use crate::{
15 McpConnectRequest, McpConnectResponse, McpDisconnectNotification, McpOverAcpNotification,
16 McpOverAcpRequest, SuccessorNotification, SuccessorRequest,
17};
18
19#[derive(Clone, Default, Debug)]
33pub struct McpServiceRegistry {
34 data: Arc<Mutex<McpServiceRegistryData>>,
35}
36
37#[derive(Default, Debug)]
38struct McpServiceRegistryData {
39 registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer>>,
40 registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer>>,
41 connections: FxHashMap<String, mpsc::Sender<MessageAndCx>>,
42}
43
44impl McpServiceRegistry {
45 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn with_mcp_server(
56 self,
57 name: impl ToString,
58 server: McpServer,
59 ) -> Result<Self, sacp::Error> {
60 self.add_mcp_server_with_context(name, move |mcp_cx| server.new_connection(mcp_cx))?;
61 Ok(self)
62 }
63
64 pub fn add_mcp_server<C: Component>(
77 &self,
78 name: impl ToString,
79 new_fn: impl Fn() -> C + Send + Sync + 'static,
80 ) -> Result<(), sacp::Error> {
81 struct FnSpawner<F> {
82 new_fn: F,
83 }
84
85 impl<C, F> SpawnMcpServer for FnSpawner<F>
86 where
87 F: Fn() -> C + Send + Sync + 'static,
88 C: Component,
89 {
90 fn spawn(&self, _cx: McpContext) -> DynComponent {
91 let component = (self.new_fn)();
92 DynComponent::new(component)
93 }
94 }
95
96 self.add_mcp_server_internal(name, FnSpawner { new_fn })
97 }
98
99 pub fn add_mcp_server_with_context<C: Component>(
111 &self,
112 name: impl ToString,
113 new_fn: impl Fn(McpContext) -> C + Send + Sync + 'static,
114 ) -> Result<(), sacp::Error> {
115 struct FnSpawner<F> {
116 new_fn: F,
117 }
118
119 impl<C, F> SpawnMcpServer for FnSpawner<F>
120 where
121 F: Fn(McpContext) -> C + Send + Sync + 'static,
122 C: Component,
123 {
124 fn spawn(&self, cx: McpContext) -> DynComponent {
125 let component = (self.new_fn)(cx);
126 DynComponent::new(component)
127 }
128 }
129
130 self.add_mcp_server_internal(name, FnSpawner { new_fn })
131 }
132 fn add_mcp_server_internal(
133 &self,
134 name: impl ToString,
135 spawner: impl SpawnMcpServer,
136 ) -> Result<(), sacp::Error> {
137 let name = name.to_string();
138 if let Some(_) = self.get_registered_server_by_name(&name) {
139 return Err(sacp::util::internal_error(format!(
140 "Server with name '{}' already exists",
141 name
142 )));
143 }
144
145 let uuid = uuid::Uuid::new_v4().to_string();
146 let service = Arc::new(RegisteredMcpServer {
147 name,
148 url: format!("acp:{uuid}"),
149 spawn: Arc::new(spawner),
150 });
151 self.insert_registered_server(service);
152 Ok(())
153 }
154
155 fn insert_registered_server(&self, service: Arc<RegisteredMcpServer>) {
156 let mut data = self.data.lock().expect("not poisoned");
157 data.registered_by_name
158 .insert(service.name.clone(), service.clone());
159 data.registered_by_url
160 .insert(service.url.clone(), service.clone());
161 }
162
163 fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer>> {
164 self.data
165 .lock()
166 .expect("not poisoned")
167 .registered_by_name
168 .get(name)
169 .cloned()
170 }
171
172 fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer>> {
173 self.data
174 .lock()
175 .expect("not poisoned")
176 .registered_by_url
177 .get(url)
178 .cloned()
179 }
180
181 fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<sacp::MessageAndCx>) {
182 self.data
183 .lock()
184 .expect("not poisoned")
185 .connections
186 .insert(connection_id.to_string(), tx);
187 }
188
189 fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<sacp::MessageAndCx>> {
190 self.data
191 .lock()
192 .expect("not poisoned")
193 .connections
194 .get(connection_id)
195 .cloned()
196 }
197
198 fn remove_connection(&self, connection_id: &str) -> bool {
199 self.data
200 .lock()
201 .expect("not poisoned")
202 .connections
203 .remove(connection_id)
204 .is_some()
205 }
206
207 pub fn add_registered_mcp_servers_to(&self, request: &mut NewSessionRequest) {
226 let data = self.data.lock().expect("not poisoned");
227 for server in data.registered_by_url.values() {
228 request.mcp_servers.push(server.acp_mcp_server());
229 }
230 }
231
232 async fn handle_connect_request(
233 &self,
234 successor_request: SuccessorRequest<McpConnectRequest>,
235 request_cx: JrRequestCx<McpConnectResponse>,
236 ) -> Result<
237 Handled<(
238 SuccessorRequest<McpConnectRequest>,
239 JrRequestCx<McpConnectResponse>,
240 )>,
241 sacp::Error,
242 > {
243 let SuccessorRequest { request, .. } = &successor_request;
244 let outer_cx = request_cx.connection_cx();
245
246 let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
248 return Ok(Handled::No((successor_request, request_cx)));
249 };
250
251 let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
253 let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
254 self.insert_connection(&connection_id, mcp_server_tx);
255
256 let (client_channel, server_channel) = Channel::duplex();
258
259 let client_component = {
261 let connection_id = connection_id.clone();
262 let outer_cx = outer_cx.clone();
263
264 JrHandlerChain::new()
265 .on_receive_message(async move |message: sacp::MessageAndCx| {
266 let wrapped = message.map(
268 |request, request_cx| {
269 (
270 McpOverAcpRequest {
271 connection_id: connection_id.clone(),
272 request,
273 meta: None,
274 },
275 request_cx,
276 )
277 },
278 |notification, cx| {
279 (
280 McpOverAcpNotification {
281 connection_id: connection_id.clone(),
282 notification,
283 meta: None,
284 },
285 cx,
286 )
287 },
288 );
289 outer_cx.send_proxied_message(wrapped)
290 })
291 .with_spawned(move |mcp_cx| async move {
292 while let Some(msg) = mcp_server_rx.next().await {
293 mcp_cx.send_proxied_message(msg)?;
294 }
295 Ok(())
296 })
297 };
298
299 let mcp_server = registered_server.spawn.spawn(McpContext {
301 acp_url: request.acp_url.clone(),
302 connection_cx: outer_cx.clone(),
303 });
304
305 let spawn_results = outer_cx
307 .spawn(async move { client_component.serve(client_channel).await })
308 .and_then(|()| {
309 request_cx
311 .connection_cx()
312 .spawn(async move { mcp_server.serve(server_channel).await })
313 });
314
315 match spawn_results {
316 Ok(()) => {
317 request_cx.respond(McpConnectResponse {
318 connection_id,
319 meta: None,
320 })?;
321 Ok(Handled::Yes)
322 }
323
324 Err(err) => {
325 request_cx.respond_with_error(err)?;
326 Ok(Handled::Yes)
327 }
328 }
329 }
330
331 async fn handle_mcp_over_acp_request(
332 &self,
333 successor_request: SuccessorRequest<McpOverAcpRequest<UntypedMessage>>,
334 request_cx: JrRequestCx<serde_json::Value>,
335 ) -> Result<
336 Handled<(
337 SuccessorRequest<McpOverAcpRequest<UntypedMessage>>,
338 JrRequestCx<serde_json::Value>,
339 )>,
340 sacp::Error,
341 > {
342 let Some(mut mcp_server_tx) = self.get_connection(&successor_request.request.connection_id)
344 else {
345 return Ok(Handled::No((successor_request, request_cx)));
346 };
347
348 let SuccessorRequest { request, .. } = successor_request;
349
350 mcp_server_tx
351 .send(MessageAndCx::Request(request.request, request_cx))
352 .await
353 .map_err(sacp::Error::into_internal_error)?;
354
355 Ok(Handled::Yes)
356 }
357
358 async fn handle_mcp_over_acp_notification(
359 &self,
360 successor_notification: SuccessorNotification<McpOverAcpNotification<UntypedMessage>>,
361 notification_cx: JrConnectionCx,
362 ) -> Result<
363 Handled<(
364 SuccessorNotification<McpOverAcpNotification<UntypedMessage>>,
365 JrConnectionCx,
366 )>,
367 sacp::Error,
368 > {
369 let Some(mut mcp_server_tx) =
371 self.get_connection(&successor_notification.notification.connection_id)
372 else {
373 return Ok(Handled::No((successor_notification, notification_cx)));
374 };
375
376 let SuccessorNotification { notification, .. } = successor_notification;
377
378 mcp_server_tx
379 .send(MessageAndCx::Notification(
380 notification.notification,
381 notification_cx.clone(),
382 ))
383 .await
384 .map_err(sacp::Error::into_internal_error)?;
385
386 Ok(Handled::Yes)
387 }
388
389 async fn handle_mcp_disconnect_notification(
390 &self,
391 successor_notification: SuccessorNotification<McpDisconnectNotification>,
392 notification_cx: JrConnectionCx,
393 ) -> Result<
394 Handled<(
395 SuccessorNotification<McpDisconnectNotification>,
396 JrConnectionCx,
397 )>,
398 sacp::Error,
399 > {
400 let SuccessorNotification { notification, .. } = &successor_notification;
401
402 if self.remove_connection(¬ification.connection_id) {
404 Ok(Handled::Yes)
405 } else {
406 Ok(Handled::No((successor_notification, notification_cx)))
407 }
408 }
409
410 async fn handle_new_session_request(
411 &self,
412 mut request: NewSessionRequest,
413 request_cx: JrRequestCx<NewSessionResponse>,
414 ) -> Result<Handled<(NewSessionRequest, JrRequestCx<NewSessionResponse>)>, sacp::Error> {
415 self.add_registered_mcp_servers_to(&mut request);
419
420 Ok(Handled::No((request, request_cx)))
422 }
423}
424
425impl JrMessageHandler for McpServiceRegistry {
426 fn describe_chain(&self) -> impl std::fmt::Debug {
427 "McpServiceRegistry"
428 }
429
430 async fn handle_message(
431 &mut self,
432 message: sacp::MessageAndCx,
433 ) -> Result<sacp::Handled<sacp::MessageAndCx>, sacp::Error> {
434 MatchMessage::new(message)
435 .if_request(|request, request_cx| self.handle_connect_request(request, request_cx))
436 .await
437 .if_request(|request, request_cx| self.handle_mcp_over_acp_request(request, request_cx))
438 .await
439 .if_request(|request, request_cx| self.handle_new_session_request(request, request_cx))
440 .await
441 .if_notification(|notification, notification_cx| {
442 self.handle_mcp_over_acp_notification(notification, notification_cx)
443 })
444 .await
445 .if_notification(|notification, notification_cx| {
446 self.handle_mcp_disconnect_notification(notification, notification_cx)
447 })
448 .await
449 .done()
450 }
451}
452
453#[derive(Clone)]
455struct RegisteredMcpServer {
456 name: String,
457 url: String,
458 spawn: Arc<dyn SpawnMcpServer>,
459}
460
461impl RegisteredMcpServer {
462 fn acp_mcp_server(&self) -> sacp::schema::McpServer {
463 sacp::schema::McpServer::Http {
464 name: self.name.clone(),
465 url: self.url.clone(),
466 headers: vec![],
467 }
468 }
469}
470
471impl std::fmt::Debug for RegisteredMcpServer {
472 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
473 f.debug_struct("RegisteredMcpServer")
474 .field("name", &self.name)
475 .field("url", &self.url)
476 .finish()
477 }
478}
479
480trait SpawnMcpServer: Send + Sync + 'static {
484 fn spawn(&self, cx: McpContext) -> sacp::DynComponent;
488}
489
490impl AsRef<McpServiceRegistry> for McpServiceRegistry {
491 fn as_ref(&self) -> &McpServiceRegistry {
492 self
493 }
494}
495
496#[derive(Clone)]
498pub struct McpContext {
499 acp_url: String,
500 connection_cx: JrConnectionCx,
501}
502
503impl McpContext {
504 pub fn acp_url(&self) -> String {
506 self.acp_url.clone()
507 }
508
509 pub fn connection_cx(&self) -> JrConnectionCx {
512 self.connection_cx.clone()
513 }
514}