strong_box/
static_strong_box.rs1use chacha20poly1305::{
2 ChaCha20Poly1305, KeyInit as _,
3 aead::{Aead as _, Payload},
4};
5use rand::{RngCore, rng};
6use std::{collections::HashMap, fmt::Debug};
7
8use super::{Error, Key, KeyId, StrongBox};
9
10#[derive(Clone, Debug)]
98pub struct StaticStrongBox {
99 encryption_key: Key,
100 encryption_key_id: KeyId,
101 decryption_keys: HashMap<KeyId, Key>,
102}
103
104impl StaticStrongBox {
105 #[tracing::instrument(level = "debug", skip(enc_key, dec_keys))]
107 pub fn new(
108 enc_key: impl Into<Key>,
109 dec_keys: impl IntoIterator<Item = impl Into<Key>>,
110 ) -> Self {
111 let mut key_map: HashMap<KeyId, Key> = HashMap::default();
112
113 for key in dec_keys.into_iter() {
114 let key = key.into();
115 let key_id = super::key_id(&key);
116 tracing::debug!(%key_id, "Including decryption key");
117 key_map.insert(key_id, key);
118 }
119
120 let enc_key = enc_key.into();
121 let enc_key_id = super::key_id(&enc_key);
122 tracing::debug!("Encryption key is {enc_key_id}");
123
124 Self {
125 encryption_key_id: enc_key_id,
126 encryption_key: enc_key,
127 decryption_keys: key_map,
128 }
129 }
130
131 pub(crate) fn decrypt_ciphertext(
132 &self,
133 ciphertext: &Ciphertext,
134 ctx: &[u8],
135 ) -> Result<Vec<u8>, Error> {
136 if let Some(key) = self.decryption_keys.get(&ciphertext.key_id) {
137 tracing::debug!(key_id=%ciphertext.key_id, "Decrypting");
138
139 let mut aad = Vec::<u8>::new();
140 aad.extend_from_slice(ctx.as_ref());
141 aad.extend_from_slice(ciphertext.key_id.as_bytes());
142 aad.extend_from_slice(&ciphertext.nonce);
143
144 let cipher = ChaCha20Poly1305::new(key.expose_secret().into());
145 let payload = Payload {
146 msg: &ciphertext.ciphertext,
147 aad: &aad,
148 };
149 cipher
150 .decrypt((&ciphertext.nonce[..]).into(), payload)
151 .map_err(|_| Error::Decryption)
152 } else {
153 tracing::debug!(key_id=%ciphertext.key_id, "Decryption key not found");
154 Err(Error::Decryption)
155 }
156 }
157}
158
159impl StrongBox for StaticStrongBox {
160 #[tracing::instrument(level = "debug", skip(plaintext))]
161 fn encrypt(
162 &self,
163 plaintext: impl AsRef<[u8]>,
164 ctx: impl AsRef<[u8]> + Debug,
165 ) -> Result<Vec<u8>, Error> {
166 let cipher = ChaCha20Poly1305::new((self.encryption_key.expose_secret()).into());
167 let mut rng = rng();
168 let mut nonce = [0u8; 12];
169 rng.fill_bytes(&mut nonce);
170
171 let mut aad = Vec::<u8>::new();
172 aad.extend_from_slice(ctx.as_ref());
173 aad.extend_from_slice(self.encryption_key_id.as_bytes());
174 aad.extend_from_slice(&nonce);
175
176 let ciphertext = cipher
177 .encrypt(
178 (&nonce).into(),
179 Payload {
180 msg: plaintext.as_ref(),
181 aad: &aad,
182 },
183 )
184 .map_err(|_| Error::Encryption)?;
185 tracing::debug!(key_id=%self.encryption_key_id, "Encrypting");
186
187 Ciphertext::new(self.encryption_key_id, nonce, ciphertext).to_bytes()
188 }
189
190 #[tracing::instrument(level = "debug", skip(ciphertext))]
191 fn decrypt(
192 &self,
193 ciphertext: impl AsRef<[u8]>,
194 ctx: impl AsRef<[u8]> + Debug,
195 ) -> Result<Vec<u8>, Error> {
196 let ciphertext = Ciphertext::try_from(ciphertext.as_ref())?;
197
198 self.decrypt_ciphertext(&ciphertext, ctx.as_ref())
199 }
200}
201
202const CIPHERTEXT_MAGIC: [u8; 3] = [0xb1, 0xb8, 0xf5];
204
205#[derive(Clone, Debug)]
206pub(crate) struct Ciphertext {
207 pub(crate) key_id: KeyId,
208 pub(crate) nonce: [u8; 12],
209 pub(crate) ciphertext: Vec<u8>,
210}
211
212impl Ciphertext {
213 pub(crate) fn new(key_id: KeyId, nonce: [u8; 12], ciphertext: Vec<u8>) -> Self {
214 Self {
215 key_id,
216 nonce,
217 ciphertext,
218 }
219 }
220
221 pub(crate) fn to_bytes(&self) -> Result<Vec<u8>, Error> {
222 use ciborium_ll::{Encoder, Header};
223
224 let mut v: Vec<u8> = Vec::new();
225
226 v.extend_from_slice(&CIPHERTEXT_MAGIC);
227
228 let mut enc = Encoder::from(&mut v);
229 enc.push(Header::Array(Some(3)))
230 .map_err(|e| Error::ciphertext_encoding("key_id", e))?;
231 self.key_id.encode(&mut enc)?;
232 enc.bytes(&self.nonce, None)
233 .map_err(|e| Error::ciphertext_encoding("nonce", e))?;
234 enc.bytes(&self.ciphertext, None)
235 .map_err(|e| Error::ciphertext_encoding("ciphertext", e))?;
236
237 tracing::debug!(
238 nonce = self
239 .nonce
240 .iter()
241 .map(|i| format!("{i:02x}"))
242 .collect::<Vec<_>>()
243 .join(""),
244 ct = self
245 .ciphertext
246 .iter()
247 .map(|i| format!("{i:02x}"))
248 .collect::<Vec<_>>()
249 .join(""),
250 "{}",
251 v.iter()
252 .map(|i| format!("{i:02x}"))
253 .collect::<Vec<_>>()
254 .join("")
255 );
256 Ok(v)
257 }
258}
259
260impl TryFrom<&[u8]> for Ciphertext {
261 type Error = Error;
262
263 fn try_from(b: &[u8]) -> Result<Self, Self::Error> {
264 use ciborium_ll::{Decoder, Header};
265
266 if b.len() < 21 {
267 return Err(Error::invalid_ciphertext("too short"));
268 }
269
270 if b[0..3] != CIPHERTEXT_MAGIC {
271 tracing::debug!(magic=?CIPHERTEXT_MAGIC, actual=?b[0..3]);
272 return Err(Error::invalid_ciphertext("incorrect magic"));
273 }
274
275 let mut dec = Decoder::from(&b[3..]);
276
277 let Header::Array(Some(3)) = dec
278 .pull()
279 .map_err(|e| Error::ciphertext_decoding("array", e))?
280 else {
281 return Err(Error::invalid_ciphertext("expected array"));
282 };
283
284 let key_id = KeyId::decode(&mut dec)?;
285
286 let Header::Bytes(len) = dec
288 .pull()
289 .map_err(|e| Error::ciphertext_decoding("nonce header", e))?
290 else {
291 return Err(Error::invalid_ciphertext("expected nonce"));
292 };
293
294 let mut segments = dec.bytes(len);
295
296 let Ok(Some(mut segment)) = segments.pull() else {
297 return Err(Error::invalid_ciphertext("bad nonce"));
298 };
299
300 let mut buf = [0u8; 1024];
301 let mut nonce = [0u8; 12];
302
303 if let Some(chunk) = segment
304 .pull(&mut buf[..])
305 .map_err(|e| Error::ciphertext_decoding("nonce", e))?
306 {
307 nonce[..].copy_from_slice(chunk);
309 } else {
310 return Err(Error::invalid_ciphertext("short nonce"));
311 }
312
313 let Header::Bytes(len) = dec
315 .pull()
316 .map_err(|e| Error::ciphertext_decoding("ciphertext header", e))?
317 else {
318 return Err(Error::invalid_ciphertext("expected ciphertext"));
319 };
320
321 let mut segments = dec.bytes(len);
322
323 let Ok(Some(mut segment)) = segments.pull() else {
324 return Err(Error::invalid_ciphertext("bad ciphertext"));
325 };
326
327 let mut ciphertext: Vec<u8> = Vec::new();
328
329 while let Some(chunk) = segment
330 .pull(&mut buf[..])
331 .map_err(|e| Error::ciphertext_decoding("ciphertext", e))?
332 {
333 ciphertext.extend_from_slice(chunk);
334 }
335
336 Ok(Self {
337 key_id,
338 nonce,
339 ciphertext,
340 })
341 }
342}