tower_request_guard/
response.rs1use crate::violation::Violation;
2use http::Response;
3
4pub(crate) fn escape_json_string(s: &str) -> String {
7 let mut escaped = String::with_capacity(s.len());
8 for ch in s.chars() {
9 match ch {
10 '"' => escaped.push_str(r#"\""#),
11 '\\' => escaped.push_str(r#"\\"#),
12 '\n' => escaped.push_str(r#"\n"#),
13 '\r' => escaped.push_str(r#"\r"#),
14 '\t' => escaped.push_str(r#"\t"#),
15 c if c.is_control() => {
16 escaped.push_str(&format!("\\u{:04x}", c as u32));
17 }
18 c => escaped.push(c),
19 }
20 }
21 escaped
22}
23
24pub fn violation_response(violation: &Violation) -> Response<String> {
26 let status = violation.status_code();
27 let body = violation_json_body(violation);
28
29 Response::builder()
30 .status(status)
31 .header("Content-Type", "application/json")
32 .body(body)
33 .unwrap()
34}
35
36fn violation_json_body(violation: &Violation) -> String {
37 match violation {
38 Violation::BodyTooLarge { max, received } => {
39 format!(
40 r#"{{"error":"payload too large","violation":"body_too_large","max":{},"received":{}}}"#,
41 max, received
42 )
43 }
44 Violation::RequestTimeout { timeout_ms } => {
45 format!(
46 r#"{{"error":"request timeout","violation":"request_timeout","timeout_ms":{}}}"#,
47 timeout_ms
48 )
49 }
50 Violation::InvalidContentType { received, allowed } => {
51 let received_escaped = escape_json_string(received);
52 let allowed_json: Vec<String> = allowed
53 .iter()
54 .map(|a| format!(r#""{}""#, escape_json_string(a)))
55 .collect();
56 format!(
57 r#"{{"error":"unsupported content type","violation":"invalid_content_type","received":"{}","allowed":[{}]}}"#,
58 received_escaped,
59 allowed_json.join(",")
60 )
61 }
62 Violation::MissingHeader { header } => {
63 format!(
64 r#"{{"error":"missing required header","violation":"missing_header","header":"{}"}}"#,
65 escape_json_string(header)
66 )
67 }
68 Violation::JsonTooDeep {
69 max_depth,
70 found_depth,
71 } => {
72 format!(
73 r#"{{"error":"json depth exceeded","violation":"json_too_deep","max_depth":{},"found_depth":{}}}"#,
74 max_depth, found_depth
75 )
76 }
77 Violation::InvalidJson { detail } => {
78 format!(
79 r#"{{"error":"invalid json","violation":"invalid_json","detail":"{}"}}"#,
80 escape_json_string(detail)
81 )
82 }
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::violation::Violation;
90 use http::StatusCode;
91
92 #[test]
93 fn escape_json_string_handles_special_chars() {
94 assert_eq!(escape_json_string(r#"hello "world""#), r#"hello \"world\""#);
95 assert_eq!(escape_json_string("back\\slash"), r#"back\\slash"#);
96 assert_eq!(escape_json_string("new\nline"), r#"new\nline"#);
97 assert_eq!(escape_json_string("tab\there"), r#"tab\there"#);
98 }
99
100 #[test]
101 fn escape_json_string_passes_through_clean_input() {
102 assert_eq!(escape_json_string("application/json"), "application/json");
103 assert_eq!(escape_json_string("Authorization"), "Authorization");
104 }
105
106 #[test]
107 fn body_too_large_response() {
108 let v = Violation::BodyTooLarge {
109 max: 1024,
110 received: 2048,
111 };
112 let resp = violation_response(&v);
113 assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
114 let body = resp.into_body();
115 assert!(body.contains(r#""violation":"body_too_large""#));
116 assert!(body.contains(r#""max":1024"#));
117 assert!(body.contains(r#""received":2048"#));
118 }
119
120 #[test]
121 fn missing_header_response() {
122 let v = Violation::MissingHeader {
123 header: "Authorization".into(),
124 };
125 let resp = violation_response(&v);
126 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
127 let body = resp.into_body();
128 assert!(body.contains(r#""violation":"missing_header""#));
129 assert!(body.contains(r#""header":"Authorization""#));
130 }
131
132 #[test]
133 fn invalid_content_type_response() {
134 let v = Violation::InvalidContentType {
135 received: "text/xml".into(),
136 allowed: vec!["application/json".into(), "multipart/form-data".into()],
137 };
138 let resp = violation_response(&v);
139 assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
140 let body = resp.into_body();
141 assert!(body.contains(r#""violation":"invalid_content_type""#));
142 assert!(body.contains(r#""received":"text/xml""#));
143 assert!(body.contains(r#""allowed":["application/json","multipart/form-data"]"#));
144 }
145
146 #[test]
147 fn timeout_response() {
148 let v = Violation::RequestTimeout { timeout_ms: 30000 };
149 let resp = violation_response(&v);
150 assert_eq!(resp.status(), StatusCode::GATEWAY_TIMEOUT);
151 let body = resp.into_body();
152 assert!(body.contains(r#""violation":"request_timeout""#));
153 assert!(body.contains(r#""timeout_ms":30000"#));
154 }
155
156 #[test]
157 fn json_too_deep_response() {
158 let v = Violation::JsonTooDeep {
159 max_depth: 32,
160 found_depth: 128,
161 };
162 let resp = violation_response(&v);
163 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
164 let body = resp.into_body();
165 assert!(body.contains(r#""violation":"json_too_deep""#));
166 assert!(body.contains(r#""max_depth":32"#));
167 assert!(body.contains(r#""found_depth":128"#));
168 }
169
170 #[test]
171 fn invalid_json_response() {
172 let v = Violation::InvalidJson {
173 detail: "unexpected EOF".into(),
174 };
175 let resp = violation_response(&v);
176 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
177 let body = resp.into_body();
178 assert!(body.contains(r#""violation":"invalid_json""#));
179 assert!(body.contains(r#""detail":"unexpected EOF""#));
180 }
181
182 #[test]
183 fn response_escapes_untrusted_input() {
184 let v = Violation::MissingHeader {
185 header: r#"X-Bad"Header"#.into(),
186 };
187 let resp = violation_response(&v);
188 let body = resp.into_body();
189 assert!(body.contains(r#""header":"X-Bad\"Header""#));
190 }
191}