1use futures::{AsyncRead, AsyncWrite};
2use sacp::handler::ChainHandler;
3use sacp::schema::{InitializeRequest, InitializeResponse};
4use sacp::{
5 Handled, JrConnection, JrConnectionCx, JrHandler, JrMessage, JrNotification, JrRequest,
6 JrRequestCx, MessageAndCx, MetaCapabilityExt, Proxy, UntypedMessage,
7};
8use serde::{Deserialize, Serialize};
9use std::marker::PhantomData;
10
11use crate::mcp_server::McpServiceRegistry;
12
13const SUCCESSOR_REQUEST_METHOD: &str = "_proxy/successor/request";
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct SuccessorRequest<Req: JrRequest> {
23 #[serde(flatten)]
25 pub request: Req,
26}
27
28impl<Req: JrRequest> JrMessage for SuccessorRequest<Req> {
29 fn into_untyped_message(self) -> Result<sacp::UntypedMessage, sacp::Error> {
30 sacp::UntypedMessage::new(
31 SUCCESSOR_REQUEST_METHOD,
32 SuccessorRequest {
33 request: self.request.into_untyped_message()?,
34 },
35 )
36 }
37
38 fn method(&self) -> &str {
39 SUCCESSOR_REQUEST_METHOD
40 }
41
42 fn parse_request(method: &str, params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
43 if method == SUCCESSOR_REQUEST_METHOD {
44 match sacp::util::json_cast::<_, SuccessorRequest<sacp::UntypedMessage>>(params) {
45 Ok(outer) => match Req::parse_request(&outer.request.method, &outer.request.params)
46 {
47 Some(Ok(request)) => Some(Ok(SuccessorRequest { request })),
48 Some(Err(err)) => Some(Err(err)),
49 None => None,
50 },
51 Err(err) => Some(Err(err)),
52 }
53 } else {
54 None
55 }
56 }
57
58 fn parse_notification(
59 _method: &str,
60 _params: &impl Serialize,
61 ) -> Option<Result<Self, sacp::Error>> {
62 None }
64}
65
66impl<Req: JrRequest> JrRequest for SuccessorRequest<Req> {
67 type Response = Req::Response;
68}
69
70const SUCCESSOR_NOTIFICATION_METHOD: &str = "_proxy/successor/notification";
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SuccessorNotification<Req: JrNotification> {
77 #[serde(flatten)]
79 pub notification: Req,
80}
81
82impl<Req: JrNotification> JrMessage for SuccessorNotification<Req> {
83 fn into_untyped_message(self) -> Result<sacp::UntypedMessage, sacp::Error> {
84 sacp::UntypedMessage::new(
85 SUCCESSOR_NOTIFICATION_METHOD,
86 SuccessorNotification {
87 notification: self.notification.into_untyped_message()?,
88 },
89 )
90 }
91
92 fn method(&self) -> &str {
93 SUCCESSOR_NOTIFICATION_METHOD
94 }
95
96 fn parse_request(_method: &str, _params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
97 None }
99
100 fn parse_notification(
101 method: &str,
102 params: &impl Serialize,
103 ) -> Option<Result<Self, sacp::Error>> {
104 if method == SUCCESSOR_NOTIFICATION_METHOD {
105 match sacp::util::json_cast::<_, SuccessorNotification<sacp::UntypedMessage>>(params) {
106 Ok(outer) => match Req::parse_notification(
107 &outer.notification.method,
108 &outer.notification.params,
109 ) {
110 Some(Ok(notification)) => Some(Ok(SuccessorNotification { notification })),
111 Some(Err(err)) => Some(Err(err)),
112 None => None,
113 },
114 Err(err) => Some(Err(err)),
115 }
116 } else {
117 None
118 }
119 }
120}
121
122impl<Req: JrNotification> JrNotification for SuccessorNotification<Req> {}
123
124pub trait AcpProxyExt<OB: AsyncWrite, IB: AsyncRead, H: JrHandler> {
129 fn on_receive_request_from_successor<R, F>(
152 self,
153 op: F,
154 ) -> JrConnection<OB, IB, ChainHandler<H, RequestFromSuccessorHandler<R, F>>>
155 where
156 R: JrRequest,
157 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>;
158
159 fn on_receive_notification_from_successor<N, F>(
182 self,
183 op: F,
184 ) -> JrConnection<OB, IB, ChainHandler<H, NotificationFromSuccessorHandler<N, F>>>
185 where
186 N: JrNotification,
187 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>;
188
189 fn on_receive_message_from_successor<R, N, F>(
196 self,
197 op: F,
198 ) -> JrConnection<OB, IB, ChainHandler<H, MessageFromSuccessorHandler<R, N, F>>>
199 where
200 R: JrRequest,
201 N: JrNotification,
202 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>;
203
204 fn proxy(self) -> JrConnection<OB, IB, ChainHandler<H, ProxyHandler>>;
207
208 fn provide_mcp(
212 self,
213 registry: impl AsRef<McpServiceRegistry>,
214 ) -> JrConnection<OB, IB, ChainHandler<H, McpServiceRegistry>>;
215}
216
217impl<OB, IB, H> AcpProxyExt<OB, IB, H> for JrConnection<OB, IB, H>
218where
219 OB: AsyncWrite,
220 IB: AsyncRead,
221 H: JrHandler,
222{
223 fn on_receive_request_from_successor<R, F>(
224 self,
225 op: F,
226 ) -> JrConnection<OB, IB, ChainHandler<H, RequestFromSuccessorHandler<R, F>>>
227 where
228 R: JrRequest,
229 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
230 {
231 self.chain_handler(RequestFromSuccessorHandler::new(op))
232 }
233
234 fn on_receive_notification_from_successor<N, F>(
235 self,
236 op: F,
237 ) -> JrConnection<OB, IB, ChainHandler<H, NotificationFromSuccessorHandler<N, F>>>
238 where
239 N: JrNotification,
240 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
241 {
242 self.chain_handler(NotificationFromSuccessorHandler::new(op))
243 }
244
245 fn on_receive_message_from_successor<R, N, F>(
246 self,
247 op: F,
248 ) -> JrConnection<OB, IB, ChainHandler<H, MessageFromSuccessorHandler<R, N, F>>>
249 where
250 R: JrRequest,
251 N: JrNotification,
252 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
253 {
254 self.chain_handler(MessageFromSuccessorHandler::new(op))
255 }
256
257 fn proxy(self) -> JrConnection<OB, IB, ChainHandler<H, ProxyHandler>> {
258 self.chain_handler(ProxyHandler {})
259 }
260
261 fn provide_mcp(
262 self,
263 registry: impl AsRef<McpServiceRegistry>,
264 ) -> JrConnection<OB, IB, ChainHandler<H, McpServiceRegistry>> {
265 self.chain_handler(registry.as_ref().clone())
266 }
267}
268
269pub struct MessageFromSuccessorHandler<R, N, F>
271where
272 R: JrRequest,
273 N: JrNotification,
274 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
275{
276 handler: F,
277 phantom: PhantomData<fn(R, N)>,
278}
279
280impl<R, N, F> MessageFromSuccessorHandler<R, N, F>
281where
282 R: JrRequest,
283 N: JrNotification,
284 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
285{
286 pub fn new(handler: F) -> Self {
288 Self {
289 handler,
290 phantom: PhantomData,
291 }
292 }
293}
294
295impl<R, N, F> JrHandler for MessageFromSuccessorHandler<R, N, F>
296where
297 R: JrRequest,
298 N: JrNotification,
299 F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
300{
301 async fn handle_message(
302 &mut self,
303 message: sacp::MessageAndCx,
304 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
305 match message {
306 MessageAndCx::Request(request, request_cx) => {
307 tracing::trace!(
308 request_type = std::any::type_name::<R>(),
309 message = ?request,
310 "MessageFromSuccessorHandler::handle_message"
311 );
312 match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
313 Some(Ok(request)) => {
314 tracing::trace!(
315 ?request,
316 "RequestHandler::handle_request: parse completed"
317 );
318 (self.handler)(MessageAndCx::Request(request.request, request_cx.cast()))
319 .await?;
320 Ok(Handled::Yes)
321 }
322 Some(Err(err)) => {
323 tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
324 Err(err)
325 }
326 None => {
327 tracing::trace!("RequestHandler::handle_request: parse failed");
328 Ok(Handled::No(MessageAndCx::Request(request, request_cx)))
329 }
330 }
331 }
332 MessageAndCx::Notification(notification, connection_cx) => {
333 tracing::trace!(
334 ?notification,
335 "NotificationFromSuccessorHandler::handle_message"
336 );
337 match <SuccessorNotification<N>>::parse_notification(
338 ¬ification.method,
339 ¬ification.params,
340 ) {
341 Some(Ok(notification)) => {
342 tracing::trace!(
343 ?notification,
344 "NotificationFromSuccessorHandler::handle_message: parse completed"
345 );
346 (self.handler)(MessageAndCx::Notification(
347 notification.notification,
348 connection_cx,
349 ))
350 .await?;
351 Ok(Handled::Yes)
352 }
353 Some(Err(err)) => {
354 tracing::trace!(
355 ?err,
356 "NotificationFromSuccessorHandler::handle_message: parse errored"
357 );
358 Err(err)
359 }
360 None => {
361 tracing::trace!(
362 "NotificationFromSuccessorHandler::handle_message: parse failed"
363 );
364 Ok(Handled::No(MessageAndCx::Notification(
365 notification,
366 connection_cx,
367 )))
368 }
369 }
370 }
371 }
372 }
373
374 fn describe_chain(&self) -> impl std::fmt::Debug {
375 std::any::type_name::<R>()
376 }
377}
378
379pub struct RequestFromSuccessorHandler<R, F>
381where
382 R: JrRequest,
383 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
384{
385 handler: F,
386 phantom: PhantomData<fn(R)>,
387}
388
389impl<R, F> RequestFromSuccessorHandler<R, F>
390where
391 R: JrRequest,
392 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
393{
394 pub fn new(handler: F) -> Self {
396 Self {
397 handler,
398 phantom: PhantomData,
399 }
400 }
401}
402
403impl<R, F> JrHandler for RequestFromSuccessorHandler<R, F>
404where
405 R: JrRequest,
406 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
407{
408 async fn handle_message(
409 &mut self,
410 message: sacp::MessageAndCx,
411 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
412 let MessageAndCx::Request(request, cx) = message else {
413 return Ok(Handled::No(message));
414 };
415
416 tracing::debug!(
417 request_type = std::any::type_name::<R>(),
418 message = ?request,
419 "RequestHandler::handle_request"
420 );
421 match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
422 Some(Ok(request)) => {
423 tracing::trace!(?request, "RequestHandler::handle_request: parse completed");
424 (self.handler)(request.request, cx.cast()).await?;
425 Ok(Handled::Yes)
426 }
427 Some(Err(err)) => {
428 tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
429 Err(err)
430 }
431 None => {
432 tracing::trace!("RequestHandler::handle_request: parse failed");
433 Ok(Handled::No(MessageAndCx::Request(request, cx)))
434 }
435 }
436 }
437
438 fn describe_chain(&self) -> impl std::fmt::Debug {
439 std::any::type_name::<R>()
440 }
441}
442
443pub struct NotificationFromSuccessorHandler<N, F>
445where
446 N: JrNotification,
447 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
448{
449 handler: F,
450 phantom: PhantomData<fn(N)>,
451}
452
453impl<N, F> NotificationFromSuccessorHandler<N, F>
454where
455 N: JrNotification,
456 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
457{
458 pub fn new(handler: F) -> Self {
460 Self {
461 handler,
462 phantom: PhantomData,
463 }
464 }
465}
466
467impl<N, F> JrHandler for NotificationFromSuccessorHandler<N, F>
468where
469 N: JrNotification,
470 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
471{
472 async fn handle_message(
473 &mut self,
474 message: sacp::MessageAndCx,
475 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
476 let MessageAndCx::Notification(message, cx) = message else {
477 return Ok(Handled::No(message));
478 };
479
480 match <SuccessorNotification<N>>::parse_notification(&message.method, &message.params) {
481 Some(Ok(notification)) => {
482 tracing::trace!(
483 ?notification,
484 "NotificationFromSuccessorHandler::handle_request: parse completed"
485 );
486 (self.handler)(notification.notification, cx).await?;
487 Ok(Handled::Yes)
488 }
489 Some(Err(err)) => {
490 tracing::trace!(
491 ?err,
492 "NotificationFromSuccessorHandler::handle_request: parse errored"
493 );
494 Err(err)
495 }
496 None => {
497 tracing::trace!("NotificationFromSuccessorHandler::handle_request: parse failed");
498 Ok(Handled::No(MessageAndCx::Notification(message, cx)))
499 }
500 }
501 }
502
503 fn describe_chain(&self) -> impl std::fmt::Debug {
504 format!("FromSuccessor<{}>", std::any::type_name::<N>())
505 }
506}
507
508pub struct ProxyHandler {}
510
511impl JrHandler for ProxyHandler {
512 fn describe_chain(&self) -> impl std::fmt::Debug {
513 "proxy"
514 }
515
516 async fn handle_message(
517 &mut self,
518 message: sacp::MessageAndCx,
519 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
520 tracing::debug!(
521 message = ?message.message(),
522 "ProxyHandler::handle_request"
523 );
524
525 match message {
526 MessageAndCx::Request(request, request_cx) => {
527 if let Some(result) = <SuccessorRequest<UntypedMessage>>::parse_request(
529 &request.method,
530 &request.params,
531 ) {
532 let request = result?;
533 request_cx
534 .send_request(request.request)
535 .forward_to_request_cx(request_cx)?;
536 return Ok(Handled::Yes);
537 }
538
539 if let Some(result) =
541 InitializeRequest::parse_request(&request.method, &request.params)
542 {
543 let request = result?;
544 return self
545 .forward_initialize(request, request_cx.cast())
546 .await
547 .map(|()| Handled::Yes);
548 }
549
550 request_cx
552 .send_request_to_successor(request)
553 .forward_to_request_cx(request_cx)?;
554 Ok(Handled::Yes)
555 }
556
557 MessageAndCx::Notification(notification, cx) => {
558 if let Some(result) = <SuccessorNotification<UntypedMessage>>::parse_notification(
560 ¬ification.method,
561 ¬ification.params,
562 ) {
563 match result {
564 Ok(r) => {
565 cx.send_notification(r.notification)?;
566 return Ok(Handled::Yes);
567 }
568 Err(err) => return Err(err),
569 }
570 }
571
572 cx.send_notification_to_successor(notification)?;
574 Ok(Handled::Yes)
575 }
576 }
577 }
578}
579
580impl ProxyHandler {
581 async fn forward_initialize(
585 &mut self,
586 mut request: InitializeRequest,
587 request_cx: JrRequestCx<InitializeResponse>,
588 ) -> Result<(), sacp::Error> {
589 tracing::debug!(
590 method = request_cx.method(),
591 params = ?request,
592 "ProxyHandler::forward_initialize"
593 );
594
595 if !request.has_meta_capability(Proxy) {
596 request_cx.respond_with_error(
597 sacp::Error::invalid_params()
598 .with_data("this command requires the proxy capability"),
599 )?;
600 return Ok(());
601 }
602
603 request = request.remove_meta_capability(Proxy);
604 request_cx
605 .send_request_to_successor(request)
606 .await_when_result_received(async move |mut result| {
607 result = result.map(|r| r.add_meta_capability(Proxy));
608 request_cx.respond_with_result(result)
609 })
610 }
611}
612
613pub trait JrCxExt {
632 fn send_request_to_successor<Req: JrRequest>(
654 &self,
655 request: Req,
656 ) -> sacp::JrResponse<Req::Response>;
657
658 fn send_notification_to_successor<Req: JrNotification>(
670 &self,
671 notification: Req,
672 ) -> Result<(), sacp::Error>;
673}
674
675impl JrCxExt for JrConnectionCx {
676 fn send_request_to_successor<Req: JrRequest>(
677 &self,
678 request: Req,
679 ) -> sacp::JrResponse<Req::Response> {
680 let wrapper = SuccessorRequest { request };
681 self.send_request(wrapper)
682 }
683
684 fn send_notification_to_successor<Req: JrNotification>(
685 &self,
686 notification: Req,
687 ) -> Result<(), sacp::Error> {
688 let wrapper = SuccessorNotification { notification };
689 self.send_notification(wrapper)
690 }
691}