1use core::future::Future;
4use core::mem;
5use core::pin::Pin;
6
7use std::sync::Arc;
8
9use anyhow::{bail, Context as _};
10use futures::{SinkExt as _, Stream, TryStreamExt as _};
11use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
12use tokio_util::codec::{FramedRead, FramedWrite};
13use tracing::{debug, instrument, trace, Instrument as _, Span};
14
15use crate::{Deferred as _, Incoming, Index, TupleDecode, TupleEncode};
16
17pub trait Serve: Sync {
19 type Context: Send + Sync + 'static;
21
22 type Outgoing: AsyncWrite + Index<Self::Outgoing> + Send + Sync + Unpin + 'static;
24
25 type Incoming: AsyncRead + Index<Self::Incoming> + Send + Sync + Unpin + 'static;
27
28 fn serve(
30 &self,
31 instance: &str,
32 func: &str,
33 paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
34 ) -> impl Future<
35 Output = anyhow::Result<
36 impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>>
37 + Send
38 + 'static,
39 >,
40 > + Send;
41}
42
43pub trait ServeExt: Serve {
45 #[instrument(level = "trace", skip(self, paths))]
47 fn serve_values<Params, Results>(
48 &self,
49 instance: &str,
50 func: &str,
51 paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
52 ) -> impl Future<
53 Output = anyhow::Result<
54 impl Stream<
55 Item = anyhow::Result<(
56 Self::Context,
57 Params,
58 Option<impl Future<Output = std::io::Result<()>> + Send + Unpin + 'static>,
59 impl FnOnce(
60 Results,
61 ) -> Pin<
62 Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>,
63 > + Send
64 + 'static,
65 )>,
66 > + Send
67 + 'static,
68 >,
69 > + Send
70 where
71 Params: TupleDecode<Self::Incoming> + Send + 'static,
72 Results: TupleEncode<Self::Outgoing> + Send + 'static,
73 <Params::Decoder as tokio_util::codec::Decoder>::Error:
74 std::error::Error + Send + Sync + 'static,
75 <Results::Encoder as tokio_util::codec::Encoder<Results>>::Error:
76 std::error::Error + Send + Sync + 'static,
77 {
78 let span = Span::current();
79 async {
80 let invocations = self.serve(instance, func, paths).await?;
81 Ok(invocations.and_then(move |(cx, outgoing, incoming)| {
82 async {
83 let mut dec = FramedRead::new(incoming, Params::Decoder::default());
84 debug!("receiving sync parameters");
85 let Some(params) = dec
86 .try_next()
87 .await
88 .context("failed to receive sync parameters")?
89 else {
90 bail!("incomplete sync parameters")
91 };
92 trace!("received sync parameters");
93 let rx = dec.decoder_mut().take_deferred();
94 let buffer = mem::take(dec.read_buffer_mut());
95 let span = Span::current();
96 Ok((
97 cx,
98 params,
99 rx.map(|f| {
100 f(
101 Incoming {
102 buffer,
103 inner: dec.into_inner(),
104 },
105 Vec::default(),
106 )
107 }),
108 move |results| {
109 Box::pin(
110 async {
111 let mut enc =
112 FramedWrite::new(outgoing, Results::Encoder::default());
113 debug!("transmitting sync results");
114 enc.send(results)
115 .await
116 .context("failed to transmit synchronous results")?;
117 let tx = enc.encoder_mut().take_deferred();
118 let mut outgoing = enc.into_inner();
119 outgoing
120 .shutdown()
121 .await
122 .context("failed to shutdown synchronous return channel")?;
123 if let Some(tx) = tx {
124 debug!("transmitting async results");
125 tx(outgoing, Vec::default())
126 .await
127 .context("failed to write async results")?;
128 }
129 Ok(())
130 }
131 .instrument(span),
132 ) as Pin<_>
133 },
134 ))
135 }
136 .instrument(span.clone())
137 }))
138 }
139 }
140}
141
142impl<T: Serve> ServeExt for T {}
143
144#[allow(dead_code)]
145#[cfg(test)]
146mod tests {
147 use bytes::Bytes;
148 use futures::{stream, StreamExt as _, TryStreamExt as _};
149
150 use crate::Captures;
151
152 use super::*;
153
154 async fn call_serve<T: Serve>(
155 s: &T,
156 ) -> anyhow::Result<Vec<(T::Context, T::Outgoing, T::Incoming)>> {
157 let st = stream::empty()
158 .chain({
159 s.serve(
160 "foo",
161 "bar",
162 [Box::from([Some(42), None]), Box::from([None])],
163 )
164 .await
165 .unwrap()
166 })
167 .chain({
168 s.serve(
169 "foo",
170 "bar",
171 vec![Box::from([Some(42), None]), Box::from([None])],
172 )
173 .await
174 .unwrap()
175 })
176 .chain({
177 s.serve(
178 "foo",
179 "bar",
180 [Box::from([Some(42), None]), Box::from([None])].as_slice(),
181 )
182 .await
183 .unwrap()
184 });
185 tokio::spawn(async move { st.try_collect().await })
186 .await
187 .unwrap()
188 }
189
190 fn serve_lifetime<T: Serve>(
191 s: &T,
192 ) -> impl Future<
193 Output = anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<T::Context>> + 'static>>>,
194 > + Captures<'_> {
195 let fut = s.serve(
196 "foo",
197 "bar",
198 [Box::from([Some(42), None]), Box::from([None])],
199 );
200 async move {
201 let st = fut.await.unwrap();
202 Ok(Box::pin(st.and_then(|(cx, _, _)| async { Ok(cx) }))
203 as Pin<Box<dyn Stream<Item = _>>>)
204 }
205 }
206
207 fn serve_values_lifetime<T: Serve>(
208 s: &T,
209 ) -> impl Future<
210 Output = anyhow::Result<Pin<Box<dyn Stream<Item = anyhow::Result<T::Context>> + 'static>>>,
211 > + crate::Captures<'_> {
212 let fut = s.serve_values::<(Bytes,), (Bytes,)>(
213 "foo",
214 "bar",
215 [Box::from([Some(42), None]), Box::from([None])],
216 );
217 async move {
218 let st = fut.await.unwrap();
219 Ok(Box::pin(st.and_then(|(cx, _, _, tx)| async {
220 tx((Bytes::from("test"),)).await.unwrap();
221 Ok(cx)
222 })) as Pin<Box<dyn Stream<Item = _>>>)
223 }
224 }
225}