1use core::str::FromStr;
4use std::sync::Arc;
5
6use base64ct::Encoding;
7use http::{HeaderName, HeaderValue};
8use secrecy::{ExposeSecret, SecretBox, zeroize::Zeroizing};
9use strum::{Display, EnumString};
10
11use crate::http::ParseableHeader;
12
13pub static S2_ENCRYPTION_HEADER: HeaderName = HeaderName::from_static("s2-encryption");
14
15type EncryptionKey<const N: usize> = Arc<SecretBox<[u8; N]>>;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Display, EnumString)]
19#[strum(ascii_case_insensitive)]
20pub enum EncryptionAlgorithm {
21 #[strum(serialize = "aegis-256")]
23 Aegis256,
24 #[strum(serialize = "aes-256-gcm")]
26 Aes256Gcm,
27}
28
29#[derive(
31 Debug,
32 Clone,
33 Copy,
34 PartialEq,
35 Eq,
36 Hash,
37 serde::Serialize,
38 serde::Deserialize,
39 Display,
40 EnumString,
41 enumset::EnumSetType,
42)]
43#[strum(ascii_case_insensitive)]
44#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
45#[enumset(no_super_impls)]
46#[serde(rename_all = "kebab-case")]
47pub enum EncryptionMode {
48 #[strum(serialize = "plain")]
49 Plain,
50 #[strum(serialize = "aegis-256")]
51 Aegis256,
52 #[strum(serialize = "aes-256-gcm")]
53 Aes256Gcm,
54}
55
56impl From<EncryptionAlgorithm> for EncryptionMode {
57 fn from(value: EncryptionAlgorithm) -> Self {
58 match value {
59 EncryptionAlgorithm::Aegis256 => Self::Aegis256,
60 EncryptionAlgorithm::Aes256Gcm => Self::Aes256Gcm,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
66pub struct Aegis256Key(EncryptionKey<32>);
67
68impl Aegis256Key {
69 pub fn new(key: [u8; 32]) -> Self {
70 Self(Arc::new(SecretBox::new(Box::new(key))))
71 }
72
73 pub fn from_base64(key_b64: &str) -> Result<Self, EncryptionSpecError> {
74 parse_encryption_key::<32>(key_b64).map(Self)
75 }
76
77 pub(crate) fn secret(&self) -> &[u8; 32] {
78 self.0.as_ref().expose_secret()
79 }
80}
81
82#[derive(Debug, Clone)]
83pub struct Aes256GcmKey(EncryptionKey<32>);
84
85impl Aes256GcmKey {
86 pub fn new(key: [u8; 32]) -> Self {
87 Self(Arc::new(SecretBox::new(Box::new(key))))
88 }
89
90 pub fn from_base64(key_b64: &str) -> Result<Self, EncryptionSpecError> {
91 parse_encryption_key::<32>(key_b64).map(Self)
92 }
93
94 pub(crate) fn secret(&self) -> &[u8; 32] {
95 self.0.as_ref().expose_secret()
96 }
97}
98
99#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
100pub enum EncryptionSpecError {
101 #[error("Invalid encryption spec: expected '<mode>; <key>' or 'plain'")]
102 InvalidSyntax,
103 #[error("Invalid encryption spec: missing encryption mode")]
104 MissingMode,
105 #[error(
106 "Invalid encryption spec: unknown encryption mode {mode:?}; expected 'plain', 'aegis-256', or 'aes-256-gcm'"
107 )]
108 UnknownMode { mode: String },
109 #[error("Invalid encryption spec: key is not allowed when mode is 'plain'")]
110 UnexpectedKeyForPlain,
111 #[error("Invalid encryption spec: missing key for '{mode}'")]
112 MissingKey { mode: EncryptionMode },
113 #[error("Invalid encryption spec: key is not valid base64")]
114 InvalidKeyBase64,
115 #[error("Invalid encryption spec: key must be exactly {expected} bytes, got {actual} bytes")]
116 InvalidKeyLength { expected: usize, actual: usize },
117}
118
119#[derive(Debug, Clone, Default)]
120pub enum EncryptionSpec {
121 #[default]
122 Plain,
123 Aegis256(Aegis256Key),
124 Aes256Gcm(Aes256GcmKey),
125}
126
127impl EncryptionSpec {
128 pub fn aegis256(key: [u8; 32]) -> Self {
129 Self::Aegis256(Aegis256Key::new(key))
130 }
131
132 pub fn aes256_gcm(key: [u8; 32]) -> Self {
133 Self::Aes256Gcm(Aes256GcmKey::new(key))
134 }
135
136 pub fn mode(&self) -> EncryptionMode {
137 match self {
138 Self::Plain => EncryptionMode::Plain,
139 Self::Aegis256(_) => EncryptionMode::Aegis256,
140 Self::Aes256Gcm(_) => EncryptionMode::Aes256Gcm,
141 }
142 }
143
144 pub fn to_header_value(&self) -> HeaderValue {
145 let mut value = match self {
146 Self::Plain => HeaderValue::from_static("plain"),
147 Self::Aegis256(key) => {
148 header_value_for_key(EncryptionAlgorithm::Aegis256, key.secret())
149 }
150 Self::Aes256Gcm(key) => {
151 header_value_for_key(EncryptionAlgorithm::Aes256Gcm, key.secret())
152 }
153 };
154 value.set_sensitive(true);
155 value
156 }
157}
158
159impl FromStr for EncryptionSpec {
160 type Err = EncryptionSpecError;
161
162 fn from_str(s: &str) -> Result<Self, Self::Err> {
163 let s = s.trim();
164 let mut parts = s.splitn(3, ';');
165 let mode_str = parts.next().unwrap_or_default().trim();
166 let key_b64 = parts.next().map(str::trim);
167 if parts.next().is_some() {
168 return Err(EncryptionSpecError::InvalidSyntax);
169 }
170
171 if mode_str.is_empty() {
172 return Err(EncryptionSpecError::MissingMode);
173 }
174
175 let key_b64 = key_b64.filter(|key| !key.is_empty());
176 match (parse_mode(mode_str)?, key_b64) {
177 (EncryptionMode::Plain, None) => Ok(Self::Plain),
178 (EncryptionMode::Plain, Some(_)) => Err(EncryptionSpecError::UnexpectedKeyForPlain),
179 (EncryptionMode::Aegis256, Some(key_b64)) => {
180 Ok(Self::Aegis256(Aegis256Key::from_base64(key_b64)?))
181 }
182 (EncryptionMode::Aegis256, None) => Err(EncryptionSpecError::MissingKey {
183 mode: EncryptionMode::Aegis256,
184 }),
185 (EncryptionMode::Aes256Gcm, Some(key_b64)) => {
186 Ok(Self::Aes256Gcm(Aes256GcmKey::from_base64(key_b64)?))
187 }
188 (EncryptionMode::Aes256Gcm, None) => Err(EncryptionSpecError::MissingKey {
189 mode: EncryptionMode::Aes256Gcm,
190 }),
191 }
192 }
193}
194
195impl ParseableHeader for EncryptionSpec {
196 fn name() -> &'static HeaderName {
197 &S2_ENCRYPTION_HEADER
198 }
199}
200
201fn parse_encryption_key<const N: usize>(
202 key_b64: &str,
203) -> Result<EncryptionKey<N>, EncryptionSpecError> {
204 use base64ct::{Base64, Encoding};
205 use secrecy::zeroize::Zeroize;
206
207 let mut key = Box::new([0u8; N]);
208 let decoded = match Base64::decode(key_b64, key.as_mut()) {
209 Ok(decoded) => decoded,
210 Err(_) => {
211 key.as_mut().zeroize();
212 return Err(EncryptionSpecError::InvalidKeyBase64);
213 }
214 };
215
216 if decoded.len() != N {
217 let len = decoded.len();
218 key.as_mut().zeroize();
219 return Err(EncryptionSpecError::InvalidKeyLength {
220 expected: N,
221 actual: len,
222 });
223 }
224
225 Ok(Arc::new(SecretBox::new(key)))
226}
227
228fn header_value_for_key(algorithm: EncryptionAlgorithm, key: &[u8; 32]) -> HeaderValue {
229 let algorithm = algorithm.to_string();
230 let encoded_len = base64ct::Base64::encoded_len(key);
231 let mut value = Zeroizing::new(vec![0u8; algorithm.len() + 2 + encoded_len]);
232 value[..algorithm.len()].copy_from_slice(algorithm.as_bytes());
233 value[algorithm.len()..algorithm.len() + 2].copy_from_slice(b"; ");
234 base64ct::Base64::encode(key, &mut value[algorithm.len() + 2..])
235 .expect("base64 output length should match buffer");
236
237 HeaderValue::from_bytes(&value).expect("encryption header value should be ASCII")
238}
239
240fn parse_mode(mode_str: &str) -> Result<EncryptionMode, EncryptionSpecError> {
241 mode_str
242 .parse::<EncryptionMode>()
243 .map_err(|_| EncryptionSpecError::UnknownMode {
244 mode: mode_str.to_owned(),
245 })
246}
247
248#[cfg(test)]
249mod tests {
250 use http::header::HeaderValue;
251 use rstest::rstest;
252
253 use super::*;
254
255 const KEY_B64: &str = "AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=";
256 const KEY_BYTES: [u8; 32] = [
257 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
258 26, 27, 28, 29, 30, 31, 32,
259 ];
260
261 fn assert_encrypted_spec(
262 spec: EncryptionSpec,
263 algorithm: EncryptionAlgorithm,
264 expected: &[u8; 32],
265 ) {
266 match (algorithm, spec) {
267 (EncryptionAlgorithm::Aegis256, EncryptionSpec::Aegis256(key)) => {
268 assert_eq!(key.secret(), expected)
269 }
270 (EncryptionAlgorithm::Aes256Gcm, EncryptionSpec::Aes256Gcm(key)) => {
271 assert_eq!(key.secret(), expected)
272 }
273 (_, EncryptionSpec::Plain) => panic!("expected encrypted spec"),
274 (expected_algorithm, actual_spec) => {
275 panic!("expected {expected_algorithm:?}, got {actual_spec:?}")
276 }
277 }
278 }
279
280 fn assert_invalid_parse(header: &str, expected: EncryptionSpecError) {
281 let result = header.parse::<EncryptionSpec>();
282 match result {
283 Err(actual) => assert_eq!(actual, expected),
284 Ok(actual) => panic!("expected invalid spec for {header:?}, got {actual:?}"),
285 }
286 }
287
288 #[rstest]
289 #[case("aegis-256", EncryptionAlgorithm::Aegis256)]
290 #[case("aes-256-gcm", EncryptionAlgorithm::Aes256Gcm)]
291 #[case("AEGIS-256", EncryptionAlgorithm::Aegis256)]
292 #[case("AES-256-GCM", EncryptionAlgorithm::Aes256Gcm)]
293 fn parse_header_valid_encrypted(
294 #[case] algorithm: &str,
295 #[case] expected: EncryptionAlgorithm,
296 ) {
297 let spec = format!("{algorithm}; {KEY_B64}")
298 .parse::<EncryptionSpec>()
299 .unwrap();
300 assert_encrypted_spec(spec, expected, &KEY_BYTES);
301 }
302
303 #[test]
304 fn parse_header_aes_with_whitespace() {
305 let spec = format!(" aes-256-gcm ; {KEY_B64} ")
306 .parse::<EncryptionSpec>()
307 .unwrap();
308 assert_encrypted_spec(spec, EncryptionAlgorithm::Aes256Gcm, &KEY_BYTES);
309 }
310
311 #[rstest]
312 #[case("plain")]
313 #[case("PLAIN")]
314 #[case("plain; ")]
315 fn parse_header_plain_variants(#[case] header: &str) {
316 let spec = header.parse::<EncryptionSpec>().unwrap();
317 assert!(matches!(spec, EncryptionSpec::Plain));
318 }
319
320 #[test]
321 fn spec_mode_matches_variant() {
322 assert_eq!(EncryptionSpec::Plain.mode(), EncryptionMode::Plain);
323 assert_eq!(
324 EncryptionSpec::aegis256(KEY_BYTES).mode(),
325 EncryptionMode::Aegis256
326 );
327 assert_eq!(
328 EncryptionSpec::aes256_gcm(KEY_BYTES).mode(),
329 EncryptionMode::Aes256Gcm
330 );
331 }
332
333 #[rstest]
334 #[case("", EncryptionSpecError::MissingMode)]
335 #[case(
336 "aegis-256",
337 EncryptionSpecError::MissingKey {
338 mode: EncryptionMode::Aegis256
339 }
340 )]
341 #[case(
342 "aegis-256; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=; extra",
343 EncryptionSpecError::InvalidSyntax
344 )]
345 #[case(
346 "aegis-256; 3q2+7w==",
347 EncryptionSpecError::InvalidKeyLength {
348 expected: 32,
349 actual: 4
350 }
351 )]
352 #[case(
353 "aegis-256; not-valid-base64!!!",
354 EncryptionSpecError::InvalidKeyBase64
355 )]
356 #[case(
357 "bogus; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=",
358 EncryptionSpecError::UnknownMode {
359 mode: "bogus".to_owned()
360 }
361 )]
362 #[case(
363 "plain; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=",
364 EncryptionSpecError::UnexpectedKeyForPlain
365 )]
366 fn parse_header_invalid_cases(#[case] header: &str, #[case] expected: EncryptionSpecError) {
367 assert_invalid_parse(header, expected);
368 }
369
370 #[test]
371 fn header_value_is_sensitive() {
372 let value = EncryptionSpec::aegis256([7; 32]).to_header_value();
373 assert!(value.is_sensitive());
374 assert_ne!(value, HeaderValue::from_static("plain"));
375 }
376
377 #[test]
378 fn plain_header_value_roundtrips() {
379 let value = EncryptionSpec::Plain.to_header_value();
380 assert_eq!(value.to_str().unwrap(), "plain");
381 assert!(value.is_sensitive());
382
383 let parsed = value.to_str().unwrap().parse::<EncryptionSpec>().unwrap();
384 assert!(matches!(parsed, EncryptionSpec::Plain));
385 }
386
387 #[rstest]
388 #[case(EncryptionAlgorithm::Aegis256)]
389 #[case(EncryptionAlgorithm::Aes256Gcm)]
390 fn encrypted_header_value_roundtrips(#[case] algorithm: EncryptionAlgorithm) {
391 let value = match algorithm {
392 EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(KEY_BYTES),
393 EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(KEY_BYTES),
394 }
395 .to_header_value();
396 assert_eq!(value.to_str().unwrap(), format!("{algorithm}; {KEY_B64}"));
397 assert!(value.is_sensitive());
398
399 let parsed = value.to_str().unwrap().parse::<EncryptionSpec>().unwrap();
400 assert_encrypted_spec(parsed, algorithm, &KEY_BYTES);
401 }
402}