1use std::collections::BTreeMap;
11
12use http::StatusCode;
13use http::header::HeaderValue;
14use serde::Deserialize;
15use serde::Serialize;
16
17use crate::body::TakoBody;
18use crate::responder::Responder;
19use crate::types::Response;
20
21pub const PROBLEM_JSON: &str = "application/problem+json";
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct Problem {
30 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
32 pub r#type: Option<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub title: Option<String>,
36 pub status: u16,
38 #[serde(skip_serializing_if = "Option::is_none")]
40 pub detail: Option<String>,
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub instance: Option<String>,
44 #[serde(flatten)]
46 pub extensions: BTreeMap<String, serde_json::Value>,
47}
48
49impl Problem {
50 pub fn from_status(status: StatusCode) -> Self {
53 Self {
54 r#type: None,
55 title: status.canonical_reason().map(str::to_string),
56 status: status.as_u16(),
57 detail: None,
58 instance: None,
59 extensions: BTreeMap::new(),
60 }
61 }
62
63 pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
65 self.detail = Some(detail.into());
66 self
67 }
68
69 pub fn with_type(mut self, type_uri: impl Into<String>) -> Self {
71 self.r#type = Some(type_uri.into());
72 self
73 }
74
75 pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
77 self.instance = Some(instance.into());
78 self
79 }
80
81 pub fn with_extension(
83 mut self,
84 key: impl Into<String>,
85 value: impl Into<serde_json::Value>,
86 ) -> Self {
87 self.extensions.insert(key.into(), value.into());
88 self
89 }
90}
91
92impl Responder for Problem {
93 fn into_response(self) -> Response {
94 let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
95 let body = serde_json::to_vec(&self).unwrap_or_else(|_| b"{}".to_vec());
96 let mut res = Response::new(TakoBody::from(body));
97 *res.status_mut() = status;
98 res.headers_mut().insert(
99 http::header::CONTENT_TYPE,
100 HeaderValue::from_static(PROBLEM_JSON),
101 );
102 res
103 }
104}
105
106pub fn default_problem_responder(response: Response) -> Response {
120 let status = response.status();
121
122 if !status.is_client_error() && !status.is_server_error() {
123 return response;
124 }
125
126 if let Some(ct) = response.headers().get(http::header::CONTENT_TYPE)
127 && let Ok(s) = ct.to_str()
128 {
129 let essence = s
130 .split(';')
131 .next()
132 .unwrap_or("")
133 .trim()
134 .to_ascii_lowercase();
135 if essence == "application/json" || essence == "application/problem+json" {
136 return response;
137 }
138 }
139
140 let problem = Problem::from_status(status);
141 problem.into_response()
142}
143
144#[cfg(test)]
145mod tests {
146 use http::Response as HttpResponse;
147 use http_body_util::BodyExt;
148
149 use super::*;
150
151 fn body_string(resp: Response) -> String {
152 tokio::runtime::Builder::new_current_thread()
153 .build()
154 .unwrap()
155 .block_on(async {
156 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
157 String::from_utf8(bytes.to_vec()).unwrap()
158 })
159 }
160
161 #[test]
162 fn problem_from_status_uses_canonical_reason() {
163 let p = Problem::from_status(StatusCode::NOT_FOUND);
164 assert_eq!(p.status, 404);
165 assert_eq!(p.title.as_deref(), Some("Not Found"));
166 assert!(p.detail.is_none());
167 assert!(p.r#type.is_none());
168 }
169
170 #[test]
171 fn problem_with_detail_setter() {
172 let p = Problem::from_status(StatusCode::BAD_REQUEST).with_detail("missing field 'name'");
173 assert_eq!(p.detail.as_deref(), Some("missing field 'name'"));
174 }
175
176 #[test]
177 fn problem_with_type_and_instance() {
178 let p = Problem::from_status(StatusCode::CONFLICT)
179 .with_type("https://example.com/probs/conflict")
180 .with_instance("/orders/42");
181 assert_eq!(
182 p.r#type.as_deref(),
183 Some("https://example.com/probs/conflict")
184 );
185 assert_eq!(p.instance.as_deref(), Some("/orders/42"));
186 }
187
188 #[test]
189 fn problem_with_extension_round_trips_through_serde() {
190 let p = Problem::from_status(StatusCode::UNPROCESSABLE_ENTITY)
191 .with_extension("invalid_params", serde_json::json!(["email", "age"]));
192 let body = serde_json::to_string(&p).unwrap();
193 assert!(body.contains(r#""invalid_params":["email","age"]"#));
194 assert!(body.contains(r#""status":422"#));
195 }
196
197 #[test]
198 fn problem_into_response_writes_problem_json_content_type() {
199 let p = Problem::from_status(StatusCode::INTERNAL_SERVER_ERROR);
200 let resp = p.into_response();
201 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
202 assert_eq!(
203 resp.headers().get(http::header::CONTENT_TYPE).unwrap(),
204 &HeaderValue::from_static(PROBLEM_JSON),
205 );
206 }
207
208 #[test]
209 fn problem_into_response_serializes_canonical_fields() {
210 let p = Problem::from_status(StatusCode::NOT_FOUND).with_detail("user 7 missing");
211 let body = body_string(p.into_response());
212 assert!(body.contains(r#""title":"Not Found""#));
213 assert!(body.contains(r#""status":404"#));
214 assert!(body.contains(r#""detail":"user 7 missing""#));
215 }
216
217 #[test]
218 fn default_problem_responder_replaces_plain_response() {
219 let mut resp = HttpResponse::new(TakoBody::from("oops"));
220 *resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
221 resp.headers_mut().insert(
222 http::header::CONTENT_TYPE,
223 HeaderValue::from_static("text/plain"),
224 );
225
226 let upgraded = default_problem_responder(resp);
227 assert_eq!(
228 upgraded.headers().get(http::header::CONTENT_TYPE).unwrap(),
229 &HeaderValue::from_static(PROBLEM_JSON),
230 );
231 let body = body_string(upgraded);
232 assert!(body.contains(r#""status":500"#));
233 }
234
235 #[test]
236 fn default_problem_responder_passes_through_existing_json() {
237 let mut resp = HttpResponse::new(TakoBody::from(r#"{"err":"x"}"#));
238 *resp.status_mut() = StatusCode::BAD_REQUEST;
239 resp.headers_mut().insert(
240 http::header::CONTENT_TYPE,
241 HeaderValue::from_static("application/json"),
242 );
243
244 let unchanged = default_problem_responder(resp);
245 let body = body_string(unchanged);
246 assert_eq!(body, r#"{"err":"x"}"#);
247 }
248
249 #[test]
250 fn default_problem_responder_passes_through_problem_json() {
251 let mut resp = HttpResponse::new(TakoBody::from(r#"{"status":418}"#));
252 *resp.status_mut() = StatusCode::IM_A_TEAPOT;
253 resp.headers_mut().insert(
254 http::header::CONTENT_TYPE,
255 HeaderValue::from_static(PROBLEM_JSON),
256 );
257
258 let unchanged = default_problem_responder(resp);
259 let body = body_string(unchanged);
260 assert_eq!(body, r#"{"status":418}"#);
261 }
262}