tansu_proxy/
lib.rs

1// Copyright ⓒ 2024-2025 Peter Morgan <peter.james.morgan@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use opentelemetry::{InstrumentationScope, global, metrics::Meter};
16use opentelemetry_otlp::ExporterBuildError;
17use opentelemetry_sdk::error::OTelSdkError;
18use opentelemetry_semantic_conventions::SCHEMA_URL;
19use rama::{
20    Context, Layer, Service,
21    layer::{HijackLayer, MapErrLayer, MapRequestLayer, MapResponseLayer},
22};
23use std::{
24    fmt, io,
25    sync::{Arc, LazyLock},
26};
27use tansu_client::{
28    BytesConnectionService, ConnectionManager, FrameConnectionLayer, FramePoolLayer,
29    RequestConnectionLayer, RequestPoolLayer,
30};
31use tansu_otel::meter_provider;
32use tansu_sans_io::{
33    ApiKey, ErrorCode, FindCoordinatorRequest, FindCoordinatorResponse, MetadataRequest,
34    MetadataResponse, NULL_TOPIC_ID, ProduceRequest, find_coordinator_response::Coordinator,
35    metadata_request::MetadataRequestTopic, metadata_response::MetadataResponseBroker,
36};
37use tansu_service::{
38    BytesFrameLayer, FrameApiKeyMatcher, FrameBytesLayer, FrameRequestLayer, TcpBytesLayer,
39    TcpContextLayer, TcpListenerLayer, host_port,
40};
41use tokio::{
42    net::TcpListener,
43    task::{JoinError, JoinSet},
44};
45use tokio_util::sync::CancellationToken;
46use tracing::debug;
47use tracing_subscriber::filter::ParseError;
48use url::Url;
49
50use crate::{
51    produce::BatchProduceLayer,
52    topic::{ResourceConfig, ResourceConfigValue, ResourceConfigValueMatcher, TopicConfigLayer},
53};
54
55mod produce;
56mod topic;
57
58#[derive(Clone, Debug, thiserror::Error)]
59pub enum Error {
60    Client(#[from] tansu_client::Error),
61    ExporterBuild(Arc<ExporterBuildError>),
62    FrameTooBig(usize),
63    Io(Arc<io::Error>),
64    Join(Arc<JoinError>),
65    Otel(#[from] tansu_otel::Error),
66    OtelSdk(Arc<OTelSdkError>),
67    ParseFilter(Arc<ParseError>),
68    Protocol(#[from] tansu_sans_io::Error),
69    ResourceLock {
70        name: String,
71        key: Option<String>,
72        value: Option<ResourceConfigValue>,
73    },
74    Service(#[from] tansu_service::Error),
75    UnknownHost(Url),
76    Message(String),
77}
78
79impl fmt::Display for Error {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        write!(f, "{self:?}")
82    }
83}
84
85impl From<JoinError> for Error {
86    fn from(value: JoinError) -> Self {
87        Self::Join(Arc::new(value))
88    }
89}
90
91impl From<OTelSdkError> for Error {
92    fn from(value: OTelSdkError) -> Self {
93        Self::OtelSdk(Arc::new(value))
94    }
95}
96
97impl From<ExporterBuildError> for Error {
98    fn from(value: ExporterBuildError) -> Self {
99        Self::ExporterBuild(Arc::new(value))
100    }
101}
102
103impl From<ParseError> for Error {
104    fn from(value: ParseError) -> Self {
105        Self::ParseFilter(Arc::new(value))
106    }
107}
108
109impl From<io::Error> for Error {
110    fn from(value: io::Error) -> Self {
111        Self::Io(Arc::new(value))
112    }
113}
114
115pub(crate) static METER: LazyLock<Meter> = LazyLock::new(|| {
116    global::meter_with_scope(
117        InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
118            .with_version(env!("CARGO_PKG_VERSION"))
119            .with_schema_url(SCHEMA_URL)
120            .build(),
121    )
122});
123
124#[derive(Clone, Debug)]
125pub struct Proxy {
126    listener: Url,
127    advertised_listener: Url,
128    origin: Url,
129}
130
131impl Proxy {
132    pub fn new(listener: Url, advertised_listener: Url, origin: Url) -> Self {
133        Self {
134            listener,
135            advertised_listener,
136            origin,
137        }
138    }
139
140    pub async fn listen(&self) -> Result<(), Error> {
141        debug!(%self.listener, %self.advertised_listener, %self.origin);
142
143        let configuration = ResourceConfig::default();
144
145        let listener = TcpListener::bind(host_port(self.listener.clone()).await?).await?;
146
147        let token = CancellationToken::new();
148
149        let pool = ConnectionManager::builder(self.origin.clone())
150            .client_id(Some(env!("CARGO_PKG_NAME").into()))
151            .build()
152            .await
153            .inspect(|pool| debug!(?pool))?;
154
155        let request_origin = (
156            MapErrLayer::new(Error::from),
157            RequestPoolLayer::new(pool.clone()),
158            RequestConnectionLayer,
159            FrameBytesLayer,
160        )
161            .into_layer(BytesConnectionService);
162
163        let frame_origin = (
164            MapErrLayer::new(Error::from),
165            FramePoolLayer::new(pool.clone()),
166            FrameConnectionLayer,
167            FrameBytesLayer,
168        )
169            .into_layer(BytesConnectionService);
170
171        let host = String::from(self.advertised_listener.host_str().unwrap_or("localhost"));
172        let port = i32::from(self.advertised_listener.port().unwrap_or(9092));
173
174        let meta = HijackLayer::new(
175            FrameApiKeyMatcher(MetadataRequest::KEY),
176            (
177                FrameRequestLayer::<MetadataRequest>::new(),
178                MapRequestLayer::new(move |request: MetadataRequest| {
179                    MetadataRequest::default()
180                        .topics(request.topics.map(|topics| {
181                            topics
182                                .into_iter()
183                                .map(|topic| {
184                                    MetadataRequestTopic::default()
185                                        .name(topic.name)
186                                        .topic_id(topic.topic_id.or(Some(NULL_TOPIC_ID)))
187                                })
188                                .collect()
189                        }))
190                        .allow_auto_topic_creation(
191                            request.allow_auto_topic_creation.or(Some(false)),
192                        )
193                        .include_cluster_authorized_operations(
194                            request.include_cluster_authorized_operations,
195                        )
196                        .include_topic_authorized_operations(
197                            request.include_topic_authorized_operations.or(Some(false)),
198                        )
199                }),
200                MapResponseLayer::new(move |response: MetadataResponse| {
201                    let brokers = response.brokers.as_ref().map(|brokers| {
202                        brokers
203                            .iter()
204                            .map(|broker| {
205                                MetadataResponseBroker::default()
206                                    .node_id(broker.node_id)
207                                    .host(host.clone())
208                                    .port(port)
209                                    .rack(broker.rack.clone())
210                            })
211                            .collect()
212                    });
213
214                    response.brokers(brokers)
215                }),
216            )
217                .into_layer(request_origin.clone()),
218        );
219
220        let host = String::from(self.advertised_listener.host_str().unwrap_or("localhost"));
221
222        let find_coordinator = HijackLayer::new(
223            FrameApiKeyMatcher(FindCoordinatorRequest::KEY),
224            (
225                FrameRequestLayer::<FindCoordinatorRequest>::new(),
226                MapRequestLayer::new(move |request: FindCoordinatorRequest| {
227                    FindCoordinatorRequest::default()
228                        .key_type(request.key_type)
229                        .coordinator_keys(
230                            request
231                                .coordinator_keys
232                                .or(request.key.map(|key| vec![key])),
233                        )
234                }),
235                MapResponseLayer::new(move |response: FindCoordinatorResponse| {
236                    let coordinators = response.coordinators.as_ref().map(|coordinators| {
237                        coordinators
238                            .iter()
239                            .map(|coordinator| {
240                                Coordinator::default()
241                                    .key(coordinator.key.clone())
242                                    .error_code(coordinator.error_code)
243                                    .host(host.clone())
244                                    .port(port)
245                                    .node_id(coordinator.node_id)
246                            })
247                            .collect()
248                    });
249
250                    let coordinator = response
251                        .coordinators
252                        .as_deref()
253                        .and_then(|coordinators| coordinators.first());
254
255                    FindCoordinatorResponse::default()
256                        .throttle_time_ms(Some(0))
257                        .coordinators(coordinators)
258                        .error_code(
259                            response
260                                .error_code
261                                .or(coordinator.map(|coordinator| coordinator.error_code)),
262                        )
263                        .error_message(
264                            response
265                                .error_message
266                                .or(coordinator
267                                    .and_then(|coordinator| coordinator.error_message.clone()))
268                                .or(Some("NONE".into())),
269                        )
270                        .host(Some(host.clone()))
271                        .port(Some(port))
272                        .node_id(
273                            response
274                                .node_id
275                                .or(coordinator.map(|coordinator| coordinator.node_id)),
276                        )
277                }),
278            )
279                .into_layer(request_origin.clone()),
280        );
281
282        let produce = HijackLayer::new(
283            FrameApiKeyMatcher(ProduceRequest::KEY),
284            (
285                FrameRequestLayer::<ProduceRequest>::new(),
286                TopicConfigLayer::new(configuration.clone(), request_origin.clone()),
287            )
288                .into_layer(
289                    HijackLayer::new(
290                        ResourceConfigValueMatcher::new(
291                            configuration.clone(),
292                            "tansu.batch",
293                            "true",
294                        ),
295                        BatchProduceLayer::new(configuration.clone())
296                            .into_layer(request_origin.clone()),
297                    )
298                    .into_layer(request_origin.clone()),
299                ),
300        );
301
302        let s = (
303            TcpListenerLayer::new(token),
304            TcpContextLayer::default(),
305            TcpBytesLayer::<()>::default(),
306            BytesFrameLayer,
307            meta,
308            produce,
309            find_coordinator,
310        )
311            .into_layer(frame_origin);
312
313        s.serve(Context::with_state(()), listener).await?;
314
315        Ok(())
316    }
317
318    pub async fn main(
319        listener_url: Url,
320        advertised_listener_url: Url,
321        origin_url: Url,
322        otlp_endpoint_url: Option<Url>,
323    ) -> Result<ErrorCode, Error> {
324        let mut set = JoinSet::new();
325
326        let meter_provider = otlp_endpoint_url.map_or(Ok(None), |otlp_endpoint_url| {
327            meter_provider(otlp_endpoint_url, env!("CARGO_PKG_NAME")).map(Some)
328        })?;
329
330        {
331            let proxy = Proxy::new(listener_url, advertised_listener_url, origin_url);
332            _ = set.spawn(async move { proxy.listen().await });
333        }
334
335        loop {
336            if set.join_next().await.is_none() {
337                break;
338            }
339        }
340
341        if let Some(meter_provider) = meter_provider {
342            meter_provider
343                .force_flush()
344                .inspect(|force_flush| debug!(?force_flush))?;
345
346            meter_provider
347                .shutdown()
348                .inspect(|shutdown| debug!(?shutdown))?;
349        }
350
351        Ok(ErrorCode::None)
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use std::{fs::File, sync::Arc, thread};
358
359    use tansu_sans_io::{
360        DescribeConfigsRequest, DescribeConfigsResponse, Frame, Header, ProduceResponse,
361    };
362    use tansu_service::{FrameService, RequestApiKeyMatcher, ResponseService};
363    use tracing::subscriber::DefaultGuard;
364    use tracing_subscriber::EnvFilter;
365
366    use super::*;
367
368    fn init_tracing() -> Result<DefaultGuard, Error> {
369        Ok(tracing::subscriber::set_default(
370            tracing_subscriber::fmt()
371                .with_level(true)
372                .with_line_number(true)
373                .with_thread_names(false)
374                .with_env_filter(
375                    EnvFilter::from_default_env()
376                        .add_directive(format!("{}=debug", env!("CARGO_CRATE_NAME")).parse()?),
377                )
378                .with_writer(
379                    thread::current()
380                        .name()
381                        .ok_or(Error::Message(String::from("unnamed thread")))
382                        .and_then(|name| {
383                            File::create(format!("../logs/{}/{name}.log", env!("CARGO_PKG_NAME"),))
384                                .map_err(Into::into)
385                        })
386                        .map(Arc::new)?,
387                )
388                .finish(),
389        ))
390    }
391
392    #[tokio::test]
393    async fn produce_hijack() -> Result<(), Error> {
394        let _guard = init_tracing()?;
395
396        const THROTTLE_TIME_MS: Option<i32> = Some(43234);
397
398        let produce =
399            HijackLayer::new(
400                FrameApiKeyMatcher(ProduceRequest::KEY),
401                FrameRequestLayer::<ProduceRequest>::new().into_layer(ResponseService::new(
402                    |_ctx: Context<()>, _req: ProduceRequest| {
403                        Ok::<_, Error>(
404                            ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS),
405                        )
406                    },
407                )),
408            )
409            .into_layer(FrameRequestLayer::<ProduceRequest>::new().into_layer(
410                ResponseService::new(|_ctx: Context<()>, _req: ProduceRequest| {
411                    Ok::<_, Error>(ProduceResponse::default())
412                }),
413            ));
414
415        let frame = produce
416            .serve(
417                Context::default(),
418                Frame {
419                    size: 0,
420                    header: Header::Request {
421                        api_key: ProduceRequest::KEY,
422                        api_version: 12,
423                        correlation_id: 12321,
424                        client_id: Some("abc".into()),
425                    },
426                    body: ProduceRequest::default().into(),
427                },
428            )
429            .await?;
430
431        let response = ProduceResponse::try_from(frame.body)?;
432        assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
433
434        Ok(())
435    }
436
437    #[tokio::test]
438    async fn request_api_matcher() -> Result<(), Error> {
439        let _guard = init_tracing()?;
440
441        const THROTTLE_TIME_MS: Option<i32> = Some(43234);
442
443        let service = HijackLayer::new(
444            RequestApiKeyMatcher(ProduceRequest::KEY),
445            ResponseService::new(|_, _req: ProduceRequest| {
446                Ok::<_, Error>(ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS))
447            }),
448        )
449        .into_layer(ResponseService::new(|_, _req: ProduceRequest| {
450            Ok::<_, Error>(ProduceResponse::default())
451        }));
452
453        let response = service
454            .serve(Context::default(), ProduceRequest::default())
455            .await?;
456
457        assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
458
459        Ok(())
460    }
461
462    #[tokio::test]
463    async fn frame_topic_config() -> Result<(), Error> {
464        let _guard = init_tracing()?;
465
466        let configuration = ResourceConfig::default();
467        const THROTTLE_TIME_MS: Option<i32> = Some(43234);
468
469        let service = HijackLayer::new(
470            FrameApiKeyMatcher(ProduceRequest::KEY),
471            (
472                FrameRequestLayer::<ProduceRequest>::new(),
473                TopicConfigLayer::new(
474                    configuration.clone(),
475                    ResponseService::new(|_: Context<()>, _req: DescribeConfigsRequest| {
476                        Ok::<_, Error>(DescribeConfigsResponse::default())
477                    }),
478                ),
479            )
480                .into_layer(ResponseService::new(
481                    |_: Context<()>, _req: ProduceRequest| {
482                        Ok::<_, Error>(
483                            ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS),
484                        )
485                    },
486                )),
487        )
488        .into_layer(FrameService::new(|_: Context<()>, _req: Frame| {
489            Ok::<_, Error>(Frame {
490                size: 0,
491                header: Header::Response {
492                    correlation_id: 12321,
493                },
494                body: MetadataResponse::default().into(),
495            })
496        }));
497
498        let response = service
499            .serve(
500                Context::default(),
501                Frame {
502                    size: 0,
503                    header: Header::Request {
504                        api_key: ProduceRequest::KEY,
505                        api_version: 123,
506                        correlation_id: 321,
507                        client_id: Some("abc".into()),
508                    },
509                    body: ProduceRequest::default().into(),
510                },
511            )
512            .await?;
513
514        assert!(ProduceResponse::try_from(response.body).is_ok());
515
516        Ok(())
517    }
518
519    #[tokio::test]
520    async fn response_topic_config() -> Result<(), Error> {
521        let configuration = ResourceConfig::default();
522        const THROTTLE_TIME_MS: Option<i32> = Some(43234);
523
524        let service = TopicConfigLayer::new(
525            configuration,
526            ResponseService::new(|_: Context<()>, _req: DescribeConfigsRequest| {
527                Ok::<_, Error>(DescribeConfigsResponse::default())
528            }),
529        )
530        .layer(ResponseService::new(
531            |_: Context<()>, _req: ProduceRequest| {
532                Ok::<_, Error>(ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS))
533            },
534        ));
535
536        let response = service
537            .serve(Context::default(), ProduceRequest::default())
538            .await?;
539
540        assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
541
542        Ok(())
543    }
544
545    #[tokio::test]
546    async fn frame_api_matcher() -> Result<(), Error> {
547        let service = HijackLayer::new(
548            FrameApiKeyMatcher(ProduceRequest::KEY),
549            FrameRequestLayer::<ProduceRequest>::new().into_layer(ResponseService::new(
550                |_: Context<()>, _req: ProduceRequest| Ok::<_, Error>(ProduceResponse::default()),
551            )),
552        )
553        .into_layer(FrameService::new(|_: Context<()>, _req: Frame| {
554            Ok::<_, Error>(Frame {
555                size: 0,
556                header: Header::Response {
557                    correlation_id: 12321,
558                },
559                body: MetadataResponse::default().into(),
560            })
561        }));
562
563        let response = service
564            .serve(
565                Context::default(),
566                Frame {
567                    size: 0,
568                    header: Header::Request {
569                        api_key: ProduceRequest::KEY,
570                        api_version: 123,
571                        correlation_id: 321,
572                        client_id: Some("abc".into()),
573                    },
574                    body: ProduceRequest::default().into(),
575                },
576            )
577            .await?;
578
579        assert!(ProduceResponse::try_from(response.body).is_ok());
580
581        let response = service
582            .serve(
583                Context::default(),
584                Frame {
585                    size: 0,
586                    header: Header::Request {
587                        api_key: MetadataRequest::KEY,
588                        api_version: 123,
589                        correlation_id: 321,
590                        client_id: Some("abc".into()),
591                    },
592                    body: MetadataRequest::default().into(),
593                },
594            )
595            .await?;
596
597        assert!(MetadataResponse::try_from(response.body).is_ok());
598
599        Ok(())
600    }
601}