sacp_proxy/
to_from_successor.rs

1use futures::{AsyncRead, AsyncWrite};
2use sacp::handler::ChainHandler;
3use sacp::schema::{InitializeRequest, InitializeResponse};
4use sacp::{
5    Handled, JrConnection, JrConnectionCx, JrHandler, JrMessage, JrNotification, JrRequest,
6    JrRequestCx, MessageAndCx, MetaCapabilityExt, Proxy, UntypedMessage,
7};
8use serde::{Deserialize, Serialize};
9use std::marker::PhantomData;
10
11use crate::mcp_server::McpServiceRegistry;
12
13// Requests and notifications send between us and the successor
14// ============================================================
15
16const SUCCESSOR_REQUEST_METHOD: &str = "_proxy/successor/request";
17
18/// A request being sent to the successor component.
19///
20/// Used in `_proxy/successor/send` when the proxy wants to forward a request downstream.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct SuccessorRequest<Req: JrRequest> {
23    /// The message to be sent to the successor component.
24    #[serde(flatten)]
25    pub request: Req,
26}
27
28impl<Req: JrRequest> JrMessage for SuccessorRequest<Req> {
29    fn into_untyped_message(self) -> Result<sacp::UntypedMessage, sacp::Error> {
30        sacp::UntypedMessage::new(
31            SUCCESSOR_REQUEST_METHOD,
32            SuccessorRequest {
33                request: self.request.into_untyped_message()?,
34            },
35        )
36    }
37
38    fn method(&self) -> &str {
39        SUCCESSOR_REQUEST_METHOD
40    }
41
42    fn parse_request(method: &str, params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
43        if method == SUCCESSOR_REQUEST_METHOD {
44            match sacp::util::json_cast::<_, SuccessorRequest<sacp::UntypedMessage>>(params) {
45                Ok(outer) => match Req::parse_request(&outer.request.method, &outer.request.params)
46                {
47                    Some(Ok(request)) => Some(Ok(SuccessorRequest { request })),
48                    Some(Err(err)) => Some(Err(err)),
49                    None => None,
50                },
51                Err(err) => Some(Err(err)),
52            }
53        } else {
54            None
55        }
56    }
57
58    fn parse_notification(
59        _method: &str,
60        _params: &impl Serialize,
61    ) -> Option<Result<Self, sacp::Error>> {
62        None // Request, not notification
63    }
64}
65
66impl<Req: JrRequest> JrRequest for SuccessorRequest<Req> {
67    type Response = Req::Response;
68}
69
70const SUCCESSOR_NOTIFICATION_METHOD: &str = "_proxy/successor/notification";
71
72/// A notification being sent to the successor component.
73///
74/// Used in `_proxy/successor/send` when the proxy wants to forward a notification downstream.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SuccessorNotification<Req: JrNotification> {
77    /// The message to be sent to the successor component.
78    #[serde(flatten)]
79    pub notification: Req,
80}
81
82impl<Req: JrNotification> JrMessage for SuccessorNotification<Req> {
83    fn into_untyped_message(self) -> Result<sacp::UntypedMessage, sacp::Error> {
84        sacp::UntypedMessage::new(
85            SUCCESSOR_NOTIFICATION_METHOD,
86            SuccessorNotification {
87                notification: self.notification.into_untyped_message()?,
88            },
89        )
90    }
91
92    fn method(&self) -> &str {
93        SUCCESSOR_NOTIFICATION_METHOD
94    }
95
96    fn parse_request(_method: &str, _params: &impl Serialize) -> Option<Result<Self, sacp::Error>> {
97        None // Notification, not request
98    }
99
100    fn parse_notification(
101        method: &str,
102        params: &impl Serialize,
103    ) -> Option<Result<Self, sacp::Error>> {
104        if method == SUCCESSOR_NOTIFICATION_METHOD {
105            match sacp::util::json_cast::<_, SuccessorNotification<sacp::UntypedMessage>>(params) {
106                Ok(outer) => match Req::parse_notification(
107                    &outer.notification.method,
108                    &outer.notification.params,
109                ) {
110                    Some(Ok(notification)) => Some(Ok(SuccessorNotification { notification })),
111                    Some(Err(err)) => Some(Err(err)),
112                    None => None,
113                },
114                Err(err) => Some(Err(err)),
115            }
116        } else {
117            None
118        }
119    }
120}
121
122impl<Req: JrNotification> JrNotification for SuccessorNotification<Req> {}
123
124// Proxy methods
125// ============================================================
126
127/// Extension trait for JrConnection that adds proxy-specific functionality
128pub trait AcpProxyExt<OB: AsyncWrite, IB: AsyncRead, H: JrHandler> {
129    /// Adds a handler for requests received from the successor component.
130    ///
131    /// The provided handler will receive unwrapped ACP messages - the
132    /// `_proxy/successor/receive/*` protocol wrappers are handled automatically.
133    /// Your handler processes normal ACP requests and notifications as if it were
134    /// a regular ACP component.
135    ///
136    /// # Example
137    ///
138    /// ```rust,ignore
139    /// # use sacp::proxy::JrConnectionExt;
140    /// # use sacp::{JrConnection, JrHandler};
141    /// # struct MyHandler;
142    /// # impl JrHandler for MyHandler {}
143    /// # async fn example() -> Result<(), sacp::Error> {
144    /// JrConnection::new(tokio::io::stdin(), tokio::io::stdout())
145    ///     .on_receive_from_successor(MyHandler)
146    ///     .serve()
147    ///     .await?;
148    /// # Ok(())
149    /// # }
150    /// ```
151    fn on_receive_request_from_successor<R, F>(
152        self,
153        op: F,
154    ) -> JrConnection<OB, IB, ChainHandler<H, RequestFromSuccessorHandler<R, F>>>
155    where
156        R: JrRequest,
157        F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>;
158
159    /// Adds a handler for notifications received from the successor component.
160    ///
161    /// The provided handler will receive unwrapped ACP messages - the
162    /// `_proxy/successor/receive/*` protocol wrappers are handled automatically.
163    /// Your handler processes normal ACP requests and notifications as if it were
164    /// a regular ACP component.
165    ///
166    /// # Example
167    ///
168    /// ```rust,ignore
169    /// # use sacp::proxy::JrConnectionExt;
170    /// # use sacp::{JrConnection, JrHandler};
171    /// # struct MyHandler;
172    /// # impl JrHandler for MyHandler {}
173    /// # async fn example() -> Result<(), sacp::Error> {
174    /// JrConnection::new(tokio::io::stdin(), tokio::io::stdout())
175    ///     .on_receive_from_successor(MyHandler)
176    ///     .serve()
177    ///     .await?;
178    /// # Ok(())
179    /// # }
180    /// ```
181    fn on_receive_notification_from_successor<N, F>(
182        self,
183        op: F,
184    ) -> JrConnection<OB, IB, ChainHandler<H, NotificationFromSuccessorHandler<N, F>>>
185    where
186        N: JrNotification,
187        F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>;
188
189    /// Adds a handler for messages received from the successor component.
190    ///
191    /// The provided handler will receive unwrapped ACP messages - the
192    /// `_proxy/successor/receive/*` protocol wrappers are handled automatically.
193    /// Your handler processes normal ACP requests and notifications as if it were
194    /// a regular ACP component.
195    fn on_receive_message_from_successor<R, N, F>(
196        self,
197        op: F,
198    ) -> JrConnection<OB, IB, ChainHandler<H, MessageFromSuccessorHandler<R, N, F>>>
199    where
200        R: JrRequest,
201        N: JrNotification,
202        F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>;
203
204    /// Installs a proxy layer that proxies all requests/notifications to/from the successor.
205    /// This is typically the last component in the chain.
206    fn proxy(self) -> JrConnection<OB, IB, ChainHandler<H, ProxyHandler>>;
207
208    /// Provide MCP servers to downstream successors.
209    /// This layer will modify `session/new` requests to include those MCP servers
210    /// (unless you intercept them earlier).
211    fn provide_mcp(
212        self,
213        registry: impl AsRef<McpServiceRegistry>,
214    ) -> JrConnection<OB, IB, ChainHandler<H, McpServiceRegistry>>;
215}
216
217impl<OB, IB, H> AcpProxyExt<OB, IB, H> for JrConnection<OB, IB, H>
218where
219    OB: AsyncWrite,
220    IB: AsyncRead,
221    H: JrHandler,
222{
223    fn on_receive_request_from_successor<R, F>(
224        self,
225        op: F,
226    ) -> JrConnection<OB, IB, ChainHandler<H, RequestFromSuccessorHandler<R, F>>>
227    where
228        R: JrRequest,
229        F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
230    {
231        self.chain_handler(RequestFromSuccessorHandler::new(op))
232    }
233
234    fn on_receive_notification_from_successor<N, F>(
235        self,
236        op: F,
237    ) -> JrConnection<OB, IB, ChainHandler<H, NotificationFromSuccessorHandler<N, F>>>
238    where
239        N: JrNotification,
240        F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
241    {
242        self.chain_handler(NotificationFromSuccessorHandler::new(op))
243    }
244
245    fn on_receive_message_from_successor<R, N, F>(
246        self,
247        op: F,
248    ) -> JrConnection<OB, IB, ChainHandler<H, MessageFromSuccessorHandler<R, N, F>>>
249    where
250        R: JrRequest,
251        N: JrNotification,
252        F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
253    {
254        self.chain_handler(MessageFromSuccessorHandler::new(op))
255    }
256
257    fn proxy(self) -> JrConnection<OB, IB, ChainHandler<H, ProxyHandler>> {
258        self.chain_handler(ProxyHandler {})
259    }
260
261    fn provide_mcp(
262        self,
263        registry: impl AsRef<McpServiceRegistry>,
264    ) -> JrConnection<OB, IB, ChainHandler<H, McpServiceRegistry>> {
265        self.chain_handler(registry.as_ref().clone())
266    }
267}
268
269/// Handler to process a message of type `R` coming from the successor component.
270pub struct MessageFromSuccessorHandler<R, N, F>
271where
272    R: JrRequest,
273    N: JrNotification,
274    F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
275{
276    handler: F,
277    phantom: PhantomData<fn(R, N)>,
278}
279
280impl<R, N, F> MessageFromSuccessorHandler<R, N, F>
281where
282    R: JrRequest,
283    N: JrNotification,
284    F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
285{
286    /// Creates a new handler for requests from the successor
287    pub fn new(handler: F) -> Self {
288        Self {
289            handler,
290            phantom: PhantomData,
291        }
292    }
293}
294
295impl<R, N, F> JrHandler for MessageFromSuccessorHandler<R, N, F>
296where
297    R: JrRequest,
298    N: JrNotification,
299    F: AsyncFnMut(MessageAndCx<R, N>) -> Result<(), sacp::Error>,
300{
301    async fn handle_message(
302        &mut self,
303        message: sacp::MessageAndCx,
304    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
305        match message {
306            MessageAndCx::Request(request, request_cx) => {
307                tracing::trace!(
308                    request_type = std::any::type_name::<R>(),
309                    message = ?request,
310                    "MessageFromSuccessorHandler::handle_message"
311                );
312                match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
313                    Some(Ok(request)) => {
314                        tracing::trace!(
315                            ?request,
316                            "RequestHandler::handle_request: parse completed"
317                        );
318                        (self.handler)(MessageAndCx::Request(request.request, request_cx.cast()))
319                            .await?;
320                        Ok(Handled::Yes)
321                    }
322                    Some(Err(err)) => {
323                        tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
324                        Err(err)
325                    }
326                    None => {
327                        tracing::trace!("RequestHandler::handle_request: parse failed");
328                        Ok(Handled::No(MessageAndCx::Request(request, request_cx)))
329                    }
330                }
331            }
332            MessageAndCx::Notification(notification, connection_cx) => {
333                tracing::trace!(
334                    ?notification,
335                    "NotificationFromSuccessorHandler::handle_message"
336                );
337                match <SuccessorNotification<N>>::parse_notification(
338                    &notification.method,
339                    &notification.params,
340                ) {
341                    Some(Ok(notification)) => {
342                        tracing::trace!(
343                            ?notification,
344                            "NotificationFromSuccessorHandler::handle_message: parse completed"
345                        );
346                        (self.handler)(MessageAndCx::Notification(
347                            notification.notification,
348                            connection_cx,
349                        ))
350                        .await?;
351                        Ok(Handled::Yes)
352                    }
353                    Some(Err(err)) => {
354                        tracing::trace!(
355                            ?err,
356                            "NotificationFromSuccessorHandler::handle_message: parse errored"
357                        );
358                        Err(err)
359                    }
360                    None => {
361                        tracing::trace!(
362                            "NotificationFromSuccessorHandler::handle_message: parse failed"
363                        );
364                        Ok(Handled::No(MessageAndCx::Notification(
365                            notification,
366                            connection_cx,
367                        )))
368                    }
369                }
370            }
371        }
372    }
373
374    fn describe_chain(&self) -> impl std::fmt::Debug {
375        std::any::type_name::<R>()
376    }
377}
378
379/// Handler to process a request of type `R` coming from the successor component.
380pub struct RequestFromSuccessorHandler<R, F>
381where
382    R: JrRequest,
383    F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
384{
385    handler: F,
386    phantom: PhantomData<fn(R)>,
387}
388
389impl<R, F> RequestFromSuccessorHandler<R, F>
390where
391    R: JrRequest,
392    F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
393{
394    /// Creates a new handler for requests from the successor
395    pub fn new(handler: F) -> Self {
396        Self {
397            handler,
398            phantom: PhantomData,
399        }
400    }
401}
402
403impl<R, F> JrHandler for RequestFromSuccessorHandler<R, F>
404where
405    R: JrRequest,
406    F: AsyncFnMut(R, JrRequestCx<R::Response>) -> Result<(), sacp::Error>,
407{
408    async fn handle_message(
409        &mut self,
410        message: sacp::MessageAndCx,
411    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
412        let MessageAndCx::Request(request, cx) = message else {
413            return Ok(Handled::No(message));
414        };
415
416        tracing::debug!(
417            request_type = std::any::type_name::<R>(),
418            message = ?request,
419            "RequestHandler::handle_request"
420        );
421        match <SuccessorRequest<R>>::parse_request(&request.method, &request.params) {
422            Some(Ok(request)) => {
423                tracing::trace!(?request, "RequestHandler::handle_request: parse completed");
424                (self.handler)(request.request, cx.cast()).await?;
425                Ok(Handled::Yes)
426            }
427            Some(Err(err)) => {
428                tracing::trace!(?err, "RequestHandler::handle_request: parse errored");
429                Err(err)
430            }
431            None => {
432                tracing::trace!("RequestHandler::handle_request: parse failed");
433                Ok(Handled::No(MessageAndCx::Request(request, cx)))
434            }
435        }
436    }
437
438    fn describe_chain(&self) -> impl std::fmt::Debug {
439        std::any::type_name::<R>()
440    }
441}
442
443/// Handler to process a notification of type `N` coming from the successor component.
444pub struct NotificationFromSuccessorHandler<N, F>
445where
446    N: JrNotification,
447    F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
448{
449    handler: F,
450    phantom: PhantomData<fn(N)>,
451}
452
453impl<N, F> NotificationFromSuccessorHandler<N, F>
454where
455    N: JrNotification,
456    F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
457{
458    /// Creates a new handler for notifications from the successor
459    pub fn new(handler: F) -> Self {
460        Self {
461            handler,
462            phantom: PhantomData,
463        }
464    }
465}
466
467impl<N, F> JrHandler for NotificationFromSuccessorHandler<N, F>
468where
469    N: JrNotification,
470    F: AsyncFnMut(N, JrConnectionCx) -> Result<(), sacp::Error>,
471{
472    async fn handle_message(
473        &mut self,
474        message: sacp::MessageAndCx,
475    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
476        let MessageAndCx::Notification(message, cx) = message else {
477            return Ok(Handled::No(message));
478        };
479
480        match <SuccessorNotification<N>>::parse_notification(&message.method, &message.params) {
481            Some(Ok(notification)) => {
482                tracing::trace!(
483                    ?notification,
484                    "NotificationFromSuccessorHandler::handle_request: parse completed"
485                );
486                (self.handler)(notification.notification, cx).await?;
487                Ok(Handled::Yes)
488            }
489            Some(Err(err)) => {
490                tracing::trace!(
491                    ?err,
492                    "NotificationFromSuccessorHandler::handle_request: parse errored"
493                );
494                Err(err)
495            }
496            None => {
497                tracing::trace!("NotificationFromSuccessorHandler::handle_request: parse failed");
498                Ok(Handled::No(MessageAndCx::Notification(message, cx)))
499            }
500        }
501    }
502
503    fn describe_chain(&self) -> impl std::fmt::Debug {
504        format!("FromSuccessor<{}>", std::any::type_name::<N>())
505    }
506}
507
508/// Handler for the "default proxy" behavior.
509pub struct ProxyHandler {}
510
511impl JrHandler for ProxyHandler {
512    fn describe_chain(&self) -> impl std::fmt::Debug {
513        "proxy"
514    }
515
516    async fn handle_message(
517        &mut self,
518        message: sacp::MessageAndCx,
519    ) -> Result<Handled<sacp::MessageAndCx>, sacp::Error> {
520        tracing::debug!(
521            message = ?message.message(),
522            "ProxyHandler::handle_request"
523        );
524
525        match message {
526            MessageAndCx::Request(request, request_cx) => {
527                // If we receive a request from the successor, send it to our predecessor.
528                if let Some(result) = <SuccessorRequest<UntypedMessage>>::parse_request(
529                    &request.method,
530                    &request.params,
531                ) {
532                    let request = result?;
533                    request_cx
534                        .send_request(request.request)
535                        .forward_to_request_cx(request_cx)?;
536                    return Ok(Handled::Yes);
537                }
538
539                // If we receive "Initialize", require the proxy capability (and remove it)
540                if let Some(result) =
541                    InitializeRequest::parse_request(&request.method, &request.params)
542                {
543                    let request = result?;
544                    return self
545                        .forward_initialize(request, request_cx.cast())
546                        .await
547                        .map(|()| Handled::Yes);
548                }
549
550                // If we receive any other request, send it to our successor.
551                request_cx
552                    .send_request_to_successor(request)
553                    .forward_to_request_cx(request_cx)?;
554                Ok(Handled::Yes)
555            }
556
557            MessageAndCx::Notification(notification, cx) => {
558                // If we receive a request from the successor, send it to our predecessor.
559                if let Some(result) = <SuccessorNotification<UntypedMessage>>::parse_notification(
560                    &notification.method,
561                    &notification.params,
562                ) {
563                    match result {
564                        Ok(r) => {
565                            cx.send_notification(r.notification)?;
566                            return Ok(Handled::Yes);
567                        }
568                        Err(err) => return Err(err),
569                    }
570                }
571
572                // If we receive any other request, send it to our successor.
573                cx.send_notification_to_successor(notification)?;
574                Ok(Handled::Yes)
575            }
576        }
577    }
578}
579
580impl ProxyHandler {
581    /// Proxy initialization requires (1) a `Proxy` capability to be
582    /// provided by the conductor and (2) provides a `Proxy` capability
583    /// in our response.
584    async fn forward_initialize(
585        &mut self,
586        mut request: InitializeRequest,
587        request_cx: JrRequestCx<InitializeResponse>,
588    ) -> Result<(), sacp::Error> {
589        tracing::debug!(
590            method = request_cx.method(),
591            params = ?request,
592            "ProxyHandler::forward_initialize"
593        );
594
595        if !request.has_meta_capability(Proxy) {
596            request_cx.respond_with_error(
597                sacp::Error::invalid_params()
598                    .with_data("this command requires the proxy capability"),
599            )?;
600            return Ok(());
601        }
602
603        request = request.remove_meta_capability(Proxy);
604        request_cx
605            .send_request_to_successor(request)
606            .await_when_result_received(async move |mut result| {
607                result = result.map(|r| r.add_meta_capability(Proxy));
608                request_cx.respond_with_result(result)
609            })
610    }
611}
612
613/// Extension trait for [`JrConnectionCx`](sacp::JrConnectionCx) that adds methods for sending to successor.
614///
615/// This trait provides convenient methods for proxies to forward messages downstream
616/// to their successor component (next proxy or agent). Messages are automatically
617/// wrapped in the `_proxy/successor/send/*` protocol format.
618///
619/// # Example
620///
621/// ```rust,ignore
622/// // Example using ACP request types
623/// use sacp::proxy::JrCxExt;
624/// use agent_client_protocol_schema_schema::agent::PromptRequest;
625///
626/// async fn forward_prompt(cx: &JsonRpcCx, prompt: PromptRequest) {
627///     let response = cx.send_request_to_successor(prompt).recv().await?;
628///     // response is the typed response from the successor
629/// }
630/// ```
631pub trait JrCxExt {
632    /// Send a request to the successor component.
633    ///
634    /// The request is automatically wrapped in a `ToSuccessorRequest` and sent
635    /// using the `_proxy/successor/send/request` method. The orchestrator routes
636    /// it to the next component in the chain.
637    ///
638    /// # Returns
639    ///
640    /// Returns a [`JrResponse`](sacp::JrResponse) that can be awaited to get the successor's
641    /// response.
642    ///
643    /// # Example
644    ///
645    /// ```rust,ignore
646    /// use sacp::proxy::JrCxExt;
647    /// use agent_client_protocol_schema_schema::agent::PromptRequest;
648    ///
649    /// let prompt = PromptRequest { /* ... */ };
650    /// let response = cx.send_request_to_successor(prompt).recv().await?;
651    /// // response is the typed PromptResponse
652    /// ```
653    fn send_request_to_successor<Req: JrRequest>(
654        &self,
655        request: Req,
656    ) -> sacp::JrResponse<Req::Response>;
657
658    /// Send a notification to the successor component.
659    ///
660    /// The notification is automatically wrapped in a `ToSuccessorNotification`
661    /// and sent using the `_proxy/successor/send/notification` method. The
662    /// orchestrator routes it to the next component in the chain.
663    ///
664    /// Notifications are fire-and-forget - no response is expected.
665    ///
666    /// # Errors
667    ///
668    /// Returns an error if the notification fails to send.
669    fn send_notification_to_successor<Req: JrNotification>(
670        &self,
671        notification: Req,
672    ) -> Result<(), sacp::Error>;
673}
674
675impl JrCxExt for JrConnectionCx {
676    fn send_request_to_successor<Req: JrRequest>(
677        &self,
678        request: Req,
679    ) -> sacp::JrResponse<Req::Response> {
680        let wrapper = SuccessorRequest { request };
681        self.send_request(wrapper)
682    }
683
684    fn send_notification_to_successor<Req: JrNotification>(
685        &self,
686        notification: Req,
687    ) -> Result<(), sacp::Error> {
688        let wrapper = SuccessorNotification { notification };
689        self.send_notification(wrapper)
690    }
691}