spring_opentelemetry/trace/
http.rs

1//! Middleware that adds tracing to a [`Service`] that handles HTTP requests.
2//! https://opentelemetry.io/docs/specs/semconv/http/http-spans/
3
4use super::SpanKind;
5use crate::util::http as http_util;
6use http::{HeaderName, HeaderValue, Request, Response};
7use opentelemetry::trace::TraceContextExt;
8use opentelemetry_http::{HeaderExtractor, HeaderInjector};
9use opentelemetry_semantic_conventions::{
10    attribute::{EXCEPTION_MESSAGE, HTTP_RESPONSE_STATUS_CODE, OTEL_STATUS_CODE},
11    trace::HTTP_ROUTE,
12};
13use pin_project::pin_project;
14use std::{
15    fmt::Display,
16    future::Future,
17    pin::Pin,
18    task::{ready, Context, Poll},
19};
20use tower::{Layer, Service};
21use tracing::{Level, Span};
22use tracing_opentelemetry::OpenTelemetrySpanExt;
23
24/// [`Layer`] that adds tracing to a [`Service`] that handles HTTP requests.
25#[derive(Clone, Debug)]
26pub struct HttpLayer {
27    level: Level,
28    kind: SpanKind,
29    export_trace_id: bool,
30}
31
32impl HttpLayer {
33    /// [`Span`]s are constructed at the given level from server side.
34    pub fn server(level: Level) -> Self {
35        Self {
36            level,
37            kind: SpanKind::Server,
38            export_trace_id: false,
39        }
40    }
41
42    /// [`Span`]s are constructed at the given level from client side.
43    pub fn client(level: Level) -> Self {
44        Self {
45            level,
46            kind: SpanKind::Client,
47            export_trace_id: false,
48        }
49    }
50
51    pub fn export_trace_id(mut self, export_trace_id: bool) -> Self {
52        self.export_trace_id = export_trace_id;
53        self
54    }
55}
56
57impl<S> Layer<S> for HttpLayer {
58    type Service = HttpService<S>;
59
60    fn layer(&self, inner: S) -> Self::Service {
61        HttpService {
62            inner,
63            level: self.level,
64            kind: self.kind,
65            export_trace_id: self.export_trace_id,
66        }
67    }
68}
69
70/// Middleware that adds tracing to a [`Service`] that handles HTTP requests.
71#[derive(Clone, Debug)]
72pub struct HttpService<S> {
73    inner: S,
74    level: Level,
75    kind: SpanKind,
76    export_trace_id: bool,
77}
78
79impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for HttpService<S>
80where
81    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
82    S::Error: Display,
83{
84    type Response = S::Response;
85    type Error = S::Error;
86    type Future = ResponseFuture<S::Future>;
87
88    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89        self.inner.poll_ready(cx)
90    }
91
92    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
93        let span = self.make_request_span(&mut req);
94        let inner = {
95            let _enter = span.enter();
96            self.inner.call(req)
97        };
98
99        ResponseFuture {
100            inner,
101            span,
102            kind: self.kind,
103            export_trace_id: self.export_trace_id,
104        }
105    }
106}
107
108impl<S> HttpService<S> {
109    /// Creates a new [`Span`] for the given request.
110    fn make_request_span<B>(&self, request: &mut Request<B>) -> Span {
111        let Self { level, kind, .. } = self;
112
113        // attribute::HTTP_REQUEST_METHOD
114        macro_rules! make_span {
115            ($level:expr) => {{
116                use tracing::field::Empty;
117
118                tracing::span!(
119                    $level,
120                    "HTTP",
121                    "exception.message" = Empty,
122                    "http.request.method" = tracing::field::display(request.method()),
123                    "http.response.status_code" = Empty,
124                    "network.protocol.name" = "http",
125                    "network.protocol.version" = tracing::field::debug(request.version()),
126                    "otel.kind" = tracing::field::debug(kind),
127                    "otel.status_code" = Empty,
128                    "url.full" = tracing::field::display(request.uri()),
129                    "url.path" = request.uri().path(),
130                    "url.query" = request.uri().query(),
131                )
132            }};
133        }
134
135        let span = match *level {
136            Level::ERROR => make_span!(Level::ERROR),
137            Level::WARN => make_span!(Level::WARN),
138            Level::INFO => make_span!(Level::INFO),
139            Level::DEBUG => make_span!(Level::DEBUG),
140            Level::TRACE => make_span!(Level::TRACE),
141        };
142
143        let capture_request_headers = kind.capture_request_headers();
144
145        for (header_name, header_value) in request.headers().iter() {
146            let header_name = header_name.as_str().to_lowercase();
147            if capture_request_headers.contains(&header_name) {
148                if let Ok(attribute_value) = header_value.to_str() {
149                    // attribute::HTTP_REQUEST_HEADER
150                    let attribute_name = format!("http.request.header.{header_name}");
151                    span.set_attribute(attribute_name, attribute_value.to_owned());
152                }
153            }
154        }
155
156        match kind {
157            SpanKind::Client => {
158                let context = span.context();
159                opentelemetry::global::get_text_map_propagator(|injector| {
160                    injector.inject_context(&context, &mut HeaderInjector(request.headers_mut()));
161                });
162            }
163            SpanKind::Server => {
164                if let Some(http_route) = http_util::http_route(request) {
165                    span.record(HTTP_ROUTE, http_route);
166                }
167                let context = opentelemetry::global::get_text_map_propagator(|extractor| {
168                    extractor.extract(&HeaderExtractor(request.headers()))
169                });
170                span.set_parent(context);
171            }
172        }
173
174        span
175    }
176}
177
178/// Response future for [`Http`].
179#[pin_project]
180pub struct ResponseFuture<F> {
181    #[pin]
182    inner: F,
183    span: Span,
184    kind: SpanKind,
185    export_trace_id: bool,
186}
187
188impl<F, ResBody, E> Future for ResponseFuture<F>
189where
190    F: Future<Output = Result<Response<ResBody>, E>>,
191    E: Display,
192{
193    type Output = Result<Response<ResBody>, E>;
194
195    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
196        let this = self.project();
197        let _enter = this.span.enter();
198
199        match ready!(this.inner.poll(cx)) {
200            Ok(mut response) => {
201                Self::record_response(this.span, *this.kind, &response);
202                if *this.export_trace_id {
203                    let trace_id = this
204                        .span
205                        .context()
206                        .span()
207                        .span_context()
208                        .trace_id()
209                        .to_string();
210                    if let Ok(value) = HeaderValue::from_str(&trace_id) {
211                        response
212                            .headers_mut()
213                            .insert(HeaderName::from_static("x-trace-id"), value);
214                    }
215                }
216                Poll::Ready(Ok(response))
217            }
218            Err(err) => {
219                Self::record_error(this.span, &err);
220                Poll::Ready(Err(err))
221            }
222        }
223    }
224}
225
226impl<F, ResBody, E> ResponseFuture<F>
227where
228    F: Future<Output = Result<Response<ResBody>, E>>,
229    E: Display,
230{
231    /// Records fields associated to the response.
232    fn record_response<B>(span: &Span, kind: SpanKind, response: &Response<B>) {
233        span.record(HTTP_RESPONSE_STATUS_CODE, response.status().as_u16() as i64);
234
235        let capture_response_headers = kind.capture_response_headers();
236
237        for (header_name, header_value) in response.headers().iter() {
238            let header_name = header_name.as_str().to_lowercase();
239            if capture_response_headers.contains(&header_name) {
240                if let Ok(attribute_value) = header_value.to_str() {
241                    let attribute_name: String = format!("http.response.header.{header_name}");
242                    span.set_attribute(attribute_name, attribute_value.to_owned());
243                }
244            }
245        }
246
247        if let SpanKind::Client = kind {
248            if response.status().is_client_error() {
249                span.record(OTEL_STATUS_CODE, "ERROR");
250            }
251        }
252        if response.status().is_server_error() {
253            span.record(OTEL_STATUS_CODE, "ERROR");
254        }
255    }
256
257    /// Records the error message.
258    fn record_error(span: &Span, err: &E) {
259        span.record(OTEL_STATUS_CODE, "ERROR");
260        span.record(EXCEPTION_MESSAGE, err.to_string());
261    }
262}