sacp_proxy/
to_from_successor.rs

1use sacp::handler::ChainedHandler;
2use sacp::schema::{InitializeRequest, InitializeResponse};
3use sacp::{
4    Handled, JrConnectionCx, JrHandlerChain, JrMessage, JrMessageHandler, JrNotification,
5    JrRequest, JrRequestCx, MessageAndCx, MetaCapabilityExt, Proxy, UntypedMessage,
6};
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9
10use crate::mcp_server::McpServiceRegistry;
11
12// Requests and notifications send between us and the successor
13// ============================================================
14
15const SUCCESSOR_REQUEST_METHOD: &str = "_proxy/successor/request";
16
17/// A request being sent to the successor component.
18///
19/// Used in `_proxy/successor/send` when the proxy wants to forward a request downstream.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SuccessorRequest<Req: JrRequest> {
22    /// The message to be sent to the successor component.
23    #[serde(flatten)]
24    pub request: Req,
25}
26
27impl<Req: JrRequest> JrMessage for SuccessorRequest<Req> {
28    fn into_untyped_message(self) -> Result<sacp::UntypedMessage, sacp::Error> {
29        sacp::UntypedMessage::new(
30            SUCCESSOR_REQUEST_METHOD,
31            SuccessorRequest {
32                request: self.request.into_untyped_message()?,
33            },
34        )
35    }
36
37    fn method(&self) -> &str {
38        SUCCESSOR_REQUEST_METHOD
39    }
40
41    fn parse_request(method: &str, params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
42        if method == SUCCESSOR_REQUEST_METHOD {
43            match sacp::util::json_cast::<_, SuccessorRequest<sacp::UntypedMessage>>(params) {
44                Ok(outer) => match Req::parse_request(&outer.request.method, &outer.request.params)
45                {
46                    Some(Ok(request)) => Some(Ok(SuccessorRequest { request })),
47                    Some(Err(err)) => Some(Err(err)),
48                    None => None,
49                },
50                Err(err) => Some(Err(err)),
51            }
52        } else {
53            None
54        }
55    }
56
57    fn parse_notification(
58        _method: &str,
59        _params: &impl Serialize,
60    ) -> Option<Result<Self, sacp::Error>> {
61        None // Request, not notification
62    }
63}
64
65impl<Req: JrRequest> JrRequest for SuccessorRequest<Req> {
66    type Response = Req::Response;
67}
68
69const SUCCESSOR_NOTIFICATION_METHOD: &str = "_proxy/successor/notification";
70
71/// A notification being sent to the successor component.
72///
73/// Used in `_proxy/successor/send` when the proxy wants to forward a notification downstream.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct SuccessorNotification<Req: JrNotification> {
76    /// The message to be sent to the successor component.
77    #[serde(flatten)]
78    pub notification: Req,
79}
80
81impl<Req: JrNotification> JrMessage for SuccessorNotification<Req> {
82    fn into_untyped_message(self) -> Result<sacp::UntypedMessage, sacp::Error> {
83        sacp::UntypedMessage::new(
84            SUCCESSOR_NOTIFICATION_METHOD,
85            SuccessorNotification {
86                notification: self.notification.into_untyped_message()?,
87            },
88        )
89    }
90
91    fn method(&self) -> &str {
92        SUCCESSOR_NOTIFICATION_METHOD
93    }
94
95    fn parse_request(_method: &str, _params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
96        None // Notification, not request
97    }
98
99    fn parse_notification(
100        method: &str,
101        params: &impl Serialize,
102    ) -> Option<Result<Self, sacp::Error>> {
103        if method == SUCCESSOR_NOTIFICATION_METHOD {
104            match sacp::util::json_cast::<_, SuccessorNotification<sacp::UntypedMessage>>(params) {
105                Ok(outer) => match Req::parse_notification(
106                    &outer.notification.method,
107                    &outer.notification.params,
108                ) {
109                    Some(Ok(notification)) => Some(Ok(SuccessorNotification { notification })),
110                    Some(Err(err)) => Some(Err(err)),
111                    None => None,
112                },
113                Err(err) => Some(Err(err)),
114            }
115        } else {
116            None
117        }
118    }
119}
120
121impl<Req: JrNotification> JrNotification for SuccessorNotification<Req> {}
122
123// Proxy methods
124// ============================================================
125
126/// Extension trait for JrConnection that adds proxy-specific functionality
127pub trait AcpProxyExt<H: JrMessageHandler> {
128    /// Adds a handler for requests received from the successor component.
129    ///
130    /// The provided handler will receive unwrapped ACP messages - the
131    /// `_proxy/successor/receive/*` protocol wrappers are handled automatically.
132    /// Your handler processes normal ACP requests and notifications as if it were
133    /// a regular ACP component.
134    ///
135    /// # Example
136    ///
137    /// ```rust,ignore
138    /// # use sacp::proxy::JrConnectionExt;
139    /// # use sacp::{JrConnection, JrHandler};
140    /// # struct MyHandler;
141    /// # impl JrHandler for MyHandler {}
142    /// # async fn example() -> Result<(), sacp::Error> {
143    /// JrConnection::new(tokio::io::stdin(), tokio::io::stdout())
144    ///     .on_receive_from_successor(MyHandler)
145    ///     .serve()
146    ///     .await?;
147    /// # Ok(())
148    /// # }
149    /// ```
150    fn on_receive_request_from_successor<R, F>(
151        self,
152        op: F,
153    ) -> JrHandlerChain<ChainedHandler<H, RequestFromSuccessorHandler<R, F>>>
154    where
155        R: JrRequest,
156        F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>;
157
158    /// Adds a handler for notifications received from the successor component.
159    ///
160    /// The provided handler will receive unwrapped ACP messages - the
161    /// `_proxy/successor/receive/*` protocol wrappers are handled automatically.
162    /// Your handler processes normal ACP requests and notifications as if it were
163    /// a regular ACP component.
164    ///
165    /// # Example
166    ///
167    /// ```rust,ignore
168    /// # use sacp::proxy::JrConnectionExt;
169    /// # use sacp::{JrConnection, JrHandler};
170    /// # struct MyHandler;
171    /// # impl JrHandler for MyHandler {}
172    /// # async fn example() -> Result<(), sacp::Error> {
173    /// JrConnection::new()
174    ///     .on_receive_from_successor(MyHandler)
175    ///     .serve()
176    ///     .await?;
177    /// # Ok(())
178    /// # }
179    /// ```
180    fn on_receive_notification_from_successor<N, F>(
181        self,
182        op: F,
183    ) -> JrHandlerChain<ChainedHandler<H, NotificationFromSuccessorHandler<N, F>>>
184    where
185        N: JrNotification,
186        F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>;
187
188    /// Adds a handler for messages received from the successor component.
189    ///
190    /// The provided handler will receive unwrapped ACP messages - the
191    /// `_proxy/successor/receive/*` protocol wrappers are handled automatically.
192    /// Your handler processes normal ACP requests and notifications as if it were
193    /// a regular ACP component.
194    fn on_receive_message_from_successor<R, N, F>(
195        self,
196        op: F,
197    ) -> JrHandlerChain<ChainedHandler<H, MessageFromSuccessorHandler<R, N, F>>>
198    where
199        R: JrRequest,
200        N: JrNotification,
201        F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>;
202
203    /// Installs a proxy layer that proxies all requests/notifications to/from the successor.
204    /// This is typically the last component in the chain.
205    fn proxy(self) -> JrHandlerChain<ChainedHandler<H, ProxyHandler>>;
206
207    /// Provide MCP servers to downstream successors.
208    /// This layer will modify `session/new` requests to include those MCP servers
209    /// (unless you intercept them earlier).
210    fn provide_mcp(
211        self,
212        registry: impl AsRef<McpServiceRegistry>,
213    ) -> JrHandlerChain<ChainedHandler<H, McpServiceRegistry>>;
214}
215
216impl<H: JrMessageHandler> AcpProxyExt<H> for JrHandlerChain<H> {
217    fn on_receive_request_from_successor<R, F>(
218        self,
219        op: F,
220    ) -> JrHandlerChain<ChainedHandler<H, RequestFromSuccessorHandler<R, F>>>
221    where
222        R: JrRequest,
223        F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
224    {
225        self.with_handler(RequestFromSuccessorHandler::new(op))
226    }
227
228    fn on_receive_notification_from_successor<N, F>(
229        self,
230        op: F,
231    ) -> JrHandlerChain<ChainedHandler<H, NotificationFromSuccessorHandler<N, F>>>
232    where
233        N: JrNotification,
234        F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
235    {
236        self.with_handler(NotificationFromSuccessorHandler::new(op))
237    }
238
239    fn on_receive_message_from_successor<R, N, F>(
240        self,
241        op: F,
242    ) -> JrHandlerChain<ChainedHandler<H, MessageFromSuccessorHandler<R, N, F>>>
243    where
244        R: JrRequest,
245        N: JrNotification,
246        F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
247    {
248        self.with_handler(MessageFromSuccessorHandler::new(op))
249    }
250
251    fn proxy(self) -> JrHandlerChain<ChainedHandler<H, ProxyHandler>> {
252        self.with_handler(ProxyHandler {})
253    }
254
255    fn provide_mcp(
256        self,
257        registry: impl AsRef<McpServiceRegistry>,
258    ) -> JrHandlerChain<ChainedHandler<H, McpServiceRegistry>> {
259        self.with_handler(registry.as_ref().clone())
260    }
261}
262
263/// Handler to process a message of type `R` coming from the successor component.
264pub struct MessageFromSuccessorHandler<R, N, F>
265where
266    R: JrRequest,
267    N: JrNotification,
268    F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
269{
270    handler: F,
271    phantom: PhantomData<fn(R, N)>,
272}
273
274impl<R, N, F> MessageFromSuccessorHandler<R, N, F>
275where
276    R: JrRequest,
277    N: JrNotification,
278    F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
279{
280    /// Creates a new handler for requests from the successor
281    pub fn new(handler: F) -> Self {
282        Self {
283            handler,
284            phantom: PhantomData,
285        }
286    }
287}
288
289impl<R, N, F> JrMessageHandler for MessageFromSuccessorHandler<R, N, F>
290where
291    R: JrRequest,
292    N: JrNotification,
293    F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
294{
295    async fn handle_message(
296        &mut self,
297        message: sacp::MessageAndCx,
298    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
299        match message {
300            MessageAndCx::Request(request, request_cx) => {
301                tracing::trace!(
302                    request_type = std::any::type_name::<R>(),
303                    message = ?request,
304                    "MessageFromSuccessorHandler::handle_message"
305                );
306                match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
307                    Some(Ok(request)) => {
308                        tracing::trace!(
309                            ?request,
310                            "RequestHandler::handle_request: parse completed"
311                        );
312                        (self.handler)(MessageAndCx::Request(request.request, request_cx.cast()))
313                            .await?;
314                        Ok(Handled::Yes)
315                    }
316                    Some(Err(err)) => {
317                        tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
318                        Err(err)
319                    }
320                    None => {
321                        tracing::trace!("RequestHandler::handle_request: parse failed");
322                        Ok(Handled::No(MessageAndCx::Request(request, request_cx)))
323                    }
324                }
325            }
326            MessageAndCx::Notification(notification, connection_cx) => {
327                tracing::trace!(
328                    ?notification,
329                    "NotificationFromSuccessorHandler::handle_message"
330                );
331                match <SuccessorNotification<N>>::parse_notification(
332                    &notification.method,
333                    &notification.params,
334                ) {
335                    Some(Ok(notification)) => {
336                        tracing::trace!(
337                            ?notification,
338                            "NotificationFromSuccessorHandler::handle_message: parse completed"
339                        );
340                        (self.handler)(MessageAndCx::Notification(
341                            notification.notification,
342                            connection_cx,
343                        ))
344                        .await?;
345                        Ok(Handled::Yes)
346                    }
347                    Some(Err(err)) => {
348                        tracing::trace!(
349                            ?err,
350                            "NotificationFromSuccessorHandler::handle_message: parse errored"
351                        );
352                        Err(err)
353                    }
354                    None => {
355                        tracing::trace!(
356                            "NotificationFromSuccessorHandler::handle_message: parse failed"
357                        );
358                        Ok(Handled::No(MessageAndCx::Notification(
359                            notification,
360                            connection_cx,
361                        )))
362                    }
363                }
364            }
365        }
366    }
367
368    fn describe_chain(&self) -> impl std::fmt::Debug {
369        std::any::type_name::<R>()
370    }
371}
372
373/// Handler to process a request of type `R` coming from the successor component.
374pub struct RequestFromSuccessorHandler<R, F>
375where
376    R: JrRequest,
377    F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
378{
379    handler: F,
380    phantom: PhantomData<fn(R)>,
381}
382
383impl<R, F> RequestFromSuccessorHandler<R, F>
384where
385    R: JrRequest,
386    F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
387{
388    /// Creates a new handler for requests from the successor
389    pub fn new(handler: F) -> Self {
390        Self {
391            handler,
392            phantom: PhantomData,
393        }
394    }
395}
396
397impl<R, F> JrMessageHandler for RequestFromSuccessorHandler<R, F>
398where
399    R: JrRequest,
400    F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
401{
402    async fn handle_message(
403        &mut self,
404        message: sacp::MessageAndCx,
405    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
406        let MessageAndCx::Request(request, cx) = message else {
407            return Ok(Handled::No(message));
408        };
409
410        tracing::debug!(
411            request_type = std::any::type_name::<R>(),
412            message = ?request,
413            "RequestHandler::handle_request"
414        );
415        match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
416            Some(Ok(request)) => {
417                tracing::trace!(?request, "RequestHandler::handle_request: parse completed");
418                (self.handler)(request.request, cx.cast()).await?;
419                Ok(Handled::Yes)
420            }
421            Some(Err(err)) => {
422                tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
423                Err(err)
424            }
425            None => {
426                tracing::trace!("RequestHandler::handle_request: parse failed");
427                Ok(Handled::No(MessageAndCx::Request(request, cx)))
428            }
429        }
430    }
431
432    fn describe_chain(&self) -> impl std::fmt::Debug {
433        std::any::type_name::<R>()
434    }
435}
436
437/// Handler to process a notification of type `N` coming from the successor component.
438pub struct NotificationFromSuccessorHandler<N, F>
439where
440    N: JrNotification,
441    F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
442{
443    handler: F,
444    phantom: PhantomData<fn(N)>,
445}
446
447impl<N, F> NotificationFromSuccessorHandler<N, F>
448where
449    N: JrNotification,
450    F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
451{
452    /// Creates a new handler for notifications from the successor
453    pub fn new(handler: F) -> Self {
454        Self {
455            handler,
456            phantom: PhantomData,
457        }
458    }
459}
460
461impl<N, F> JrMessageHandler for NotificationFromSuccessorHandler<N, F>
462where
463    N: JrNotification,
464    F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
465{
466    async fn handle_message(
467        &mut self,
468        message: sacp::MessageAndCx,
469    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
470        let MessageAndCx::Notification(message, cx) = message else {
471            return Ok(Handled::No(message));
472        };
473
474        match <SuccessorNotification<N>>::parse_notification(&message.method, &message.params) {
475            Some(Ok(notification)) => {
476                tracing::trace!(
477                    ?notification,
478                    "NotificationFromSuccessorHandler::handle_request: parse completed"
479                );
480                (self.handler)(notification.notification, cx).await?;
481                Ok(Handled::Yes)
482            }
483            Some(Err(err)) => {
484                tracing::trace!(
485                    ?err,
486                    "NotificationFromSuccessorHandler::handle_request: parse errored"
487                );
488                Err(err)
489            }
490            None => {
491                tracing::trace!("NotificationFromSuccessorHandler::handle_request: parse failed");
492                Ok(Handled::No(MessageAndCx::Notification(message, cx)))
493            }
494        }
495    }
496
497    fn describe_chain(&self) -> impl std::fmt::Debug {
498        format!("FromSuccessor<{}>", std::any::type_name::<N>())
499    }
500}
501
502/// Handler for the "default proxy" behavior.
503pub struct ProxyHandler {}
504
505impl JrMessageHandler for ProxyHandler {
506    fn describe_chain(&self) -> impl std::fmt::Debug {
507        "proxy"
508    }
509
510    async fn handle_message(
511        &mut self,
512        message: sacp::MessageAndCx,
513    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
514        tracing::debug!(
515            message = ?message.message(),
516            "ProxyHandler::handle_request"
517        );
518
519        match message {
520            MessageAndCx::Request(request, request_cx) => {
521                // If we receive a request from the successor, send it to our predecessor.
522                if let Some(result) = <SuccessorRequest<UntypedMessage>>::parse_request(
523                    &request.method,
524                    &request.params,
525                ) {
526                    let request = result?;
527                    request_cx
528                        .send_request(request.request)
529                        .forward_to_request_cx(request_cx)?;
530                    return Ok(Handled::Yes);
531                }
532
533                // If we receive "Initialize", require the proxy capability (and remove it)
534                if let Some(result) =
535                    InitializeRequest::parse_request(&request.method, &request.params)
536                {
537                    let request = result?;
538                    return self
539                        .forward_initialize(request, request_cx.cast())
540                        .await
541                        .map(|()| Handled::Yes);
542                }
543
544                // If we receive any other request, send it to our successor.
545                request_cx
546                    .send_request_to_successor(request)
547                    .forward_to_request_cx(request_cx)?;
548                Ok(Handled::Yes)
549            }
550
551            MessageAndCx::Notification(notification, cx) => {
552                // If we receive a request from the successor, send it to our predecessor.
553                if let Some(result) = <SuccessorNotification<UntypedMessage>>::parse_notification(
554                    &notification.method,
555                    &notification.params,
556                ) {
557                    match result {
558                        Ok(r) => {
559                            cx.send_notification(r.notification)?;
560                            return Ok(Handled::Yes);
561                        }
562                        Err(err) => return Err(err),
563                    }
564                }
565
566                // If we receive any other request, send it to our successor.
567                cx.send_notification_to_successor(notification)?;
568                Ok(Handled::Yes)
569            }
570        }
571    }
572}
573
574impl ProxyHandler {
575    /// Proxy initialization requires (1) a `Proxy` capability to be
576    /// provided by the conductor and (2) provides a `Proxy` capability
577    /// in our response.
578    async fn forward_initialize(
579        &mut self,
580        mut request: InitializeRequest,
581        request_cx: JrRequestCx<InitializeResponse>,
582    ) -> Result<(), sacp::Error> {
583        tracing::debug!(
584            method = request_cx.method(),
585            params = ?request,
586            "ProxyHandler::forward_initialize"
587        );
588
589        if !request.has_meta_capability(Proxy) {
590            request_cx.respond_with_error(
591                sacp::Error::invalid_params()
592                    .with_data("this command requires the proxy capability"),
593            )?;
594            return Ok(());
595        }
596
597        request = request.remove_meta_capability(Proxy);
598        request_cx
599            .send_request_to_successor(request)
600            .await_when_result_received(async move |mut result| {
601                result = result.map(|r| r.add_meta_capability(Proxy));
602                request_cx.respond_with_result(result)
603            })
604    }
605}
606
607/// Extension trait for [`JrConnectionCx`](sacp::JrConnectionCx) that adds methods for sending to successor.
608///
609/// This trait provides convenient methods for proxies to forward messages downstream
610/// to their successor component (next proxy or agent). Messages are automatically
611/// wrapped in the `_proxy/successor/send/*` protocol format.
612///
613/// # Example
614///
615/// ```rust,ignore
616/// // Example using ACP request types
617/// use sacp::proxy::JrCxExt;
618/// use agent_client_protocol_schema_schema::agent::PromptRequest;
619///
620/// async fn forward_prompt(cx: &JsonRpcCx, prompt: PromptRequest) {
621///     let response = cx.send_request_to_successor(prompt).recv().await?;
622///     // response is the typed response from the successor
623/// }
624/// ```
625pub trait JrCxExt {
626    /// Send a request to the successor component.
627    ///
628    /// The request is automatically wrapped in a `ToSuccessorRequest` and sent
629    /// using the `_proxy/successor/send/request` method. The orchestrator routes
630    /// it to the next component in the chain.
631    ///
632    /// # Returns
633    ///
634    /// Returns a [`JrResponse`](sacp::JrResponse) that can be awaited to get the successor's
635    /// response.
636    ///
637    /// # Example
638    ///
639    /// ```rust,ignore
640    /// use sacp::proxy::JrCxExt;
641    /// use agent_client_protocol_schema_schema::agent::PromptRequest;
642    ///
643    /// let prompt = PromptRequest { /* ... */ };
644    /// let response = cx.send_request_to_successor(prompt).recv().await?;
645    /// // response is the typed PromptResponse
646    /// ```
647    fn send_request_to_successor<Req: JrRequest>(
648        &self,
649        request: Req,
650    ) -> sacp::JrResponse<Req::Response>;
651
652    /// Send a notification to the successor component.
653    ///
654    /// The notification is automatically wrapped in a `ToSuccessorNotification`
655    /// and sent using the `_proxy/successor/send/notification` method. The
656    /// orchestrator routes it to the next component in the chain.
657    ///
658    /// Notifications are fire-and-forget - no response is expected.
659    ///
660    /// # Errors
661    ///
662    /// Returns an error if the notification fails to send.
663    fn send_notification_to_successor<Req: JrNotification>(
664        &self,
665        notification: Req,
666    ) -> Result<(), sacp::Error>;
667}
668
669impl JrCxExt for JrConnectionCx {
670    fn send_request_to_successor<Req: JrRequest>(
671        &self,
672        request: Req,
673    ) -> sacp::JrResponse<Req::Response> {
674        let wrapper = SuccessorRequest { request };
675        self.send_request(wrapper)
676    }
677
678    fn send_notification_to_successor<Req: JrNotification>(
679        &self,
680        notification: Req,
681    ) -> Result<(), sacp::Error> {
682        let wrapper = SuccessorNotification { notification };
683        self.send_notification(wrapper)
684    }
685}