tansu_proxy/
lib.rs

1// Copyright ⓒ 2024-2025 Peter Morgan <peter.james.morgan@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use opentelemetry::{InstrumentationScope, global, metrics::Meter};
16use opentelemetry_semantic_conventions::SCHEMA_URL;
17use rama::{
18    Layer,
19    error::{BoxError, OpaqueError},
20    layer::{HijackLayer, MapResponseLayer},
21    tcp::server::TcpListener,
22};
23use std::{fmt::Debug, net::SocketAddr, ops::Deref, result, sync::LazyLock};
24use tansu_otel::meter_provider;
25use tansu_sans_io::{ErrorCode, metadata_response::MetadataResponseBroker};
26use tansu_service::{
27    api::{
28        ApiKey, ApiKeyVersionLayer,
29        describe_config::{ResourceConfig, ResourceConfigValueMatcher, TopicConfigLayer},
30        metadata::{MetadataIntoApiLayer, MetadataLayer, MetadataResponse},
31        produce::{self, ProduceIntoApiLayer, ProduceLayer},
32    },
33    service::{ApiClient, ApiRequestLayer, ByteLayer, TcpStreamLayer},
34};
35use tokio::{net::lookup_host, task::JoinSet};
36use tracing::debug;
37use url::Url;
38
39use crate::batch::BatchProduceLayer;
40
41mod batch;
42
43pub type Result<T, E = BoxError> = result::Result<T, E>;
44
45pub(crate) static METER: LazyLock<Meter> = LazyLock::new(|| {
46    global::meter_with_scope(
47        InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
48            .with_version(env!("CARGO_PKG_VERSION"))
49            .with_schema_url(SCHEMA_URL)
50            .build(),
51    )
52});
53
54#[derive(Clone, Debug)]
55pub struct Proxy {
56    listener: Url,
57    origin: Url,
58}
59
60async fn host_port(url: &Url) -> Result<SocketAddr> {
61    if let Some(host) = url.host_str()
62        && let Some(port) = url.port()
63    {
64        let mut addresses = lookup_host(format!("{host}:{port}"))
65            .await?
66            .filter(|socket_addr| matches!(socket_addr, SocketAddr::V4(_)));
67
68        if let Some(socket_addr) = addresses.next().inspect(|socket_addr| debug!(?socket_addr)) {
69            return Ok(socket_addr);
70        }
71    }
72
73    Err(OpaqueError::from_display(format!("unknown host: {url}")).into_boxed())
74}
75
76impl Proxy {
77    const METADATA_API_KEY: ApiKey = ApiKey(3);
78
79    const NODE_ID: i32 = 111;
80
81    pub fn new(listener: Url, origin: Url) -> Self {
82        Self { listener, origin }
83    }
84
85    pub async fn listen(&self) -> Result<()> {
86        debug!(%self.listener);
87
88        let configuration = ResourceConfig::default();
89
90        let listener = TcpListener::bind(host_port(&self.listener).await?).await?;
91
92        let origin = host_port(&self.origin)
93            .await
94            .map(Into::into)
95            .map(ApiClient::new)?;
96
97        let host = String::from(self.listener.host_str().unwrap_or("localhost"));
98        let port = i32::from(self.listener.port().unwrap_or(9092));
99
100        let meta = HijackLayer::new(
101            Self::METADATA_API_KEY,
102            (
103                MetadataLayer,
104                MapResponseLayer::new(move |response: MetadataResponse| MetadataResponse {
105                    brokers: Some(vec![
106                        MetadataResponseBroker::default()
107                            .node_id(Self::NODE_ID)
108                            .host(host)
109                            .port(port)
110                            .rack(None),
111                    ]),
112                    ..response
113                }),
114                MetadataIntoApiLayer,
115            )
116                .into_layer(origin.clone()),
117        );
118
119        let produce = HijackLayer::new(
120            produce::API_KEY_VERSION.deref().0,
121            (
122                ApiKeyVersionLayer,
123                ProduceLayer,
124                TopicConfigLayer::new(configuration.clone(), origin.clone()),
125                HijackLayer::new(
126                    ResourceConfigValueMatcher::new(configuration.clone(), "tansu.batch", "true"),
127                    (
128                        BatchProduceLayer::new(configuration.clone()),
129                        ProduceIntoApiLayer,
130                    )
131                        .into_layer(origin.clone()),
132                ),
133                ProduceIntoApiLayer,
134            )
135                .into_layer(origin.clone()),
136        );
137
138        let stack = (TcpStreamLayer, ByteLayer, ApiRequestLayer, meta, produce).into_layer(origin);
139
140        listener.serve(stack).await;
141
142        Ok(())
143    }
144
145    pub async fn main(
146        listener_url: Url,
147        origin_url: Url,
148        otlp_endpoint_url: Option<Url>,
149    ) -> Result<ErrorCode> {
150        let mut set = JoinSet::new();
151
152        let meter_provider = otlp_endpoint_url.map_or(Ok(None), |otlp_endpoint_url| {
153            meter_provider(otlp_endpoint_url, env!("CARGO_PKG_NAME")).map(Some)
154        })?;
155
156        {
157            let proxy = Proxy::new(listener_url, origin_url);
158            _ = set.spawn(async move { proxy.listen().await.unwrap() });
159        }
160
161        loop {
162            if set.join_next().await.is_none() {
163                break;
164            }
165        }
166
167        if let Some(meter_provider) = meter_provider {
168            meter_provider
169                .force_flush()
170                .inspect(|force_flush| debug!(?force_flush))?;
171
172            meter_provider
173                .shutdown()
174                .inspect(|shutdown| debug!(?shutdown))?;
175        }
176
177        Ok(ErrorCode::None)
178    }
179}