solana_zk_sdk/encryption/
auth_encryption.rs

1//! Authenticated encryption implementation.
2//!
3//! This module is a simple wrapper of the `Aes128GcmSiv` implementation
4//! specialized for SPL Token2022 program where the plaintext is always a `u64`
5//! number.
6#[cfg(target_arch = "wasm32")]
7use wasm_bindgen::prelude::*;
8use {
9    crate::{
10        encryption::{AE_CIPHERTEXT_LEN, AE_KEY_LEN},
11        errors::AuthenticatedEncryptionError,
12    },
13    aes_gcm_siv::{
14        aead::{Aead, KeyInit},
15        Aes128GcmSiv,
16    },
17    base64::{prelude::BASE64_STANDARD, Engine},
18    rand::{rngs::OsRng, Rng},
19    std::{convert::TryInto, fmt},
20    zeroize::Zeroize,
21};
22// Currently, `wasm_bindgen` exports types and functions included in the current crate, but all
23// types and functions exported for wasm targets in all of its dependencies
24// (https://github.com/rustwasm/wasm-bindgen/issues/3759). We specifically exclude some of the
25// dependencies that will cause unnecessary bloat to the wasm binary.
26#[cfg(not(target_arch = "wasm32"))]
27use {
28    sha3::Digest,
29    sha3::Sha3_512,
30    solana_derivation_path::DerivationPath,
31    solana_seed_derivable::SeedDerivable,
32    solana_seed_phrase::generate_seed_from_seed_phrase_and_passphrase,
33    solana_signature::Signature,
34    solana_signer::{EncodableKey, Signer, SignerError},
35    std::{
36        error,
37        io::{Read, Write},
38    },
39    subtle::ConstantTimeEq,
40};
41
42/// Byte length of an authenticated encryption nonce component
43const NONCE_LEN: usize = 12;
44
45/// Byte length of an authenticated encryption ciphertext component
46const CIPHERTEXT_LEN: usize = 24;
47
48struct AuthenticatedEncryption;
49impl AuthenticatedEncryption {
50    /// Generates an authenticated encryption key.
51    ///
52    /// This function is randomized. It internally samples a 128-bit key using `OsRng`.
53    fn keygen() -> AeKey {
54        AeKey(OsRng.gen::<[u8; AE_KEY_LEN]>())
55    }
56
57    /// On input of an authenticated encryption key and an amount, the function returns a
58    /// corresponding authenticated encryption ciphertext.
59    fn encrypt(key: &AeKey, balance: u64) -> AeCiphertext {
60        let mut plaintext = balance.to_le_bytes();
61        let nonce: Nonce = OsRng.gen::<[u8; NONCE_LEN]>();
62
63        // The balance and the nonce have fixed length and therefore, encryption should not fail.
64        let ciphertext = Aes128GcmSiv::new(&key.0.into())
65            .encrypt(&nonce.into(), plaintext.as_ref())
66            .expect("authenticated encryption");
67
68        plaintext.zeroize();
69
70        AeCiphertext {
71            nonce,
72            ciphertext: ciphertext.try_into().unwrap(),
73        }
74    }
75
76    /// On input of an authenticated encryption key and a ciphertext, the function returns the
77    /// originally encrypted amount.
78    fn decrypt(key: &AeKey, ciphertext: &AeCiphertext) -> Option<u64> {
79        let plaintext = Aes128GcmSiv::new(&key.0.into())
80            .decrypt(&ciphertext.nonce.into(), ciphertext.ciphertext.as_ref());
81
82        if let Ok(plaintext) = plaintext {
83            let amount_bytes: [u8; 8] = plaintext.try_into().unwrap();
84            Some(u64::from_le_bytes(amount_bytes))
85        } else {
86            None
87        }
88    }
89}
90
91#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
92#[derive(Clone, Debug, Zeroize, Eq, PartialEq)]
93pub struct AeKey([u8; AE_KEY_LEN]);
94#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
95impl AeKey {
96    /// Generates a random authenticated encryption key.
97    ///
98    /// This function is randomized. It internally samples a 128-bit key using `OsRng`.
99    #[cfg_attr(target_arch = "wasm32", wasm_bindgen(js_name = newRand))]
100    pub fn new_rand() -> Self {
101        AuthenticatedEncryption::keygen()
102    }
103
104    /// Encrypts an amount under the authenticated encryption key.
105    pub fn encrypt(&self, amount: u64) -> AeCiphertext {
106        AuthenticatedEncryption::encrypt(self, amount)
107    }
108
109    pub fn decrypt(&self, ciphertext: &AeCiphertext) -> Option<u64> {
110        AuthenticatedEncryption::decrypt(self, ciphertext)
111    }
112}
113
114#[cfg(not(target_arch = "wasm32"))]
115impl AeKey {
116    /// Deterministically derives an authenticated encryption key from a Solana signer and a public
117    /// seed.
118    ///
119    /// This function exists for applications where a user may not wish to maintain a Solana signer
120    /// and an authenticated encryption key separately. Instead, a user can derive the ElGamal
121    /// keypair on-the-fly whenever encrytion/decryption is needed.
122    pub fn new_from_signer(
123        signer: &dyn Signer,
124        public_seed: &[u8],
125    ) -> Result<Self, Box<dyn error::Error>> {
126        let seed = Self::seed_from_signer(signer, public_seed)?;
127        Self::from_seed(&seed)
128    }
129
130    /// Derive a seed from a Solana signer used to generate an authenticated encryption key.
131    ///
132    /// The seed is derived as the hash of the signature of a public seed.
133    pub fn seed_from_signer(
134        signer: &dyn Signer,
135        public_seed: &[u8],
136    ) -> Result<Vec<u8>, SignerError> {
137        // TODO: This function uses a non-standard KDF and should be refactored.
138        // See: https://github.com/solana-program/zk-elgamal-proof/issues/35
139        let message = [b"AeKey", public_seed].concat();
140        let signature = signer.try_sign_message(&message)?;
141
142        // Some `Signer` implementations return the default signature, which is not suitable for
143        // use as key material
144        if bool::from(signature.as_ref().ct_eq(Signature::default().as_ref())) {
145            return Err(SignerError::Custom("Rejecting default signature".into()));
146        }
147
148        Ok(Self::seed_from_signature(&signature))
149    }
150
151    /// Derive an authenticated encryption key from a signature.
152    pub fn new_from_signature(signature: &Signature) -> Result<Self, Box<dyn error::Error>> {
153        let seed = Self::seed_from_signature(signature);
154        Self::from_seed(&seed)
155    }
156
157    /// Derive a seed from a signature used to generate an authenticated encryption key.
158    pub fn seed_from_signature(signature: &Signature) -> Vec<u8> {
159        let mut hasher = Sha3_512::new();
160        hasher.update(signature);
161        let result = hasher.finalize();
162
163        result.to_vec()
164    }
165}
166
167#[cfg(not(target_arch = "wasm32"))]
168impl EncodableKey for AeKey {
169    fn read<R: Read>(reader: &mut R) -> Result<Self, Box<dyn error::Error>> {
170        let bytes: [u8; AE_KEY_LEN] = serde_json::from_reader(reader)?;
171        Ok(Self(bytes))
172    }
173
174    fn write<W: Write>(&self, writer: &mut W) -> Result<String, Box<dyn error::Error>> {
175        let bytes = self.0;
176        let json = serde_json::to_string(&bytes.to_vec())?;
177        writer.write_all(&json.clone().into_bytes())?;
178        Ok(json)
179    }
180}
181
182#[cfg(not(target_arch = "wasm32"))]
183impl SeedDerivable for AeKey {
184    fn from_seed(seed: &[u8]) -> Result<Self, Box<dyn error::Error>> {
185        const MINIMUM_SEED_LEN: usize = AE_KEY_LEN;
186        const MAXIMUM_SEED_LEN: usize = 65535;
187
188        if seed.len() < MINIMUM_SEED_LEN {
189            return Err(AuthenticatedEncryptionError::SeedLengthTooShort.into());
190        }
191        if seed.len() > MAXIMUM_SEED_LEN {
192            return Err(AuthenticatedEncryptionError::SeedLengthTooLong.into());
193        }
194
195        let mut hasher = Sha3_512::new();
196        hasher.update(seed);
197        let result = hasher.finalize();
198
199        Ok(Self(result[..AE_KEY_LEN].try_into()?))
200    }
201
202    fn from_seed_and_derivation_path(
203        _seed: &[u8],
204        _derivation_path: Option<DerivationPath>,
205    ) -> Result<Self, Box<dyn error::Error>> {
206        Err(AuthenticatedEncryptionError::DerivationMethodNotSupported.into())
207    }
208
209    fn from_seed_phrase_and_passphrase(
210        seed_phrase: &str,
211        passphrase: &str,
212    ) -> Result<Self, Box<dyn error::Error>> {
213        Self::from_seed(&generate_seed_from_seed_phrase_and_passphrase(
214            seed_phrase,
215            passphrase,
216        ))
217    }
218}
219
220impl From<[u8; AE_KEY_LEN]> for AeKey {
221    fn from(bytes: [u8; AE_KEY_LEN]) -> Self {
222        Self(bytes)
223    }
224}
225
226impl From<AeKey> for [u8; AE_KEY_LEN] {
227    fn from(key: AeKey) -> Self {
228        key.0
229    }
230}
231
232impl TryFrom<&[u8]> for AeKey {
233    type Error = AuthenticatedEncryptionError;
234    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
235        if bytes.len() != AE_KEY_LEN {
236            return Err(AuthenticatedEncryptionError::Deserialization);
237        }
238        bytes
239            .try_into()
240            .map(Self)
241            .map_err(|_| AuthenticatedEncryptionError::Deserialization)
242    }
243}
244
245/// For the purpose of encrypting balances for the spl token accounts, the nonce and ciphertext
246/// sizes should always be fixed.
247type Nonce = [u8; NONCE_LEN];
248type Ciphertext = [u8; CIPHERTEXT_LEN];
249
250/// Authenticated encryption nonce and ciphertext
251#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
252#[derive(Clone, Copy, Debug, Default)]
253pub struct AeCiphertext {
254    nonce: Nonce,
255    ciphertext: Ciphertext,
256}
257impl AeCiphertext {
258    pub fn decrypt(&self, key: &AeKey) -> Option<u64> {
259        AuthenticatedEncryption::decrypt(key, self)
260    }
261
262    pub fn to_bytes(&self) -> [u8; AE_CIPHERTEXT_LEN] {
263        let mut buf = [0_u8; AE_CIPHERTEXT_LEN];
264        buf[..NONCE_LEN].copy_from_slice(&self.nonce);
265        buf[NONCE_LEN..].copy_from_slice(&self.ciphertext);
266        buf
267    }
268
269    pub fn from_bytes(bytes: &[u8]) -> Option<AeCiphertext> {
270        if bytes.len() != AE_CIPHERTEXT_LEN {
271            return None;
272        }
273
274        let nonce = bytes[..NONCE_LEN].try_into().ok()?;
275        let ciphertext = bytes[NONCE_LEN..].try_into().ok()?;
276
277        Some(AeCiphertext { nonce, ciphertext })
278    }
279}
280
281impl fmt::Display for AeCiphertext {
282    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
283        write!(f, "{}", BASE64_STANDARD.encode(self.to_bytes()))
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use {
290        super::*, solana_keypair::Keypair, solana_pubkey::Pubkey,
291        solana_signer::null_signer::NullSigner,
292    };
293
294    #[test]
295    fn test_aes_encrypt_decrypt_correctness() {
296        let key = AeKey::new_rand();
297        let amount = 55;
298
299        let ciphertext = key.encrypt(amount);
300        let decrypted_amount = ciphertext.decrypt(&key).unwrap();
301
302        assert_eq!(amount, decrypted_amount);
303    }
304
305    #[test]
306    fn test_aes_new() {
307        let keypair1 = Keypair::new();
308        let keypair2 = Keypair::new();
309
310        assert_ne!(
311            AeKey::new_from_signer(&keypair1, Pubkey::default().as_ref())
312                .unwrap()
313                .0,
314            AeKey::new_from_signer(&keypair2, Pubkey::default().as_ref())
315                .unwrap()
316                .0,
317        );
318
319        let null_signer = NullSigner::new(&Pubkey::default());
320        assert!(AeKey::new_from_signer(&null_signer, Pubkey::default().as_ref()).is_err());
321    }
322
323    #[test]
324    fn test_aes_key_from_seed() {
325        let good_seed = vec![0; 32];
326        assert!(AeKey::from_seed(&good_seed).is_ok());
327
328        let too_short_seed = vec![0; 15];
329        assert!(AeKey::from_seed(&too_short_seed).is_err());
330
331        let too_long_seed = vec![0; 65536];
332        assert!(AeKey::from_seed(&too_long_seed).is_err());
333    }
334
335    #[test]
336    fn test_aes_key_from() {
337        let key = AeKey::from_seed(&[0; 32]).unwrap();
338        let key_bytes: [u8; AE_KEY_LEN] = AeKey::from_seed(&[0; 32]).unwrap().into();
339
340        assert_eq!(key, AeKey::from(key_bytes));
341    }
342
343    #[test]
344    fn test_aes_key_try_from() {
345        let key = AeKey::from_seed(&[0; 32]).unwrap();
346        let key_bytes: [u8; AE_KEY_LEN] = AeKey::from_seed(&[0; 32]).unwrap().into();
347
348        assert_eq!(key, AeKey::try_from(key_bytes.as_slice()).unwrap());
349    }
350
351    #[test]
352    fn test_aes_key_try_from_error() {
353        let too_short_bytes = vec![0_u8; AE_KEY_LEN - 1];
354        assert!(AeKey::try_from(too_short_bytes.as_slice()).is_err());
355
356        let too_many_bytes = vec![0_u8; AE_KEY_LEN + 1];
357        assert!(AeKey::try_from(too_many_bytes.as_slice()).is_err());
358    }
359
360    #[test]
361    fn test_tampered_ciphertext_fails_decryption() {
362        let key = AeKey::new_rand();
363        let amount = 99_u64;
364
365        let ciphertext = key.encrypt(amount);
366        let mut tampered_bytes = ciphertext.to_bytes();
367
368        // Flip the first bit of the actual ciphertext component
369        tampered_bytes[NONCE_LEN] ^= 1;
370
371        let tampered_ciphertext = AeCiphertext::from_bytes(&tampered_bytes).unwrap();
372        assert!(tampered_ciphertext.decrypt(&key).is_none());
373    }
374
375    #[test]
376    fn test_tampered_nonce_fails_decryption() {
377        let key = AeKey::new_rand();
378        let amount = 99_u64;
379
380        let ciphertext = key.encrypt(amount);
381        let mut tampered_bytes = ciphertext.to_bytes();
382
383        // Flip the first bit of the nonce
384        tampered_bytes[0] ^= 1;
385
386        let tampered_ciphertext = AeCiphertext::from_bytes(&tampered_bytes).unwrap();
387        assert!(tampered_ciphertext.decrypt(&key).is_none());
388    }
389
390    #[test]
391    fn test_encryption_is_non_deterministic() {
392        let key = AeKey::new_rand();
393        let amount = 123_u64;
394
395        let ciphertext1 = key.encrypt(amount);
396        let ciphertext2 = key.encrypt(amount);
397
398        assert_ne!(ciphertext1.to_bytes(), ciphertext2.to_bytes());
399    }
400}