1use std::{
103 collections::BTreeMap,
104 error, fmt, io,
105 sync::{Arc, LazyLock},
106 time::SystemTime,
107};
108
109use bytes::Bytes;
110use deadpool::managed::{self, BuildError, Object, PoolError};
111use opentelemetry::{
112 InstrumentationScope, KeyValue, global,
113 metrics::{Counter, Histogram, Meter},
114};
115use opentelemetry_semantic_conventions::SCHEMA_URL;
116use rama::{Context, Layer, Service};
117use tansu_sans_io::{ApiKey, ApiVersionsRequest, Body, Frame, Header, Request};
118use tansu_service::{FrameBytesLayer, FrameBytesService, host_port};
119use tokio::{
120 io::{AsyncReadExt as _, AsyncWriteExt as _},
121 net::TcpStream,
122};
123use tracing::{Instrument, Level, debug, error, span};
124use tracing_subscriber::filter::ParseError;
125use url::Url;
126
127#[derive(thiserror::Error, Clone, Debug)]
129pub enum Error {
130 DeadPoolBuild(#[from] BuildError),
131 Io(Arc<io::Error>),
132 Message(String),
133 ParseFilter(Arc<ParseError>),
134 ParseUrl(#[from] url::ParseError),
135 Pool(Arc<Box<dyn error::Error + Send + Sync>>),
136 Protocol(#[from] tansu_sans_io::Error),
137 Service(#[from] tansu_service::Error),
138 UnknownApiKey(i16),
139 UnknownHost(Url),
140}
141
142impl fmt::Display for Error {
143 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144 write!(f, "{self:?}")
145 }
146}
147
148impl<E> From<PoolError<E>> for Error
149where
150 E: error::Error + Send + Sync + 'static,
151{
152 fn from(value: PoolError<E>) -> Self {
153 Self::Pool(Arc::new(Box::new(value)))
154 }
155}
156
157impl From<io::Error> for Error {
158 fn from(value: io::Error) -> Self {
159 Self::Io(Arc::new(value))
160 }
161}
162
163impl From<ParseError> for Error {
164 fn from(value: ParseError) -> Self {
165 Self::ParseFilter(Arc::new(value))
166 }
167}
168
169pub(crate) static METER: LazyLock<Meter> = LazyLock::new(|| {
170 global::meter_with_scope(
171 InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
172 .with_version(env!("CARGO_PKG_VERSION"))
173 .with_schema_url(SCHEMA_URL)
174 .build(),
175 )
176});
177
178#[derive(Debug)]
180pub struct Connection {
181 stream: TcpStream,
182 correlation_id: i32,
183}
184
185#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
187pub struct ConnectionManager {
188 broker: Url,
189 client_id: Option<String>,
190 versions: BTreeMap<i16, i16>,
191}
192
193impl ConnectionManager {
194 pub fn builder(broker: Url) -> Builder {
196 Builder::broker(broker)
197 }
198
199 pub fn client_id(&self) -> Option<String> {
201 self.client_id.clone()
202 }
203
204 pub fn api_version(&self, api_key: i16) -> Result<i16, Error> {
206 self.versions
207 .get(&api_key)
208 .copied()
209 .ok_or(Error::UnknownApiKey(api_key))
210 }
211}
212
213impl managed::Manager for ConnectionManager {
214 type Type = Connection;
215 type Error = Error;
216
217 async fn create(&self) -> Result<Self::Type, Self::Error> {
218 debug!(%self.broker);
219
220 let attributes = [KeyValue::new("broker", self.broker.to_string())];
221 let start = SystemTime::now();
222
223 TcpStream::connect(host_port(self.broker.clone()).await?)
224 .await
225 .inspect(|_| {
226 TCP_CONNECT_DURATION.record(
227 start
228 .elapsed()
229 .map_or(0, |duration| duration.as_millis() as u64),
230 &attributes,
231 )
232 })
233 .inspect_err(|err| {
234 error!(broker = %self.broker, ?err);
235 TCP_CONNECT_ERRORS.add(1, &attributes);
236 })
237 .map(|stream| Connection {
238 stream,
239 correlation_id: 0,
240 })
241 .map_err(Into::into)
242 }
243
244 async fn recycle(
245 &self,
246 obj: &mut Self::Type,
247 metrics: &managed::Metrics,
248 ) -> managed::RecycleResult<Self::Error> {
249 debug!(?obj, ?metrics);
250 Ok(())
251 }
252}
253
254pub type Pool = managed::Pool<ConnectionManager>;
256
257#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
259pub struct Builder {
260 broker: Url,
261 client_id: Option<String>,
262}
263
264impl Builder {
265 pub fn broker(broker: Url) -> Self {
267 Self {
268 broker,
269 client_id: None,
270 }
271 }
272
273 pub fn client_id(self, client_id: Option<String>) -> Self {
275 Self { client_id, ..self }
276 }
277
278 async fn bootstrap(&self) -> Result<BTreeMap<i16, i16>, Error> {
280 let versions = BTreeMap::from([(ApiVersionsRequest::KEY, 0)]);
283
284 let req = ApiVersionsRequest::default()
285 .client_software_name(Some(env!("CARGO_PKG_NAME").into()))
286 .client_software_version(Some(env!("CARGO_PKG_VERSION").into()));
287
288 let client = Pool::builder(ConnectionManager {
289 broker: self.broker.clone(),
290 client_id: self.client_id.clone(),
291 versions,
292 })
293 .build()
294 .map(Client::new)?;
295
296 client.call(req).await.map(|response| {
297 response
298 .api_keys
299 .unwrap_or_default()
300 .into_iter()
301 .map(|api| (api.api_key, api.max_version))
302 .collect()
303 })
304 }
305
306 pub async fn build(self) -> Result<Pool, Error> {
308 self.bootstrap().await.and_then(|versions| {
309 Pool::builder(ConnectionManager {
310 broker: self.broker,
311 client_id: self.client_id,
312 versions,
313 })
314 .build()
315 .map_err(Into::into)
316 })
317 }
318}
319
320#[derive(Clone, Debug)]
322pub struct FramePoolLayer {
323 pool: Pool,
324}
325
326impl FramePoolLayer {
327 pub fn new(pool: Pool) -> Self {
328 Self { pool }
329 }
330}
331
332impl<S> Layer<S> for FramePoolLayer {
333 type Service = FramePoolService<S>;
334
335 fn layer(&self, inner: S) -> Self::Service {
336 FramePoolService {
337 pool: self.pool.clone(),
338 inner,
339 }
340 }
341}
342
343#[derive(Clone, Debug)]
345pub struct FramePoolService<S> {
346 pool: Pool,
347 inner: S,
348}
349
350impl<State, S> Service<State, Frame> for FramePoolService<S>
351where
352 S: Service<Pool, Frame, Response = Frame>,
353 State: Send + Sync + 'static,
354{
355 type Response = Frame;
356 type Error = S::Error;
357
358 async fn serve(&self, ctx: Context<State>, req: Frame) -> Result<Self::Response, Self::Error> {
359 let (ctx, _) = ctx.swap_state(self.pool.clone());
360 self.inner.serve(ctx, req).await
361 }
362}
363
364#[derive(Clone, Debug)]
366pub struct RequestPoolLayer {
367 pool: Pool,
368}
369
370impl RequestPoolLayer {
371 pub fn new(pool: Pool) -> Self {
372 Self { pool }
373 }
374}
375
376impl<S> Layer<S> for RequestPoolLayer {
377 type Service = RequestPoolService<S>;
378
379 fn layer(&self, inner: S) -> Self::Service {
380 RequestPoolService {
381 pool: self.pool.clone(),
382 inner,
383 }
384 }
385}
386
387#[derive(Clone, Debug)]
389pub struct RequestPoolService<S> {
390 pool: Pool,
391 inner: S,
392}
393
394impl<State, S, Q> Service<State, Q> for RequestPoolService<S>
395where
396 Q: Request,
397 S: Service<Pool, Q>,
398 State: Send + Sync + 'static,
399{
400 type Response = S::Response;
401 type Error = S::Error;
402
403 async fn serve(&self, ctx: Context<State>, req: Q) -> Result<Self::Response, Self::Error> {
405 let (ctx, _) = ctx.swap_state(self.pool.clone());
406 self.inner.serve(ctx, req).await
407 }
408}
409
410#[derive(Clone, Debug)]
412pub struct Client {
413 service:
414 RequestPoolService<RequestConnectionService<FrameBytesService<BytesConnectionService>>>,
415}
416
417impl Client {
418 pub fn new(pool: Pool) -> Self {
420 let service = (
421 RequestPoolLayer::new(pool),
422 RequestConnectionLayer,
423 FrameBytesLayer,
424 )
425 .into_layer(BytesConnectionService);
426
427 Self { service }
428 }
429
430 pub async fn call<Q>(&self, req: Q) -> Result<Q::Response, Error>
432 where
433 Q: Request,
434 Error: From<<<Q as Request>::Response as TryFrom<Body>>::Error>,
435 {
436 self.service.serve(Context::default(), req).await
437 }
438}
439
440#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
442pub struct FrameConnectionLayer;
443
444impl<S> Layer<S> for FrameConnectionLayer {
445 type Service = FrameConnectionService<S>;
446
447 fn layer(&self, inner: S) -> Self::Service {
448 Self::Service { inner }
449 }
450}
451
452#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
454pub struct FrameConnectionService<S> {
455 inner: S,
456}
457
458impl<S> Service<Pool, Frame> for FrameConnectionService<S>
459where
460 S: Service<Object<ConnectionManager>, Frame, Response = Frame>,
461 S::Error: From<Error> + From<PoolError<Error>> + From<tansu_sans_io::Error>,
462{
463 type Response = Frame;
464 type Error = S::Error;
465
466 async fn serve(&self, ctx: Context<Pool>, req: Frame) -> Result<Self::Response, Self::Error> {
467 debug!(?ctx, ?req);
468
469 let api_key = req.api_key()?;
470 let api_version = req.api_version()?;
471 let client_id = req
472 .client_id()
473 .map(|client_id| client_id.map(|client_id| client_id.to_string()))?;
474
475 let connection = ctx.state().get().await?;
476 let correlation_id = connection.correlation_id;
477
478 let frame = Frame {
479 size: 0,
480 header: Header::Request {
481 api_key,
482 api_version,
483 correlation_id,
484 client_id,
485 },
486 body: req.body,
487 };
488
489 let (ctx, _) = ctx.swap_state(connection);
490
491 self.inner.serve(ctx, frame).await
492 }
493}
494
495#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
497pub struct RequestConnectionLayer;
498
499impl<S> Layer<S> for RequestConnectionLayer {
500 type Service = RequestConnectionService<S>;
501
502 fn layer(&self, inner: S) -> Self::Service {
503 Self::Service { inner }
504 }
505}
506
507#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
511pub struct RequestConnectionService<S> {
512 inner: S,
513}
514
515impl<Q, S> Service<Pool, Q> for RequestConnectionService<S>
516where
517 Q: Request,
518 S: Service<Object<ConnectionManager>, Frame, Response = Frame>,
519 S::Error: From<Error>
520 + From<PoolError<Error>>
521 + From<tansu_sans_io::Error>
522 + From<<Q::Response as TryFrom<Body>>::Error>,
523{
524 type Response = Q::Response;
525 type Error = S::Error;
526
527 async fn serve(&self, ctx: Context<Pool>, req: Q) -> Result<Self::Response, Self::Error> {
528 debug!(?ctx, ?req);
529 let pool = ctx.state();
530 let api_key = Q::KEY;
531 let api_version = pool.manager().api_version(api_key)?;
532 let client_id = pool.manager().client_id();
533 let connection = pool.get().await?;
534 let correlation_id = connection.correlation_id;
535
536 let frame = Frame {
537 size: 0,
538 header: Header::Request {
539 api_key,
540 api_version,
541 correlation_id,
542 client_id,
543 },
544 body: req.into(),
545 };
546
547 let (ctx, _) = ctx.swap_state(connection);
548
549 let frame = self.inner.serve(ctx, frame).await?;
550
551 Q::Response::try_from(frame.body).map_err(Into::into)
552 }
553}
554
555#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
557pub struct BytesConnectionService;
558
559impl BytesConnectionService {
560 async fn write(
561 &self,
562 stream: &mut TcpStream,
563 frame: Bytes,
564 attributes: &[KeyValue],
565 ) -> Result<(), Error> {
566 debug!(?frame);
567
568 let start = SystemTime::now();
569
570 stream
571 .write_all(&frame[..])
572 .await
573 .inspect(|_| {
574 TCP_SEND_DURATION.record(
575 start
576 .elapsed()
577 .map_or(0, |duration| duration.as_millis() as u64),
578 attributes,
579 );
580
581 TCP_BYTES_SENT.add(frame.len() as u64, attributes);
582 })
583 .inspect_err(|_| {
584 TCP_SEND_ERRORS.add(1, attributes);
585 })
586 .map_err(Into::into)
587 }
588
589 async fn read(&self, stream: &mut TcpStream, attributes: &[KeyValue]) -> Result<Bytes, Error> {
590 let start = SystemTime::now();
591
592 let mut size = [0u8; 4];
593 _ = stream.read_exact(&mut size).await?;
594
595 let mut buffer: Vec<u8> = vec![0u8; frame_length(size)];
596 buffer[0..size.len()].copy_from_slice(&size[..]);
597 _ = stream
598 .read_exact(&mut buffer[4..])
599 .await
600 .inspect(|_| {
601 TCP_RECEIVE_DURATION.record(
602 start
603 .elapsed()
604 .map_or(0, |duration| duration.as_millis() as u64),
605 attributes,
606 );
607
608 TCP_BYTES_RECEIVED.add(buffer.len() as u64, attributes);
609 })
610 .inspect_err(|_| {
611 TCP_RECEIVE_ERRORS.add(1, attributes);
612 })?;
613
614 Ok(Bytes::from(buffer)).inspect(|frame| debug!(?frame))
615 }
616}
617
618impl Service<Object<ConnectionManager>, Bytes> for BytesConnectionService {
619 type Response = Bytes;
620 type Error = Error;
621
622 async fn serve(
623 &self,
624 mut ctx: Context<Object<ConnectionManager>>,
625 req: Bytes,
626 ) -> Result<Self::Response, Self::Error> {
627 let c = ctx.state_mut();
628
629 let local = c.stream.local_addr()?;
630 let peer = c.stream.peer_addr()?;
631
632 let attributes = [
633 KeyValue::new("correlation_id", c.correlation_id.to_string()),
634 KeyValue::new("peer", peer.to_string()),
635 ];
636
637 let span = span!(Level::DEBUG, "client", local = %local, peer = %peer);
638
639 async move {
640 self.write(&mut c.stream, req, &attributes).await?;
641
642 c.correlation_id += 1;
643
644 self.read(&mut c.stream, &attributes).await
645 }
646 .instrument(span)
647 .await
648 }
649}
650
651fn frame_length(encoded: [u8; 4]) -> usize {
652 i32::from_be_bytes(encoded) as usize + encoded.len()
653}
654
655static TCP_CONNECT_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
656 METER
657 .u64_histogram("tcp_connect_duration")
658 .with_unit("ms")
659 .with_description("The TCP connect latencies in milliseconds")
660 .build()
661});
662
663static TCP_CONNECT_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
664 METER
665 .u64_counter("tcp_connect_errors")
666 .with_description("TCP connect errors")
667 .build()
668});
669
670static TCP_SEND_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
671 METER
672 .u64_histogram("tcp_send_duration")
673 .with_unit("ms")
674 .with_description("The TCP send latencies in milliseconds")
675 .build()
676});
677
678static TCP_SEND_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
679 METER
680 .u64_counter("tcp_send_errors")
681 .with_description("TCP send errors")
682 .build()
683});
684
685static TCP_RECEIVE_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
686 METER
687 .u64_histogram("tcp_receive_duration")
688 .with_unit("ms")
689 .with_description("The TCP receive latencies in milliseconds")
690 .build()
691});
692
693static TCP_RECEIVE_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
694 METER
695 .u64_counter("tcp_receive_errors")
696 .with_description("TCP receive errors")
697 .build()
698});
699
700static TCP_BYTES_SENT: LazyLock<Counter<u64>> = LazyLock::new(|| {
701 METER
702 .u64_counter("tcp_bytes_sent")
703 .with_description("TCP bytes sent")
704 .build()
705});
706
707static TCP_BYTES_RECEIVED: LazyLock<Counter<u64>> = LazyLock::new(|| {
708 METER
709 .u64_counter("tcp_bytes_received")
710 .with_description("TCP bytes received")
711 .build()
712});
713
714#[cfg(test)]
715mod tests {
716 use std::{fs::File, thread};
717
718 use tansu_sans_io::{MetadataRequest, MetadataResponse};
719 use tansu_service::{
720 BytesFrameLayer, FrameRouteService, RequestLayer, ResponseService, TcpBytesLayer,
721 TcpContextLayer, TcpListenerLayer,
722 };
723 use tokio::{net::TcpListener, task::JoinSet};
724 use tokio_util::sync::CancellationToken;
725 use tracing::subscriber::DefaultGuard;
726 use tracing_subscriber::EnvFilter;
727
728 use super::*;
729
730 fn init_tracing() -> Result<DefaultGuard, Error> {
731 Ok(tracing::subscriber::set_default(
732 tracing_subscriber::fmt()
733 .with_level(true)
734 .with_line_number(true)
735 .with_thread_names(false)
736 .with_env_filter(
737 EnvFilter::from_default_env()
738 .add_directive(format!("{}=debug", env!("CARGO_CRATE_NAME")).parse()?),
739 )
740 .with_writer(
741 thread::current()
742 .name()
743 .ok_or(Error::Message(String::from("unnamed thread")))
744 .and_then(|name| {
745 File::create(format!("../logs/{}/{name}.log", env!("CARGO_PKG_NAME"),))
746 .map_err(Into::into)
747 })
748 .map(Arc::new)?,
749 )
750 .finish(),
751 ))
752 }
753
754 async fn server(cancellation: CancellationToken, listener: TcpListener) -> Result<(), Error> {
755 let server = (
756 TcpListenerLayer::new(cancellation),
757 TcpContextLayer::default(),
758 TcpBytesLayer::default(),
759 BytesFrameLayer,
760 )
761 .into_layer(
762 FrameRouteService::builder()
763 .with_service(RequestLayer::<MetadataRequest>::new().into_layer(
764 ResponseService::new(|_ctx: Context<()>, _req: MetadataRequest| {
765 Ok::<_, Error>(
766 MetadataResponse::default()
767 .brokers(Some([].into()))
768 .topics(Some([].into()))
769 .cluster_id(Some("abc".into()))
770 .controller_id(Some(111))
771 .throttle_time_ms(Some(0))
772 .cluster_authorized_operations(Some(-1)),
773 )
774 }),
775 ))
776 .and_then(|builder| builder.build())?,
777 );
778
779 server
780 .serve(Context::default(), listener)
781 .await
782 .map_err(Into::into)
783 }
784
785 #[tokio::test]
786 async fn tcp_client_server() -> Result<(), Error> {
787 let _guard = init_tracing()?;
788
789 let cancellation = CancellationToken::new();
790 let listener = TcpListener::bind("127.0.0.1:0").await?;
791 let local_addr = listener.local_addr()?;
792
793 let mut join = JoinSet::new();
794
795 let _server = {
796 let cancellation = cancellation.clone();
797 join.spawn(async move { server(cancellation, listener).await })
798 };
799
800 let origin = (
801 RequestPoolLayer::new(
802 ConnectionManager::builder(
803 Url::parse(&format!("tcp://{local_addr}")).inspect(|url| debug!(%url))?,
804 )
805 .client_id(Some(env!("CARGO_PKG_NAME").into()))
806 .build()
807 .await
808 .inspect(|pool| debug!(?pool))?,
809 ),
810 RequestConnectionLayer,
811 FrameBytesLayer,
812 )
813 .into_layer(BytesConnectionService);
814
815 let response = origin
816 .serve(
817 Context::default(),
818 MetadataRequest::default()
819 .topics(Some([].into()))
820 .allow_auto_topic_creation(Some(false))
821 .include_cluster_authorized_operations(Some(false))
822 .include_topic_authorized_operations(Some(false)),
823 )
824 .await?;
825
826 assert_eq!(Some("abc"), response.cluster_id.as_deref());
827 assert_eq!(Some(111), response.controller_id);
828
829 cancellation.cancel();
830
831 let joined = join.join_all().await;
832 debug!(?joined);
833
834 Ok(())
835 }
836}