viz_core/middleware/otel/
tracing.rs

1//! Request tracing middleware with [`OpenTelemetry`].
2//!
3//! [`OpenTelemetry`]: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md
4
5use http::{HeaderValue, uri::Scheme};
6use opentelemetry::{
7    Context, InstrumentationScope, KeyValue, global,
8    propagation::Extractor,
9    trace::{
10        FutureExt as OtelFutureExt, Span, SpanKind, Status, TraceContextExt, Tracer, TracerProvider,
11    },
12};
13use opentelemetry_semantic_conventions::trace::{
14    CLIENT_ADDRESS, EXCEPTION_MESSAGE, HTTP_REQUEST_METHOD, HTTP_RESPONSE_STATUS_CODE, HTTP_ROUTE,
15    NETWORK_PROTOCOL_VERSION, SERVER_ADDRESS, SERVER_PORT, URL_PATH, URL_QUERY, URL_SCHEME,
16    USER_AGENT_ORIGINAL,
17};
18
19use crate::{
20    Handler, IntoResponse, Request, RequestExt, Response, ResponseExt, Result, Transform,
21    header::{HeaderMap, HeaderName},
22    headers::UserAgent,
23};
24
25const HTTP_REQUEST_BODY_SIZE: &str = "http.request.body.size";
26const HTTP_RESPONSE_BODY_SIZE: &str = "http.response.body.size";
27
28/// `OpenTelemetry` tracing config.
29#[derive(Debug)]
30pub struct Config<T> {
31    tracer: T,
32    name: Option<String>,
33}
34
35impl<T> Config<T> {
36    /// Creats new `OpenTelemetry` tracing config.
37    pub fn new(t: T, name: Option<String>) -> Self {
38        Self { tracer: t, name }
39    }
40}
41
42impl<H, T> Transform<H> for Config<T>
43where
44    T: Clone,
45{
46    type Output = TracingMiddleware<H, T>;
47
48    fn transform(&self, h: H) -> Self::Output {
49        TracingMiddleware {
50            h,
51            tracer: self.tracer.clone(),
52            name: self.name.clone().unwrap_or("tracing".to_string()),
53        }
54    }
55}
56
57/// `OpenTelemetry` tracing middleware.
58#[derive(Clone, Debug)]
59pub struct TracingMiddleware<H, T> {
60    h: H,
61    tracer: T,
62    name: String,
63}
64
65#[crate::async_trait]
66impl<H, O, T> Handler<Request> for TracingMiddleware<H, T>
67where
68    H: Handler<Request, Output = Result<O>>,
69    O: IntoResponse,
70    T: TracerProvider + Send + Sync + Clone + 'static,
71    T::Tracer: Tracer + Send + Sync + 'static,
72    <T::Tracer as Tracer>::Span: Span + Send + Sync + 'static,
73{
74    type Output = Result<Response>;
75
76    async fn call(&self, req: Request) -> Self::Output {
77        let parent_context = global::get_text_map_propagator(|propagator| {
78            propagator.extract(&RequestHeaderCarrier::new(req.headers()))
79        });
80
81        let http_route = &req.route_info().pattern;
82        let attributes = build_attributes(&req, http_route.as_str());
83        let scope = InstrumentationScope::builder(self.name.clone())
84            .with_attributes(attributes)
85            .build();
86        let tracer = self.tracer.tracer_with_scope(scope);
87        let mut span = tracer.build_with_context(
88            tracer
89                .span_builder(format!("{} {}", req.method(), http_route))
90                .with_kind(SpanKind::Server),
91            &parent_context,
92        );
93
94        span.add_event("request.started".to_string(), vec![]);
95
96        let resp = self
97            .h
98            .call(req)
99            .with_context(Context::current_with_span(span))
100            .await;
101
102        let cx = Context::current();
103        let span = cx.span();
104
105        match resp {
106            Ok(resp) => {
107                let resp = resp.into_response();
108                span.add_event("request.completed".to_string(), vec![]);
109                span.set_attribute(KeyValue::new(
110                    HTTP_RESPONSE_STATUS_CODE,
111                    i64::from(resp.status().as_u16()),
112                ));
113                if let Some(content_length) = resp.content_length() {
114                    span.set_attribute(KeyValue::new(
115                        HTTP_RESPONSE_BODY_SIZE,
116                        i64::try_from(content_length).unwrap_or(i64::MAX),
117                    ));
118                }
119                if resp.status().is_server_error() {
120                    span.set_status(Status::error(
121                        resp.status()
122                            .canonical_reason()
123                            .map(ToString::to_string)
124                            .unwrap_or_default(),
125                    ));
126                }
127                span.end();
128                Ok(resp)
129            }
130            Err(err) => {
131                span.add_event(
132                    "request.error".to_string(),
133                    vec![KeyValue::new(EXCEPTION_MESSAGE, err.to_string())],
134                );
135                span.set_status(Status::error(err.to_string()));
136                span.end();
137                Err(err)
138            }
139        }
140    }
141}
142
143struct RequestHeaderCarrier<'a> {
144    headers: &'a HeaderMap,
145}
146
147impl<'a> RequestHeaderCarrier<'a> {
148    const fn new(headers: &'a HeaderMap) -> Self {
149        RequestHeaderCarrier { headers }
150    }
151}
152
153impl Extractor for RequestHeaderCarrier<'_> {
154    fn get(&self, key: &str) -> Option<&str> {
155        self.headers
156            .get(key)
157            .map(HeaderValue::to_str)
158            .and_then(Result::ok)
159    }
160
161    fn keys(&self) -> Vec<&str> {
162        self.headers.keys().map(HeaderName::as_str).collect()
163    }
164}
165
166fn build_attributes(req: &Request, http_route: &str) -> Vec<KeyValue> {
167    let mut attributes = Vec::with_capacity(10);
168    // <https://github.com/open-telemetry/semantic-conventions/blob/v1.21.0/docs/http/http-spans.md#http-server>
169    attributes.push(KeyValue::new(HTTP_ROUTE, http_route.to_string()));
170
171    // <https://github.com/open-telemetry/semantic-conventions/blob/v1.21.0/docs/http/http-spans.md#common-attributes>
172    attributes.push(KeyValue::new(HTTP_REQUEST_METHOD, req.method().to_string()));
173    attributes.push(KeyValue::new(
174        NETWORK_PROTOCOL_VERSION,
175        format!("{:?}", req.version()),
176    ));
177
178    if let Some(remote_addr) = req.remote_addr() {
179        attributes.push(KeyValue::new(CLIENT_ADDRESS, remote_addr.to_string()));
180    }
181
182    let uri = req.uri();
183    if let Some(host) = uri.host() {
184        attributes.push(KeyValue::new(SERVER_ADDRESS, host.to_string()));
185    }
186    if let Some(port) = uri
187        .port_u16()
188        .map(i64::from)
189        .filter(|port| *port != 80 && *port != 443)
190    {
191        attributes.push(KeyValue::new(SERVER_PORT, port.to_string()));
192    }
193
194    if let Some(path_query) = uri.path_and_query() {
195        if path_query.path() != "/" {
196            attributes.push(KeyValue::new(URL_PATH, path_query.path().to_string()));
197        }
198        if let Some(query) = path_query.query() {
199            attributes.push(KeyValue::new(URL_QUERY, query.to_string()));
200        }
201    }
202
203    attributes.push(KeyValue::new(
204        URL_SCHEME,
205        uri.scheme().unwrap_or(&Scheme::HTTP).to_string(),
206    ));
207
208    if let Some(content_length) = req
209        .content_length()
210        .and_then(|len| i64::try_from(len).ok())
211        .filter(|len| *len > 0)
212    {
213        attributes.push(KeyValue::new(
214            HTTP_REQUEST_BODY_SIZE,
215            content_length.to_string(),
216        ));
217    }
218
219    if let Some(user_agent) = req
220        .header_typed::<UserAgent>()
221        .as_ref()
222        .map(UserAgent::as_str)
223    {
224        attributes.push(KeyValue::new(USER_AGENT_ORIGINAL, user_agent.to_string()));
225    }
226
227    attributes
228}