1use super::{ClassifiedResponse, ClassifyEos, ClassifyResponse, SharedClassifier};
2use bitflags::bitflags;
3use http::{HeaderMap, Response};
4use std::{fmt, num::NonZeroI32};
5
6#[derive(Clone, Copy, Debug)]
12pub enum GrpcCode {
13 Ok,
15 Cancelled,
17 Unknown,
19 InvalidArgument,
21 DeadlineExceeded,
23 NotFound,
25 AlreadyExists,
27 PermissionDenied,
29 ResourceExhausted,
31 FailedPrecondition,
33 Aborted,
35 OutOfRange,
37 Unimplemented,
39 Internal,
41 Unavailable,
43 DataLoss,
45 Unauthenticated,
47}
48
49impl GrpcCode {
50 pub(crate) fn into_bitmask(self) -> GrpcCodeBitmask {
51 match self {
52 Self::Ok => GrpcCodeBitmask::OK,
53 Self::Cancelled => GrpcCodeBitmask::CANCELLED,
54 Self::Unknown => GrpcCodeBitmask::UNKNOWN,
55 Self::InvalidArgument => GrpcCodeBitmask::INVALID_ARGUMENT,
56 Self::DeadlineExceeded => GrpcCodeBitmask::DEADLINE_EXCEEDED,
57 Self::NotFound => GrpcCodeBitmask::NOT_FOUND,
58 Self::AlreadyExists => GrpcCodeBitmask::ALREADY_EXISTS,
59 Self::PermissionDenied => GrpcCodeBitmask::PERMISSION_DENIED,
60 Self::ResourceExhausted => GrpcCodeBitmask::RESOURCE_EXHAUSTED,
61 Self::FailedPrecondition => GrpcCodeBitmask::FAILED_PRECONDITION,
62 Self::Aborted => GrpcCodeBitmask::ABORTED,
63 Self::OutOfRange => GrpcCodeBitmask::OUT_OF_RANGE,
64 Self::Unimplemented => GrpcCodeBitmask::UNIMPLEMENTED,
65 Self::Internal => GrpcCodeBitmask::INTERNAL,
66 Self::Unavailable => GrpcCodeBitmask::UNAVAILABLE,
67 Self::DataLoss => GrpcCodeBitmask::DATA_LOSS,
68 Self::Unauthenticated => GrpcCodeBitmask::UNAUTHENTICATED,
69 }
70 }
71}
72
73impl From<i32> for GrpcCode {
77 fn from(value: i32) -> Self {
78 match value {
79 0 => GrpcCode::Ok,
80 1 => GrpcCode::Cancelled,
81 2 => GrpcCode::Unknown,
82 3 => GrpcCode::InvalidArgument,
83 4 => GrpcCode::DeadlineExceeded,
84 5 => GrpcCode::NotFound,
85 6 => GrpcCode::AlreadyExists,
86 7 => GrpcCode::PermissionDenied,
87 8 => GrpcCode::ResourceExhausted,
88 9 => GrpcCode::FailedPrecondition,
89 10 => GrpcCode::Aborted,
90 11 => GrpcCode::OutOfRange,
91 12 => GrpcCode::Unimplemented,
92 13 => GrpcCode::Internal,
93 14 => GrpcCode::Unavailable,
94 15 => GrpcCode::DataLoss,
95 16 => GrpcCode::Unauthenticated,
96
97 _ => GrpcCode::Unknown,
98 }
99 }
100}
101
102impl From<NonZeroI32> for GrpcCode {
103 fn from(value: NonZeroI32) -> Self {
104 GrpcCode::from(value.get())
105 }
106}
107
108bitflags! {
109 #[derive(Debug, Clone, Copy)]
110 pub(crate) struct GrpcCodeBitmask: u32 {
111 const OK = 0b00000000000000001;
112 const CANCELLED = 0b00000000000000010;
113 const UNKNOWN = 0b00000000000000100;
114 const INVALID_ARGUMENT = 0b00000000000001000;
115 const DEADLINE_EXCEEDED = 0b00000000000010000;
116 const NOT_FOUND = 0b00000000000100000;
117 const ALREADY_EXISTS = 0b00000000001000000;
118 const PERMISSION_DENIED = 0b00000000010000000;
119 const RESOURCE_EXHAUSTED = 0b00000000100000000;
120 const FAILED_PRECONDITION = 0b00000001000000000;
121 const ABORTED = 0b00000010000000000;
122 const OUT_OF_RANGE = 0b00000100000000000;
123 const UNIMPLEMENTED = 0b00001000000000000;
124 const INTERNAL = 0b00010000000000000;
125 const UNAVAILABLE = 0b00100000000000000;
126 const DATA_LOSS = 0b01000000000000000;
127 const UNAUTHENTICATED = 0b10000000000000000;
128 }
129}
130
131impl GrpcCodeBitmask {
132 fn try_from_u32(code: u32) -> Option<Self> {
133 match code {
134 0 => Some(Self::OK),
135 1 => Some(Self::CANCELLED),
136 2 => Some(Self::UNKNOWN),
137 3 => Some(Self::INVALID_ARGUMENT),
138 4 => Some(Self::DEADLINE_EXCEEDED),
139 5 => Some(Self::NOT_FOUND),
140 6 => Some(Self::ALREADY_EXISTS),
141 7 => Some(Self::PERMISSION_DENIED),
142 8 => Some(Self::RESOURCE_EXHAUSTED),
143 9 => Some(Self::FAILED_PRECONDITION),
144 10 => Some(Self::ABORTED),
145 11 => Some(Self::OUT_OF_RANGE),
146 12 => Some(Self::UNIMPLEMENTED),
147 13 => Some(Self::INTERNAL),
148 14 => Some(Self::UNAVAILABLE),
149 15 => Some(Self::DATA_LOSS),
150 16 => Some(Self::UNAUTHENTICATED),
151 _ => None,
152 }
153 }
154}
155
156#[derive(Debug, Clone)]
170pub struct GrpcErrorsAsFailures {
171 success_codes: GrpcCodeBitmask,
172}
173
174impl Default for GrpcErrorsAsFailures {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180impl GrpcErrorsAsFailures {
181 pub fn new() -> Self {
183 Self {
184 success_codes: GrpcCodeBitmask::OK,
185 }
186 }
187
188 pub fn with_success(mut self, code: GrpcCode) -> Self {
207 self.success_codes |= code.into_bitmask();
208 self
209 }
210
211 pub fn make_classifier() -> SharedClassifier<Self> {
215 SharedClassifier::new(Self::new())
216 }
217}
218
219impl ClassifyResponse for GrpcErrorsAsFailures {
220 type FailureClass = GrpcFailureClass;
221 type ClassifyEos = GrpcEosErrorsAsFailures;
222
223 fn classify_response<B>(
224 self,
225 res: &Response<B>,
226 ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
227 match classify_grpc_metadata(res.headers(), self.success_codes) {
228 ParsedGrpcStatus::Success
229 | ParsedGrpcStatus::HeaderNotString
230 | ParsedGrpcStatus::HeaderNotInt => ClassifiedResponse::Ready(Ok(())),
231 ParsedGrpcStatus::NonSuccess(status) => {
232 ClassifiedResponse::Ready(Err(GrpcFailureClass::Code(status)))
233 }
234 ParsedGrpcStatus::GrpcStatusHeaderMissing => {
235 ClassifiedResponse::RequiresEos(GrpcEosErrorsAsFailures {
236 success_codes: self.success_codes,
237 })
238 }
239 }
240 }
241
242 fn classify_error<E>(self, error: &E) -> Self::FailureClass
243 where
244 E: fmt::Display + 'static,
245 {
246 GrpcFailureClass::Error(error.to_string())
247 }
248}
249
250#[derive(Debug, Clone)]
252pub struct GrpcEosErrorsAsFailures {
253 success_codes: GrpcCodeBitmask,
254}
255
256impl ClassifyEos for GrpcEosErrorsAsFailures {
257 type FailureClass = GrpcFailureClass;
258
259 fn classify_eos(self, trailers: Option<&HeaderMap>) -> Result<(), Self::FailureClass> {
260 if let Some(trailers) = trailers {
261 match classify_grpc_metadata(trailers, self.success_codes) {
262 ParsedGrpcStatus::Success
263 | ParsedGrpcStatus::GrpcStatusHeaderMissing
264 | ParsedGrpcStatus::HeaderNotString
265 | ParsedGrpcStatus::HeaderNotInt => Ok(()),
266 ParsedGrpcStatus::NonSuccess(status) => Err(GrpcFailureClass::Code(status)),
267 }
268 } else {
269 Ok(())
270 }
271 }
272
273 fn classify_error<E>(self, error: &E) -> Self::FailureClass
274 where
275 E: fmt::Display + 'static,
276 {
277 GrpcFailureClass::Error(error.to_string())
278 }
279}
280
281#[derive(Debug)]
283pub enum GrpcFailureClass {
284 Code(std::num::NonZeroI32),
286 Error(String),
288}
289
290impl fmt::Display for GrpcFailureClass {
291 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
292 match self {
293 Self::Code(code) => write!(f, "Code: {}", code),
294 Self::Error(error) => write!(f, "Error: {}", error),
295 }
296 }
297}
298
299pub(crate) fn classify_grpc_metadata(
300 headers: &HeaderMap,
301 success_codes: GrpcCodeBitmask,
302) -> ParsedGrpcStatus {
303 macro_rules! or_else {
304 ($expr:expr, $other:ident) => {
305 if let Some(value) = $expr {
306 value
307 } else {
308 return ParsedGrpcStatus::$other;
309 }
310 };
311 }
312
313 let status = or_else!(headers.get("grpc-status"), GrpcStatusHeaderMissing);
314 let status = or_else!(status.to_str().ok(), HeaderNotString);
315 let status = or_else!(status.parse::<i32>().ok(), HeaderNotInt);
316
317 if GrpcCodeBitmask::try_from_u32(status as _)
318 .filter(|code| success_codes.contains(*code))
319 .is_some()
320 {
321 ParsedGrpcStatus::Success
322 } else {
323 ParsedGrpcStatus::NonSuccess(NonZeroI32::new(status).unwrap())
324 }
325}
326
327#[derive(Debug, PartialEq, Eq)]
328pub(crate) enum ParsedGrpcStatus {
329 Success,
330 NonSuccess(NonZeroI32),
331 GrpcStatusHeaderMissing,
332 HeaderNotString,
334 HeaderNotInt,
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 macro_rules! classify_grpc_metadata_test {
342 (
343 name: $name:ident,
344 status: $status:expr,
345 success_flags: $success_flags:expr,
346 expected: $expected:expr,
347 ) => {
348 #[test]
349 fn $name() {
350 let mut headers = HeaderMap::new();
351 headers.insert("grpc-status", $status.parse().unwrap());
352 let status = classify_grpc_metadata(&headers, $success_flags);
353 assert_eq!(status, $expected);
354 }
355 };
356 }
357
358 classify_grpc_metadata_test! {
359 name: basic_ok,
360 status: "0",
361 success_flags: GrpcCodeBitmask::OK,
362 expected: ParsedGrpcStatus::Success,
363 }
364
365 classify_grpc_metadata_test! {
366 name: basic_error,
367 status: "1",
368 success_flags: GrpcCodeBitmask::OK,
369 expected: ParsedGrpcStatus::NonSuccess(NonZeroI32::new(1).unwrap()),
370 }
371
372 classify_grpc_metadata_test! {
373 name: two_success_codes_first_matches,
374 status: "0",
375 success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT,
376 expected: ParsedGrpcStatus::Success,
377 }
378
379 classify_grpc_metadata_test! {
380 name: two_success_codes_second_matches,
381 status: "3",
382 success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT,
383 expected: ParsedGrpcStatus::Success,
384 }
385
386 classify_grpc_metadata_test! {
387 name: two_success_codes_none_matches,
388 status: "16",
389 success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT,
390 expected: ParsedGrpcStatus::NonSuccess(NonZeroI32::new(16).unwrap()),
391 }
392
393 #[test]
394 fn grpc_code_from_i32_known_codes() {
395 assert!(matches!(GrpcCode::from(0), GrpcCode::Ok));
396 assert!(matches!(GrpcCode::from(1), GrpcCode::Cancelled));
397 assert!(matches!(GrpcCode::from(4), GrpcCode::DeadlineExceeded));
398 assert!(matches!(GrpcCode::from(13), GrpcCode::Internal));
399 assert!(matches!(GrpcCode::from(16), GrpcCode::Unauthenticated));
400 }
401
402 #[test]
403 fn grpc_code_from_i32_unknown_codes() {
404 assert!(matches!(GrpcCode::from(17), GrpcCode::Unknown));
405 assert!(matches!(GrpcCode::from(-1), GrpcCode::Unknown));
406 assert!(matches!(GrpcCode::from(9999), GrpcCode::Unknown));
407 }
408
409 #[test]
410 fn grpc_code_from_non_zero_i32() {
411 let code = NonZeroI32::new(7).unwrap();
412 assert!(matches!(GrpcCode::from(code), GrpcCode::PermissionDenied));
413
414 let code = NonZeroI32::new(99).unwrap();
415 assert!(matches!(GrpcCode::from(code), GrpcCode::Unknown));
416 }
417}