Skip to main content

s2_common/
encryption.rs

1//! Encryption spec parsing, header parsing, and key parsing.
2
3use core::str::FromStr;
4use std::sync::Arc;
5
6use base64ct::Encoding;
7use http::{HeaderName, HeaderValue};
8use secrecy::{ExposeSecret, SecretBox, zeroize::Zeroizing};
9use strum::{Display, EnumString};
10
11use crate::http::ParseableHeader;
12
13pub static S2_ENCRYPTION_HEADER: HeaderName = HeaderName::from_static("s2-encryption");
14
15type EncryptionKey<const N: usize> = Arc<SecretBox<[u8; N]>>;
16
17/// Encryption algorithm.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Display, EnumString)]
19#[strum(ascii_case_insensitive)]
20pub enum EncryptionAlgorithm {
21    /// AEGIS-256
22    #[strum(serialize = "aegis-256")]
23    Aegis256,
24    /// AES-256-GCM
25    #[strum(serialize = "aes-256-gcm")]
26    Aes256Gcm,
27}
28
29/// Encryption mode, including plaintext.
30#[derive(
31    Debug,
32    Clone,
33    Copy,
34    PartialEq,
35    Eq,
36    Hash,
37    serde::Serialize,
38    serde::Deserialize,
39    Display,
40    EnumString,
41    enumset::EnumSetType,
42)]
43#[strum(ascii_case_insensitive)]
44#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
45#[enumset(no_super_impls)]
46#[serde(rename_all = "kebab-case")]
47pub enum EncryptionMode {
48    #[strum(serialize = "plain")]
49    Plain,
50    #[strum(serialize = "aegis-256")]
51    Aegis256,
52    #[strum(serialize = "aes-256-gcm")]
53    Aes256Gcm,
54}
55
56impl From<EncryptionAlgorithm> for EncryptionMode {
57    fn from(value: EncryptionAlgorithm) -> Self {
58        match value {
59            EncryptionAlgorithm::Aegis256 => Self::Aegis256,
60            EncryptionAlgorithm::Aes256Gcm => Self::Aes256Gcm,
61        }
62    }
63}
64
65#[derive(Debug, Clone)]
66pub struct Aegis256Key(EncryptionKey<32>);
67
68impl Aegis256Key {
69    pub fn new(key: [u8; 32]) -> Self {
70        Self(Arc::new(SecretBox::new(Box::new(key))))
71    }
72
73    pub fn from_base64(key_b64: &str) -> Result<Self, EncryptionSpecError> {
74        parse_encryption_key::<32>(key_b64).map(Self)
75    }
76
77    pub(crate) fn secret(&self) -> &[u8; 32] {
78        self.0.as_ref().expose_secret()
79    }
80}
81
82#[derive(Debug, Clone)]
83pub struct Aes256GcmKey(EncryptionKey<32>);
84
85impl Aes256GcmKey {
86    pub fn new(key: [u8; 32]) -> Self {
87        Self(Arc::new(SecretBox::new(Box::new(key))))
88    }
89
90    pub fn from_base64(key_b64: &str) -> Result<Self, EncryptionSpecError> {
91        parse_encryption_key::<32>(key_b64).map(Self)
92    }
93
94    pub(crate) fn secret(&self) -> &[u8; 32] {
95        self.0.as_ref().expose_secret()
96    }
97}
98
99#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
100pub enum EncryptionSpecError {
101    #[error("Invalid encryption spec: expected '<mode>; <key>' or 'plain'")]
102    InvalidSyntax,
103    #[error("Invalid encryption spec: missing encryption mode")]
104    MissingMode,
105    #[error(
106        "Invalid encryption spec: unknown encryption mode {mode:?}; expected 'plain', 'aegis-256', or 'aes-256-gcm'"
107    )]
108    UnknownMode { mode: String },
109    #[error("Invalid encryption spec: key is not allowed when mode is 'plain'")]
110    UnexpectedKeyForPlain,
111    #[error("Invalid encryption spec: missing key for '{mode}'")]
112    MissingKey { mode: EncryptionMode },
113    #[error("Invalid encryption spec: key is not valid base64")]
114    InvalidKeyBase64,
115    #[error("Invalid encryption spec: key must be exactly {expected} bytes, got {actual} bytes")]
116    InvalidKeyLength { expected: usize, actual: usize },
117}
118
119#[derive(Debug, Clone, Default)]
120pub enum EncryptionSpec {
121    #[default]
122    Plain,
123    Aegis256(Aegis256Key),
124    Aes256Gcm(Aes256GcmKey),
125}
126
127impl EncryptionSpec {
128    pub fn aegis256(key: [u8; 32]) -> Self {
129        Self::Aegis256(Aegis256Key::new(key))
130    }
131
132    pub fn aes256_gcm(key: [u8; 32]) -> Self {
133        Self::Aes256Gcm(Aes256GcmKey::new(key))
134    }
135
136    pub fn mode(&self) -> EncryptionMode {
137        match self {
138            Self::Plain => EncryptionMode::Plain,
139            Self::Aegis256(_) => EncryptionMode::Aegis256,
140            Self::Aes256Gcm(_) => EncryptionMode::Aes256Gcm,
141        }
142    }
143
144    pub fn to_header_value(&self) -> HeaderValue {
145        let mut value = match self {
146            Self::Plain => HeaderValue::from_static("plain"),
147            Self::Aegis256(key) => {
148                header_value_for_key(EncryptionAlgorithm::Aegis256, key.secret())
149            }
150            Self::Aes256Gcm(key) => {
151                header_value_for_key(EncryptionAlgorithm::Aes256Gcm, key.secret())
152            }
153        };
154        value.set_sensitive(true);
155        value
156    }
157}
158
159impl FromStr for EncryptionSpec {
160    type Err = EncryptionSpecError;
161
162    fn from_str(s: &str) -> Result<Self, Self::Err> {
163        let s = s.trim();
164        let mut parts = s.splitn(3, ';');
165        let mode_str = parts.next().unwrap_or_default().trim();
166        let key_b64 = parts.next().map(str::trim);
167        if parts.next().is_some() {
168            return Err(EncryptionSpecError::InvalidSyntax);
169        }
170
171        if mode_str.is_empty() {
172            return Err(EncryptionSpecError::MissingMode);
173        }
174
175        let key_b64 = key_b64.filter(|key| !key.is_empty());
176        match (parse_mode(mode_str)?, key_b64) {
177            (EncryptionMode::Plain, None) => Ok(Self::Plain),
178            (EncryptionMode::Plain, Some(_)) => Err(EncryptionSpecError::UnexpectedKeyForPlain),
179            (EncryptionMode::Aegis256, Some(key_b64)) => {
180                Ok(Self::Aegis256(Aegis256Key::from_base64(key_b64)?))
181            }
182            (EncryptionMode::Aegis256, None) => Err(EncryptionSpecError::MissingKey {
183                mode: EncryptionMode::Aegis256,
184            }),
185            (EncryptionMode::Aes256Gcm, Some(key_b64)) => {
186                Ok(Self::Aes256Gcm(Aes256GcmKey::from_base64(key_b64)?))
187            }
188            (EncryptionMode::Aes256Gcm, None) => Err(EncryptionSpecError::MissingKey {
189                mode: EncryptionMode::Aes256Gcm,
190            }),
191        }
192    }
193}
194
195impl ParseableHeader for EncryptionSpec {
196    fn name() -> &'static HeaderName {
197        &S2_ENCRYPTION_HEADER
198    }
199}
200
201fn parse_encryption_key<const N: usize>(
202    key_b64: &str,
203) -> Result<EncryptionKey<N>, EncryptionSpecError> {
204    use base64ct::{Base64, Encoding};
205    use secrecy::zeroize::Zeroize;
206
207    let mut key = Box::new([0u8; N]);
208    let decoded = match Base64::decode(key_b64, key.as_mut()) {
209        Ok(decoded) => decoded,
210        Err(_) => {
211            key.as_mut().zeroize();
212            return Err(EncryptionSpecError::InvalidKeyBase64);
213        }
214    };
215
216    if decoded.len() != N {
217        let len = decoded.len();
218        key.as_mut().zeroize();
219        return Err(EncryptionSpecError::InvalidKeyLength {
220            expected: N,
221            actual: len,
222        });
223    }
224
225    Ok(Arc::new(SecretBox::new(key)))
226}
227
228fn header_value_for_key(algorithm: EncryptionAlgorithm, key: &[u8; 32]) -> HeaderValue {
229    let algorithm = algorithm.to_string();
230    let encoded_len = base64ct::Base64::encoded_len(key);
231    let mut value = Zeroizing::new(vec![0u8; algorithm.len() + 2 + encoded_len]);
232    value[..algorithm.len()].copy_from_slice(algorithm.as_bytes());
233    value[algorithm.len()..algorithm.len() + 2].copy_from_slice(b"; ");
234    base64ct::Base64::encode(key, &mut value[algorithm.len() + 2..])
235        .expect("base64 output length should match buffer");
236
237    HeaderValue::from_bytes(&value).expect("encryption header value should be ASCII")
238}
239
240fn parse_mode(mode_str: &str) -> Result<EncryptionMode, EncryptionSpecError> {
241    mode_str
242        .parse::<EncryptionMode>()
243        .map_err(|_| EncryptionSpecError::UnknownMode {
244            mode: mode_str.to_owned(),
245        })
246}
247
248#[cfg(test)]
249mod tests {
250    use http::header::HeaderValue;
251    use rstest::rstest;
252
253    use super::*;
254
255    const KEY_B64: &str = "AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=";
256    const KEY_BYTES: [u8; 32] = [
257        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
258        26, 27, 28, 29, 30, 31, 32,
259    ];
260
261    fn assert_encrypted_spec(
262        spec: EncryptionSpec,
263        algorithm: EncryptionAlgorithm,
264        expected: &[u8; 32],
265    ) {
266        match (algorithm, spec) {
267            (EncryptionAlgorithm::Aegis256, EncryptionSpec::Aegis256(key)) => {
268                assert_eq!(key.secret(), expected)
269            }
270            (EncryptionAlgorithm::Aes256Gcm, EncryptionSpec::Aes256Gcm(key)) => {
271                assert_eq!(key.secret(), expected)
272            }
273            (_, EncryptionSpec::Plain) => panic!("expected encrypted spec"),
274            (expected_algorithm, actual_spec) => {
275                panic!("expected {expected_algorithm:?}, got {actual_spec:?}")
276            }
277        }
278    }
279
280    fn assert_invalid_parse(header: &str, expected: EncryptionSpecError) {
281        let result = header.parse::<EncryptionSpec>();
282        match result {
283            Err(actual) => assert_eq!(actual, expected),
284            Ok(actual) => panic!("expected invalid spec for {header:?}, got {actual:?}"),
285        }
286    }
287
288    #[rstest]
289    #[case("aegis-256", EncryptionAlgorithm::Aegis256)]
290    #[case("aes-256-gcm", EncryptionAlgorithm::Aes256Gcm)]
291    #[case("AEGIS-256", EncryptionAlgorithm::Aegis256)]
292    #[case("AES-256-GCM", EncryptionAlgorithm::Aes256Gcm)]
293    fn parse_header_valid_encrypted(
294        #[case] algorithm: &str,
295        #[case] expected: EncryptionAlgorithm,
296    ) {
297        let spec = format!("{algorithm}; {KEY_B64}")
298            .parse::<EncryptionSpec>()
299            .unwrap();
300        assert_encrypted_spec(spec, expected, &KEY_BYTES);
301    }
302
303    #[test]
304    fn parse_header_aes_with_whitespace() {
305        let spec = format!(" aes-256-gcm ; {KEY_B64} ")
306            .parse::<EncryptionSpec>()
307            .unwrap();
308        assert_encrypted_spec(spec, EncryptionAlgorithm::Aes256Gcm, &KEY_BYTES);
309    }
310
311    #[rstest]
312    #[case("plain")]
313    #[case("PLAIN")]
314    #[case("plain; ")]
315    fn parse_header_plain_variants(#[case] header: &str) {
316        let spec = header.parse::<EncryptionSpec>().unwrap();
317        assert!(matches!(spec, EncryptionSpec::Plain));
318    }
319
320    #[test]
321    fn spec_mode_matches_variant() {
322        assert_eq!(EncryptionSpec::Plain.mode(), EncryptionMode::Plain);
323        assert_eq!(
324            EncryptionSpec::aegis256(KEY_BYTES).mode(),
325            EncryptionMode::Aegis256
326        );
327        assert_eq!(
328            EncryptionSpec::aes256_gcm(KEY_BYTES).mode(),
329            EncryptionMode::Aes256Gcm
330        );
331    }
332
333    #[rstest]
334    #[case("", EncryptionSpecError::MissingMode)]
335    #[case(
336        "aegis-256",
337        EncryptionSpecError::MissingKey {
338            mode: EncryptionMode::Aegis256
339        }
340    )]
341    #[case(
342        "aegis-256; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=; extra",
343        EncryptionSpecError::InvalidSyntax
344    )]
345    #[case(
346        "aegis-256; 3q2+7w==",
347        EncryptionSpecError::InvalidKeyLength {
348            expected: 32,
349            actual: 4
350        }
351    )]
352    #[case(
353        "aegis-256; not-valid-base64!!!",
354        EncryptionSpecError::InvalidKeyBase64
355    )]
356    #[case(
357        "bogus; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=",
358        EncryptionSpecError::UnknownMode {
359            mode: "bogus".to_owned()
360        }
361    )]
362    #[case(
363        "plain; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=",
364        EncryptionSpecError::UnexpectedKeyForPlain
365    )]
366    fn parse_header_invalid_cases(#[case] header: &str, #[case] expected: EncryptionSpecError) {
367        assert_invalid_parse(header, expected);
368    }
369
370    #[test]
371    fn header_value_is_sensitive() {
372        let value = EncryptionSpec::aegis256([7; 32]).to_header_value();
373        assert!(value.is_sensitive());
374        assert_ne!(value, HeaderValue::from_static("plain"));
375    }
376
377    #[test]
378    fn plain_header_value_roundtrips() {
379        let value = EncryptionSpec::Plain.to_header_value();
380        assert_eq!(value.to_str().unwrap(), "plain");
381        assert!(value.is_sensitive());
382
383        let parsed = value.to_str().unwrap().parse::<EncryptionSpec>().unwrap();
384        assert!(matches!(parsed, EncryptionSpec::Plain));
385    }
386
387    #[rstest]
388    #[case(EncryptionAlgorithm::Aegis256)]
389    #[case(EncryptionAlgorithm::Aes256Gcm)]
390    fn encrypted_header_value_roundtrips(#[case] algorithm: EncryptionAlgorithm) {
391        let value = match algorithm {
392            EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(KEY_BYTES),
393            EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(KEY_BYTES),
394        }
395        .to_header_value();
396        assert_eq!(value.to_str().unwrap(), format!("{algorithm}; {KEY_B64}"));
397        assert!(value.is_sensitive());
398
399        let parsed = value.to_str().unwrap().parse::<EncryptionSpec>().unwrap();
400        assert_encrypted_spec(parsed, algorithm, &KEY_BYTES);
401    }
402}