1use futures::channel::mpsc;
4use futures::{SinkExt, StreamExt};
5use fxhash::FxHashMap;
6
7use crate::mcp::{McpClientToServer, McpServerEnd};
8use crate::schema::{
9 McpConnectRequest, McpConnectResponse, McpDisconnectNotification, McpOverAcpMessage,
10 NewSessionRequest, NewSessionResponse,
11};
12use crate::util::MatchMessage;
13use crate::{
14 Agent, Channel, Component, DynComponent, Handled, HasEndpoint, JrConnectionCx,
15 JrMessageHandlerSend, JrRequestCx, JrRole, MessageCx, UntypedMessage,
16};
17use std::sync::{Arc, Mutex};
18
19use super::server::McpServer;
20
21#[derive(Clone, Default, Debug)]
34pub struct McpServiceRegistry<Role: JrRole> {
35 data: Arc<Mutex<McpServiceRegistryData<Role>>>,
36}
37
38#[derive(Default, Debug)]
39struct McpServiceRegistryData<Role: JrRole> {
40 registered_by_name: FxHashMap<String, Arc<RegisteredMcpServer<Role>>>,
41 registered_by_url: FxHashMap<String, Arc<RegisteredMcpServer<Role>>>,
42 connections: FxHashMap<String, mpsc::Sender<MessageCx>>,
43}
44
45impl<Role: JrRole> McpServiceRegistry<Role>
46where
47 Role: HasEndpoint<Agent>,
48{
49 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn with_mcp_server(
60 self,
61 name: impl ToString,
62 server: McpServer<Role>,
63 ) -> Result<Self, crate::Error> {
64 self.add_mcp_server_with_context(name, move |mcp_cx| server.new_connection(mcp_cx))?;
65 Ok(self)
66 }
67
68 pub fn add_mcp_server<C: Component>(
81 &self,
82 name: impl ToString,
83 new_fn: impl Fn() -> C + Send + Sync + 'static,
84 ) -> Result<(), crate::Error> {
85 struct FnSpawner<F> {
86 new_fn: F,
87 }
88
89 impl<Role, C, F> SpawnMcpServer<Role> for FnSpawner<F>
90 where
91 Role: JrRole,
92 F: Fn() -> C + Send + Sync + 'static,
93 C: Component,
94 {
95 fn spawn(&self, _cx: McpContext<Role>) -> DynComponent {
96 let component = (self.new_fn)();
97 DynComponent::new(component)
98 }
99 }
100
101 self.add_mcp_server_internal(name, FnSpawner { new_fn })
102 }
103
104 pub fn add_mcp_server_with_context<C: Component>(
116 &self,
117 name: impl ToString,
118 new_fn: impl Fn(McpContext<Role>) -> C + Send + Sync + 'static,
119 ) -> Result<(), crate::Error> {
120 struct FnSpawner<F> {
121 new_fn: F,
122 }
123
124 impl<Role, C, F> SpawnMcpServer<Role> for FnSpawner<F>
125 where
126 Role: JrRole,
127 F: Fn(McpContext<Role>) -> C + Send + Sync + 'static,
128 C: Component,
129 {
130 fn spawn(&self, cx: McpContext<Role>) -> DynComponent {
131 let component = (self.new_fn)(cx);
132 DynComponent::new(component)
133 }
134 }
135
136 self.add_mcp_server_internal(name, FnSpawner { new_fn })
137 }
138
139 fn add_mcp_server_internal(
140 &self,
141 name: impl ToString,
142 spawner: impl SpawnMcpServer<Role>,
143 ) -> Result<(), crate::Error> {
144 let name = name.to_string();
145 if self.get_registered_server_by_name(&name).is_some() {
146 return Err(crate::util::internal_error(format!(
147 "Server with name '{}' already exists",
148 name
149 )));
150 }
151
152 let uuid = uuid::Uuid::new_v4().to_string();
153 let service = Arc::new(RegisteredMcpServer {
154 name,
155 url: format!("acp:{uuid}"),
156 spawn: Arc::new(spawner),
157 });
158 self.insert_registered_server(service);
159 Ok(())
160 }
161
162 fn insert_registered_server(&self, service: Arc<RegisteredMcpServer<Role>>) {
163 let mut data = self.data.lock().expect("not poisoned");
164 data.registered_by_name
165 .insert(service.name.clone(), service.clone());
166 data.registered_by_url
167 .insert(service.url.clone(), service.clone());
168 }
169
170 fn get_registered_server_by_name(&self, name: &str) -> Option<Arc<RegisteredMcpServer<Role>>> {
171 self.data
172 .lock()
173 .expect("not poisoned")
174 .registered_by_name
175 .get(name)
176 .cloned()
177 }
178
179 fn get_registered_server_by_url(&self, url: &str) -> Option<Arc<RegisteredMcpServer<Role>>> {
180 self.data
181 .lock()
182 .expect("not poisoned")
183 .registered_by_url
184 .get(url)
185 .cloned()
186 }
187
188 fn insert_connection(&self, connection_id: &str, tx: mpsc::Sender<MessageCx>) {
189 self.data
190 .lock()
191 .expect("not poisoned")
192 .connections
193 .insert(connection_id.to_string(), tx);
194 }
195
196 fn get_connection(&self, connection_id: &str) -> Option<mpsc::Sender<MessageCx>> {
197 self.data
198 .lock()
199 .expect("not poisoned")
200 .connections
201 .get(connection_id)
202 .cloned()
203 }
204
205 fn remove_connection(&self, connection_id: &str) -> bool {
206 self.data
207 .lock()
208 .expect("not poisoned")
209 .connections
210 .remove(connection_id)
211 .is_some()
212 }
213
214 pub fn add_registered_mcp_servers_to(&self, request: &mut NewSessionRequest) {
221 let data = self.data.lock().expect("not poisoned");
222 for server in data.registered_by_url.values() {
223 request.mcp_servers.push(server.acp_mcp_server());
224 }
225 }
226
227 async fn handle_connect_request(
228 &self,
229 request: McpConnectRequest,
230 request_cx: JrRequestCx<McpConnectResponse>,
231 outer_cx: JrConnectionCx<Role>,
232 ) -> Result<Handled<(McpConnectRequest, JrRequestCx<McpConnectResponse>)>, crate::Error> {
233 let Some(registered_server) = self.get_registered_server_by_url(&request.acp_url) else {
235 return Ok(Handled::No((request, request_cx)));
236 };
237
238 let connection_id = format!("mcp-over-acp-connection:{}", uuid::Uuid::new_v4());
240 let (mcp_server_tx, mut mcp_server_rx) = mpsc::channel(128);
241 self.insert_connection(&connection_id, mcp_server_tx);
242
243 let (client_channel, server_channel) = Channel::duplex();
245
246 let client_component = {
248 let connection_id = connection_id.clone();
249 let outer_cx = outer_cx.clone();
250
251 McpClientToServer::builder()
252 .on_receive_message(async move |message: MessageCx, _mcp_cx| {
253 let wrapped = message.map(
255 |request, request_cx| {
256 (
257 McpOverAcpMessage {
258 connection_id: connection_id.clone(),
259 message: request,
260 meta: None,
261 },
262 request_cx,
263 )
264 },
265 |notification| McpOverAcpMessage {
266 connection_id: connection_id.clone(),
267 message: notification,
268 meta: None,
269 },
270 );
271 outer_cx.send_proxied_message_to(Agent, wrapped)
272 })
273 .with_spawned(move |mcp_cx| async move {
274 while let Some(msg) = mcp_server_rx.next().await {
277 mcp_cx.send_proxied_message_to(McpServerEnd, msg)?;
278 }
279 Ok(())
280 })
281 };
282
283 let mcp_server = registered_server.spawn.spawn(McpContext {
285 acp_url: request.acp_url.clone(),
286 connection_cx: outer_cx.clone(),
287 });
288
289 let spawn_results = outer_cx
291 .spawn(async move { client_component.serve(client_channel).await })
292 .and_then(|()| {
293 outer_cx.spawn(async move { mcp_server.serve(server_channel).await })
295 });
296
297 match spawn_results {
298 Ok(()) => {
299 request_cx.respond(McpConnectResponse {
300 connection_id,
301 meta: None,
302 })?;
303 Ok(Handled::Yes)
304 }
305
306 Err(err) => {
307 request_cx.respond_with_error(err)?;
308 Ok(Handled::Yes)
309 }
310 }
311 }
312
313 async fn handle_mcp_over_acp_request(
314 &self,
315 request: McpOverAcpMessage<UntypedMessage>,
316 request_cx: JrRequestCx<serde_json::Value>,
317 ) -> Result<
318 Handled<(
319 McpOverAcpMessage<UntypedMessage>,
320 JrRequestCx<serde_json::Value>,
321 )>,
322 crate::Error,
323 > {
324 let Some(mut mcp_server_tx) = self.get_connection(&request.connection_id) else {
326 return Ok(Handled::No((request, request_cx)));
327 };
328
329 mcp_server_tx
330 .send(MessageCx::Request(request.message, request_cx))
331 .await
332 .map_err(crate::Error::into_internal_error)?;
333
334 Ok(Handled::Yes)
335 }
336
337 async fn handle_mcp_over_acp_notification(
338 &self,
339 notification: McpOverAcpMessage<UntypedMessage>,
340 ) -> Result<Handled<McpOverAcpMessage<UntypedMessage>>, crate::Error> {
341 let Some(mut mcp_server_tx) = self.get_connection(¬ification.connection_id) else {
343 return Ok(Handled::No(notification));
344 };
345
346 mcp_server_tx
347 .send(MessageCx::Notification(notification.message))
348 .await
349 .map_err(crate::Error::into_internal_error)?;
350
351 Ok(Handled::Yes)
352 }
353
354 async fn handle_mcp_disconnect_notification(
355 &self,
356 successor_notification: McpDisconnectNotification,
357 ) -> Result<Handled<McpDisconnectNotification>, crate::Error> {
358 if self.remove_connection(&successor_notification.connection_id) {
360 Ok(Handled::Yes)
361 } else {
362 Ok(Handled::No(successor_notification))
363 }
364 }
365
366 async fn handle_new_session_request(
367 &self,
368 mut request: NewSessionRequest,
369 request_cx: JrRequestCx<NewSessionResponse>,
370 ) -> Result<Handled<(NewSessionRequest, JrRequestCx<NewSessionResponse>)>, crate::Error> {
371 self.add_registered_mcp_servers_to(&mut request);
375
376 Ok(Handled::No((request, request_cx)))
378 }
379}
380
381impl<Role: JrRole> JrMessageHandlerSend for McpServiceRegistry<Role>
382where
383 Role: HasEndpoint<Agent>,
384{
385 type Role = Role;
386
387 fn describe_chain(&self) -> impl std::fmt::Debug {
388 "McpServiceRegistry"
389 }
390
391 async fn handle_message(
392 &mut self,
393 message: MessageCx,
394 connection_cx: JrConnectionCx<Role>,
395 ) -> Result<Handled<MessageCx>, crate::Error> {
396 MatchMessage::new(message)
408 .if_request_from(Agent, connection_cx.clone(), |request, request_cx, cx| {
410 self.handle_connect_request(request, request_cx, cx)
411 })
412 .await
413 .if_request_from(Agent, connection_cx.clone(), |request, request_cx, _cx| {
415 self.handle_mcp_over_acp_request(request, request_cx)
416 })
417 .await
418 .if_request(|request, request_cx| self.handle_new_session_request(request, request_cx))
420 .await
421 .if_notification_from(Agent, connection_cx.clone(), |notification, _cx| {
423 self.handle_mcp_over_acp_notification(notification)
424 })
425 .await
426 .if_notification_from(Agent, connection_cx, |notification, _cx| {
428 self.handle_mcp_disconnect_notification(notification)
429 })
430 .await
431 .done()
432 }
433}
434
435#[derive(Clone)]
437struct RegisteredMcpServer<Role: JrRole> {
438 name: String,
439 url: String,
440 spawn: Arc<dyn SpawnMcpServer<Role>>,
441}
442
443impl<Role: JrRole> RegisteredMcpServer<Role> {
444 fn acp_mcp_server(&self) -> crate::schema::McpServer {
445 crate::schema::McpServer::Http {
446 name: self.name.clone(),
447 url: self.url.clone(),
448 headers: vec![],
449 }
450 }
451}
452
453impl<Role: JrRole> std::fmt::Debug for RegisteredMcpServer<Role> {
454 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455 f.debug_struct("RegisteredMcpServer")
456 .field("name", &self.name)
457 .field("url", &self.url)
458 .finish()
459 }
460}
461
462trait SpawnMcpServer<Role: JrRole>: Send + Sync + 'static {
466 fn spawn(&self, cx: McpContext<Role>) -> DynComponent;
470}
471
472impl<Role: JrRole> AsRef<McpServiceRegistry<Role>> for McpServiceRegistry<Role>
473where
474 Role: HasEndpoint<Agent>,
475{
476 fn as_ref(&self) -> &McpServiceRegistry<Role> {
477 self
478 }
479}
480
481#[derive(Clone)]
483pub struct McpContext<Role: JrRole> {
484 acp_url: String,
485 connection_cx: JrConnectionCx<Role>,
486}
487
488impl<Role: JrRole> McpContext<Role> {
489 pub fn acp_url(&self) -> String {
491 self.acp_url.clone()
492 }
493
494 pub fn connection_cx(&self) -> JrConnectionCx<Role> {
497 self.connection_cx.clone()
498 }
499}