wrpc_transport/
serve.rs

1//! wRPC transport server handle
2
3use core::future::Future;
4use core::mem;
5use core::pin::Pin;
6
7use std::sync::Arc;
8
9use anyhow::{bail, Context as _};
10use futures::{SinkExt as _, Stream, TryStreamExt as _};
11use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
12use tokio_util::codec::{FramedRead, FramedWrite};
13use tracing::{debug, instrument, trace, Instrument as _, Span};
14
15use crate::{Deferred as _, Incoming, Index, TupleDecode, TupleEncode};
16
17/// Server-side handle to a wRPC transport
18pub trait Serve: Sync {
19    /// Transport-specific invocation context
20    type Context: Send + Sync + 'static;
21
22    /// Outgoing multiplexed byte stream
23    type Outgoing: AsyncWrite + Index<Self::Outgoing> + Send + Sync + Unpin + 'static;
24
25    /// Incoming multiplexed byte stream
26    type Incoming: AsyncRead + Index<Self::Incoming> + Send + Sync + Unpin + 'static;
27
28    /// Serve function `func` from instance `instance`
29    fn serve(
30        &self,
31        instance: &str,
32        func: &str,
33        paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
34    ) -> impl Future<
35        Output = anyhow::Result<
36            impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>>
37                + Send
38                + 'static,
39        >,
40    > + Send;
41}
42
43/// Extension trait for [Serve]
44pub trait ServeExt: Serve {
45    /// Serve function `func` from instance `instance` using typed `Params` and `Results`
46    #[instrument(level = "trace", skip(self, paths))]
47    fn serve_values<Params, Results>(
48        &self,
49        instance: &str,
50        func: &str,
51        paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
52    ) -> impl Future<
53        Output = anyhow::Result<
54            impl Stream<
55                    Item = anyhow::Result<(
56                        Self::Context,
57                        Params,
58                        Option<impl Future<Output = std::io::Result<()>> + Send + Unpin + 'static>,
59                        impl FnOnce(
60                                Results,
61                            ) -> Pin<
62                                Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>,
63                            > + Send
64                            + 'static,
65                    )>,
66                > + Send
67                + 'static,
68        >,
69    > + Send
70    where
71        Params: TupleDecode<Self::Incoming> + Send + 'static,
72        Results: TupleEncode<Self::Outgoing> + Send + 'static,
73        <Params::Decoder as tokio_util::codec::Decoder>::Error:
74            std::error::Error + Send + Sync + 'static,
75        <Results::Encoder as tokio_util::codec::Encoder<Results>>::Error:
76            std::error::Error + Send + Sync + 'static,
77    {
78        let span = Span::current();
79        async {
80            let invocations = self.serve(instance, func, paths).await?;
81            Ok(invocations.and_then(move |(cx, outgoing, incoming)| {
82                async {
83                    let mut dec = FramedRead::new(incoming, Params::Decoder::default());
84                    debug!("receiving sync parameters");
85                    let Some(params) = dec
86                        .try_next()
87                        .await
88                        .context("failed to receive sync parameters")?
89                    else {
90                        bail!("incomplete sync parameters")
91                    };
92                    trace!("received sync parameters");
93                    let rx = dec.decoder_mut().take_deferred();
94                    let buffer = mem::take(dec.read_buffer_mut());
95                    let span = Span::current();
96                    Ok((
97                        cx,
98                        params,
99                        rx.map(|f| {
100                            f(
101                                Incoming {
102                                    buffer,
103                                    inner: dec.into_inner(),
104                                },
105                                Vec::default(),
106                            )
107                        }),
108                        move |results| {
109                            Box::pin(
110                                async {
111                                    let mut enc =
112                                        FramedWrite::new(outgoing, Results::Encoder::default());
113                                    debug!("transmitting sync results");
114                                    enc.send(results)
115                                        .await
116                                        .context("failed to transmit synchronous results")?;
117                                    let tx = enc.encoder_mut().take_deferred();
118                                    let mut outgoing = enc.into_inner();
119                                    outgoing
120                                        .shutdown()
121                                        .await
122                                        .context("failed to shutdown synchronous return channel")?;
123                                    if let Some(tx) = tx {
124                                        debug!("transmitting async results");
125                                        tx(outgoing, Vec::default())
126                                            .await
127                                            .context("failed to write async results")?;
128                                    }
129                                    Ok(())
130                                }
131                                .instrument(span),
132                            ) as Pin<_>
133                        },
134                    ))
135                }
136                .instrument(span.clone())
137            }))
138        }
139    }
140}
141
142impl<T: Serve> ServeExt for T {}
143
144#[allow(dead_code)]
145#[cfg(test)]
146mod tests {
147    use bytes::Bytes;
148    use futures::{stream, StreamExt as _, TryStreamExt as _};
149
150    use crate::Captures;
151
152    use super::*;
153
154    async fn call_serve<T: Serve>(
155        s: &T,
156    ) -> anyhow::Result<Vec<(T::Context, T::Outgoing, T::Incoming)>> {
157        let st = stream::empty()
158            .chain({
159                s.serve(
160                    "foo",
161                    "bar",
162                    [Box::from([Some(42), None]), Box::from([None])],
163                )
164                .await
165                .unwrap()
166            })
167            .chain({
168                s.serve(
169                    "foo",
170                    "bar",
171                    vec![Box::from([Some(42), None]), Box::from([None])],
172                )
173                .await
174                .unwrap()
175            })
176            .chain({
177                s.serve(
178                    "foo",
179                    "bar",
180                    [Box::from([Some(42), None]), Box::from([None])].as_slice(),
181                )
182                .await
183                .unwrap()
184            });
185        tokio::spawn(async move { st.try_collect().await })
186            .await
187            .unwrap()
188    }
189
190    fn serve_lifetime<T: Serve>(
191        s: &T,
192    ) -> impl Future<
193        Output = anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<T::Context>> + 'static>>>,
194    > + Captures<'_> {
195        let fut = s.serve(
196            "foo",
197            "bar",
198            [Box::from([Some(42), None]), Box::from([None])],
199        );
200        async move {
201            let st = fut.await.unwrap();
202            Ok(Box::pin(st.and_then(|(cx, _, _)| async { Ok(cx) }))
203                as Pin<Box<dyn Stream<Item = _>>>)
204        }
205    }
206
207    fn serve_values_lifetime<T: Serve>(
208        s: &T,
209    ) -> impl Future<
210        Output = anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<T::Context>> + 'static>>>,
211    > + crate::Captures<'_> {
212        let fut = s.serve_values::<(Bytes,), (Bytes,)>(
213            "foo",
214            "bar",
215            [Box::from([Some(42), None]), Box::from([None])],
216        );
217        async move {
218            let st = fut.await.unwrap();
219            Ok(Box::pin(st.and_then(|(cx, _, _, tx)| async {
220                tx((Bytes::from("test"),)).await.unwrap();
221                Ok(cx)
222            })) as Pin<Box<dyn Stream<Item = _>>>)
223        }
224    }
225}