1use crate::encoding::Strategy;
2use wafrift_types::injection_context::{ContextualEncodeError, InjectionContext};
3
4pub fn encode_in_context(
5 payload: &[u8],
6 strategy: Strategy,
7 context: InjectionContext,
8) -> Result<String, ContextualEncodeError> {
9 let max_size = match context {
10 InjectionContext::JsonString => 4 * 1024 * 1024,
11 InjectionContext::JsonNumber => 1024,
12 InjectionContext::XmlAttribute => 1024 * 1024,
13 InjectionContext::XmlCdata => 8 * 1024 * 1024,
14 InjectionContext::HeaderValue => 8 * 1024,
15 InjectionContext::CookieValue => 4 * 1024,
16 InjectionContext::MultipartFileName => 256,
17 _ => 8 * 1024 * 1024,
18 };
19
20 if payload.len() > max_size {
21 return Err(ContextualEncodeError::PayloadTooLarge {
22 context,
23 size: payload.len(),
24 max: max_size,
25 });
26 }
27
28 let base = match crate::encoding::encode(payload, strategy) {
29 Ok(s) => s,
30 Err(e) => {
31 return Err(match e {
32 crate::error::EncodeError::InvalidUtf8 => {
33 ContextualEncodeError::InvalidUtf8 { offset: 0 }
34 }
35 crate::error::EncodeError::PayloadTooLarge { max, actual } => {
36 ContextualEncodeError::PayloadTooLarge {
37 context,
38 size: actual,
39 max,
40 }
41 }
42 crate::error::EncodeError::LayeredOutputTooLarge { max, actual } => {
43 ContextualEncodeError::PayloadTooLarge {
44 context,
45 size: actual,
46 max,
47 }
48 }
49 crate::error::EncodeError::InvalidContext {
50 strategy: s,
51 context: _,
52 } => ContextualEncodeError::ContextIncompatible {
53 strategy: s.into(),
54 context,
55 reason: "strategy invalid for context".into(),
56 },
57 crate::error::EncodeError::InvalidConfig(msg) => {
58 ContextualEncodeError::ContextIncompatible {
59 strategy: "config".into(),
60 context,
61 reason: msg,
62 }
63 }
64 });
65 }
66 };
67
68 escape_for_context(&base, context)
69}
70
71pub fn escape_for_context(
72 input: &str,
73 context: InjectionContext,
74) -> Result<String, ContextualEncodeError> {
75 let escaped = match context {
76 InjectionContext::JsonString => {
77 let mut s = String::with_capacity(input.len() + 10);
78 for c in input.chars() {
79 match c {
80 '\\' => s.push_str("\\\\"),
81 '"' => s.push_str("\\\""),
82 '\n' => s.push_str("\\n"),
83 '\r' => s.push_str("\\r"),
84 '\t' => s.push_str("\\t"),
85 '\x00'..='\x1f' => s.push_str(&format!("\\u{:04x}", c as u32)),
86 _ => s.push(c),
87 }
88 }
89 s
90 }
91 InjectionContext::JsonNumber => {
92 if input.chars().any(|c| {
93 !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+'
94 }) {
95 return Err(ContextualEncodeError::ContextIncompatible {
96 strategy: "escape".into(),
97 context,
98 reason: "not a valid JSON number".into(),
99 });
100 }
101 input.to_string()
102 }
103 InjectionContext::XmlAttribute => {
104 if input.contains('\x00') {
105 return Err(ContextualEncodeError::ContextIncompatible {
106 strategy: "escape".into(),
107 context,
108 reason: "null byte in xml attribute".into(),
109 });
110 }
111 input
112 .replace('&', "&")
113 .replace('"', """)
114 .replace('<', "<")
115 .replace('>', ">")
116 }
117 InjectionContext::XmlCdata => {
118 if input.contains("]]>") {
119 return Err(ContextualEncodeError::ContextIncompatible {
120 strategy: "escape".into(),
121 context,
122 reason: "CDATA cannot contain ]]>".into(),
123 });
124 }
125 input.to_string()
126 }
127 InjectionContext::XmlText => input
128 .replace('&', "&")
129 .replace('<', "<")
130 .replace('>', ">"),
131 InjectionContext::HtmlAttribute => input
132 .replace('&', "&")
133 .replace('"', """)
134 .replace('\'', "'")
135 .replace('<', "<"),
136 InjectionContext::HtmlText => input.replace('&', "&").replace('<', "<"),
137 InjectionContext::UrlQuery => urlencoding::encode(input).to_string(),
138 InjectionContext::UrlPath => urlencoding::encode(input).to_string().replace("%2F", "/"),
139 InjectionContext::UrlFragment => urlencoding::encode(input).to_string(),
140 InjectionContext::HeaderValue => {
141 if input.contains('\r') || input.contains('\n') {
142 return Err(ContextualEncodeError::ContextIncompatible {
143 strategy: "escape".into(),
144 context,
145 reason: "CR/LF in header value".into(),
146 });
147 }
148 if input.contains('\x00') {
149 return Err(ContextualEncodeError::ContextIncompatible {
150 strategy: "escape".into(),
151 context,
152 reason: "null byte in header value".into(),
153 });
154 }
155 input.to_string()
156 }
157 InjectionContext::CookieValue => input
158 .replace(';', "%3B")
159 .replace('=', "%3D")
160 .replace('\x00', "%00")
161 .replace('\r', "%0D")
162 .replace('\n', "%0A"),
163 InjectionContext::MultipartField => {
164 if input.contains('\r') || input.contains('\n') {
165 return Err(ContextualEncodeError::ContextIncompatible {
166 strategy: "escape".into(),
167 context,
168 reason: "CR/LF would break multipart structure".into(),
169 });
170 }
171 input.to_string()
172 }
173 InjectionContext::MultipartFileName => {
174 if input.contains('"') {
175 return Err(ContextualEncodeError::ContextIncompatible {
176 strategy: "escape".into(),
177 context,
178 reason: "quote in filename".into(),
179 });
180 }
181 if input.contains('\r') || input.contains('\n') {
182 return Err(ContextualEncodeError::ContextIncompatible {
183 strategy: "escape".into(),
184 context,
185 reason: "CR/LF in filename".into(),
186 });
187 }
188 input.to_string()
189 }
190 InjectionContext::PlainBody => input.to_string(),
191 _ => input.to_string(),
192 };
193 validate_in_context(&escaped, context)?;
194 Ok(escaped)
195}
196
197pub fn validate_in_context(
198 payload: &str,
199 context: InjectionContext,
200) -> Result<(), ContextualEncodeError> {
201 match context {
202 InjectionContext::JsonString => {
203 let mut chars = payload.chars().peekable();
204 while let Some(c) = chars.next() {
205 if c == '"' {
206 return Err(ContextualEncodeError::ContextIncompatible {
207 strategy: "validate".into(),
208 context,
209 reason: "unescaped double quote in JSON string".into(),
210 });
211 }
212 if c == '\\' {
213 let escaped = chars.next();
214 match escaped {
215 Some('\\') | Some('"') | Some('n') | Some('r') | Some('t') | Some('b')
216 | Some('f') | Some('/') => {}
217 Some('u') => {
218 for _ in 0..4 {
220 match chars.next() {
221 Some(c) if c.is_ascii_hexdigit() => {}
222 _ => {
223 return Err(ContextualEncodeError::ContextIncompatible {
224 strategy: "validate".into(),
225 context,
226 reason: "invalid Unicode escape in JSON string".into(),
227 });
228 }
229 }
230 }
231 }
232 Some(other) => {
233 return Err(ContextualEncodeError::ContextIncompatible {
234 strategy: "validate".into(),
235 context,
236 reason: format!("invalid JSON escape sequence: \\{other}"),
237 });
238 }
239 None => {
240 return Err(ContextualEncodeError::ContextIncompatible {
241 strategy: "validate".into(),
242 context,
243 reason: "trailing backslash in JSON string".into(),
244 });
245 }
246 }
247 }
248 }
249 }
250 InjectionContext::XmlAttribute => {
251 let mut chars = payload.chars();
252 while let Some(c) = chars.next() {
253 if c == '"' {
254 return Err(ContextualEncodeError::ContextIncompatible {
255 strategy: "validate".into(),
256 context,
257 reason: "unescaped double quote in XML attribute".into(),
258 });
259 }
260 if c == '&' {
261 let remainder: String = chars.by_ref().take(6).collect();
263 if !remainder.starts_with("quot;")
264 && !remainder.starts_with("amp;")
265 && !remainder.starts_with("lt;")
266 && !remainder.starts_with("gt;")
267 {
268 }
272 }
273 }
274 }
275 InjectionContext::PlainBody => {
279 }
281 InjectionContext::XmlCdata => {
282 }
285 InjectionContext::XmlText => {
286 }
289 InjectionContext::HtmlAttribute => {
290 }
293 InjectionContext::HtmlText => {
294 }
297 InjectionContext::UrlQuery | InjectionContext::UrlPath | InjectionContext::UrlFragment => {
298 }
301 InjectionContext::HeaderValue => {
302 }
305 InjectionContext::CookieValue => {
306 }
309 InjectionContext::MultipartField | InjectionContext::MultipartFileName => {
310 }
313 _ => {}
316 }
317 Ok(())
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::encoding::Strategy;
324
325 #[test]
326 fn encode_error_mapping_payload_too_large() {
327 let result = encode_in_context(
331 b"\x80",
332 Strategy::CaseAlternation,
333 InjectionContext::PlainBody,
334 );
335 assert!(result.is_err());
337 let err = result.unwrap_err();
338 assert!(
339 err.to_string().contains("invalid") || err.to_string().contains("UTF-8"),
340 "error should mention invalid UTF-8, got: {}",
341 err
342 );
343 }
344
345 #[test]
346 fn json_string_validates_unescaped_quote() {
347 let err = validate_in_context("hello\"world", InjectionContext::JsonString).unwrap_err();
348 assert!(err.to_string().contains("unescaped double quote"));
349 }
350
351 #[test]
352 fn json_string_validates_valid_escapes() {
353 assert!(validate_in_context("hello\\nworld", InjectionContext::JsonString).is_ok());
354 assert!(validate_in_context("hello\\tworld", InjectionContext::JsonString).is_ok());
355 assert!(validate_in_context("hello\\\\world", InjectionContext::JsonString).is_ok());
356 assert!(validate_in_context("hello\\\"world", InjectionContext::JsonString).is_ok());
357 }
358
359 #[test]
360 fn json_string_validates_unicode_escape() {
361 assert!(validate_in_context("\\u00e4", InjectionContext::JsonString).is_ok());
363 let err = validate_in_context("\\u00g4", InjectionContext::JsonString).unwrap_err();
365 assert!(err.to_string().contains("invalid Unicode escape"));
366 let err = validate_in_context("\\u00", InjectionContext::JsonString).unwrap_err();
368 assert!(err.to_string().contains("invalid Unicode escape"));
369 }
370
371 #[test]
372 fn json_string_validates_invalid_escape() {
373 let err = validate_in_context("\\x", InjectionContext::JsonString).unwrap_err();
374 assert!(err.to_string().contains("invalid JSON escape"));
375 }
376
377 #[test]
378 fn json_string_validates_trailing_backslash() {
379 let err = validate_in_context("hello\\", InjectionContext::JsonString).unwrap_err();
380 assert!(err.to_string().contains("trailing backslash"));
381 }
382
383 #[test]
384 fn xml_attribute_validates_unescaped_quote() {
385 let err = validate_in_context("hello\"world", InjectionContext::XmlAttribute).unwrap_err();
386 assert!(err.to_string().contains("unescaped double quote"));
387 }
388
389 #[test]
390 fn xml_attribute_allows_escaped_quote() {
391 assert!(validate_in_context("hello"world", InjectionContext::XmlAttribute).is_ok());
394 }
395
396 #[test]
397 fn header_value_validates_crlf() {
398 let err = encode_in_context(
399 b"hello\r\nworld",
400 Strategy::CaseAlternation,
401 InjectionContext::HeaderValue,
402 )
403 .unwrap_err();
404 assert!(err.to_string().contains("CR/LF"));
405 }
406
407 #[test]
408 fn cookie_value_escapes_crlf() {
409 let out = encode_in_context(
410 b"hello\r\nworld",
411 Strategy::CaseAlternation,
412 InjectionContext::CookieValue,
413 )
414 .unwrap();
415 assert!(out.contains("%0D") && out.contains("%0A"));
416 }
417
418 #[test]
419 fn multipart_field_validates_crlf() {
420 let err = encode_in_context(
421 b"hello\r\nworld",
422 Strategy::CaseAlternation,
423 InjectionContext::MultipartField,
424 )
425 .unwrap_err();
426 assert!(err.to_string().contains("CR/LF"));
427 }
428
429 #[test]
430 fn html_attribute_escapes_ampersand() {
431 let out = encode_in_context(
432 b"a&b",
433 Strategy::CaseAlternation,
434 InjectionContext::HtmlAttribute,
435 )
436 .unwrap();
437 assert!(out.contains("&"));
438 }
439
440 #[test]
441 fn url_query_escapes_space() {
442 let out = encode_in_context(
443 b"hello world",
444 Strategy::CaseAlternation,
445 InjectionContext::UrlQuery,
446 )
447 .unwrap();
448 assert!(!out.contains(' '));
449 }
450
451 #[test]
452 fn url_path_preserves_slash() {
453 let out = encode_in_context(
454 b"/api/v1",
455 Strategy::CaseAlternation,
456 InjectionContext::UrlPath,
457 )
458 .unwrap();
459 assert!(out.contains('/'));
460 }
461
462 #[test]
463 fn plain_body_no_structural_escaping() {
464 let out = encode_in_context(
466 b"<script>",
467 Strategy::CaseAlternation,
468 InjectionContext::PlainBody,
469 )
470 .unwrap();
471 assert_eq!(out, "<ScRiPt>");
472 }
473
474 #[test]
475 fn max_size_enforced() {
476 let big = vec![b'a'; 8 * 1024 * 1024 + 1];
477 let err = encode_in_context(&big, Strategy::CaseAlternation, InjectionContext::PlainBody)
478 .unwrap_err();
479 assert!(err.to_string().contains("too large"));
480 }
481
482 #[test]
483 fn xml_cdata_rejects_termination_sequence() {
484 let err = encode_in_context(
485 b"hello]]>world",
486 Strategy::CaseAlternation,
487 InjectionContext::XmlCdata,
488 )
489 .unwrap_err();
490 assert!(err.to_string().contains("CDATA"));
491 }
492
493 #[test]
494 fn multipart_filename_rejects_quote() {
495 let err = encode_in_context(
496 b"file\"name.txt",
497 Strategy::CaseAlternation,
498 InjectionContext::MultipartFileName,
499 )
500 .unwrap_err();
501 assert!(err.to_string().contains("quote"));
502 }
503
504 #[test]
505 fn json_number_rejects_non_numeric() {
506 let err = encode_in_context(
507 b"abc",
508 Strategy::CaseAlternation,
509 InjectionContext::JsonNumber,
510 )
511 .unwrap_err();
512 assert!(err.to_string().contains("not a valid JSON number"));
513 }
514
515 #[test]
516 fn empty_payload_valid_in_all_contexts() {
517 for ctx in [
518 InjectionContext::PlainBody,
519 InjectionContext::JsonString,
520 InjectionContext::XmlAttribute,
521 InjectionContext::HeaderValue,
522 InjectionContext::CookieValue,
523 ] {
524 assert!(
525 encode_in_context(b"", Strategy::UrlEncode, ctx).is_ok(),
526 "empty payload should be valid in {ctx:?}"
527 );
528 }
529 }
530}