Skip to main content

s2_common/
encryption.rs

1//! Encryption algorithm, key material parsing, and request header handling.
2
3use core::str::FromStr;
4use std::sync::Arc;
5
6use base64ct::{Base64, Decoder, Encoding};
7use http::{HeaderName, HeaderValue};
8use secrecy::{ExposeSecret, SecretBox, SecretString, zeroize::Zeroize};
9use strum::{Display, EnumString};
10
11use crate::http::ParseableHeader;
12
13pub static S2_ENCRYPTION_KEY_HEADER: HeaderName = HeaderName::from_static("s2-encryption-key");
14
15// 32 bytes in Base 64
16const MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN: usize = 44;
17
18type EncodedKeyMaterial = Arc<SecretString>;
19type DecodedKey<const N: usize> = Arc<SecretBox<[u8; N]>>;
20
21/// Encryption algorithm.
22#[derive(
23    Debug,
24    Clone,
25    Copy,
26    PartialEq,
27    Eq,
28    Hash,
29    serde::Serialize,
30    serde::Deserialize,
31    Display,
32    EnumString,
33)]
34#[strum(ascii_case_insensitive)]
35#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
36pub enum EncryptionAlgorithm {
37    /// AEGIS-256
38    #[strum(serialize = "aegis-256")]
39    #[serde(rename = "aegis-256")]
40    #[cfg_attr(feature = "clap", value(name = "aegis-256"))]
41    Aegis256,
42    /// AES-256-GCM
43    #[strum(serialize = "aes-256-gcm")]
44    #[serde(rename = "aes-256-gcm")]
45    #[cfg_attr(feature = "clap", value(name = "aes-256-gcm"))]
46    Aes256Gcm,
47}
48
49/// Encryption key material for append/read operations.
50#[derive(Debug, Clone)]
51pub struct EncryptionKey(EncodedKeyMaterial);
52
53impl EncryptionKey {
54    pub fn new<const N: usize>(key: [u8; N]) -> Self {
55        Self(Arc::new(Base64::encode_string(&key).into()))
56    }
57
58    pub(crate) fn expose_secret(&self) -> &str {
59        self.0.expose_secret()
60    }
61
62    pub fn to_header_value(&self) -> HeaderValue {
63        let mut value = HeaderValue::from_bytes(self.expose_secret().as_bytes())
64            .expect("encryption key header value should be ASCII");
65        value.set_sensitive(true);
66        value
67    }
68}
69
70/// Decoded fixed-size encryption key material.
71#[derive(Debug, Clone)]
72pub struct DecodedEncryptionKey<const N: usize>(DecodedKey<N>);
73
74impl<const N: usize> DecodedEncryptionKey<N> {
75    pub fn new(key: [u8; N]) -> Self {
76        Self(Arc::new(SecretBox::new(Box::new(key))))
77    }
78}
79
80impl<const N: usize> ExposeSecret<[u8; N]> for DecodedEncryptionKey<N> {
81    fn expose_secret(&self) -> &[u8; N] {
82        self.0.expose_secret()
83    }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
87#[error("invalid encryption key: key material length {0} is out of range")]
88pub struct EncryptionKeyLengthError(usize);
89
90#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
91pub enum EncryptionSpecResolutionError {
92    #[error("missing encryption key for stream cipher '{cipher}'")]
93    MissingKey { cipher: EncryptionAlgorithm },
94    #[error("invalid encryption key for stream cipher '{cipher}': invalid base64")]
95    InvalidBase64 { cipher: EncryptionAlgorithm },
96    #[error("invalid encryption key length for stream cipher '{cipher}': {length}")]
97    InvalidKeyLength {
98        cipher: EncryptionAlgorithm,
99        length: usize,
100    },
101}
102
103/// Resolved encryption spec after combining stream metadata with the encryption key material, if any.
104#[rustfmt::skip]
105#[derive(Debug, Clone, Default)]
106pub enum EncryptionSpec {
107    #[default]
108    Plain,
109    Aegis256(DecodedEncryptionKey<32>),
110    Aes256Gcm(DecodedEncryptionKey<32>),
111}
112
113impl EncryptionSpec {
114    pub fn resolve(
115        cipher: Option<EncryptionAlgorithm>,
116        key: Option<EncryptionKey>,
117    ) -> Result<Self, EncryptionSpecResolutionError> {
118        match (cipher, key) {
119            (None, _) => Ok(Self::Plain),
120            (Some(cipher @ EncryptionAlgorithm::Aegis256), Some(key)) => {
121                Ok(Self::Aegis256(resolve_key(cipher, key)?))
122            }
123            (Some(cipher @ EncryptionAlgorithm::Aes256Gcm), Some(key)) => {
124                Ok(Self::Aes256Gcm(resolve_key(cipher, key)?))
125            }
126            (Some(cipher), None) => Err(EncryptionSpecResolutionError::MissingKey { cipher }),
127        }
128    }
129
130    pub fn aegis256(key: [u8; 32]) -> Self {
131        Self::Aegis256(DecodedEncryptionKey::new(key))
132    }
133
134    pub fn aes256_gcm(key: [u8; 32]) -> Self {
135        Self::Aes256Gcm(DecodedEncryptionKey::new(key))
136    }
137}
138
139impl FromStr for EncryptionKey {
140    type Err = EncryptionKeyLengthError;
141
142    fn from_str(s: &str) -> Result<Self, Self::Err> {
143        let trimmed = s.trim();
144        if (1..=MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN).contains(&trimmed.len()) {
145            Ok(Self(Arc::new(trimmed.to_owned().into())))
146        } else {
147            Err(EncryptionKeyLengthError(trimmed.len()))
148        }
149    }
150}
151
152impl ParseableHeader for EncryptionKey {
153    fn name() -> &'static HeaderName {
154        &S2_ENCRYPTION_KEY_HEADER
155    }
156}
157
158fn resolve_key<const N: usize>(
159    cipher: EncryptionAlgorithm,
160    key: EncryptionKey,
161) -> Result<DecodedEncryptionKey<N>, EncryptionSpecResolutionError> {
162    let mut decoder = Decoder::<Base64>::new(key.expose_secret().as_bytes())
163        .map_err(|_| EncryptionSpecResolutionError::InvalidBase64 { cipher })?;
164    let mut key_material = Box::new([0u8; N]);
165    match decoder.decode(key_material.as_mut()) {
166        Ok(_) if decoder.is_finished() => {
167            Ok(DecodedEncryptionKey(Arc::new(SecretBox::new(key_material))))
168        }
169        Ok(_) => {
170            let length = N
171                .checked_add(decoder.remaining_len())
172                .expect("decoded key length should fit usize");
173            key_material.as_mut().zeroize();
174            Err(EncryptionSpecResolutionError::InvalidKeyLength { cipher, length })
175        }
176        Err(base64ct::Error::InvalidEncoding) => {
177            key_material.as_mut().zeroize();
178            Err(EncryptionSpecResolutionError::InvalidBase64 { cipher })
179        }
180        Err(base64ct::Error::InvalidLength) => {
181            let length = decoder.remaining_len();
182            key_material.as_mut().zeroize();
183            Err(EncryptionSpecResolutionError::InvalidKeyLength { cipher, length })
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use rstest::rstest;
191
192    use super::*;
193
194    const KEY_B64: &str = "AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=";
195    const KEY_BYTES: [u8; 32] = [
196        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
197        26, 27, 28, 29, 30, 31, 32,
198    ];
199
200    fn resolve_encrypted(
201        cipher: EncryptionAlgorithm,
202        key: EncryptionKey,
203    ) -> Result<EncryptionSpec, EncryptionSpecResolutionError> {
204        EncryptionSpec::resolve(Some(cipher), Some(key))
205    }
206
207    #[test]
208    fn key_header_value_roundtrips_and_is_sensitive() {
209        let value = EncryptionKey::new(KEY_BYTES).to_header_value();
210        assert_eq!(value.to_str().unwrap(), KEY_B64);
211        assert!(value.is_sensitive());
212
213        let parsed = value.to_str().unwrap().parse::<EncryptionKey>().unwrap();
214        assert_eq!(parsed.to_header_value().to_str().unwrap(), KEY_B64);
215    }
216
217    #[test]
218    fn encryption_key_parsing_trims_and_enforces_bounds() {
219        let parsed = format!("  {KEY_B64}\n").parse::<EncryptionKey>().unwrap();
220        assert_eq!(parsed.to_header_value().to_str().unwrap(), KEY_B64);
221
222        assert_eq!(
223            "   ".parse::<EncryptionKey>().unwrap_err(),
224            EncryptionKeyLengthError(0)
225        );
226
227        let too_long = "A".repeat(MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN + 1);
228        assert_eq!(
229            too_long.parse::<EncryptionKey>().unwrap_err(),
230            EncryptionKeyLengthError(MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN + 1)
231        );
232    }
233
234    #[test]
235    fn resolve_plain_ignores_supplied_key() {
236        let encryption = EncryptionSpec::resolve(None, Some("!!!!".parse().unwrap())).unwrap();
237        assert!(matches!(encryption, EncryptionSpec::Plain));
238    }
239
240    #[rstest]
241    #[case(EncryptionAlgorithm::Aegis256)]
242    #[case(EncryptionAlgorithm::Aes256Gcm)]
243    fn resolve_encrypted_requires_key(#[case] cipher: EncryptionAlgorithm) {
244        let err = EncryptionSpec::resolve(Some(cipher), None).unwrap_err();
245        assert_eq!(err, EncryptionSpecResolutionError::MissingKey { cipher });
246    }
247
248    #[rstest]
249    #[case(EncryptionAlgorithm::Aegis256)]
250    #[case(EncryptionAlgorithm::Aes256Gcm)]
251    fn resolve_encrypted_decodes_key_for_each_algorithm(#[case] cipher: EncryptionAlgorithm) {
252        let encryption = resolve_encrypted(cipher, EncryptionKey::new(KEY_BYTES)).unwrap();
253
254        match (cipher, encryption) {
255            (EncryptionAlgorithm::Aegis256, EncryptionSpec::Aegis256(key)) => {
256                assert_eq!(key.expose_secret(), &KEY_BYTES);
257            }
258            (EncryptionAlgorithm::Aes256Gcm, EncryptionSpec::Aes256Gcm(key)) => {
259                assert_eq!(key.expose_secret(), &KEY_BYTES);
260            }
261            _ => panic!("resolved encryption spec did not match requested algorithm"),
262        }
263    }
264
265    #[rstest]
266    #[case(EncryptionAlgorithm::Aegis256)]
267    #[case(EncryptionAlgorithm::Aes256Gcm)]
268    fn resolve_encrypted_rejects_invalid_base64(#[case] cipher: EncryptionAlgorithm) {
269        let err = resolve_encrypted(cipher, "!!!!".parse().unwrap()).unwrap_err();
270        assert_eq!(err, EncryptionSpecResolutionError::InvalidBase64 { cipher });
271    }
272
273    #[test]
274    fn resolve_encrypted_rejects_non_32_byte_keys() {
275        let cipher = EncryptionAlgorithm::Aegis256;
276
277        let short_err = resolve_encrypted(cipher, EncryptionKey::new([0x42; 4])).unwrap_err();
278        assert_eq!(
279            short_err,
280            EncryptionSpecResolutionError::InvalidKeyLength { cipher, length: 4 }
281        );
282
283        let long_err = resolve_encrypted(cipher, EncryptionKey::new([0x42; 33])).unwrap_err();
284        assert_eq!(
285            long_err,
286            EncryptionSpecResolutionError::InvalidKeyLength { cipher, length: 33 }
287        );
288    }
289}