1use std::collections::HashMap;
4
5use axum::body::Body;
6use axum::response::IntoResponse;
7use http::header::{self, HeaderMap, HeaderValue};
8use hyper::{Response, StatusCode};
9use serde::{Deserialize, Serialize, Serializer};
10
11pub type GenericError = Box<dyn std::error::Error + Send + Sync>;
13
14macro_rules! twirp_error_codes {
15 (
16 $(
17 $(#[$docs:meta])*
18 ($konst:ident, $num:expr, $phrase:ident);
19 )+
20 ) => {
21 #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
23 #[serde(field_identifier, rename_all = "snake_case")]
24 #[non_exhaustive]
25 pub enum TwirpErrorCode {
26 $(
27 $(#[$docs])*
28 $konst,
29 )+
30 }
31
32 impl TwirpErrorCode {
33 pub fn http_status_code(&self) -> StatusCode {
34 match *self {
35 $(
36 TwirpErrorCode::$konst => $num,
37 )+
38 }
39 }
40
41 pub fn twirp_code(&self) -> &'static str {
42 match *self {
43 $(
44 TwirpErrorCode::$konst => stringify!($phrase),
45 )+
46 }
47 }
48 }
49
50 $(
51 pub fn $phrase<T: ToString>(msg: T) -> TwirpErrorResponse {
52 TwirpErrorResponse {
53 code: TwirpErrorCode::$konst,
54 msg: msg.to_string(),
55 meta: Default::default(),
56 }
57 }
58 )+
59 }
60}
61
62twirp_error_codes! {
64 (Canceled, StatusCode::REQUEST_TIMEOUT, canceled);
66 (Unknown, StatusCode::INTERNAL_SERVER_ERROR, unknown);
69 (InvalidArgument, StatusCode::BAD_REQUEST, invalid_argument);
73 (Malformed, StatusCode::BAD_REQUEST, malformed);
77 (DeadlineExceeded, StatusCode::REQUEST_TIMEOUT, deadline_exceeded);
81 (NotFound, StatusCode::NOT_FOUND, not_found);
83 (BadRoute, StatusCode::NOT_FOUND, bad_route);
87 (AlreadyExists, StatusCode::CONFLICT, already_exists);
89 (PermissionDenied, StatusCode::FORBIDDEN, permission_denied);
93 (Unauthenticated, StatusCode::UNAUTHORIZED, unauthenticated);
96 (ResourceExhausted, StatusCode::TOO_MANY_REQUESTS, resource_exhausted);
99 (FailedPrecondition, StatusCode::PRECONDITION_FAILED, failed_precondition);
104 (Aborted, StatusCode::CONFLICT, aborted);
107 (OutOfRange, StatusCode::BAD_REQUEST, out_of_range);
117 (Unimplemented, StatusCode::NOT_IMPLEMENTED, unimplemented);
120 (Internal, StatusCode::INTERNAL_SERVER_ERROR, internal);
125 (Unavailable, StatusCode::SERVICE_UNAVAILABLE, unavailable);
128 (Dataloss, StatusCode::INTERNAL_SERVER_ERROR, dataloss);
130}
131
132impl Serialize for TwirpErrorCode {
133 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134 where
135 S: Serializer,
136 {
137 serializer.serialize_str(self.twirp_code())
138 }
139}
140
141#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
143pub struct TwirpErrorResponse {
144 pub code: TwirpErrorCode,
145 pub msg: String,
146 #[serde(skip_serializing_if = "HashMap::is_empty")]
147 #[serde(default)]
148 pub meta: HashMap<String, String>,
149}
150
151impl TwirpErrorResponse {
152 pub fn insert_meta(&mut self, key: String, value: String) -> Option<String> {
153 self.meta.insert(key, value)
154 }
155}
156
157impl IntoResponse for TwirpErrorResponse {
158 fn into_response(self) -> Response<Body> {
159 let mut headers = HeaderMap::new();
160 headers.insert(
161 header::CONTENT_TYPE,
162 HeaderValue::from_static("application/json"),
163 );
164
165 let json =
166 serde_json::to_string(&self).expect("JSON serialization of an error should not fail");
167
168 (self.code.http_status_code(), headers, json).into_response()
169 }
170}
171
172#[cfg(test)]
173mod test {
174 use crate::{TwirpErrorCode, TwirpErrorResponse};
175
176 #[test]
177 fn twirp_status_mapping() {
178 assert_code(TwirpErrorCode::Canceled, "canceled", 408);
179 assert_code(TwirpErrorCode::Unknown, "unknown", 500);
180 assert_code(TwirpErrorCode::InvalidArgument, "invalid_argument", 400);
181 assert_code(TwirpErrorCode::Malformed, "malformed", 400);
182 assert_code(TwirpErrorCode::Unauthenticated, "unauthenticated", 401);
183 assert_code(TwirpErrorCode::PermissionDenied, "permission_denied", 403);
184 assert_code(TwirpErrorCode::DeadlineExceeded, "deadline_exceeded", 408);
185 assert_code(TwirpErrorCode::NotFound, "not_found", 404);
186 assert_code(TwirpErrorCode::BadRoute, "bad_route", 404);
187 assert_code(TwirpErrorCode::Unimplemented, "unimplemented", 501);
188 assert_code(TwirpErrorCode::Internal, "internal", 500);
189 assert_code(TwirpErrorCode::Unavailable, "unavailable", 503);
190 }
191
192 fn assert_code(code: TwirpErrorCode, msg: &str, http: u16) {
193 assert_eq!(
194 code.http_status_code(),
195 http,
196 "expected http status code {} but got {}",
197 http,
198 code.http_status_code()
199 );
200 assert_eq!(
201 code.twirp_code(),
202 msg,
203 "expected error message '{}' but got '{}'",
204 msg,
205 code.twirp_code()
206 );
207 }
208
209 #[test]
210 fn twirp_error_response_serialization() {
211 let response = TwirpErrorResponse {
212 code: TwirpErrorCode::DeadlineExceeded,
213 msg: "test".to_string(),
214 meta: Default::default(),
215 };
216
217 let result = serde_json::to_string(&response).unwrap();
218 assert!(result.contains(r#""code":"deadline_exceeded""#));
219 assert!(result.contains(r#""msg":"test""#));
220
221 let result = serde_json::from_str(&result).unwrap();
222 assert_eq!(response, result);
223 }
224}