1use futures::{AsyncRead, AsyncWrite};
2use sacp::{
3 ChainHandler, Handled, InitializeRequest, InitializeResponse, JrConnection, JrConnectionCx,
4 JrHandler, JrMessage, JrNotification, JrRequestCx, JsonRpcRequest, MessageAndCx,
5 MetaCapabilityExt, Proxy, UntypedMessage,
6};
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9
10use crate::mcp_server::McpServiceRegistry;
11
12const SUCCESSOR_REQUEST_METHOD: &str = "_proxy/successor/request";
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SuccessorRequest<Req: JsonRpcRequest> {
22 #[serde(flatten)]
24 pub request: Req,
25}
26
27impl<Req: JsonRpcRequest> JrMessage for SuccessorRequest<Req> {
28 fn into_untyped_message(self) -> Result<sacp::UntypedMessage, sacp::Error> {
29 sacp::UntypedMessage::new(
30 SUCCESSOR_REQUEST_METHOD,
31 SuccessorRequest {
32 request: self.request.into_untyped_message()?,
33 },
34 )
35 }
36
37 fn method(&self) -> &str {
38 SUCCESSOR_REQUEST_METHOD
39 }
40
41 fn parse_request(method: &str, params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
42 if method == SUCCESSOR_REQUEST_METHOD {
43 match sacp::util::json_cast::<_, SuccessorRequest<sacp::UntypedMessage>>(params) {
44 Ok(outer) => match Req::parse_request(&outer.request.method, &outer.request.params)
45 {
46 Some(Ok(request)) => Some(Ok(SuccessorRequest { request })),
47 Some(Err(err)) => Some(Err(err)),
48 None => None,
49 },
50 Err(err) => Some(Err(err)),
51 }
52 } else {
53 None
54 }
55 }
56
57 fn parse_notification(
58 _method: &str,
59 _params: &impl Serialize,
60 ) -> Option<Result<Self, sacp::Error>> {
61 None }
63}
64
65impl<Req: JsonRpcRequest> JsonRpcRequest for SuccessorRequest<Req> {
66 type Response = Req::Response;
67}
68
69const SUCCESSOR_NOTIFICATION_METHOD: &str = "_proxy/successor/notification";
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct SuccessorNotification<Req: JrNotification> {
76 #[serde(flatten)]
78 pub notification: Req,
79}
80
81impl<Req: JrNotification> JrMessage for SuccessorNotification<Req> {
82 fn into_untyped_message(self) -> Result<sacp::UntypedMessage, sacp::Error> {
83 sacp::UntypedMessage::new(
84 SUCCESSOR_NOTIFICATION_METHOD,
85 SuccessorNotification {
86 notification: self.notification.into_untyped_message()?,
87 },
88 )
89 }
90
91 fn method(&self) -> &str {
92 SUCCESSOR_NOTIFICATION_METHOD
93 }
94
95 fn parse_request(_method: &str, _params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
96 None }
98
99 fn parse_notification(
100 method: &str,
101 params: &impl Serialize,
102 ) -> Option<Result<Self, sacp::Error>> {
103 if method == SUCCESSOR_NOTIFICATION_METHOD {
104 match sacp::util::json_cast::<_, SuccessorNotification<sacp::UntypedMessage>>(params) {
105 Ok(outer) => match Req::parse_notification(
106 &outer.notification.method,
107 &outer.notification.params,
108 ) {
109 Some(Ok(notification)) => Some(Ok(SuccessorNotification { notification })),
110 Some(Err(err)) => Some(Err(err)),
111 None => None,
112 },
113 Err(err) => Some(Err(err)),
114 }
115 } else {
116 None
117 }
118 }
119}
120
121impl<Req: JrNotification> JrNotification for SuccessorNotification<Req> {}
122
123pub trait AcpProxyExt<OB: AsyncWrite, IB: AsyncRead, H: JrHandler> {
127 fn on_receive_request_from_successor<R, F>(
150 self,
151 op: F,
152 ) -> JrConnection<OB, IB, ChainHandler<H, RequestFromSuccessorHandler<R, F>>>
153 where
154 R: JsonRpcRequest,
155 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>;
156
157 fn on_receive_notification_from_successor<N, F>(
180 self,
181 op: F,
182 ) -> JrConnection<OB, IB, ChainHandler<H, NotificationFromSuccessorHandler<N, F>>>
183 where
184 N: JrNotification,
185 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>;
186
187 fn proxy(self) -> JrConnection<OB, IB, ChainHandler<H, ProxyHandler>>;
190
191 fn provide_mcp(
195 self,
196 registry: impl AsRef<McpServiceRegistry>,
197 ) -> JrConnection<OB, IB, ChainHandler<H, McpServiceRegistry>>;
198}
199
200impl<OB, IB, H> AcpProxyExt<OB, IB, H> for JrConnection<OB, IB, H>
201where
202 OB: AsyncWrite,
203 IB: AsyncRead,
204 H: JrHandler,
205{
206 fn on_receive_request_from_successor<R, F>(
207 self,
208 op: F,
209 ) -> JrConnection<OB, IB, ChainHandler<H, RequestFromSuccessorHandler<R, F>>>
210 where
211 R: JsonRpcRequest,
212 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
213 {
214 self.chain_handler(RequestFromSuccessorHandler::new(op))
215 }
216
217 fn on_receive_notification_from_successor<N, F>(
218 self,
219 op: F,
220 ) -> JrConnection<OB, IB, ChainHandler<H, NotificationFromSuccessorHandler<N, F>>>
221 where
222 N: JrNotification,
223 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
224 {
225 self.chain_handler(NotificationFromSuccessorHandler::new(op))
226 }
227
228 fn proxy(self) -> JrConnection<OB, IB, ChainHandler<H, ProxyHandler>> {
229 self.chain_handler(ProxyHandler {})
230 }
231
232 fn provide_mcp(
233 self,
234 registry: impl AsRef<McpServiceRegistry>,
235 ) -> JrConnection<OB, IB, ChainHandler<H, McpServiceRegistry>> {
236 self.chain_handler(registry.as_ref().clone())
237 }
238}
239
240pub struct RequestFromSuccessorHandler<R, F>
242where
243 R: JsonRpcRequest,
244 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
245{
246 handler: F,
247 phantom: PhantomData<fn(R)>,
248}
249
250impl<R, F> RequestFromSuccessorHandler<R, F>
251where
252 R: JsonRpcRequest,
253 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
254{
255 pub fn new(handler: F) -> Self {
256 Self {
257 handler,
258 phantom: PhantomData,
259 }
260 }
261}
262
263impl<R, F> JrHandler for RequestFromSuccessorHandler<R, F>
264where
265 R: JsonRpcRequest,
266 F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
267{
268 async fn handle_message(
269 &mut self,
270 message: sacp::MessageAndCx,
271 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
272 let MessageAndCx::Request(request, cx) = message else {
273 return Ok(Handled::No(message));
274 };
275
276 tracing::debug!(
277 request_type = std::any::type_name::<R>(),
278 message = ?request,
279 "RequestHandler::handle_request"
280 );
281 match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
282 Some(Ok(request)) => {
283 tracing::trace!(?request, "RequestHandler::handle_request: parse completed");
284 (self.handler)(request.request, cx.cast()).await?;
285 Ok(Handled::Yes)
286 }
287 Some(Err(err)) => {
288 tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
289 Err(err)
290 }
291 None => {
292 tracing::trace!("RequestHandler::handle_request: parse failed");
293 Ok(Handled::No(MessageAndCx::Request(request, cx)))
294 }
295 }
296 }
297
298 fn describe_chain(&self) -> impl std::fmt::Debug {
299 std::any::type_name::<R>()
300 }
301}
302
303pub struct NotificationFromSuccessorHandler<N, F>
305where
306 N: JrNotification,
307 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
308{
309 handler: F,
310 phantom: PhantomData<fn(N)>,
311}
312
313impl<N, F> NotificationFromSuccessorHandler<N, F>
314where
315 N: JrNotification,
316 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
317{
318 pub fn new(handler: F) -> Self {
319 Self {
320 handler,
321 phantom: PhantomData,
322 }
323 }
324}
325
326impl<N, F> JrHandler for NotificationFromSuccessorHandler<N, F>
327where
328 N: JrNotification,
329 F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
330{
331 async fn handle_message(
332 &mut self,
333 message: sacp::MessageAndCx,
334 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
335 let MessageAndCx::Notification(message, cx) = message else {
336 return Ok(Handled::No(message));
337 };
338
339 match <SuccessorNotification<N>>::parse_notification(&message.method, &message.params) {
340 Some(Ok(notification)) => {
341 tracing::trace!(
342 ?notification,
343 "NotificationFromSuccessorHandler::handle_request: parse completed"
344 );
345 (self.handler)(notification.notification, cx).await?;
346 Ok(Handled::Yes)
347 }
348 Some(Err(err)) => {
349 tracing::trace!(
350 ?err,
351 "NotificationFromSuccessorHandler::handle_request: parse errored"
352 );
353 Err(err)
354 }
355 None => {
356 tracing::trace!("NotificationFromSuccessorHandler::handle_request: parse failed");
357 Ok(Handled::No(MessageAndCx::Notification(message, cx)))
358 }
359 }
360 }
361
362 fn describe_chain(&self) -> impl std::fmt::Debug {
363 format!("FromSuccessor<{}>", std::any::type_name::<N>())
364 }
365}
366
367pub struct ProxyHandler {}
369
370impl JrHandler for ProxyHandler {
371 fn describe_chain(&self) -> impl std::fmt::Debug {
372 "proxy"
373 }
374
375 async fn handle_message(
376 &mut self,
377 message: sacp::MessageAndCx,
378 ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
379 tracing::debug!(
380 message = ?message.message(),
381 "ProxyHandler::handle_request"
382 );
383
384 match message {
385 MessageAndCx::Request(request, request_cx) => {
386 if let Some(result) = <SuccessorRequest<UntypedMessage>>::parse_request(
388 &request.method,
389 &request.params,
390 ) {
391 let request = result?;
392 request_cx
393 .send_request(request.request)
394 .forward_to_request_cx(request_cx)?;
395 return Ok(Handled::Yes);
396 }
397
398 if let Some(result) =
400 InitializeRequest::parse_request(&request.method, &request.params)
401 {
402 let request = result?;
403 return self
404 .forward_initialize(request, request_cx.cast())
405 .await
406 .map(|()| Handled::Yes);
407 }
408
409 request_cx
411 .send_request_to_successor(request)
412 .forward_to_request_cx(request_cx)?;
413 Ok(Handled::Yes)
414 }
415
416 MessageAndCx::Notification(notification, cx) => {
417 if let Some(result) = <SuccessorNotification<UntypedMessage>>::parse_notification(
419 ¬ification.method,
420 ¬ification.params,
421 ) {
422 match result {
423 Ok(r) => {
424 cx.send_notification(r.notification)?;
425 return Ok(Handled::Yes);
426 }
427 Err(err) => return Err(err),
428 }
429 }
430
431 cx.send_notification_to_successor(notification)?;
433 Ok(Handled::Yes)
434 }
435 }
436 }
437}
438
439impl ProxyHandler {
440 async fn forward_initialize(
444 &mut self,
445 mut request: InitializeRequest,
446 request_cx: JrRequestCx<InitializeResponse>,
447 ) -> Result<(), sacp::Error> {
448 tracing::debug!(
449 method = request_cx.method(),
450 params = ?request,
451 "ProxyHandler::forward_initialize"
452 );
453
454 if !request.has_meta_capability(Proxy) {
455 request_cx.respond_with_error(
456 sacp::Error::invalid_params()
457 .with_data("this command requires the proxy capability"),
458 )?;
459 return Ok(());
460 }
461
462 request = request.remove_meta_capability(Proxy);
463 request_cx
464 .send_request_to_successor(request)
465 .await_when_result_received(async move |mut result| {
466 result = result.map(|r| r.add_meta_capability(Proxy));
467 request_cx.respond_with_result(result)
468 })
469 }
470}
471
472pub trait JrCxExt {
491 fn send_request_to_successor<Req: JsonRpcRequest>(
513 &self,
514 request: Req,
515 ) -> sacp::JrResponse<Req::Response>;
516
517 fn send_notification_to_successor<Req: JrNotification>(
529 &self,
530 notification: Req,
531 ) -> Result<(), sacp::Error>;
532}
533
534impl JrCxExt for JrConnectionCx {
535 fn send_request_to_successor<Req: JsonRpcRequest>(
536 &self,
537 request: Req,
538 ) -> sacp::JrResponse<Req::Response> {
539 let wrapper = SuccessorRequest { request };
540 self.send_request(wrapper)
541 }
542
543 fn send_notification_to_successor<Req: JrNotification>(
544 &self,
545 notification: Req,
546 ) -> Result<(), sacp::Error> {
547 let wrapper = SuccessorNotification { notification };
548 self.send_notification(wrapper)
549 }
550}