sacp_proxy/
to_from_successor.rs

1use sacp::handler::ChainedHandler;
2use sacp::schema::{InitializeRequest, InitializeResponse};
3use sacp::{
4    Handled, JrConnectionCx, JrHandlerChain, JrMessage, JrMessageHandler, JrNotification,
5    JrRequest, JrRequestCx, MessageAndCx, MetaCapabilityExt, Proxy, UntypedMessage,
6};
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9
10use crate::mcp_server_registry::McpServiceRegistry;
11
12// Requests and notifications send between us and the successor
13// ============================================================
14
15const SUCCESSOR_REQUEST_METHOD: &str = "_proxy/successor/request";
16
17/// A request being sent to the successor component.
18///
19/// Used in `_proxy/successor/send` when the proxy wants to forward a request downstream.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SuccessorRequest<Req: JrRequest> {
22    /// The message to be sent to the successor component.
23    #[serde(flatten)]
24    pub request: Req,
25
26    /// Optional metadata
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub meta: Option<serde_json::Value>,
29}
30
31impl<Req: JrRequest> JrMessage for SuccessorRequest<Req> {
32    fn to_untyped_message(&self) -> Result<sacp::UntypedMessage, sacp::Error> {
33        sacp::UntypedMessage::new(
34            SUCCESSOR_REQUEST_METHOD,
35            SuccessorRequest {
36                request: self.request.to_untyped_message()?,
37                meta: self.meta.clone(),
38            },
39        )
40    }
41
42    fn method(&self) -> &str {
43        SUCCESSOR_REQUEST_METHOD
44    }
45
46    fn parse_request(method: &str, params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
47        if method == SUCCESSOR_REQUEST_METHOD {
48            match sacp::util::json_cast::<_, SuccessorRequest<sacp::UntypedMessage>>(params) {
49                Ok(outer) => match Req::parse_request(&outer.request.method, &outer.request.params)
50                {
51                    Some(Ok(request)) => Some(Ok(SuccessorRequest {
52                        request,
53                        meta: outer.meta,
54                    })),
55                    Some(Err(err)) => Some(Err(err)),
56                    None => None,
57                },
58                Err(err) => Some(Err(err)),
59            }
60        } else {
61            None
62        }
63    }
64
65    fn parse_notification(
66        _method: &str,
67        _params: &impl Serialize,
68    ) -> Option<Result<Self, sacp::Error>> {
69        None // Request, not notification
70    }
71}
72
73impl<Req: JrRequest> JrRequest for SuccessorRequest<Req> {
74    type Response = Req::Response;
75}
76
77const SUCCESSOR_NOTIFICATION_METHOD: &str = "_proxy/successor/notification";
78
79/// A notification being sent to the successor component.
80///
81/// Used in `_proxy/successor/send` when the proxy wants to forward a notification downstream.
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct SuccessorNotification<Req: JrNotification> {
84    /// The message to be sent to the successor component.
85    #[serde(flatten)]
86    pub notification: Req,
87
88    /// Optional metadata
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub meta: Option<serde_json::Value>,
91}
92
93impl<Req: JrNotification> JrMessage for SuccessorNotification<Req> {
94    fn to_untyped_message(&self) -> Result<sacp::UntypedMessage, sacp::Error> {
95        sacp::UntypedMessage::new(
96            SUCCESSOR_NOTIFICATION_METHOD,
97            SuccessorNotification {
98                notification: self.notification.to_untyped_message()?,
99                meta: self.meta.clone(),
100            },
101        )
102    }
103
104    fn method(&self) -> &str {
105        SUCCESSOR_NOTIFICATION_METHOD
106    }
107
108    fn parse_request(_method: &str, _params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
109        None // Notification, not request
110    }
111
112    fn parse_notification(
113        method: &str,
114        params: &impl Serialize,
115    ) -> Option<Result<Self, sacp::Error>> {
116        if method == SUCCESSOR_NOTIFICATION_METHOD {
117            match sacp::util::json_cast::<_, SuccessorNotification<sacp::UntypedMessage>>(params) {
118                Ok(outer) => match Req::parse_notification(
119                    &outer.notification.method,
120                    &outer.notification.params,
121                ) {
122                    Some(Ok(notification)) => Some(Ok(SuccessorNotification {
123                        notification,
124                        meta: outer.meta,
125                    })),
126                    Some(Err(err)) => Some(Err(err)),
127                    None => None,
128                },
129                Err(err) => Some(Err(err)),
130            }
131        } else {
132            None
133        }
134    }
135}
136
137impl<Req: JrNotification> JrNotification for SuccessorNotification<Req> {}
138
139// Proxy methods
140// ============================================================
141
142/// Extension trait for JrConnection that adds proxy-specific functionality
143pub trait AcpProxyExt<H: JrMessageHandler> {
144    /// Adds a handler for requests received from the successor component.
145    ///
146    /// The provided handler will receive unwrapped ACP messages - the
147    /// `_proxy/successor/receive/*` protocol wrappers are handled automatically.
148    /// Your handler processes normal ACP requests and notifications as if it were
149    /// a regular ACP component.
150    ///
151    /// # Example
152    ///
153    /// ```rust,ignore
154    /// # use sacp::proxy::JrConnectionExt;
155    /// # use sacp::{JrConnection, JrHandler};
156    /// # struct MyHandler;
157    /// # impl JrHandler for MyHandler {}
158    /// # async fn example() -> Result<(), sacp::Error> {
159    /// JrConnection::new(tokio::io::stdin(), tokio::io::stdout())
160    ///     .on_receive_from_successor(MyHandler)
161    ///     .serve()
162    ///     .await?;
163    /// # Ok(())
164    /// # }
165    /// ```
166    fn on_receive_request_from_successor<R, F>(
167        self,
168        op: F,
169    ) -> JrHandlerChain<ChainedHandler<H, RequestFromSuccessorHandler<R, F>>>
170    where
171        R: JrRequest,
172        F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>;
173
174    /// Adds a handler for notifications received from the successor component.
175    ///
176    /// The provided handler will receive unwrapped ACP messages - the
177    /// `_proxy/successor/receive/*` protocol wrappers are handled automatically.
178    /// Your handler processes normal ACP requests and notifications as if it were
179    /// a regular ACP component.
180    ///
181    /// # Example
182    ///
183    /// ```rust,ignore
184    /// # use sacp::proxy::JrConnectionExt;
185    /// # use sacp::{JrConnection, JrHandler};
186    /// # struct MyHandler;
187    /// # impl JrHandler for MyHandler {}
188    /// # async fn example() -> Result<(), sacp::Error> {
189    /// JrConnection::new()
190    ///     .on_receive_from_successor(MyHandler)
191    ///     .serve()
192    ///     .await?;
193    /// # Ok(())
194    /// # }
195    /// ```
196    fn on_receive_notification_from_successor<N, F>(
197        self,
198        op: F,
199    ) -> JrHandlerChain<ChainedHandler<H, NotificationFromSuccessorHandler<N, F>>>
200    where
201        N: JrNotification,
202        F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>;
203
204    /// Adds a handler for messages received from the successor component.
205    ///
206    /// The provided handler will receive unwrapped ACP messages - the
207    /// `_proxy/successor/receive/*` protocol wrappers are handled automatically.
208    /// Your handler processes normal ACP requests and notifications as if it were
209    /// a regular ACP component.
210    fn on_receive_message_from_successor<R, N, F>(
211        self,
212        op: F,
213    ) -> JrHandlerChain<ChainedHandler<H, MessageFromSuccessorHandler<R, N, F>>>
214    where
215        R: JrRequest,
216        N: JrNotification,
217        F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>;
218
219    /// Installs a proxy layer that proxies all requests/notifications to/from the successor.
220    /// This is typically the last component in the chain.
221    fn proxy(self) -> JrHandlerChain<ChainedHandler<H, ProxyHandler>>;
222
223    /// Provide MCP servers to downstream successors.
224    /// This layer will modify `session/new` requests to include those MCP servers
225    /// (unless you intercept them earlier).
226    fn provide_mcp(
227        self,
228        registry: impl AsRef<McpServiceRegistry>,
229    ) -> JrHandlerChain<ChainedHandler<H, McpServiceRegistry>>;
230}
231
232impl<H: JrMessageHandler> AcpProxyExt<H> for JrHandlerChain<H> {
233    fn on_receive_request_from_successor<R, F>(
234        self,
235        op: F,
236    ) -> JrHandlerChain<ChainedHandler<H, RequestFromSuccessorHandler<R, F>>>
237    where
238        R: JrRequest,
239        F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
240    {
241        self.with_handler(RequestFromSuccessorHandler::new(op))
242    }
243
244    fn on_receive_notification_from_successor<N, F>(
245        self,
246        op: F,
247    ) -> JrHandlerChain<ChainedHandler<H, NotificationFromSuccessorHandler<N, F>>>
248    where
249        N: JrNotification,
250        F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
251    {
252        self.with_handler(NotificationFromSuccessorHandler::new(op))
253    }
254
255    fn on_receive_message_from_successor<R, N, F>(
256        self,
257        op: F,
258    ) -> JrHandlerChain<ChainedHandler<H, MessageFromSuccessorHandler<R, N, F>>>
259    where
260        R: JrRequest,
261        N: JrNotification,
262        F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
263    {
264        self.with_handler(MessageFromSuccessorHandler::new(op))
265    }
266
267    fn proxy(self) -> JrHandlerChain<ChainedHandler<H, ProxyHandler>> {
268        self.with_handler(ProxyHandler {})
269    }
270
271    fn provide_mcp(
272        self,
273        registry: impl AsRef<McpServiceRegistry>,
274    ) -> JrHandlerChain<ChainedHandler<H, McpServiceRegistry>> {
275        self.with_handler(registry.as_ref().clone())
276    }
277}
278
279/// Handler to process a message of type `R` coming from the successor component.
280pub struct MessageFromSuccessorHandler<R, N, F>
281where
282    R: JrRequest,
283    N: JrNotification,
284    F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
285{
286    handler: F,
287    phantom: PhantomData<fn(R, N)>,
288}
289
290impl<R, N, F> MessageFromSuccessorHandler<R, N, F>
291where
292    R: JrRequest,
293    N: JrNotification,
294    F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
295{
296    /// Creates a new handler for requests from the successor
297    pub fn new(handler: F) -> Self {
298        Self {
299            handler,
300            phantom: PhantomData,
301        }
302    }
303}
304
305impl<R, N, F> JrMessageHandler for MessageFromSuccessorHandler<R, N, F>
306where
307    R: JrRequest,
308    N: JrNotification,
309    F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
310{
311    async fn handle_message(
312        &mut self,
313        message: sacp::MessageAndCx,
314    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
315        match message {
316            MessageAndCx::Request(request, request_cx) => {
317                tracing::trace!(
318                    request_type = std::any::type_name::<R>(),
319                    message = ?request,
320                    "MessageFromSuccessorHandler::handle_message"
321                );
322                match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
323                    Some(Ok(request)) => {
324                        tracing::trace!(
325                            ?request,
326                            "RequestHandler::handle_request: parse completed"
327                        );
328                        (self.handler)(MessageAndCx::Request(request.request, request_cx.cast()))
329                            .await?;
330                        Ok(Handled::Yes)
331                    }
332                    Some(Err(err)) => {
333                        tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
334                        Err(err)
335                    }
336                    None => {
337                        tracing::trace!("RequestHandler::handle_request: parse failed");
338                        Ok(Handled::No(MessageAndCx::Request(request, request_cx)))
339                    }
340                }
341            }
342            MessageAndCx::Notification(notification, connection_cx) => {
343                tracing::trace!(
344                    ?notification,
345                    "NotificationFromSuccessorHandler::handle_message"
346                );
347                match <SuccessorNotification<N>>::parse_notification(
348                    &notification.method,
349                    &notification.params,
350                ) {
351                    Some(Ok(notification)) => {
352                        tracing::trace!(
353                            ?notification,
354                            "NotificationFromSuccessorHandler::handle_message: parse completed"
355                        );
356                        (self.handler)(MessageAndCx::Notification(
357                            notification.notification,
358                            connection_cx,
359                        ))
360                        .await?;
361                        Ok(Handled::Yes)
362                    }
363                    Some(Err(err)) => {
364                        tracing::trace!(
365                            ?err,
366                            "NotificationFromSuccessorHandler::handle_message: parse errored"
367                        );
368                        Err(err)
369                    }
370                    None => {
371                        tracing::trace!(
372                            "NotificationFromSuccessorHandler::handle_message: parse failed"
373                        );
374                        Ok(Handled::No(MessageAndCx::Notification(
375                            notification,
376                            connection_cx,
377                        )))
378                    }
379                }
380            }
381        }
382    }
383
384    fn describe_chain(&self) -> impl std::fmt::Debug {
385        std::any::type_name::<R>()
386    }
387}
388
389/// Handler to process a request of type `R` coming from the successor component.
390pub struct RequestFromSuccessorHandler<R, F>
391where
392    R: JrRequest,
393    F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
394{
395    handler: F,
396    phantom: PhantomData<fn(R)>,
397}
398
399impl<R, F> RequestFromSuccessorHandler<R, F>
400where
401    R: JrRequest,
402    F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
403{
404    /// Creates a new handler for requests from the successor
405    pub fn new(handler: F) -> Self {
406        Self {
407            handler,
408            phantom: PhantomData,
409        }
410    }
411}
412
413impl<R, F> JrMessageHandler for RequestFromSuccessorHandler<R, F>
414where
415    R: JrRequest,
416    F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
417{
418    async fn handle_message(
419        &mut self,
420        message: sacp::MessageAndCx,
421    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
422        let MessageAndCx::Request(request, cx) = message else {
423            return Ok(Handled::No(message));
424        };
425
426        tracing::debug!(
427            request_type = std::any::type_name::<R>(),
428            message = ?request,
429            "RequestHandler::handle_request"
430        );
431        match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
432            Some(Ok(request)) => {
433                tracing::trace!(?request, "RequestHandler::handle_request: parse completed");
434                (self.handler)(request.request, cx.cast()).await?;
435                Ok(Handled::Yes)
436            }
437            Some(Err(err)) => {
438                tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
439                Err(err)
440            }
441            None => {
442                tracing::trace!("RequestHandler::handle_request: parse failed");
443                Ok(Handled::No(MessageAndCx::Request(request, cx)))
444            }
445        }
446    }
447
448    fn describe_chain(&self) -> impl std::fmt::Debug {
449        std::any::type_name::<R>()
450    }
451}
452
453/// Handler to process a notification of type `N` coming from the successor component.
454pub struct NotificationFromSuccessorHandler<N, F>
455where
456    N: JrNotification,
457    F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
458{
459    handler: F,
460    phantom: PhantomData<fn(N)>,
461}
462
463impl<N, F> NotificationFromSuccessorHandler<N, F>
464where
465    N: JrNotification,
466    F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
467{
468    /// Creates a new handler for notifications from the successor
469    pub fn new(handler: F) -> Self {
470        Self {
471            handler,
472            phantom: PhantomData,
473        }
474    }
475}
476
477impl<N, F> JrMessageHandler for NotificationFromSuccessorHandler<N, F>
478where
479    N: JrNotification,
480    F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
481{
482    async fn handle_message(
483        &mut self,
484        message: sacp::MessageAndCx,
485    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
486        let MessageAndCx::Notification(message, cx) = message else {
487            return Ok(Handled::No(message));
488        };
489
490        match <SuccessorNotification<N>>::parse_notification(&message.method, &message.params) {
491            Some(Ok(notification)) => {
492                tracing::trace!(
493                    ?notification,
494                    "NotificationFromSuccessorHandler::handle_request: parse completed"
495                );
496                (self.handler)(notification.notification, cx).await?;
497                Ok(Handled::Yes)
498            }
499            Some(Err(err)) => {
500                tracing::trace!(
501                    ?err,
502                    "NotificationFromSuccessorHandler::handle_request: parse errored"
503                );
504                Err(err)
505            }
506            None => {
507                tracing::trace!("NotificationFromSuccessorHandler::handle_request: parse failed");
508                Ok(Handled::No(MessageAndCx::Notification(message, cx)))
509            }
510        }
511    }
512
513    fn describe_chain(&self) -> impl std::fmt::Debug {
514        format!("FromSuccessor<{}>", std::any::type_name::<N>())
515    }
516}
517
518/// Handler for the "default proxy" behavior.
519pub struct ProxyHandler {}
520
521impl JrMessageHandler for ProxyHandler {
522    fn describe_chain(&self) -> impl std::fmt::Debug {
523        "proxy"
524    }
525
526    async fn handle_message(
527        &mut self,
528        message: sacp::MessageAndCx,
529    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
530        tracing::debug!(
531            message = ?message.message(),
532            "ProxyHandler::handle_request"
533        );
534
535        match message {
536            MessageAndCx::Request(request, request_cx) => {
537                // If we receive a request from the successor, send it to our predecessor.
538                if let Some(result) = <SuccessorRequest<UntypedMessage>>::parse_request(
539                    &request.method,
540                    &request.params,
541                ) {
542                    let request = result?;
543                    request_cx
544                        .connection_cx()
545                        .send_request(request.request)
546                        .forward_to_request_cx(request_cx)?;
547                    return Ok(Handled::Yes);
548                }
549
550                // If we receive "Initialize", require the proxy capability (and remove it)
551                if let Some(result) =
552                    InitializeRequest::parse_request(&request.method, &request.params)
553                {
554                    let request = result?;
555                    return self
556                        .forward_initialize(request, request_cx.cast())
557                        .await
558                        .map(|()| Handled::Yes);
559                }
560
561                // If we receive any other request, send it to our successor.
562                request_cx
563                    .connection_cx()
564                    .send_request_to_successor(request)
565                    .forward_to_request_cx(request_cx)?;
566                Ok(Handled::Yes)
567            }
568
569            MessageAndCx::Notification(notification, cx) => {
570                // If we receive a request from the successor, send it to our predecessor.
571                if let Some(result) = <SuccessorNotification<UntypedMessage>>::parse_notification(
572                    &notification.method,
573                    &notification.params,
574                ) {
575                    match result {
576                        Ok(r) => {
577                            cx.send_notification(r.notification)?;
578                            return Ok(Handled::Yes);
579                        }
580                        Err(err) => return Err(err),
581                    }
582                }
583
584                // If we receive any other request, send it to our successor.
585                cx.send_notification_to_successor(notification)?;
586                Ok(Handled::Yes)
587            }
588        }
589    }
590}
591
592impl ProxyHandler {
593    /// Proxy initialization requires (1) a `Proxy` capability to be
594    /// provided by the conductor and (2) provides a `Proxy` capability
595    /// in our response.
596    async fn forward_initialize(
597        &mut self,
598        mut request: InitializeRequest,
599        request_cx: JrRequestCx<InitializeResponse>,
600    ) -> Result<(), sacp::Error> {
601        tracing::debug!(
602            method = request_cx.method(),
603            params = ?request,
604            "ProxyHandler::forward_initialize"
605        );
606
607        if !request.has_meta_capability(Proxy) {
608            request_cx.respond_with_error(
609                sacp::Error::invalid_params()
610                    .with_data("this command requires the proxy capability"),
611            )?;
612            return Ok(());
613        }
614
615        request = request.remove_meta_capability(Proxy);
616        request_cx
617            .connection_cx()
618            .send_request_to_successor(request)
619            .await_when_result_received(async move |mut result| {
620                result = result.map(|r| r.add_meta_capability(Proxy));
621                request_cx.respond_with_result(result)
622            })
623    }
624}
625
626/// Extension trait for [`JrConnectionCx`](sacp::JrConnectionCx) that adds methods for sending to successor.
627///
628/// This trait provides convenient methods for proxies to forward messages downstream
629/// to their successor component (next proxy or agent). Messages are automatically
630/// wrapped in the `_proxy/successor/send/*` protocol format.
631///
632/// # Example
633///
634/// ```rust,ignore
635/// // Example using ACP request types
636/// use sacp::proxy::JrCxExt;
637/// use agent_client_protocol_schema_schema::agent::PromptRequest;
638///
639/// async fn forward_prompt(cx: &JsonRpcCx, prompt: PromptRequest) {
640///     let response = cx.send_request_to_successor(prompt).recv().await?;
641///     // response is the typed response from the successor
642/// }
643/// ```
644pub trait JrCxExt {
645    /// Send a request to the successor component.
646    ///
647    /// The request is automatically wrapped in a `ToSuccessorRequest` and sent
648    /// using the `_proxy/successor/send/request` method. The orchestrator routes
649    /// it to the next component in the chain.
650    ///
651    /// # Returns
652    ///
653    /// Returns a [`JrResponse`](sacp::JrResponse) that can be awaited to get the successor's
654    /// response.
655    ///
656    /// # Example
657    ///
658    /// ```rust,ignore
659    /// use sacp::proxy::JrCxExt;
660    /// use agent_client_protocol_schema_schema::agent::PromptRequest;
661    ///
662    /// let prompt = PromptRequest { /* ... */ };
663    /// let response = cx.send_request_to_successor(prompt).recv().await?;
664    /// // response is the typed PromptResponse
665    /// ```
666    fn send_request_to_successor<Req: JrRequest>(
667        &self,
668        request: Req,
669    ) -> sacp::JrResponse<Req::Response>;
670
671    /// Send a notification to the successor component.
672    ///
673    /// The notification is automatically wrapped in a `ToSuccessorNotification`
674    /// and sent using the `_proxy/successor/send/notification` method. The
675    /// orchestrator routes it to the next component in the chain.
676    ///
677    /// Notifications are fire-and-forget - no response is expected.
678    ///
679    /// # Errors
680    ///
681    /// Returns an error if the notification fails to send.
682    fn send_notification_to_successor<Req: JrNotification>(
683        &self,
684        notification: Req,
685    ) -> Result<(), sacp::Error>;
686}
687
688impl JrCxExt for JrConnectionCx {
689    fn send_request_to_successor<Req: JrRequest>(
690        &self,
691        request: Req,
692    ) -> sacp::JrResponse<Req::Response> {
693        let wrapper = SuccessorRequest {
694            request,
695            meta: None,
696        };
697        self.send_request(wrapper)
698    }
699
700    fn send_notification_to_successor<Req: JrNotification>(
701        &self,
702        notification: Req,
703    ) -> Result<(), sacp::Error> {
704        let wrapper = SuccessorNotification {
705            notification,
706            meta: None,
707        };
708        self.send_notification(wrapper)
709    }
710}