1use std::{
16 fmt::{self, Debug},
17 marker::PhantomData,
18 sync::{Arc, Mutex},
19};
20
21use bytes::{BufMut as _, Bytes, BytesMut};
22use indicatif::ProgressBar;
23use opentelemetry::KeyValue;
24use rama::{Context, Layer, Service, context::Extensions, matcher::Matcher, service::BoxService};
25use rsasl::config::SASLConfig;
26use tansu_auth::Authentication;
27use tansu_sans_io::{
28 ApiKey, ApiVersionsRequest, Body, Frame, Header, Request, Response, RootMessageMeta,
29 SaslAuthenticateRequest, SaslAuthenticateResponse, SaslHandshakeRequest,
30};
31use tokio::task::spawn_blocking;
32use tracing::{debug, error, instrument};
33
34use crate::{API_ERRORS, API_REQUESTS};
35
36#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
38pub struct RequestApiKeyMatcher(pub i16);
39
40impl<State, Q> Matcher<State, Q> for RequestApiKeyMatcher
41where
42 Q: Request,
43 State: Clone + Debug,
44{
45 fn matches(&self, ext: Option<&mut Extensions>, ctx: &Context<State>, req: &Q) -> bool {
46 debug!(?ext, ?ctx, ?req);
47 Q::KEY == self.0
48 }
49}
50
51#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
53pub struct FrameApiKeyMatcher(pub i16);
54
55impl<State> Matcher<State, Frame> for FrameApiKeyMatcher
56where
57 State: Clone + Debug,
58{
59 fn matches(&self, ext: Option<&mut Extensions>, ctx: &Context<State>, req: &Frame) -> bool {
60 let _ = (ext, ctx);
61 req.api_key().is_ok_and(|api_key| api_key == self.0)
62 }
63}
64
65#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
67pub struct RequestLayer<Q> {
68 request: PhantomData<Q>,
69}
70
71impl<Q> RequestLayer<Q> {
72 pub fn new() -> Self {
73 Self {
74 request: PhantomData,
75 }
76 }
77}
78
79impl<S, Q> Layer<S> for RequestLayer<Q> {
80 type Service = RequestService<S, Q>;
81
82 fn layer(&self, inner: S) -> Self::Service {
83 Self::Service {
84 inner,
85 request: PhantomData,
86 }
87 }
88}
89
90#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
92pub struct RequestService<S, Q> {
93 inner: S,
94 request: PhantomData<Q>,
95}
96
97impl<S, Q> Debug for RequestService<S, Q> {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 f.debug_struct(stringify!(RequestService)).finish()
100 }
101}
102
103impl<State, S, Q> Service<State, Q> for RequestService<S, Q>
104where
105 S: Service<State, Q>,
106 Q: Request,
107 S::Error: From<<Q as TryFrom<Body>>::Error> + From<<S as Service<State, Q>>::Error>,
108 S::Response: Response,
109 Body: From<<S as Service<State, Q>>::Response>,
110 State: Send + Sync + 'static,
111{
112 type Response = S::Response;
113 type Error = S::Error;
114
115 #[instrument(skip(ctx, req))]
116 async fn serve(&self, ctx: Context<State>, req: Q) -> Result<Self::Response, Self::Error> {
117 debug!(?req);
118 self.inner
119 .serve(ctx, req)
120 .await
121 .inspect(|response| debug!(?response))
122 }
123}
124
125impl<S, Q> ApiKey for RequestService<S, Q>
126where
127 Q: Request,
128{
129 const KEY: i16 = Q::KEY;
130}
131
132#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
134pub struct FrameRequestLayer<Q> {
135 request: PhantomData<Q>,
136}
137
138impl<Q> FrameRequestLayer<Q> {
139 pub fn new() -> Self {
140 Self {
141 request: PhantomData,
142 }
143 }
144}
145
146impl<S, Q> Layer<S> for FrameRequestLayer<Q> {
147 type Service = FrameRequestService<S, Q>;
148
149 fn layer(&self, inner: S) -> Self::Service {
150 Self::Service {
151 inner,
152 request: PhantomData,
153 }
154 }
155}
156
157#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
159pub struct FrameRequestService<S, Q> {
160 inner: S,
161 request: PhantomData<Q>,
162}
163
164impl<S, Q> Debug for FrameRequestService<S, Q> {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 f.debug_struct(stringify!(FrameRequestService)).finish()
167 }
168}
169
170impl<S, Q, State> Service<State, Frame> for FrameRequestService<S, Q>
171where
172 S: Service<State, Q>,
173 S::Response: Response,
174 S::Error: From<tansu_sans_io::Error>,
175 Q: Request + TryFrom<Body>,
176 <Q as TryFrom<Body>>::Error: Into<S::Error>,
177 State: Send + Sync + 'static,
178{
179 type Response = Frame;
180 type Error = S::Error;
181
182 #[instrument(skip(ctx, req))]
183 async fn serve(&self, ctx: Context<State>, req: Frame) -> Result<Self::Response, Self::Error> {
184 let correlation_id = req.correlation_id()?;
185
186 let req = Q::try_from(req.body).map_err(Into::into)?;
187
188 self.inner.serve(ctx, req).await.map(|response| Frame {
189 size: 0,
190 header: Header::Response { correlation_id },
191 body: response.into(),
192 })
193 }
194}
195
196impl<S, Q, State> Matcher<State, Frame> for FrameRequestService<S, Q>
197where
198 S: Clone + Send + Sync + 'static,
199 Q: Request,
200 State: Clone + Debug,
201{
202 fn matches(&self, ext: Option<&mut Extensions>, ctx: &Context<State>, req: &Frame) -> bool {
203 debug!(?ext, ?ctx, ?req);
204 req.api_key().is_ok_and(|api_key| api_key == Q::KEY)
205 }
206}
207
208#[derive(Clone, Debug, Default)]
210pub struct BytesFrameLayer {
211 sasl_config: Option<Arc<SASLConfig>>,
212}
213
214impl BytesFrameLayer {
215 pub fn with_sasl_config(self, sasl_config: Option<Arc<SASLConfig>>) -> Self {
216 Self { sasl_config }
217 }
218}
219
220impl<S> Layer<S> for BytesFrameLayer {
221 type Service = BytesFrameService<S>;
222
223 fn layer(&self, inner: S) -> Self::Service {
224 Self::Service {
225 inner,
226 af: self
227 .sasl_config
228 .clone()
229 .map(|sasl_config| AuthenticationFrame {
230 authentication: Authentication::server(sasl_config),
231 v0: Arc::new(Mutex::new(None)),
232 }),
233 }
234 }
235}
236
237#[derive(Clone, Default)]
238struct AuthenticationFrame {
239 authentication: Authentication,
240 v0: Arc<Mutex<Option<bool>>>,
241}
242
243impl AuthenticationFrame {
244 fn is_authenticated(&self) -> bool {
245 self.authentication.is_authenticated()
246 }
247}
248
249#[derive(Clone, Default)]
251pub struct BytesFrameService<S> {
252 inner: S,
253 af: Option<AuthenticationFrame>,
254}
255
256impl<S> Debug for BytesFrameService<S> {
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 f.debug_struct(stringify!(BytesFrameService)).finish()
259 }
260}
261
262impl<S> BytesFrameService<S> {
263 fn is_authenticated(&self, api_key: i16) -> bool {
264 self.af.as_ref().is_none_or(|af| {
265 af.authentication.is_authenticated()
266 || api_key == SaslHandshakeRequest::KEY
267 || api_key == SaslAuthenticateRequest::KEY
268 || api_key == ApiVersionsRequest::KEY
269 })
270 }
271}
272
273impl<S, State> Service<State, Bytes> for BytesFrameService<S>
274where
275 S: Service<State, Frame, Response = Frame>,
276 State: Clone + Send + Sync + 'static,
277 S::Error: From<tansu_sans_io::Error> + From<tokio::task::JoinError> + Debug,
278{
279 type Response = Bytes;
280 type Error = S::Error;
281
282 #[instrument(skip(ctx, req))]
283 async fn serve(
284 &self,
285 mut ctx: Context<State>,
286 req: Bytes,
287 ) -> Result<Self::Response, Self::Error> {
288 let sasl_handshake_v0 = self
289 .af
290 .as_ref()
291 .and_then(|af| af.v0.lock().ok())
292 .inspect(|v0| debug!(?v0))
293 .map(|v0| v0.unwrap_or_default())
294 .unwrap_or_default();
295
296 debug!(request = ?&req[..], sasl_handshake_v0);
297
298 let req = if sasl_handshake_v0 {
299 Frame {
306 size: 0,
307 header: Header::Request {
308 api_key: SaslAuthenticateRequest::KEY,
309 api_version: 0,
310 correlation_id: 0,
311 client_id: None,
312 },
313 body: Body::SaslAuthenticateRequest(
314 SaslAuthenticateRequest::default().auth_bytes(req.slice(4..)),
315 ),
316 }
317 } else {
318 spawn_blocking(|| Frame::request_from_bytes(req))
319 .await?
320 .inspect(|request| debug!(?request))?
321 };
322
323 let api_key = req.api_key()?;
324
325 if !self.is_authenticated(api_key) {
326 return Err(Into::into(tansu_sans_io::Error::NotAuthenticated));
327 }
328
329 let api_version = req.api_version()?;
330 let correlation_id = req.correlation_id()?;
331
332 if let Some(pb) = ctx.get::<ProgressBar>() {
333 let api_name = req.api_name();
334
335 pb.set_message(format!("{api_name} v{api_version}/{correlation_id}"));
336 pb.tick();
337 }
338
339 let attributes = vec![
340 KeyValue::new("api_key", api_key as i64),
341 KeyValue::new("api_version", api_version as i64),
342 ];
343
344 let Frame { body, .. } = {
345 if let Some(authentication) = self.af.as_ref().map(|af| af.authentication.clone()) {
346 assert!(ctx.insert(authentication).is_none());
347 }
348
349 self.inner
350 .serve(ctx, req)
351 .await
352 .inspect(|response| debug!(?response))?
353 };
354
355 if sasl_handshake_v0 {
356 if let Some(af) = self.af.as_ref()
362 && af.is_authenticated()
363 && let Ok(mut v0) = af.v0.lock()
364 && v0.is_some()
365 {
366 *v0 = None
367 }
368
369 SaslAuthenticateResponse::try_from(body)
370 .and_then(|response| {
371 i32::try_from(response.auth_bytes.len())
372 .map_err(Into::into)
373 .map(|size| {
374 let mut frame = BytesMut::new();
375 frame.put(&size.to_be_bytes()[..]);
376 frame.put(response.auth_bytes);
377 Bytes::from(frame)
378 })
379 })
380 .map_err(Into::into)
381 } else {
382 if let Some(af) = self.af.as_ref()
388 && (api_key == SaslHandshakeRequest::KEY && api_version == 0)
389 && let Ok(mut v0) = af.v0.lock()
390 {
391 *v0 = Some(true)
392 }
393
394 spawn_blocking(move || {
395 Frame::response(
396 Header::Response { correlation_id },
397 body,
398 api_key,
399 api_version,
400 )
401 })
402 .await?
403 .inspect(|response| {
404 debug!(response = ?response[..]);
405 API_REQUESTS.add(1, &attributes);
406 })
407 .inspect_err(|err| {
408 error!(api_key, api_version, ?err);
409 API_ERRORS.add(1, &attributes);
410 })
411 .map_err(Into::into)
412 }
413 }
414}
415
416#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
418pub struct FrameBytesLayer;
419
420impl<S> Layer<S> for FrameBytesLayer {
421 type Service = FrameBytesService<S>;
422
423 fn layer(&self, inner: S) -> Self::Service {
424 Self::Service { inner }
425 }
426}
427
428#[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
430pub struct FrameBytesService<S> {
431 inner: S,
432}
433
434impl<S> Debug for FrameBytesService<S> {
435 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436 f.debug_struct(stringify!(FrameBytesService)).finish()
437 }
438}
439
440impl<S, State> Service<State, Frame> for FrameBytesService<S>
441where
442 S: Service<State, Bytes, Response = Bytes>,
443 S::Error: From<tansu_sans_io::Error>,
444 State: Send + Sync + 'static,
445{
446 type Response = Frame;
447 type Error = S::Error;
448
449 #[instrument(skip(ctx, req), fields(api_key = req.api_key()?, api_version = req.api_version()?, correlation_id = req.correlation_id()?))]
450 async fn serve(&self, ctx: Context<State>, req: Frame) -> Result<Self::Response, Self::Error> {
451 debug!(?req);
452
453 let api_key = req.api_key()?;
454 let api_version = req.api_version()?;
455
456 let req = Frame::request(req.header, req.body)?;
457
458 self.inner
459 .serve(ctx, req)
460 .await
461 .and_then(|response| {
462 Frame::response_from_bytes(response, api_key, api_version).map_err(Into::into)
463 })
464 .inspect(|response| debug!(?response))
465 }
466}
467
468#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
470pub struct FrameBodyLayer;
471
472impl<S> Layer<S> for FrameBodyLayer {
473 type Service = FrameBodyService<S>;
474
475 fn layer(&self, inner: S) -> Self::Service {
476 Self::Service { inner }
477 }
478}
479
480#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
482pub struct FrameBodyService<S> {
483 inner: S,
484}
485
486impl<S, State> Service<State, Frame> for FrameBodyService<S>
487where
488 S: Service<State, Body, Response = Body>,
489 S::Error: From<tansu_sans_io::Error>,
490 State: Send + Sync + 'static,
491{
492 type Response = Frame;
493
494 type Error = S::Error;
495
496 #[instrument(skip_all, fields(api_key = req.api_key()?, api_version = req.api_version()?, correlation_id = req.correlation_id()?))]
497 async fn serve(&self, ctx: Context<State>, req: Frame) -> Result<Self::Response, Self::Error> {
498 let correlation_id = req.correlation_id()?;
499
500 self.inner.serve(ctx, req.body).await.map(|body| Frame {
501 size: 0,
502 header: Header::Response { correlation_id },
503 body,
504 })
505 }
506}
507
508#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
510pub struct BodyRequestLayer<Q> {
511 request: PhantomData<Q>,
512}
513
514impl<Q> BodyRequestLayer<Q> {
515 pub fn new() -> Self {
516 Self {
517 request: PhantomData,
518 }
519 }
520}
521
522impl<S, Q> Layer<S> for BodyRequestLayer<Q> {
523 type Service = BodyRequestService<S, Q>;
524
525 fn layer(&self, inner: S) -> Self::Service {
526 Self::Service {
527 inner,
528 request: PhantomData,
529 }
530 }
531}
532
533#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
535pub struct BodyRequestService<S, Q> {
536 inner: S,
537 request: PhantomData<Q>,
538}
539
540impl<S, Q> ApiKey for BodyRequestService<S, Q>
541where
542 Q: Request,
543{
544 const KEY: i16 = Q::KEY;
545}
546
547impl<S, State, Q> Service<State, Body> for BodyRequestService<S, Q>
548where
549 S: Service<State, Q>,
550 Q: Request,
551 S::Error: From<<Q as TryFrom<Body>>::Error> + From<<S as Service<State, Q>>::Error>,
552 Body: From<<S as Service<State, Q>>::Response>,
553 State: Send + Sync + 'static,
554{
555 type Response = Body;
556 type Error = S::Error;
557
558 #[instrument(skip_all)]
559 async fn serve(&self, ctx: Context<State>, req: Body) -> Result<Self::Response, Self::Error> {
560 let req = Q::try_from(req)?;
561 self.inner.serve(ctx, req).await.map(Body::from)
562 }
563}
564
565#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
567pub struct RequestFrameLayer;
568
569impl<S> Layer<S> for RequestFrameLayer {
570 type Service = RequestFrameService<S>;
571
572 fn layer(&self, inner: S) -> Self::Service {
573 Self::Service { inner }
574 }
575}
576
577#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
579pub struct RequestFrameService<S> {
580 inner: S,
581}
582
583impl<S, State, Q> Service<State, Q> for RequestFrameService<S>
584where
585 Q: Request,
586 S: Service<State, Frame, Response = Frame>,
587 S::Error: From<<<Q as Request>::Response as TryFrom<Body>>::Error>,
588 State: Send + Sync + 'static,
589{
590 type Response = Q::Response;
591 type Error = S::Error;
592
593 #[instrument(skip_all)]
594 async fn serve(&self, ctx: Context<State>, req: Q) -> Result<Self::Response, Self::Error> {
595 debug!(?req);
596
597 let api_key = Q::KEY;
598 let api_version = RootMessageMeta::messages()
599 .requests()
600 .get(&api_key)
601 .map(|message_meta| message_meta.version.valid().end)
602 .unwrap_or_default();
603 let correlation_id = 0;
604 let client_id = Some(env!("CARGO_CRATE_NAME").into());
605
606 let req = Frame {
607 size: 0,
608 header: Header::Request {
609 api_key,
610 api_version,
611 correlation_id,
612 client_id,
613 },
614 body: req.into(),
615 };
616
617 self.inner
618 .serve(ctx, req)
619 .await
620 .and_then(|response| Q::Response::try_from(response.body).map_err(Into::into))
621 .inspect(|response| debug!(?response))
622 }
623}
624
625impl<S, State, Q, E> From<RequestService<S, Q>> for BoxService<State, Body, Body, E>
626where
627 S: Service<State, Q, Error = E>,
628 Q: Request,
629 <S as Service<State, Q>>::Response: Response,
630 E: From<<Q as TryFrom<Body>>::Error> + From<<S as Service<State, Q>>::Error>,
631 Body: From<<S as Service<State, Q>>::Response>,
632 State: Send + Sync + 'static,
633{
634 fn from(value: RequestService<S, Q>) -> Self {
635 BodyRequestLayer::<Q>::new().into_layer(value).boxed()
636 }
637}
638
639impl<S, State, Q, E> From<RequestService<S, Q>> for BoxService<State, Frame, Frame, E>
640where
641 S: Service<State, Q, Error = E>,
642 Q: Request,
643 <S as Service<State, Q>>::Response: Response,
644 E: From<tansu_sans_io::Error>
645 + From<<Q as TryFrom<Body>>::Error>
646 + From<<S as Service<State, Q>>::Error>,
647 Body: From<<S as Service<State, Q>>::Response>,
648 State: Send + Sync + 'static,
649{
650 fn from(value: RequestService<S, Q>) -> Self {
651 (FrameBodyLayer, BodyRequestLayer::<Q>::new())
652 .into_layer(value)
653 .boxed()
654 }
655}
656
657impl<S, State, Q, E> From<BodyRequestService<S, Q>> for BoxService<State, Frame, Frame, E>
658where
659 S: Service<State, Q, Error = E>,
660 Q: Request,
661 E: From<tansu_sans_io::Error>
662 + From<<Q as TryFrom<Body>>::Error>
663 + From<<S as Service<State, Q>>::Error>,
664 Body: From<<S as Service<State, Q>>::Response>,
665 State: Send + Sync + 'static,
666{
667 fn from(value: BodyRequestService<S, Q>) -> Self {
668 FrameBodyLayer.into_layer(value).boxed()
669 }
670}
671
672#[derive(Clone, Copy, Debug, Hash)]
674pub struct FrameService<F> {
675 response: F,
676}
677
678impl<State, E, F> Service<State, Frame> for FrameService<F>
679where
680 F: Fn(Context<State>, Frame) -> Result<Frame, E> + Clone + Send + Sync + 'static,
681 E: Send + Sync + 'static,
682 State: Send + Sync + 'static,
683{
684 type Response = Frame;
685 type Error = E;
686
687 #[instrument(skip_all)]
688 async fn serve(&self, ctx: Context<State>, req: Frame) -> Result<Self::Response, Self::Error> {
689 (self.response)(ctx, req)
690 }
691}
692
693impl<F> FrameService<F> {
694 pub fn new<State, E>(response: F) -> Self
695 where
696 F: Fn(Context<State>, Frame) -> Result<Frame, E> + Clone,
697 E: Send + Sync + 'static,
698 {
699 Self { response }
700 }
701}
702
703#[derive(Clone, Copy, Hash)]
705pub struct ResponseService<F> {
706 response: F,
707}
708
709impl<F> Debug for ResponseService<F> {
710 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
711 f.debug_struct(stringify!(ResponseService)).finish()
712 }
713}
714
715impl<State, Q, E, F> Service<State, Q> for ResponseService<F>
716where
717 F: Fn(Context<State>, Q) -> Result<Q::Response, E> + Clone + Send + Sync + 'static,
718 Q: Request,
719 E: Send + Sync + 'static,
720 State: Send + Sync + 'static,
721{
722 type Response = Q::Response;
723 type Error = E;
724
725 #[instrument(skip(ctx, req))]
726 async fn serve(&self, ctx: Context<State>, req: Q) -> Result<Self::Response, Self::Error> {
727 (self.response)(ctx, req)
728 }
729}
730
731impl<F> ResponseService<F> {
732 pub fn new<State, Q, E>(response: F) -> Self
733 where
734 F: Fn(Context<State>, Q) -> Result<Q::Response, E> + Clone,
735 Q: Request,
736 E: Send + Sync + 'static,
737 {
738 Self { response }
739 }
740}