secure_boundary/
safe_types.rs1use crate::{
8 attack_signal::{BoundaryViolation, ViolationKind},
9 error::BoundaryRejection,
10};
11use serde::{Deserialize, Deserializer};
12use std::fmt;
13
14fn emit_violation(kind: ViolationKind, code: &'static str) {
17 BoundaryViolation::new(kind, code).emit();
18}
19
20#[derive(Clone, Debug, PartialEq, Eq, Hash)]
39pub struct SafePath(String);
40
41impl SafePath {
42 #[must_use]
44 pub fn as_inner(&self) -> &str {
45 &self.0
46 }
47
48 #[must_use]
50 pub fn into_inner(self) -> String {
51 self.0
52 }
53}
54
55impl TryFrom<&str> for SafePath {
56 type Error = BoundaryRejection;
57
58 fn try_from(s: &str) -> Result<Self, Self::Error> {
59 if s.contains('\0') {
60 emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
61 return Err(BoundaryRejection::PathTraversal);
62 }
63 if s.starts_with('/') || s.starts_with('\\') {
64 emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
65 return Err(BoundaryRejection::PathTraversal);
66 }
67 if s.contains("../")
68 || s.contains("..\\")
69 || s == ".."
70 || s.ends_with("/..")
71 || s.ends_with("\\..")
72 {
73 emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
74 return Err(BoundaryRejection::PathTraversal);
75 }
76 let lower = s.to_lowercase();
78 if lower.contains("%2e%2e")
79 || lower.contains("%2f")
80 || lower.contains("%5c")
81 || lower.contains("..%2f")
82 || lower.contains("..%5c")
83 {
84 emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
85 return Err(BoundaryRejection::PathTraversal);
86 }
87 Ok(Self(s.to_owned()))
88 }
89}
90
91impl<'de> Deserialize<'de> for SafePath {
92 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
93 let s = String::deserialize(d)?;
94 SafePath::try_from(s.as_str()).map_err(serde::de::Error::custom)
95 }
96}
97
98impl fmt::Display for SafePath {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 f.write_str(&self.0)
101 }
102}
103
104#[derive(Clone, Debug, PartialEq, Eq, Hash)]
119pub struct SafeFilename(String);
120
121impl SafeFilename {
122 #[must_use]
124 pub fn as_inner(&self) -> &str {
125 &self.0
126 }
127
128 #[must_use]
130 pub fn into_inner(self) -> String {
131 self.0
132 }
133}
134
135impl TryFrom<&str> for SafeFilename {
136 type Error = BoundaryRejection;
137
138 fn try_from(s: &str) -> Result<Self, Self::Error> {
139 let reject = || {
140 emit_violation(ViolationKind::SyntaxViolation, "invalid_filename");
141 BoundaryRejection::InjectionAttempt {
142 code: "invalid_filename",
143 }
144 };
145 if s.is_empty() {
146 return Err(reject());
147 }
148 if s.contains('\0') {
149 return Err(reject());
150 }
151 if s.contains('/') || s.contains('\\') {
152 return Err(reject());
153 }
154 if s == ".." || s.starts_with("../") || s.starts_with("..\\") {
155 return Err(reject());
156 }
157 if s.chars()
159 .any(|c| matches!(c, ';' | '|' | '&' | '`' | '$' | '>' | '<'))
160 {
161 return Err(reject());
162 }
163 Ok(Self(s.to_owned()))
164 }
165}
166
167impl<'de> Deserialize<'de> for SafeFilename {
168 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
169 let s = String::deserialize(d)?;
170 SafeFilename::try_from(s.as_str()).map_err(serde::de::Error::custom)
171 }
172}
173
174impl fmt::Display for SafeFilename {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 f.write_str(&self.0)
177 }
178}
179
180#[derive(Clone, Debug, PartialEq, Eq, Hash)]
195pub struct SafeCommandArg(String);
196
197impl SafeCommandArg {
198 #[must_use]
200 pub fn as_inner(&self) -> &str {
201 &self.0
202 }
203
204 #[must_use]
206 pub fn into_inner(self) -> String {
207 self.0
208 }
209}
210
211impl TryFrom<&str> for SafeCommandArg {
212 type Error = BoundaryRejection;
213
214 fn try_from(s: &str) -> Result<Self, Self::Error> {
215 let reject = || {
216 emit_violation(ViolationKind::SyntaxViolation, "command_injection");
217 BoundaryRejection::InjectionAttempt {
218 code: "command_injection",
219 }
220 };
221 if s.chars()
222 .any(|c| matches!(c, ';' | '|' | '&' | '`' | '$' | '>' | '<' | '\n' | '\r'))
223 {
224 return Err(reject());
225 }
226 Ok(Self(s.to_owned()))
227 }
228}
229
230impl<'de> Deserialize<'de> for SafeCommandArg {
231 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
232 let s = String::deserialize(d)?;
233 SafeCommandArg::try_from(s.as_str()).map_err(serde::de::Error::custom)
234 }
235}
236
237impl fmt::Display for SafeCommandArg {
238 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239 f.write_str(&self.0)
240 }
241}
242
243#[derive(Clone, Debug, PartialEq, Eq, Hash)]
303pub struct SafeUrl(String);
304
305impl SafeUrl {
306 #[must_use]
308 pub fn as_inner(&self) -> &str {
309 &self.0
310 }
311
312 #[must_use]
314 pub fn into_inner(self) -> String {
315 self.0
316 }
317}
318
319impl TryFrom<&str> for SafeUrl {
320 type Error = BoundaryRejection;
321
322 fn try_from(s: &str) -> Result<Self, Self::Error> {
323 let reject = || {
324 emit_violation(ViolationKind::SyntaxViolation, "ssrf_attempt");
325 BoundaryRejection::SsrfAttempt
326 };
327
328 let lower = s.to_lowercase();
329
330 let is_http = lower.starts_with("http://");
331 let is_https = lower.starts_with("https://");
332
333 if !is_http && !is_https {
334 return Err(reject());
335 }
336
337 let prefix_len = if is_https {
338 "https://".len()
339 } else {
340 "http://".len()
341 };
342 let rest = &s[prefix_len..];
343 let host_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
344 let authority = &rest[..host_end];
345 let host_with_port = authority
346 .rsplit_once('@')
347 .map_or(authority, |(_, host_port)| host_port);
348
349 let host = if host_with_port.starts_with('[') {
351 let bracket_end = host_with_port
353 .find(']')
354 .map(|i| i + 1)
355 .unwrap_or(host_with_port.len());
356 &host_with_port[..bracket_end]
357 } else {
358 match host_with_port.rfind(':') {
360 Some(pos)
361 if host_with_port[pos + 1..]
362 .chars()
363 .all(|c| c.is_ascii_digit()) =>
364 {
365 &host_with_port[..pos]
366 }
367 _ => host_with_port,
368 }
369 };
370
371 if is_private_ip(host) {
372 return Err(reject());
373 }
374
375 Ok(Self(s.to_owned()))
376 }
377}
378
379impl<'de> Deserialize<'de> for SafeUrl {
380 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
381 let s = String::deserialize(d)?;
382 SafeUrl::try_from(s.as_str()).map_err(serde::de::Error::custom)
383 }
384}
385
386impl fmt::Display for SafeUrl {
387 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388 f.write_str(&self.0)
389 }
390}
391
392fn is_private_ip(host: &str) -> bool {
393 let host = host.trim_matches(|c| c == '[' || c == ']');
395
396 if let Ok(addr) = host.parse::<std::net::Ipv4Addr>() {
397 return is_private_ipv4(addr);
398 }
399 if let Ok(addr) = host.parse::<std::net::Ipv6Addr>() {
400 return is_private_ipv6(addr);
401 }
402 false
403}
404
405fn is_private_ipv4(addr: std::net::Ipv4Addr) -> bool {
406 let o = addr.octets();
407 o[0] == 127
409 || o[0] == 10
411 || (o[0] == 172 && o[1] >= 16 && o[1] <= 31)
413 || (o[0] == 192 && o[1] == 168)
415 || (o[0] == 169 && o[1] == 254)
417 || (o[0] >= 224 && o[0] <= 239)
419 || (o[0] == 0 && o[1] == 0 && o[2] == 0 && o[3] == 0)
421}
422
423fn is_private_ipv6(addr: std::net::Ipv6Addr) -> bool {
424 addr.is_loopback()
426 || addr.is_unspecified()
428 || (addr.segments()[0] & 0xfe00) == 0xfc00
430 || (addr.segments()[0] & 0xffc0) == 0xfe80
432 || (addr.segments()[0] & 0xff00) == 0xff00
434}
435
436#[derive(Clone, Debug, PartialEq, Eq, Hash)]
454pub struct SafeRedirectUrl(String);
455
456impl SafeRedirectUrl {
457 #[must_use]
459 pub fn as_inner(&self) -> &str {
460 &self.0
461 }
462
463 #[must_use]
465 pub fn into_inner(self) -> String {
466 self.0
467 }
468}
469
470impl TryFrom<&str> for SafeRedirectUrl {
471 type Error = BoundaryRejection;
472
473 fn try_from(s: &str) -> Result<Self, Self::Error> {
474 let reject = || {
475 emit_violation(ViolationKind::SyntaxViolation, "invalid_redirect");
476 BoundaryRejection::InjectionAttempt {
477 code: "invalid_redirect",
478 }
479 };
480 if !s.starts_with('/') || s.starts_with("//") {
482 return Err(reject());
483 }
484 if s.contains(':') {
486 return Err(reject());
487 }
488 Ok(Self(s.to_owned()))
489 }
490}
491
492impl<'de> Deserialize<'de> for SafeRedirectUrl {
493 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
494 let s = String::deserialize(d)?;
495 SafeRedirectUrl::try_from(s.as_str()).map_err(serde::de::Error::custom)
496 }
497}
498
499impl fmt::Display for SafeRedirectUrl {
500 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
501 f.write_str(&self.0)
502 }
503}
504
505#[derive(Clone, Debug, PartialEq, Eq, Hash)]
523pub struct SqlIdentifier(String);
524
525impl SqlIdentifier {
526 #[must_use]
528 pub fn as_inner(&self) -> &str {
529 &self.0
530 }
531
532 #[must_use]
534 pub fn into_inner(self) -> String {
535 self.0
536 }
537}
538
539impl TryFrom<&str> for SqlIdentifier {
540 type Error = BoundaryRejection;
541
542 fn try_from(s: &str) -> Result<Self, Self::Error> {
543 let reject = || {
544 emit_violation(ViolationKind::SyntaxViolation, "invalid_sql_identifier");
545 BoundaryRejection::InjectionAttempt {
546 code: "invalid_sql_identifier",
547 }
548 };
549 if s.is_empty() {
550 return Err(reject());
551 }
552 if s.len() > 128 {
553 return Err(reject());
554 }
555 let mut chars = s.chars();
556 let first = chars.next().expect("non-empty string has a first char");
557 if !first.is_ascii_alphabetic() && first != '_' {
558 return Err(reject());
559 }
560 if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
561 return Err(reject());
562 }
563 Ok(Self(s.to_owned()))
564 }
565}
566
567impl<'de> Deserialize<'de> for SqlIdentifier {
568 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
569 let s = String::deserialize(d)?;
570 SqlIdentifier::try_from(s.as_str()).map_err(serde::de::Error::custom)
571 }
572}
573
574impl fmt::Display for SqlIdentifier {
575 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
576 f.write_str(&self.0)
577 }
578}
579
580#[derive(Clone, Debug, PartialEq, Eq, Hash)]
597pub struct LdapSafeString(String);
598
599impl LdapSafeString {
600 #[must_use]
602 pub fn as_inner(&self) -> &str {
603 &self.0
604 }
605
606 #[must_use]
608 pub fn into_inner(self) -> String {
609 self.0
610 }
611}
612
613impl TryFrom<&str> for LdapSafeString {
614 type Error = BoundaryRejection;
615
616 fn try_from(s: &str) -> Result<Self, Self::Error> {
617 let mut escaped = String::with_capacity(s.len() * 2);
618 let mut had_special = false;
619
620 for c in s.chars() {
621 match c {
622 '\0' => {
623 escaped.push_str("\\00");
624 had_special = true;
625 }
626 '*' => {
627 escaped.push_str("\\2a");
628 had_special = true;
629 }
630 '(' => {
631 escaped.push_str("\\28");
632 had_special = true;
633 }
634 ')' => {
635 escaped.push_str("\\29");
636 had_special = true;
637 }
638 '\\' => {
639 escaped.push_str("\\5c");
640 had_special = true;
641 }
642 other => escaped.push(other),
643 }
644 }
645
646 if had_special {
647 emit_violation(ViolationKind::SyntaxViolation, "ldap_injection_chars");
648 }
649
650 Ok(Self(escaped))
651 }
652}
653
654impl<'de> Deserialize<'de> for LdapSafeString {
655 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
656 let s = String::deserialize(d)?;
657 LdapSafeString::try_from(s.as_str()).map_err(serde::de::Error::custom)
660 }
661}
662
663impl fmt::Display for LdapSafeString {
664 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
665 f.write_str(&self.0)
666 }
667}