tansu_proxy/
lib.rs

1// Copyright ⓒ 2024-2025 Peter Morgan <peter.james.morgan@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use opentelemetry::{InstrumentationScope, global, metrics::Meter};
16use opentelemetry_otlp::ExporterBuildError;
17use opentelemetry_sdk::error::OTelSdkError;
18use opentelemetry_semantic_conventions::SCHEMA_URL;
19use rama::{
20    Context, Layer, Service,
21    layer::{HijackLayer, MapErrLayer, MapResponseLayer},
22};
23use std::{
24    fmt, io,
25    sync::{Arc, LazyLock},
26};
27use tansu_client::{
28    BytesConnectionService, ConnectionManager, FrameConnectionLayer, FramePoolLayer,
29    RequestConnectionLayer, RequestPoolLayer,
30};
31use tansu_otel::meter_provider;
32use tansu_sans_io::{
33    ApiKey, ErrorCode, MetadataRequest, MetadataResponse, ProduceRequest,
34    metadata_response::MetadataResponseBroker,
35};
36use tansu_service::{
37    BytesFrameLayer, FrameApiKeyMatcher, FrameBytesLayer, FrameRequestLayer, TcpBytesLayer,
38    TcpContextLayer, TcpListenerLayer, host_port,
39};
40use tokio::{
41    net::TcpListener,
42    task::{JoinError, JoinSet},
43};
44use tokio_util::sync::CancellationToken;
45use tracing::debug;
46use tracing_subscriber::filter::ParseError;
47use url::Url;
48
49use crate::{
50    produce::BatchProduceLayer,
51    topic::{ResourceConfig, ResourceConfigValue, ResourceConfigValueMatcher, TopicConfigLayer},
52};
53
54mod produce;
55mod topic;
56
57#[derive(Clone, Debug, thiserror::Error)]
58pub enum Error {
59    Client(#[from] tansu_client::Error),
60    ExporterBuild(Arc<ExporterBuildError>),
61    FrameTooBig(usize),
62    Io(Arc<io::Error>),
63    Join(Arc<JoinError>),
64    Otel(#[from] tansu_otel::Error),
65    OtelSdk(Arc<OTelSdkError>),
66    ParseFilter(Arc<ParseError>),
67    Protocol(#[from] tansu_sans_io::Error),
68    ResourceLock {
69        name: String,
70        key: Option<String>,
71        value: Option<ResourceConfigValue>,
72    },
73    Service(#[from] tansu_service::Error),
74    UnknownHost(Url),
75    Message(String),
76}
77
78impl fmt::Display for Error {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        write!(f, "{self:?}")
81    }
82}
83
84impl From<JoinError> for Error {
85    fn from(value: JoinError) -> Self {
86        Self::Join(Arc::new(value))
87    }
88}
89
90impl From<OTelSdkError> for Error {
91    fn from(value: OTelSdkError) -> Self {
92        Self::OtelSdk(Arc::new(value))
93    }
94}
95
96impl From<ExporterBuildError> for Error {
97    fn from(value: ExporterBuildError) -> Self {
98        Self::ExporterBuild(Arc::new(value))
99    }
100}
101
102impl From<ParseError> for Error {
103    fn from(value: ParseError) -> Self {
104        Self::ParseFilter(Arc::new(value))
105    }
106}
107
108impl From<io::Error> for Error {
109    fn from(value: io::Error) -> Self {
110        Self::Io(Arc::new(value))
111    }
112}
113
114pub(crate) static METER: LazyLock<Meter> = LazyLock::new(|| {
115    global::meter_with_scope(
116        InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
117            .with_version(env!("CARGO_PKG_VERSION"))
118            .with_schema_url(SCHEMA_URL)
119            .build(),
120    )
121});
122
123#[derive(Clone, Debug)]
124pub struct Proxy {
125    listener: Url,
126    origin: Url,
127}
128
129impl Proxy {
130    const NODE_ID: i32 = 111;
131
132    pub fn new(listener: Url, origin: Url) -> Self {
133        Self { listener, origin }
134    }
135
136    pub async fn listen(&self) -> Result<(), Error> {
137        debug!(%self.listener);
138
139        let configuration = ResourceConfig::default();
140
141        let listener = TcpListener::bind(host_port(self.listener.clone()).await?).await?;
142
143        let token = CancellationToken::new();
144
145        let pool = ConnectionManager::builder(self.origin.clone())
146            .client_id(Some(env!("CARGO_PKG_NAME").into()))
147            .build()
148            .await
149            .inspect(|pool| debug!(?pool))?;
150
151        let request_origin = (
152            MapErrLayer::new(Error::from),
153            RequestPoolLayer::new(pool.clone()),
154            RequestConnectionLayer,
155            FrameBytesLayer,
156        )
157            .into_layer(BytesConnectionService);
158
159        let frame_origin = (
160            MapErrLayer::new(Error::from),
161            FramePoolLayer::new(pool.clone()),
162            FrameConnectionLayer,
163            FrameBytesLayer,
164        )
165            .into_layer(BytesConnectionService);
166
167        let host = String::from(self.listener.host_str().unwrap_or("localhost"));
168        let port = i32::from(self.listener.port().unwrap_or(9092));
169
170        let meta = HijackLayer::new(
171            FrameApiKeyMatcher(MetadataRequest::KEY),
172            (
173                FrameRequestLayer::<MetadataRequest>::new(),
174                MapResponseLayer::new(move |response: MetadataResponse| {
175                    response.brokers(Some(vec![
176                        MetadataResponseBroker::default()
177                            .node_id(Self::NODE_ID)
178                            .host(host)
179                            .port(port)
180                            .rack(None),
181                    ]))
182                }),
183            )
184                .into_layer(request_origin.clone()),
185        );
186
187        let produce = HijackLayer::new(
188            FrameApiKeyMatcher(ProduceRequest::KEY),
189            (
190                FrameRequestLayer::<ProduceRequest>::new(),
191                TopicConfigLayer::new(configuration.clone(), request_origin.clone()),
192            )
193                .into_layer(
194                    HijackLayer::new(
195                        ResourceConfigValueMatcher::new(
196                            configuration.clone(),
197                            "tansu.batch",
198                            "true",
199                        ),
200                        BatchProduceLayer::new(configuration.clone())
201                            .into_layer(request_origin.clone()),
202                    )
203                    .into_layer(request_origin.clone()),
204                ),
205        );
206
207        let s = (
208            TcpListenerLayer::new(token),
209            TcpContextLayer::default(),
210            TcpBytesLayer::<()>::default(),
211            BytesFrameLayer,
212            meta,
213            produce,
214        )
215            .into_layer(frame_origin);
216
217        s.serve(Context::with_state(()), listener).await?;
218
219        Ok(())
220    }
221
222    pub async fn main(
223        listener_url: Url,
224        origin_url: Url,
225        otlp_endpoint_url: Option<Url>,
226    ) -> Result<ErrorCode, Error> {
227        let mut set = JoinSet::new();
228
229        let meter_provider = otlp_endpoint_url.map_or(Ok(None), |otlp_endpoint_url| {
230            meter_provider(otlp_endpoint_url, env!("CARGO_PKG_NAME")).map(Some)
231        })?;
232
233        {
234            let proxy = Proxy::new(listener_url, origin_url);
235            _ = set.spawn(async move { proxy.listen().await.unwrap() });
236        }
237
238        loop {
239            if set.join_next().await.is_none() {
240                break;
241            }
242        }
243
244        if let Some(meter_provider) = meter_provider {
245            meter_provider
246                .force_flush()
247                .inspect(|force_flush| debug!(?force_flush))?;
248
249            meter_provider
250                .shutdown()
251                .inspect(|shutdown| debug!(?shutdown))?;
252        }
253
254        Ok(ErrorCode::None)
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use std::{fs::File, sync::Arc, thread};
261
262    use tansu_sans_io::{
263        DescribeConfigsRequest, DescribeConfigsResponse, Frame, Header, ProduceResponse,
264    };
265    use tansu_service::{FrameService, RequestApiKeyMatcher, ResponseService};
266    use tracing::subscriber::DefaultGuard;
267    use tracing_subscriber::EnvFilter;
268
269    use super::*;
270
271    fn init_tracing() -> Result<DefaultGuard, Error> {
272        Ok(tracing::subscriber::set_default(
273            tracing_subscriber::fmt()
274                .with_level(true)
275                .with_line_number(true)
276                .with_thread_names(false)
277                .with_env_filter(
278                    EnvFilter::from_default_env()
279                        .add_directive(format!("{}=debug", env!("CARGO_CRATE_NAME")).parse()?),
280                )
281                .with_writer(
282                    thread::current()
283                        .name()
284                        .ok_or(Error::Message(String::from("unnamed thread")))
285                        .and_then(|name| {
286                            File::create(format!("../logs/{}/{name}.log", env!("CARGO_PKG_NAME"),))
287                                .map_err(Into::into)
288                        })
289                        .map(Arc::new)?,
290                )
291                .finish(),
292        ))
293    }
294
295    #[tokio::test]
296    async fn produce_hijack() -> Result<(), Error> {
297        let _guard = init_tracing()?;
298
299        const THROTTLE_TIME_MS: Option<i32> = Some(43234);
300
301        let produce =
302            HijackLayer::new(
303                FrameApiKeyMatcher(ProduceRequest::KEY),
304                FrameRequestLayer::<ProduceRequest>::new().into_layer(ResponseService::new(
305                    |_ctx: Context<()>, _req: ProduceRequest| {
306                        Ok::<_, Error>(
307                            ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS),
308                        )
309                    },
310                )),
311            )
312            .into_layer(FrameRequestLayer::<ProduceRequest>::new().into_layer(
313                ResponseService::new(|_ctx: Context<()>, _req: ProduceRequest| {
314                    Ok::<_, Error>(ProduceResponse::default())
315                }),
316            ));
317
318        let frame = produce
319            .serve(
320                Context::default(),
321                Frame {
322                    size: 0,
323                    header: Header::Request {
324                        api_key: ProduceRequest::KEY,
325                        api_version: 12,
326                        correlation_id: 12321,
327                        client_id: Some("abc".into()),
328                    },
329                    body: ProduceRequest::default().into(),
330                },
331            )
332            .await?;
333
334        let response = ProduceResponse::try_from(frame.body)?;
335        assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
336
337        Ok(())
338    }
339
340    #[tokio::test]
341    async fn request_api_matcher() -> Result<(), Error> {
342        let _guard = init_tracing()?;
343
344        const THROTTLE_TIME_MS: Option<i32> = Some(43234);
345
346        let service = HijackLayer::new(
347            RequestApiKeyMatcher(ProduceRequest::KEY),
348            ResponseService::new(|_, _req: ProduceRequest| {
349                Ok::<_, Error>(ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS))
350            }),
351        )
352        .into_layer(ResponseService::new(|_, _req: ProduceRequest| {
353            Ok::<_, Error>(ProduceResponse::default())
354        }));
355
356        let response = service
357            .serve(Context::default(), ProduceRequest::default())
358            .await?;
359
360        assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
361
362        Ok(())
363    }
364
365    #[tokio::test]
366    async fn frame_topic_config() -> Result<(), Error> {
367        let _guard = init_tracing()?;
368
369        let configuration = ResourceConfig::default();
370        const THROTTLE_TIME_MS: Option<i32> = Some(43234);
371
372        let service = HijackLayer::new(
373            FrameApiKeyMatcher(ProduceRequest::KEY),
374            (
375                FrameRequestLayer::<ProduceRequest>::new(),
376                TopicConfigLayer::new(
377                    configuration.clone(),
378                    ResponseService::new(|_: Context<()>, _req: DescribeConfigsRequest| {
379                        Ok::<_, Error>(DescribeConfigsResponse::default())
380                    }),
381                ),
382            )
383                .into_layer(ResponseService::new(
384                    |_: Context<()>, _req: ProduceRequest| {
385                        Ok::<_, Error>(
386                            ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS),
387                        )
388                    },
389                )),
390        )
391        .into_layer(FrameService::new(|_: Context<()>, _req: Frame| {
392            Ok::<_, Error>(Frame {
393                size: 0,
394                header: Header::Response {
395                    correlation_id: 12321,
396                },
397                body: MetadataResponse::default().into(),
398            })
399        }));
400
401        let response = service
402            .serve(
403                Context::default(),
404                Frame {
405                    size: 0,
406                    header: Header::Request {
407                        api_key: ProduceRequest::KEY,
408                        api_version: 123,
409                        correlation_id: 321,
410                        client_id: Some("abc".into()),
411                    },
412                    body: ProduceRequest::default().into(),
413                },
414            )
415            .await?;
416
417        assert!(ProduceResponse::try_from(response.body).is_ok());
418
419        Ok(())
420    }
421
422    #[tokio::test]
423    async fn response_topic_config() -> Result<(), Error> {
424        let configuration = ResourceConfig::default();
425        const THROTTLE_TIME_MS: Option<i32> = Some(43234);
426
427        let service = TopicConfigLayer::new(
428            configuration,
429            ResponseService::new(|_: Context<()>, _req: DescribeConfigsRequest| {
430                Ok::<_, Error>(DescribeConfigsResponse::default())
431            }),
432        )
433        .layer(ResponseService::new(
434            |_: Context<()>, _req: ProduceRequest| {
435                Ok::<_, Error>(ProduceResponse::default().throttle_time_ms(THROTTLE_TIME_MS))
436            },
437        ));
438
439        let response = service
440            .serve(Context::default(), ProduceRequest::default())
441            .await?;
442
443        assert_eq!(THROTTLE_TIME_MS, response.throttle_time_ms);
444
445        Ok(())
446    }
447
448    #[tokio::test]
449    async fn frame_api_matcher() -> Result<(), Error> {
450        let service = HijackLayer::new(
451            FrameApiKeyMatcher(ProduceRequest::KEY),
452            FrameRequestLayer::<ProduceRequest>::new().into_layer(ResponseService::new(
453                |_: Context<()>, _req: ProduceRequest| Ok::<_, Error>(ProduceResponse::default()),
454            )),
455        )
456        .into_layer(FrameService::new(|_: Context<()>, _req: Frame| {
457            Ok::<_, Error>(Frame {
458                size: 0,
459                header: Header::Response {
460                    correlation_id: 12321,
461                },
462                body: MetadataResponse::default().into(),
463            })
464        }));
465
466        let response = service
467            .serve(
468                Context::default(),
469                Frame {
470                    size: 0,
471                    header: Header::Request {
472                        api_key: ProduceRequest::KEY,
473                        api_version: 123,
474                        correlation_id: 321,
475                        client_id: Some("abc".into()),
476                    },
477                    body: ProduceRequest::default().into(),
478                },
479            )
480            .await?;
481
482        assert!(ProduceResponse::try_from(response.body).is_ok());
483
484        let response = service
485            .serve(
486                Context::default(),
487                Frame {
488                    size: 0,
489                    header: Header::Request {
490                        api_key: MetadataRequest::KEY,
491                        api_version: 123,
492                        correlation_id: 321,
493                        client_id: Some("abc".into()),
494                    },
495                    body: MetadataRequest::default().into(),
496                },
497            )
498            .await?;
499
500        assert!(MetadataResponse::try_from(response.body).is_ok());
501
502        Ok(())
503    }
504}