1use sacp::handler::ChainedHandler;
2use sacp::schema::{InitializeRequest, InitializeResponse};
3use sacp::{
4 Handled, JrConnectionCx, JrHandlerChain, JrMessage, JrMessageHandler, JrNotification,
5 JrRequest, JrRequestCx, MessageAndCx, MetaCapabilityExt, Proxy, UntypedMessage,
6};
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9
10use crate::mcp_server::McpServiceRegistry;
11
12const SUCCESSOR_REQUEST_METHOD: &str = "_proxy/successor/request";
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SuccessorRequest<Req: JrRequest> {
22 #[serde(flatten)]
24 pub request: Req,
25}
26
27impl<Req: JrRequest> JrMessage for SuccessorRequest<Req> {
28 fn into_untyped_message(self) -> Result<sacp::UntypedMessage, sacp::Error> {
29 sacp::UntypedMessage::new(
30 SUCCESSOR_REQUEST_METHOD,
31 SuccessorRequest {
32 request: self.request.into_untyped_message()?,
33 },
34 )
35 }
36
37 fn method(&self) -> &str {
38 SUCCESSOR_REQUEST_METHOD
39 }
40
41 fn parse_request(method: &str, params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
42 if method == SUCCESSOR_REQUEST_METHOD {
43 match sacp::util::json_cast::<_, SuccessorRequest<sacp::UntypedMessage>>(params) {
44 Ok(outer) => match Req::parse_request(&outer.request.method, &outer.request.params)
45 {
46 Some(Ok(request)) => Some(Ok(SuccessorRequest { request })),
47 Some(Err(err)) => Some(Err(err)),
48 None => None,
49 },
50 Err(err) => Some(Err(err)),
51 }
52 } else {
53 None
54 }
55 }
56
57 fn parse_notification(
58 _method: &str,
59 _params: &impl Serialize,
60 ) -> Option<Result<Self, sacp::Error>> {
61 None }
63}
64
65impl<Req: JrRequest> JrRequest for SuccessorRequest<Req> {
66 type Response = Req::Response;
67}
68
69const SUCCESSOR_NOTIFICATION_METHOD: &str = "_proxy/successor/notification";
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct SuccessorNotification<Req: JrNotification> {
76 #[serde(flatten)]
78 pub notification: Req,
79}
80
81impl<Req: JrNotification> JrMessage for SuccessorNotification<Req> {
82 fn into_untyped_message(self) -> Result<sacp::UntypedMessage, sacp::Error> {
83 sacp::UntypedMessage::new(
84 SUCCESSOR_NOTIFICATION_METHOD,
85 SuccessorNotification {
86 notification: self.notification.into_untyped_message()?,
87 },
88 )
89 }
90
91 fn method(&self) -> &str {
92 SUCCESSOR_NOTIFICATION_METHOD
93 }
94
95 fn parse_request(_method: &str, _params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
96 None }
98
99 fn parse_notification(
100 method: &str,
101 params: &impl Serialize,
102 ) -> Option<Result<Self, sacp::Error>> {
103 if method == SUCCESSOR_NOTIFICATION_METHOD {
104 match sacp::util::json_cast::<_, SuccessorNotification<sacp::UntypedMessage>>(params) {
105 Ok(outer) => match Req::parse_notification(
106 &outer.notification.method,
107 &outer.notification.params,
108 ) {
109 Some(Ok(notification)) => Some(Ok(SuccessorNotification { notification })),
110 Some(Err(err)) => Some(Err(err)),
111 None => None,
112 },
113 Err(err) => Some(Err(err)),
114 }
115 } else {
116 None
117 }
118 }
119}
120
121impl<Req: JrNotification> JrNotification for SuccessorNotification<Req> {}
122
123pub trait AcpProxyExt<H: JrMessageHandler> {
128 fn on_receive_request_from_successor<R, F>(
151 self,
152 op: F,
153 ) -> JrHandlerChain<ChainedHandler<H, RequestFromSuccessorHandler<R, F>>>
154 where
155 R: JrRequest,
156 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>;
157
158 fn on_receive_notification_from_successor<N, F>(
181 self,
182 op: F,
183 ) -> JrHandlerChain<ChainedHandler<H, NotificationFromSuccessorHandler<N, F>>>
184 where
185 N: JrNotification,
186 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>;
187
188 fn on_receive_message_from_successor<R, N, F>(
195 self,
196 op: F,
197 ) -> JrHandlerChain<ChainedHandler<H, MessageFromSuccessorHandler<R, N, F>>>
198 where
199 R: JrRequest,
200 N: JrNotification,
201 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>;
202
203 fn proxy(self) -> JrHandlerChain<ChainedHandler<H, ProxyHandler>>;
206
207 fn provide_mcp(
211 self,
212 registry: impl AsRef<McpServiceRegistry>,
213 ) -> JrHandlerChain<ChainedHandler<H, McpServiceRegistry>>;
214}
215
216impl<H: JrMessageHandler> AcpProxyExt<H> for JrHandlerChain<H> {
217 fn on_receive_request_from_successor<R, F>(
218 self,
219 op: F,
220 ) -> JrHandlerChain<ChainedHandler<H, RequestFromSuccessorHandler<R, F>>>
221 where
222 R: JrRequest,
223 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
224 {
225 self.with_handler(RequestFromSuccessorHandler::new(op))
226 }
227
228 fn on_receive_notification_from_successor<N, F>(
229 self,
230 op: F,
231 ) -> JrHandlerChain<ChainedHandler<H, NotificationFromSuccessorHandler<N, F>>>
232 where
233 N: JrNotification,
234 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
235 {
236 self.with_handler(NotificationFromSuccessorHandler::new(op))
237 }
238
239 fn on_receive_message_from_successor<R, N, F>(
240 self,
241 op: F,
242 ) -> JrHandlerChain<ChainedHandler<H, MessageFromSuccessorHandler<R, N, F>>>
243 where
244 R: JrRequest,
245 N: JrNotification,
246 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
247 {
248 self.with_handler(MessageFromSuccessorHandler::new(op))
249 }
250
251 fn proxy(self) -> JrHandlerChain<ChainedHandler<H, ProxyHandler>> {
252 self.with_handler(ProxyHandler {})
253 }
254
255 fn provide_mcp(
256 self,
257 registry: impl AsRef<McpServiceRegistry>,
258 ) -> JrHandlerChain<ChainedHandler<H, McpServiceRegistry>> {
259 self.with_handler(registry.as_ref().clone())
260 }
261}
262
263pub struct MessageFromSuccessorHandler<R, N, F>
265where
266 R: JrRequest,
267 N: JrNotification,
268 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
269{
270 handler: F,
271 phantom: PhantomData<fn(R, N)>,
272}
273
274impl<R, N, F> MessageFromSuccessorHandler<R, N, F>
275where
276 R: JrRequest,
277 N: JrNotification,
278 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
279{
280 pub fn new(handler: F) -> Self {
282 Self {
283 handler,
284 phantom: PhantomData,
285 }
286 }
287}
288
289impl<R, N, F> JrMessageHandler for MessageFromSuccessorHandler<R, N, F>
290where
291 R: JrRequest,
292 N: JrNotification,
293 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
294{
295 async fn handle_message(
296 &mut self,
297 message: sacp::MessageAndCx,
298 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
299 match message {
300 MessageAndCx::Request(request, request_cx) => {
301 tracing::trace!(
302 request_type = std::any::type_name::<R>(),
303 message = ?request,
304 "MessageFromSuccessorHandler::handle_message"
305 );
306 match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
307 Some(Ok(request)) => {
308 tracing::trace!(
309 ?request,
310 "RequestHandler::handle_request: parse completed"
311 );
312 (self.handler)(MessageAndCx::Request(request.request, request_cx.cast()))
313 .await?;
314 Ok(Handled::Yes)
315 }
316 Some(Err(err)) => {
317 tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
318 Err(err)
319 }
320 None => {
321 tracing::trace!("RequestHandler::handle_request: parse failed");
322 Ok(Handled::No(MessageAndCx::Request(request, request_cx)))
323 }
324 }
325 }
326 MessageAndCx::Notification(notification, connection_cx) => {
327 tracing::trace!(
328 ?notification,
329 "NotificationFromSuccessorHandler::handle_message"
330 );
331 match <SuccessorNotification<N>>::parse_notification(
332 ¬ification.method,
333 ¬ification.params,
334 ) {
335 Some(Ok(notification)) => {
336 tracing::trace!(
337 ?notification,
338 "NotificationFromSuccessorHandler::handle_message: parse completed"
339 );
340 (self.handler)(MessageAndCx::Notification(
341 notification.notification,
342 connection_cx,
343 ))
344 .await?;
345 Ok(Handled::Yes)
346 }
347 Some(Err(err)) => {
348 tracing::trace!(
349 ?err,
350 "NotificationFromSuccessorHandler::handle_message: parse errored"
351 );
352 Err(err)
353 }
354 None => {
355 tracing::trace!(
356 "NotificationFromSuccessorHandler::handle_message: parse failed"
357 );
358 Ok(Handled::No(MessageAndCx::Notification(
359 notification,
360 connection_cx,
361 )))
362 }
363 }
364 }
365 }
366 }
367
368 fn describe_chain(&self) -> impl std::fmt::Debug {
369 std::any::type_name::<R>()
370 }
371}
372
373pub struct RequestFromSuccessorHandler<R, F>
375where
376 R: JrRequest,
377 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
378{
379 handler: F,
380 phantom: PhantomData<fn(R)>,
381}
382
383impl<R, F> RequestFromSuccessorHandler<R, F>
384where
385 R: JrRequest,
386 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
387{
388 pub fn new(handler: F) -> Self {
390 Self {
391 handler,
392 phantom: PhantomData,
393 }
394 }
395}
396
397impl<R, F> JrMessageHandler for RequestFromSuccessorHandler<R, F>
398where
399 R: JrRequest,
400 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
401{
402 async fn handle_message(
403 &mut self,
404 message: sacp::MessageAndCx,
405 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
406 let MessageAndCx::Request(request, cx) = message else {
407 return Ok(Handled::No(message));
408 };
409
410 tracing::debug!(
411 request_type = std::any::type_name::<R>(),
412 message = ?request,
413 "RequestHandler::handle_request"
414 );
415 match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
416 Some(Ok(request)) => {
417 tracing::trace!(?request, "RequestHandler::handle_request: parse completed");
418 (self.handler)(request.request, cx.cast()).await?;
419 Ok(Handled::Yes)
420 }
421 Some(Err(err)) => {
422 tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
423 Err(err)
424 }
425 None => {
426 tracing::trace!("RequestHandler::handle_request: parse failed");
427 Ok(Handled::No(MessageAndCx::Request(request, cx)))
428 }
429 }
430 }
431
432 fn describe_chain(&self) -> impl std::fmt::Debug {
433 std::any::type_name::<R>()
434 }
435}
436
437pub struct NotificationFromSuccessorHandler<N, F>
439where
440 N: JrNotification,
441 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
442{
443 handler: F,
444 phantom: PhantomData<fn(N)>,
445}
446
447impl<N, F> NotificationFromSuccessorHandler<N, F>
448where
449 N: JrNotification,
450 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
451{
452 pub fn new(handler: F) -> Self {
454 Self {
455 handler,
456 phantom: PhantomData,
457 }
458 }
459}
460
461impl<N, F> JrMessageHandler for NotificationFromSuccessorHandler<N, F>
462where
463 N: JrNotification,
464 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
465{
466 async fn handle_message(
467 &mut self,
468 message: sacp::MessageAndCx,
469 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
470 let MessageAndCx::Notification(message, cx) = message else {
471 return Ok(Handled::No(message));
472 };
473
474 match <SuccessorNotification<N>>::parse_notification(&message.method, &message.params) {
475 Some(Ok(notification)) => {
476 tracing::trace!(
477 ?notification,
478 "NotificationFromSuccessorHandler::handle_request: parse completed"
479 );
480 (self.handler)(notification.notification, cx).await?;
481 Ok(Handled::Yes)
482 }
483 Some(Err(err)) => {
484 tracing::trace!(
485 ?err,
486 "NotificationFromSuccessorHandler::handle_request: parse errored"
487 );
488 Err(err)
489 }
490 None => {
491 tracing::trace!("NotificationFromSuccessorHandler::handle_request: parse failed");
492 Ok(Handled::No(MessageAndCx::Notification(message, cx)))
493 }
494 }
495 }
496
497 fn describe_chain(&self) -> impl std::fmt::Debug {
498 format!("FromSuccessor<{}>", std::any::type_name::<N>())
499 }
500}
501
502pub struct ProxyHandler {}
504
505impl JrMessageHandler for ProxyHandler {
506 fn describe_chain(&self) -> impl std::fmt::Debug {
507 "proxy"
508 }
509
510 async fn handle_message(
511 &mut self,
512 message: sacp::MessageAndCx,
513 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
514 tracing::debug!(
515 message = ?message.message(),
516 "ProxyHandler::handle_request"
517 );
518
519 match message {
520 MessageAndCx::Request(request, request_cx) => {
521 if let Some(result) = <SuccessorRequest<UntypedMessage>>::parse_request(
523 &request.method,
524 &request.params,
525 ) {
526 let request = result?;
527 request_cx
528 .send_request(request.request)
529 .forward_to_request_cx(request_cx)?;
530 return Ok(Handled::Yes);
531 }
532
533 if let Some(result) =
535 InitializeRequest::parse_request(&request.method, &request.params)
536 {
537 let request = result?;
538 return self
539 .forward_initialize(request, request_cx.cast())
540 .await
541 .map(|()| Handled::Yes);
542 }
543
544 request_cx
546 .send_request_to_successor(request)
547 .forward_to_request_cx(request_cx)?;
548 Ok(Handled::Yes)
549 }
550
551 MessageAndCx::Notification(notification, cx) => {
552 if let Some(result) = <SuccessorNotification<UntypedMessage>>::parse_notification(
554 ¬ification.method,
555 ¬ification.params,
556 ) {
557 match result {
558 Ok(r) => {
559 cx.send_notification(r.notification)?;
560 return Ok(Handled::Yes);
561 }
562 Err(err) => return Err(err),
563 }
564 }
565
566 cx.send_notification_to_successor(notification)?;
568 Ok(Handled::Yes)
569 }
570 }
571 }
572}
573
574impl ProxyHandler {
575 async fn forward_initialize(
579 &mut self,
580 mut request: InitializeRequest,
581 request_cx: JrRequestCx<InitializeResponse>,
582 ) -> Result<(), sacp::Error> {
583 tracing::debug!(
584 method = request_cx.method(),
585 params = ?request,
586 "ProxyHandler::forward_initialize"
587 );
588
589 if !request.has_meta_capability(Proxy) {
590 request_cx.respond_with_error(
591 sacp::Error::invalid_params()
592 .with_data("this command requires the proxy capability"),
593 )?;
594 return Ok(());
595 }
596
597 request = request.remove_meta_capability(Proxy);
598 request_cx
599 .send_request_to_successor(request)
600 .await_when_result_received(async move |mut result| {
601 result = result.map(|r| r.add_meta_capability(Proxy));
602 request_cx.respond_with_result(result)
603 })
604 }
605}
606
607pub trait JrCxExt {
626 fn send_request_to_successor<Req: JrRequest>(
648 &self,
649 request: Req,
650 ) -> sacp::JrResponse<Req::Response>;
651
652 fn send_notification_to_successor<Req: JrNotification>(
664 &self,
665 notification: Req,
666 ) -> Result<(), sacp::Error>;
667}
668
669impl JrCxExt for JrConnectionCx {
670 fn send_request_to_successor<Req: JrRequest>(
671 &self,
672 request: Req,
673 ) -> sacp::JrResponse<Req::Response> {
674 let wrapper = SuccessorRequest { request };
675 self.send_request(wrapper)
676 }
677
678 fn send_notification_to_successor<Req: JrNotification>(
679 &self,
680 notification: Req,
681 ) -> Result<(), sacp::Error> {
682 let wrapper = SuccessorNotification { notification };
683 self.send_notification(wrapper)
684 }
685}