Skip to main content

secure_boundary/
safe_types.rs

1//! Type-safe input wrappers that reject dangerous values at construction time.
2//!
3//! Every type validates in both [`TryFrom<&str>`] and [`serde::Deserialize`].
4//! Invalid input emits a [`BoundaryViolation`] security event before rejection.
5//! No `Deref` is implemented — callers must use `as_inner()` / `into_inner()`.
6
7use crate::{
8    attack_signal::{BoundaryViolation, ViolationKind},
9    error::BoundaryRejection,
10};
11use serde::{Deserialize, Deserializer};
12use std::fmt;
13
14// ── Internal helpers ──────────────────────────────────────────────────────────
15
16fn emit_violation(kind: ViolationKind, code: &'static str) {
17    BoundaryViolation::new(kind, code).emit();
18}
19
20// ── SafePath ──────────────────────────────────────────────────────────────────
21
22/// A validated relative file-system path.
23///
24/// Rejects directory traversal (`../`, `..\`), absolute paths, null bytes,
25/// and percent-encoded traversal sequences.
26///
27/// # Examples
28///
29/// ```
30/// use secure_boundary::safe_types::SafePath;
31///
32/// let path = SafePath::try_from("uploads/photo.jpg").unwrap();
33/// assert_eq!(path.as_inner(), "uploads/photo.jpg");
34///
35/// // Traversal attempts are rejected.
36/// assert!(SafePath::try_from("../../etc/passwd").is_err());
37/// ```
38#[derive(Clone, Debug, PartialEq, Eq, Hash)]
39pub struct SafePath(String);
40
41impl SafePath {
42    /// Returns a reference to the inner path string.
43    #[must_use]
44    pub fn as_inner(&self) -> &str {
45        &self.0
46    }
47
48    /// Consumes the wrapper and returns the inner path string.
49    #[must_use]
50    pub fn into_inner(self) -> String {
51        self.0
52    }
53}
54
55impl TryFrom<&str> for SafePath {
56    type Error = BoundaryRejection;
57
58    fn try_from(s: &str) -> Result<Self, Self::Error> {
59        if s.contains('\0') {
60            emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
61            return Err(BoundaryRejection::PathTraversal);
62        }
63        if s.starts_with('/') || s.starts_with('\\') {
64            emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
65            return Err(BoundaryRejection::PathTraversal);
66        }
67        if s.contains("../")
68            || s.contains("..\\")
69            || s == ".."
70            || s.ends_with("/..")
71            || s.ends_with("\\..")
72        {
73            emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
74            return Err(BoundaryRejection::PathTraversal);
75        }
76        // Reject percent-encoded traversal sequences
77        let lower = s.to_lowercase();
78        if lower.contains("%2e%2e")
79            || lower.contains("%2f")
80            || lower.contains("%5c")
81            || lower.contains("..%2f")
82            || lower.contains("..%5c")
83        {
84            emit_violation(ViolationKind::SyntaxViolation, "path_traversal");
85            return Err(BoundaryRejection::PathTraversal);
86        }
87        Ok(Self(s.to_owned()))
88    }
89}
90
91impl<'de> Deserialize<'de> for SafePath {
92    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
93        let s = String::deserialize(d)?;
94        SafePath::try_from(s.as_str()).map_err(serde::de::Error::custom)
95    }
96}
97
98impl fmt::Display for SafePath {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        f.write_str(&self.0)
101    }
102}
103
104// ── SafeFilename ──────────────────────────────────────────────────────────────
105
106/// A validated filename (no path separators, shell metacharacters, or traversal).
107///
108/// # Examples
109///
110/// ```
111/// use secure_boundary::safe_types::SafeFilename;
112///
113/// let name = SafeFilename::try_from("report.pdf").unwrap();
114/// assert_eq!(name.as_inner(), "report.pdf");
115///
116/// assert!(SafeFilename::try_from("../evil").is_err());
117/// ```
118#[derive(Clone, Debug, PartialEq, Eq, Hash)]
119pub struct SafeFilename(String);
120
121impl SafeFilename {
122    /// Returns a reference to the inner filename string.
123    #[must_use]
124    pub fn as_inner(&self) -> &str {
125        &self.0
126    }
127
128    /// Consumes the wrapper and returns the inner filename string.
129    #[must_use]
130    pub fn into_inner(self) -> String {
131        self.0
132    }
133}
134
135impl TryFrom<&str> for SafeFilename {
136    type Error = BoundaryRejection;
137
138    fn try_from(s: &str) -> Result<Self, Self::Error> {
139        let reject = || {
140            emit_violation(ViolationKind::SyntaxViolation, "invalid_filename");
141            BoundaryRejection::InjectionAttempt {
142                code: "invalid_filename",
143            }
144        };
145        if s.is_empty() {
146            return Err(reject());
147        }
148        if s.contains('\0') {
149            return Err(reject());
150        }
151        if s.contains('/') || s.contains('\\') {
152            return Err(reject());
153        }
154        if s == ".." || s.starts_with("../") || s.starts_with("..\\") {
155            return Err(reject());
156        }
157        // Shell metacharacters
158        if s.chars()
159            .any(|c| matches!(c, ';' | '|' | '&' | '`' | '$' | '>' | '<'))
160        {
161            return Err(reject());
162        }
163        Ok(Self(s.to_owned()))
164    }
165}
166
167impl<'de> Deserialize<'de> for SafeFilename {
168    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
169        let s = String::deserialize(d)?;
170        SafeFilename::try_from(s.as_str()).map_err(serde::de::Error::custom)
171    }
172}
173
174impl fmt::Display for SafeFilename {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        f.write_str(&self.0)
177    }
178}
179
180// ── SafeCommandArg ────────────────────────────────────────────────────────────
181
182/// A validated command-line argument with shell injection characters rejected.
183///
184/// # Examples
185///
186/// ```
187/// use secure_boundary::safe_types::SafeCommandArg;
188///
189/// let arg = SafeCommandArg::try_from("backup-2024").unwrap();
190/// assert_eq!(arg.as_inner(), "backup-2024");
191///
192/// assert!(SafeCommandArg::try_from("file; rm -rf /").is_err());
193/// ```
194#[derive(Clone, Debug, PartialEq, Eq, Hash)]
195pub struct SafeCommandArg(String);
196
197impl SafeCommandArg {
198    /// Returns a reference to the inner argument string.
199    #[must_use]
200    pub fn as_inner(&self) -> &str {
201        &self.0
202    }
203
204    /// Consumes the wrapper and returns the inner argument string.
205    #[must_use]
206    pub fn into_inner(self) -> String {
207        self.0
208    }
209}
210
211impl TryFrom<&str> for SafeCommandArg {
212    type Error = BoundaryRejection;
213
214    fn try_from(s: &str) -> Result<Self, Self::Error> {
215        let reject = || {
216            emit_violation(ViolationKind::SyntaxViolation, "command_injection");
217            BoundaryRejection::InjectionAttempt {
218                code: "command_injection",
219            }
220        };
221        if s.chars()
222            .any(|c| matches!(c, ';' | '|' | '&' | '`' | '$' | '>' | '<' | '\n' | '\r'))
223        {
224            return Err(reject());
225        }
226        Ok(Self(s.to_owned()))
227    }
228}
229
230impl<'de> Deserialize<'de> for SafeCommandArg {
231    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
232        let s = String::deserialize(d)?;
233        SafeCommandArg::try_from(s.as_str()).map_err(serde::de::Error::custom)
234    }
235}
236
237impl fmt::Display for SafeCommandArg {
238    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239        f.write_str(&self.0)
240    }
241}
242
243// ── SafeUrl ───────────────────────────────────────────────────────────────────
244
245/// A validated URL that rejects dangerous schemes and URLs resolving to
246/// network ranges that enable server-side request forgery.
247///
248/// # Allowed schemes
249/// Only `http` and `https`. Rejects `file://`, `gopher://`, `javascript:`,
250/// `data:`, and any other non-http(s) scheme.
251///
252/// # Blocked host ranges (SSRF prevention)
253///
254/// Every URL whose host string parses as one of the following IP families
255/// is rejected with [`BoundaryRejection::SsrfAttempt`]:
256///
257/// | CIDR | What it is | Why it's blocked |
258/// |---|---|---|
259/// | `10.0.0.0/8` | RFC 1918 private | Classic LAN SSRF |
260/// | `172.16.0.0/12` | RFC 1918 private | Classic LAN SSRF |
261/// | `192.168.0.0/16` | RFC 1918 private | Classic LAN SSRF |
262/// | `169.254.0.0/16` | IPv4 link-local | AWS IMDS (`169.254.169.254`) — credential exfiltration |
263/// | `127.0.0.0/8` | IPv4 loopback | Bypass to localhost services |
264/// | `224.0.0.0/4` | IPv4 multicast | Lateral-movement response surface |
265/// | `0.0.0.0/32` | IPv4 unspecified | Stack-internal vulnerabilities |
266/// | `fc00::/7` | IPv6 Unique Local Address | Analogue of RFC 1918 on IPv6 |
267/// | `fe80::/10` | IPv6 link-local | IPv6 analogue of IMDS attack vector |
268/// | `::1/128` | IPv6 loopback | Bypass to localhost services on IPv6 |
269/// | `ff00::/8` | IPv6 multicast | Same as IPv4 multicast, on IPv6 |
270/// | `::/128` | IPv6 unspecified | Stack-internal vulnerabilities |
271///
272/// The blocked set is variant-analysis-tested — each CIDR has a named
273/// regression test in `sg_gate_a_safeurl_cidrs.rs`, so removing a single
274/// line from the internal classifier fails a specific, named test.
275///
276/// DNS rebinding is **not** prevented by `SafeUrl` alone; validate only
277/// accepts a host *string*. If you resolve and connect, perform a fresh
278/// `is_private_ip` check on the resolved address, or pin to a specific
279/// resolver policy.
280///
281/// # Examples
282///
283/// ```
284/// use secure_boundary::safe_types::SafeUrl;
285///
286/// // Public URL — accepted.
287/// let url = SafeUrl::try_from("https://example.com/api").unwrap();
288/// assert_eq!(url.as_inner(), "https://example.com/api");
289///
290/// // Loopback — rejected.
291/// assert!(SafeUrl::try_from("http://127.0.0.1/admin").is_err());
292///
293/// // AWS IMDS — rejected.
294/// assert!(SafeUrl::try_from("http://169.254.169.254/latest/meta-data").is_err());
295///
296/// // IPv6 link-local — rejected.
297/// assert!(SafeUrl::try_from("http://[fe80::1]/").is_err());
298///
299/// // Dangerous scheme — rejected.
300/// assert!(SafeUrl::try_from("javascript:alert(1)").is_err());
301/// ```
302#[derive(Clone, Debug, PartialEq, Eq, Hash)]
303pub struct SafeUrl(String);
304
305impl SafeUrl {
306    /// Returns a reference to the inner URL string.
307    #[must_use]
308    pub fn as_inner(&self) -> &str {
309        &self.0
310    }
311
312    /// Consumes the wrapper and returns the inner URL string.
313    #[must_use]
314    pub fn into_inner(self) -> String {
315        self.0
316    }
317}
318
319impl TryFrom<&str> for SafeUrl {
320    type Error = BoundaryRejection;
321
322    fn try_from(s: &str) -> Result<Self, Self::Error> {
323        let reject = || {
324            emit_violation(ViolationKind::SyntaxViolation, "ssrf_attempt");
325            BoundaryRejection::SsrfAttempt
326        };
327
328        let lower = s.to_lowercase();
329
330        let is_http = lower.starts_with("http://");
331        let is_https = lower.starts_with("https://");
332
333        if !is_http && !is_https {
334            return Err(reject());
335        }
336
337        let prefix_len = if is_https {
338            "https://".len()
339        } else {
340            "http://".len()
341        };
342        let rest = &s[prefix_len..];
343        let host_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
344        let authority = &rest[..host_end];
345        let host_with_port = authority
346            .rsplit_once('@')
347            .map_or(authority, |(_, host_port)| host_port);
348
349        // Strip IPv6 brackets or port suffix
350        let host = if host_with_port.starts_with('[') {
351            // IPv6: [::1] or [::1]:8080
352            let bracket_end = host_with_port
353                .find(']')
354                .map(|i| i + 1)
355                .unwrap_or(host_with_port.len());
356            &host_with_port[..bracket_end]
357        } else {
358            // IPv4 or hostname: strip port
359            match host_with_port.rfind(':') {
360                Some(pos)
361                    if host_with_port[pos + 1..]
362                        .chars()
363                        .all(|c| c.is_ascii_digit()) =>
364                {
365                    &host_with_port[..pos]
366                }
367                _ => host_with_port,
368            }
369        };
370
371        if is_private_ip(host) {
372            return Err(reject());
373        }
374
375        Ok(Self(s.to_owned()))
376    }
377}
378
379impl<'de> Deserialize<'de> for SafeUrl {
380    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
381        let s = String::deserialize(d)?;
382        SafeUrl::try_from(s.as_str()).map_err(serde::de::Error::custom)
383    }
384}
385
386impl fmt::Display for SafeUrl {
387    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388        f.write_str(&self.0)
389    }
390}
391
392fn is_private_ip(host: &str) -> bool {
393    // Strip IPv6 brackets for parsing
394    let host = host.trim_matches(|c| c == '[' || c == ']');
395
396    if let Ok(addr) = host.parse::<std::net::Ipv4Addr>() {
397        return is_private_ipv4(addr);
398    }
399    if let Ok(addr) = host.parse::<std::net::Ipv6Addr>() {
400        return is_private_ipv6(addr);
401    }
402    false
403}
404
405fn is_private_ipv4(addr: std::net::Ipv4Addr) -> bool {
406    let o = addr.octets();
407    // 127.0.0.0/8  loopback
408    o[0] == 127
409    // 10.0.0.0/8
410    || o[0] == 10
411    // 172.16.0.0/12
412    || (o[0] == 172 && o[1] >= 16 && o[1] <= 31)
413    // 192.168.0.0/16
414    || (o[0] == 192 && o[1] == 168)
415    // 169.254.0.0/16  link-local (AWS IMDS 169.254.169.254 and neighbours)
416    || (o[0] == 169 && o[1] == 254)
417    // 224.0.0.0/4  IPv4 multicast (lateral-movement response surface)
418    || (o[0] >= 224 && o[0] <= 239)
419    // 0.0.0.0/32  unspecified
420    || (o[0] == 0 && o[1] == 0 && o[2] == 0 && o[3] == 0)
421}
422
423fn is_private_ipv6(addr: std::net::Ipv6Addr) -> bool {
424    // ::1/128 loopback
425    addr.is_loopback()
426    // ::/128 unspecified
427    || addr.is_unspecified()
428    // fc00::/7 unique local
429    || (addr.segments()[0] & 0xfe00) == 0xfc00
430    // fe80::/10 link-local (IPv6 analogue of the IMDS attack vector)
431    || (addr.segments()[0] & 0xffc0) == 0xfe80
432    // ff00::/8 multicast
433    || (addr.segments()[0] & 0xff00) == 0xff00
434}
435
436// ── SafeRedirectUrl ───────────────────────────────────────────────────────────
437
438/// A validated redirect URL that only allows relative paths (no open redirect).
439///
440/// Must start with `/` but not `//`, and must not contain a scheme colon.
441///
442/// # Examples
443///
444/// ```
445/// use secure_boundary::safe_types::SafeRedirectUrl;
446///
447/// let url = SafeRedirectUrl::try_from("/dashboard").unwrap();
448/// assert_eq!(url.as_inner(), "/dashboard");
449///
450/// // External URLs are rejected (open redirect prevention).
451/// assert!(SafeRedirectUrl::try_from("//evil.com").is_err());
452/// ```
453#[derive(Clone, Debug, PartialEq, Eq, Hash)]
454pub struct SafeRedirectUrl(String);
455
456impl SafeRedirectUrl {
457    /// Returns a reference to the inner URL string.
458    #[must_use]
459    pub fn as_inner(&self) -> &str {
460        &self.0
461    }
462
463    /// Consumes the wrapper and returns the inner URL string.
464    #[must_use]
465    pub fn into_inner(self) -> String {
466        self.0
467    }
468}
469
470impl TryFrom<&str> for SafeRedirectUrl {
471    type Error = BoundaryRejection;
472
473    fn try_from(s: &str) -> Result<Self, Self::Error> {
474        let reject = || {
475            emit_violation(ViolationKind::SyntaxViolation, "invalid_redirect");
476            BoundaryRejection::InjectionAttempt {
477                code: "invalid_redirect",
478            }
479        };
480        // Must be a relative path starting with /
481        if !s.starts_with('/') || s.starts_with("//") {
482            return Err(reject());
483        }
484        // No scheme separator allowed
485        if s.contains(':') {
486            return Err(reject());
487        }
488        Ok(Self(s.to_owned()))
489    }
490}
491
492impl<'de> Deserialize<'de> for SafeRedirectUrl {
493    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
494        let s = String::deserialize(d)?;
495        SafeRedirectUrl::try_from(s.as_str()).map_err(serde::de::Error::custom)
496    }
497}
498
499impl fmt::Display for SafeRedirectUrl {
500    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
501        f.write_str(&self.0)
502    }
503}
504
505// ── SqlIdentifier ─────────────────────────────────────────────────────────────
506
507/// A validated SQL identifier (alphanumeric + underscore, max 128 chars).
508///
509/// Rejects anything that is not a valid SQL identifier: must start with a letter
510/// or underscore, must contain only `[A-Za-z0-9_]`, maximum 128 characters.
511///
512/// # Examples
513///
514/// ```
515/// use secure_boundary::safe_types::SqlIdentifier;
516///
517/// let id = SqlIdentifier::try_from("users_table").unwrap();
518/// assert_eq!(id.as_inner(), "users_table");
519///
520/// assert!(SqlIdentifier::try_from("DROP TABLE;").is_err());
521/// ```
522#[derive(Clone, Debug, PartialEq, Eq, Hash)]
523pub struct SqlIdentifier(String);
524
525impl SqlIdentifier {
526    /// Returns a reference to the inner identifier string.
527    #[must_use]
528    pub fn as_inner(&self) -> &str {
529        &self.0
530    }
531
532    /// Consumes the wrapper and returns the inner identifier string.
533    #[must_use]
534    pub fn into_inner(self) -> String {
535        self.0
536    }
537}
538
539impl TryFrom<&str> for SqlIdentifier {
540    type Error = BoundaryRejection;
541
542    fn try_from(s: &str) -> Result<Self, Self::Error> {
543        let reject = || {
544            emit_violation(ViolationKind::SyntaxViolation, "invalid_sql_identifier");
545            BoundaryRejection::InjectionAttempt {
546                code: "invalid_sql_identifier",
547            }
548        };
549        if s.is_empty() {
550            return Err(reject());
551        }
552        if s.len() > 128 {
553            return Err(reject());
554        }
555        let mut chars = s.chars();
556        let first = chars.next().expect("non-empty string has a first char");
557        if !first.is_ascii_alphabetic() && first != '_' {
558            return Err(reject());
559        }
560        if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
561            return Err(reject());
562        }
563        Ok(Self(s.to_owned()))
564    }
565}
566
567impl<'de> Deserialize<'de> for SqlIdentifier {
568    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
569        let s = String::deserialize(d)?;
570        SqlIdentifier::try_from(s.as_str()).map_err(serde::de::Error::custom)
571    }
572}
573
574impl fmt::Display for SqlIdentifier {
575    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
576        f.write_str(&self.0)
577    }
578}
579
580// ── LdapSafeString ────────────────────────────────────────────────────────────
581
582/// An LDAP-safe string with RFC 4515 escaping applied to special characters.
583///
584/// Special characters (`*`, `(`, `)`, `\`, NUL) are escaped to their `\xx`
585/// hex form. The construction always succeeds; a [`BoundaryViolation`] event is
586/// emitted when escaping was necessary (potential injection signal).
587///
588/// # Examples
589///
590/// ```
591/// use secure_boundary::safe_types::LdapSafeString;
592///
593/// let s = LdapSafeString::try_from("user*(admin)").unwrap();
594/// assert_eq!(s.as_inner(), "user\\2a\\28admin\\29");
595/// ```
596#[derive(Clone, Debug, PartialEq, Eq, Hash)]
597pub struct LdapSafeString(String);
598
599impl LdapSafeString {
600    /// Returns a reference to the RFC 4515-escaped string.
601    #[must_use]
602    pub fn as_inner(&self) -> &str {
603        &self.0
604    }
605
606    /// Consumes the wrapper and returns the RFC 4515-escaped string.
607    #[must_use]
608    pub fn into_inner(self) -> String {
609        self.0
610    }
611}
612
613impl TryFrom<&str> for LdapSafeString {
614    type Error = BoundaryRejection;
615
616    fn try_from(s: &str) -> Result<Self, Self::Error> {
617        let mut escaped = String::with_capacity(s.len() * 2);
618        let mut had_special = false;
619
620        for c in s.chars() {
621            match c {
622                '\0' => {
623                    escaped.push_str("\\00");
624                    had_special = true;
625                }
626                '*' => {
627                    escaped.push_str("\\2a");
628                    had_special = true;
629                }
630                '(' => {
631                    escaped.push_str("\\28");
632                    had_special = true;
633                }
634                ')' => {
635                    escaped.push_str("\\29");
636                    had_special = true;
637                }
638                '\\' => {
639                    escaped.push_str("\\5c");
640                    had_special = true;
641                }
642                other => escaped.push(other),
643            }
644        }
645
646        if had_special {
647            emit_violation(ViolationKind::SyntaxViolation, "ldap_injection_chars");
648        }
649
650        Ok(Self(escaped))
651    }
652}
653
654impl<'de> Deserialize<'de> for LdapSafeString {
655    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
656        let s = String::deserialize(d)?;
657        // TryFrom never returns Err for LdapSafeString, but we still use the
658        // unified pattern for consistency.
659        LdapSafeString::try_from(s.as_str()).map_err(serde::de::Error::custom)
660    }
661}
662
663impl fmt::Display for LdapSafeString {
664    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
665        f.write_str(&self.0)
666    }
667}