1use std::{
16 error::{self},
17 fmt::Debug,
18 io,
19 marker::PhantomData,
20 time::SystemTime,
21};
22
23use bytes::Bytes;
24use nanoid::nanoid;
25use opentelemetry::KeyValue;
26use rama::{Context, Layer, Service};
27use tokio::{
28 io::{AsyncReadExt, AsyncWriteExt, BufWriter},
29 net::{TcpListener, TcpStream},
30 task::JoinSet,
31};
32use tokio_util::sync::CancellationToken;
33use tracing::{debug, error, instrument};
34
35use crate::{
36 BYTES_RECEIVED, BYTES_SENT, Error, REQUEST_DURATION, REQUEST_SIZE, RESPONSE_SIZE, frame_length,
37};
38
39#[derive(Clone, Debug, Default)]
41pub struct TcpListenerLayer {
42 cancellation: CancellationToken,
43}
44
45impl TcpListenerLayer {
46 pub fn new(cancellation: CancellationToken) -> Self {
47 Self { cancellation }
48 }
49}
50
51impl<S> Layer<S> for TcpListenerLayer {
52 type Service = TcpListenerService<S>;
53
54 fn layer(&self, inner: S) -> Self::Service {
55 Self::Service {
56 cancellation: self.cancellation.clone(),
57 inner,
58 }
59 }
60}
61
62#[derive(Clone, Default)]
64pub struct TcpListenerService<S> {
65 cancellation: CancellationToken,
66 inner: S,
67}
68
69impl<S> Debug for TcpListenerService<S> {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct(stringify!(TcpListenerService)).finish()
72 }
73}
74
75impl<State, S> Service<State, TcpListener> for TcpListenerService<S>
76where
77 S: Service<State, TcpStream> + Clone,
78 S::Response: Debug,
79 S::Error: error::Error,
80 State: Clone + Send + Sync + 'static,
81{
82 type Response = ();
83 type Error = S::Error;
84
85 #[instrument(skip(ctx, req))]
86 async fn serve(
87 &self,
88 ctx: Context<State>,
89 req: TcpListener,
90 ) -> Result<Self::Response, Self::Error> {
91 let mut set = JoinSet::new();
92
93 loop {
94 tokio::select! {
95 Ok((stream, addr)) = req.accept() => {
96 debug!(?req, ?stream, %addr);
97
98 let service = self.inner.clone();
99 let ctx = ctx.clone();
100
101 let handle = set.spawn(async move {
102 match service.serve(ctx, stream).await {
103 Err(error) => {
104 debug!(%addr, %error);
105 },
106
107 Ok(response) => {
108 debug!(%addr, ?response)
109 }
110 }
111 });
112
113 debug!(?handle);
114 continue;
115 }
116
117 v = set.join_next(), if !set.is_empty() => {
118 debug!(?v);
119 }
120
121 cancelled = self.cancellation.cancelled() => {
122 debug!(?cancelled);
123 break;
124 }
125 }
126 }
127
128 Ok(())
129 }
130}
131
132#[non_exhaustive]
134#[derive(Clone, Debug, Default)]
135pub struct TcpContext {
136 cluster_id: Option<String>,
137 maximum_frame_size: Option<usize>,
138}
139
140impl TcpContext {
141 pub fn cluster_id(self, cluster_id: Option<String>) -> Self {
142 Self { cluster_id, ..self }
143 }
144
145 pub fn maximum_frame_size(self, maximum_frame_size: Option<usize>) -> Self {
146 Self {
147 maximum_frame_size,
148 ..self
149 }
150 }
151}
152
153#[derive(Clone, Debug, Default)]
155pub struct TcpContextLayer {
156 state: TcpContext,
157}
158
159impl TcpContextLayer {
160 pub fn new(state: TcpContext) -> Self {
161 Self { state }
162 }
163}
164
165impl<S> Layer<S> for TcpContextLayer {
166 type Service = TcpContextService<S>;
167
168 fn layer(&self, inner: S) -> Self::Service {
169 Self::Service {
170 inner,
171 state: self.state.clone(),
172 }
173 }
174}
175
176#[derive(Clone)]
178pub struct TcpContextService<S> {
179 inner: S,
180 state: TcpContext,
181}
182
183impl<S> Debug for TcpContextService<S> {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 f.debug_struct(stringify!(TcpContextService)).finish()
186 }
187}
188
189impl<State, S> Service<State, TcpStream> for TcpContextService<S>
190where
191 S: Service<TcpContext, TcpStream>,
192 S::Error: From<io::Error>,
193 State: Clone + Send + Sync + 'static,
194{
195 type Response = S::Response;
196 type Error = S::Error;
197
198 #[instrument(skip_all, fields(peer = %req.peer_addr()?))]
199 async fn serve(
200 &self,
201 ctx: Context<State>,
202 req: TcpStream,
203 ) -> Result<Self::Response, Self::Error> {
204 let (ctx, _) = ctx.swap_state(self.state.clone());
205
206 self.inner.serve(ctx, req).await
207 }
208}
209
210#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
212pub struct BytesTcpService;
213
214impl Service<TcpStream, Bytes> for BytesTcpService {
215 type Response = Bytes;
216 type Error = Error;
217
218 #[instrument(skip(ctx, req))]
219 async fn serve(
220 &self,
221 mut ctx: Context<TcpStream>,
222 req: Bytes,
223 ) -> Result<Self::Response, Self::Error> {
224 let stream = ctx.state_mut();
225
226 stream.write_all(&req[..]).await?;
227 BYTES_SENT.add(req.len() as u64, &[]);
228
229 let mut size = [0u8; 4];
230 _ = stream.read_exact(&mut size).await?;
231
232 let mut buffer: Vec<u8> = vec![0u8; frame_length(size)];
233 buffer[0..size.len()].copy_from_slice(&size[..]);
234 _ = stream.read_exact(&mut buffer[4..]).await?;
235 BYTES_RECEIVED.add(buffer.len() as u64, &[]);
236
237 Ok(Bytes::from(buffer))
238 }
239}
240
241#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
243pub struct TcpBytesLayer<State = ()> {
244 _state: PhantomData<State>,
245}
246
247impl<S, State> Layer<S> for TcpBytesLayer<State> {
248 type Service = TcpBytesService<S, State>;
249
250 fn layer(&self, inner: S) -> Self::Service {
251 Self::Service {
252 inner,
253 _state: PhantomData,
254 }
255 }
256}
257
258#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
260pub struct TcpBytesService<S, State> {
261 inner: S,
262 _state: PhantomData<State>,
263}
264
265impl<S, State> Debug for TcpBytesService<S, State> {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 f.debug_struct(stringify!(TcpBytesService)).finish()
268 }
269}
270
271impl<S, State> TcpBytesService<S, State> {
272 fn elapsed_millis(&self, start: SystemTime) -> u64 {
273 start
274 .elapsed()
275 .map_or(0, |duration| duration.as_millis() as u64)
276 }
277}
278
279impl<S, State> TcpBytesService<S, State>
280where
281 S: Service<State, Bytes, Response = Bytes>,
282 S::Error: From<Error> + From<io::Error> + Debug,
283 State: Clone + Default + Send + Sync + 'static,
284{
285 #[instrument(skip_all)]
286 async fn wait<R>(
287 &self,
288 req: &mut R,
289 maximum_frame_size: Option<usize>,
290 ) -> Result<[u8; 4], S::Error>
291 where
292 R: AsyncReadExt + Unpin,
293 {
294 let mut size = [0u8; 4];
295
296 _ = req
297 .read_exact(&mut size)
298 .await
299 .inspect_err(|err| debug!(?err))?;
300
301 if maximum_frame_size
302 .is_some_and(|maximum_frame_size| maximum_frame_size > frame_length(size))
303 {
304 return Err(Into::into(Error::FrameTooBig(frame_length(size))));
305 } else {
306 Ok(size)
307 }
308 }
309
310 #[instrument(skip_all)]
311 async fn read<R>(&self, req: &mut R, size: [u8; 4]) -> Result<Bytes, S::Error>
312 where
313 R: AsyncReadExt + Unpin,
314 {
315 let mut request: Vec<u8> = vec![0u8; frame_length(size)];
316
317 request[0..size.len()].copy_from_slice(&size[..]);
318
319 _ = req
320 .read_exact(&mut request[4..])
321 .await
322 .inspect_err(|err| error!(?err))?;
323 BYTES_RECEIVED.add(request.len() as u64, &[]);
324
325 Ok(Bytes::from(request))
326 }
327
328 #[instrument(skip_all)]
329 async fn process(
330 &self,
331 attributes: &[KeyValue],
332 ctx: Context<TcpContext>,
333 request: Bytes,
334 ) -> Result<Bytes, S::Error> {
335 REQUEST_SIZE.record(request.len() as u64, attributes);
336
337 let (ctx, _) = ctx.swap_state(State::default());
338 let request_start = SystemTime::now();
339
340 self.inner
341 .serve(ctx, request)
342 .await
343 .inspect_err(|err| error!(?err))
344 .inspect(|response| {
345 RESPONSE_SIZE.record(response.len() as u64, attributes);
346
347 let elapsed_millis = self.elapsed_millis(request_start);
348
349 REQUEST_DURATION.record(elapsed_millis, attributes);
350 })
351 }
352
353 #[instrument(skip_all)]
354 async fn write<W>(&self, req: &mut W, frame: Bytes) -> Result<(), S::Error>
355 where
356 W: AsyncWriteExt + Unpin,
357 {
358 let mut w = BufWriter::new(req);
359 w.write_all(&frame).await.inspect_err(|err| error!(?err))?;
360 BYTES_SENT.add(frame.len() as u64, &[]);
361 w.flush().await.map_err(Into::into)
362 }
363
364 #[instrument(skip_all, fields(id = nanoid!()))]
365 async fn req<R>(
366 &self,
367 req: &mut R,
368 maximum_frame_size: Option<usize>,
369 attributes: &[KeyValue],
370 ctx: Context<TcpContext>,
371 ) -> Result<(), S::Error>
372 where
373 R: AsyncReadExt + AsyncWriteExt + Unpin,
374 {
375 let size = self.wait(req, maximum_frame_size).await?;
376 let request = self.read(req, size).await?;
377 let response = self.process(attributes, ctx, request).await?;
378 self.write(req, response).await
379 }
380}
381
382impl<S, State, Stream> Service<TcpContext, Stream> for TcpBytesService<S, State>
383where
384 S: Service<State, Bytes, Response = Bytes>,
385 S::Error: From<Error> + From<io::Error> + Debug,
386 State: Clone + Default + Send + Sync + 'static,
387 Stream: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync + 'static,
388{
389 type Response = ();
390
391 type Error = S::Error;
392
393 #[instrument(skip(ctx, req))]
394 async fn serve(
395 &self,
396 ctx: Context<TcpContext>,
397 mut req: Stream,
398 ) -> Result<Self::Response, Self::Error> {
399 let attributes = {
400 let state = ctx.state();
401
402 let mut attributes = vec![];
403
404 if let Some(cluster_id) = state.cluster_id.clone() {
405 attributes.push(KeyValue::new("cluster_id", cluster_id))
406 }
407
408 attributes
409 };
410
411 let maximum_frame_size = ctx.state().maximum_frame_size;
412
413 loop {
414 let ctx = ctx.clone();
415 let attributes = attributes.clone();
416
417 self.req(&mut req, maximum_frame_size, &attributes[..], ctx)
418 .await?
419 }
420 }
421}
422
423#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
425pub struct BytesLayer;
426
427impl<S> Layer<S> for BytesLayer {
428 type Service = BytesService<S>;
429
430 fn layer(&self, inner: S) -> Self::Service {
431 Self::Service { inner }
432 }
433}
434
435#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
437pub struct BytesService<S> {
438 inner: S,
439}
440
441impl<S> Debug for BytesService<S> {
442 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443 f.debug_struct(stringify!(BytesService)).finish()
444 }
445}
446
447impl<S, State> Service<State, Bytes> for BytesService<S>
448where
449 S: Service<State, Bytes, Response = Bytes>,
450 State: Clone + Send + Sync + 'static,
451{
452 type Response = Bytes;
453 type Error = S::Error;
454
455 #[instrument(skip_all)]
456 async fn serve(&self, ctx: Context<State>, req: Bytes) -> Result<Self::Response, Self::Error> {
457 debug!(req = ?&req[..]);
458 self.inner
459 .serve(ctx, req)
460 .await
461 .inspect(|response| debug!(response = ?&response[..]))
462 }
463}