1use std::{
103 collections::BTreeMap,
104 error, fmt, io,
105 sync::{Arc, LazyLock},
106 time::SystemTime,
107};
108
109use backoff::{ExponentialBackoff, future::retry};
110use bytes::Bytes;
111use deadpool::managed::{self, BuildError, Object, PoolError};
112use opentelemetry::{
113 InstrumentationScope, KeyValue, global,
114 metrics::{Counter, Histogram, Meter},
115};
116use opentelemetry_semantic_conventions::SCHEMA_URL;
117use rama::{Context, Layer, Service};
118use tansu_sans_io::{ApiKey, ApiVersionsRequest, Body, Frame, Header, Request, RootMessageMeta};
119use tansu_service::{FrameBytesLayer, FrameBytesService, host_port};
120use tokio::{
121 io::{AsyncReadExt as _, AsyncWriteExt as _},
122 net::TcpStream,
123 task::JoinError,
124};
125use tracing::{Instrument, Level, debug, span};
126use tracing_subscriber::filter::ParseError;
127use url::Url;
128
129#[derive(thiserror::Error, Clone, Debug)]
131pub enum Error {
132 DeadPoolBuild(#[from] BuildError),
133 Io(Arc<io::Error>),
134 Join(Arc<JoinError>),
135 Message(String),
136 ParseFilter(Arc<ParseError>),
137 ParseUrl(#[from] url::ParseError),
138 Pool(Arc<Box<dyn error::Error + Send + Sync>>),
139 Protocol(#[from] tansu_sans_io::Error),
140 Service(#[from] tansu_service::Error),
141 UnknownApiKey(i16),
142 UnknownHost(Url),
143}
144
145impl fmt::Display for Error {
146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147 write!(f, "{self:?}")
148 }
149}
150
151impl From<JoinError> for Error {
152 fn from(value: JoinError) -> Self {
153 Self::Join(Arc::new(value))
154 }
155}
156
157impl<E> From<PoolError<E>> for Error
158where
159 E: error::Error + Send + Sync + 'static,
160{
161 fn from(value: PoolError<E>) -> Self {
162 Self::Pool(Arc::new(Box::new(value)))
163 }
164}
165
166impl From<io::Error> for Error {
167 fn from(value: io::Error) -> Self {
168 Self::Io(Arc::new(value))
169 }
170}
171
172impl From<ParseError> for Error {
173 fn from(value: ParseError) -> Self {
174 Self::ParseFilter(Arc::new(value))
175 }
176}
177
178pub(crate) static METER: LazyLock<Meter> = LazyLock::new(|| {
179 global::meter_with_scope(
180 InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
181 .with_version(env!("CARGO_PKG_VERSION"))
182 .with_schema_url(SCHEMA_URL)
183 .build(),
184 )
185});
186
187#[derive(Debug)]
189pub struct Connection {
190 stream: TcpStream,
191 correlation_id: i32,
192}
193
194#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
196pub struct ConnectionManager {
197 broker: Url,
198 client_id: Option<String>,
199 versions: BTreeMap<i16, i16>,
200}
201
202impl ConnectionManager {
203 pub fn builder(broker: Url) -> Builder {
205 Builder::broker(broker)
206 }
207
208 pub fn client_id(&self) -> Option<String> {
210 self.client_id.clone()
211 }
212
213 pub fn api_version(&self, api_key: i16) -> Result<i16, Error> {
215 self.versions
216 .get(&api_key)
217 .copied()
218 .ok_or(Error::UnknownApiKey(api_key))
219 }
220}
221
222impl managed::Manager for ConnectionManager {
223 type Type = Connection;
224 type Error = Error;
225
226 async fn create(&self) -> Result<Self::Type, Self::Error> {
227 debug!(%self.broker);
228
229 let attributes = [KeyValue::new("broker", self.broker.to_string())];
230 let start = SystemTime::now();
231
232 let addr = host_port(self.broker.clone()).await?;
233
234 retry(ExponentialBackoff::default(), || async {
235 Ok(TcpStream::connect(addr)
236 .await
237 .inspect(|_| {
238 TCP_CONNECT_DURATION.record(
239 start
240 .elapsed()
241 .map_or(0, |duration| duration.as_millis() as u64),
242 &attributes,
243 )
244 })
245 .inspect_err(|err| {
246 debug!(broker = %self.broker, ?err, elapsed = start.elapsed().map_or(0, |duration| duration.as_millis() as u64));
247 TCP_CONNECT_ERRORS.add(1, &attributes);
248 })
249 .map(|stream| Connection {
250 stream,
251 correlation_id: 0,
252 })?)
253 })
254 .await
255 .map_err(Into::into)
256 }
257
258 async fn recycle(
259 &self,
260 obj: &mut Self::Type,
261 metrics: &managed::Metrics,
262 ) -> managed::RecycleResult<Self::Error> {
263 debug!(?obj, ?metrics);
264 Ok(())
265 }
266}
267
268pub type Pool = managed::Pool<ConnectionManager>;
270
271#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
273pub struct Builder {
274 broker: Url,
275 client_id: Option<String>,
276}
277
278impl Builder {
279 pub fn broker(broker: Url) -> Self {
281 Self {
282 broker,
283 client_id: None,
284 }
285 }
286
287 pub fn client_id(self, client_id: Option<String>) -> Self {
289 Self { client_id, ..self }
290 }
291
292 async fn bootstrap(&self) -> Result<BTreeMap<i16, i16>, Error> {
294 let versions = BTreeMap::from([(ApiVersionsRequest::KEY, 0)]);
297
298 let req = ApiVersionsRequest::default()
299 .client_software_name(Some(env!("CARGO_PKG_NAME").into()))
300 .client_software_version(Some(env!("CARGO_PKG_VERSION").into()));
301
302 let client = Pool::builder(ConnectionManager {
303 broker: self.broker.clone(),
304 client_id: self.client_id.clone(),
305 versions,
306 })
307 .build()
308 .map(Client::new)?;
309
310 let supported = RootMessageMeta::messages().requests();
311
312 client.call(req).await.map(|response| {
313 response
314 .api_keys
315 .unwrap_or_default()
316 .into_iter()
317 .filter_map(|api| {
318 supported.get(&api.api_key).and_then(|supported| {
319 if api.min_version >= supported.version.valid.start {
320 Some((
321 api.api_key,
322 api.max_version.min(supported.version.valid.end),
323 ))
324 } else {
325 None
326 }
327 })
328 })
329 .collect()
330 })
331 }
332
333 pub async fn build(self) -> Result<Pool, Error> {
335 self.bootstrap().await.and_then(|versions| {
336 Pool::builder(ConnectionManager {
337 broker: self.broker,
338 client_id: self.client_id,
339 versions,
340 })
341 .build()
342 .map_err(Into::into)
343 })
344 }
345}
346
347#[derive(Clone, Debug)]
349pub struct FramePoolLayer {
350 pool: Pool,
351}
352
353impl FramePoolLayer {
354 pub fn new(pool: Pool) -> Self {
355 Self { pool }
356 }
357}
358
359impl<S> Layer<S> for FramePoolLayer {
360 type Service = FramePoolService<S>;
361
362 fn layer(&self, inner: S) -> Self::Service {
363 FramePoolService {
364 pool: self.pool.clone(),
365 inner,
366 }
367 }
368}
369
370#[derive(Clone, Debug)]
372pub struct FramePoolService<S> {
373 pool: Pool,
374 inner: S,
375}
376
377impl<State, S> Service<State, Frame> for FramePoolService<S>
378where
379 S: Service<Pool, Frame, Response = Frame>,
380 State: Send + Sync + 'static,
381{
382 type Response = Frame;
383 type Error = S::Error;
384
385 async fn serve(&self, ctx: Context<State>, req: Frame) -> Result<Self::Response, Self::Error> {
386 let (ctx, _) = ctx.swap_state(self.pool.clone());
387 self.inner.serve(ctx, req).await
388 }
389}
390
391#[derive(Clone, Debug)]
393pub struct RequestPoolLayer {
394 pool: Pool,
395}
396
397impl RequestPoolLayer {
398 pub fn new(pool: Pool) -> Self {
399 Self { pool }
400 }
401}
402
403impl<S> Layer<S> for RequestPoolLayer {
404 type Service = RequestPoolService<S>;
405
406 fn layer(&self, inner: S) -> Self::Service {
407 RequestPoolService {
408 pool: self.pool.clone(),
409 inner,
410 }
411 }
412}
413
414#[derive(Clone, Debug)]
416pub struct RequestPoolService<S> {
417 pool: Pool,
418 inner: S,
419}
420
421impl<State, S, Q> Service<State, Q> for RequestPoolService<S>
422where
423 Q: Request,
424 S: Service<Pool, Q>,
425 State: Send + Sync + 'static,
426{
427 type Response = S::Response;
428 type Error = S::Error;
429
430 async fn serve(&self, ctx: Context<State>, req: Q) -> Result<Self::Response, Self::Error> {
432 let (ctx, _) = ctx.swap_state(self.pool.clone());
433 self.inner.serve(ctx, req).await
434 }
435}
436
437#[derive(Clone, Debug)]
439pub struct Client {
440 service:
441 RequestPoolService<RequestConnectionService<FrameBytesService<BytesConnectionService>>>,
442}
443
444impl Client {
445 pub fn new(pool: Pool) -> Self {
447 let service = (
448 RequestPoolLayer::new(pool),
449 RequestConnectionLayer,
450 FrameBytesLayer,
451 )
452 .into_layer(BytesConnectionService);
453
454 Self { service }
455 }
456
457 pub async fn call<Q>(&self, req: Q) -> Result<Q::Response, Error>
459 where
460 Q: Request,
461 Error: From<<<Q as Request>::Response as TryFrom<Body>>::Error>,
462 {
463 self.service.serve(Context::default(), req).await
464 }
465}
466
467#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
469pub struct FrameConnectionLayer;
470
471impl<S> Layer<S> for FrameConnectionLayer {
472 type Service = FrameConnectionService<S>;
473
474 fn layer(&self, inner: S) -> Self::Service {
475 Self::Service { inner }
476 }
477}
478
479#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
481pub struct FrameConnectionService<S> {
482 inner: S,
483}
484
485impl<S> Service<Pool, Frame> for FrameConnectionService<S>
486where
487 S: Service<Object<ConnectionManager>, Frame, Response = Frame>,
488 S::Error: From<Error> + From<PoolError<Error>> + From<tansu_sans_io::Error>,
489{
490 type Response = Frame;
491 type Error = S::Error;
492
493 async fn serve(&self, ctx: Context<Pool>, req: Frame) -> Result<Self::Response, Self::Error> {
494 debug!(?req);
495
496 let api_key = req.api_key()?;
497 let api_version = req.api_version()?;
498 let client_id = req
499 .client_id()
500 .map(|client_id| client_id.map(|client_id| client_id.to_string()))?;
501
502 let connection = ctx.state().get().await?;
503 let correlation_id = connection.correlation_id;
504
505 let frame = Frame {
506 size: 0,
507 header: Header::Request {
508 api_key,
509 api_version,
510 correlation_id,
511 client_id,
512 },
513 body: req.body,
514 };
515
516 let (ctx, _) = ctx.swap_state(connection);
517
518 self.inner.serve(ctx, frame).await
519 }
520}
521
522#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
524pub struct RequestConnectionLayer;
525
526impl<S> Layer<S> for RequestConnectionLayer {
527 type Service = RequestConnectionService<S>;
528
529 fn layer(&self, inner: S) -> Self::Service {
530 Self::Service { inner }
531 }
532}
533
534#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
538pub struct RequestConnectionService<S> {
539 inner: S,
540}
541
542impl<Q, S> Service<Pool, Q> for RequestConnectionService<S>
543where
544 Q: Request,
545 S: Service<Object<ConnectionManager>, Frame, Response = Frame>,
546 S::Error: From<Error>
547 + From<PoolError<Error>>
548 + From<tansu_sans_io::Error>
549 + From<<Q::Response as TryFrom<Body>>::Error>,
550{
551 type Response = Q::Response;
552 type Error = S::Error;
553
554 async fn serve(&self, ctx: Context<Pool>, req: Q) -> Result<Self::Response, Self::Error> {
555 debug!(?req);
556 let pool = ctx.state();
557 let api_key = Q::KEY;
558 let api_version = pool.manager().api_version(api_key)?;
559 let client_id = pool.manager().client_id();
560 let connection = pool.get().await?;
561 let correlation_id = connection.correlation_id;
562
563 let frame = Frame {
564 size: 0,
565 header: Header::Request {
566 api_key,
567 api_version,
568 correlation_id,
569 client_id,
570 },
571 body: req.into(),
572 };
573
574 let (ctx, _) = ctx.swap_state(connection);
575
576 let frame = self.inner.serve(ctx, frame).await?;
577
578 Q::Response::try_from(frame.body)
579 .inspect(|response| debug!(?response))
580 .map_err(Into::into)
581 }
582}
583
584#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
586pub struct BytesConnectionService;
587
588impl BytesConnectionService {
589 async fn write(
590 &self,
591 stream: &mut TcpStream,
592 frame: Bytes,
593 attributes: &[KeyValue],
594 ) -> Result<(), Error> {
595 debug!(frame = ?&frame[..]);
596
597 let start = SystemTime::now();
598
599 stream
600 .write_all(&frame[..])
601 .await
602 .inspect(|_| {
603 TCP_SEND_DURATION.record(
604 start
605 .elapsed()
606 .map_or(0, |duration| duration.as_millis() as u64),
607 attributes,
608 );
609
610 TCP_BYTES_SENT.add(frame.len() as u64, attributes);
611 })
612 .inspect_err(|_| {
613 TCP_SEND_ERRORS.add(1, attributes);
614 })
615 .map_err(Into::into)
616 }
617
618 async fn read(&self, stream: &mut TcpStream, attributes: &[KeyValue]) -> Result<Bytes, Error> {
619 let start = SystemTime::now();
620
621 let mut size = [0u8; 4];
622 _ = stream.read_exact(&mut size).await?;
623
624 let mut buffer: Vec<u8> = vec![0u8; frame_length(size)];
625 buffer[0..size.len()].copy_from_slice(&size[..]);
626 _ = stream
627 .read_exact(&mut buffer[4..])
628 .await
629 .inspect(|_| {
630 TCP_RECEIVE_DURATION.record(
631 start
632 .elapsed()
633 .map_or(0, |duration| duration.as_millis() as u64),
634 attributes,
635 );
636
637 TCP_BYTES_RECEIVED.add(buffer.len() as u64, attributes);
638 })
639 .inspect_err(|_| {
640 TCP_RECEIVE_ERRORS.add(1, attributes);
641 })?;
642
643 Ok(Bytes::from(buffer)).inspect(|frame| debug!(frame = ?&frame[..]))
644 }
645}
646
647impl Service<Object<ConnectionManager>, Bytes> for BytesConnectionService {
648 type Response = Bytes;
649 type Error = Error;
650
651 async fn serve(
652 &self,
653 mut ctx: Context<Object<ConnectionManager>>,
654 req: Bytes,
655 ) -> Result<Self::Response, Self::Error> {
656 let c = ctx.state_mut();
657
658 let local = c.stream.local_addr()?;
659 let peer = c.stream.peer_addr()?;
660
661 let attributes = [KeyValue::new("peer", peer.to_string())];
662
663 let span = span!(Level::DEBUG, "client", local = %local, peer = %peer);
664
665 async move {
666 self.write(&mut c.stream, req, &attributes).await?;
667
668 c.correlation_id += 1;
669
670 self.read(&mut c.stream, &attributes).await
671 }
672 .instrument(span)
673 .await
674 }
675}
676
677fn frame_length(encoded: [u8; 4]) -> usize {
678 i32::from_be_bytes(encoded) as usize + encoded.len()
679}
680
681static TCP_CONNECT_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
682 METER
683 .u64_histogram("tcp_connect_duration")
684 .with_unit("ms")
685 .with_description("The TCP connect latencies in milliseconds")
686 .build()
687});
688
689static TCP_CONNECT_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
690 METER
691 .u64_counter("tcp_connect_errors")
692 .with_description("TCP connect errors")
693 .build()
694});
695
696static TCP_SEND_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
697 METER
698 .u64_histogram("tcp_send_duration")
699 .with_unit("ms")
700 .with_description("The TCP send latencies in milliseconds")
701 .build()
702});
703
704static TCP_SEND_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
705 METER
706 .u64_counter("tcp_send_errors")
707 .with_description("TCP send errors")
708 .build()
709});
710
711static TCP_RECEIVE_DURATION: LazyLock<Histogram<u64>> = LazyLock::new(|| {
712 METER
713 .u64_histogram("tcp_receive_duration")
714 .with_unit("ms")
715 .with_description("The TCP receive latencies in milliseconds")
716 .build()
717});
718
719static TCP_RECEIVE_ERRORS: LazyLock<Counter<u64>> = LazyLock::new(|| {
720 METER
721 .u64_counter("tcp_receive_errors")
722 .with_description("TCP receive errors")
723 .build()
724});
725
726static TCP_BYTES_SENT: LazyLock<Counter<u64>> = LazyLock::new(|| {
727 METER
728 .u64_counter("tcp_bytes_sent")
729 .with_description("TCP bytes sent")
730 .build()
731});
732
733static TCP_BYTES_RECEIVED: LazyLock<Counter<u64>> = LazyLock::new(|| {
734 METER
735 .u64_counter("tcp_bytes_received")
736 .with_description("TCP bytes received")
737 .build()
738});
739
740#[cfg(test)]
741mod tests {
742 use std::{fs::File, thread};
743
744 use tansu_sans_io::{MetadataRequest, MetadataResponse};
745 use tansu_service::{
746 BytesFrameLayer, FrameRouteService, RequestLayer, ResponseService, TcpBytesLayer,
747 TcpContextLayer, TcpListenerLayer,
748 };
749 use tokio::{net::TcpListener, task::JoinSet};
750 use tokio_util::sync::CancellationToken;
751 use tracing::subscriber::DefaultGuard;
752 use tracing_subscriber::EnvFilter;
753
754 use super::*;
755
756 fn init_tracing() -> Result<DefaultGuard, Error> {
757 Ok(tracing::subscriber::set_default(
758 tracing_subscriber::fmt()
759 .with_level(true)
760 .with_line_number(true)
761 .with_thread_names(false)
762 .with_env_filter(
763 EnvFilter::from_default_env()
764 .add_directive(format!("{}=debug", env!("CARGO_CRATE_NAME")).parse()?),
765 )
766 .with_writer(
767 thread::current()
768 .name()
769 .ok_or(Error::Message(String::from("unnamed thread")))
770 .and_then(|name| {
771 File::create(format!("../logs/{}/{name}.log", env!("CARGO_PKG_NAME"),))
772 .map_err(Into::into)
773 })
774 .map(Arc::new)?,
775 )
776 .finish(),
777 ))
778 }
779
780 async fn server(cancellation: CancellationToken, listener: TcpListener) -> Result<(), Error> {
781 let server = (
782 TcpListenerLayer::new(cancellation),
783 TcpContextLayer::default(),
784 TcpBytesLayer::default(),
785 BytesFrameLayer,
786 )
787 .into_layer(
788 FrameRouteService::builder()
789 .with_service(RequestLayer::<MetadataRequest>::new().into_layer(
790 ResponseService::new(|_ctx: Context<()>, _req: MetadataRequest| {
791 Ok::<_, Error>(
792 MetadataResponse::default()
793 .brokers(Some([].into()))
794 .topics(Some([].into()))
795 .cluster_id(Some("abc".into()))
796 .controller_id(Some(111))
797 .throttle_time_ms(Some(0))
798 .cluster_authorized_operations(Some(-1)),
799 )
800 }),
801 ))
802 .and_then(|builder| builder.build())?,
803 );
804
805 server.serve(Context::default(), listener).await
806 }
807
808 #[tokio::test]
809 async fn tcp_client_server() -> Result<(), Error> {
810 let _guard = init_tracing()?;
811
812 let cancellation = CancellationToken::new();
813 let listener = TcpListener::bind("127.0.0.1:0").await?;
814 let local_addr = listener.local_addr()?;
815
816 let mut join = JoinSet::new();
817
818 let _server = {
819 let cancellation = cancellation.clone();
820 join.spawn(async move { server(cancellation, listener).await })
821 };
822
823 let origin = (
824 RequestPoolLayer::new(
825 ConnectionManager::builder(
826 Url::parse(&format!("tcp://{local_addr}")).inspect(|url| debug!(%url))?,
827 )
828 .client_id(Some(env!("CARGO_PKG_NAME").into()))
829 .build()
830 .await
831 .inspect(|pool| debug!(?pool))?,
832 ),
833 RequestConnectionLayer,
834 FrameBytesLayer,
835 )
836 .into_layer(BytesConnectionService);
837
838 let response = origin
839 .serve(
840 Context::default(),
841 MetadataRequest::default()
842 .topics(Some([].into()))
843 .allow_auto_topic_creation(Some(false))
844 .include_cluster_authorized_operations(Some(false))
845 .include_topic_authorized_operations(Some(false)),
846 )
847 .await?;
848
849 assert_eq!(Some("abc"), response.cluster_id.as_deref());
850 assert_eq!(Some(111), response.controller_id);
851
852 cancellation.cancel();
853
854 let joined = join.join_all().await;
855 debug!(?joined);
856
857 Ok(())
858 }
859}