telemetry_rust/middleware/axum.rs
1//! Axum web framework middleware.
2//!
3//! Provides middleware for the Axum web framework to automatically
4//! instrument HTTP requests with OpenTelemetry tracing.
5
6// Originally retired from davidB/tracing-opentelemetry-instrumentation-sdk
7// https://github.com/davidB/tracing-opentelemetry-instrumentation-sdk/blob/d3609ac2cc699d3a24fbf89754053cc8e938e3bf/axum-tracing-opentelemetry/src/middleware/trace_extractor.rs#L53
8// which is licensed under CC0 1.0 Universal
9// https://github.com/davidB/tracing-opentelemetry-instrumentation-sdk/blob/d3609ac2cc699d3a24fbf89754053cc8e938e3bf/LICENSE
10
11use http::{Request, Response};
12use pin_project_lite::pin_project;
13use std::{
14 error::Error,
15 future::Future,
16 pin::Pin,
17 task::{Context, Poll},
18};
19use tower::{Layer, Service};
20use tracing::Span;
21use tracing_opentelemetry_instrumentation_sdk::http as otel_http;
22
23/// Function type for filtering HTTP requests by path.
24///
25/// Takes a path string and returns true if the request should be traced.
26pub type Filter = fn(&str) -> bool;
27
28/// Function type for extracting string representation from a matched path type.
29///
30/// Used to convert Axum's matched path type to a string for span attributes.
31pub type AsStr<T> = fn(&T) -> &str;
32
33/// OpenTelemetry layer for Axum applications.
34///
35/// This layer provides automatic tracing instrumentation for Axum web applications,
36/// creating spans for HTTP requests with appropriate semantic attributes.
37///
38/// The layer is generic over [`axum::extract::MatchedPath`](https://docs.rs/axum/latest/axum/extract/struct.MatchedPath.html),
39/// making it compatible with different versions of axum without being tied to any specific one.
40///
41/// # Example
42///
43/// ```rust
44/// use axum::{Router, routing};
45/// use telemetry_rust::middleware::axum::OtelAxumLayer;
46///
47/// let app: Router = axum::Router::new()
48/// .nest("/api", Router::new()) // api_routes would be your actual routes
49/// .layer(OtelAxumLayer::new(axum::extract::MatchedPath::as_str));
50/// ```
51#[derive(Debug, Clone)]
52pub struct OtelAxumLayer<P> {
53 matched_path_as_str: AsStr<P>,
54 filter: Option<Filter>,
55 inject_context: bool,
56}
57
58// add a builder like api
59impl<P> OtelAxumLayer<P> {
60 /// Creates a new OpenTelemetry layer for Axum.
61 ///
62 /// # Arguments
63 ///
64 /// * `matched_path_as_str` - [`axum::extract::MatchedPath::as_str`] or any function to convert [`axum::extract::MatchedPath`] to a `&str`
65 ///
66 /// [`axum::extract::MatchedPath::as_str`]: https://docs.rs/axum/latest/axum/extract/struct.MatchedPath.html#method.as_str
67 /// [`axum::extract::MatchedPath`]: https://docs.rs/axum/latest/axum/extract/struct.MatchedPath.html
68 pub fn new(matched_path_as_str: AsStr<P>) -> Self {
69 OtelAxumLayer {
70 matched_path_as_str,
71 filter: None,
72 inject_context: false,
73 }
74 }
75
76 /// Sets a filter function to selectively trace requests.
77 ///
78 /// # Arguments
79 ///
80 /// * `filter` - Function that returns true for paths that should be traced
81 pub fn filter(self, filter: Filter) -> Self {
82 OtelAxumLayer {
83 filter: Some(filter),
84 ..self
85 }
86 }
87
88 /// Configures whether to inject OpenTelemetry context into responses.
89 ///
90 /// # Arguments
91 ///
92 /// * `inject_context` - Whether to inject trace context into response headers
93 pub fn inject_context(self, inject_context: bool) -> Self {
94 OtelAxumLayer {
95 inject_context,
96 ..self
97 }
98 }
99}
100
101impl<S, P> Layer<S> for OtelAxumLayer<P> {
102 /// The wrapped service
103 type Service = OtelAxumService<S, P>;
104 fn layer(&self, inner: S) -> Self::Service {
105 OtelAxumService {
106 inner,
107 matched_path_as_str: self.matched_path_as_str,
108 filter: self.filter,
109 inject_context: self.inject_context,
110 }
111 }
112}
113
114/// OpenTelemetry service wrapper for Axum applications.
115///
116/// This service wraps Axum services to provide automatic HTTP request tracing
117/// with OpenTelemetry spans and context propagation.
118#[derive(Debug, Clone)]
119pub struct OtelAxumService<S, P> {
120 inner: S,
121 matched_path_as_str: AsStr<P>,
122 filter: Option<Filter>,
123 inject_context: bool,
124}
125
126impl<S, B, B2, P> Service<Request<B>> for OtelAxumService<S, P>
127where
128 S: Service<Request<B>, Response = Response<B2>> + Clone + Send + 'static,
129 S::Error: Error + 'static, //fmt::Display + 'static,
130 S::Future: Send + 'static,
131 B: Send + 'static,
132 P: Send + Sync + 'static,
133{
134 type Response = S::Response;
135 type Error = S::Error;
136 // #[allow(clippy::type_complexity)]
137 // type Future = futures_core::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
138 type Future = ResponseFuture<S::Future>;
139
140 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141 self.inner.poll_ready(cx)
142 }
143
144 fn call(&mut self, req: Request<B>) -> Self::Future {
145 use tracing_opentelemetry::OpenTelemetrySpanExt;
146 let span = if self.filter.is_none_or(|f| f(req.uri().path())) {
147 let span = otel_http::http_server::make_span_from_request(&req);
148 let matched_path = req.extensions().get::<P>();
149 let route = matched_path.map_or("", self.matched_path_as_str);
150 let method = otel_http::http_method(req.method());
151 // let client_ip = parse_x_forwarded_for(req.headers())
152 // .or_else(|| {
153 // req.extensions()
154 // .get::<ConnectInfo<SocketAddr>>()
155 // .map(|ConnectInfo(client_ip)| Cow::from(client_ip.to_string()))
156 // })
157 // .unwrap_or_default();
158 span.record("http.route", route);
159 span.record("otel.name", format!("{method} {route}").trim());
160 // span.record("trace_id", find_trace_id_from_tracing(&span));
161 // span.record("client.address", client_ip);
162 span.set_parent(otel_http::extract_context(req.headers()));
163 span
164 } else {
165 tracing::Span::none()
166 };
167 let future = {
168 let _ = span.enter();
169 self.inner.call(req)
170 };
171 ResponseFuture {
172 inner: future,
173 inject_context: self.inject_context,
174 span,
175 }
176 }
177}
178
179pin_project! {
180 /// Response future for [`Trace`].
181 ///
182 /// [`Trace`]: super::Trace
183 pub struct ResponseFuture<F> {
184 #[pin]
185 pub(crate) inner: F,
186 pub(crate) inject_context: bool,
187 pub(crate) span: Span,
188 // pub(crate) start: Instant,
189 }
190}
191
192impl<Fut, ResBody, E> Future for ResponseFuture<Fut>
193where
194 Fut: Future<Output = Result<Response<ResBody>, E>>,
195 E: std::error::Error + 'static,
196{
197 type Output = Result<Response<ResBody>, E>;
198
199 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
200 let this = self.project();
201 let _guard = this.span.enter();
202 let mut result = futures_util::ready!(this.inner.poll(cx));
203 otel_http::http_server::update_span_from_response_or_error(this.span, &result);
204 if *this.inject_context
205 && let Ok(response) = result.as_mut()
206 {
207 otel_http::inject_context(
208 &tracing_opentelemetry_instrumentation_sdk::find_current_context(),
209 response.headers_mut(),
210 );
211 }
212
213 Poll::Ready(result)
214 }
215}