Skip to main content

tower_http/classify/
grpc_errors_as_failures.rs

1use super::{ClassifiedResponse, ClassifyEos, ClassifyResponse, SharedClassifier};
2use bitflags::bitflags;
3use http::{HeaderMap, Response};
4use percent_encoding::percent_decode;
5use std::{fmt, num::NonZeroI32};
6
7/// gRPC status codes.
8///
9/// These variants match the [gRPC status codes].
10///
11/// [gRPC status codes]: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13#[repr(i32)]
14#[non_exhaustive]
15pub enum GrpcCode {
16    /// The operation completed successfully.
17    Ok = 0,
18    /// The operation was cancelled.
19    Cancelled = 1,
20    /// Unknown error.
21    Unknown = 2,
22    /// Client specified an invalid argument.
23    InvalidArgument = 3,
24    /// Deadline expired before operation could complete.
25    DeadlineExceeded = 4,
26    /// Some requested entity was not found.
27    NotFound = 5,
28    /// Some entity that we attempted to create already exists.
29    AlreadyExists = 6,
30    /// The caller does not have permission to execute the specified operation.
31    PermissionDenied = 7,
32    /// Some resource has been exhausted.
33    ResourceExhausted = 8,
34    /// The system is not in a state required for the operation's execution.
35    FailedPrecondition = 9,
36    /// The operation was aborted.
37    Aborted = 10,
38    /// Operation was attempted past the valid range.
39    OutOfRange = 11,
40    /// Operation is not implemented or not supported.
41    Unimplemented = 12,
42    /// Internal error.
43    Internal = 13,
44    /// The service is currently unavailable.
45    Unavailable = 14,
46    /// Unrecoverable data loss or corruption.
47    DataLoss = 15,
48    /// The request does not have valid authentication credentials
49    Unauthenticated = 16,
50}
51
52impl GrpcCode {
53    pub(crate) fn into_bitmask(self) -> GrpcCodeBitmask {
54        match self {
55            Self::Ok => GrpcCodeBitmask::OK,
56            Self::Cancelled => GrpcCodeBitmask::CANCELLED,
57            Self::Unknown => GrpcCodeBitmask::UNKNOWN,
58            Self::InvalidArgument => GrpcCodeBitmask::INVALID_ARGUMENT,
59            Self::DeadlineExceeded => GrpcCodeBitmask::DEADLINE_EXCEEDED,
60            Self::NotFound => GrpcCodeBitmask::NOT_FOUND,
61            Self::AlreadyExists => GrpcCodeBitmask::ALREADY_EXISTS,
62            Self::PermissionDenied => GrpcCodeBitmask::PERMISSION_DENIED,
63            Self::ResourceExhausted => GrpcCodeBitmask::RESOURCE_EXHAUSTED,
64            Self::FailedPrecondition => GrpcCodeBitmask::FAILED_PRECONDITION,
65            Self::Aborted => GrpcCodeBitmask::ABORTED,
66            Self::OutOfRange => GrpcCodeBitmask::OUT_OF_RANGE,
67            Self::Unimplemented => GrpcCodeBitmask::UNIMPLEMENTED,
68            Self::Internal => GrpcCodeBitmask::INTERNAL,
69            Self::Unavailable => GrpcCodeBitmask::UNAVAILABLE,
70            Self::DataLoss => GrpcCodeBitmask::DATA_LOSS,
71            Self::Unauthenticated => GrpcCodeBitmask::UNAUTHENTICATED,
72        }
73    }
74
75    fn from_i32(code: i32) -> Option<GrpcCode> {
76        match code {
77            0 => Some(GrpcCode::Ok),
78            1 => Some(GrpcCode::Cancelled),
79            2 => Some(GrpcCode::Unknown),
80            3 => Some(GrpcCode::InvalidArgument),
81            4 => Some(GrpcCode::DeadlineExceeded),
82            5 => Some(GrpcCode::NotFound),
83            6 => Some(GrpcCode::AlreadyExists),
84            7 => Some(GrpcCode::PermissionDenied),
85            8 => Some(GrpcCode::ResourceExhausted),
86            9 => Some(GrpcCode::FailedPrecondition),
87            10 => Some(GrpcCode::Aborted),
88            11 => Some(GrpcCode::OutOfRange),
89            12 => Some(GrpcCode::Unimplemented),
90            13 => Some(GrpcCode::Internal),
91            14 => Some(GrpcCode::Unavailable),
92            15 => Some(GrpcCode::DataLoss),
93            16 => Some(GrpcCode::Unauthenticated),
94            _ => None,
95        }
96    }
97}
98
99/// Converts an `i32` gRPC status code into a [`GrpcCode`].
100///
101/// Unrecognized codes (outside 0-16) map to [`GrpcCode::Unknown`].
102impl From<i32> for GrpcCode {
103    fn from(value: i32) -> Self {
104        match value {
105            0 => GrpcCode::Ok,
106            1 => GrpcCode::Cancelled,
107            2 => GrpcCode::Unknown,
108            3 => GrpcCode::InvalidArgument,
109            4 => GrpcCode::DeadlineExceeded,
110            5 => GrpcCode::NotFound,
111            6 => GrpcCode::AlreadyExists,
112            7 => GrpcCode::PermissionDenied,
113            8 => GrpcCode::ResourceExhausted,
114            9 => GrpcCode::FailedPrecondition,
115            10 => GrpcCode::Aborted,
116            11 => GrpcCode::OutOfRange,
117            12 => GrpcCode::Unimplemented,
118            13 => GrpcCode::Internal,
119            14 => GrpcCode::Unavailable,
120            15 => GrpcCode::DataLoss,
121            16 => GrpcCode::Unauthenticated,
122
123            _ => GrpcCode::Unknown,
124        }
125    }
126}
127
128impl From<NonZeroI32> for GrpcCode {
129    fn from(value: NonZeroI32) -> Self {
130        GrpcCode::from(value.get())
131    }
132}
133
134bitflags! {
135    #[derive(Debug, Clone, Copy)]
136    pub(crate) struct GrpcCodeBitmask: u32 {
137        const OK                  = 0b00000000000000001;
138        const CANCELLED           = 0b00000000000000010;
139        const UNKNOWN             = 0b00000000000000100;
140        const INVALID_ARGUMENT    = 0b00000000000001000;
141        const DEADLINE_EXCEEDED   = 0b00000000000010000;
142        const NOT_FOUND           = 0b00000000000100000;
143        const ALREADY_EXISTS      = 0b00000000001000000;
144        const PERMISSION_DENIED   = 0b00000000010000000;
145        const RESOURCE_EXHAUSTED  = 0b00000000100000000;
146        const FAILED_PRECONDITION = 0b00000001000000000;
147        const ABORTED             = 0b00000010000000000;
148        const OUT_OF_RANGE        = 0b00000100000000000;
149        const UNIMPLEMENTED       = 0b00001000000000000;
150        const INTERNAL            = 0b00010000000000000;
151        const UNAVAILABLE         = 0b00100000000000000;
152        const DATA_LOSS           = 0b01000000000000000;
153        const UNAUTHENTICATED     = 0b10000000000000000;
154    }
155}
156
157impl From<GrpcCode> for GrpcCodeBitmask {
158    fn from(code: GrpcCode) -> Self {
159        match code {
160            GrpcCode::Ok => GrpcCodeBitmask::OK,
161            GrpcCode::Cancelled => GrpcCodeBitmask::CANCELLED,
162            GrpcCode::Unknown => GrpcCodeBitmask::UNKNOWN,
163            GrpcCode::InvalidArgument => GrpcCodeBitmask::INVALID_ARGUMENT,
164            GrpcCode::DeadlineExceeded => GrpcCodeBitmask::DEADLINE_EXCEEDED,
165            GrpcCode::NotFound => GrpcCodeBitmask::NOT_FOUND,
166            GrpcCode::AlreadyExists => GrpcCodeBitmask::ALREADY_EXISTS,
167            GrpcCode::PermissionDenied => GrpcCodeBitmask::PERMISSION_DENIED,
168            GrpcCode::ResourceExhausted => GrpcCodeBitmask::RESOURCE_EXHAUSTED,
169            GrpcCode::FailedPrecondition => GrpcCodeBitmask::FAILED_PRECONDITION,
170            GrpcCode::Aborted => GrpcCodeBitmask::ABORTED,
171            GrpcCode::OutOfRange => GrpcCodeBitmask::OUT_OF_RANGE,
172            GrpcCode::Unimplemented => GrpcCodeBitmask::UNIMPLEMENTED,
173            GrpcCode::Internal => GrpcCodeBitmask::INTERNAL,
174            GrpcCode::Unavailable => GrpcCodeBitmask::UNAVAILABLE,
175            GrpcCode::DataLoss => GrpcCodeBitmask::DATA_LOSS,
176            GrpcCode::Unauthenticated => GrpcCodeBitmask::UNAUTHENTICATED,
177        }
178    }
179}
180
181/// Response classifier for gRPC responses.
182///
183/// gRPC doesn't use normal HTTP statuses for indicating success or failure but instead a special
184/// header that might appear in a trailer.
185///
186/// Responses are considered successful if
187///
188/// - `grpc-status` header value contains a success value.
189/// - `grpc-status` header is missing.
190/// - `grpc-status` header value isn't a valid `String`.
191/// - `grpc-status` header value can't parsed into an `i32`.
192///
193/// All others are considered failures.
194#[derive(Debug, Clone)]
195pub struct GrpcErrorsAsFailures {
196    success_codes: GrpcCodeBitmask,
197}
198
199impl Default for GrpcErrorsAsFailures {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205impl GrpcErrorsAsFailures {
206    /// Create a new [`GrpcErrorsAsFailures`].
207    pub fn new() -> Self {
208        Self {
209            success_codes: GrpcCodeBitmask::OK,
210        }
211    }
212
213    /// Change which gRPC codes are considered success.
214    ///
215    /// Defaults to only considering `Ok` as success.
216    ///
217    /// `Ok` will always be considered a success.
218    ///
219    /// # Example
220    ///
221    /// Servers might not want to consider `Invalid Argument` or `Not Found` as failures since
222    /// thats likely the clients fault:
223    ///
224    /// ```rust
225    /// use tower_http::classify::{GrpcErrorsAsFailures, GrpcCode};
226    ///
227    /// let classifier = GrpcErrorsAsFailures::new()
228    ///     .with_success(GrpcCode::InvalidArgument)
229    ///     .with_success(GrpcCode::NotFound);
230    /// ```
231    pub fn with_success(mut self, code: GrpcCode) -> Self {
232        self.success_codes |= code.into_bitmask();
233        self
234    }
235
236    /// Returns a [`MakeClassifier`](super::MakeClassifier) that produces `GrpcErrorsAsFailures`.
237    ///
238    /// This is a convenience function that simply calls `SharedClassifier::new`.
239    pub fn make_classifier() -> SharedClassifier<Self> {
240        SharedClassifier::new(Self::new())
241    }
242}
243
244impl ClassifyResponse for GrpcErrorsAsFailures {
245    type FailureClass = GrpcFailureClass;
246    type ClassifyEos = GrpcEosErrorsAsFailures;
247
248    fn classify_response<B>(
249        self,
250        res: &Response<B>,
251    ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
252        match classify_grpc_metadata(res.headers(), self.success_codes) {
253            ParsedGrpcStatus::Success | ParsedGrpcStatus::HeaderNotGrpcCode => {
254                ClassifiedResponse::Ready(Ok(()))
255            }
256            ParsedGrpcStatus::NonSuccess(status) => {
257                ClassifiedResponse::Ready(Err(GrpcFailureClass::Status(status)))
258            }
259            ParsedGrpcStatus::GrpcStatusHeaderMissing => {
260                ClassifiedResponse::RequiresEos(GrpcEosErrorsAsFailures {
261                    success_codes: self.success_codes,
262                })
263            }
264        }
265    }
266
267    fn classify_error<E>(self, error: &E) -> Self::FailureClass
268    where
269        E: fmt::Display + 'static,
270    {
271        GrpcFailureClass::Error(error.to_string())
272    }
273}
274
275/// The [`ClassifyEos`] for [`GrpcErrorsAsFailures`].
276#[derive(Debug, Clone)]
277pub struct GrpcEosErrorsAsFailures {
278    success_codes: GrpcCodeBitmask,
279}
280
281impl ClassifyEos for GrpcEosErrorsAsFailures {
282    type FailureClass = GrpcFailureClass;
283
284    fn classify_eos(self, trailers: Option<&HeaderMap>) -> Result<(), Self::FailureClass> {
285        if let Some(trailers) = trailers {
286            match classify_grpc_metadata(trailers, self.success_codes) {
287                ParsedGrpcStatus::Success
288                | ParsedGrpcStatus::GrpcStatusHeaderMissing
289                | ParsedGrpcStatus::HeaderNotGrpcCode => Ok(()),
290                ParsedGrpcStatus::NonSuccess(status) => Err(GrpcFailureClass::Status(status)),
291            }
292        } else {
293            Ok(())
294        }
295    }
296
297    fn classify_error<E>(self, error: &E) -> Self::FailureClass
298    where
299        E: fmt::Display + 'static,
300    {
301        GrpcFailureClass::Error(error.to_string())
302    }
303}
304
305/// The failure class for [`GrpcErrorsAsFailures`].
306#[derive(Debug)]
307#[non_exhaustive]
308pub enum GrpcFailureClass {
309    /// A gRPC response was classified as a failure with the corresponding status.
310    Status(GrpcStatus),
311    /// A gRPC response was classified as an error with the corresponding error description.
312    Error(String),
313}
314
315impl fmt::Display for GrpcFailureClass {
316    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
317        match self {
318            Self::Status(status) => {
319                write!(f, "Status: {}", status)
320            }
321            Self::Error(error) => write!(f, "Error: {}", error),
322        }
323    }
324}
325
326impl std::error::Error for GrpcFailureClass {}
327
328pub(crate) fn classify_grpc_metadata(
329    headers: &HeaderMap,
330    success_codes: GrpcCodeBitmask,
331) -> ParsedGrpcStatus {
332    macro_rules! or_else {
333        ($expr:expr, $other:ident) => {
334            if let Some(value) = $expr {
335                value
336            } else {
337                return ParsedGrpcStatus::$other;
338            }
339        };
340    }
341
342    let code_header = or_else!(headers.get("grpc-status"), GrpcStatusHeaderMissing);
343    let code_value: i32 = or_else!(
344        code_header.to_str().ok().and_then(|s| s.parse().ok()),
345        HeaderNotGrpcCode
346    );
347    let grpc_code = GrpcCode::from_i32(code_value);
348
349    if let Some(code) = grpc_code {
350        if success_codes.contains(GrpcCodeBitmask::from(code)) {
351            return ParsedGrpcStatus::Success;
352        }
353    }
354
355    let message = headers.get("grpc-message").map(|header| {
356        percent_decode(header.as_bytes())
357            .decode_utf8_lossy()
358            .into_owned()
359    });
360
361    ParsedGrpcStatus::NonSuccess(GrpcStatus {
362        code: grpc_code,
363        code_raw: code_value,
364        message,
365    })
366}
367
368/// A gRPC status extracted from response headers/trailers.
369#[derive(Debug, PartialEq, Eq)]
370pub struct GrpcStatus {
371    code: Option<GrpcCode>,
372    code_raw: i32,
373    message: Option<String>,
374}
375
376impl GrpcStatus {
377    /// Returns the status code as a [`GrpcCode`], or `None` if the code is not recognized.
378    pub fn code(&self) -> Option<GrpcCode> {
379        self.code
380    }
381
382    /// Returns the raw integer status code.
383    pub fn code_raw(&self) -> i32 {
384        self.code_raw
385    }
386
387    /// Returns the percent-decoded gRPC error message, if present.
388    pub fn message(&self) -> Option<&str> {
389        self.message.as_deref()
390    }
391}
392
393impl fmt::Display for GrpcStatus {
394    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
395        match self.code {
396            Some(code) => write!(f, "{:?}", code)?,
397            None => write!(f, "Code({})", self.code_raw)?,
398        }
399        if let Some(message) = self.message.as_ref() {
400            write!(f, ": {}", message)?;
401        }
402        Ok(())
403    }
404}
405
406#[derive(Debug, PartialEq, Eq)]
407pub(crate) enum ParsedGrpcStatus {
408    Success,
409    NonSuccess(GrpcStatus),
410    GrpcStatusHeaderMissing,
411    // this is treated as `Success` but kept separate for clarity
412    HeaderNotGrpcCode,
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    macro_rules! classify_grpc_metadata_test {
420        (
421            name: $name:ident,
422            status: $status:expr,
423            success_flags: $success_flags:expr,
424            expected: $expected:expr,
425        ) => {
426            classify_grpc_metadata_test!(
427                name: $name,
428                status: $status,
429                message: "",
430                success_flags: $success_flags,
431                expected: $expected,
432            );
433        };
434        (
435            name: $name:ident,
436            status: $status:expr,
437            message: $message:expr,
438            success_flags: $success_flags:expr,
439            expected: $expected:expr,
440        ) => {
441            #[test]
442            fn $name() {
443                let mut headers = HeaderMap::new();
444                headers.insert("grpc-status", $status.parse().unwrap());
445                if !$message.is_empty() {
446                    headers.insert("grpc-message", $message.parse().unwrap());
447                }
448                let status = classify_grpc_metadata(&headers, $success_flags);
449                assert_eq!(status, $expected);
450            }
451        };
452    }
453
454    classify_grpc_metadata_test! {
455        name: basic_ok,
456        status: "0",
457        success_flags: GrpcCodeBitmask::OK,
458        expected: ParsedGrpcStatus::Success,
459    }
460
461    classify_grpc_metadata_test! {
462        name: basic_error,
463        status: "1",
464        success_flags: GrpcCodeBitmask::OK,
465        expected: ParsedGrpcStatus::NonSuccess(GrpcStatus{
466            code: Some(GrpcCode::Cancelled),
467            code_raw: 1,
468            message: None,
469        }),
470    }
471
472    classify_grpc_metadata_test! {
473        name: two_success_codes_first_matches,
474        status: "0",
475        success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT,
476        expected: ParsedGrpcStatus::Success,
477    }
478
479    classify_grpc_metadata_test! {
480        name: two_success_codes_second_matches,
481        status: "3",
482        success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT,
483        expected: ParsedGrpcStatus::Success,
484    }
485
486    classify_grpc_metadata_test! {
487        name: two_success_codes_none_matches,
488        status: "16",
489        message: "mock message",
490        success_flags: GrpcCodeBitmask::OK | GrpcCodeBitmask::INVALID_ARGUMENT,
491        expected: ParsedGrpcStatus::NonSuccess(GrpcStatus{
492            code: Some(GrpcCode::Unauthenticated),
493            code_raw: 16,
494            message: Some("mock message".to_string()),
495        }),
496    }
497
498    classify_grpc_metadata_test! {
499        name: percent_encoded_message,
500        status: "2",
501        message: "hello%20world",
502        success_flags: GrpcCodeBitmask::OK,
503        expected: ParsedGrpcStatus::NonSuccess(GrpcStatus{
504            code: Some(GrpcCode::Unknown),
505            code_raw: 2,
506            message: Some("hello world".to_string()),
507        }),
508    }
509
510    classify_grpc_metadata_test! {
511        name: invalid_percent_encoding,
512        status: "13",
513        message: "bad%2Gencode",
514        success_flags: GrpcCodeBitmask::OK,
515        expected: ParsedGrpcStatus::NonSuccess(GrpcStatus{
516            code: Some(GrpcCode::Internal),
517            code_raw: 13,
518            message: Some("bad%2Gencode".to_string()),
519        }),
520    }
521
522    classify_grpc_metadata_test! {
523        name: empty_grpc_message,
524        status: "5",
525        message: "",
526        success_flags: GrpcCodeBitmask::OK,
527        expected: ParsedGrpcStatus::NonSuccess(GrpcStatus{
528            code: Some(GrpcCode::NotFound),
529            code_raw: 5,
530            message: None,
531        }),
532    }
533
534    classify_grpc_metadata_test! {
535        name: unknown_status_code_above_16,
536        status: "99",
537        message: "custom error",
538        success_flags: GrpcCodeBitmask::OK,
539        expected: ParsedGrpcStatus::NonSuccess(GrpcStatus{
540            code: None,
541            code_raw: 99,
542            message: Some("custom error".to_string()),
543        }),
544    }
545
546    #[test]
547    fn invalid_utf8_after_percent_decode() {
548        let mut headers = HeaderMap::new();
549        headers.insert("grpc-status", "2".parse().unwrap());
550        // %80 is an invalid UTF-8 start byte; lossy decode replaces it with U+FFFD
551        headers.insert("grpc-message", "bad%80byte".parse().unwrap());
552        let status = classify_grpc_metadata(&headers, GrpcCodeBitmask::OK);
553        assert_eq!(
554            status,
555            ParsedGrpcStatus::NonSuccess(GrpcStatus {
556                code: Some(GrpcCode::Unknown),
557                code_raw: 2,
558                message: Some("bad\u{FFFD}byte".to_string()),
559            })
560        );
561    }
562
563    #[test]
564    fn valid_utf8_percent_encoded() {
565        let mut headers = HeaderMap::new();
566        headers.insert("grpc-status", "3".parse().unwrap());
567        // %C3%A9 is the percent-encoded form of 'é' (U+00E9) in UTF-8
568        headers.insert("grpc-message", "caf%C3%A9".parse().unwrap());
569        let status = classify_grpc_metadata(&headers, GrpcCodeBitmask::OK);
570        assert_eq!(
571            status,
572            ParsedGrpcStatus::NonSuccess(GrpcStatus {
573                code: Some(GrpcCode::InvalidArgument),
574                code_raw: 3,
575                message: Some("café".to_string()),
576            })
577        );
578    }
579
580    #[test]
581    fn grpc_ok_classified_as_success() {
582        use http::Response;
583
584        let res = Response::builder()
585            .header("grpc-status", "0")
586            .body(())
587            .unwrap();
588
589        let classifier = GrpcErrorsAsFailures::new();
590        let result = classifier.classify_response(&res);
591        assert!(matches!(result, ClassifiedResponse::Ready(Ok(()))));
592    }
593
594    #[test]
595    fn grpc_code_from_i32_known_codes() {
596        assert!(matches!(GrpcCode::from(0), GrpcCode::Ok));
597        assert!(matches!(GrpcCode::from(1), GrpcCode::Cancelled));
598        assert!(matches!(GrpcCode::from(4), GrpcCode::DeadlineExceeded));
599        assert!(matches!(GrpcCode::from(13), GrpcCode::Internal));
600        assert!(matches!(GrpcCode::from(16), GrpcCode::Unauthenticated));
601    }
602
603    #[test]
604    fn grpc_code_from_i32_unknown_codes() {
605        assert!(matches!(GrpcCode::from(17), GrpcCode::Unknown));
606        assert!(matches!(GrpcCode::from(-1), GrpcCode::Unknown));
607        assert!(matches!(GrpcCode::from(9999), GrpcCode::Unknown));
608    }
609
610    #[test]
611    fn grpc_code_from_non_zero_i32() {
612        let code = NonZeroI32::new(7).unwrap();
613        assert!(matches!(GrpcCode::from(code), GrpcCode::PermissionDenied));
614
615        let code = NonZeroI32::new(99).unwrap();
616        assert!(matches!(GrpcCode::from(code), GrpcCode::Unknown));
617    }
618}