1use std::{
103 collections::BTreeMap,
104 error, fmt, io,
105 sync::{Arc, LazyLock},
106 time::SystemTime,
107};
108
109use backoff::{ExponentialBackoff, future::retry};
110use bytes::Bytes;
111use deadpool::managed::{self, BuildError, Object, PoolError};
112use opentelemetry::{
113 InstrumentationScope, KeyValue, global,
114 metrics::{Counter, Gauge, Histogram, Meter},
115};
116use opentelemetry_semantic_conventions::SCHEMA_URL;
117use rama::{Context, Layer, Service};
118use tansu_sans_io::{ApiKey, ApiVersionsRequest, Body, Frame, Header, Request, RootMessageMeta};
119use tansu_service::{FrameBytesLayer, FrameBytesService, host_port};
120use tokio::{
121 io::{AsyncReadExt as _, AsyncWriteExt as _},
122 net::TcpStream,
123 task::JoinError,
124};
125use tracing::{Instrument, Level, debug, span};
126use tracing_subscriber::filter::ParseError;
127use url::Url;
128
129#[derive(thiserror::Error, Clone, Debug)]
131pub enum Error {
132 DeadPoolBuild(#[from] BuildError),
133 Io(Arc<io::Error>),
134 Join(Arc<JoinError>),
135 Message(String),
136 ParseFilter(Arc<ParseError>),
137 ParseUrl(#[from] url::ParseError),
138 Pool(Arc<Box<dyn error::Error + Send + Sync>>),
139 Protocol(#[from] tansu_sans_io::Error),
140 Service(#[from] tansu_service::Error),
141 UnknownApiKey(i16),
142 UnknownHost(Url),
143}
144
145impl fmt::Display for Error {
146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147 write!(f, "{self:?}")
148 }
149}
150
151impl From<JoinError> for Error {
152 fn from(value: JoinError) -> Self {
153 Self::Join(Arc::new(value))
154 }
155}
156
157impl<E> From<PoolError<E>> for Error
158where
159 E: error::Error + Send + Sync + 'static,
160{
161 fn from(value: PoolError<E>) -> Self {
162 Self::Pool(Arc::new(Box::new(value)))
163 }
164}
165
166impl From<io::Error> for Error {
167 fn from(value: io::Error) -> Self {
168 Self::Io(Arc::new(value))
169 }
170}
171
172impl From<ParseError> for Error {
173 fn from(value: ParseError) -> Self {
174 Self::ParseFilter(Arc::new(value))
175 }
176}
177
178pub(crate) static METER: LazyLock<Meter> = LazyLock::new(|| {
179 global::meter_with_scope(
180 InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
181 .with_version(env!("CARGO_PKG_VERSION"))
182 .with_schema_url(SCHEMA_URL)
183 .build(),
184 )
185});
186
187#[derive(Debug)]
189pub struct Connection {
190 stream: TcpStream,
191 correlation_id: i32,
192}
193
194#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
196pub struct ConnectionManager {
197 broker: Url,
198 client_id: Option<String>,
199 versions: BTreeMap<i16, i16>,
200}
201
202impl ConnectionManager {
203 pub fn builder(broker: Url) -> Builder {
205 Builder::broker(broker)
206 }
207
208 pub fn client_id(&self) -> Option<String> {
210 self.client_id.clone()
211 }
212
213 pub fn api_version(&self, api_key: i16) -> Result<i16, Error> {
215 self.versions
216 .get(&api_key)
217 .copied()
218 .ok_or(Error::UnknownApiKey(api_key))
219 }
220}
221
222impl managed::Manager for ConnectionManager {
223 type Type = Connection;
224 type Error = Error;
225
226 async fn create(&self) -> Result<Self::Type, Self::Error> {
227 debug!(%self.broker);
228
229 let attributes = [KeyValue::new("broker", self.broker.to_string())];
230 let start = SystemTime::now();
231
232 let addr = host_port(self.broker.clone()).await?;
233
234 retry(ExponentialBackoff::default(), || async {
235 Ok(TcpStream::connect(addr)
236 .await
237 .inspect(|_| {
238 TCP_CONNECT_DURATION.record(
239 start
240 .elapsed()
241 .map_or(0, |duration| duration.as_millis() as u64),
242 &attributes,
243 )
244 })
245 .inspect_err(|err| {
246 debug!(broker = %self.broker, ?err, elapsed = start.elapsed().map_or(0, |duration| duration.as_millis() as u64));
247 TCP_CONNECT_ERRORS.add(1, &attributes);
248 })
249 .map(|stream| Connection {
250 stream,
251 correlation_id: 0,
252 })?)
253 })
254 .await
255 .map_err(Into::into)
256 }
257
258 async fn recycle(
259 &self,
260 obj: &mut Self::Type,
261 metrics: &managed::Metrics,
262 ) -> managed::RecycleResult<Self::Error> {
263 debug!(?obj, ?metrics);
264
265 Ok(())
266 }
267}
268
269pub type Pool = managed::Pool<ConnectionManager>;
271
272fn status_update(pool: &Pool) {
273 let status = pool.status();
274 POOL_AVAILABLE.record(status.available as u64, &[]);
275 POOL_CURRENT_SIZE.record(status.size as u64, &[]);
276 POOL_MAX_SIZE.record(status.max_size as u64, &[]);
277 POOL_WAITING.record(status.waiting as u64, &[]);
278}
279
280#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
282pub struct Builder {
283 broker: Url,
284 client_id: Option<String>,
285}
286
287impl Builder {
288 pub fn broker(broker: Url) -> Self {
290 Self {
291 broker,
292 client_id: None,
293 }
294 }
295
296 pub fn client_id(self, client_id: Option<String>) -> Self {
298 Self { client_id, ..self }
299 }
300
301 async fn bootstrap(&self) -> Result<BTreeMap<i16, i16>, Error> {
303 let versions = BTreeMap::from([(ApiVersionsRequest::KEY, 0)]);
306
307 let req = ApiVersionsRequest::default()
308 .client_software_name(Some(env!("CARGO_PKG_NAME").into()))
309 .client_software_version(Some(env!("CARGO_PKG_VERSION").into()));
310
311 let client = Pool::builder(ConnectionManager {
312 broker: self.broker.clone(),
313 client_id: self.client_id.clone(),
314 versions,
315 })
316 .build()
317 .map(Client::new)?;
318
319 let supported = RootMessageMeta::messages().requests();
320
321 client.call(req).await.map(|response| {
322 response
323 .api_keys
324 .unwrap_or_default()
325 .into_iter()
326 .filter_map(|api| {
327 supported.get(&api.api_key).and_then(|supported| {
328 if api.min_version >= supported.version.valid.start {
329 Some((
330 api.api_key,
331 api.max_version.min(supported.version.valid.end),
332 ))
333 } else {
334 None
335 }
336 })
337 })
338 .collect()
339 })
340 }
341
342 pub async fn build(self) -> Result<Pool, Error> {
344 self.bootstrap().await.and_then(|versions| {
345 Pool::builder(ConnectionManager {
346 broker: self.broker,
347 client_id: self.client_id,
348 versions,
349 })
350 .build()
351 .map_err(Into::into)
352 })
353 }
354}
355
356#[derive(Clone, Debug)]
358pub struct FramePoolLayer {
359 pool: Pool,
360}
361
362impl FramePoolLayer {
363 pub fn new(pool: Pool) -> Self {
364 Self { pool }
365 }
366}
367
368impl<S> Layer<S> for FramePoolLayer {
369 type Service = FramePoolService<S>;
370
371 fn layer(&self, inner: S) -> Self::Service {
372 FramePoolService {
373 pool: self.pool.clone(),
374 inner,
375 }
376 }
377}
378
379#[derive(Clone, Debug)]
381pub struct FramePoolService<S> {
382 pool: Pool,
383 inner: S,
384}
385
386impl<State, S> Service<State, Frame> for FramePoolService<S>
387where
388 S: Service<Pool, Frame, Response = Frame>,
389 State: Send + Sync + 'static,
390{
391 type Response = Frame;
392 type Error = S::Error;
393
394 async fn serve(&self, ctx: Context<State>, req: Frame) -> Result<Self::Response, Self::Error> {
395 let (ctx, _) = ctx.swap_state(self.pool.clone());
396 self.inner.serve(ctx, req).await
397 }
398}
399
400#[derive(Clone, Debug)]
402pub struct RequestPoolLayer {
403 pool: Pool,
404}
405
406impl RequestPoolLayer {
407 pub fn new(pool: Pool) -> Self {
408 Self { pool }
409 }
410}
411
412impl<S> Layer<S> for RequestPoolLayer {
413 type Service = RequestPoolService<S>;
414
415 fn layer(&self, inner: S) -> Self::Service {
416 RequestPoolService {
417 pool: self.pool.clone(),
418 inner,
419 }
420 }
421}
422
423#[derive(Clone, Debug)]
425pub struct RequestPoolService<S> {
426 pool: Pool,
427 inner: S,
428}
429
430impl<State, S, Q> Service<State, Q> for RequestPoolService<S>
431where
432 Q: Request,
433 S: Service<Pool, Q>,
434 State: Send + Sync + 'static,
435{
436 type Response = S::Response;
437 type Error = S::Error;
438
439 async fn serve(&self, ctx: Context<State>, req: Q) -> Result<Self::Response, Self::Error> {
441 let (ctx, _) = ctx.swap_state(self.pool.clone());
442 self.inner.serve(ctx, req).await
443 }
444}
445
446#[derive(Clone, Debug)]
448pub struct Client {
449 service:
450 RequestPoolService<RequestConnectionService<FrameBytesService<BytesConnectionService>>>,
451}
452
453impl Client {
454 pub fn new(pool: Pool) -> Self {
456 let service = (
457 RequestPoolLayer::new(pool),
458 RequestConnectionLayer,
459 FrameBytesLayer,
460 )
461 .into_layer(BytesConnectionService);
462
463 Self { service }
464 }
465
466 pub async fn call<Q>(&self, req: Q) -> Result<Q::Response, Error>
468 where
469 Q: Request,
470 Error: From<<<Q as Request>::Response as TryFrom<Body>>::Error>,
471 {
472 self.service.serve(Context::default(), req).await
473 }
474}
475
476#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
478pub struct FrameConnectionLayer;
479
480impl<S> Layer<S> for FrameConnectionLayer {
481 type Service = FrameConnectionService<S>;
482
483 fn layer(&self, inner: S) -> Self::Service {
484 Self::Service { inner }
485 }
486}
487
488#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
490pub struct FrameConnectionService<S> {
491 inner: S,
492}
493
494impl<S> Service<Pool, Frame> for FrameConnectionService<S>
495where
496 S: Service<Object<ConnectionManager>, Frame, Response = Frame>,
497 S::Error: From<Error> + From<PoolError<Error>> + From<tansu_sans_io::Error>,
498{
499 type Response = Frame;
500 type Error = S::Error;
501
502 async fn serve(&self, ctx: Context<Pool>, req: Frame) -> Result<Self::Response, Self::Error> {
503 debug!(?req);
504
505 let api_key = req.api_key()?;
506 let api_version = req.api_version()?;
507 let client_id = req
508 .client_id()
509 .map(|client_id| client_id.map(|client_id| client_id.to_string()))?;
510
511 let pool = ctx.state();
512 status_update(pool);
513
514 let connection = pool.get().await?;
515 let correlation_id = connection.correlation_id;
516
517 let frame = Frame {
518 size: 0,
519 header: Header::Request {
520 api_key,
521 api_version,
522 correlation_id,
523 client_id,
524 },
525 body: req.body,
526 };
527
528 let (ctx, _) = ctx.swap_state(connection);
529
530 self.inner.serve(ctx, frame).await
531 }
532}
533
534#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
536pub struct RequestConnectionLayer;
537
538impl<S> Layer<S> for RequestConnectionLayer {
539 type Service = RequestConnectionService<S>;
540
541 fn layer(&self, inner: S) -> Self::Service {
542 Self::Service { inner }
543 }
544}
545
546#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
550pub struct RequestConnectionService<S> {
551 inner: S,
552}
553
554impl<Q, S> Service<Pool, Q> for RequestConnectionService<S>
555where
556 Q: Request,
557 S: Service<Object<ConnectionManager>, Frame, Response = Frame>,
558 S::Error: From<Error>
559 + From<PoolError<Error>>
560 + From<tansu_sans_io::Error>
561 + From<<Q::Response as TryFrom<Body>>::Error>,
562{
563 type Response = Q::Response;
564 type Error = S::Error;
565
566 async fn serve(&self, ctx: Context<Pool>, req: Q) -> Result<Self::Response, Self::Error> {
567 debug!(?req);
568 let pool = ctx.state();
569 let api_key = Q::KEY;
570 let api_version = pool.manager().api_version(api_key)?;
571 let client_id = pool.manager().client_id();
572 let connection = pool.get().await?;
573 let correlation_id = connection.correlation_id;
574
575 let frame = Frame {
576 size: 0,
577 header: Header::Request {
578 api_key,
579 api_version,
580 correlation_id,
581 client_id,
582 },
583 body: req.into(),
584 };
585
586 let (ctx, _) = ctx.swap_state(connection);
587
588 let frame = self.inner.serve(ctx, frame).await?;
589
590 Q::Response::try_from(frame.body)
591 .inspect(|response| debug!(?response))
592 .map_err(Into::into)
593 }
594}
595
596#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
598pub struct BytesConnectionService;
599
600impl BytesConnectionService {
601 async fn write(
602 &self,
603 stream: &mut TcpStream,
604 frame: Bytes,
605 attributes: &[KeyValue],
606 ) -> Result<(), Error> {
607 debug!(frame = ?&frame[..]);
608
609 let start = SystemTime::now();
610
611 stream
612 .write_all(&frame[..])
613 .await
614 .inspect(|_| {
615 TCP_SEND_DURATION.record(
616 start
617 .elapsed()
618 .map_or(0, |duration| duration.as_millis() as u64),
619 attributes,
620 );
621
622 TCP_BYTES_SENT.add(frame.len() as u64, attributes);
623 })
624 .inspect_err(|_| {
625 TCP_SEND_ERRORS.add(1, attributes);
626 })
627 .map_err(Into::into)
628 }
629
630 async fn read(&self, stream: &mut TcpStream, attributes: &[KeyValue]) -> Result<Bytes, Error> {
631 let start = SystemTime::now();
632
633 let mut size = [0u8; 4];
634 _ = stream.read_exact(&mut size).await?;
635
636 let mut buffer: Vec<u8> = vec![0u8; frame_length(size)];
637 buffer[0..size.len()].copy_from_slice(&size[..]);
638 _ = stream
639 .read_exact(&mut buffer[4..])
640 .await
641 .inspect(|_| {
642 TCP_RECEIVE_DURATION.record(
643 start
644 .elapsed()
645 .map_or(0, |duration| duration.as_millis() as u64),
646 attributes,
647 );
648
649 TCP_BYTES_RECEIVED.add(buffer.len() as u64, attributes);
650 })
651 .inspect_err(|_| {
652 TCP_RECEIVE_ERRORS.add(1, attributes);
653 })?;
654
655 Ok(Bytes::from(buffer)).inspect(|frame| debug!(frame = ?&frame[..]))
656 }
657}
658
659impl Service<Object<ConnectionManager>, Bytes> for BytesConnectionService {
660 type Response = Bytes;
661 type Error = Error;
662
663 async fn serve(
664 &self,
665 mut ctx: Context<Object<ConnectionManager>>,
666 req: Bytes,
667 ) -> Result<Self::Response, Self::Error> {
668 let c = ctx.state_mut();
669
670 let local = c.stream.local_addr()?;
671 let peer = c.stream.peer_addr()?;
672
673 let attributes = [KeyValue::new("peer", peer.to_string())];
674
675 let span = span!(Level::DEBUG, "client", local = %local, peer = %peer);
676
677 async move {
678 self.write(&mut c.stream, req, &attributes).await?;
679
680 c.correlation_id += 1;
681
682 self.read(&mut c.stream, &attributes).await
683 }
684 .instrument(span)
685 .await
686 }
687}
688
689fn frame_length(encoded: [u8; 4]) -> usize {
690 i32::from_be_bytes(encoded) as usize + encoded.len()
691}
692
693static TCP_CONNECT_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
694 METER
695 .u64_histogram("tcp_connect_duration")
696 .with_unit("ms")
697 .with_description("The TCP connect latencies in milliseconds")
698 .build()
699});
700
701static TCP_CONNECT_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
702 METER
703 .u64_counter("tcp_connect_errors")
704 .with_description("TCP connect errors")
705 .build()
706});
707
708static TCP_SEND_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
709 METER
710 .u64_histogram("tcp_send_duration")
711 .with_unit("ms")
712 .with_description("The TCP send latencies in milliseconds")
713 .build()
714});
715
716static TCP_SEND_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
717 METER
718 .u64_counter("tcp_send_errors")
719 .with_description("TCP send errors")
720 .build()
721});
722
723static TCP_RECEIVE_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
724 METER
725 .u64_histogram("tcp_receive_duration")
726 .with_unit("ms")
727 .with_description("The TCP receive latencies in milliseconds")
728 .build()
729});
730
731static TCP_RECEIVE_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
732 METER
733 .u64_counter("tcp_receive_errors")
734 .with_description("TCP receive errors")
735 .build()
736});
737
738static TCP_BYTES_SENT: LazyLock<Counter<u64>> = LazyLock::new(|| {
739 METER
740 .u64_counter("tcp_bytes_sent")
741 .with_description("TCP bytes sent")
742 .build()
743});
744
745static TCP_BYTES_RECEIVED: LazyLock<Counter<u64>> = LazyLock::new(|| {
746 METER
747 .u64_counter("tcp_bytes_received")
748 .with_description("TCP bytes received")
749 .build()
750});
751
752static POOL_MAX_SIZE: LazyLock<Gauge<u64>> = LazyLock::new(|| {
753 METER
754 .u64_gauge("pool_max_size")
755 .with_description("The maximum size of the pool")
756 .build()
757});
758
759static POOL_CURRENT_SIZE: LazyLock<Gauge<u64>> = LazyLock::new(|| {
760 METER
761 .u64_gauge("pool_current_size")
762 .with_description("The current size of the pool")
763 .build()
764});
765
766static POOL_AVAILABLE: LazyLock<Gauge<u64>> = LazyLock::new(|| {
767 METER
768 .u64_gauge("pool_available")
769 .with_description("The number of available objects in the pool")
770 .build()
771});
772
773static POOL_WAITING: LazyLock<Gauge<u64>> = LazyLock::new(|| {
774 METER
775 .u64_gauge("pool_waiting")
776 .with_description("The number of waiting objects in the pool")
777 .build()
778});
779
780#[cfg(test)]
781mod tests {
782 use std::{fs::File, thread};
783
784 use tansu_sans_io::{MetadataRequest, MetadataResponse};
785 use tansu_service::{
786 BytesFrameLayer, FrameRouteService, RequestLayer, ResponseService, TcpBytesLayer,
787 TcpContextLayer, TcpListenerLayer,
788 };
789 use tokio::{net::TcpListener, task::JoinSet};
790 use tokio_util::sync::CancellationToken;
791 use tracing::subscriber::DefaultGuard;
792 use tracing_subscriber::EnvFilter;
793
794 use super::*;
795
796 fn init_tracing() -> Result<DefaultGuard, Error> {
797 Ok(tracing::subscriber::set_default(
798 tracing_subscriber::fmt()
799 .with_level(true)
800 .with_line_number(true)
801 .with_thread_names(false)
802 .with_env_filter(
803 EnvFilter::from_default_env()
804 .add_directive(format!("{}=debug", env!("CARGO_CRATE_NAME")).parse()?),
805 )
806 .with_writer(
807 thread::current()
808 .name()
809 .ok_or(Error::Message(String::from("unnamed thread")))
810 .and_then(|name| {
811 File::create(format!("../logs/{}/{name}.log", env!("CARGO_PKG_NAME"),))
812 .map_err(Into::into)
813 })
814 .map(Arc::new)?,
815 )
816 .finish(),
817 ))
818 }
819
820 async fn server(cancellation: CancellationToken, listener: TcpListener) -> Result<(), Error> {
821 let server = (
822 TcpListenerLayer::new(cancellation),
823 TcpContextLayer::default(),
824 TcpBytesLayer::default(),
825 BytesFrameLayer,
826 )
827 .into_layer(
828 FrameRouteService::builder()
829 .with_service(RequestLayer::<MetadataRequest>::new().into_layer(
830 ResponseService::new(|_ctx: Context<()>, _req: MetadataRequest| {
831 Ok::<_, Error>(
832 MetadataResponse::default()
833 .brokers(Some([].into()))
834 .topics(Some([].into()))
835 .cluster_id(Some("abc".into()))
836 .controller_id(Some(111))
837 .throttle_time_ms(Some(0))
838 .cluster_authorized_operations(Some(-1)),
839 )
840 }),
841 ))
842 .and_then(|builder| builder.build())?,
843 );
844
845 server.serve(Context::default(), listener).await
846 }
847
848 #[tokio::test]
849 async fn tcp_client_server() -> Result<(), Error> {
850 let _guard = init_tracing()?;
851
852 let cancellation = CancellationToken::new();
853 let listener = TcpListener::bind("127.0.0.1:0").await?;
854 let local_addr = listener.local_addr()?;
855
856 let mut join = JoinSet::new();
857
858 let _server = {
859 let cancellation = cancellation.clone();
860 join.spawn(async move { server(cancellation, listener).await })
861 };
862
863 let origin = (
864 RequestPoolLayer::new(
865 ConnectionManager::builder(
866 Url::parse(&format!("tcp://{local_addr}")).inspect(|url| debug!(%url))?,
867 )
868 .client_id(Some(env!("CARGO_PKG_NAME").into()))
869 .build()
870 .await
871 .inspect(|pool| debug!(?pool))?,
872 ),
873 RequestConnectionLayer,
874 FrameBytesLayer,
875 )
876 .into_layer(BytesConnectionService);
877
878 let response = origin
879 .serve(
880 Context::default(),
881 MetadataRequest::default()
882 .topics(Some([].into()))
883 .allow_auto_topic_creation(Some(false))
884 .include_cluster_authorized_operations(Some(false))
885 .include_topic_authorized_operations(Some(false)),
886 )
887 .await?;
888
889 assert_eq!(Some("abc"), response.cluster_id.as_deref());
890 assert_eq!(Some(111), response.controller_id);
891
892 cancellation.cancel();
893
894 let joined = join.join_all().await;
895 debug!(?joined);
896
897 Ok(())
898 }
899}