Skip to main content

tansu_service/
stream.rs

1// Copyright ⓒ 2024-2025 Peter Morgan <peter.james.morgan@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    error::{self},
17    fmt::Debug,
18    io,
19    marker::PhantomData,
20    time::SystemTime,
21};
22
23use bytes::Bytes;
24use nanoid::nanoid;
25use opentelemetry::KeyValue;
26use rama::{Context, Layer, Service};
27use tokio::{
28    io::{AsyncReadExt, AsyncWriteExt, BufWriter},
29    net::{TcpListener, TcpStream},
30    task::JoinSet,
31};
32use tokio_util::sync::CancellationToken;
33use tracing::{debug, error, instrument};
34
35use crate::{
36    BYTES_RECEIVED, BYTES_SENT, Error, REQUEST_DURATION, REQUEST_SIZE, RESPONSE_SIZE, frame_length,
37};
38
39/// A [`Layer`] that listens for TCP connections
40#[derive(Clone, Debug, Default)]
41pub struct TcpListenerLayer {
42    cancellation: CancellationToken,
43}
44
45impl TcpListenerLayer {
46    pub fn new(cancellation: CancellationToken) -> Self {
47        Self { cancellation }
48    }
49}
50
51impl<S> Layer<S> for TcpListenerLayer {
52    type Service = TcpListenerService<S>;
53
54    fn layer(&self, inner: S) -> Self::Service {
55        Self::Service {
56            cancellation: self.cancellation.clone(),
57            inner,
58        }
59    }
60}
61
62/// A [`Service`] that listens for TCP connections
63#[derive(Clone, Default)]
64pub struct TcpListenerService<S> {
65    cancellation: CancellationToken,
66    inner: S,
67}
68
69impl<S> Debug for TcpListenerService<S> {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct(stringify!(TcpListenerService)).finish()
72    }
73}
74
75impl<State, S> Service<State, TcpListener> for TcpListenerService<S>
76where
77    S: Service<State, TcpStream> + Clone,
78    S::Response: Debug,
79    S::Error: error::Error,
80    State: Clone + Send + Sync + 'static,
81{
82    type Response = ();
83    type Error = S::Error;
84
85    #[instrument(skip(ctx, req))]
86    async fn serve(
87        &self,
88        ctx: Context<State>,
89        req: TcpListener,
90    ) -> Result<Self::Response, Self::Error> {
91        let mut set = JoinSet::new();
92
93        loop {
94            tokio::select! {
95                Ok((stream, addr)) = req.accept() => {
96                    debug!(?req, ?stream, %addr);
97
98                    let service = self.inner.clone();
99                    let ctx = ctx.clone();
100
101                    let handle = set.spawn(async move {
102                            match service.serve(ctx, stream).await {
103                                Err(error) => {
104                                    debug!(%addr, %error);
105                                },
106
107                                Ok(response) => {
108                                    debug!(%addr, ?response)
109                                }
110                        }
111                    });
112
113                    debug!(?handle);
114                    continue;
115                }
116
117                v = set.join_next(), if !set.is_empty() => {
118                    debug!(?v);
119                }
120
121                cancelled = self.cancellation.cancelled() => {
122                    debug!(?cancelled);
123                    break;
124                }
125            }
126        }
127
128        Ok(())
129    }
130}
131
132/// A [context state][`Context#method.state`] state used by [`TcpContextLayer`] and [`TcpContextService`]
133#[non_exhaustive]
134#[derive(Clone, Debug, Default)]
135pub struct TcpContext {
136    cluster_id: Option<String>,
137    maximum_frame_size: Option<usize>,
138}
139
140impl TcpContext {
141    pub fn cluster_id(self, cluster_id: Option<String>) -> Self {
142        Self { cluster_id, ..self }
143    }
144
145    pub fn maximum_frame_size(self, maximum_frame_size: Option<usize>) -> Self {
146        Self {
147            maximum_frame_size,
148            ..self
149        }
150    }
151}
152
153/// A [`Layer`] that injects the [`TcpContext`] into the service [`Context`] state
154#[derive(Clone, Debug, Default)]
155pub struct TcpContextLayer {
156    state: TcpContext,
157}
158
159impl TcpContextLayer {
160    pub fn new(state: TcpContext) -> Self {
161        Self { state }
162    }
163}
164
165impl<S> Layer<S> for TcpContextLayer {
166    type Service = TcpContextService<S>;
167
168    fn layer(&self, inner: S) -> Self::Service {
169        Self::Service {
170            inner,
171            state: self.state.clone(),
172        }
173    }
174}
175
176/// A [`Service`] that requires the [`TcpContext`] as the service [`Context`] state
177#[derive(Clone)]
178pub struct TcpContextService<S> {
179    inner: S,
180    state: TcpContext,
181}
182
183impl<S> Debug for TcpContextService<S> {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        f.debug_struct(stringify!(TcpContextService)).finish()
186    }
187}
188
189impl<State, S> Service<State, TcpStream> for TcpContextService<S>
190where
191    S: Service<TcpContext, TcpStream>,
192    S::Error: From<io::Error>,
193    State: Clone + Send + Sync + 'static,
194{
195    type Response = S::Response;
196    type Error = S::Error;
197
198    #[instrument(skip_all, fields(peer = %req.peer_addr()?))]
199    async fn serve(
200        &self,
201        ctx: Context<State>,
202        req: TcpStream,
203    ) -> Result<Self::Response, Self::Error> {
204        let (ctx, _) = ctx.swap_state(self.state.clone());
205
206        self.inner.serve(ctx, req).await
207    }
208}
209
210/// A [`Service`] writing [`Bytes`] into a [`TcpStream`], responding with a length delimited frame of [`Bytes`]
211#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
212pub struct BytesTcpService;
213
214impl Service<TcpStream, Bytes> for BytesTcpService {
215    type Response = Bytes;
216    type Error = Error;
217
218    #[instrument(skip(ctx, req))]
219    async fn serve(
220        &self,
221        mut ctx: Context<TcpStream>,
222        req: Bytes,
223    ) -> Result<Self::Response, Self::Error> {
224        let stream = ctx.state_mut();
225
226        stream.write_all(&req[..]).await?;
227        BYTES_SENT.add(req.len() as u64, &[]);
228
229        let mut size = [0u8; 4];
230        _ = stream.read_exact(&mut size).await?;
231
232        let mut buffer: Vec<u8> = vec![0u8; frame_length(size)];
233        buffer[0..size.len()].copy_from_slice(&size[..]);
234        _ = stream.read_exact(&mut buffer[4..]).await?;
235        BYTES_RECEIVED.add(buffer.len() as u64, &[]);
236
237        Ok(Bytes::from(buffer))
238    }
239}
240
241/// A [`Layer`] receiving [`Bytes`] from a [`TcpStream`]
242#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
243pub struct TcpBytesLayer<State = ()> {
244    _state: PhantomData<State>,
245}
246
247impl<S, State> Layer<S> for TcpBytesLayer<State> {
248    type Service = TcpBytesService<S, State>;
249
250    fn layer(&self, inner: S) -> Self::Service {
251        Self::Service {
252            inner,
253            _state: PhantomData,
254        }
255    }
256}
257
258/// A [`Service`] receiving [`Bytes`] from a [`TcpStream`], calling an inner [`Service`] and sending [`Bytes`] into the [`TcpStream`]
259#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
260pub struct TcpBytesService<S, State> {
261    inner: S,
262    _state: PhantomData<State>,
263}
264
265impl<S, State> Debug for TcpBytesService<S, State> {
266    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        f.debug_struct(stringify!(TcpBytesService)).finish()
268    }
269}
270
271impl<S, State> TcpBytesService<S, State> {
272    fn elapsed_millis(&self, start: SystemTime) -> u64 {
273        start
274            .elapsed()
275            .map_or(0, |duration| duration.as_millis() as u64)
276    }
277}
278
279impl<S, State> TcpBytesService<S, State>
280where
281    S: Service<State, Bytes, Response = Bytes>,
282    S::Error: From<Error> + From<io::Error> + Debug,
283    State: Clone + Default + Send + Sync + 'static,
284{
285    #[instrument(skip_all)]
286    async fn wait<R>(
287        &self,
288        req: &mut R,
289        maximum_frame_size: Option<usize>,
290    ) -> Result<[u8; 4], S::Error>
291    where
292        R: AsyncReadExt + Unpin,
293    {
294        let mut size = [0u8; 4];
295
296        _ = req
297            .read_exact(&mut size)
298            .await
299            .inspect_err(|err| debug!(?err))?;
300
301        if maximum_frame_size
302            .is_some_and(|maximum_frame_size| maximum_frame_size > frame_length(size))
303        {
304            return Err(Into::into(Error::FrameTooBig(frame_length(size))));
305        } else {
306            Ok(size)
307        }
308    }
309
310    #[instrument(skip_all)]
311    async fn read<R>(&self, req: &mut R, size: [u8; 4]) -> Result<Bytes, S::Error>
312    where
313        R: AsyncReadExt + Unpin,
314    {
315        let mut request: Vec<u8> = vec![0u8; frame_length(size)];
316
317        request[0..size.len()].copy_from_slice(&size[..]);
318
319        _ = req
320            .read_exact(&mut request[4..])
321            .await
322            .inspect_err(|err| error!(?err))?;
323        BYTES_RECEIVED.add(request.len() as u64, &[]);
324
325        Ok(Bytes::from(request))
326    }
327
328    #[instrument(skip_all)]
329    async fn process(
330        &self,
331        attributes: &[KeyValue],
332        ctx: Context<TcpContext>,
333        request: Bytes,
334    ) -> Result<Bytes, S::Error> {
335        REQUEST_SIZE.record(request.len() as u64, attributes);
336
337        let (ctx, _) = ctx.swap_state(State::default());
338        let request_start = SystemTime::now();
339
340        self.inner
341            .serve(ctx, request)
342            .await
343            .inspect_err(|err| error!(?err))
344            .inspect(|response| {
345                RESPONSE_SIZE.record(response.len() as u64, attributes);
346
347                let elapsed_millis = self.elapsed_millis(request_start);
348
349                REQUEST_DURATION.record(elapsed_millis, attributes);
350            })
351    }
352
353    #[instrument(skip_all)]
354    async fn write<W>(&self, req: &mut W, frame: Bytes) -> Result<(), S::Error>
355    where
356        W: AsyncWriteExt + Unpin,
357    {
358        let mut w = BufWriter::new(req);
359        w.write_all(&frame).await.inspect_err(|err| error!(?err))?;
360        BYTES_SENT.add(frame.len() as u64, &[]);
361        w.flush().await.map_err(Into::into)
362    }
363
364    #[instrument(skip_all, fields(id = nanoid!()))]
365    async fn req<R>(
366        &self,
367        req: &mut R,
368        maximum_frame_size: Option<usize>,
369        attributes: &[KeyValue],
370        ctx: Context<TcpContext>,
371    ) -> Result<(), S::Error>
372    where
373        R: AsyncReadExt + AsyncWriteExt + Unpin,
374    {
375        let size = self.wait(req, maximum_frame_size).await?;
376        let request = self.read(req, size).await?;
377        let response = self.process(attributes, ctx, request).await?;
378        self.write(req, response).await
379    }
380}
381
382impl<S, State, Stream> Service<TcpContext, Stream> for TcpBytesService<S, State>
383where
384    S: Service<State, Bytes, Response = Bytes>,
385    S::Error: From<Error> + From<io::Error> + Debug,
386    State: Clone + Default + Send + Sync + 'static,
387    Stream: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync + 'static,
388{
389    type Response = ();
390
391    type Error = S::Error;
392
393    #[instrument(skip(ctx, req))]
394    async fn serve(
395        &self,
396        ctx: Context<TcpContext>,
397        mut req: Stream,
398    ) -> Result<Self::Response, Self::Error> {
399        let attributes = {
400            let state = ctx.state();
401
402            let mut attributes = vec![];
403
404            if let Some(cluster_id) = state.cluster_id.clone() {
405                attributes.push(KeyValue::new("cluster_id", cluster_id))
406            }
407
408            attributes
409        };
410
411        let maximum_frame_size = ctx.state().maximum_frame_size;
412
413        loop {
414            let ctx = ctx.clone();
415            let attributes = attributes.clone();
416
417            self.req(&mut req, maximum_frame_size, &attributes[..], ctx)
418                .await?
419        }
420    }
421}
422
423/// A [`Layer`] that handles and responds with [`Bytes`]
424#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
425pub struct BytesLayer;
426
427impl<S> Layer<S> for BytesLayer {
428    type Service = BytesService<S>;
429
430    fn layer(&self, inner: S) -> Self::Service {
431        Self::Service { inner }
432    }
433}
434
435/// A [`Service`] that handles and responds with [`Bytes`]
436#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
437pub struct BytesService<S> {
438    inner: S,
439}
440
441impl<S> Debug for BytesService<S> {
442    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443        f.debug_struct(stringify!(BytesService)).finish()
444    }
445}
446
447impl<S, State> Service<State, Bytes> for BytesService<S>
448where
449    S: Service<State, Bytes, Response = Bytes>,
450    State: Clone + Send + Sync + 'static,
451{
452    type Response = Bytes;
453    type Error = S::Error;
454
455    #[instrument(skip_all)]
456    async fn serve(&self, ctx: Context<State>, req: Bytes) -> Result<Self::Response, Self::Error> {
457        debug!(req = ?&req[..]);
458        self.inner
459            .serve(ctx, req)
460            .await
461            .inspect(|response| debug!(response = ?&response[..]))
462    }
463}