1use std::{
4 fmt::Display,
5 future::Future,
6 pin::Pin,
7 task::{ready, Context, Poll},
8};
9
10use http::{Request, Response};
11use pin_project::pin_project;
12use tower_layer::Layer;
13use tower_service::Service;
14use tracing::{Level, Span};
15use tracing_opentelemetry::OpenTelemetrySpanExt;
16
17use crate::{
18 trace::{extractor::HeaderExtractor, injector::HeaderInjector},
19 util,
20};
21
22#[derive(Clone, Copy, Debug)]
24enum SpanKind {
25 Client,
27 Server,
29}
30
31#[derive(Clone, Debug)]
33pub struct HttpLayer {
34 level: Level,
35 kind: SpanKind,
36}
37
38impl HttpLayer {
39 pub fn server(level: Level) -> Self {
41 Self {
42 level,
43 kind: SpanKind::Server,
44 }
45 }
46
47 pub fn client(level: Level) -> Self {
49 Self {
50 level,
51 kind: SpanKind::Client,
52 }
53 }
54}
55
56impl<S> Layer<S> for HttpLayer {
57 type Service = Http<S>;
58
59 fn layer(&self, inner: S) -> Self::Service {
60 Http {
61 inner,
62 level: self.level,
63 kind: self.kind,
64 }
65 }
66}
67
68#[derive(Clone, Debug)]
70pub struct Http<S> {
71 inner: S,
72 level: Level,
73 kind: SpanKind,
74}
75
76impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Http<S>
77where
78 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
79 S::Error: Display,
80{
81 type Response = S::Response;
82 type Error = S::Error;
83 type Future = ResponseFuture<S::Future>;
84
85 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86 self.inner.poll_ready(cx)
87 }
88
89 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
90 let span = make_request_span(self.level, self.kind, &mut req);
91 let inner = {
92 let _enter = span.enter();
93 self.inner.call(req)
94 };
95
96 ResponseFuture {
97 inner,
98 span,
99 kind: self.kind,
100 }
101 }
102}
103
104#[pin_project]
106pub struct ResponseFuture<F> {
107 #[pin]
108 inner: F,
109 span: Span,
110 kind: SpanKind,
111}
112
113impl<F, ResBody, E> Future for ResponseFuture<F>
114where
115 F: Future<Output = Result<Response<ResBody>, E>>,
116 E: Display,
117{
118 type Output = Result<Response<ResBody>, E>;
119
120 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
121 let this = self.project();
122 let _enter = this.span.enter();
123
124 match ready!(this.inner.poll(cx)) {
125 Ok(response) => {
126 record_response(this.span, *this.kind, &response);
127 Poll::Ready(Ok(response))
128 }
129 Err(err) => {
130 record_error(this.span, &err);
131 Poll::Ready(Err(err))
132 }
133 }
134 }
135}
136
137fn span_kind(kind: SpanKind) -> &'static str {
139 match kind {
140 SpanKind::Client => "client",
141 SpanKind::Server => "server",
142 }
143}
144
145fn make_request_span<B>(level: Level, kind: SpanKind, request: &mut Request<B>) -> Span {
147 macro_rules! make_span {
148 ($level:expr) => {{
149 use tracing::field::Empty;
150
151 tracing::span!(
152 $level,
153 "HTTP",
154 "client.address" = Empty,
155 "client.port" = Empty,
156 "error.message" = Empty,
157 "http.request.method" = util::http_method(request.method()),
158 "http.response.status_code" = Empty,
159 "network.protocol.name" = "http",
160 "network.protocol.version" = util::http_version(request.version()),
161 "otel.kind" = span_kind(kind),
162 "otel.status_code" = Empty,
163 "server.address" = Empty,
164 "server.port" = Empty,
165 "url.full" = Empty,
166 "url.path" = request.uri().path(),
167 "url.query" = Empty,
168 "url.scheme" = Empty,
169 )
170 }};
171 }
172
173 let span = match level {
174 Level::ERROR => make_span!(Level::ERROR),
175 Level::WARN => make_span!(Level::WARN),
176 Level::INFO => make_span!(Level::INFO),
177 Level::DEBUG => make_span!(Level::DEBUG),
178 Level::TRACE => make_span!(Level::TRACE),
179 };
180
181 for (header_name, header_value) in request.headers().iter() {
182 if let Ok(attribute_value) = header_value.to_str() {
183 let attribute_name = format!("http.request.header.{}", header_name);
184 span.set_attribute(attribute_name, attribute_value.to_owned());
185 }
186 }
187
188 if let Some(query) = request.uri().query() {
189 span.record("url.query", query);
190 }
191
192 match kind {
193 SpanKind::Client => {
194 span.record("url.full", tracing::field::display(request.uri()));
195
196 let util::HttpRequestAttributes {
197 url_scheme,
198 server_address,
199 server_port,
200 } = util::HttpRequestAttributes::from_sent_request(request);
201
202 if let Some(server_address) = server_address {
203 span.record("server.address", server_address);
204 }
205 if let Some(server_port) = server_port {
206 span.record("server.port", server_port);
207 }
208 if let Some(url_scheme) = url_scheme {
209 span.record("url.scheme", url_scheme);
210 }
211
212 let context = span.context();
213 opentelemetry::global::get_text_map_propagator(|injector| {
214 injector.inject_context(&context, &mut HeaderInjector(request.headers_mut()));
215 });
216 }
217 SpanKind::Server => {
218 if let Some(http_route) = util::http_route(request) {
219 span.record("http.route", http_route);
220 }
221 if let Some(client_address) = util::client_address(request) {
222 let ip = client_address.ip();
223 span.record("client.address", tracing::field::display(ip));
224 span.record("client.port", client_address.port());
225 }
226
227 let util::HttpRequestAttributes {
228 url_scheme,
229 server_address,
230 server_port,
231 } = util::HttpRequestAttributes::from_recv_request(request);
232
233 if let Some(server_address) = server_address {
234 span.record("server.address", server_address);
235 }
236 if let Some(server_port) = server_port {
237 span.record("server.port", server_port);
238 }
239 if let Some(url_scheme) = url_scheme {
240 span.record("url.scheme", url_scheme);
241 }
242
243 let context = opentelemetry::global::get_text_map_propagator(|extractor| {
244 extractor.extract(&HeaderExtractor(request.headers_mut()))
245 });
246 if let Err(err) = span.set_parent(context) {
247 tracing::warn!("Failed to set parent span: {err}");
248 }
249 }
250 }
251
252 span
253}
254
255fn record_response<B>(span: &Span, kind: SpanKind, response: &Response<B>) {
257 span.record(
258 "http.response.status_code",
259 response.status().as_u16() as i64,
260 );
261
262 for (header_name, header_value) in response.headers().iter() {
263 if let Ok(attribute_value) = header_value.to_str() {
264 let attribute_name = format!("http.response.header.{}", header_name);
265 span.set_attribute(attribute_name, attribute_value.to_owned());
266 }
267 }
268
269 if let SpanKind::Client = kind {
270 if response.status().is_client_error() {
271 span.record("otel.status_code", "ERROR");
272 }
273 }
274 if response.status().is_server_error() {
275 span.record("otel.status_code", "ERROR");
276 }
277}
278
279fn record_error<E: Display>(span: &Span, err: &E) {
281 span.record("otel.status_code", "ERROR");
282 span.record("error.message", err.to_string());
283}