Skip to main content

use_security_header/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3#![allow(clippy::module_name_repetitions)]
4
5use core::{fmt, str::FromStr};
6use std::error::Error;
7
8/// Error returned when a security header name is invalid.
9#[derive(Clone, Copy, Debug, Eq, PartialEq)]
10pub enum SecurityHeaderNameError {
11    Empty,
12    NonAscii,
13    InvalidCharacter,
14}
15
16impl fmt::Display for SecurityHeaderNameError {
17    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
18        match self {
19            Self::Empty => formatter.write_str("security header name cannot be empty"),
20            Self::NonAscii => formatter.write_str("security header name must be ASCII"),
21            Self::InvalidCharacter => {
22                formatter.write_str("security header name contains an invalid character")
23            }
24        }
25    }
26}
27
28impl Error for SecurityHeaderNameError {}
29
30/// Error returned when a security header label cannot be parsed.
31#[derive(Clone, Copy, Debug, Eq, PartialEq)]
32pub enum SecurityHeaderParseError {
33    Empty,
34    Unknown,
35}
36
37impl fmt::Display for SecurityHeaderParseError {
38    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
39        match self {
40            Self::Empty => formatter.write_str("security header label cannot be empty"),
41            Self::Unknown => formatter.write_str("unknown security header label"),
42        }
43    }
44}
45
46impl Error for SecurityHeaderParseError {}
47
48/// A validated HTTP security header name.
49#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
50pub struct SecurityHeaderName(String);
51
52impl SecurityHeaderName {
53    /// Creates a security header name from an HTTP token-shaped string.
54    pub fn new(input: impl AsRef<str>) -> Result<Self, SecurityHeaderNameError> {
55        let trimmed = input.as_ref().trim();
56        validate_header_name(trimmed)?;
57        Ok(Self(trimmed.to_owned()))
58    }
59
60    /// Returns the stored header name.
61    #[must_use]
62    pub fn as_str(&self) -> &str {
63        &self.0
64    }
65}
66
67impl fmt::Display for SecurityHeaderName {
68    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
69        formatter.write_str(self.as_str())
70    }
71}
72
73impl FromStr for SecurityHeaderName {
74    type Err = SecurityHeaderNameError;
75
76    fn from_str(input: &str) -> Result<Self, Self::Err> {
77        Self::new(input)
78    }
79}
80
81impl TryFrom<&str> for SecurityHeaderName {
82    type Error = SecurityHeaderNameError;
83
84    fn try_from(value: &str) -> Result<Self, Self::Error> {
85        Self::new(value)
86    }
87}
88
89macro_rules! label_enum {
90    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
91        impl $name {
92            /// Returns the stable label.
93            #[must_use]
94            pub const fn as_str(self) -> &'static str {
95                match self {
96                    $(Self::$variant => $label,)+
97                }
98            }
99        }
100
101        impl fmt::Display for $name {
102            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
103                formatter.write_str(self.as_str())
104            }
105        }
106
107        impl FromStr for $name {
108            type Err = SecurityHeaderParseError;
109
110            fn from_str(input: &str) -> Result<Self, Self::Err> {
111                let trimmed = input.trim();
112                if trimmed.is_empty() {
113                    return Err(SecurityHeaderParseError::Empty);
114                }
115                let normalized = trimmed.to_ascii_lowercase();
116                match normalized.as_str() {
117                    $($label => Ok(Self::$variant),)+
118                    _ => Err(SecurityHeaderParseError::Unknown),
119                }
120            }
121        }
122    };
123}
124
125/// Security header categories.
126#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
127pub enum SecurityHeaderKind {
128    ContentSecurityPolicy,
129    StrictTransportSecurity,
130    XContentTypeOptions,
131    XFrameOptions,
132    ReferrerPolicy,
133    PermissionsPolicy,
134    CrossOriginOpenerPolicy,
135    CrossOriginResourcePolicy,
136    CrossOriginEmbedderPolicy,
137    CacheControl,
138}
139
140impl SecurityHeaderKind {
141    /// Returns the canonical HTTP header name.
142    #[must_use]
143    pub const fn header_name(self) -> &'static str {
144        match self {
145            Self::ContentSecurityPolicy => "Content-Security-Policy",
146            Self::StrictTransportSecurity => "Strict-Transport-Security",
147            Self::XContentTypeOptions => "X-Content-Type-Options",
148            Self::XFrameOptions => "X-Frame-Options",
149            Self::ReferrerPolicy => "Referrer-Policy",
150            Self::PermissionsPolicy => "Permissions-Policy",
151            Self::CrossOriginOpenerPolicy => "Cross-Origin-Opener-Policy",
152            Self::CrossOriginResourcePolicy => "Cross-Origin-Resource-Policy",
153            Self::CrossOriginEmbedderPolicy => "Cross-Origin-Embedder-Policy",
154            Self::CacheControl => "Cache-Control",
155        }
156    }
157}
158
159label_enum!(SecurityHeaderKind {
160    ContentSecurityPolicy => "content-security-policy",
161    StrictTransportSecurity => "strict-transport-security",
162    XContentTypeOptions => "x-content-type-options",
163    XFrameOptions => "x-frame-options",
164    ReferrerPolicy => "referrer-policy",
165    PermissionsPolicy => "permissions-policy",
166    CrossOriginOpenerPolicy => "cross-origin-opener-policy",
167    CrossOriginResourcePolicy => "cross-origin-resource-policy",
168    CrossOriginEmbedderPolicy => "cross-origin-embedder-policy",
169    CacheControl => "cache-control",
170});
171
172/// Content Security Policy directive labels.
173#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
174pub enum ContentSecurityPolicyDirective {
175    DefaultSrc,
176    ScriptSrc,
177    StyleSrc,
178    ImgSrc,
179    ConnectSrc,
180    FrameAncestors,
181    BaseUri,
182    FormAction,
183    UpgradeInsecureRequests,
184    Other,
185}
186
187label_enum!(ContentSecurityPolicyDirective {
188    DefaultSrc => "default-src",
189    ScriptSrc => "script-src",
190    StyleSrc => "style-src",
191    ImgSrc => "img-src",
192    ConnectSrc => "connect-src",
193    FrameAncestors => "frame-ancestors",
194    BaseUri => "base-uri",
195    FormAction => "form-action",
196    UpgradeInsecureRequests => "upgrade-insecure-requests",
197    Other => "other",
198});
199
200/// Referrer policy labels.
201#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
202pub enum ReferrerPolicyKind {
203    NoReferrer,
204    NoReferrerWhenDowngrade,
205    Origin,
206    OriginWhenCrossOrigin,
207    SameOrigin,
208    StrictOrigin,
209    StrictOriginWhenCrossOrigin,
210    UnsafeUrl,
211}
212
213label_enum!(ReferrerPolicyKind {
214    NoReferrer => "no-referrer",
215    NoReferrerWhenDowngrade => "no-referrer-when-downgrade",
216    Origin => "origin",
217    OriginWhenCrossOrigin => "origin-when-cross-origin",
218    SameOrigin => "same-origin",
219    StrictOrigin => "strict-origin",
220    StrictOriginWhenCrossOrigin => "strict-origin-when-cross-origin",
221    UnsafeUrl => "unsafe-url",
222});
223
224/// X-Frame-Options labels.
225#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
226pub enum FrameOptionsKind {
227    Deny,
228    SameOrigin,
229}
230
231label_enum!(FrameOptionsKind {
232    Deny => "deny",
233    SameOrigin => "sameorigin",
234});
235
236/// Strict-Transport-Security directive labels.
237#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
238pub enum TransportSecurityDirective {
239    MaxAge,
240    IncludeSubDomains,
241    Preload,
242}
243
244label_enum!(TransportSecurityDirective {
245    MaxAge => "max-age",
246    IncludeSubDomains => "includesubdomains",
247    Preload => "preload",
248});
249
250/// CORS policy labels.
251#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
252pub enum CorsPolicyKind {
253    DenyAll,
254    SameOrigin,
255    AllowList,
256    AllowAll,
257}
258
259label_enum!(CorsPolicyKind {
260    DenyAll => "deny-all",
261    SameOrigin => "same-origin",
262    AllowList => "allow-list",
263    AllowAll => "allow-all",
264});
265
266/// Permissions policy directive labels.
267#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
268pub enum PermissionsPolicyDirective {
269    Geolocation,
270    Camera,
271    Microphone,
272    Payment,
273    Usb,
274    Fullscreen,
275    Other,
276}
277
278label_enum!(PermissionsPolicyDirective {
279    Geolocation => "geolocation",
280    Camera => "camera",
281    Microphone => "microphone",
282    Payment => "payment",
283    Usb => "usb",
284    Fullscreen => "fullscreen",
285    Other => "other",
286});
287
288fn validate_header_name(value: &str) -> Result<(), SecurityHeaderNameError> {
289    if value.is_empty() {
290        return Err(SecurityHeaderNameError::Empty);
291    }
292    if !value.is_ascii() {
293        return Err(SecurityHeaderNameError::NonAscii);
294    }
295    if value.bytes().all(is_token_byte) {
296        Ok(())
297    } else {
298        Err(SecurityHeaderNameError::InvalidCharacter)
299    }
300}
301
302const fn is_token_byte(byte: u8) -> bool {
303    byte.is_ascii_alphanumeric()
304        || matches!(
305            byte,
306            b'!' | b'#'
307                | b'$'
308                | b'%'
309                | b'&'
310                | b'\''
311                | b'*'
312                | b'+'
313                | b'-'
314                | b'.'
315                | b'^'
316                | b'_'
317                | b'`'
318                | b'|'
319                | b'~'
320        )
321}
322
323#[cfg(test)]
324mod tests {
325    use super::{
326        ContentSecurityPolicyDirective, ReferrerPolicyKind, SecurityHeaderKind, SecurityHeaderName,
327        SecurityHeaderNameError,
328    };
329
330    #[test]
331    fn validates_header_names() {
332        let name = SecurityHeaderName::new("Content-Security-Policy").expect("header name");
333
334        assert_eq!(name.as_str(), "Content-Security-Policy");
335        assert_eq!(
336            SecurityHeaderName::new(" "),
337            Err(SecurityHeaderNameError::Empty)
338        );
339        assert_eq!(
340            SecurityHeaderName::new("Bad Header"),
341            Err(SecurityHeaderNameError::InvalidCharacter)
342        );
343    }
344
345    #[test]
346    fn parses_and_displays_labels() {
347        assert_eq!(
348            "script-src"
349                .parse::<ContentSecurityPolicyDirective>()
350                .expect("directive"),
351            ContentSecurityPolicyDirective::ScriptSrc
352        );
353        assert_eq!(
354            ReferrerPolicyKind::StrictOriginWhenCrossOrigin.to_string(),
355            "strict-origin-when-cross-origin"
356        );
357    }
358
359    #[test]
360    fn exposes_canonical_header_name() {
361        assert_eq!(
362            SecurityHeaderKind::StrictTransportSecurity.header_name(),
363            "Strict-Transport-Security"
364        );
365    }
366}