tower_otel/trace/
http.rs

1//! Middleware that adds tracing to a [`Service`] that handles HTTP requests.
2
3use std::{
4    fmt::Display,
5    future::Future,
6    pin::Pin,
7    task::{ready, Context, Poll},
8};
9
10use http::{Request, Response};
11use pin_project::pin_project;
12use tower_layer::Layer;
13use tower_service::Service;
14use tracing::{Level, Span};
15use tracing_opentelemetry::OpenTelemetrySpanExt;
16
17use crate::{
18    trace::{extractor::HeaderExtractor, injector::HeaderInjector},
19    util,
20};
21
22/// Describes the relationship between the [`Span`] and the service producing the span.
23#[derive(Clone, Copy, Debug)]
24enum SpanKind {
25    /// The span describes a request sent to some remote service.
26    Client,
27    /// The span describes the server-side handling of a request.
28    Server,
29}
30
31/// [`Layer`] that adds tracing to a [`Service`] that handles HTTP requests.
32#[derive(Clone, Debug)]
33pub struct HttpLayer {
34    level: Level,
35    kind: SpanKind,
36}
37
38impl HttpLayer {
39    /// [`Span`] are constructed at the given level from server side.
40    pub fn server(level: Level) -> Self {
41        Self {
42            level,
43            kind: SpanKind::Server,
44        }
45    }
46
47    /// [`Span`] are constructed at the given level from client side.
48    pub fn client(level: Level) -> Self {
49        Self {
50            level,
51            kind: SpanKind::Client,
52        }
53    }
54}
55
56impl<S> Layer<S> for HttpLayer {
57    type Service = Http<S>;
58
59    fn layer(&self, inner: S) -> Self::Service {
60        Http {
61            inner,
62            level: self.level,
63            kind: self.kind,
64        }
65    }
66}
67
68/// Middleware that adds tracing to a [`Service`] that handles HTTP requests.
69#[derive(Clone, Debug)]
70pub struct Http<S> {
71    inner: S,
72    level: Level,
73    kind: SpanKind,
74}
75
76impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Http<S>
77where
78    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
79    S::Error: Display,
80{
81    type Response = S::Response;
82    type Error = S::Error;
83    type Future = ResponseFuture<S::Future>;
84
85    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        self.inner.poll_ready(cx)
87    }
88
89    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
90        let span = make_request_span(self.level, self.kind, &mut req);
91        let inner = {
92            let _enter = span.enter();
93            self.inner.call(req)
94        };
95
96        ResponseFuture {
97            inner,
98            span,
99            kind: self.kind,
100        }
101    }
102}
103
104/// Response future for [`Http`].
105#[pin_project]
106pub struct ResponseFuture<F> {
107    #[pin]
108    inner: F,
109    span: Span,
110    kind: SpanKind,
111}
112
113impl<F, ResBody, E> Future for ResponseFuture<F>
114where
115    F: Future<Output = Result<Response<ResBody>, E>>,
116    E: Display,
117{
118    type Output = Result<Response<ResBody>, E>;
119
120    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
121        let this = self.project();
122        let _enter = this.span.enter();
123
124        match ready!(this.inner.poll(cx)) {
125            Ok(response) => {
126                record_response(this.span, *this.kind, &response);
127                Poll::Ready(Ok(response))
128            }
129            Err(err) => {
130                record_error(this.span, &err);
131                Poll::Ready(Err(err))
132            }
133        }
134    }
135}
136
137/// String representation of span kind
138fn span_kind(kind: SpanKind) -> &'static str {
139    match kind {
140        SpanKind::Client => "client",
141        SpanKind::Server => "server",
142    }
143}
144
145/// Creates a new [`Span`] for the given request.
146fn make_request_span<B>(level: Level, kind: SpanKind, request: &mut Request<B>) -> Span {
147    macro_rules! make_span {
148        ($level:expr) => {{
149            use tracing::field::Empty;
150
151            tracing::span!(
152                $level,
153                "HTTP",
154                "client.address" = Empty,
155                "client.port" = Empty,
156                "error.message" = Empty,
157                "http.request.method" = util::http_method(request.method()),
158                "http.response.status_code" = Empty,
159                "network.protocol.name" = "http",
160                "network.protocol.version" = util::http_version(request.version()),
161                "otel.kind" = span_kind(kind),
162                "otel.status_code" = Empty,
163                "server.address" = Empty,
164                "server.port" = Empty,
165                "url.full" = Empty,
166                "url.path" = request.uri().path(),
167                "url.query" = Empty,
168                "url.scheme" = Empty,
169            )
170        }};
171    }
172
173    let span = match level {
174        Level::ERROR => make_span!(Level::ERROR),
175        Level::WARN => make_span!(Level::WARN),
176        Level::INFO => make_span!(Level::INFO),
177        Level::DEBUG => make_span!(Level::DEBUG),
178        Level::TRACE => make_span!(Level::TRACE),
179    };
180
181    for (header_name, header_value) in request.headers().iter() {
182        if let Ok(attribute_value) = header_value.to_str() {
183            let attribute_name = format!("http.request.header.{}", header_name);
184            span.set_attribute(attribute_name, attribute_value.to_owned());
185        }
186    }
187
188    if let Some(query) = request.uri().query() {
189        span.record("url.query", query);
190    }
191
192    match kind {
193        SpanKind::Client => {
194            span.record("url.full", tracing::field::display(request.uri()));
195
196            let util::HttpRequestAttributes {
197                url_scheme,
198                server_address,
199                server_port,
200            } = util::HttpRequestAttributes::from_sent_request(request);
201
202            if let Some(server_address) = server_address {
203                span.record("server.address", server_address);
204            }
205            if let Some(server_port) = server_port {
206                span.record("server.port", server_port);
207            }
208            if let Some(url_scheme) = url_scheme {
209                span.record("url.scheme", url_scheme);
210            }
211
212            let context = span.context();
213            opentelemetry::global::get_text_map_propagator(|injector| {
214                injector.inject_context(&context, &mut HeaderInjector(request.headers_mut()));
215            });
216        }
217        SpanKind::Server => {
218            if let Some(http_route) = util::http_route(request) {
219                span.record("http.route", http_route);
220            }
221            if let Some(client_address) = util::client_address(request) {
222                let ip = client_address.ip();
223                span.record("client.address", tracing::field::display(ip));
224                span.record("client.port", client_address.port());
225            }
226
227            let util::HttpRequestAttributes {
228                url_scheme,
229                server_address,
230                server_port,
231            } = util::HttpRequestAttributes::from_recv_request(request);
232
233            if let Some(server_address) = server_address {
234                span.record("server.address", server_address);
235            }
236            if let Some(server_port) = server_port {
237                span.record("server.port", server_port);
238            }
239            if let Some(url_scheme) = url_scheme {
240                span.record("url.scheme", url_scheme);
241            }
242
243            let context = opentelemetry::global::get_text_map_propagator(|extractor| {
244                extractor.extract(&HeaderExtractor(request.headers_mut()))
245            });
246            if let Err(err) = span.set_parent(context) {
247                tracing::warn!("Failed to set parent span: {err}");
248            }
249        }
250    }
251
252    span
253}
254
255/// Records fields associated to the response.
256fn record_response<B>(span: &Span, kind: SpanKind, response: &Response<B>) {
257    span.record(
258        "http.response.status_code",
259        response.status().as_u16() as i64,
260    );
261
262    for (header_name, header_value) in response.headers().iter() {
263        if let Ok(attribute_value) = header_value.to_str() {
264            let attribute_name = format!("http.response.header.{}", header_name);
265            span.set_attribute(attribute_name, attribute_value.to_owned());
266        }
267    }
268
269    if let SpanKind::Client = kind {
270        if response.status().is_client_error() {
271            span.record("otel.status_code", "ERROR");
272        }
273    }
274    if response.status().is_server_error() {
275        span.record("otel.status_code", "ERROR");
276    }
277}
278
279/// Records the error message.
280fn record_error<E: Display>(span: &Span, err: &E) {
281    span.record("otel.status_code", "ERROR");
282    span.record("error.message", err.to_string());
283}