Skip to main content

telemetry_rust/instrumentations/http/
reqwest.rs

1//! Async reqwest instrumentation helpers.
2//!
3//! # Example
4//!
5//! ```no_run
6//! use telemetry_rust::instrumentations::http::reqwest::ReqwestBuilderInstrument;
7//!
8//! # async fn example() -> Result<(), reqwest::Error> {
9//! let response = reqwest::Client::new()
10//!     .get("https://example.com/health")
11//!     .instrument()
12//!     .send()
13//!     .await?;
14//! # let _ = response;
15//! # Ok(())
16//! # }
17//! ```
18
19use ::reqwest as reqwest_crate;
20use std::future::Future;
21
22use crate::{Context, http, instrumentations::http::client::HttpClientSpanBuilder};
23
24/// A trait for instrumenting async reqwest request builders with OpenTelemetry tracing.
25///
26/// ```no_run
27/// use telemetry_rust::instrumentations::http::reqwest::ReqwestBuilderInstrument;
28///
29/// # async fn example() -> Result<(), reqwest::Error> {
30/// let response = reqwest::Client::new()
31///     .get("https://example.com/health")
32///     .instrument()
33///     .send()
34///     .await?;
35/// # let _ = response;
36/// # Ok(())
37/// # }
38/// ```
39pub trait ReqwestBuilderInstrument
40where
41    Self: Sized,
42{
43    /// Instruments this reqwest builder with OpenTelemetry tracing.
44    fn instrument(self) -> InstrumentedRequestBuilder;
45}
46
47impl ReqwestBuilderInstrument for reqwest_crate::RequestBuilder {
48    fn instrument(self) -> InstrumentedRequestBuilder {
49        InstrumentedRequestBuilder::new(self)
50    }
51}
52
53impl HttpClientSpanBuilder {
54    pub(crate) fn from_reqwest_request(request: &reqwest_crate::Request) -> Self {
55        Self::from_parts(request.method(), request.headers(), request.url())
56    }
57}
58
59/// A wrapper that instruments async reqwest request builders with OpenTelemetry tracing.
60#[must_use = "RequestBuilder does nothing until you call send()"]
61pub struct InstrumentedRequestBuilder {
62    inner: reqwest_crate::RequestBuilder,
63    context: Option<Context>,
64}
65
66impl InstrumentedRequestBuilder {
67    /// Creates a new instrumented reqwest request builder.
68    pub fn new(inner: reqwest_crate::RequestBuilder) -> Self {
69        Self {
70            inner,
71            context: None,
72        }
73    }
74
75    /// Sets the OpenTelemetry context for this instrumented request.
76    pub fn context(mut self, context: &Context) -> Self {
77        self.context = Some(context.clone());
78        self
79    }
80
81    /// Sets the optional OpenTelemetry context for this instrumented request.
82    pub fn set_context(mut self, context: Option<&Context>) -> Self {
83        self.context = context.cloned();
84        self
85    }
86
87    /// Sends the request and records an outbound HTTP client span around it.
88    pub fn send(
89        self,
90    ) -> impl Future<Output = Result<reqwest_crate::Response, reqwest_crate::Error>> {
91        let (client, request_result) = self.inner.build_split();
92        let context = self.context;
93
94        async move {
95            let mut request = request_result?;
96            let span_builder = HttpClientSpanBuilder::from_reqwest_request(&request);
97            let span = match context.as_ref() {
98                Some(context) => span_builder.start_with_context(context),
99                None => span_builder.start(),
100            };
101
102            http::inject_context_on_context(span.context(), request.headers_mut());
103
104            let result = client.execute(request).await;
105            match &result {
106                Ok(response) => {
107                    span.end_response(
108                        response.status(),
109                        response.version(),
110                        response.remote_addr(),
111                    );
112                }
113                Err(error) => span.end_error(reqwest_error_type(error), error),
114            }
115            result
116        }
117    }
118}
119
120fn reqwest_error_type(error: &reqwest_crate::Error) -> &'static str {
121    if error.is_timeout() {
122        "timeout"
123    } else if error.is_connect() {
124        "connect"
125    } else if error.is_redirect() {
126        "redirect"
127    } else if error.is_request() {
128        "request"
129    } else if error.is_body() {
130        "body"
131    } else if error.is_decode() {
132        "decode"
133    } else if error.is_builder() {
134        "builder"
135    } else {
136        "_OTHER"
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::ReqwestBuilderInstrument;
143    use crate::{Context, OpenTelemetryLayer, Value, semconv};
144    use assert2::assert;
145    use axum::{
146        Router,
147        extract::State,
148        http::{HeaderMap, StatusCode},
149        response::{IntoResponse, Redirect},
150        routing::get,
151    };
152    use opentelemetry::{
153        global,
154        trace::{Span as _, SpanKind, TraceContextExt, Tracer as _, TracerProvider as _},
155    };
156    use opentelemetry_sdk::{
157        propagation::TraceContextPropagator,
158        trace::{InMemorySpanExporter, SdkTracerProvider as TracerProvider},
159    };
160    use serial_test::serial;
161    use std::sync::{Arc, Mutex};
162    use tokio::{net::TcpListener, task::JoinHandle};
163    use tracing_opentelemetry::OpenTelemetrySpanExt as _;
164    use tracing_subscriber::{Registry, layer::SubscriberExt};
165
166    #[derive(Clone, Default)]
167    struct TestState {
168        traceparents: Arc<Mutex<Vec<(String, String)>>>,
169    }
170
171    impl TestState {
172        fn record(&self, path: &str, headers: &HeaderMap) {
173            if let Some(traceparent) = headers
174                .get("traceparent")
175                .and_then(|value| value.to_str().ok())
176            {
177                self.traceparents
178                    .lock()
179                    .unwrap()
180                    .push((path.to_owned(), traceparent.to_owned()));
181            }
182        }
183
184        fn traceparent_for(&self, path: &str) -> Option<String> {
185            self.traceparents
186                .lock()
187                .unwrap()
188                .iter()
189                .rev()
190                .find(|(recorded_path, _)| recorded_path == path)
191                .map(|(_, traceparent)| traceparent.clone())
192        }
193    }
194
195    #[tokio::test]
196    #[serial]
197    async fn instruments_successful_requests() {
198        let telemetry = configure_test_tracing();
199        let server = spawn_server().await;
200
201        let response = test_client()
202            .get(format!("{}/ok?ready=true", server.base_url))
203            .header(::reqwest::header::USER_AGENT, "telemetry-rust-tests")
204            .instrument()
205            .send()
206            .await
207            .unwrap();
208
209        assert!(response.status() == StatusCode::OK);
210
211        let spans = force_flush_and_get_spans(&telemetry);
212        let span = find_span(&spans, "GET");
213        let traceparent = server.state.traceparent_for("/ok").unwrap();
214        let (trace_id, span_id) = traceparent_ids(&traceparent);
215
216        assert!(span.span_kind == SpanKind::Client);
217        assert!(span.span_context.trace_id().to_string() == trace_id);
218        assert!(span.span_context.span_id().to_string() == span_id);
219        assert!(matches!(span.status, opentelemetry::trace::Status::Unset));
220        assert!(string_attr(span, semconv::HTTP_REQUEST_METHOD) == Some("GET"));
221        assert!(string_attr(span, semconv::URL_SCHEME) == Some("http"));
222        assert!(string_attr(span, semconv::SERVER_ADDRESS) == Some("127.0.0.1"));
223        assert!(string_attr(span, semconv::URL_PATH) == Some("/ok"));
224        assert!(string_attr(span, semconv::URL_QUERY) == Some("ready=true"));
225        assert!(
226            string_attr(span, semconv::USER_AGENT_ORIGINAL)
227                == Some("telemetry-rust-tests")
228        );
229        assert!(i64_attr(span, semconv::HTTP_RESPONSE_STATUS_CODE) == Some(200));
230        assert!(string_attr(span, semconv::NETWORK_PROTOCOL_VERSION).is_some());
231        assert!(string_attr(span, semconv::NETWORK_PEER_ADDRESS).is_some());
232        assert!(i64_attr(span, semconv::NETWORK_PEER_PORT).is_some());
233    }
234
235    #[tokio::test]
236    #[serial]
237    async fn propagates_traceparent_with_client_span_id() {
238        let telemetry = configure_test_tracing();
239        let server = spawn_server().await;
240
241        let tracer = global::tracer("reqwest-propagation-test");
242        let subscriber = Registry::default().with(OpenTelemetryLayer::new(tracer));
243        let _guard = tracing::subscriber::set_default(subscriber);
244
245        let parent = tracing::info_span!("parent");
246        let parent_context = parent.context();
247        let expected_trace_id = parent_context.span().span_context().trace_id();
248
249        tracing::Instrument::instrument(
250            async {
251                test_client()
252                    .get(format!("{}/ok", server.base_url))
253                    .instrument()
254                    .send()
255                    .await
256                    .unwrap();
257            },
258            parent,
259        )
260        .await;
261
262        let spans = force_flush_and_get_spans(&telemetry);
263        let client_span = find_span(&spans, "GET");
264        let traceparent = server.state.traceparent_for("/ok").unwrap();
265        let (trace_id, span_id) = traceparent_ids(&traceparent);
266
267        // The traceparent carries the client span's own span-id, not the parent's.
268        assert!(trace_id == expected_trace_id.to_string());
269        assert!(span_id == client_span.span_context.span_id().to_string());
270    }
271
272    #[tokio::test]
273    #[serial]
274    async fn marks_client_error_responses_as_errors() {
275        let telemetry = configure_test_tracing();
276        let server = spawn_server().await;
277
278        let response = test_client()
279            .get(format!("{}/not-found", server.base_url))
280            .instrument()
281            .send()
282            .await
283            .unwrap();
284
285        assert!(response.status() == StatusCode::NOT_FOUND);
286
287        let spans = force_flush_and_get_spans(&telemetry);
288        let span = find_span(&spans, "GET");
289
290        assert!(matches!(
291            span.status,
292            opentelemetry::trace::Status::Error { .. }
293        ));
294        assert!(i64_attr(span, semconv::HTTP_RESPONSE_STATUS_CODE) == Some(404));
295        assert!(string_attr(span, semconv::ERROR_TYPE) == Some("404"));
296    }
297
298    #[tokio::test]
299    #[serial]
300    async fn marks_server_error_responses_as_errors() {
301        let telemetry = configure_test_tracing();
302        let server = spawn_server().await;
303
304        test_client()
305            .get(format!("{}/server-error", server.base_url))
306            .instrument()
307            .send()
308            .await
309            .unwrap();
310
311        let spans = force_flush_and_get_spans(&telemetry);
312        let span = find_span(&spans, "GET");
313
314        assert!(i64_attr(span, semconv::HTTP_RESPONSE_STATUS_CODE) == Some(500));
315        assert!(string_attr(span, semconv::ERROR_TYPE) == Some("500"));
316        assert!(matches!(
317            span.status,
318            opentelemetry::trace::Status::Error { .. }
319        ));
320    }
321
322    #[tokio::test]
323    #[serial]
324    async fn marks_transport_failures_as_errors() {
325        let telemetry = configure_test_tracing();
326
327        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
328        let addr = listener.local_addr().unwrap();
329        drop(listener);
330
331        let error = test_client()
332            .get(format!("http://{addr}/unavailable"))
333            .instrument()
334            .send()
335            .await
336            .unwrap_err();
337
338        assert!(error.is_connect());
339
340        let spans = force_flush_and_get_spans(&telemetry);
341        let span = find_span(&spans, "GET");
342
343        assert!(matches!(
344            span.status,
345            opentelemetry::trace::Status::Error { .. }
346        ));
347        assert!(string_attr(span, semconv::ERROR_TYPE) == Some("connect"));
348        assert!(i64_attr(span, semconv::HTTP_RESPONSE_STATUS_CODE).is_none());
349    }
350
351    #[tokio::test]
352    #[serial]
353    async fn preserves_original_url_when_redirects_are_followed() {
354        let telemetry = configure_test_tracing();
355        let server = spawn_server().await;
356
357        let response = test_client()
358            .get(format!("{}/redirect?step=1", server.base_url))
359            .instrument()
360            .send()
361            .await
362            .unwrap();
363
364        assert!(response.url().path() == "/final");
365
366        let spans = force_flush_and_get_spans(&telemetry);
367        let span = find_span(&spans, "GET");
368
369        let expected_url = format!("{}/redirect?step=1", server.base_url);
370        assert!(string_attr(span, semconv::URL_FULL) == Some(expected_url.as_str()));
371        assert!(server.state.traceparent_for("/redirect").is_some());
372        assert!(server.state.traceparent_for("/final").is_some());
373    }
374
375    #[tokio::test]
376    #[serial]
377    async fn uses_explicit_parent_context_when_provided() {
378        let telemetry = configure_test_tracing();
379        let server = spawn_server().await;
380        let tracer = telemetry.provider.tracer("reqwest-tests");
381        let explicit_parent = tracer.start("explicit-parent");
382        let explicit_parent_span_id = explicit_parent.span_context().span_id();
383        let explicit_parent_cx = Context::current_with_span(explicit_parent);
384        let tracing_tracer = telemetry.provider.tracer("tracing-tests");
385        let subscriber = Registry::default()
386            .with(tracing_opentelemetry::layer().with_tracer(tracing_tracer));
387        let guard = tracing::subscriber::set_default(subscriber);
388        let current_parent = tracing::info_span!("current-parent");
389
390        tracing::Instrument::instrument(
391            async {
392                test_client()
393                    .get(format!("{}/ok", server.base_url))
394                    .instrument()
395                    .context(&explicit_parent_cx)
396                    .send()
397                    .await
398                    .unwrap();
399            },
400            current_parent,
401        )
402        .await;
403
404        drop(guard);
405        explicit_parent_cx.span().end();
406
407        let spans = force_flush_and_get_spans(&telemetry);
408        let reqwest_span = find_span(&spans, "GET");
409        let current_span = find_span(&spans, "current-parent");
410
411        assert!(reqwest_span.parent_span_id == explicit_parent_span_id);
412        assert!(reqwest_span.parent_span_id != current_span.span_context.span_id());
413    }
414
415    #[tokio::test]
416    #[serial]
417    async fn does_not_emit_span_for_invalid_builder() {
418        let telemetry = configure_test_tracing();
419
420        let result = test_client()
421            .get("http://example.com")
422            .header("bad\nheader", "value")
423            .instrument()
424            .send()
425            .await;
426
427        assert!(result.is_err());
428        assert!(result.unwrap_err().is_builder());
429
430        let spans = force_flush_and_get_spans(&telemetry);
431        assert!(client_spans(&spans).is_empty());
432    }
433
434    // --- helpers ---
435
436    struct TestServer {
437        base_url: String,
438        state: TestState,
439        _handle: JoinHandle<()>,
440    }
441
442    async fn spawn_server() -> TestServer {
443        async fn ok(
444            State(state): State<TestState>,
445            headers: HeaderMap,
446        ) -> impl IntoResponse {
447            state.record("/ok", &headers);
448            StatusCode::OK
449        }
450
451        async fn not_found(
452            State(state): State<TestState>,
453            headers: HeaderMap,
454        ) -> impl IntoResponse {
455            state.record("/not-found", &headers);
456            StatusCode::NOT_FOUND
457        }
458
459        async fn server_error(
460            State(state): State<TestState>,
461            headers: HeaderMap,
462        ) -> impl IntoResponse {
463            state.record("/server-error", &headers);
464            StatusCode::INTERNAL_SERVER_ERROR
465        }
466
467        async fn redirect(
468            State(state): State<TestState>,
469            headers: HeaderMap,
470        ) -> impl IntoResponse {
471            state.record("/redirect", &headers);
472            Redirect::temporary("/final")
473        }
474
475        async fn final_route(
476            State(state): State<TestState>,
477            headers: HeaderMap,
478        ) -> impl IntoResponse {
479            state.record("/final", &headers);
480            StatusCode::OK
481        }
482
483        let state = TestState::default();
484        let app = Router::new()
485            .route("/ok", get(ok))
486            .route("/not-found", get(not_found))
487            .route("/server-error", get(server_error))
488            .route("/redirect", get(redirect))
489            .route("/final", get(final_route))
490            .with_state(state.clone());
491
492        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
493        let addr = listener.local_addr().unwrap();
494        let handle = tokio::spawn(async move {
495            axum::serve(listener, app).await.unwrap();
496        });
497
498        TestServer {
499            base_url: format!("http://{addr}"),
500            state,
501            _handle: handle,
502        }
503    }
504
505    fn configure_test_tracing() -> TestTelemetry {
506        let exporter = InMemorySpanExporter::default();
507        let provider = TracerProvider::builder()
508            .with_simple_exporter(exporter.clone())
509            .build();
510        global::set_tracer_provider(provider.clone());
511        global::set_text_map_propagator(TraceContextPropagator::new());
512        TestTelemetry { exporter, provider }
513    }
514
515    fn test_client() -> ::reqwest::Client {
516        ::reqwest::Client::builder().no_proxy().build().unwrap()
517    }
518
519    fn force_flush_and_get_spans(
520        telemetry: &TestTelemetry,
521    ) -> Vec<opentelemetry_sdk::trace::SpanData> {
522        telemetry.provider.force_flush().unwrap();
523        telemetry.exporter.get_finished_spans().unwrap()
524    }
525
526    fn client_spans(
527        spans: &[opentelemetry_sdk::trace::SpanData],
528    ) -> Vec<&opentelemetry_sdk::trace::SpanData> {
529        spans
530            .iter()
531            .filter(|span| span.span_kind == SpanKind::Client)
532            .collect()
533    }
534
535    fn find_span<'a>(
536        spans: &'a [opentelemetry_sdk::trace::SpanData],
537        name: &str,
538    ) -> &'a opentelemetry_sdk::trace::SpanData {
539        spans.iter().find(|span| span.name == name).unwrap()
540    }
541
542    fn string_attr<'a>(
543        span: &'a opentelemetry_sdk::trace::SpanData,
544        key: &str,
545    ) -> Option<&'a str> {
546        match attr(span, key) {
547            Some(Value::String(value)) => Some(value.as_str()),
548            _ => None,
549        }
550    }
551
552    fn i64_attr(span: &opentelemetry_sdk::trace::SpanData, key: &str) -> Option<i64> {
553        match attr(span, key) {
554            Some(Value::I64(value)) => Some(*value),
555            _ => None,
556        }
557    }
558
559    fn attr<'a>(
560        span: &'a opentelemetry_sdk::trace::SpanData,
561        key: &str,
562    ) -> Option<&'a Value> {
563        span.attributes
564            .iter()
565            .find(|kv| kv.key.as_str() == key)
566            .map(|kv| &kv.value)
567    }
568
569    fn traceparent_ids(traceparent: &str) -> (&str, &str) {
570        let mut parts = traceparent.split('-');
571        let _version = parts.next().unwrap();
572        let trace_id = parts.next().unwrap();
573        let span_id = parts.next().unwrap();
574        (trace_id, span_id)
575    }
576
577    struct TestTelemetry {
578        exporter: InMemorySpanExporter,
579        provider: TracerProvider,
580    }
581
582    impl Drop for TestTelemetry {
583        fn drop(&mut self) {
584            let _ = self.provider.shutdown();
585        }
586    }
587}