1use opentelemetry::{InstrumentationScope, global, metrics::Meter};
16use opentelemetry_otlp::ExporterBuildError;
17use opentelemetry_sdk::error::OTelSdkError;
18use opentelemetry_semantic_conventions::SCHEMA_URL;
19use rama::{
20 Context, Layer, Service,
21 layer::{HijackLayer, MapErrLayer, MapRequestLayer, MapResponseLayer},
22};
23use std::{
24 fmt, io,
25 sync::{Arc, LazyLock},
26};
27use tansu_client::{
28 BytesConnectionService, ConnectionManager, FrameConnectionLayer, FramePoolLayer,
29 RequestConnectionLayer, RequestPoolLayer,
30};
31use tansu_otel::meter_provider;
32use tansu_sans_io::{
33 ApiKey, ErrorCode, FindCoordinatorRequest, FindCoordinatorResponse, MetadataRequest,
34 MetadataResponse, NULL_TOPIC_ID, ProduceRequest, find_coordinator_response::Coordinator,
35 metadata_request::MetadataRequestTopic, metadata_response::MetadataResponseBroker,
36};
37use tansu_service::{
38 BytesFrameLayer, FrameApiKeyMatcher, FrameBytesLayer, FrameRequestLayer, TcpBytesLayer,
39 TcpContextLayer, TcpListenerLayer, host_port,
40};
41use tokio::{
42 net::TcpListener,
43 task::{JoinError, JoinSet},
44};
45use tokio_util::sync::CancellationToken;
46use tracing::debug;
47use tracing_subscriber::filter::ParseError;
48use url::Url;
49
50use crate::{
51 produce::BatchProduceLayer,
52 topic::{ResourceConfig, ResourceConfigValue, ResourceConfigValueMatcher, TopicConfigLayer},
53};
54
55mod produce;
56mod topic;
57
58#[derive(Clone, Debug, thiserror::Error)]
59pub enum Error {
60 Client(#[from] tansu_client::Error),
61 ExporterBuild(Arc<ExporterBuildError>),
62 FrameTooBig(usize),
63 Io(Arc<io::Error>),
64 Join(Arc<JoinError>),
65 Otel(#[from] tansu_otel::Error),
66 OtelSdk(Arc<OTelSdkError>),
67 ParseFilter(Arc<ParseError>),
68 Protocol(#[from] tansu_sans_io::Error),
69 ResourceLock {
70 name: String,
71 key: Option<String>,
72 value: Option<ResourceConfigValue>,
73 },
74 Service(#[from] tansu_service::Error),
75 UnknownHost(Url),
76 Message(String),
77}
78
79impl fmt::Display for Error {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 write!(f, "{self:?}")
82 }
83}
84
85impl From<JoinError> for Error {
86 fn from(value: JoinError) -> Self {
87 Self::Join(Arc::new(value))
88 }
89}
90
91impl From<OTelSdkError> for Error {
92 fn from(value: OTelSdkError) -> Self {
93 Self::OtelSdk(Arc::new(value))
94 }
95}
96
97impl From<ExporterBuildError> for Error {
98 fn from(value: ExporterBuildError) -> Self {
99 Self::ExporterBuild(Arc::new(value))
100 }
101}
102
103impl From<ParseError> for Error {
104 fn from(value: ParseError) -> Self {
105 Self::ParseFilter(Arc::new(value))
106 }
107}
108
109impl From<io::Error> for Error {
110 fn from(value: io::Error) -> Self {
111 Self::Io(Arc::new(value))
112 }
113}
114
115pub(crate) static METER: LazyLock<Meter> = LazyLock::new(|| {
116 global::meter_with_scope(
117 InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
118 .with_version(env!("CARGO_PKG_VERSION"))
119 .with_schema_url(SCHEMA_URL)
120 .build(),
121 )
122});
123
124#[derive(Clone, Debug)]
125pub struct Proxy {
126 listener: Url,
127 advertised_listener: Url,
128 origin: Url,
129}
130
131impl Proxy {
132 pub fn new(listener: Url, advertised_listener: Url, origin: Url) -> Self {
133 Self {
134 listener,
135 advertised_listener,
136 origin,
137 }
138 }
139
140 pub async fn listen(&self) -> Result<(), Error> {
141 debug!(%self.listener, %self.advertised_listener, %self.origin);
142
143 let configuration = ResourceConfig::default();
144
145 let listener = TcpListener::bind(host_port(self.listener.clone()).await?).await?;
146
147 let token = CancellationToken::new();
148
149 let pool = ConnectionManager::builder(self.origin.clone())
150 .client_id(Some(env!("CARGO_PKG_NAME").into()))
151 .build()
152 .await
153 .inspect(|pool| debug!(?pool))?;
154
155 let request_origin = (
156 MapErrLayer::new(Error::from),
157 RequestPoolLayer::new(pool.clone()),
158 RequestConnectionLayer,
159 FrameBytesLayer,
160 )
161 .into_layer(BytesConnectionService);
162
163 let frame_origin = (
164 MapErrLayer::new(Error::from),
165 FramePoolLayer::new(pool.clone()),
166 FrameConnectionLayer,
167 FrameBytesLayer,
168 )
169 .into_layer(BytesConnectionService);
170
171 let host = String::from(self.advertised_listener.host_str().unwrap_or("localhost"));
172 let port = i32::from(self.advertised_listener.port().unwrap_or(9092));
173
174 let meta = HijackLayer::new(
175 FrameApiKeyMatcher(MetadataRequest::KEY),
176 (
177 FrameRequestLayer::<MetadataRequest>::new(),
178 MapRequestLayer::new(move |request: MetadataRequest| {
179 MetadataRequest::default()
180 .topics(request.topics.map(|topics| {
181 topics
182 .into_iter()
183 .map(|topic| {
184 MetadataRequestTopic::default()
185 .name(topic.name)
186 .topic_id(topic.topic_id.or(Some(NULL_TOPIC_ID)))
187 })
188 .collect()
189 }))
190 .allow_auto_topic_creation(
191 request.allow_auto_topic_creation.or(Some(false)),
192 )
193 .include_cluster_authorized_operations(
194 request.include_cluster_authorized_operations,
195 )
196 .include_topic_authorized_operations(
197 request.include_topic_authorized_operations.or(Some(false)),
198 )
199 }),
200 MapResponseLayer::new(move |response: MetadataResponse| {
201 let brokers = response.brokers.as_ref().map(|brokers| {
202 brokers
203 .iter()
204 .map(|broker| {
205 MetadataResponseBroker::default()
206 .node_id(broker.node_id)
207 .host(host.clone())
208 .port(port)
209 .rack(broker.rack.clone())
210 })
211 .collect()
212 });
213
214 response.brokers(brokers)
215 }),
216 )
217 .into_layer(request_origin.clone()),
218 );
219
220 let host = String::from(self.advertised_listener.host_str().unwrap_or("localhost"));
221
222 let find_coordinator = HijackLayer::new(
223 FrameApiKeyMatcher(FindCoordinatorRequest::KEY),
224 (
225 FrameRequestLayer::<FindCoordinatorRequest>::new(),
226 MapRequestLayer::new(move |request: FindCoordinatorRequest| {
227 FindCoordinatorRequest::default()
228 .key_type(request.key_type)
229 .coordinator_keys(
230 request
231 .coordinator_keys
232 .or(request.key.map(|key| vec![key])),
233 )
234 }),
235 MapResponseLayer::new(move |response: FindCoordinatorResponse| {
236 let coordinators = response.coordinators.as_ref().map(|coordinators| {
237 coordinators
238 .iter()
239 .map(|coordinator| {
240 Coordinator::default()
241 .key(coordinator.key.clone())
242 .error_code(coordinator.error_code)
243 .host(host.clone())
244 .port(port)
245 .node_id(coordinator.node_id)
246 })
247 .collect()
248 });
249
250 let coordinator = response
251 .coordinators
252 .as_deref()
253 .and_then(|coordinators| coordinators.first());
254
255 FindCoordinatorResponse::default()
256 .throttle_time_ms(Some(0))
257 .coordinators(coordinators)
258 .error_code(
259 response
260 .error_code
261 .or(coordinator.map(|coordinator| coordinator.error_code)),
262 )
263 .error_message(
264 response
265 .error_message
266 .or(coordinator
267 .and_then(|coordinator| coordinator.error_message.clone()))
268 .or(Some("NONE".into())),
269 )
270 .host(Some(host.clone()))
271 .port(Some(port))
272 .node_id(
273 response
274 .node_id
275 .or(coordinator.map(|coordinator| coordinator.node_id)),
276 )
277 }),
278 )
279 .into_layer(request_origin.clone()),
280 );
281
282 let produce = HijackLayer::new(
283 FrameApiKeyMatcher(ProduceRequest::KEY),
284 (
285 FrameRequestLayer::<ProduceRequest>::new(),
286 TopicConfigLayer::new(configuration.clone(), request_origin.clone()),
287 )
288 .into_layer(
289 HijackLayer::new(
290 ResourceConfigValueMatcher::new(
291 configuration.clone(),
292 "tansu.batch",
293 "true",
294 ),
295 BatchProduceLayer::new(configuration.clone())
296 .into_layer(request_origin.clone()),
297 )
298 .into_layer(request_origin.clone()),
299 ),
300 );
301
302 let s = (
303 TcpListenerLayer::new(token),
304 TcpContextLayer::default(),
305 TcpBytesLayer::<()>::default(),
306 BytesFrameLayer,
307 meta,
308 produce,
309 find_coordinator,
310 )
311 .into_layer(frame_origin);
312
313 s.serve(Context::with_state(()), listener).await?;
314
315 Ok(())
316 }
317
318 pub async fn main(
319 listener_url: Url,
320 advertised_listener_url: Url,
321 origin_url: Url,
322 otlp_endpoint_url: Option<Url>,
323 ) -> Result<ErrorCode, Error> {
324 let mut set = JoinSet::new();
325
326 let meter_provider = otlp_endpoint_url.map_or(Ok(None), |otlp_endpoint_url| {
327 meter_provider(otlp_endpoint_url, env!("CARGO_PKG_NAME")).map(Some)
328 })?;
329
330 {
331 let proxy = Proxy::new(listener_url, advertised_listener_url, origin_url);
332 _ = set.spawn(async move { proxy.listen().await });
333 }
334
335 loop {
336 if set.join_next().await.is_none() {
337 break;
338 }
339 }
340
341 if let Some(meter_provider) = meter_provider {
342 meter_provider
343 .force_flush()
344 .inspect(|force_flush| debug!(?force_flush))?;
345
346 meter_provider
347 .shutdown()
348 .inspect(|shutdown| debug!(?shutdown))?;
349 }
350
351 Ok(ErrorCode::None)
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use std::{fs::File, sync::Arc, thread};
358
359 use tansu_sans_io::{
360 DescribeConfigsRequest, DescribeConfigsResponse, Frame, Header, ProduceResponse,
361 };
362 use tansu_service::{FrameService, RequestApiKeyMatcher, ResponseService};
363 use tracing::subscriber::DefaultGuard;
364 use tracing_subscriber::EnvFilter;
365
366 use super::*;
367
368 fn init_tracing() -> Result<DefaultGuard, Error> {
369 Ok(tracing::subscriber::set_default(
370 tracing_subscriber::fmt()
371 .with_level(true)
372 .with_line_number(true)
373 .with_thread_names(false)
374 .with_env_filter(
375 EnvFilter::from_default_env()
376 .add_directive(format!("{}=debug", env!("CARGO_CRATE_NAME")).parse()?),
377 )
378 .with_writer(
379 thread::current()
380 .name()
381 .ok_or(Error::Message(String::from("unnamed thread")))
382 .and_then(|name| {
383 File::create(format!("../logs/{}/{name}.log", env!("CARGO_PKG_NAME"),))
384 .map_err(Into::into)
385 })
386 .map(Arc::new)?,
387 )
388 .finish(),
389 ))
390 }
391
392 #[tokio::test]
393 async fn produce_hijack() -> Result<(), Error> {
394 let _guard = init_tracing()?;
395
396 const THROTTLE_TIME_MS: Option<i32> = Some(43234);
397
398 let produce =
399 HijackLayer::new(
400 FrameApiKeyMatcher(ProduceRequest::KEY),
401 FrameRequestLayer::<ProduceRequest>::new().into_layer(ResponseService::new(
402 |_ctx: Context<()>, _req: ProduceRequest| {
403 Ok::<_, Error>(
404 ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS),
405 )
406 },
407 )),
408 )
409 .into_layer(FrameRequestLayer::<ProduceRequest>::new().into_layer(
410 ResponseService::new(|_ctx: Context<()>, _req: ProduceRequest| {
411 Ok::<_, Error>(ProduceResponse::default())
412 }),
413 ));
414
415 let frame = produce
416 .serve(
417 Context::default(),
418 Frame {
419 size: 0,
420 header: Header::Request {
421 api_key: ProduceRequest::KEY,
422 api_version: 12,
423 correlation_id: 12321,
424 client_id: Some("abc".into()),
425 },
426 body: ProduceRequest::default().into(),
427 },
428 )
429 .await?;
430
431 let response = ProduceResponse::try_from(frame.body)?;
432 assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
433
434 Ok(())
435 }
436
437 #[tokio::test]
438 async fn request_api_matcher() -> Result<(), Error> {
439 let _guard = init_tracing()?;
440
441 const THROTTLE_TIME_MS: Option<i32> = Some(43234);
442
443 let service = HijackLayer::new(
444 RequestApiKeyMatcher(ProduceRequest::KEY),
445 ResponseService::new(|_, _req: ProduceRequest| {
446 Ok::<_, Error>(ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS))
447 }),
448 )
449 .into_layer(ResponseService::new(|_, _req: ProduceRequest| {
450 Ok::<_, Error>(ProduceResponse::default())
451 }));
452
453 let response = service
454 .serve(Context::default(), ProduceRequest::default())
455 .await?;
456
457 assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
458
459 Ok(())
460 }
461
462 #[tokio::test]
463 async fn frame_topic_config() -> Result<(), Error> {
464 let _guard = init_tracing()?;
465
466 let configuration = ResourceConfig::default();
467 const THROTTLE_TIME_MS: Option<i32> = Some(43234);
468
469 let service = HijackLayer::new(
470 FrameApiKeyMatcher(ProduceRequest::KEY),
471 (
472 FrameRequestLayer::<ProduceRequest>::new(),
473 TopicConfigLayer::new(
474 configuration.clone(),
475 ResponseService::new(|_: Context<()>, _req: DescribeConfigsRequest| {
476 Ok::<_, Error>(DescribeConfigsResponse::default())
477 }),
478 ),
479 )
480 .into_layer(ResponseService::new(
481 |_: Context<()>, _req: ProduceRequest| {
482 Ok::<_, Error>(
483 ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS),
484 )
485 },
486 )),
487 )
488 .into_layer(FrameService::new(|_: Context<()>, _req: Frame| {
489 Ok::<_, Error>(Frame {
490 size: 0,
491 header: Header::Response {
492 correlation_id: 12321,
493 },
494 body: MetadataResponse::default().into(),
495 })
496 }));
497
498 let response = service
499 .serve(
500 Context::default(),
501 Frame {
502 size: 0,
503 header: Header::Request {
504 api_key: ProduceRequest::KEY,
505 api_version: 123,
506 correlation_id: 321,
507 client_id: Some("abc".into()),
508 },
509 body: ProduceRequest::default().into(),
510 },
511 )
512 .await?;
513
514 assert!(ProduceResponse::try_from(response.body).is_ok());
515
516 Ok(())
517 }
518
519 #[tokio::test]
520 async fn response_topic_config() -> Result<(), Error> {
521 let configuration = ResourceConfig::default();
522 const THROTTLE_TIME_MS: Option<i32> = Some(43234);
523
524 let service = TopicConfigLayer::new(
525 configuration,
526 ResponseService::new(|_: Context<()>, _req: DescribeConfigsRequest| {
527 Ok::<_, Error>(DescribeConfigsResponse::default())
528 }),
529 )
530 .layer(ResponseService::new(
531 |_: Context<()>, _req: ProduceRequest| {
532 Ok::<_, Error>(ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS))
533 },
534 ));
535
536 let response = service
537 .serve(Context::default(), ProduceRequest::default())
538 .await?;
539
540 assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
541
542 Ok(())
543 }
544
545 #[tokio::test]
546 async fn frame_api_matcher() -> Result<(), Error> {
547 let service = HijackLayer::new(
548 FrameApiKeyMatcher(ProduceRequest::KEY),
549 FrameRequestLayer::<ProduceRequest>::new().into_layer(ResponseService::new(
550 |_: Context<()>, _req: ProduceRequest| Ok::<_, Error>(ProduceResponse::default()),
551 )),
552 )
553 .into_layer(FrameService::new(|_: Context<()>, _req: Frame| {
554 Ok::<_, Error>(Frame {
555 size: 0,
556 header: Header::Response {
557 correlation_id: 12321,
558 },
559 body: MetadataResponse::default().into(),
560 })
561 }));
562
563 let response = service
564 .serve(
565 Context::default(),
566 Frame {
567 size: 0,
568 header: Header::Request {
569 api_key: ProduceRequest::KEY,
570 api_version: 123,
571 correlation_id: 321,
572 client_id: Some("abc".into()),
573 },
574 body: ProduceRequest::default().into(),
575 },
576 )
577 .await?;
578
579 assert!(ProduceResponse::try_from(response.body).is_ok());
580
581 let response = service
582 .serve(
583 Context::default(),
584 Frame {
585 size: 0,
586 header: Header::Request {
587 api_key: MetadataRequest::KEY,
588 api_version: 123,
589 correlation_id: 321,
590 client_id: Some("abc".into()),
591 },
592 body: MetadataRequest::default().into(),
593 },
594 )
595 .await?;
596
597 assert!(MetadataResponse::try_from(response.body).is_ok());
598
599 Ok(())
600 }
601}