spring_opentelemetry/trace/
grpc.rs

1//! Middleware that adds tracing to a [`Service`] that handles gRPC requests.
2//! refs: https://opentelemetry.io/docs/specs/semconv/rpc/grpc/
3
4use super::SpanKind;
5use http::{Request, Response};
6use opentelemetry_http::{HeaderExtractor, HeaderInjector};
7use opentelemetry_semantic_conventions::attribute::{
8    EXCEPTION_MESSAGE, OTEL_STATUS_CODE, RPC_GRPC_STATUS_CODE,
9};
10use pin_project::pin_project;
11use std::{
12    fmt::Display,
13    future::Future,
14    pin::Pin,
15    task::{ready, Context, Poll},
16};
17use tower::{Layer, Service};
18use tracing::{Level, Span};
19use tracing_opentelemetry::OpenTelemetrySpanExt;
20
21/// [`Layer`] that adds tracing to a [`Service`] that handles gRPC requests.
22#[derive(Clone, Debug)]
23pub struct GrpcLayer {
24    level: Level,
25    kind: SpanKind,
26}
27
28impl GrpcLayer {
29    /// [`Span`]s are constructed at the given level from server side.
30    pub fn server(level: Level) -> Self {
31        Self {
32            level,
33            kind: SpanKind::Server,
34        }
35    }
36
37    /// [`Span`]s are constructed at the given level from client side.
38    pub fn client(level: Level) -> Self {
39        Self {
40            level,
41            kind: SpanKind::Client,
42        }
43    }
44}
45
46impl<S> Layer<S> for GrpcLayer {
47    type Service = GrpcService<S>;
48
49    fn layer(&self, inner: S) -> Self::Service {
50        GrpcService {
51            inner,
52            level: self.level,
53            kind: self.kind,
54        }
55    }
56}
57
58/// Middleware that adds tracing to a [`Service`] that handles gRPC requests.
59#[derive(Clone, Debug)]
60pub struct GrpcService<S> {
61    inner: S,
62    level: Level,
63    kind: SpanKind,
64}
65
66impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for GrpcService<S>
67where
68    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
69    S::Error: Display,
70{
71    type Response = S::Response;
72    type Error = S::Error;
73    type Future = ResponseFuture<S::Future>;
74
75    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76        self.inner.poll_ready(cx)
77    }
78
79    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
80        let span = self.make_request_span(&mut req);
81        let inner = {
82            let _enter = span.enter();
83            self.inner.call(req)
84        };
85
86        ResponseFuture {
87            inner,
88            span,
89            kind: self.kind,
90        }
91    }
92}
93
94impl<S> GrpcService<S> {
95    /// Creates a new [`Span`] for the given request.
96    fn make_request_span<B>(&self, request: &mut Request<B>) -> Span {
97        let Self { level, kind, .. } = self;
98
99        macro_rules! make_span {
100            ($level:expr) => {{
101                use tracing::field::Empty;
102
103                tracing::span!(
104                    $level,
105                    "GRPC",
106                    "exception.message" = Empty,
107                    "otel.kind" = tracing::field::debug(kind),
108                    "otel.name" = Empty,
109                    "otel.status_code" = Empty,
110                    "rpc.grpc.status_code" = Empty,
111                    "rpc.method" = Empty,
112                    "rpc.service" = Empty,
113                    "rpc.system" = "grpc",
114                )
115            }};
116        }
117
118        let span = match *level {
119            Level::ERROR => make_span!(Level::ERROR),
120            Level::WARN => make_span!(Level::WARN),
121            Level::INFO => make_span!(Level::INFO),
122            Level::DEBUG => make_span!(Level::DEBUG),
123            Level::TRACE => make_span!(Level::TRACE),
124        };
125
126        let path = request.uri().path();
127        let name = path.trim_start_matches('/');
128        span.record("otel.name", name);
129        if let Some((service, method)) = name.split_once('/') {
130            span.record("rpc.service", service);
131            span.record("rpc.method", method);
132        }
133
134        let capture_request_headers = kind.capture_request_headers();
135
136        for (header_name, header_value) in request.headers().iter() {
137            let header_name = header_name.as_str().to_lowercase();
138            if capture_request_headers.contains(&header_name) {
139                if let Ok(attribute_value) = header_value.to_str() {
140                    // attribute::RPC_GRPC_REQUEST_METADATA
141                    let attribute_name = format!("rpc.grpc.request.metadata.{header_name}");
142                    span.set_attribute(attribute_name, attribute_value.to_owned());
143                }
144            }
145        }
146
147        match kind {
148            SpanKind::Client => {
149                let context = span.context();
150                opentelemetry::global::get_text_map_propagator(|injector| {
151                    injector.inject_context(&context, &mut HeaderInjector(request.headers_mut()));
152                });
153            }
154            SpanKind::Server => {
155                let context = opentelemetry::global::get_text_map_propagator(|extractor| {
156                    extractor.extract(&HeaderExtractor(request.headers()))
157                });
158                span.set_parent(context);
159            }
160        }
161
162        span
163    }
164}
165
166/// Response future for [`GrpcService`].
167#[pin_project]
168pub struct ResponseFuture<F> {
169    #[pin]
170    inner: F,
171    span: Span,
172    kind: SpanKind,
173}
174
175impl<F, ResBody, E> Future for ResponseFuture<F>
176where
177    F: Future<Output = Result<Response<ResBody>, E>>,
178    E: Display,
179{
180    type Output = Result<Response<ResBody>, E>;
181
182    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
183        let this = self.project();
184        let _enter = this.span.enter();
185
186        match ready!(this.inner.poll(cx)) {
187            Ok(response) => {
188                Self::record_response(this.span, *this.kind, &response);
189                Poll::Ready(Ok(response))
190            }
191            Err(err) => {
192                Self::record_error(this.span, &err);
193                Poll::Ready(Err(err))
194            }
195        }
196    }
197}
198
199impl<F, ResBody, E> ResponseFuture<F>
200where
201    F: Future<Output = Result<Response<ResBody>, E>>,
202    E: Display,
203{
204    /// Records fields associated to the response.
205    fn record_response<B>(span: &Span, kind: SpanKind, response: &Response<B>) {
206        let capture_response_headers = kind.capture_response_headers();
207
208        for (header_name, header_value) in response.headers().iter() {
209            let header_name = header_name.as_str().to_lowercase();
210            if capture_response_headers.contains(&header_name) {
211                if let Ok(attribute_value) = header_value.to_str() {
212                    let attribute_name: String =
213                        format!("rpc.grpc.response.metadata.{header_name}");
214                    span.set_attribute(attribute_name, attribute_value.to_owned());
215                }
216            }
217        }
218
219        if let Some(header_value) = response.headers().get("grpc-status") {
220            if let Ok(header_value) = header_value.to_str() {
221                if let Ok(status_code) = header_value.parse::<i32>() {
222                    span.record(RPC_GRPC_STATUS_CODE, status_code);
223                }
224            }
225        } else {
226            span.record(RPC_GRPC_STATUS_CODE, 0);
227        }
228    }
229
230    /// Records the error message.
231    fn record_error(span: &Span, err: &E) {
232        span.record(OTEL_STATUS_CODE, "ERROR");
233        span.record(EXCEPTION_MESSAGE, err.to_string());
234    }
235}