1use crate::encoding::Strategy;
2use wafrift_types::injection_context::{ContextualEncodeError, InjectionContext};
3
4pub fn encode_in_context(
5 payload: &[u8],
6 strategy: Strategy,
7 context: InjectionContext,
8) -> Result<String, ContextualEncodeError> {
9 let max_size = match context {
10 InjectionContext::JsonString => 4 * 1024 * 1024,
11 InjectionContext::JsonNumber => 1024,
12 InjectionContext::XmlAttribute => 1024 * 1024,
13 InjectionContext::XmlCdata => 8 * 1024 * 1024,
14 InjectionContext::HeaderValue => 8 * 1024,
15 InjectionContext::CookieValue => 4 * 1024,
16 InjectionContext::MultipartFileName => 256,
17 _ => 8 * 1024 * 1024,
18 };
19
20 if payload.len() > max_size {
21 return Err(ContextualEncodeError::PayloadTooLarge {
22 context,
23 size: payload.len(),
24 max: max_size,
25 });
26 }
27
28 let base = match crate::encoding::encode(payload, strategy) {
29 Ok(s) => s,
30 Err(e) => {
31 return Err(match e {
32 crate::error::EncodeError::InvalidUtf8 => {
33 ContextualEncodeError::InvalidUtf8 { offset: 0 }
34 }
35 crate::error::EncodeError::PayloadTooLarge { max, actual } => {
36 ContextualEncodeError::PayloadTooLarge {
37 context,
38 size: actual,
39 max,
40 }
41 }
42 crate::error::EncodeError::LayeredOutputTooLarge { max, actual } => {
43 ContextualEncodeError::PayloadTooLarge {
44 context,
45 size: actual,
46 max,
47 }
48 }
49 crate::error::EncodeError::InvalidContext {
50 strategy: s,
51 context: _,
52 } => ContextualEncodeError::ContextIncompatible {
53 strategy: s.into(),
54 context,
55 reason: "strategy invalid for context".into(),
56 },
57 crate::error::EncodeError::InvalidConfig(msg) => {
58 ContextualEncodeError::ContextIncompatible {
59 strategy: "config".into(),
60 context,
61 reason: msg,
62 }
63 }
64 });
65 }
66 };
67
68 escape_for_context(&base, context)
69}
70
71pub fn escape_for_context(
72 input: &str,
73 context: InjectionContext,
74) -> Result<String, ContextualEncodeError> {
75 let escaped = match context {
76 InjectionContext::JsonString => {
77 let mut s = String::with_capacity(input.len() + 10);
78 for c in input.chars() {
79 match c {
80 '\\' => s.push_str("\\\\"),
81 '"' => s.push_str("\\\""),
82 '\n' => s.push_str("\\n"),
83 '\r' => s.push_str("\\r"),
84 '\t' => s.push_str("\\t"),
85 '\x00'..='\x1f' => s.push_str(&format!("\\u{:04x}", c as u32)),
86 '\u{2028}' => s.push_str("\\u2028"),
95 '\u{2029}' => s.push_str("\\u2029"),
96 _ => s.push(c),
97 }
98 }
99 s
100 }
101 InjectionContext::JsonNumber => {
102 if input.chars().any(|c| {
103 !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+'
104 }) {
105 return Err(ContextualEncodeError::ContextIncompatible {
106 strategy: "escape".into(),
107 context,
108 reason: "not a valid JSON number".into(),
109 });
110 }
111 input.to_string()
112 }
113 InjectionContext::XmlAttribute => {
114 if input.contains('\x00') {
115 return Err(ContextualEncodeError::ContextIncompatible {
116 strategy: "escape".into(),
117 context,
118 reason: "null byte in xml attribute".into(),
119 });
120 }
121 input
125 .replace('&', "&")
126 .replace('"', """)
127 .replace('\'', "'")
128 .replace('<', "<")
129 .replace('>', ">")
130 }
131 InjectionContext::XmlCdata => {
132 if input.contains("]]>") {
133 return Err(ContextualEncodeError::ContextIncompatible {
134 strategy: "escape".into(),
135 context,
136 reason: "CDATA cannot contain ]]>".into(),
137 });
138 }
139 input.to_string()
140 }
141 InjectionContext::XmlText => input
142 .replace('&', "&")
143 .replace('<', "<")
144 .replace('>', ">"),
145 InjectionContext::HtmlAttribute => input
146 .replace('&', "&")
147 .replace('"', """)
148 .replace('\'', "'")
149 .replace('<', "<"),
150 InjectionContext::HtmlText => input.replace('&', "&").replace('<', "<"),
151 InjectionContext::UrlQuery => urlencoding::encode(input).to_string(),
152 InjectionContext::UrlPath => urlencoding::encode(input).to_string().replace("%2F", "/"),
153 InjectionContext::UrlFragment => urlencoding::encode(input).to_string(),
154 InjectionContext::HeaderValue => {
155 if input.contains('\r') || input.contains('\n') {
156 return Err(ContextualEncodeError::ContextIncompatible {
157 strategy: "escape".into(),
158 context,
159 reason: "CR/LF in header value".into(),
160 });
161 }
162 if input.contains('\x00') {
163 return Err(ContextualEncodeError::ContextIncompatible {
164 strategy: "escape".into(),
165 context,
166 reason: "null byte in header value".into(),
167 });
168 }
169 input.to_string()
170 }
171 InjectionContext::CookieValue => input
172 .replace(';', "%3B")
178 .replace('=', "%3D")
179 .replace(' ', "%20")
180 .replace(',', "%2C")
181 .replace('"', "%22")
182 .replace('\\', "%5C")
183 .replace('\x00', "%00")
184 .replace('\r', "%0D")
185 .replace('\n', "%0A"),
186 InjectionContext::MultipartField => {
187 if input.contains('\r') || input.contains('\n') {
188 return Err(ContextualEncodeError::ContextIncompatible {
189 strategy: "escape".into(),
190 context,
191 reason: "CR/LF would break multipart structure".into(),
192 });
193 }
194 input.to_string()
195 }
196 InjectionContext::MultipartFileName => {
197 if input.contains('"') {
198 return Err(ContextualEncodeError::ContextIncompatible {
199 strategy: "escape".into(),
200 context,
201 reason: "quote in filename".into(),
202 });
203 }
204 if input.contains('\r') || input.contains('\n') {
205 return Err(ContextualEncodeError::ContextIncompatible {
206 strategy: "escape".into(),
207 context,
208 reason: "CR/LF in filename".into(),
209 });
210 }
211 input.to_string()
212 }
213 InjectionContext::PlainBody => input.to_string(),
214 _ => input.to_string(),
215 };
216 validate_in_context(&escaped, context)?;
217 Ok(escaped)
218}
219
220pub fn validate_in_context(
221 payload: &str,
222 context: InjectionContext,
223) -> Result<(), ContextualEncodeError> {
224 match context {
225 InjectionContext::JsonString => {
226 let mut chars = payload.chars().peekable();
227 while let Some(c) = chars.next() {
228 if c == '"' {
229 return Err(ContextualEncodeError::ContextIncompatible {
230 strategy: "validate".into(),
231 context,
232 reason: "unescaped double quote in JSON string".into(),
233 });
234 }
235 if c == '\\' {
236 let escaped = chars.next();
237 match escaped {
238 Some('\\') | Some('"') | Some('n') | Some('r') | Some('t') | Some('b')
239 | Some('f') | Some('/') => {}
240 Some('u') => {
241 for _ in 0..4 {
243 match chars.next() {
244 Some(c) if c.is_ascii_hexdigit() => {}
245 _ => {
246 return Err(ContextualEncodeError::ContextIncompatible {
247 strategy: "validate".into(),
248 context,
249 reason: "invalid Unicode escape in JSON string".into(),
250 });
251 }
252 }
253 }
254 }
255 Some(other) => {
256 return Err(ContextualEncodeError::ContextIncompatible {
257 strategy: "validate".into(),
258 context,
259 reason: format!("invalid JSON escape sequence: \\{other}"),
260 });
261 }
262 None => {
263 return Err(ContextualEncodeError::ContextIncompatible {
264 strategy: "validate".into(),
265 context,
266 reason: "trailing backslash in JSON string".into(),
267 });
268 }
269 }
270 }
271 }
272 }
273 InjectionContext::XmlAttribute => {
274 let mut chars = payload.chars();
275 while let Some(c) = chars.next() {
276 if c == '"' {
277 return Err(ContextualEncodeError::ContextIncompatible {
278 strategy: "validate".into(),
279 context,
280 reason: "unescaped double quote in XML attribute".into(),
281 });
282 }
283 if c == '&' {
284 let remainder: String = chars.by_ref().take(6).collect();
286 if !remainder.starts_with("quot;")
287 && !remainder.starts_with("amp;")
288 && !remainder.starts_with("lt;")
289 && !remainder.starts_with("gt;")
290 {
291 }
295 }
296 }
297 }
298 InjectionContext::PlainBody => {
302 }
304 InjectionContext::XmlCdata
305 if payload.contains("]]>") => {
306 return Err(ContextualEncodeError::ContextIncompatible {
307 strategy: "validate".into(),
308 context,
309 reason: "CDATA payload contains `]]>` (unterminated section)".into(),
310 });
311 }
312 InjectionContext::XmlText => {
313 if payload.contains('<') {
314 return Err(ContextualEncodeError::ContextIncompatible {
315 strategy: "validate".into(),
316 context,
317 reason: "XML text payload contains unescaped `<`".into(),
318 });
319 }
320 reject_unescaped_ampersand(payload, context)?;
321 }
322 InjectionContext::HtmlAttribute => {
323 if payload.contains('<') {
324 return Err(ContextualEncodeError::ContextIncompatible {
325 strategy: "validate".into(),
326 context,
327 reason: "HTML attribute contains unescaped `<` — would close the attribute"
328 .into(),
329 });
330 }
331 if payload.contains('"') {
332 return Err(ContextualEncodeError::ContextIncompatible {
333 strategy: "validate".into(),
334 context,
335 reason: "HTML attribute contains unescaped `\"` — attribute breakout".into(),
336 });
337 }
338 if payload.contains('\'') {
339 return Err(ContextualEncodeError::ContextIncompatible {
340 strategy: "validate".into(),
341 context,
342 reason: "HTML attribute contains unescaped `'` — single-quoted attr breakout"
343 .into(),
344 });
345 }
346 reject_unescaped_ampersand(payload, context)?;
347 }
348 InjectionContext::HtmlText => {
349 if payload.contains('<') {
350 return Err(ContextualEncodeError::ContextIncompatible {
351 strategy: "validate".into(),
352 context,
353 reason: "HTML text contains unescaped `<` — would start a tag".into(),
354 });
355 }
356 reject_unescaped_ampersand(payload, context)?;
357 }
358 InjectionContext::UrlQuery | InjectionContext::UrlPath | InjectionContext::UrlFragment => {
359 }
362 InjectionContext::HeaderValue => {
363 }
366 InjectionContext::CookieValue => {
367 }
370 InjectionContext::MultipartField | InjectionContext::MultipartFileName => {
371 }
374 _ => {}
377 }
378 Ok(())
379}
380
381fn reject_unescaped_ampersand(
389 payload: &str,
390 context: InjectionContext,
391) -> Result<(), ContextualEncodeError> {
392 let bytes = payload.as_bytes();
393 let mut i = 0;
394 while i < bytes.len() {
395 if bytes[i] != b'&' {
396 i += 1;
397 continue;
398 }
399 let mut j = i + 1;
403 let max = (i + 12).min(bytes.len());
404 let mut saw_semicolon = false;
405 let mut valid_shape = true;
406 let first = bytes.get(j).copied();
407 if first == Some(b'#') {
408 j += 1;
409 let hex = bytes.get(j).copied() == Some(b'x') || bytes.get(j).copied() == Some(b'X');
410 if hex {
411 j += 1;
412 }
413 let mut digit_count = 0;
414 while j < max {
415 let b = bytes[j];
416 if b == b';' {
417 saw_semicolon = true;
418 j += 1;
419 break;
420 }
421 let ok = if hex { b.is_ascii_hexdigit() } else { b.is_ascii_digit() };
422 if !ok {
423 valid_shape = false;
424 break;
425 }
426 digit_count += 1;
427 j += 1;
428 }
429 if digit_count == 0 {
430 valid_shape = false;
431 }
432 } else if let Some(b) = first {
433 if !b.is_ascii_alphabetic() {
434 valid_shape = false;
435 } else {
436 while j < max {
437 let b = bytes[j];
438 if b == b';' {
439 saw_semicolon = true;
440 j += 1;
441 break;
442 }
443 if !b.is_ascii_alphanumeric() {
444 valid_shape = false;
445 break;
446 }
447 j += 1;
448 }
449 }
450 } else {
451 valid_shape = false;
452 }
453 if !valid_shape || !saw_semicolon {
454 return Err(ContextualEncodeError::ContextIncompatible {
455 strategy: "validate".into(),
456 context,
457 reason: format!("unescaped `&` at byte {i} (no entity reference follows)"),
458 });
459 }
460 i = j;
461 }
462 Ok(())
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468 use crate::encoding::Strategy;
469
470 #[test]
471 fn encode_error_mapping_payload_too_large() {
472 let result = encode_in_context(
476 b"\x80",
477 Strategy::CaseAlternation,
478 InjectionContext::PlainBody,
479 );
480 assert!(result.is_err());
482 let err = result.unwrap_err();
483 assert!(
484 err.to_string().contains("invalid") || err.to_string().contains("UTF-8"),
485 "error should mention invalid UTF-8, got: {}",
486 err
487 );
488 }
489
490 #[test]
491 fn json_string_validates_unescaped_quote() {
492 let err = validate_in_context("hello\"world", InjectionContext::JsonString).unwrap_err();
493 assert!(err.to_string().contains("unescaped double quote"));
494 }
495
496 #[test]
497 fn json_string_validates_valid_escapes() {
498 assert!(validate_in_context("hello\\nworld", InjectionContext::JsonString).is_ok());
499 assert!(validate_in_context("hello\\tworld", InjectionContext::JsonString).is_ok());
500 assert!(validate_in_context("hello\\\\world", InjectionContext::JsonString).is_ok());
501 assert!(validate_in_context("hello\\\"world", InjectionContext::JsonString).is_ok());
502 }
503
504 #[test]
505 fn json_string_validates_unicode_escape() {
506 assert!(validate_in_context("\\u00e4", InjectionContext::JsonString).is_ok());
508 let err = validate_in_context("\\u00g4", InjectionContext::JsonString).unwrap_err();
510 assert!(err.to_string().contains("invalid Unicode escape"));
511 let err = validate_in_context("\\u00", InjectionContext::JsonString).unwrap_err();
513 assert!(err.to_string().contains("invalid Unicode escape"));
514 }
515
516 #[test]
517 fn json_string_validates_invalid_escape() {
518 let err = validate_in_context("\\x", InjectionContext::JsonString).unwrap_err();
519 assert!(err.to_string().contains("invalid JSON escape"));
520 }
521
522 #[test]
523 fn json_string_validates_trailing_backslash() {
524 let err = validate_in_context("hello\\", InjectionContext::JsonString).unwrap_err();
525 assert!(err.to_string().contains("trailing backslash"));
526 }
527
528 #[test]
529 fn xml_attribute_validates_unescaped_quote() {
530 let err = validate_in_context("hello\"world", InjectionContext::XmlAttribute).unwrap_err();
531 assert!(err.to_string().contains("unescaped double quote"));
532 }
533
534 #[test]
535 fn xml_attribute_allows_escaped_quote() {
536 assert!(validate_in_context("hello"world", InjectionContext::XmlAttribute).is_ok());
539 }
540
541 #[test]
542 fn header_value_validates_crlf() {
543 let err = encode_in_context(
544 b"hello\r\nworld",
545 Strategy::CaseAlternation,
546 InjectionContext::HeaderValue,
547 )
548 .unwrap_err();
549 assert!(err.to_string().contains("CR/LF"));
550 }
551
552 #[test]
553 fn cookie_value_escapes_crlf() {
554 let out = encode_in_context(
555 b"hello\r\nworld",
556 Strategy::CaseAlternation,
557 InjectionContext::CookieValue,
558 )
559 .unwrap();
560 assert!(out.contains("%0D") && out.contains("%0A"));
561 }
562
563 #[test]
564 fn multipart_field_validates_crlf() {
565 let err = encode_in_context(
566 b"hello\r\nworld",
567 Strategy::CaseAlternation,
568 InjectionContext::MultipartField,
569 )
570 .unwrap_err();
571 assert!(err.to_string().contains("CR/LF"));
572 }
573
574 #[test]
575 fn html_attribute_escapes_ampersand() {
576 let out = encode_in_context(
577 b"a&b",
578 Strategy::CaseAlternation,
579 InjectionContext::HtmlAttribute,
580 )
581 .unwrap();
582 assert!(out.contains("&"));
583 }
584
585 #[test]
586 fn url_query_escapes_space() {
587 let out = encode_in_context(
588 b"hello world",
589 Strategy::CaseAlternation,
590 InjectionContext::UrlQuery,
591 )
592 .unwrap();
593 assert!(!out.contains(' '));
594 }
595
596 #[test]
597 fn url_path_preserves_slash() {
598 let out = encode_in_context(
599 b"/api/v1",
600 Strategy::CaseAlternation,
601 InjectionContext::UrlPath,
602 )
603 .unwrap();
604 assert!(out.contains('/'));
605 }
606
607 #[test]
608 fn plain_body_no_structural_escaping() {
609 let out = encode_in_context(
611 b"<script>",
612 Strategy::CaseAlternation,
613 InjectionContext::PlainBody,
614 )
615 .unwrap();
616 assert_eq!(out, "<ScRiPt>");
617 }
618
619 #[test]
620 fn max_size_enforced() {
621 let big = vec![b'a'; 8 * 1024 * 1024 + 1];
622 let err = encode_in_context(&big, Strategy::CaseAlternation, InjectionContext::PlainBody)
623 .unwrap_err();
624 assert!(err.to_string().contains("too large"));
625 }
626
627 #[test]
628 fn xml_cdata_rejects_termination_sequence() {
629 let err = encode_in_context(
630 b"hello]]>world",
631 Strategy::CaseAlternation,
632 InjectionContext::XmlCdata,
633 )
634 .unwrap_err();
635 assert!(err.to_string().contains("CDATA"));
636 }
637
638 #[test]
639 fn multipart_filename_rejects_quote() {
640 let err = encode_in_context(
641 b"file\"name.txt",
642 Strategy::CaseAlternation,
643 InjectionContext::MultipartFileName,
644 )
645 .unwrap_err();
646 assert!(err.to_string().contains("quote"));
647 }
648
649 #[test]
650 fn json_number_rejects_non_numeric() {
651 let err = encode_in_context(
652 b"abc",
653 Strategy::CaseAlternation,
654 InjectionContext::JsonNumber,
655 )
656 .unwrap_err();
657 assert!(err.to_string().contains("not a valid JSON number"));
658 }
659
660 #[test]
661 fn empty_payload_valid_in_all_contexts() {
662 for ctx in [
663 InjectionContext::PlainBody,
664 InjectionContext::JsonString,
665 InjectionContext::XmlAttribute,
666 InjectionContext::HeaderValue,
667 InjectionContext::CookieValue,
668 ] {
669 assert!(
670 encode_in_context(b"", Strategy::UrlEncode, ctx).is_ok(),
671 "empty payload should be valid in {ctx:?}"
672 );
673 }
674 }
675}