1use crate::{AttachMetricLabels, CallType, callback_based, dbg_panic};
2use futures_util::{
3 FutureExt, TryFutureExt,
4 future::{BoxFuture, Either},
5};
6use std::{
7 fmt,
8 task::{Context, Poll},
9 time::{Duration, Instant},
10};
11use temporalio_common::telemetry::{
12 TaskQueueLabelStrategy,
13 metrics::{
14 Counter, CounterBase, HistogramDuration, HistogramDurationBase, MetricAttributable,
15 MetricAttributes, MetricKeyValue, MetricParameters, TemporalMeter,
16 },
17};
18use tonic::{Code, body::Body, transport::Channel};
19use tower::Service;
20
21pub static REQUEST_LATENCY_HISTOGRAM_NAME: &str = "request_latency";
23pub static LONG_REQUEST_LATENCY_HISTOGRAM_NAME: &str = "long_request_latency";
25
26#[derive(Clone, derive_more::Debug)]
28#[debug("MetricsContext {{ poll_is_long: {poll_is_long} }}")]
29pub(crate) struct MetricsContext {
30 meter: TemporalMeter,
31 poll_is_long: bool,
32 instruments: Instruments,
33}
34#[derive(Clone)]
35struct Instruments {
36 svc_request: Counter,
37 svc_request_failed: Counter,
38 long_svc_request: Counter,
39 long_svc_request_failed: Counter,
40
41 svc_request_latency: HistogramDuration,
42 long_svc_request_latency: HistogramDuration,
43}
44
45impl MetricsContext {
46 pub(crate) fn new(tm: TemporalMeter) -> Self {
47 let instruments = Instruments {
48 svc_request: tm.counter(MetricParameters {
49 name: "request".into(),
50 description: "Count of client request successes by rpc name".into(),
51 unit: "".into(),
52 }),
53 svc_request_failed: tm.counter(MetricParameters {
54 name: "request_failure".into(),
55 description: "Count of client request failures by rpc name".into(),
56 unit: "".into(),
57 }),
58 long_svc_request: tm.counter(MetricParameters {
59 name: "long_request".into(),
60 description: "Count of long-poll request successes by rpc name".into(),
61 unit: "".into(),
62 }),
63 long_svc_request_failed: tm.counter(MetricParameters {
64 name: "long_request_failure".into(),
65 description: "Count of long-poll request failures by rpc name".into(),
66 unit: "".into(),
67 }),
68 svc_request_latency: tm.histogram_duration(MetricParameters {
69 name: REQUEST_LATENCY_HISTOGRAM_NAME.into(),
70 unit: "duration".into(),
71 description: "Histogram of client request latencies".into(),
72 }),
73 long_svc_request_latency: tm.histogram_duration(MetricParameters {
74 name: LONG_REQUEST_LATENCY_HISTOGRAM_NAME.into(),
75 unit: "duration".into(),
76 description: "Histogram of client long-poll request latencies".into(),
77 }),
78 };
79 Self {
80 poll_is_long: false,
81 instruments,
82 meter: tm,
83 }
84 }
85
86 pub(crate) fn with_new_attrs(&mut self, new_kvs: impl IntoIterator<Item = MetricKeyValue>) {
88 self.meter.merge_attributes(new_kvs.into());
89
90 let _ = self
91 .instruments
92 .svc_request
93 .with_attributes(self.meter.get_default_attributes())
94 .and_then(|v| {
95 self.instruments.svc_request = v;
96 self.instruments
97 .long_svc_request
98 .with_attributes(self.meter.get_default_attributes())
99 })
100 .and_then(|v| {
101 self.instruments.long_svc_request = v;
102 self.instruments
103 .svc_request_latency
104 .with_attributes(self.meter.get_default_attributes())
105 })
106 .and_then(|v| {
107 self.instruments.svc_request_latency = v;
108 self.instruments
109 .long_svc_request_latency
110 .with_attributes(self.meter.get_default_attributes())
111 })
112 .map(|v| {
113 self.instruments.long_svc_request_latency = v;
114 })
115 .inspect_err(|e| {
116 dbg_panic!("Failed to extend client metrics attributes: {:?}", e);
117 });
118 }
119
120 pub(crate) fn set_is_long_poll(&mut self) {
121 self.poll_is_long = true;
122 }
123
124 pub(crate) fn svc_request(&self) {
126 if self.poll_is_long {
127 self.instruments.long_svc_request.adds(1);
128 } else {
129 self.instruments.svc_request.adds(1);
130 }
131 }
132
133 pub(crate) fn svc_request_failed(&self, code: Option<Code>) {
135 self.svc_request_failed_with_label(code.map(status_code_kv));
136 }
137
138 pub(crate) fn svc_request_failed_transport(&self) {
141 self.svc_request_failed_with_label(Some(transport_error_kv()));
142 }
143
144 fn svc_request_failed_with_label(&self, label: Option<MetricKeyValue>) {
145 let refme: MetricAttributes;
146 let kvs = if let Some(kv) = label {
147 refme = self
148 .meter
149 .extend_attributes(self.meter.get_default_attributes().clone(), [kv].into());
150 &refme
151 } else {
152 self.meter.get_default_attributes()
153 };
154 if self.poll_is_long {
155 self.instruments.long_svc_request_failed.add(1, kvs);
156 } else {
157 self.instruments.svc_request_failed.add(1, kvs);
158 }
159 }
160
161 pub(crate) fn record_svc_req_latency(&self, dur: Duration) {
163 if self.poll_is_long {
164 self.instruments.long_svc_request_latency.records(dur);
165 } else {
166 self.instruments.svc_request_latency.records(dur);
167 }
168 }
169}
170
171const KEY_NAMESPACE: &str = "namespace";
172const KEY_SVC_METHOD: &str = "operation";
173const KEY_TASK_QUEUE: &str = "task_queue";
174const KEY_STATUS_CODE: &str = "status_code";
175
176pub(crate) fn namespace_kv(ns: String) -> MetricKeyValue {
177 MetricKeyValue::new(KEY_NAMESPACE, ns)
178}
179
180pub(crate) fn task_queue_kv(tq: String) -> MetricKeyValue {
181 MetricKeyValue::new(KEY_TASK_QUEUE, tq)
182}
183
184pub(crate) fn svc_operation(op: String) -> MetricKeyValue {
185 MetricKeyValue::new(KEY_SVC_METHOD, op)
186}
187
188pub(crate) fn status_code_kv(code: Code) -> MetricKeyValue {
189 MetricKeyValue::new(KEY_STATUS_CODE, code_as_screaming_snake(&code))
190}
191
192fn transport_error_kv() -> MetricKeyValue {
193 MetricKeyValue::new(KEY_STATUS_CODE, "TRANSPORT_ERROR")
194}
195
196fn code_as_screaming_snake(code: &Code) -> &'static str {
198 match code {
199 Code::Ok => "OK",
200 Code::Cancelled => "CANCELLED",
201 Code::Unknown => "UNKNOWN",
202 Code::InvalidArgument => "INVALID_ARGUMENT",
203 Code::DeadlineExceeded => "DEADLINE_EXCEEDED",
204 Code::NotFound => "NOT_FOUND",
205 Code::AlreadyExists => "ALREADY_EXISTS",
206 Code::PermissionDenied => "PERMISSION_DENIED",
207 Code::ResourceExhausted => "RESOURCE_EXHAUSTED",
208 Code::FailedPrecondition => "FAILED_PRECONDITION",
209 Code::Aborted => "ABORTED",
210 Code::OutOfRange => "OUT_OF_RANGE",
211 Code::Unimplemented => "UNIMPLEMENTED",
212 Code::Internal => "INTERNAL",
213 Code::Unavailable => "UNAVAILABLE",
214 Code::DataLoss => "DATA_LOSS",
215 Code::Unauthenticated => "UNAUTHENTICATED",
216 }
217}
218
219#[derive(Debug, Clone)]
221pub(crate) struct GrpcMetricSvc {
222 pub(crate) inner: ChannelOrGrpcOverride,
223 pub(crate) metrics: Option<MetricsContext>,
225 pub(crate) disable_errcode_label: bool,
226}
227
228#[derive(Clone)]
229pub(crate) enum ChannelOrGrpcOverride {
230 Channel(Channel),
231 GrpcOverride(callback_based::CallbackBasedGrpcService),
232}
233
234impl fmt::Debug for ChannelOrGrpcOverride {
235 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
236 match self {
237 ChannelOrGrpcOverride::Channel(inner) => fmt::Debug::fmt(inner, f),
238 ChannelOrGrpcOverride::GrpcOverride(_) => f.write_str("<callback-based-grpc-service>"),
239 }
240 }
241}
242
243impl Service<http::Request<Body>> for GrpcMetricSvc {
245 type Response = http::Response<Body>;
246 type Error = Box<dyn std::error::Error + Send + Sync>;
247 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
248
249 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
250 match &mut self.inner {
251 ChannelOrGrpcOverride::Channel(inner) => inner.poll_ready(cx).map_err(Into::into),
252 ChannelOrGrpcOverride::GrpcOverride(inner) => inner.poll_ready(cx).map_err(Into::into),
253 }
254 }
255
256 fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
257 let metrics = self
258 .metrics
259 .clone()
260 .map(|mut m| {
261 if let Some(other_labels) = req.extensions_mut().remove::<AttachMetricLabels>() {
263 m.with_new_attrs(other_labels.labels);
264
265 if other_labels.normal_task_queue.is_some()
266 || other_labels.sticky_task_queue.is_some()
267 {
268 let task_queue_name = match m.meter.get_task_queue_label_strategy() {
269 TaskQueueLabelStrategy::UseNormal => other_labels.normal_task_queue,
270 TaskQueueLabelStrategy::UseNormalAndSticky => other_labels
271 .sticky_task_queue
272 .or(other_labels.normal_task_queue),
273 _ => other_labels.normal_task_queue,
274 };
275
276 if let Some(tq_name) = task_queue_name {
277 m.with_new_attrs([task_queue_kv(tq_name)]);
278 }
279 }
280 }
281 if let Some(ct) = req.extensions().get::<CallType>()
282 && ct.is_long()
283 {
284 m.set_is_long_poll();
285 }
286 m
287 })
288 .and_then(|mut metrics| {
289 req.uri().to_string().rsplit_once('/').map(|split_tup| {
291 let method_name = split_tup.1;
292 metrics.with_new_attrs([svc_operation(method_name.to_string())]);
293 metrics.svc_request();
294 metrics
295 })
296 });
297 let callfut = match &mut self.inner {
298 ChannelOrGrpcOverride::Channel(inner) => {
299 Either::Left(inner.call(req).map_err(Into::into))
300 }
301 ChannelOrGrpcOverride::GrpcOverride(inner) => {
302 Either::Right(inner.call(req).map_err(Into::into))
303 }
304 };
305 let errcode_label_disabled = self.disable_errcode_label;
306 async move {
307 let started = Instant::now();
308 let res = callfut.await;
309 if let Some(metrics) = metrics {
310 metrics.record_svc_req_latency(started.elapsed());
311 match res {
312 Ok(ref ok_res) => {
313 if let Some(number) = ok_res
314 .headers()
315 .get("grpc-status")
316 .and_then(|s| s.to_str().ok())
317 .and_then(|s| s.parse::<i32>().ok())
318 {
319 let code = Code::from(number);
320 if code != Code::Ok {
321 let code = if errcode_label_disabled {
322 None
323 } else {
324 Some(code)
325 };
326 metrics.svc_request_failed(code);
327 }
328 }
329 }
330 Err(_) => {
331 if !errcode_label_disabled {
335 metrics.svc_request_failed_transport();
336 } else {
337 metrics.svc_request_failed(None);
338 }
339 }
340 }
341 }
342 res
343 }
344 .boxed()
345 }
346}