1use core::str::FromStr;
4use std::sync::Arc;
5
6use base64ct::{Base64, Decoder, Encoding};
7use http::{HeaderName, HeaderValue};
8use secrecy::{ExposeSecret, SecretBox, SecretString, zeroize::Zeroize};
9use strum::{Display, EnumString};
10
11use crate::http::ParseableHeader;
12
13pub static S2_ENCRYPTION_KEY_HEADER: HeaderName = HeaderName::from_static("s2-encryption-key");
14
15const MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN: usize = 44;
17
18type EncodedKeyMaterial = Arc<SecretString>;
19type DecodedKey<const N: usize> = Arc<SecretBox<[u8; N]>>;
20
21#[derive(
23 Debug,
24 Clone,
25 Copy,
26 PartialEq,
27 Eq,
28 Hash,
29 serde::Serialize,
30 serde::Deserialize,
31 Display,
32 EnumString,
33)]
34#[strum(ascii_case_insensitive)]
35#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
36pub enum EncryptionAlgorithm {
37 #[strum(serialize = "aegis-256")]
39 #[serde(rename = "aegis-256")]
40 #[cfg_attr(feature = "clap", value(name = "aegis-256"))]
41 Aegis256,
42 #[strum(serialize = "aes-256-gcm")]
44 #[serde(rename = "aes-256-gcm")]
45 #[cfg_attr(feature = "clap", value(name = "aes-256-gcm"))]
46 Aes256Gcm,
47}
48
49#[derive(Debug, Clone)]
51pub struct EncryptionKey(EncodedKeyMaterial);
52
53impl EncryptionKey {
54 pub fn new<const N: usize>(key: [u8; N]) -> Self {
55 Self(Arc::new(Base64::encode_string(&key).into()))
56 }
57
58 pub(crate) fn expose_secret(&self) -> &str {
59 self.0.expose_secret()
60 }
61
62 pub fn to_header_value(&self) -> HeaderValue {
63 let mut value = HeaderValue::from_bytes(self.expose_secret().as_bytes())
64 .expect("encryption key header value should be ASCII");
65 value.set_sensitive(true);
66 value
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct DecodedEncryptionKey<const N: usize>(DecodedKey<N>);
73
74impl<const N: usize> DecodedEncryptionKey<N> {
75 pub fn new(key: [u8; N]) -> Self {
76 Self(Arc::new(SecretBox::new(Box::new(key))))
77 }
78}
79
80impl<const N: usize> ExposeSecret<[u8; N]> for DecodedEncryptionKey<N> {
81 fn expose_secret(&self) -> &[u8; N] {
82 self.0.expose_secret()
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
87#[error("invalid encryption key: key material length {0} is out of range")]
88pub struct EncryptionKeyLengthError(usize);
89
90#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
91pub enum EncryptionSpecResolutionError {
92 #[error("missing encryption key for stream cipher '{cipher}'")]
93 MissingKey { cipher: EncryptionAlgorithm },
94 #[error("invalid encryption key for stream cipher '{cipher}': invalid base64")]
95 InvalidBase64 { cipher: EncryptionAlgorithm },
96 #[error("invalid encryption key length for stream cipher '{cipher}': {length}")]
97 InvalidKeyLength {
98 cipher: EncryptionAlgorithm,
99 length: usize,
100 },
101}
102
103#[rustfmt::skip]
105#[derive(Debug, Clone, Default)]
106pub enum EncryptionSpec {
107 #[default]
108 Plain,
109 Aegis256(DecodedEncryptionKey<32>),
110 Aes256Gcm(DecodedEncryptionKey<32>),
111}
112
113impl EncryptionSpec {
114 pub fn resolve(
115 cipher: Option<EncryptionAlgorithm>,
116 key: Option<EncryptionKey>,
117 ) -> Result<Self, EncryptionSpecResolutionError> {
118 match (cipher, key) {
119 (None, _) => Ok(Self::Plain),
120 (Some(cipher @ EncryptionAlgorithm::Aegis256), Some(key)) => {
121 Ok(Self::Aegis256(resolve_key(cipher, key)?))
122 }
123 (Some(cipher @ EncryptionAlgorithm::Aes256Gcm), Some(key)) => {
124 Ok(Self::Aes256Gcm(resolve_key(cipher, key)?))
125 }
126 (Some(cipher), None) => Err(EncryptionSpecResolutionError::MissingKey { cipher }),
127 }
128 }
129
130 pub fn aegis256(key: [u8; 32]) -> Self {
131 Self::Aegis256(DecodedEncryptionKey::new(key))
132 }
133
134 pub fn aes256_gcm(key: [u8; 32]) -> Self {
135 Self::Aes256Gcm(DecodedEncryptionKey::new(key))
136 }
137}
138
139impl FromStr for EncryptionKey {
140 type Err = EncryptionKeyLengthError;
141
142 fn from_str(s: &str) -> Result<Self, Self::Err> {
143 let trimmed = s.trim();
144 if (1..=MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN).contains(&trimmed.len()) {
145 Ok(Self(Arc::new(trimmed.to_owned().into())))
146 } else {
147 Err(EncryptionKeyLengthError(trimmed.len()))
148 }
149 }
150}
151
152impl ParseableHeader for EncryptionKey {
153 fn name() -> &'static HeaderName {
154 &S2_ENCRYPTION_KEY_HEADER
155 }
156}
157
158fn resolve_key<const N: usize>(
159 cipher: EncryptionAlgorithm,
160 key: EncryptionKey,
161) -> Result<DecodedEncryptionKey<N>, EncryptionSpecResolutionError> {
162 let mut decoder = Decoder::<Base64>::new(key.expose_secret().as_bytes())
163 .map_err(|_| EncryptionSpecResolutionError::InvalidBase64 { cipher })?;
164 let mut key_material = Box::new([0u8; N]);
165 match decoder.decode(key_material.as_mut()) {
166 Ok(_) if decoder.is_finished() => {
167 Ok(DecodedEncryptionKey(Arc::new(SecretBox::new(key_material))))
168 }
169 Ok(_) => {
170 let length = N
171 .checked_add(decoder.remaining_len())
172 .expect("decoded key length should fit usize");
173 key_material.as_mut().zeroize();
174 Err(EncryptionSpecResolutionError::InvalidKeyLength { cipher, length })
175 }
176 Err(base64ct::Error::InvalidEncoding) => {
177 key_material.as_mut().zeroize();
178 Err(EncryptionSpecResolutionError::InvalidBase64 { cipher })
179 }
180 Err(base64ct::Error::InvalidLength) => {
181 let length = decoder.remaining_len();
182 key_material.as_mut().zeroize();
183 Err(EncryptionSpecResolutionError::InvalidKeyLength { cipher, length })
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use rstest::rstest;
191
192 use super::*;
193
194 const KEY_B64: &str = "AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=";
195 const KEY_BYTES: [u8; 32] = [
196 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
197 26, 27, 28, 29, 30, 31, 32,
198 ];
199
200 fn resolve_encrypted(
201 cipher: EncryptionAlgorithm,
202 key: EncryptionKey,
203 ) -> Result<EncryptionSpec, EncryptionSpecResolutionError> {
204 EncryptionSpec::resolve(Some(cipher), Some(key))
205 }
206
207 #[test]
208 fn key_header_value_roundtrips_and_is_sensitive() {
209 let value = EncryptionKey::new(KEY_BYTES).to_header_value();
210 assert_eq!(value.to_str().unwrap(), KEY_B64);
211 assert!(value.is_sensitive());
212
213 let parsed = value.to_str().unwrap().parse::<EncryptionKey>().unwrap();
214 assert_eq!(parsed.to_header_value().to_str().unwrap(), KEY_B64);
215 }
216
217 #[test]
218 fn encryption_key_parsing_trims_and_enforces_bounds() {
219 let parsed = format!(" {KEY_B64}\n").parse::<EncryptionKey>().unwrap();
220 assert_eq!(parsed.to_header_value().to_str().unwrap(), KEY_B64);
221
222 assert_eq!(
223 " ".parse::<EncryptionKey>().unwrap_err(),
224 EncryptionKeyLengthError(0)
225 );
226
227 let too_long = "A".repeat(MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN + 1);
228 assert_eq!(
229 too_long.parse::<EncryptionKey>().unwrap_err(),
230 EncryptionKeyLengthError(MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN + 1)
231 );
232 }
233
234 #[test]
235 fn resolve_plain_ignores_supplied_key() {
236 let encryption = EncryptionSpec::resolve(None, Some("!!!!".parse().unwrap())).unwrap();
237 assert!(matches!(encryption, EncryptionSpec::Plain));
238 }
239
240 #[rstest]
241 #[case(EncryptionAlgorithm::Aegis256)]
242 #[case(EncryptionAlgorithm::Aes256Gcm)]
243 fn resolve_encrypted_requires_key(#[case] cipher: EncryptionAlgorithm) {
244 let err = EncryptionSpec::resolve(Some(cipher), None).unwrap_err();
245 assert_eq!(err, EncryptionSpecResolutionError::MissingKey { cipher });
246 }
247
248 #[rstest]
249 #[case(EncryptionAlgorithm::Aegis256)]
250 #[case(EncryptionAlgorithm::Aes256Gcm)]
251 fn resolve_encrypted_decodes_key_for_each_algorithm(#[case] cipher: EncryptionAlgorithm) {
252 let encryption = resolve_encrypted(cipher, EncryptionKey::new(KEY_BYTES)).unwrap();
253
254 match (cipher, encryption) {
255 (EncryptionAlgorithm::Aegis256, EncryptionSpec::Aegis256(key)) => {
256 assert_eq!(key.expose_secret(), &KEY_BYTES);
257 }
258 (EncryptionAlgorithm::Aes256Gcm, EncryptionSpec::Aes256Gcm(key)) => {
259 assert_eq!(key.expose_secret(), &KEY_BYTES);
260 }
261 _ => panic!("resolved encryption spec did not match requested algorithm"),
262 }
263 }
264
265 #[rstest]
266 #[case(EncryptionAlgorithm::Aegis256)]
267 #[case(EncryptionAlgorithm::Aes256Gcm)]
268 fn resolve_encrypted_rejects_invalid_base64(#[case] cipher: EncryptionAlgorithm) {
269 let err = resolve_encrypted(cipher, "!!!!".parse().unwrap()).unwrap_err();
270 assert_eq!(err, EncryptionSpecResolutionError::InvalidBase64 { cipher });
271 }
272
273 #[test]
274 fn resolve_encrypted_rejects_non_32_byte_keys() {
275 let cipher = EncryptionAlgorithm::Aegis256;
276
277 let short_err = resolve_encrypted(cipher, EncryptionKey::new([0x42; 4])).unwrap_err();
278 assert_eq!(
279 short_err,
280 EncryptionSpecResolutionError::InvalidKeyLength { cipher, length: 4 }
281 );
282
283 let long_err = resolve_encrypted(cipher, EncryptionKey::new([0x42; 33])).unwrap_err();
284 assert_eq!(
285 long_err,
286 EncryptionSpecResolutionError::InvalidKeyLength { cipher, length: 33 }
287 );
288 }
289}