Skip to main content

s2_common/
encryption.rs

1//! Encryption algorithm, key material parsing, and request header handling.
2
3use core::str::FromStr;
4use std::sync::Arc;
5
6use base64ct::{Base64, Decoder, Encoding};
7use http::{HeaderName, HeaderValue};
8use secrecy::{ExposeSecret, SecretBox, SecretString, zeroize::Zeroize};
9use strum::{Display, EnumString};
10
11use crate::http::ParseableHeader;
12
13pub static S2_ENCRYPTION_KEY_HEADER: HeaderName = HeaderName::from_static("s2-encryption-key");
14
15// 32 bytes in Base 64
16const MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN: usize = 44;
17
18type EncodedKeyMaterial = Arc<SecretString>;
19type DecodedKey<const N: usize> = Arc<SecretBox<[u8; N]>>;
20
21/// Encryption algorithm.
22#[derive(
23    Debug,
24    Clone,
25    Copy,
26    PartialEq,
27    Eq,
28    Hash,
29    serde::Serialize,
30    serde::Deserialize,
31    Display,
32    EnumString,
33)]
34#[strum(ascii_case_insensitive)]
35#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
36pub enum EncryptionAlgorithm {
37    /// AEGIS-256
38    #[strum(serialize = "aegis-256")]
39    #[serde(rename = "aegis-256")]
40    #[cfg_attr(feature = "clap", value(name = "aegis-256"))]
41    Aegis256,
42    /// AES-256-GCM
43    #[strum(serialize = "aes-256-gcm")]
44    #[serde(rename = "aes-256-gcm")]
45    #[cfg_attr(feature = "clap", value(name = "aes-256-gcm"))]
46    Aes256Gcm,
47}
48
49/// Encryption key material for append/read operations.
50#[derive(Debug, Clone)]
51pub struct EncryptionKey(EncodedKeyMaterial);
52
53impl EncryptionKey {
54    pub fn new<const N: usize>(key: [u8; N]) -> Self {
55        Self(Arc::new(Base64::encode_string(&key).into()))
56    }
57
58    pub(crate) fn expose_secret(&self) -> &str {
59        self.0.expose_secret()
60    }
61
62    pub fn to_header_value(&self) -> HeaderValue {
63        let mut value = HeaderValue::from_bytes(self.expose_secret().as_bytes())
64            .expect("encryption key header value should be ASCII");
65        value.set_sensitive(true);
66        value
67    }
68}
69
70/// Decoded fixed-size encryption key material.
71#[derive(Debug, Clone)]
72pub struct DecodedEncryptionKey<const N: usize>(DecodedKey<N>);
73
74impl<const N: usize> DecodedEncryptionKey<N> {
75    pub fn new(key: [u8; N]) -> Self {
76        Self(Arc::new(SecretBox::new(Box::new(key))))
77    }
78
79    pub(crate) fn expose_secret(&self) -> &[u8; N] {
80        self.0.expose_secret()
81    }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
85#[error("invalid encryption key: key material length {0} is out of range")]
86pub struct EncryptionKeyLengthError(usize);
87
88#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
89pub enum EncryptionSpecResolutionError {
90    #[error("missing encryption key for stream cipher '{cipher}'")]
91    MissingKey { cipher: EncryptionAlgorithm },
92    #[error("invalid encryption key for stream cipher '{cipher}': invalid base64")]
93    InvalidBase64 { cipher: EncryptionAlgorithm },
94    #[error("invalid encryption key length for stream cipher '{cipher}': {length}")]
95    InvalidKeyLength {
96        cipher: EncryptionAlgorithm,
97        length: usize,
98    },
99}
100
101/// Resolved encryption spec after combining stream metadata with the encryption key material, if any.
102#[rustfmt::skip]
103#[derive(Debug, Clone, Default)]
104pub enum EncryptionSpec {
105    #[default]
106    Plain,
107    Aegis256(DecodedEncryptionKey<32>),
108    Aes256Gcm(DecodedEncryptionKey<32>),
109}
110
111impl EncryptionSpec {
112    pub fn resolve(
113        cipher: Option<EncryptionAlgorithm>,
114        key: Option<EncryptionKey>,
115    ) -> Result<Self, EncryptionSpecResolutionError> {
116        match (cipher, key) {
117            (None, _) => Ok(Self::Plain),
118            (Some(cipher @ EncryptionAlgorithm::Aegis256), Some(key)) => {
119                Ok(Self::Aegis256(resolve_key(cipher, key)?))
120            }
121            (Some(cipher @ EncryptionAlgorithm::Aes256Gcm), Some(key)) => {
122                Ok(Self::Aes256Gcm(resolve_key(cipher, key)?))
123            }
124            (Some(cipher), None) => Err(EncryptionSpecResolutionError::MissingKey { cipher }),
125        }
126    }
127
128    pub fn aegis256(key: [u8; 32]) -> Self {
129        Self::Aegis256(DecodedEncryptionKey::new(key))
130    }
131
132    pub fn aes256_gcm(key: [u8; 32]) -> Self {
133        Self::Aes256Gcm(DecodedEncryptionKey::new(key))
134    }
135}
136
137impl FromStr for EncryptionKey {
138    type Err = EncryptionKeyLengthError;
139
140    fn from_str(s: &str) -> Result<Self, Self::Err> {
141        let trimmed = s.trim();
142        if (1..=MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN).contains(&trimmed.len()) {
143            Ok(Self(Arc::new(trimmed.to_owned().into())))
144        } else {
145            Err(EncryptionKeyLengthError(trimmed.len()))
146        }
147    }
148}
149
150impl ParseableHeader for EncryptionKey {
151    fn name() -> &'static HeaderName {
152        &S2_ENCRYPTION_KEY_HEADER
153    }
154}
155
156fn resolve_key<const N: usize>(
157    cipher: EncryptionAlgorithm,
158    key: EncryptionKey,
159) -> Result<DecodedEncryptionKey<N>, EncryptionSpecResolutionError> {
160    let mut decoder = Decoder::<Base64>::new(key.expose_secret().as_bytes())
161        .map_err(|_| EncryptionSpecResolutionError::InvalidBase64 { cipher })?;
162    let mut key_material = Box::new([0u8; N]);
163    match decoder.decode(key_material.as_mut()) {
164        Ok(_) if decoder.is_finished() => {
165            Ok(DecodedEncryptionKey(Arc::new(SecretBox::new(key_material))))
166        }
167        Ok(_) => {
168            let length = N
169                .checked_add(decoder.remaining_len())
170                .expect("decoded key length should fit usize");
171            key_material.as_mut().zeroize();
172            Err(EncryptionSpecResolutionError::InvalidKeyLength { cipher, length })
173        }
174        Err(base64ct::Error::InvalidEncoding) => {
175            key_material.as_mut().zeroize();
176            Err(EncryptionSpecResolutionError::InvalidBase64 { cipher })
177        }
178        Err(base64ct::Error::InvalidLength) => {
179            let length = decoder.remaining_len();
180            key_material.as_mut().zeroize();
181            Err(EncryptionSpecResolutionError::InvalidKeyLength { cipher, length })
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use rstest::rstest;
189
190    use super::*;
191
192    const KEY_B64: &str = "AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=";
193    const KEY_BYTES: [u8; 32] = [
194        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
195        26, 27, 28, 29, 30, 31, 32,
196    ];
197
198    fn resolve_encrypted(
199        cipher: EncryptionAlgorithm,
200        key: EncryptionKey,
201    ) -> Result<EncryptionSpec, EncryptionSpecResolutionError> {
202        EncryptionSpec::resolve(Some(cipher), Some(key))
203    }
204
205    #[test]
206    fn key_header_value_roundtrips_and_is_sensitive() {
207        let value = EncryptionKey::new(KEY_BYTES).to_header_value();
208        assert_eq!(value.to_str().unwrap(), KEY_B64);
209        assert!(value.is_sensitive());
210
211        let parsed = value.to_str().unwrap().parse::<EncryptionKey>().unwrap();
212        assert_eq!(parsed.to_header_value().to_str().unwrap(), KEY_B64);
213    }
214
215    #[test]
216    fn encryption_key_parsing_trims_and_enforces_bounds() {
217        let parsed = format!("  {KEY_B64}\n").parse::<EncryptionKey>().unwrap();
218        assert_eq!(parsed.to_header_value().to_str().unwrap(), KEY_B64);
219
220        assert_eq!(
221            "   ".parse::<EncryptionKey>().unwrap_err(),
222            EncryptionKeyLengthError(0)
223        );
224
225        let too_long = "A".repeat(MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN + 1);
226        assert_eq!(
227            too_long.parse::<EncryptionKey>().unwrap_err(),
228            EncryptionKeyLengthError(MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN + 1)
229        );
230    }
231
232    #[test]
233    fn resolve_plain_ignores_supplied_key() {
234        let encryption = EncryptionSpec::resolve(None, Some("!!!!".parse().unwrap())).unwrap();
235        assert!(matches!(encryption, EncryptionSpec::Plain));
236    }
237
238    #[rstest]
239    #[case(EncryptionAlgorithm::Aegis256)]
240    #[case(EncryptionAlgorithm::Aes256Gcm)]
241    fn resolve_encrypted_requires_key(#[case] cipher: EncryptionAlgorithm) {
242        let err = EncryptionSpec::resolve(Some(cipher), None).unwrap_err();
243        assert_eq!(err, EncryptionSpecResolutionError::MissingKey { cipher });
244    }
245
246    #[rstest]
247    #[case(EncryptionAlgorithm::Aegis256)]
248    #[case(EncryptionAlgorithm::Aes256Gcm)]
249    fn resolve_encrypted_decodes_key_for_each_algorithm(#[case] cipher: EncryptionAlgorithm) {
250        let encryption = resolve_encrypted(cipher, EncryptionKey::new(KEY_BYTES)).unwrap();
251
252        match (cipher, encryption) {
253            (EncryptionAlgorithm::Aegis256, EncryptionSpec::Aegis256(key)) => {
254                assert_eq!(key.expose_secret(), &KEY_BYTES);
255            }
256            (EncryptionAlgorithm::Aes256Gcm, EncryptionSpec::Aes256Gcm(key)) => {
257                assert_eq!(key.expose_secret(), &KEY_BYTES);
258            }
259            _ => panic!("resolved encryption spec did not match requested algorithm"),
260        }
261    }
262
263    #[rstest]
264    #[case(EncryptionAlgorithm::Aegis256)]
265    #[case(EncryptionAlgorithm::Aes256Gcm)]
266    fn resolve_encrypted_rejects_invalid_base64(#[case] cipher: EncryptionAlgorithm) {
267        let err = resolve_encrypted(cipher, "!!!!".parse().unwrap()).unwrap_err();
268        assert_eq!(err, EncryptionSpecResolutionError::InvalidBase64 { cipher });
269    }
270
271    #[test]
272    fn resolve_encrypted_rejects_non_32_byte_keys() {
273        let cipher = EncryptionAlgorithm::Aegis256;
274
275        let short_err = resolve_encrypted(cipher, EncryptionKey::new([0x42; 4])).unwrap_err();
276        assert_eq!(
277            short_err,
278            EncryptionSpecResolutionError::InvalidKeyLength { cipher, length: 4 }
279        );
280
281        let long_err = resolve_encrypted(cipher, EncryptionKey::new([0x42; 33])).unwrap_err();
282        assert_eq!(
283            long_err,
284            EncryptionSpecResolutionError::InvalidKeyLength { cipher, length: 33 }
285        );
286    }
287}