1use crate::encoding::Strategy;
2use wafrift_types::injection_context::{ContextualEncodeError, InjectionContext};
3
4pub fn encode_in_context(
5 payload: &[u8],
6 strategy: Strategy,
7 context: InjectionContext,
8) -> Result<String, ContextualEncodeError> {
9 let max_size = match context {
10 InjectionContext::JsonString => 4 * 1024 * 1024,
11 InjectionContext::JsonNumber => 1024,
12 InjectionContext::XmlAttribute => 1024 * 1024,
13 InjectionContext::XmlCdata => 8 * 1024 * 1024,
14 InjectionContext::HeaderValue => 8 * 1024,
15 InjectionContext::CookieValue => 4 * 1024,
16 InjectionContext::MultipartFileName => 256,
17 _ => 8 * 1024 * 1024,
18 };
19
20 if payload.len() > max_size {
21 return Err(ContextualEncodeError::PayloadTooLarge {
22 context,
23 size: payload.len(),
24 max: max_size,
25 });
26 }
27
28 let base = match crate::encoding::encode(payload, strategy) {
29 Ok(s) => s,
30 Err(e) => {
31 return Err(match e {
32 crate::error::EncodeError::InvalidUtf8 => {
33 ContextualEncodeError::InvalidUtf8 { offset: 0 }
34 }
35 crate::error::EncodeError::PayloadTooLarge { max, actual } => {
36 ContextualEncodeError::PayloadTooLarge {
37 context,
38 size: actual,
39 max,
40 }
41 }
42 crate::error::EncodeError::LayeredOutputTooLarge { max, actual } => {
43 ContextualEncodeError::PayloadTooLarge {
44 context,
45 size: actual,
46 max,
47 }
48 }
49 crate::error::EncodeError::InvalidContext {
50 strategy: s,
51 context: _,
52 } => ContextualEncodeError::ContextIncompatible {
53 strategy: s.into(),
54 context,
55 reason: "strategy invalid for context".into(),
56 },
57 crate::error::EncodeError::InvalidConfig(msg) => {
58 ContextualEncodeError::ContextIncompatible {
59 strategy: "config".into(),
60 context,
61 reason: msg,
62 }
63 }
64 });
65 }
66 };
67
68 escape_for_context(&base, context)
69}
70
71pub fn escape_for_context(
72 input: &str,
73 context: InjectionContext,
74) -> Result<String, ContextualEncodeError> {
75 let escaped = match context {
76 InjectionContext::JsonString => {
77 let mut s = String::with_capacity(input.len() + 10);
78 for c in input.chars() {
79 match c {
80 '\\' => s.push_str("\\\\"),
81 '"' => s.push_str("\\\""),
82 '\n' => s.push_str("\\n"),
83 '\r' => s.push_str("\\r"),
84 '\t' => s.push_str("\\t"),
85 '\x00'..='\x1f' => s.push_str(&format!("\\u{:04x}", c as u32)),
86 '\u{2028}' => s.push_str("\\u2028"),
95 '\u{2029}' => s.push_str("\\u2029"),
96 _ => s.push(c),
97 }
98 }
99 s
100 }
101 InjectionContext::JsonNumber => {
102 if input.chars().any(|c| {
103 !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+'
104 }) {
105 return Err(ContextualEncodeError::ContextIncompatible {
106 strategy: "escape".into(),
107 context,
108 reason: "not a valid JSON number".into(),
109 });
110 }
111 input.to_string()
112 }
113 InjectionContext::XmlAttribute => {
114 if input.contains('\x00') {
115 return Err(ContextualEncodeError::ContextIncompatible {
116 strategy: "escape".into(),
117 context,
118 reason: "null byte in xml attribute".into(),
119 });
120 }
121 input
125 .replace('&', "&")
126 .replace('"', """)
127 .replace('\'', "'")
128 .replace('<', "<")
129 .replace('>', ">")
130 }
131 InjectionContext::XmlCdata => {
132 if input.contains("]]>") {
133 return Err(ContextualEncodeError::ContextIncompatible {
134 strategy: "escape".into(),
135 context,
136 reason: "CDATA cannot contain ]]>".into(),
137 });
138 }
139 input.to_string()
140 }
141 InjectionContext::XmlText => input
142 .replace('&', "&")
143 .replace('<', "<")
144 .replace('>', ">"),
145 InjectionContext::HtmlAttribute => input
146 .replace('&', "&")
147 .replace('"', """)
148 .replace('\'', "'")
149 .replace('<', "<"),
150 InjectionContext::HtmlText => input.replace('&', "&").replace('<', "<"),
151 InjectionContext::UrlQuery => urlencoding::encode(input).to_string(),
152 InjectionContext::UrlPath => urlencoding::encode(input).to_string().replace("%2F", "/"),
153 InjectionContext::UrlFragment => urlencoding::encode(input).to_string(),
154 InjectionContext::HeaderValue => {
155 if input.contains('\r') || input.contains('\n') {
156 return Err(ContextualEncodeError::ContextIncompatible {
157 strategy: "escape".into(),
158 context,
159 reason: "CR/LF in header value".into(),
160 });
161 }
162 if input.contains('\x00') {
163 return Err(ContextualEncodeError::ContextIncompatible {
164 strategy: "escape".into(),
165 context,
166 reason: "null byte in header value".into(),
167 });
168 }
169 input.to_string()
170 }
171 InjectionContext::CookieValue => input
172 .replace(';', "%3B")
178 .replace('=', "%3D")
179 .replace(' ', "%20")
180 .replace(',', "%2C")
181 .replace('"', "%22")
182 .replace('\\', "%5C")
183 .replace('\x00', "%00")
184 .replace('\r', "%0D")
185 .replace('\n', "%0A"),
186 InjectionContext::MultipartField => {
187 if input.contains('\r') || input.contains('\n') {
188 return Err(ContextualEncodeError::ContextIncompatible {
189 strategy: "escape".into(),
190 context,
191 reason: "CR/LF would break multipart structure".into(),
192 });
193 }
194 input.to_string()
195 }
196 InjectionContext::MultipartFileName => {
197 if input.contains('"') {
198 return Err(ContextualEncodeError::ContextIncompatible {
199 strategy: "escape".into(),
200 context,
201 reason: "quote in filename".into(),
202 });
203 }
204 if input.contains('\r') || input.contains('\n') {
205 return Err(ContextualEncodeError::ContextIncompatible {
206 strategy: "escape".into(),
207 context,
208 reason: "CR/LF in filename".into(),
209 });
210 }
211 input.to_string()
212 }
213 InjectionContext::PlainBody => input.to_string(),
214 _ => input.to_string(),
215 };
216 validate_in_context(&escaped, context)?;
217 Ok(escaped)
218}
219
220pub fn validate_in_context(
221 payload: &str,
222 context: InjectionContext,
223) -> Result<(), ContextualEncodeError> {
224 match context {
225 InjectionContext::JsonString => {
226 let mut chars = payload.chars().peekable();
227 while let Some(c) = chars.next() {
228 if c == '"' {
229 return Err(ContextualEncodeError::ContextIncompatible {
230 strategy: "validate".into(),
231 context,
232 reason: "unescaped double quote in JSON string".into(),
233 });
234 }
235 if c == '\\' {
236 let escaped = chars.next();
237 match escaped {
238 Some('\\' | '"' | 'n' | 'r' | 't' | 'b' | 'f' | '/') => {}
239 Some('u') => {
240 for _ in 0..4 {
242 match chars.next() {
243 Some(c) if c.is_ascii_hexdigit() => {}
244 _ => {
245 return Err(ContextualEncodeError::ContextIncompatible {
246 strategy: "validate".into(),
247 context,
248 reason: "invalid Unicode escape in JSON string".into(),
249 });
250 }
251 }
252 }
253 }
254 Some(other) => {
255 return Err(ContextualEncodeError::ContextIncompatible {
256 strategy: "validate".into(),
257 context,
258 reason: format!("invalid JSON escape sequence: \\{other}"),
259 });
260 }
261 None => {
262 return Err(ContextualEncodeError::ContextIncompatible {
263 strategy: "validate".into(),
264 context,
265 reason: "trailing backslash in JSON string".into(),
266 });
267 }
268 }
269 }
270 }
271 }
272 InjectionContext::XmlAttribute => {
273 let mut chars = payload.chars();
274 while let Some(c) = chars.next() {
275 if c == '"' {
276 return Err(ContextualEncodeError::ContextIncompatible {
277 strategy: "validate".into(),
278 context,
279 reason: "unescaped double quote in XML attribute".into(),
280 });
281 }
282 if c == '&' {
283 let remainder: String = chars.by_ref().take(6).collect();
285 if !remainder.starts_with("quot;")
286 && !remainder.starts_with("amp;")
287 && !remainder.starts_with("lt;")
288 && !remainder.starts_with("gt;")
289 {
290 }
294 }
295 }
296 }
297 InjectionContext::PlainBody => {
301 }
303 InjectionContext::XmlCdata
304 if payload.contains("]]>") => {
305 return Err(ContextualEncodeError::ContextIncompatible {
306 strategy: "validate".into(),
307 context,
308 reason: "CDATA payload contains `]]>` (unterminated section)".into(),
309 });
310 }
311 InjectionContext::XmlText => {
312 if payload.contains('<') {
313 return Err(ContextualEncodeError::ContextIncompatible {
314 strategy: "validate".into(),
315 context,
316 reason: "XML text payload contains unescaped `<`".into(),
317 });
318 }
319 reject_unescaped_ampersand(payload, context)?;
320 }
321 InjectionContext::HtmlAttribute => {
322 if payload.contains('<') {
323 return Err(ContextualEncodeError::ContextIncompatible {
324 strategy: "validate".into(),
325 context,
326 reason: "HTML attribute contains unescaped `<` — would close the attribute"
327 .into(),
328 });
329 }
330 if payload.contains('"') {
331 return Err(ContextualEncodeError::ContextIncompatible {
332 strategy: "validate".into(),
333 context,
334 reason: "HTML attribute contains unescaped `\"` — attribute breakout".into(),
335 });
336 }
337 if payload.contains('\'') {
338 return Err(ContextualEncodeError::ContextIncompatible {
339 strategy: "validate".into(),
340 context,
341 reason: "HTML attribute contains unescaped `'` — single-quoted attr breakout"
342 .into(),
343 });
344 }
345 reject_unescaped_ampersand(payload, context)?;
346 }
347 InjectionContext::HtmlText => {
348 if payload.contains('<') {
349 return Err(ContextualEncodeError::ContextIncompatible {
350 strategy: "validate".into(),
351 context,
352 reason: "HTML text contains unescaped `<` — would start a tag".into(),
353 });
354 }
355 reject_unescaped_ampersand(payload, context)?;
356 }
357 InjectionContext::UrlQuery | InjectionContext::UrlPath | InjectionContext::UrlFragment => {
358 }
361 InjectionContext::HeaderValue => {
362 }
365 InjectionContext::CookieValue => {
366 }
369 InjectionContext::MultipartField | InjectionContext::MultipartFileName => {
370 }
373 _ => {}
376 }
377 Ok(())
378}
379
380fn reject_unescaped_ampersand(
388 payload: &str,
389 context: InjectionContext,
390) -> Result<(), ContextualEncodeError> {
391 let bytes = payload.as_bytes();
392 let mut i = 0;
393 while i < bytes.len() {
394 if bytes[i] != b'&' {
395 i += 1;
396 continue;
397 }
398 let mut j = i + 1;
402 let max = (i + 12).min(bytes.len());
403 let mut saw_semicolon = false;
404 let mut valid_shape = true;
405 let first = bytes.get(j).copied();
406 if first == Some(b'#') {
407 j += 1;
408 let hex = bytes.get(j).copied() == Some(b'x') || bytes.get(j).copied() == Some(b'X');
409 if hex {
410 j += 1;
411 }
412 let mut digit_count = 0;
413 while j < max {
414 let b = bytes[j];
415 if b == b';' {
416 saw_semicolon = true;
417 j += 1;
418 break;
419 }
420 let ok = if hex { b.is_ascii_hexdigit() } else { b.is_ascii_digit() };
421 if !ok {
422 valid_shape = false;
423 break;
424 }
425 digit_count += 1;
426 j += 1;
427 }
428 if digit_count == 0 {
429 valid_shape = false;
430 }
431 } else if let Some(b) = first {
432 if b.is_ascii_alphabetic() {
433 while j < max {
434 let b = bytes[j];
435 if b == b';' {
436 saw_semicolon = true;
437 j += 1;
438 break;
439 }
440 if !b.is_ascii_alphanumeric() {
441 valid_shape = false;
442 break;
443 }
444 j += 1;
445 }
446 } else {
447 valid_shape = false;
448 }
449 } else {
450 valid_shape = false;
451 }
452 if !valid_shape || !saw_semicolon {
453 return Err(ContextualEncodeError::ContextIncompatible {
454 strategy: "validate".into(),
455 context,
456 reason: format!("unescaped `&` at byte {i} (no entity reference follows)"),
457 });
458 }
459 i = j;
460 }
461 Ok(())
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::encoding::Strategy;
468
469 #[test]
470 fn encode_error_mapping_payload_too_large() {
471 let result = encode_in_context(
475 b"\x80",
476 Strategy::CaseAlternation,
477 InjectionContext::PlainBody,
478 );
479 assert!(result.is_err());
481 let err = result.unwrap_err();
482 assert!(
483 err.to_string().contains("invalid") || err.to_string().contains("UTF-8"),
484 "error should mention invalid UTF-8, got: {err}"
485 );
486 }
487
488 #[test]
489 fn json_string_validates_unescaped_quote() {
490 let err = validate_in_context("hello\"world", InjectionContext::JsonString).unwrap_err();
491 assert!(err.to_string().contains("unescaped double quote"));
492 }
493
494 #[test]
495 fn json_string_validates_valid_escapes() {
496 assert!(validate_in_context("hello\\nworld", InjectionContext::JsonString).is_ok());
497 assert!(validate_in_context("hello\\tworld", InjectionContext::JsonString).is_ok());
498 assert!(validate_in_context("hello\\\\world", InjectionContext::JsonString).is_ok());
499 assert!(validate_in_context("hello\\\"world", InjectionContext::JsonString).is_ok());
500 }
501
502 #[test]
503 fn json_string_validates_unicode_escape() {
504 assert!(validate_in_context("\\u00e4", InjectionContext::JsonString).is_ok());
506 let err = validate_in_context("\\u00g4", InjectionContext::JsonString).unwrap_err();
508 assert!(err.to_string().contains("invalid Unicode escape"));
509 let err = validate_in_context("\\u00", InjectionContext::JsonString).unwrap_err();
511 assert!(err.to_string().contains("invalid Unicode escape"));
512 }
513
514 #[test]
515 fn json_string_validates_invalid_escape() {
516 let err = validate_in_context("\\x", InjectionContext::JsonString).unwrap_err();
517 assert!(err.to_string().contains("invalid JSON escape"));
518 }
519
520 #[test]
521 fn json_string_validates_trailing_backslash() {
522 let err = validate_in_context("hello\\", InjectionContext::JsonString).unwrap_err();
523 assert!(err.to_string().contains("trailing backslash"));
524 }
525
526 #[test]
527 fn xml_attribute_validates_unescaped_quote() {
528 let err = validate_in_context("hello\"world", InjectionContext::XmlAttribute).unwrap_err();
529 assert!(err.to_string().contains("unescaped double quote"));
530 }
531
532 #[test]
533 fn xml_attribute_allows_escaped_quote() {
534 assert!(validate_in_context("hello"world", InjectionContext::XmlAttribute).is_ok());
537 }
538
539 #[test]
540 fn header_value_validates_crlf() {
541 let err = encode_in_context(
542 b"hello\r\nworld",
543 Strategy::CaseAlternation,
544 InjectionContext::HeaderValue,
545 )
546 .unwrap_err();
547 assert!(err.to_string().contains("CR/LF"));
548 }
549
550 #[test]
551 fn cookie_value_escapes_crlf() {
552 let out = encode_in_context(
553 b"hello\r\nworld",
554 Strategy::CaseAlternation,
555 InjectionContext::CookieValue,
556 )
557 .unwrap();
558 assert!(out.contains("%0D") && out.contains("%0A"));
559 }
560
561 #[test]
562 fn multipart_field_validates_crlf() {
563 let err = encode_in_context(
564 b"hello\r\nworld",
565 Strategy::CaseAlternation,
566 InjectionContext::MultipartField,
567 )
568 .unwrap_err();
569 assert!(err.to_string().contains("CR/LF"));
570 }
571
572 #[test]
573 fn html_attribute_escapes_ampersand() {
574 let out = encode_in_context(
575 b"a&b",
576 Strategy::CaseAlternation,
577 InjectionContext::HtmlAttribute,
578 )
579 .unwrap();
580 assert!(out.contains("&"));
581 }
582
583 #[test]
584 fn url_query_escapes_space() {
585 let out = encode_in_context(
586 b"hello world",
587 Strategy::CaseAlternation,
588 InjectionContext::UrlQuery,
589 )
590 .unwrap();
591 assert!(!out.contains(' '));
592 }
593
594 #[test]
595 fn url_path_preserves_slash() {
596 let out = encode_in_context(
597 b"/api/v1",
598 Strategy::CaseAlternation,
599 InjectionContext::UrlPath,
600 )
601 .unwrap();
602 assert!(out.contains('/'));
603 }
604
605 #[test]
606 fn plain_body_no_structural_escaping() {
607 let out = encode_in_context(
609 b"<script>",
610 Strategy::CaseAlternation,
611 InjectionContext::PlainBody,
612 )
613 .unwrap();
614 assert_eq!(out, "<ScRiPt>");
615 }
616
617 #[test]
618 fn max_size_enforced() {
619 let big = vec![b'a'; 8 * 1024 * 1024 + 1];
620 let err = encode_in_context(&big, Strategy::CaseAlternation, InjectionContext::PlainBody)
621 .unwrap_err();
622 assert!(err.to_string().contains("too large"));
623 }
624
625 #[test]
626 fn xml_cdata_rejects_termination_sequence() {
627 let err = encode_in_context(
628 b"hello]]>world",
629 Strategy::CaseAlternation,
630 InjectionContext::XmlCdata,
631 )
632 .unwrap_err();
633 assert!(err.to_string().contains("CDATA"));
634 }
635
636 #[test]
637 fn multipart_filename_rejects_quote() {
638 let err = encode_in_context(
639 b"file\"name.txt",
640 Strategy::CaseAlternation,
641 InjectionContext::MultipartFileName,
642 )
643 .unwrap_err();
644 assert!(err.to_string().contains("quote"));
645 }
646
647 #[test]
648 fn json_number_rejects_non_numeric() {
649 let err = encode_in_context(
650 b"abc",
651 Strategy::CaseAlternation,
652 InjectionContext::JsonNumber,
653 )
654 .unwrap_err();
655 assert!(err.to_string().contains("not a valid JSON number"));
656 }
657
658 #[test]
659 fn empty_payload_valid_in_all_contexts() {
660 for ctx in [
661 InjectionContext::PlainBody,
662 InjectionContext::JsonString,
663 InjectionContext::XmlAttribute,
664 InjectionContext::HeaderValue,
665 InjectionContext::CookieValue,
666 ] {
667 assert!(
668 encode_in_context(b"", Strategy::UrlEncode, ctx).is_ok(),
669 "empty payload should be valid in {ctx:?}"
670 );
671 }
672 }
673}