spring_opentelemetry/middlewares/
tracing.rs

1//! Middleware that adds tracing to a [`Service`] that handles HTTP requests.
2
3use http::{HeaderName, HeaderValue, Request, Response};
4use opentelemetry::trace::TraceContextExt;
5use opentelemetry_http::{HeaderExtractor, HeaderInjector};
6use opentelemetry_semantic_conventions::attribute;
7use pin_project::pin_project;
8use std::{
9    fmt::Display,
10    future::Future,
11    pin::Pin,
12    task::{ready, Context, Poll},
13};
14use tower::{Layer, Service};
15use tracing::{Level, Span};
16use tracing_opentelemetry::OpenTelemetrySpanExt;
17
18/// Describes the relationship between the [`Span`] and the service producing the span.
19#[derive(Clone, Copy, Debug)]
20enum SpanKind {
21    /// The span describes a request sent to some remote service.
22    Client,
23    /// The span describes the server-side handling of a request.
24    Server,
25}
26
27/// [`Layer`] that adds tracing to a [`Service`] that handles HTTP requests.
28#[derive(Clone, Debug)]
29pub struct HttpLayer {
30    level: Level,
31    kind: SpanKind,
32    with_headers: bool,
33    export_trace_id: bool,
34}
35
36impl HttpLayer {
37    /// [`Span`]s are constructed at the given level from server side.
38    pub fn server(level: Level) -> Self {
39        Self {
40            level,
41            kind: SpanKind::Server,
42            with_headers: true,
43            export_trace_id: false,
44        }
45    }
46
47    /// [`Span`]s are constructed at the given level from client side.
48    pub fn client(level: Level) -> Self {
49        Self {
50            level,
51            kind: SpanKind::Client,
52            with_headers: true,
53            export_trace_id: false,
54        }
55    }
56
57    pub fn with_headers(mut self, with_headers: bool) -> Self {
58        self.with_headers = with_headers;
59        self
60    }
61
62    pub fn export_trace_id(mut self, export_trace_id: bool) -> Self {
63        self.export_trace_id = export_trace_id;
64        self
65    }
66}
67
68impl<S> Layer<S> for HttpLayer {
69    type Service = HttpService<S>;
70
71    fn layer(&self, inner: S) -> Self::Service {
72        HttpService {
73            inner,
74            level: self.level,
75            kind: self.kind,
76            with_headers: self.with_headers,
77            export_trace_id: self.export_trace_id,
78        }
79    }
80}
81
82/// Middleware that adds tracing to a [`Service`] that handles HTTP requests.
83#[derive(Clone, Debug)]
84pub struct HttpService<S> {
85    inner: S,
86    level: Level,
87    kind: SpanKind,
88    with_headers: bool,
89    export_trace_id: bool,
90}
91
92impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for HttpService<S>
93where
94    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
95    S::Error: Display,
96{
97    type Response = S::Response;
98    type Error = S::Error;
99    type Future = ResponseFuture<S::Future>;
100
101    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102        self.inner.poll_ready(cx)
103    }
104
105    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
106        let span = self.make_request_span(&mut req);
107        let inner = {
108            let _enter = span.enter();
109            self.inner.call(req)
110        };
111
112        ResponseFuture {
113            inner,
114            span,
115            kind: self.kind,
116            with_headers: self.with_headers,
117            export_trace_id: self.export_trace_id,
118        }
119    }
120}
121
122impl<S> HttpService<S> {
123    /// Creates a new [`Span`] for the given request.
124    fn make_request_span<B>(&self, request: &mut Request<B>) -> Span {
125        let Self { level, kind, .. } = self;
126
127        macro_rules! make_span {
128            ($level:expr) => {{
129                use tracing::field::Empty;
130
131                tracing::span!(
132                    $level,
133                    "HTTP",
134                    "exception.message" = Empty,
135                    "http.request.method" = tracing::field::display(request.method()),
136                    "http.response.status_code" = Empty,
137                    "network.protocol.name" = "http",
138                    "network.protocol.version" = tracing::field::debug(request.version()),
139                    "otel.kind" = tracing::field::debug(kind),
140                    "otel.status_code" = Empty,
141                    "url.full" = tracing::field::display(request.uri()),
142                    "url.path" = request.uri().path(),
143                    "url.query" = Empty,
144                )
145            }};
146        }
147
148        let span = match *level {
149            Level::ERROR => make_span!(Level::ERROR),
150            Level::WARN => make_span!(Level::WARN),
151            Level::INFO => make_span!(Level::INFO),
152            Level::DEBUG => make_span!(Level::DEBUG),
153            Level::TRACE => make_span!(Level::TRACE),
154        };
155
156        if self.with_headers {
157            for (header_name, header_value) in request.headers().iter() {
158                if let Ok(attribute_value) = header_value.to_str() {
159                    // attribute::HTTP_REQUEST_HEADER
160                    let attribute_name = format!("http.request.header.{}", header_name);
161                    span.set_attribute(attribute_name, attribute_value.to_owned());
162                }
163            }
164        }
165
166        if let Some(query) = request.uri().query() {
167            span.record(attribute::URL_QUERY, query);
168        }
169
170        match kind {
171            SpanKind::Client => {
172                let context = span.context();
173                opentelemetry::global::get_text_map_propagator(|injector| {
174                    injector.inject_context(&context, &mut HeaderInjector(request.headers_mut()));
175                });
176            }
177            SpanKind::Server => {
178                let context = opentelemetry::global::get_text_map_propagator(|extractor| {
179                    extractor.extract(&HeaderExtractor(request.headers()))
180                });
181                span.set_parent(context);
182            }
183        }
184
185        span
186    }
187}
188
189/// Response future for [`Http`].
190#[pin_project]
191pub struct ResponseFuture<F> {
192    #[pin]
193    inner: F,
194    span: Span,
195    kind: SpanKind,
196    with_headers: bool,
197    export_trace_id: bool,
198}
199
200impl<F, ResBody, E> Future for ResponseFuture<F>
201where
202    F: Future<Output = Result<Response<ResBody>, E>>,
203    E: Display,
204{
205    type Output = Result<Response<ResBody>, E>;
206
207    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
208        let this = self.project();
209        let _enter = this.span.enter();
210
211        match ready!(this.inner.poll(cx)) {
212            Ok(mut response) => {
213                Self::record_response(this.span, *this.kind, *this.with_headers, &response);
214                if *this.export_trace_id {
215                    let trace_id = this
216                        .span
217                        .context()
218                        .span()
219                        .span_context()
220                        .trace_id()
221                        .to_string();
222                    if let Ok(value) = HeaderValue::from_str(&trace_id) {
223                        response
224                            .headers_mut()
225                            .insert(HeaderName::from_static("x-trace-id"), value);
226                    }
227                }
228                Poll::Ready(Ok(response))
229            }
230            Err(err) => {
231                Self::record_error(this.span, &err);
232                Poll::Ready(Err(err))
233            }
234        }
235    }
236}
237
238impl<F, ResBody, E> ResponseFuture<F>
239where
240    F: Future<Output = Result<Response<ResBody>, E>>,
241    E: Display,
242{
243    /// Records fields associated to the response.
244    fn record_response<B>(span: &Span, kind: SpanKind, with_headers: bool, response: &Response<B>) {
245        span.record(
246            attribute::HTTP_RESPONSE_STATUS_CODE,
247            response.status().as_u16() as i64,
248        );
249
250        if with_headers {
251            for (header_name, header_value) in response.headers().iter() {
252                if let Ok(attribute_value) = header_value.to_str() {
253                    let attribute_name = format!("http.response.header.{}", header_name);
254                    span.set_attribute(attribute_name, attribute_value.to_owned());
255                }
256            }
257        }
258
259        if let SpanKind::Client = kind {
260            if response.status().is_client_error() {
261                span.record(attribute::OTEL_STATUS_CODE, "ERROR");
262            }
263        }
264        if response.status().is_server_error() {
265            span.record(attribute::OTEL_STATUS_CODE, "ERROR");
266        }
267    }
268
269    /// Records the error message.
270    fn record_error(span: &Span, err: &E) {
271        span.record(attribute::OTEL_STATUS_CODE, "ERROR");
272        span.record(attribute::EXCEPTION_MESSAGE, err.to_string());
273    }
274}