1use core::str::FromStr;
4use std::sync::Arc;
5
6use base64ct::{Base64, Decoder, Encoding};
7use http::{HeaderName, HeaderValue};
8use secrecy::{ExposeSecret, SecretBox, SecretString, zeroize::Zeroize};
9use strum::{Display, EnumString};
10
11use crate::http::ParseableHeader;
12
13pub static S2_ENCRYPTION_KEY_HEADER: HeaderName = HeaderName::from_static("s2-encryption-key");
14
15const MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN: usize = 44;
17
18type EncodedKeyMaterial = Arc<SecretString>;
19type DecodedKey<const N: usize> = Arc<SecretBox<[u8; N]>>;
20
21#[derive(
23 Debug,
24 Clone,
25 Copy,
26 PartialEq,
27 Eq,
28 Hash,
29 serde::Serialize,
30 serde::Deserialize,
31 Display,
32 EnumString,
33)]
34#[strum(ascii_case_insensitive)]
35#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
36pub enum EncryptionAlgorithm {
37 #[strum(serialize = "aegis-256")]
39 #[serde(rename = "aegis-256")]
40 #[cfg_attr(feature = "clap", value(name = "aegis-256"))]
41 Aegis256,
42 #[strum(serialize = "aes-256-gcm")]
44 #[serde(rename = "aes-256-gcm")]
45 #[cfg_attr(feature = "clap", value(name = "aes-256-gcm"))]
46 Aes256Gcm,
47}
48
49#[derive(Debug, Clone)]
51pub struct EncryptionKey(EncodedKeyMaterial);
52
53impl EncryptionKey {
54 pub fn new<const N: usize>(key: [u8; N]) -> Self {
55 Self(Arc::new(Base64::encode_string(&key).into()))
56 }
57
58 pub(crate) fn expose_secret(&self) -> &str {
59 self.0.expose_secret()
60 }
61
62 pub fn to_header_value(&self) -> HeaderValue {
63 let mut value = HeaderValue::from_bytes(self.expose_secret().as_bytes())
64 .expect("encryption key header value should be ASCII");
65 value.set_sensitive(true);
66 value
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct DecodedEncryptionKey<const N: usize>(DecodedKey<N>);
73
74impl<const N: usize> DecodedEncryptionKey<N> {
75 pub fn new(key: [u8; N]) -> Self {
76 Self(Arc::new(SecretBox::new(Box::new(key))))
77 }
78
79 pub(crate) fn expose_secret(&self) -> &[u8; N] {
80 self.0.expose_secret()
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
85#[error("invalid encryption key: key material length {0} is out of range")]
86pub struct EncryptionKeyLengthError(usize);
87
88#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
89pub enum EncryptionSpecResolutionError {
90 #[error("missing encryption key for stream cipher '{cipher}'")]
91 MissingKey { cipher: EncryptionAlgorithm },
92 #[error("invalid encryption key for stream cipher '{cipher}': invalid base64")]
93 InvalidBase64 { cipher: EncryptionAlgorithm },
94 #[error("invalid encryption key length for stream cipher '{cipher}': {length}")]
95 InvalidKeyLength {
96 cipher: EncryptionAlgorithm,
97 length: usize,
98 },
99}
100
101#[rustfmt::skip]
103#[derive(Debug, Clone, Default)]
104pub enum EncryptionSpec {
105 #[default]
106 Plain,
107 Aegis256(DecodedEncryptionKey<32>),
108 Aes256Gcm(DecodedEncryptionKey<32>),
109}
110
111impl EncryptionSpec {
112 pub fn resolve(
113 cipher: Option<EncryptionAlgorithm>,
114 key: Option<EncryptionKey>,
115 ) -> Result<Self, EncryptionSpecResolutionError> {
116 match (cipher, key) {
117 (None, _) => Ok(Self::Plain),
118 (Some(cipher @ EncryptionAlgorithm::Aegis256), Some(key)) => {
119 Ok(Self::Aegis256(resolve_key(cipher, key)?))
120 }
121 (Some(cipher @ EncryptionAlgorithm::Aes256Gcm), Some(key)) => {
122 Ok(Self::Aes256Gcm(resolve_key(cipher, key)?))
123 }
124 (Some(cipher), None) => Err(EncryptionSpecResolutionError::MissingKey { cipher }),
125 }
126 }
127
128 pub fn aegis256(key: [u8; 32]) -> Self {
129 Self::Aegis256(DecodedEncryptionKey::new(key))
130 }
131
132 pub fn aes256_gcm(key: [u8; 32]) -> Self {
133 Self::Aes256Gcm(DecodedEncryptionKey::new(key))
134 }
135}
136
137impl FromStr for EncryptionKey {
138 type Err = EncryptionKeyLengthError;
139
140 fn from_str(s: &str) -> Result<Self, Self::Err> {
141 let trimmed = s.trim();
142 if (1..=MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN).contains(&trimmed.len()) {
143 Ok(Self(Arc::new(trimmed.to_owned().into())))
144 } else {
145 Err(EncryptionKeyLengthError(trimmed.len()))
146 }
147 }
148}
149
150impl ParseableHeader for EncryptionKey {
151 fn name() -> &'static HeaderName {
152 &S2_ENCRYPTION_KEY_HEADER
153 }
154}
155
156fn resolve_key<const N: usize>(
157 cipher: EncryptionAlgorithm,
158 key: EncryptionKey,
159) -> Result<DecodedEncryptionKey<N>, EncryptionSpecResolutionError> {
160 let mut decoder = Decoder::<Base64>::new(key.expose_secret().as_bytes())
161 .map_err(|_| EncryptionSpecResolutionError::InvalidBase64 { cipher })?;
162 let mut key_material = Box::new([0u8; N]);
163 match decoder.decode(key_material.as_mut()) {
164 Ok(_) if decoder.is_finished() => {
165 Ok(DecodedEncryptionKey(Arc::new(SecretBox::new(key_material))))
166 }
167 Ok(_) => {
168 let length = N
169 .checked_add(decoder.remaining_len())
170 .expect("decoded key length should fit usize");
171 key_material.as_mut().zeroize();
172 Err(EncryptionSpecResolutionError::InvalidKeyLength { cipher, length })
173 }
174 Err(base64ct::Error::InvalidEncoding) => {
175 key_material.as_mut().zeroize();
176 Err(EncryptionSpecResolutionError::InvalidBase64 { cipher })
177 }
178 Err(base64ct::Error::InvalidLength) => {
179 let length = decoder.remaining_len();
180 key_material.as_mut().zeroize();
181 Err(EncryptionSpecResolutionError::InvalidKeyLength { cipher, length })
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use rstest::rstest;
189
190 use super::*;
191
192 const KEY_B64: &str = "AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA=";
193 const KEY_BYTES: [u8; 32] = [
194 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
195 26, 27, 28, 29, 30, 31, 32,
196 ];
197
198 fn resolve_encrypted(
199 cipher: EncryptionAlgorithm,
200 key: EncryptionKey,
201 ) -> Result<EncryptionSpec, EncryptionSpecResolutionError> {
202 EncryptionSpec::resolve(Some(cipher), Some(key))
203 }
204
205 #[test]
206 fn key_header_value_roundtrips_and_is_sensitive() {
207 let value = EncryptionKey::new(KEY_BYTES).to_header_value();
208 assert_eq!(value.to_str().unwrap(), KEY_B64);
209 assert!(value.is_sensitive());
210
211 let parsed = value.to_str().unwrap().parse::<EncryptionKey>().unwrap();
212 assert_eq!(parsed.to_header_value().to_str().unwrap(), KEY_B64);
213 }
214
215 #[test]
216 fn encryption_key_parsing_trims_and_enforces_bounds() {
217 let parsed = format!(" {KEY_B64}\n").parse::<EncryptionKey>().unwrap();
218 assert_eq!(parsed.to_header_value().to_str().unwrap(), KEY_B64);
219
220 assert_eq!(
221 " ".parse::<EncryptionKey>().unwrap_err(),
222 EncryptionKeyLengthError(0)
223 );
224
225 let too_long = "A".repeat(MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN + 1);
226 assert_eq!(
227 too_long.parse::<EncryptionKey>().unwrap_err(),
228 EncryptionKeyLengthError(MAX_ENCRYPTION_KEY_HEADER_VALUE_LEN + 1)
229 );
230 }
231
232 #[test]
233 fn resolve_plain_ignores_supplied_key() {
234 let encryption = EncryptionSpec::resolve(None, Some("!!!!".parse().unwrap())).unwrap();
235 assert!(matches!(encryption, EncryptionSpec::Plain));
236 }
237
238 #[rstest]
239 #[case(EncryptionAlgorithm::Aegis256)]
240 #[case(EncryptionAlgorithm::Aes256Gcm)]
241 fn resolve_encrypted_requires_key(#[case] cipher: EncryptionAlgorithm) {
242 let err = EncryptionSpec::resolve(Some(cipher), None).unwrap_err();
243 assert_eq!(err, EncryptionSpecResolutionError::MissingKey { cipher });
244 }
245
246 #[rstest]
247 #[case(EncryptionAlgorithm::Aegis256)]
248 #[case(EncryptionAlgorithm::Aes256Gcm)]
249 fn resolve_encrypted_decodes_key_for_each_algorithm(#[case] cipher: EncryptionAlgorithm) {
250 let encryption = resolve_encrypted(cipher, EncryptionKey::new(KEY_BYTES)).unwrap();
251
252 match (cipher, encryption) {
253 (EncryptionAlgorithm::Aegis256, EncryptionSpec::Aegis256(key)) => {
254 assert_eq!(key.expose_secret(), &KEY_BYTES);
255 }
256 (EncryptionAlgorithm::Aes256Gcm, EncryptionSpec::Aes256Gcm(key)) => {
257 assert_eq!(key.expose_secret(), &KEY_BYTES);
258 }
259 _ => panic!("resolved encryption spec did not match requested algorithm"),
260 }
261 }
262
263 #[rstest]
264 #[case(EncryptionAlgorithm::Aegis256)]
265 #[case(EncryptionAlgorithm::Aes256Gcm)]
266 fn resolve_encrypted_rejects_invalid_base64(#[case] cipher: EncryptionAlgorithm) {
267 let err = resolve_encrypted(cipher, "!!!!".parse().unwrap()).unwrap_err();
268 assert_eq!(err, EncryptionSpecResolutionError::InvalidBase64 { cipher });
269 }
270
271 #[test]
272 fn resolve_encrypted_rejects_non_32_byte_keys() {
273 let cipher = EncryptionAlgorithm::Aegis256;
274
275 let short_err = resolve_encrypted(cipher, EncryptionKey::new([0x42; 4])).unwrap_err();
276 assert_eq!(
277 short_err,
278 EncryptionSpecResolutionError::InvalidKeyLength { cipher, length: 4 }
279 );
280
281 let long_err = resolve_encrypted(cipher, EncryptionKey::new([0x42; 33])).unwrap_err();
282 assert_eq!(
283 long_err,
284 EncryptionSpecResolutionError::InvalidKeyLength { cipher, length: 33 }
285 );
286 }
287}