1use agent_client_protocol::{self as acp, InitializeRequest, InitializeResponse};
2use futures::{AsyncRead, AsyncWrite};
3use sacp::{
4 ChainHandler, Handled, JsonRpcConnection, JsonRpcConnectionCx, JsonRpcHandler, JsonRpcMessage,
5 JsonRpcNotification, JsonRpcRequest, JsonRpcRequestCx, MessageAndCx, MetaCapabilityExt, Proxy,
6 UntypedMessage,
7};
8use serde::{Deserialize, Serialize};
9use std::marker::PhantomData;
10
11use crate::mcp_server::McpServiceRegistry;
12
13const SUCCESSOR_REQUEST_METHOD: &str = "_proxy/successor/request";
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct SuccessorRequest<Req: JsonRpcRequest> {
23 #[serde(flatten)]
25 pub request: Req,
26}
27
28impl<Req: JsonRpcRequest> JsonRpcMessage for SuccessorRequest<Req> {
29 fn into_untyped_message(self) -> Result<sacp::UntypedMessage, acp::Error> {
30 sacp::UntypedMessage::new(
31 SUCCESSOR_REQUEST_METHOD,
32 SuccessorRequest {
33 request: self.request.into_untyped_message()?,
34 },
35 )
36 }
37
38 fn method(&self) -> &str {
39 SUCCESSOR_REQUEST_METHOD
40 }
41
42 fn parse_request(method: &str, params: &impl Serialize) -> Option<Result<Self, acp::Error>> {
43 if method == SUCCESSOR_REQUEST_METHOD {
44 match sacp::util::json_cast::<_, SuccessorRequest<sacp::UntypedMessage>>(params) {
45 Ok(outer) => match Req::parse_request(&outer.request.method, &outer.request.params)
46 {
47 Some(Ok(request)) => Some(Ok(SuccessorRequest { request })),
48 Some(Err(err)) => Some(Err(err)),
49 None => None,
50 },
51 Err(err) => Some(Err(err)),
52 }
53 } else {
54 None
55 }
56 }
57
58 fn parse_notification(
59 _method: &str,
60 _params: &impl Serialize,
61 ) -> Option<Result<Self, acp::Error>> {
62 None }
64}
65
66impl<Req: JsonRpcRequest> JsonRpcRequest for SuccessorRequest<Req> {
67 type Response = Req::Response;
68}
69
70const SUCCESSOR_NOTIFICATION_METHOD: &str = "_proxy/successor/notification";
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SuccessorNotification<Req: JsonRpcNotification> {
77 #[serde(flatten)]
79 pub notification: Req,
80}
81
82impl<Req: JsonRpcNotification> JsonRpcMessage for SuccessorNotification<Req> {
83 fn into_untyped_message(self) -> Result<sacp::UntypedMessage, acp::Error> {
84 sacp::UntypedMessage::new(
85 SUCCESSOR_NOTIFICATION_METHOD,
86 SuccessorNotification {
87 notification: self.notification.into_untyped_message()?,
88 },
89 )
90 }
91
92 fn method(&self) -> &str {
93 SUCCESSOR_NOTIFICATION_METHOD
94 }
95
96 fn parse_request(_method: &str, _params: &impl Serialize) -> Option<Result<Self, acp::Error>> {
97 None }
99
100 fn parse_notification(
101 method: &str,
102 params: &impl Serialize,
103 ) -> Option<Result<Self, acp::Error>> {
104 if method == SUCCESSOR_NOTIFICATION_METHOD {
105 match sacp::util::json_cast::<_, SuccessorNotification<sacp::UntypedMessage>>(params) {
106 Ok(outer) => match Req::parse_notification(
107 &outer.notification.method,
108 &outer.notification.params,
109 ) {
110 Some(Ok(notification)) => Some(Ok(SuccessorNotification { notification })),
111 Some(Err(err)) => Some(Err(err)),
112 None => None,
113 },
114 Err(err) => Some(Err(err)),
115 }
116 } else {
117 None
118 }
119 }
120}
121
122impl<Req: JsonRpcNotification> JsonRpcNotification for SuccessorNotification<Req> {}
123
124pub trait AcpProxyExt<OB: AsyncWrite, IB: AsyncRead, H: JsonRpcHandler> {
128 fn on_receive_request_from_successor<R, F>(
151 self,
152 op: F,
153 ) -> JsonRpcConnection<OB, IB, ChainHandler<H, RequestFromSuccessorHandler<R, F>>>
154 where
155 R: JsonRpcRequest,
156 F: AsyncFnMut(R, JsonRpcRequestCx<R::Response>) -> Result<(), acp::Error>;
157
158 fn on_receive_notification_from_successor<N, F>(
181 self,
182 op: F,
183 ) -> JsonRpcConnection<OB, IB, ChainHandler<H, NotificationFromSuccessorHandler<N, F>>>
184 where
185 N: JsonRpcNotification,
186 F: AsyncFnMut(N, JsonRpcConnectionCx) -> Result<(), acp::Error>;
187
188 fn proxy(self) -> JsonRpcConnection<OB, IB, ChainHandler<H, ProxyHandler>>;
191
192 fn provide_mcp(
196 self,
197 registry: impl AsRef<McpServiceRegistry>,
198 ) -> JsonRpcConnection<OB, IB, ChainHandler<H, McpServiceRegistry>>;
199}
200
201impl<OB, IB, H> AcpProxyExt<OB, IB, H> for JsonRpcConnection<OB, IB, H>
202where
203 OB: AsyncWrite,
204 IB: AsyncRead,
205 H: JsonRpcHandler,
206{
207 fn on_receive_request_from_successor<R, F>(
208 self,
209 op: F,
210 ) -> JsonRpcConnection<OB, IB, ChainHandler<H, RequestFromSuccessorHandler<R, F>>>
211 where
212 R: JsonRpcRequest,
213 F: AsyncFnMut(R, JsonRpcRequestCx<R::Response>) -> Result<(), acp::Error>,
214 {
215 self.chain_handler(RequestFromSuccessorHandler::new(op))
216 }
217
218 fn on_receive_notification_from_successor<N, F>(
219 self,
220 op: F,
221 ) -> JsonRpcConnection<OB, IB, ChainHandler<H, NotificationFromSuccessorHandler<N, F>>>
222 where
223 N: JsonRpcNotification,
224 F: AsyncFnMut(N, JsonRpcConnectionCx) -> Result<(), acp::Error>,
225 {
226 self.chain_handler(NotificationFromSuccessorHandler::new(op))
227 }
228
229 fn proxy(self) -> JsonRpcConnection<OB, IB, ChainHandler<H, ProxyHandler>> {
230 self.chain_handler(ProxyHandler {})
231 }
232
233 fn provide_mcp(
234 self,
235 registry: impl AsRef<McpServiceRegistry>,
236 ) -> JsonRpcConnection<OB, IB, ChainHandler<H, McpServiceRegistry>> {
237 self.chain_handler(registry.as_ref().clone())
238 }
239}
240
241pub struct RequestFromSuccessorHandler<R, F>
243where
244 R: JsonRpcRequest,
245 F: AsyncFnMut(R, JsonRpcRequestCx<R::Response>) -> Result<(), acp::Error>,
246{
247 handler: F,
248 phantom: PhantomData<fn(R)>,
249}
250
251impl<R, F> RequestFromSuccessorHandler<R, F>
252where
253 R: JsonRpcRequest,
254 F: AsyncFnMut(R, JsonRpcRequestCx<R::Response>) -> Result<(), acp::Error>,
255{
256 pub fn new(handler: F) -> Self {
257 Self {
258 handler,
259 phantom: PhantomData,
260 }
261 }
262}
263
264impl<R, F> JsonRpcHandler for RequestFromSuccessorHandler<R, F>
265where
266 R: JsonRpcRequest,
267 F: AsyncFnMut(R, JsonRpcRequestCx<R::Response>) -> Result<(), acp::Error>,
268{
269 async fn handle_message(
270 &mut self,
271 message: sacp::MessageAndCx,
272 ) -> Result<Handled<sacp::MessageAndCx>, agent_client_protocol::Error> {
273 let MessageAndCx::Request(request, cx) = message else {
274 return Ok(Handled::No(message));
275 };
276
277 tracing::debug!(
278 request_type = std::any::type_name::<R>(),
279 message = ?request,
280 "RequestHandler::handle_request"
281 );
282 match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
283 Some(Ok(request)) => {
284 tracing::trace!(?request, "RequestHandler::handle_request: parse completed");
285 (self.handler)(request.request, cx.cast()).await?;
286 Ok(Handled::Yes)
287 }
288 Some(Err(err)) => {
289 tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
290 Err(err)
291 }
292 None => {
293 tracing::trace!("RequestHandler::handle_request: parse failed");
294 Ok(Handled::No(MessageAndCx::Request(request, cx)))
295 }
296 }
297 }
298
299 fn describe_chain(&self) -> impl std::fmt::Debug {
300 std::any::type_name::<R>()
301 }
302}
303
304pub struct NotificationFromSuccessorHandler<N, F>
306where
307 N: JsonRpcNotification,
308 F: AsyncFnMut(N, JsonRpcConnectionCx) -> Result<(), acp::Error>,
309{
310 handler: F,
311 phantom: PhantomData<fn(N)>,
312}
313
314impl<N, F> NotificationFromSuccessorHandler<N, F>
315where
316 N: JsonRpcNotification,
317 F: AsyncFnMut(N, JsonRpcConnectionCx) -> Result<(), acp::Error>,
318{
319 pub fn new(handler: F) -> Self {
320 Self {
321 handler,
322 phantom: PhantomData,
323 }
324 }
325}
326
327impl<N, F> JsonRpcHandler for NotificationFromSuccessorHandler<N, F>
328where
329 N: JsonRpcNotification,
330 F: AsyncFnMut(N, JsonRpcConnectionCx) -> Result<(), acp::Error>,
331{
332 async fn handle_message(
333 &mut self,
334 message: sacp::MessageAndCx,
335 ) -> Result<Handled<sacp::MessageAndCx>, agent_client_protocol::Error> {
336 let MessageAndCx::Notification(message, cx) = message else {
337 return Ok(Handled::No(message));
338 };
339
340 match <SuccessorNotification<N>>::parse_notification(&message.method, &message.params) {
341 Some(Ok(notification)) => {
342 tracing::trace!(
343 ?notification,
344 "NotificationFromSuccessorHandler::handle_request: parse completed"
345 );
346 (self.handler)(notification.notification, cx).await?;
347 Ok(Handled::Yes)
348 }
349 Some(Err(err)) => {
350 tracing::trace!(
351 ?err,
352 "NotificationFromSuccessorHandler::handle_request: parse errored"
353 );
354 Err(err)
355 }
356 None => {
357 tracing::trace!("NotificationFromSuccessorHandler::handle_request: parse failed");
358 Ok(Handled::No(MessageAndCx::Notification(message, cx)))
359 }
360 }
361 }
362
363 fn describe_chain(&self) -> impl std::fmt::Debug {
364 format!("FromSuccessor<{}>", std::any::type_name::<N>())
365 }
366}
367
368pub struct ProxyHandler {}
370
371impl JsonRpcHandler for ProxyHandler {
372 fn describe_chain(&self) -> impl std::fmt::Debug {
373 "proxy"
374 }
375
376 async fn handle_message(
377 &mut self,
378 message: sacp::MessageAndCx,
379 ) -> Result<Handled<sacp::MessageAndCx>, agent_client_protocol::Error> {
380 tracing::debug!(
381 message = ?message.message(),
382 "ProxyHandler::handle_request"
383 );
384
385 match message {
386 MessageAndCx::Request(request, request_cx) => {
387 if let Some(result) = <SuccessorRequest<UntypedMessage>>::parse_request(
389 &request.method,
390 &request.params,
391 ) {
392 let request = result?;
393 request_cx
394 .send_request(request.request)
395 .forward_to_request_cx(request_cx)?;
396 return Ok(Handled::Yes);
397 }
398
399 if let Some(result) =
401 InitializeRequest::parse_request(&request.method, &request.params)
402 {
403 let request = result?;
404 return self
405 .forward_initialize(request, request_cx.cast())
406 .await
407 .map(|()| Handled::Yes);
408 }
409
410 request_cx
412 .send_request_to_successor(request)
413 .forward_to_request_cx(request_cx)?;
414 Ok(Handled::Yes)
415 }
416
417 MessageAndCx::Notification(notification, cx) => {
418 if let Some(result) = <SuccessorNotification<UntypedMessage>>::parse_notification(
420 ¬ification.method,
421 ¬ification.params,
422 ) {
423 match result {
424 Ok(r) => {
425 cx.send_notification(r.notification)?;
426 return Ok(Handled::Yes);
427 }
428 Err(err) => return Err(err),
429 }
430 }
431
432 cx.send_notification_to_successor(notification)?;
434 Ok(Handled::Yes)
435 }
436 }
437 }
438}
439
440impl ProxyHandler {
441 async fn forward_initialize(
445 &mut self,
446 mut request: InitializeRequest,
447 request_cx: JsonRpcRequestCx<InitializeResponse>,
448 ) -> Result<(), agent_client_protocol::Error> {
449 tracing::debug!(
450 method = request_cx.method(),
451 params = ?request,
452 "ProxyHandler::forward_initialize"
453 );
454
455 if !request.has_meta_capability(Proxy) {
456 request_cx.respond_with_error(
457 acp::Error::invalid_params()
458 .with_data("this command requires the proxy capability"),
459 )?;
460 return Ok(());
461 }
462
463 request = request.remove_meta_capability(Proxy);
464 request_cx
465 .send_request_to_successor(request)
466 .await_when_result_received(async move |mut result| {
467 result = result.map(|r| r.add_meta_capability(Proxy));
468 request_cx.respond_with_result(result)
469 })
470 }
471}
472
473pub trait JsonRpcCxExt {
492 fn send_request_to_successor<Req: JsonRpcRequest>(
514 &self,
515 request: Req,
516 ) -> sacp::JsonRpcResponse<Req::Response>;
517
518 fn send_notification_to_successor<Req: JsonRpcNotification>(
530 &self,
531 notification: Req,
532 ) -> Result<(), acp::Error>;
533}
534
535impl JsonRpcCxExt for JsonRpcConnectionCx {
536 fn send_request_to_successor<Req: JsonRpcRequest>(
537 &self,
538 request: Req,
539 ) -> sacp::JsonRpcResponse<Req::Response> {
540 let wrapper = SuccessorRequest { request };
541 self.send_request(wrapper)
542 }
543
544 fn send_notification_to_successor<Req: JsonRpcNotification>(
545 &self,
546 notification: Req,
547 ) -> Result<(), acp::Error> {
548 let wrapper = SuccessorNotification { notification };
549 self.send_notification(wrapper)
550 }
551}