Skip to main content

secret_manager/
aws_kms_encryptor.rs

1use async_trait::async_trait;
2use aws_sdk_kms::Client as KmsClient;
3use aws_sdk_kms::primitives::Blob;
4use crate::encryptor::{Encrypted, EncryptorError, KeyEncryptor};
5
6/// [`KeyEncryptor`] backed by AWS Key Management Service (KMS).
7///
8/// Each [`encrypt`](KeyEncryptor::encrypt) call invokes `kms:Encrypt` on the configured CMK.
9/// KMS manages its own IV internally, so [`Encrypted::nonce`] is always `None` for ciphertexts
10/// produced by this encryptor.
11///
12/// [`decrypt`](KeyEncryptor::decrypt) invokes `kms:Decrypt`; the `key_id` is embedded in the
13/// KMS ciphertext blob and does not need to be supplied again.
14///
15/// # Key versioning
16///
17/// `version` is a caller-controlled label stored in [`Encrypted::key_version`].  Increment it
18/// whenever you rotate the KMS CMK so that syncers can tell which CMK to use for a given
19/// ciphertext at decryption time (if you have multiple encryptors in rotation).
20#[derive(Clone)]
21pub struct KmsEncryptor {
22    client: KmsClient,
23    key_id: String,
24    version: u8,
25}
26
27impl KmsEncryptor {
28    /// Create a new `KmsEncryptor`.
29    ///
30    /// - `client` — pre-configured [`KmsClient`]; region and credentials come from the SDK config
31    /// - `key_id` — ARN, alias, or key ID of the KMS symmetric CMK to use
32    /// - `version` — stored in [`Encrypted::key_version`]; use 0 if you have a single CMK
33    pub fn new(client: KmsClient, key_id: impl Into<String>, version: u8) -> Self {
34        Self { client, key_id: key_id.into(), version }
35    }
36}
37
38#[async_trait]
39impl KeyEncryptor for KmsEncryptor {
40    async fn encrypt(&self, plaintext: &[u8]) -> Result<Encrypted, EncryptorError> {
41        let resp = self.client
42            .encrypt()
43            .key_id(&self.key_id)
44            .plaintext(Blob::new(plaintext))
45            .send()
46            .await
47            .map_err(|e| EncryptorError::Kms(Box::new(e)))?;
48
49        Ok(Encrypted {
50            ciphertext: resp.ciphertext_blob.unwrap().into_inner(),
51            nonce: None, // KMS manages its own IVs internally
52            key_version: self.version,
53        })
54    }
55
56    async fn decrypt(&self, encrypted: &Encrypted) -> Result<Vec<u8>, EncryptorError> {
57        let resp = self.client
58            .decrypt()
59            .ciphertext_blob(Blob::new(encrypted.ciphertext.clone()))
60            .send()
61            .await
62            .map_err(|e| EncryptorError::Kms(Box::new(e)))?;
63
64        Ok(resp.plaintext.unwrap().into_inner())
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use aws_sdk_kms::types::KeyUsageType;
72    use test_containers_util::moto_container::get_aws_config;
73
74    async fn make_encryptor(version: u8) -> KmsEncryptor {
75        let config = get_aws_config("moto-kms").await;
76        let client = KmsClient::new(&config);
77        let key_id = client
78            .create_key()
79            .key_usage(KeyUsageType::EncryptDecrypt)
80            .send()
81            .await
82            .expect("create_key failed")
83            .key_metadata()
84            .unwrap()
85            .key_id()
86            .to_string();
87        KmsEncryptor::new(client, key_id, version)
88    }
89
90    #[tokio::test(flavor = "multi_thread")]
91    async fn encrypt_decrypt_roundtrip() {
92        let enc = make_encryptor(1).await;
93        let plaintext = b"my secret key bytes";
94        let encrypted = enc.encrypt(plaintext).await.unwrap();
95        let decrypted = enc.decrypt(&encrypted).await.unwrap();
96        assert_eq!(decrypted, plaintext.as_ref());
97    }
98
99    #[tokio::test(flavor = "multi_thread")]
100    async fn encrypted_payload_has_no_nonce() {
101        let enc = make_encryptor(42).await;
102        let encrypted = enc.encrypt(b"some bytes").await.unwrap();
103        assert!(encrypted.nonce.is_none(), "KMS manages its own IV — nonce must be None");
104    }
105
106    #[tokio::test(flavor = "multi_thread")]
107    async fn encrypted_payload_carries_correct_key_version() {
108        let enc = make_encryptor(7).await;
109        let encrypted = enc.encrypt(b"some bytes").await.unwrap();
110        assert_eq!(encrypted.key_version, 7);
111    }
112
113    #[tokio::test(flavor = "multi_thread")]
114    async fn same_plaintext_produces_different_ciphertext() {
115        let enc = make_encryptor(1).await;
116        let plaintext = b"determinism test";
117        let a = enc.encrypt(plaintext).await.unwrap();
118        let b = enc.encrypt(plaintext).await.unwrap();
119        assert_ne!(a.ciphertext, b.ciphertext, "KMS should produce different ciphertext per call");
120    }
121}