solana_zk_sdk/encryption/
auth_encryption.rs1#[cfg(target_arch = "wasm32")]
7use wasm_bindgen::prelude::*;
8use {
9 crate::{
10 encryption::{AE_CIPHERTEXT_LEN, AE_KEY_LEN},
11 errors::AuthenticatedEncryptionError,
12 },
13 aes_gcm_siv::{
14 aead::{Aead, KeyInit},
15 Aes128GcmSiv,
16 },
17 base64::{prelude::BASE64_STANDARD, Engine},
18 rand::{rngs::OsRng, Rng},
19 std::{convert::TryInto, fmt},
20 zeroize::Zeroize,
21};
22#[cfg(not(target_arch = "wasm32"))]
27use {
28 sha3::Digest,
29 sha3::Sha3_512,
30 solana_derivation_path::DerivationPath,
31 solana_seed_derivable::SeedDerivable,
32 solana_seed_phrase::generate_seed_from_seed_phrase_and_passphrase,
33 solana_signature::Signature,
34 solana_signer::{EncodableKey, Signer, SignerError},
35 std::{
36 error,
37 io::{Read, Write},
38 },
39 subtle::ConstantTimeEq,
40};
41
42const NONCE_LEN: usize = 12;
44
45const CIPHERTEXT_LEN: usize = 24;
47
48struct AuthenticatedEncryption;
49impl AuthenticatedEncryption {
50 fn keygen() -> AeKey {
54 AeKey(OsRng.gen::<[u8; AE_KEY_LEN]>())
55 }
56
57 fn encrypt(key: &AeKey, balance: u64) -> AeCiphertext {
60 let mut plaintext = balance.to_le_bytes();
61 let nonce: Nonce = OsRng.gen::<[u8; NONCE_LEN]>();
62
63 let ciphertext = Aes128GcmSiv::new(&key.0.into())
65 .encrypt(&nonce.into(), plaintext.as_ref())
66 .expect("authenticated encryption");
67
68 plaintext.zeroize();
69
70 AeCiphertext {
71 nonce,
72 ciphertext: ciphertext.try_into().unwrap(),
73 }
74 }
75
76 fn decrypt(key: &AeKey, ciphertext: &AeCiphertext) -> Option<u64> {
79 let plaintext = Aes128GcmSiv::new(&key.0.into())
80 .decrypt(&ciphertext.nonce.into(), ciphertext.ciphertext.as_ref());
81
82 if let Ok(plaintext) = plaintext {
83 let amount_bytes: [u8; 8] = plaintext.try_into().unwrap();
84 Some(u64::from_le_bytes(amount_bytes))
85 } else {
86 None
87 }
88 }
89}
90
91#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
92#[derive(Clone, Debug, Zeroize, Eq, PartialEq)]
93pub struct AeKey([u8; AE_KEY_LEN]);
94#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
95impl AeKey {
96 #[cfg_attr(target_arch = "wasm32", wasm_bindgen(js_name = newRand))]
100 pub fn new_rand() -> Self {
101 AuthenticatedEncryption::keygen()
102 }
103
104 pub fn encrypt(&self, amount: u64) -> AeCiphertext {
106 AuthenticatedEncryption::encrypt(self, amount)
107 }
108
109 pub fn decrypt(&self, ciphertext: &AeCiphertext) -> Option<u64> {
110 AuthenticatedEncryption::decrypt(self, ciphertext)
111 }
112}
113
114#[cfg(not(target_arch = "wasm32"))]
115impl AeKey {
116 pub fn new_from_signer(
123 signer: &dyn Signer,
124 public_seed: &[u8],
125 ) -> Result<Self, Box<dyn error::Error>> {
126 let seed = Self::seed_from_signer(signer, public_seed)?;
127 Self::from_seed(&seed)
128 }
129
130 pub fn seed_from_signer(
134 signer: &dyn Signer,
135 public_seed: &[u8],
136 ) -> Result<Vec<u8>, SignerError> {
137 let message = [b"AeKey", public_seed].concat();
140 let signature = signer.try_sign_message(&message)?;
141
142 if bool::from(signature.as_ref().ct_eq(Signature::default().as_ref())) {
145 return Err(SignerError::Custom("Rejecting default signature".into()));
146 }
147
148 Ok(Self::seed_from_signature(&signature))
149 }
150
151 pub fn new_from_signature(signature: &Signature) -> Result<Self, Box<dyn error::Error>> {
153 let seed = Self::seed_from_signature(signature);
154 Self::from_seed(&seed)
155 }
156
157 pub fn seed_from_signature(signature: &Signature) -> Vec<u8> {
159 let mut hasher = Sha3_512::new();
160 hasher.update(signature);
161 let result = hasher.finalize();
162
163 result.to_vec()
164 }
165}
166
167#[cfg(not(target_arch = "wasm32"))]
168impl EncodableKey for AeKey {
169 fn read<R: Read>(reader: &mut R) -> Result<Self, Box<dyn error::Error>> {
170 let bytes: [u8; AE_KEY_LEN] = serde_json::from_reader(reader)?;
171 Ok(Self(bytes))
172 }
173
174 fn write<W: Write>(&self, writer: &mut W) -> Result<String, Box<dyn error::Error>> {
175 let bytes = self.0;
176 let json = serde_json::to_string(&bytes.to_vec())?;
177 writer.write_all(&json.clone().into_bytes())?;
178 Ok(json)
179 }
180}
181
182#[cfg(not(target_arch = "wasm32"))]
183impl SeedDerivable for AeKey {
184 fn from_seed(seed: &[u8]) -> Result<Self, Box<dyn error::Error>> {
185 const MINIMUM_SEED_LEN: usize = AE_KEY_LEN;
186 const MAXIMUM_SEED_LEN: usize = 65535;
187
188 if seed.len() < MINIMUM_SEED_LEN {
189 return Err(AuthenticatedEncryptionError::SeedLengthTooShort.into());
190 }
191 if seed.len() > MAXIMUM_SEED_LEN {
192 return Err(AuthenticatedEncryptionError::SeedLengthTooLong.into());
193 }
194
195 let mut hasher = Sha3_512::new();
196 hasher.update(seed);
197 let result = hasher.finalize();
198
199 Ok(Self(result[..AE_KEY_LEN].try_into()?))
200 }
201
202 fn from_seed_and_derivation_path(
203 _seed: &[u8],
204 _derivation_path: Option<DerivationPath>,
205 ) -> Result<Self, Box<dyn error::Error>> {
206 Err(AuthenticatedEncryptionError::DerivationMethodNotSupported.into())
207 }
208
209 fn from_seed_phrase_and_passphrase(
210 seed_phrase: &str,
211 passphrase: &str,
212 ) -> Result<Self, Box<dyn error::Error>> {
213 Self::from_seed(&generate_seed_from_seed_phrase_and_passphrase(
214 seed_phrase,
215 passphrase,
216 ))
217 }
218}
219
220impl From<[u8; AE_KEY_LEN]> for AeKey {
221 fn from(bytes: [u8; AE_KEY_LEN]) -> Self {
222 Self(bytes)
223 }
224}
225
226impl From<AeKey> for [u8; AE_KEY_LEN] {
227 fn from(key: AeKey) -> Self {
228 key.0
229 }
230}
231
232impl TryFrom<&[u8]> for AeKey {
233 type Error = AuthenticatedEncryptionError;
234 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
235 if bytes.len() != AE_KEY_LEN {
236 return Err(AuthenticatedEncryptionError::Deserialization);
237 }
238 bytes
239 .try_into()
240 .map(Self)
241 .map_err(|_| AuthenticatedEncryptionError::Deserialization)
242 }
243}
244
245type Nonce = [u8; NONCE_LEN];
248type Ciphertext = [u8; CIPHERTEXT_LEN];
249
250#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
252#[derive(Clone, Copy, Debug, Default)]
253pub struct AeCiphertext {
254 nonce: Nonce,
255 ciphertext: Ciphertext,
256}
257impl AeCiphertext {
258 pub fn decrypt(&self, key: &AeKey) -> Option<u64> {
259 AuthenticatedEncryption::decrypt(key, self)
260 }
261
262 pub fn to_bytes(&self) -> [u8; AE_CIPHERTEXT_LEN] {
263 let mut buf = [0_u8; AE_CIPHERTEXT_LEN];
264 buf[..NONCE_LEN].copy_from_slice(&self.nonce);
265 buf[NONCE_LEN..].copy_from_slice(&self.ciphertext);
266 buf
267 }
268
269 pub fn from_bytes(bytes: &[u8]) -> Option<AeCiphertext> {
270 if bytes.len() != AE_CIPHERTEXT_LEN {
271 return None;
272 }
273
274 let nonce = bytes[..NONCE_LEN].try_into().ok()?;
275 let ciphertext = bytes[NONCE_LEN..].try_into().ok()?;
276
277 Some(AeCiphertext { nonce, ciphertext })
278 }
279}
280
281impl fmt::Display for AeCiphertext {
282 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
283 write!(f, "{}", BASE64_STANDARD.encode(self.to_bytes()))
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use {
290 super::*, solana_keypair::Keypair, solana_pubkey::Pubkey,
291 solana_signer::null_signer::NullSigner,
292 };
293
294 #[test]
295 fn test_aes_encrypt_decrypt_correctness() {
296 let key = AeKey::new_rand();
297 let amount = 55;
298
299 let ciphertext = key.encrypt(amount);
300 let decrypted_amount = ciphertext.decrypt(&key).unwrap();
301
302 assert_eq!(amount, decrypted_amount);
303 }
304
305 #[test]
306 fn test_aes_new() {
307 let keypair1 = Keypair::new();
308 let keypair2 = Keypair::new();
309
310 assert_ne!(
311 AeKey::new_from_signer(&keypair1, Pubkey::default().as_ref())
312 .unwrap()
313 .0,
314 AeKey::new_from_signer(&keypair2, Pubkey::default().as_ref())
315 .unwrap()
316 .0,
317 );
318
319 let null_signer = NullSigner::new(&Pubkey::default());
320 assert!(AeKey::new_from_signer(&null_signer, Pubkey::default().as_ref()).is_err());
321 }
322
323 #[test]
324 fn test_aes_key_from_seed() {
325 let good_seed = vec![0; 32];
326 assert!(AeKey::from_seed(&good_seed).is_ok());
327
328 let too_short_seed = vec![0; 15];
329 assert!(AeKey::from_seed(&too_short_seed).is_err());
330
331 let too_long_seed = vec![0; 65536];
332 assert!(AeKey::from_seed(&too_long_seed).is_err());
333 }
334
335 #[test]
336 fn test_aes_key_from() {
337 let key = AeKey::from_seed(&[0; 32]).unwrap();
338 let key_bytes: [u8; AE_KEY_LEN] = AeKey::from_seed(&[0; 32]).unwrap().into();
339
340 assert_eq!(key, AeKey::from(key_bytes));
341 }
342
343 #[test]
344 fn test_aes_key_try_from() {
345 let key = AeKey::from_seed(&[0; 32]).unwrap();
346 let key_bytes: [u8; AE_KEY_LEN] = AeKey::from_seed(&[0; 32]).unwrap().into();
347
348 assert_eq!(key, AeKey::try_from(key_bytes.as_slice()).unwrap());
349 }
350
351 #[test]
352 fn test_aes_key_try_from_error() {
353 let too_short_bytes = vec![0_u8; AE_KEY_LEN - 1];
354 assert!(AeKey::try_from(too_short_bytes.as_slice()).is_err());
355
356 let too_many_bytes = vec![0_u8; AE_KEY_LEN + 1];
357 assert!(AeKey::try_from(too_many_bytes.as_slice()).is_err());
358 }
359
360 #[test]
361 fn test_tampered_ciphertext_fails_decryption() {
362 let key = AeKey::new_rand();
363 let amount = 99_u64;
364
365 let ciphertext = key.encrypt(amount);
366 let mut tampered_bytes = ciphertext.to_bytes();
367
368 tampered_bytes[NONCE_LEN] ^= 1;
370
371 let tampered_ciphertext = AeCiphertext::from_bytes(&tampered_bytes).unwrap();
372 assert!(tampered_ciphertext.decrypt(&key).is_none());
373 }
374
375 #[test]
376 fn test_tampered_nonce_fails_decryption() {
377 let key = AeKey::new_rand();
378 let amount = 99_u64;
379
380 let ciphertext = key.encrypt(amount);
381 let mut tampered_bytes = ciphertext.to_bytes();
382
383 tampered_bytes[0] ^= 1;
385
386 let tampered_ciphertext = AeCiphertext::from_bytes(&tampered_bytes).unwrap();
387 assert!(tampered_ciphertext.decrypt(&key).is_none());
388 }
389
390 #[test]
391 fn test_encryption_is_non_deterministic() {
392 let key = AeKey::new_rand();
393 let amount = 123_u64;
394
395 let ciphertext1 = key.encrypt(amount);
396 let ciphertext2 = key.encrypt(amount);
397
398 assert_ne!(ciphertext1.to_bytes(), ciphertext2.to_bytes());
399 }
400}