tower_http_tracing/
lib.rs

1//!Tower tracing middleware to annotate every HTTP request with tracing's span.
2//!
3//!## Span creation
4//!
5//!Use [macro](macro.make_request_spanner.html) to declare function that creates desirable span
6//!
7//!## Example
8//!
9//!Below is illustration of how to initialize request layer for passing into your service
10//!
11//!```rust
12//!use std::net::IpAddr;
13//!
14//!use tower_http_tracing::HttpRequestLayer;
15//!
16//!//Logic to extract client ip has to be written by user
17//!//You can use utilities in separate crate to design this logic:
18//!//https://docs.rs/http-ip/latest/http_ip/
19//!fn extract_client_ip(_parts: &http::request::Parts) -> Option<IpAddr> {
20//!    None
21//!}
22//!tower_http_tracing::make_request_spanner!(make_my_request_span("my_request", tracing::Level::INFO));
23//!let layer = HttpRequestLayer::new(make_my_request_span).with_extract_client_ip(extract_client_ip)
24//!                                                       .with_inspect_headers(&[&http::header::FORWARDED]);
25//!//Use above layer in your service
26//!```
27
28#![warn(missing_docs)]
29#![allow(clippy::style)]
30
31mod grpc;
32mod headers;
33
34use std::net::IpAddr;
35use core::{cmp, fmt, ptr, task};
36use core::pin::Pin;
37use core::future::Future;
38
39pub use tracing;
40
41///RequestId's header name
42pub const REQUEST_ID: http::HeaderName = http::HeaderName::from_static("x-request-id");
43///Alias to function signature required to create span
44pub type MakeSpan = fn() -> tracing::Span;
45///ALias to function signature to extract client's ip from request
46pub type ExtractClientIp = fn(&http::request::Parts) -> Option<IpAddr>;
47
48#[inline]
49fn default_client_ip(_: &http::request::Parts) -> Option<IpAddr> {
50    None
51}
52
53#[derive(Copy, Clone, PartialEq, Eq)]
54///Possible request protocol
55pub enum Protocol {
56    ///Regular HTTP call
57    ///
58    ///Default value for all requests
59    Http,
60    ///gRPC call, identified by presence of `Content-Type` with grpc protocol signature
61    Grpc,
62}
63
64impl Protocol {
65    #[inline(always)]
66    ///Determines protocol from value of `Content-Type`
67    pub fn from_content_type(typ: &[u8]) -> Self {
68        if typ.starts_with(b"application/grpc") {
69            Self::Grpc
70        } else {
71            Self::Http
72        }
73    }
74
75    #[inline(always)]
76    ///Returns textual representation of the `self`
77    pub const fn as_str(&self) -> &'static str {
78        match self {
79            Self::Grpc => "grpc",
80            Self::Http => "http"
81        }
82    }
83}
84
85impl fmt::Debug for Protocol {
86    #[inline(always)]
87    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
88        fmt::Debug::fmt(self.as_str(), fmt)
89    }
90}
91
92impl fmt::Display for Protocol {
93    #[inline(always)]
94    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
95        fmt::Display::fmt(self.as_str(), fmt)
96    }
97}
98
99type RequestIdBuffer = [u8; 64];
100
101#[derive(Clone)]
102///Request's id
103///
104///By default it is extracted from `X-Request-Id` header
105pub struct RequestId {
106    buffer: RequestIdBuffer,
107    len: u8,
108}
109
110impl RequestId {
111    fn from_bytes(bytes: &[u8]) -> Self {
112        let mut buffer: RequestIdBuffer = [0; 64];
113
114        let len = cmp::min(buffer.len(), bytes.len());
115
116        unsafe {
117            ptr::copy_nonoverlapping(bytes.as_ptr(), buffer.as_mut_ptr(), len)
118        };
119
120        Self {
121            buffer,
122            len: len as _,
123        }
124    }
125
126    fn from_uuid(uuid: uuid::Uuid) -> Self {
127        let mut buffer: RequestIdBuffer = [0; 64];
128        let uuid = uuid.as_hyphenated();
129        let len = uuid.encode_lower(&mut buffer).len();
130
131        Self {
132            buffer,
133            len: len as _,
134        }
135    }
136
137    #[inline]
138    ///Returns slice to already written data.
139    pub const fn as_bytes(&self) -> &[u8] {
140        unsafe {
141            core::slice::from_raw_parts(self.buffer.as_ptr(), self.len as _)
142        }
143    }
144
145    #[inline(always)]
146    ///Gets textual representation of the request id, if header value is string
147    pub const fn as_str(&self) -> Option<&str> {
148        match core::str::from_utf8(self.as_bytes()) {
149            Ok(header) => Some(header),
150            Err(_) => None,
151        }
152    }
153}
154
155impl fmt::Debug for RequestId {
156    #[inline(always)]
157    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
158        match self.as_str() {
159            Some(id) => fmt::Debug::fmt(id, fmt),
160            None => fmt::Debug::fmt(self.as_bytes(), fmt),
161        }
162    }
163}
164
165impl fmt::Display for RequestId {
166    #[inline(always)]
167    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
168        match self.as_str() {
169            Some(id) => fmt::Display::fmt(id, fmt),
170            None => fmt::Display::fmt("<non-utf8>", fmt),
171        }
172    }
173}
174
175#[macro_export]
176///Declares `fn` function compatible with `MakeSpan` using provided parameters
177///
178///## Span fields
179///
180///Following fields are declared when span is created:
181///- `http.request.method`
182///- `url.path`
183///- `url.query`
184///- `url.scheme`
185///- `http.request_id` - Inherited from request 'X-Request-Id' or random uuid
186///- `user_agent.original` - Only populated if user agent header is present
187///- `http.headers` - Optional. Populated if more than 1 header specified via layer [config](struct.HttpRequestLayer.html#method.with_inspect_headers)
188///- `network.protocol.name` - Either `http` or `grpc` depending on `content-type`
189///- `network.protocol.version` - Set to HTTP version in case of plain `http` protocol.
190///- `client.address` - Optionally added if IP extractor is specified via layer [config](struct.HttpRequestLayer.html#method.with_extract_client_ip)
191///- `http.response.status_code` - Semantics of this code depends on `protocol`
192///- `error.type` - Populated with `core::any::type_name` value of error type used by the service.
193///- `error.message` - Populated with `Display` content of the error, returned by underlying service, after processing request.
194///
195///Loosely follows <https://opentelemetry.io/docs/specs/semconv/http/http-spans/#http-server>
196///
197///## Usage
198///
199///```
200///use tower_http_tracing::make_request_spanner;
201///
202///make_request_spanner!(make_my_request_span("my_request", tracing::Level::INFO));
203/////Customize span with extra fields. You can use tracing::field::Empty if you want to omit value
204///make_request_spanner!(make_my_service_request_span("my_request", tracing::Level::INFO, service_name = "<your name>"));
205///
206///let span = make_my_request_span();
207///span.record("url.path", "I can override span field");
208///
209///```
210macro_rules! make_request_spanner {
211    ($fn:ident($name:literal, $level:expr)) => {
212        $crate::make_request_spanner!($fn($name, $level,));
213    };
214    ($fn:ident($name:literal, $level:expr, $($fields:tt)*)) => {
215        #[track_caller]
216        pub fn $fn() -> $crate::tracing::Span {
217            use $crate::tracing::field;
218
219            $crate::tracing::span!(
220                $level,
221                $name,
222                //Assigned on creation of span
223                http.request.method = field::Empty,
224                url.path = field::Empty,
225                url.query = field::Empty,
226                url.scheme = field::Empty,
227                http.request_id = field::Empty,
228                user_agent.original = field::Empty,
229                http.headers = field::Empty,
230                network.protocol.name = field::Empty,
231                network.protocol.version = field::Empty,
232                //Optional
233                client.address = field::Empty,
234                //Assigned after request is complete
235                http.response.status_code = field::Empty,
236                error.message = field::Empty,
237                $(
238                    $fields
239                )*
240            )
241        }
242    };
243}
244
245#[derive(Clone, Debug)]
246///Request's information
247///
248///It is accessible via [extensions](https://docs.rs/http/latest/http/struct.Extensions.html)
249pub struct RequestInfo {
250    ///Request's protocol
251    pub protocol: Protocol,
252    ///Request's id
253    pub request_id: RequestId,
254    ///Client's IP address extracted, if available.
255    pub client_ip: Option<IpAddr>,
256}
257
258///Request's span information
259///
260///Created on every request by the middleware, but not accessible to the user directly
261pub struct RequestSpan {
262    ///Underlying tracing span
263    pub span: tracing::Span,
264    ///Request's information
265    pub info: RequestInfo,
266}
267
268impl RequestSpan {
269    ///Creates new request span
270    pub fn new(span: tracing::Span, extract_client_ip: ExtractClientIp, parts: &http::request::Parts) -> Self {
271        let _entered = span.enter();
272
273        let client_ip = (extract_client_ip)(parts);
274        let protocol = parts.headers
275                            .get(http::header::CONTENT_TYPE)
276                            .map_or(Protocol::Http, |content_type| Protocol::from_content_type(content_type.as_bytes()));
277
278        let request_id = if let Some(request_id) = parts.headers.get(REQUEST_ID) {
279            RequestId::from_bytes(request_id.as_bytes())
280        } else {
281            RequestId::from_uuid(uuid::Uuid::new_v4())
282        };
283
284        if let Some(user_agent) = parts.headers.get(http::header::USER_AGENT).and_then(|header| header.to_str().ok()) {
285            span.record("user_agent.original", user_agent);
286        }
287        span.record("http.request.method", parts.method.as_str());
288        span.record("url.path", parts.uri.path());
289        if let Some(query) = parts.uri.query() {
290            span.record("url.query", query);
291        }
292        if let Some(scheme) = parts.uri.scheme() {
293            span.record("url.scheme", scheme.as_str());
294        }
295        if let Some(request_id) = request_id.as_str() {
296            span.record("http.request_id", &request_id);
297        } else {
298            span.record("http.request_id", request_id.as_bytes());
299        }
300        if let Some(client_ip) = client_ip {
301            span.record("client.address", tracing::field::display(client_ip));
302        }
303        span.record("network.protocol.name", protocol.as_str());
304        if let Protocol::Http = protocol {
305            match parts.version {
306                http::Version::HTTP_09 => span.record("network.protocol.version", 0.9),
307                http::Version::HTTP_10 => span.record("network.protocol.version", 1.0),
308                http::Version::HTTP_11 => span.record("network.protocol.version", 1.1),
309                http::Version::HTTP_2 => span.record("network.protocol.version", 2),
310                http::Version::HTTP_3 => span.record("network.protocol.version", 3),
311                //Invalid version so just set 0
312                _ => span.record("network.protocol.version", 0),
313            };
314        }
315
316        drop(_entered);
317
318        Self {
319            span,
320            info: RequestInfo {
321                protocol,
322                request_id,
323                client_ip
324            }
325        }
326    }
327}
328
329#[derive(Clone)]
330///Tower layer
331pub struct HttpRequestLayer {
332    make_span: MakeSpan,
333    inspect_headers: &'static [&'static http::HeaderName],
334    extract_client_ip: ExtractClientIp,
335}
336
337impl HttpRequestLayer {
338    #[inline]
339    ///Creates new layer with provided span maker
340    pub fn new(make_span: MakeSpan) -> Self {
341        Self {
342            make_span,
343            inspect_headers: &[],
344            extract_client_ip: default_client_ip
345        }
346    }
347
348    #[inline]
349    ///Specifies list of headers you want to inspect via `http.headers` attribute.
350    ///
351    ///By default none of the headers are inspected
352    pub fn with_inspect_headers(mut self, inspect_headers: &'static [&'static http::HeaderName]) -> Self {
353        self.inspect_headers = inspect_headers;
354        self
355    }
356
357    ///Customizes client ip extraction method
358    ///
359    ///Default extracts none
360    pub fn with_extract_client_ip(mut self, extract_client_ip: ExtractClientIp) -> Self {
361        self.extract_client_ip = extract_client_ip;
362        self
363    }
364}
365
366impl<S> tower_layer::Layer<S> for HttpRequestLayer {
367    type Service = HttpRequestService<S>;
368    #[inline(always)]
369    fn layer(&self, inner: S) -> Self::Service {
370        HttpRequestService {
371            layer: self.clone(),
372            inner,
373        }
374    }
375}
376
377///Tower service to annotate requests with span
378pub struct HttpRequestService<S> {
379    layer: HttpRequestLayer,
380    inner: S
381}
382
383impl<ReqBody, ResBody, S: tower_service::Service<http::Request<ReqBody>, Response = http::Response<ResBody>>> tower_service::Service<http::Request<ReqBody>> for HttpRequestService<S> where S::Error: std::error::Error {
384    type Response = S::Response;
385    type Error = S::Error;
386    type Future = ResponseFut<S::Future>;
387
388    #[inline(always)]
389    fn poll_ready(&mut self, ctx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
390        self.inner.poll_ready(ctx)
391    }
392
393    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
394        let (parts, body) = req.into_parts();
395        let RequestSpan { span, info } = RequestSpan::new((self.layer.make_span)(), self.layer.extract_client_ip, &parts);
396
397        let _entered = span.enter();
398        if !self.layer.inspect_headers.is_empty() {
399            span.record("http.headers", tracing::field::debug(headers::InspectHeaders {
400                header_list: self.layer.inspect_headers,
401                headers: &parts.headers
402            }));
403        }
404        let request_id = info.request_id.clone();
405        let protocol = info.protocol;
406        let mut req = http::Request::from_parts(parts, body);
407        req.extensions_mut().insert(info);
408        let inner = self.inner.call(req);
409
410        drop(_entered);
411        ResponseFut {
412            inner,
413            span,
414            protocol,
415            request_id
416        }
417    }
418}
419
420///Middleware's response future
421pub struct ResponseFut<F> {
422    inner: F,
423    span: tracing::Span,
424    protocol: Protocol,
425    request_id: RequestId,
426}
427
428impl<ResBody, E: std::error::Error, F: Future<Output = Result<http::Response<ResBody>, E>>> Future for ResponseFut<F> {
429    type Output = F::Output;
430
431    fn poll(self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
432        let (fut, span, protocol, request_id) = unsafe {
433            let this = self.get_unchecked_mut();
434            (
435                Pin::new_unchecked(&mut this.inner),
436                &this.span,
437                this.protocol,
438                &this.request_id,
439            )
440        };
441        let _entered = span.enter();
442        match Future::poll(fut, ctx) {
443            task::Poll::Ready(Ok(mut resp)) => {
444                if let Ok(request_id) = http::HeaderValue::from_bytes(request_id.as_bytes()) {
445                    resp.headers_mut().insert(REQUEST_ID, request_id);
446                }
447                let status = match protocol {
448                    Protocol::Http => resp.status().as_u16(),
449                    Protocol::Grpc => match resp.headers().get("grpc-status") {
450                        Some(status) => grpc::parse_grpc_status(status.as_bytes()),
451                        None => 2,
452                    }
453                };
454                span.record("http.response.status_code", status);
455
456                task::Poll::Ready(Ok(resp))
457            }
458            task::Poll::Ready(Err(error)) => {
459                let status = match protocol {
460                    Protocol::Http => 500u16,
461                    Protocol::Grpc => 13,
462                };
463                span.record("http.response.status_code", status);
464                span.record("error.type", core::any::type_name::<E>());
465                span.record("error.message", tracing::field::display(&error));
466                task::Poll::Ready(Err(error))
467            },
468            task::Poll::Pending => task::Poll::Pending
469        }
470    }
471}