1use sacp::handler::ChainedHandler;
2use sacp::schema::{InitializeRequest, InitializeResponse};
3use sacp::{
4 Handled, JrConnectionCx, JrHandlerChain, JrMessage, JrMessageHandler, JrNotification,
5 JrRequest, JrRequestCx, MessageAndCx, MetaCapabilityExt, Proxy, UntypedMessage,
6};
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9
10use crate::mcp_server_registry::McpServiceRegistry;
11
12const SUCCESSOR_REQUEST_METHOD: &str = "_proxy/successor/request";
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SuccessorRequest<Req: JrRequest> {
22 #[serde(flatten)]
24 pub request: Req,
25
26 #[serde(skip_serializing_if = "Option::is_none")]
28 pub meta: Option<serde_json::Value>,
29}
30
31impl<Req: JrRequest> JrMessage for SuccessorRequest<Req> {
32 fn to_untyped_message(&self) -> Result<sacp::UntypedMessage, sacp::Error> {
33 sacp::UntypedMessage::new(
34 SUCCESSOR_REQUEST_METHOD,
35 SuccessorRequest {
36 request: self.request.to_untyped_message()?,
37 meta: self.meta.clone(),
38 },
39 )
40 }
41
42 fn method(&self) -> &str {
43 SUCCESSOR_REQUEST_METHOD
44 }
45
46 fn parse_request(method: &str, params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
47 if method == SUCCESSOR_REQUEST_METHOD {
48 match sacp::util::json_cast::<_, SuccessorRequest<sacp::UntypedMessage>>(params) {
49 Ok(outer) => match Req::parse_request(&outer.request.method, &outer.request.params)
50 {
51 Some(Ok(request)) => Some(Ok(SuccessorRequest {
52 request,
53 meta: outer.meta,
54 })),
55 Some(Err(err)) => Some(Err(err)),
56 None => None,
57 },
58 Err(err) => Some(Err(err)),
59 }
60 } else {
61 None
62 }
63 }
64
65 fn parse_notification(
66 _method: &str,
67 _params: &impl Serialize,
68 ) -> Option<Result<Self, sacp::Error>> {
69 None }
71}
72
73impl<Req: JrRequest> JrRequest for SuccessorRequest<Req> {
74 type Response = Req::Response;
75}
76
77const SUCCESSOR_NOTIFICATION_METHOD: &str = "_proxy/successor/notification";
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct SuccessorNotification<Req: JrNotification> {
84 #[serde(flatten)]
86 pub notification: Req,
87
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub meta: Option<serde_json::Value>,
91}
92
93impl<Req: JrNotification> JrMessage for SuccessorNotification<Req> {
94 fn to_untyped_message(&self) -> Result<sacp::UntypedMessage, sacp::Error> {
95 sacp::UntypedMessage::new(
96 SUCCESSOR_NOTIFICATION_METHOD,
97 SuccessorNotification {
98 notification: self.notification.to_untyped_message()?,
99 meta: self.meta.clone(),
100 },
101 )
102 }
103
104 fn method(&self) -> &str {
105 SUCCESSOR_NOTIFICATION_METHOD
106 }
107
108 fn parse_request(_method: &str, _params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
109 None }
111
112 fn parse_notification(
113 method: &str,
114 params: &impl Serialize,
115 ) -> Option<Result<Self, sacp::Error>> {
116 if method == SUCCESSOR_NOTIFICATION_METHOD {
117 match sacp::util::json_cast::<_, SuccessorNotification<sacp::UntypedMessage>>(params) {
118 Ok(outer) => match Req::parse_notification(
119 &outer.notification.method,
120 &outer.notification.params,
121 ) {
122 Some(Ok(notification)) => Some(Ok(SuccessorNotification {
123 notification,
124 meta: outer.meta,
125 })),
126 Some(Err(err)) => Some(Err(err)),
127 None => None,
128 },
129 Err(err) => Some(Err(err)),
130 }
131 } else {
132 None
133 }
134 }
135}
136
137impl<Req: JrNotification> JrNotification for SuccessorNotification<Req> {}
138
139pub trait AcpProxyExt<H: JrMessageHandler> {
144 fn on_receive_request_from_successor<R, F>(
167 self,
168 op: F,
169 ) -> JrHandlerChain<ChainedHandler<H, RequestFromSuccessorHandler<R, F>>>
170 where
171 R: JrRequest,
172 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>;
173
174 fn on_receive_notification_from_successor<N, F>(
197 self,
198 op: F,
199 ) -> JrHandlerChain<ChainedHandler<H, NotificationFromSuccessorHandler<N, F>>>
200 where
201 N: JrNotification,
202 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>;
203
204 fn on_receive_message_from_successor<R, N, F>(
211 self,
212 op: F,
213 ) -> JrHandlerChain<ChainedHandler<H, MessageFromSuccessorHandler<R, N, F>>>
214 where
215 R: JrRequest,
216 N: JrNotification,
217 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>;
218
219 fn proxy(self) -> JrHandlerChain<ChainedHandler<H, ProxyHandler>>;
222
223 fn provide_mcp(
227 self,
228 registry: impl AsRef<McpServiceRegistry>,
229 ) -> JrHandlerChain<ChainedHandler<H, McpServiceRegistry>>;
230}
231
232impl<H: JrMessageHandler> AcpProxyExt<H> for JrHandlerChain<H> {
233 fn on_receive_request_from_successor<R, F>(
234 self,
235 op: F,
236 ) -> JrHandlerChain<ChainedHandler<H, RequestFromSuccessorHandler<R, F>>>
237 where
238 R: JrRequest,
239 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
240 {
241 self.with_handler(RequestFromSuccessorHandler::new(op))
242 }
243
244 fn on_receive_notification_from_successor<N, F>(
245 self,
246 op: F,
247 ) -> JrHandlerChain<ChainedHandler<H, NotificationFromSuccessorHandler<N, F>>>
248 where
249 N: JrNotification,
250 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
251 {
252 self.with_handler(NotificationFromSuccessorHandler::new(op))
253 }
254
255 fn on_receive_message_from_successor<R, N, F>(
256 self,
257 op: F,
258 ) -> JrHandlerChain<ChainedHandler<H, MessageFromSuccessorHandler<R, N, F>>>
259 where
260 R: JrRequest,
261 N: JrNotification,
262 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
263 {
264 self.with_handler(MessageFromSuccessorHandler::new(op))
265 }
266
267 fn proxy(self) -> JrHandlerChain<ChainedHandler<H, ProxyHandler>> {
268 self.with_handler(ProxyHandler {})
269 }
270
271 fn provide_mcp(
272 self,
273 registry: impl AsRef<McpServiceRegistry>,
274 ) -> JrHandlerChain<ChainedHandler<H, McpServiceRegistry>> {
275 self.with_handler(registry.as_ref().clone())
276 }
277}
278
279pub struct MessageFromSuccessorHandler<R, N, F>
281where
282 R: JrRequest,
283 N: JrNotification,
284 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
285{
286 handler: F,
287 phantom: PhantomData<fn(R, N)>,
288}
289
290impl<R, N, F> MessageFromSuccessorHandler<R, N, F>
291where
292 R: JrRequest,
293 N: JrNotification,
294 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
295{
296 pub fn new(handler: F) -> Self {
298 Self {
299 handler,
300 phantom: PhantomData,
301 }
302 }
303}
304
305impl<R, N, F> JrMessageHandler for MessageFromSuccessorHandler<R, N, F>
306where
307 R: JrRequest,
308 N: JrNotification,
309 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
310{
311 async fn handle_message(
312 &mut self,
313 message: sacp::MessageAndCx,
314 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
315 match message {
316 MessageAndCx::Request(request, request_cx) => {
317 tracing::trace!(
318 request_type = std::any::type_name::<R>(),
319 message = ?request,
320 "MessageFromSuccessorHandler::handle_message"
321 );
322 match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
323 Some(Ok(request)) => {
324 tracing::trace!(
325 ?request,
326 "RequestHandler::handle_request: parse completed"
327 );
328 (self.handler)(MessageAndCx::Request(request.request, request_cx.cast()))
329 .await?;
330 Ok(Handled::Yes)
331 }
332 Some(Err(err)) => {
333 tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
334 Err(err)
335 }
336 None => {
337 tracing::trace!("RequestHandler::handle_request: parse failed");
338 Ok(Handled::No(MessageAndCx::Request(request, request_cx)))
339 }
340 }
341 }
342 MessageAndCx::Notification(notification, connection_cx) => {
343 tracing::trace!(
344 ?notification,
345 "NotificationFromSuccessorHandler::handle_message"
346 );
347 match <SuccessorNotification<N>>::parse_notification(
348 ¬ification.method,
349 ¬ification.params,
350 ) {
351 Some(Ok(notification)) => {
352 tracing::trace!(
353 ?notification,
354 "NotificationFromSuccessorHandler::handle_message: parse completed"
355 );
356 (self.handler)(MessageAndCx::Notification(
357 notification.notification,
358 connection_cx,
359 ))
360 .await?;
361 Ok(Handled::Yes)
362 }
363 Some(Err(err)) => {
364 tracing::trace!(
365 ?err,
366 "NotificationFromSuccessorHandler::handle_message: parse errored"
367 );
368 Err(err)
369 }
370 None => {
371 tracing::trace!(
372 "NotificationFromSuccessorHandler::handle_message: parse failed"
373 );
374 Ok(Handled::No(MessageAndCx::Notification(
375 notification,
376 connection_cx,
377 )))
378 }
379 }
380 }
381 }
382 }
383
384 fn describe_chain(&self) -> impl std::fmt::Debug {
385 std::any::type_name::<R>()
386 }
387}
388
389pub struct RequestFromSuccessorHandler<R, F>
391where
392 R: JrRequest,
393 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
394{
395 handler: F,
396 phantom: PhantomData<fn(R)>,
397}
398
399impl<R, F> RequestFromSuccessorHandler<R, F>
400where
401 R: JrRequest,
402 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
403{
404 pub fn new(handler: F) -> Self {
406 Self {
407 handler,
408 phantom: PhantomData,
409 }
410 }
411}
412
413impl<R, F> JrMessageHandler for RequestFromSuccessorHandler<R, F>
414where
415 R: JrRequest,
416 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
417{
418 async fn handle_message(
419 &mut self,
420 message: sacp::MessageAndCx,
421 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
422 let MessageAndCx::Request(request, cx) = message else {
423 return Ok(Handled::No(message));
424 };
425
426 tracing::debug!(
427 request_type = std::any::type_name::<R>(),
428 message = ?request,
429 "RequestHandler::handle_request"
430 );
431 match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
432 Some(Ok(request)) => {
433 tracing::trace!(?request, "RequestHandler::handle_request: parse completed");
434 (self.handler)(request.request, cx.cast()).await?;
435 Ok(Handled::Yes)
436 }
437 Some(Err(err)) => {
438 tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
439 Err(err)
440 }
441 None => {
442 tracing::trace!("RequestHandler::handle_request: parse failed");
443 Ok(Handled::No(MessageAndCx::Request(request, cx)))
444 }
445 }
446 }
447
448 fn describe_chain(&self) -> impl std::fmt::Debug {
449 std::any::type_name::<R>()
450 }
451}
452
453pub struct NotificationFromSuccessorHandler<N, F>
455where
456 N: JrNotification,
457 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
458{
459 handler: F,
460 phantom: PhantomData<fn(N)>,
461}
462
463impl<N, F> NotificationFromSuccessorHandler<N, F>
464where
465 N: JrNotification,
466 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
467{
468 pub fn new(handler: F) -> Self {
470 Self {
471 handler,
472 phantom: PhantomData,
473 }
474 }
475}
476
477impl<N, F> JrMessageHandler for NotificationFromSuccessorHandler<N, F>
478where
479 N: JrNotification,
480 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
481{
482 async fn handle_message(
483 &mut self,
484 message: sacp::MessageAndCx,
485 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
486 let MessageAndCx::Notification(message, cx) = message else {
487 return Ok(Handled::No(message));
488 };
489
490 match <SuccessorNotification<N>>::parse_notification(&message.method, &message.params) {
491 Some(Ok(notification)) => {
492 tracing::trace!(
493 ?notification,
494 "NotificationFromSuccessorHandler::handle_request: parse completed"
495 );
496 (self.handler)(notification.notification, cx).await?;
497 Ok(Handled::Yes)
498 }
499 Some(Err(err)) => {
500 tracing::trace!(
501 ?err,
502 "NotificationFromSuccessorHandler::handle_request: parse errored"
503 );
504 Err(err)
505 }
506 None => {
507 tracing::trace!("NotificationFromSuccessorHandler::handle_request: parse failed");
508 Ok(Handled::No(MessageAndCx::Notification(message, cx)))
509 }
510 }
511 }
512
513 fn describe_chain(&self) -> impl std::fmt::Debug {
514 format!("FromSuccessor<{}>", std::any::type_name::<N>())
515 }
516}
517
518pub struct ProxyHandler {}
520
521impl JrMessageHandler for ProxyHandler {
522 fn describe_chain(&self) -> impl std::fmt::Debug {
523 "proxy"
524 }
525
526 async fn handle_message(
527 &mut self,
528 message: sacp::MessageAndCx,
529 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
530 tracing::debug!(
531 message = ?message.message(),
532 "ProxyHandler::handle_request"
533 );
534
535 match message {
536 MessageAndCx::Request(request, request_cx) => {
537 if let Some(result) = <SuccessorRequest<UntypedMessage>>::parse_request(
539 &request.method,
540 &request.params,
541 ) {
542 let request = result?;
543 request_cx
544 .connection_cx()
545 .send_request(request.request)
546 .forward_to_request_cx(request_cx)?;
547 return Ok(Handled::Yes);
548 }
549
550 if let Some(result) =
552 InitializeRequest::parse_request(&request.method, &request.params)
553 {
554 let request = result?;
555 return self
556 .forward_initialize(request, request_cx.cast())
557 .await
558 .map(|()| Handled::Yes);
559 }
560
561 request_cx
563 .connection_cx()
564 .send_request_to_successor(request)
565 .forward_to_request_cx(request_cx)?;
566 Ok(Handled::Yes)
567 }
568
569 MessageAndCx::Notification(notification, cx) => {
570 if let Some(result) = <SuccessorNotification<UntypedMessage>>::parse_notification(
572 ¬ification.method,
573 ¬ification.params,
574 ) {
575 match result {
576 Ok(r) => {
577 cx.send_notification(r.notification)?;
578 return Ok(Handled::Yes);
579 }
580 Err(err) => return Err(err),
581 }
582 }
583
584 cx.send_notification_to_successor(notification)?;
586 Ok(Handled::Yes)
587 }
588 }
589 }
590}
591
592impl ProxyHandler {
593 async fn forward_initialize(
597 &mut self,
598 mut request: InitializeRequest,
599 request_cx: JrRequestCx<InitializeResponse>,
600 ) -> Result<(), sacp::Error> {
601 tracing::debug!(
602 method = request_cx.method(),
603 params = ?request,
604 "ProxyHandler::forward_initialize"
605 );
606
607 if !request.has_meta_capability(Proxy) {
608 request_cx.respond_with_error(
609 sacp::Error::invalid_params()
610 .with_data("this command requires the proxy capability"),
611 )?;
612 return Ok(());
613 }
614
615 request = request.remove_meta_capability(Proxy);
616 request_cx
617 .connection_cx()
618 .send_request_to_successor(request)
619 .await_when_result_received(async move |mut result| {
620 result = result.map(|r| r.add_meta_capability(Proxy));
621 request_cx.respond_with_result(result)
622 })
623 }
624}
625
626pub trait JrCxExt {
645 fn send_request_to_successor<Req: JrRequest>(
667 &self,
668 request: Req,
669 ) -> sacp::JrResponse<Req::Response>;
670
671 fn send_notification_to_successor<Req: JrNotification>(
683 &self,
684 notification: Req,
685 ) -> Result<(), sacp::Error>;
686}
687
688impl JrCxExt for JrConnectionCx {
689 fn send_request_to_successor<Req: JrRequest>(
690 &self,
691 request: Req,
692 ) -> sacp::JrResponse<Req::Response> {
693 let wrapper = SuccessorRequest {
694 request,
695 meta: None,
696 };
697 self.send_request(wrapper)
698 }
699
700 fn send_notification_to_successor<Req: JrNotification>(
701 &self,
702 notification: Req,
703 ) -> Result<(), sacp::Error> {
704 let wrapper = SuccessorNotification {
705 notification,
706 meta: None,
707 };
708 self.send_notification(wrapper)
709 }
710}