1use core::str::FromStr;
4use std::sync::Arc;
5
6use base64ct::Encoding;
7use http::{HeaderName, HeaderValue};
8use secrecy::{ExposeSecret, SecretBox, zeroize::Zeroizing};
9use strum::{Display, EnumString};
10
11use crate::http::ParseableHeader;
12
13pub static S2_ENCRYPTION_HEADER: HeaderName = HeaderName::from_static("s2-encryption");
14
15type EncryptionKey<const N: usize> = Arc<SecretBox<[u8; N]>>;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Display, EnumString)]
19#[strum(ascii_case_insensitive)]
20pub enum EncryptionAlgorithm {
21 #[strum(serialize = "aegis-256")]
23 Aegis256,
24 #[strum(serialize = "aes-256-gcm")]
26 Aes256Gcm,
27}
28
29#[derive(
31 Debug,
32 Clone,
33 Copy,
34 PartialEq,
35 Eq,
36 Hash,
37 serde::Serialize,
38 serde::Deserialize,
39 Display,
40 EnumString,
41 enumset::EnumSetType,
42)]
43#[strum(ascii_case_insensitive)]
44#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
45#[enumset(no_super_impls)]
46pub enum EncryptionMode {
47 #[strum(serialize = "plain")]
48 #[serde(rename = "plain")]
49 Plain,
50 #[strum(serialize = "aegis-256")]
51 #[serde(rename = "aegis-256")]
52 #[cfg_attr(feature = "clap", value(name = "aegis-256"))]
53 Aegis256,
54 #[strum(serialize = "aes-256-gcm")]
55 #[serde(rename = "aes-256-gcm")]
56 #[cfg_attr(feature = "clap", value(name = "aes-256-gcm"))]
57 Aes256Gcm,
58}
59
60impl From<EncryptionAlgorithm> for EncryptionMode {
61 fn from(value: EncryptionAlgorithm) -> Self {
62 match value {
63 EncryptionAlgorithm::Aegis256 => Self::Aegis256,
64 EncryptionAlgorithm::Aes256Gcm => Self::Aes256Gcm,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
70pub struct Aegis256Key(EncryptionKey<32>);
71
72impl Aegis256Key {
73 pub fn new(key: [u8; 32]) -> Self {
74 Self(Arc::new(SecretBox::new(Box::new(key))))
75 }
76
77 pub fn from_base64(key_b64: &str) -> Result<Self, EncryptionSpecError> {
78 parse_encryption_key::<32>(key_b64).map(Self)
79 }
80
81 pub(crate) fn secret(&self) -> &[u8; 32] {
82 self.0.as_ref().expose_secret()
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct Aes256GcmKey(EncryptionKey<32>);
88
89impl Aes256GcmKey {
90 pub fn new(key: [u8; 32]) -> Self {
91 Self(Arc::new(SecretBox::new(Box::new(key))))
92 }
93
94 pub fn from_base64(key_b64: &str) -> Result<Self, EncryptionSpecError> {
95 parse_encryption_key::<32>(key_b64).map(Self)
96 }
97
98 pub(crate) fn secret(&self) -> &[u8; 32] {
99 self.0.as_ref().expose_secret()
100 }
101}
102
103#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
104pub enum EncryptionSpecError {
105 #[error("Invalid encryption spec: expected '<mode>; <key>' or 'plain'")]
106 InvalidSyntax,
107 #[error("Invalid encryption spec: missing encryption mode")]
108 MissingMode,
109 #[error(
110 "Invalid encryption spec: unknown encryption mode {mode:?}; expected 'plain', 'aegis-256', or 'aes-256-gcm'"
111 )]
112 UnknownMode { mode: String },
113 #[error("Invalid encryption spec: key is not allowed when mode is 'plain'")]
114 UnexpectedKeyForPlain,
115 #[error("Invalid encryption spec: missing key for '{mode}'")]
116 MissingKey { mode: EncryptionMode },
117 #[error("Invalid encryption spec: key is not valid base64")]
118 InvalidKeyBase64,
119 #[error("Invalid encryption spec: key must be exactly {expected} bytes, got {actual} bytes")]
120 InvalidKeyLength { expected: usize, actual: usize },
121}
122
123#[derive(Debug, Clone, Default)]
124pub enum EncryptionSpec {
125 #[default]
126 Plain,
127 Aegis256(Aegis256Key),
128 Aes256Gcm(Aes256GcmKey),
129}
130
131impl EncryptionSpec {
132 pub fn aegis256(key: [u8; 32]) -> Self {
133 Self::Aegis256(Aegis256Key::new(key))
134 }
135
136 pub fn aes256_gcm(key: [u8; 32]) -> Self {
137 Self::Aes256Gcm(Aes256GcmKey::new(key))
138 }
139
140 pub fn mode(&self) -> EncryptionMode {
141 match self {
142 Self::Plain => EncryptionMode::Plain,
143 Self::Aegis256(_) => EncryptionMode::Aegis256,
144 Self::Aes256Gcm(_) => EncryptionMode::Aes256Gcm,
145 }
146 }
147
148 pub fn to_header_value(&self) -> HeaderValue {
149 let mut value = match self {
150 Self::Plain => HeaderValue::from_static("plain"),
151 Self::Aegis256(key) => {
152 header_value_for_key(EncryptionAlgorithm::Aegis256, key.secret())
153 }
154 Self::Aes256Gcm(key) => {
155 header_value_for_key(EncryptionAlgorithm::Aes256Gcm, key.secret())
156 }
157 };
158 value.set_sensitive(true);
159 value
160 }
161}
162
163impl FromStr for EncryptionSpec {
164 type Err = EncryptionSpecError;
165
166 fn from_str(s: &str) -> Result<Self, Self::Err> {
167 let s = s.trim();
168 let mut parts = s.splitn(3, ';');
169 let mode_str = parts.next().unwrap_or_default().trim();
170 let key_b64 = parts.next().map(str::trim);
171 if parts.next().is_some() {
172 return Err(EncryptionSpecError::InvalidSyntax);
173 }
174
175 if mode_str.is_empty() {
176 return Err(EncryptionSpecError::MissingMode);
177 }
178
179 let key_b64 = key_b64.filter(|key| !key.is_empty());
180 match (parse_mode(mode_str)?, key_b64) {
181 (EncryptionMode::Plain, None) => Ok(Self::Plain),
182 (EncryptionMode::Plain, Some(_)) => Err(EncryptionSpecError::UnexpectedKeyForPlain),
183 (EncryptionMode::Aegis256, Some(key_b64)) => {
184 Ok(Self::Aegis256(Aegis256Key::from_base64(key_b64)?))
185 }
186 (EncryptionMode::Aegis256, None) => Err(EncryptionSpecError::MissingKey {
187 mode: EncryptionMode::Aegis256,
188 }),
189 (EncryptionMode::Aes256Gcm, Some(key_b64)) => {
190 Ok(Self::Aes256Gcm(Aes256GcmKey::from_base64(key_b64)?))
191 }
192 (EncryptionMode::Aes256Gcm, None) => Err(EncryptionSpecError::MissingKey {
193 mode: EncryptionMode::Aes256Gcm,
194 }),
195 }
196 }
197}
198
199impl ParseableHeader for EncryptionSpec {
200 fn name() -> &'static HeaderName {
201 &S2_ENCRYPTION_HEADER
202 }
203}
204
205fn parse_encryption_key<const N: usize>(
206 key_b64: &str,
207) -> Result<EncryptionKey<N>, EncryptionSpecError> {
208 use base64ct::{Base64, Encoding};
209 use secrecy::zeroize::Zeroize;
210
211 let mut key = Box::new([0u8; N]);
212 let decoded = match Base64::decode(key_b64, key.as_mut()) {
213 Ok(decoded) => decoded,
214 Err(_) => {
215 key.as_mut().zeroize();
216 return Err(EncryptionSpecError::InvalidKeyBase64);
217 }
218 };
219
220 if decoded.len() != N {
221 let len = decoded.len();
222 key.as_mut().zeroize();
223 return Err(EncryptionSpecError::InvalidKeyLength {
224 expected: N,
225 actual: len,
226 });
227 }
228
229 Ok(Arc::new(SecretBox::new(key)))
230}
231
232fn header_value_for_key(algorithm: EncryptionAlgorithm, key: &[u8; 32]) -> HeaderValue {
233 let algorithm = algorithm.to_string();
234 let encoded_len = base64ct::Base64::encoded_len(key);
235 let mut value = Zeroizing::new(vec![0u8; algorithm.len() + 2 + encoded_len]);
236 value[..algorithm.len()].copy_from_slice(algorithm.as_bytes());
237 value[algorithm.len()..algorithm.len() + 2].copy_from_slice(b"; ");
238 base64ct::Base64::encode(key, &mut value[algorithm.len() + 2..])
239 .expect("base64 output length should match buffer");
240
241 HeaderValue::from_bytes(&value).expect("encryption header value should be ASCII")
242}
243
244fn parse_mode(mode_str: &str) -> Result<EncryptionMode, EncryptionSpecError> {
245 mode_str
246 .parse::<EncryptionMode>()
247 .map_err(|_| EncryptionSpecError::UnknownMode {
248 mode: mode_str.to_owned(),
249 })
250}
251
252#[cfg(test)]
253mod tests {
254 use http::header::HeaderValue;
255 use rstest::rstest;
256
257 use super::*;
258
259 const KEY_B64: &str = "AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=";
260 const KEY_BYTES: [u8; 32] = [
261 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
262 26, 27, 28, 29, 30, 31, 32,
263 ];
264
265 fn assert_encrypted_spec(
266 spec: EncryptionSpec,
267 algorithm: EncryptionAlgorithm,
268 expected: &[u8; 32],
269 ) {
270 match (algorithm, spec) {
271 (EncryptionAlgorithm::Aegis256, EncryptionSpec::Aegis256(key)) => {
272 assert_eq!(key.secret(), expected)
273 }
274 (EncryptionAlgorithm::Aes256Gcm, EncryptionSpec::Aes256Gcm(key)) => {
275 assert_eq!(key.secret(), expected)
276 }
277 (_, EncryptionSpec::Plain) => panic!("expected encrypted spec"),
278 (expected_algorithm, actual_spec) => {
279 panic!("expected {expected_algorithm:?}, got {actual_spec:?}")
280 }
281 }
282 }
283
284 fn assert_invalid_parse(header: &str, expected: EncryptionSpecError) {
285 let result = header.parse::<EncryptionSpec>();
286 match result {
287 Err(actual) => assert_eq!(actual, expected),
288 Ok(actual) => panic!("expected invalid spec for {header:?}, got {actual:?}"),
289 }
290 }
291
292 #[rstest]
293 #[case("aegis-256", EncryptionAlgorithm::Aegis256)]
294 #[case("aes-256-gcm", EncryptionAlgorithm::Aes256Gcm)]
295 #[case("AEGIS-256", EncryptionAlgorithm::Aegis256)]
296 #[case("AES-256-GCM", EncryptionAlgorithm::Aes256Gcm)]
297 fn parse_header_valid_encrypted(
298 #[case] algorithm: &str,
299 #[case] expected: EncryptionAlgorithm,
300 ) {
301 let spec = format!("{algorithm}; {KEY_B64}")
302 .parse::<EncryptionSpec>()
303 .unwrap();
304 assert_encrypted_spec(spec, expected, &KEY_BYTES);
305 }
306
307 #[test]
308 fn parse_header_aes_with_whitespace() {
309 let spec = format!(" aes-256-gcm ; {KEY_B64} ")
310 .parse::<EncryptionSpec>()
311 .unwrap();
312 assert_encrypted_spec(spec, EncryptionAlgorithm::Aes256Gcm, &KEY_BYTES);
313 }
314
315 #[rstest]
316 #[case("plain")]
317 #[case("PLAIN")]
318 #[case("plain; ")]
319 fn parse_header_plain_variants(#[case] header: &str) {
320 let spec = header.parse::<EncryptionSpec>().unwrap();
321 assert!(matches!(spec, EncryptionSpec::Plain));
322 }
323
324 #[test]
325 fn spec_mode_matches_variant() {
326 assert_eq!(EncryptionSpec::Plain.mode(), EncryptionMode::Plain);
327 assert_eq!(
328 EncryptionSpec::aegis256(KEY_BYTES).mode(),
329 EncryptionMode::Aegis256
330 );
331 assert_eq!(
332 EncryptionSpec::aes256_gcm(KEY_BYTES).mode(),
333 EncryptionMode::Aes256Gcm
334 );
335 }
336
337 #[rstest]
338 #[case(EncryptionMode::Plain, "\"plain\"")]
339 #[case(EncryptionMode::Aegis256, "\"aegis-256\"")]
340 #[case(EncryptionMode::Aes256Gcm, "\"aes-256-gcm\"")]
341 fn mode_serde_roundtrip(#[case] mode: EncryptionMode, #[case] expected: &str) {
342 let serialized = serde_json::to_string(&mode).unwrap();
343 assert_eq!(serialized, expected);
344 let deserialized: EncryptionMode = serde_json::from_str(expected).unwrap();
345 assert_eq!(deserialized, mode);
346 }
347
348 #[rstest]
349 #[case(EncryptionMode::Plain, "plain")]
350 #[case(EncryptionMode::Aegis256, "aegis-256")]
351 #[case(EncryptionMode::Aes256Gcm, "aes-256-gcm")]
352 fn mode_display_matches_spec(#[case] mode: EncryptionMode, #[case] expected: &str) {
353 assert_eq!(mode.to_string(), expected);
354 }
355
356 #[rstest]
357 #[case("", EncryptionSpecError::MissingMode)]
358 #[case(
359 "aegis-256",
360 EncryptionSpecError::MissingKey {
361 mode: EncryptionMode::Aegis256
362 }
363 )]
364 #[case(
365 "aegis-256; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=; extra",
366 EncryptionSpecError::InvalidSyntax
367 )]
368 #[case(
369 "aegis-256; 3q2+7w==",
370 EncryptionSpecError::InvalidKeyLength {
371 expected: 32,
372 actual: 4
373 }
374 )]
375 #[case(
376 "aegis-256; not-valid-base64!!!",
377 EncryptionSpecError::InvalidKeyBase64
378 )]
379 #[case(
380 "bogus; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=",
381 EncryptionSpecError::UnknownMode {
382 mode: "bogus".to_owned()
383 }
384 )]
385 #[case(
386 "plain; AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=",
387 EncryptionSpecError::UnexpectedKeyForPlain
388 )]
389 fn parse_header_invalid_cases(#[case] header: &str, #[case] expected: EncryptionSpecError) {
390 assert_invalid_parse(header, expected);
391 }
392
393 #[test]
394 fn header_value_is_sensitive() {
395 let value = EncryptionSpec::aegis256([7; 32]).to_header_value();
396 assert!(value.is_sensitive());
397 assert_ne!(value, HeaderValue::from_static("plain"));
398 }
399
400 #[test]
401 fn plain_header_value_roundtrips() {
402 let value = EncryptionSpec::Plain.to_header_value();
403 assert_eq!(value.to_str().unwrap(), "plain");
404 assert!(value.is_sensitive());
405
406 let parsed = value.to_str().unwrap().parse::<EncryptionSpec>().unwrap();
407 assert!(matches!(parsed, EncryptionSpec::Plain));
408 }
409
410 #[rstest]
411 #[case(EncryptionAlgorithm::Aegis256)]
412 #[case(EncryptionAlgorithm::Aes256Gcm)]
413 fn encrypted_header_value_roundtrips(#[case] algorithm: EncryptionAlgorithm) {
414 let value = match algorithm {
415 EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(KEY_BYTES),
416 EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(KEY_BYTES),
417 }
418 .to_header_value();
419 assert_eq!(value.to_str().unwrap(), format!("{algorithm}; {KEY_B64}"));
420 assert!(value.is_sensitive());
421
422 let parsed = value.to_str().unwrap().parse::<EncryptionSpec>().unwrap();
423 assert_encrypted_spec(parsed, algorithm, &KEY_BYTES);
424 }
425}