secret_manager/
aws_kms_encryptor.rs1use async_trait::async_trait;
2use aws_sdk_kms::Client as KmsClient;
3use aws_sdk_kms::primitives::Blob;
4use crate::encryptor::{Encrypted, EncryptorError, KeyEncryptor};
5
6#[derive(Clone)]
21pub struct KmsEncryptor {
22 client: KmsClient,
23 key_id: String,
24 version: u8,
25}
26
27impl KmsEncryptor {
28 pub fn new(client: KmsClient, key_id: impl Into<String>, version: u8) -> Self {
34 Self { client, key_id: key_id.into(), version }
35 }
36}
37
38#[async_trait]
39impl KeyEncryptor for KmsEncryptor {
40 async fn encrypt(&self, plaintext: &[u8]) -> Result<Encrypted, EncryptorError> {
41 let resp = self.client
42 .encrypt()
43 .key_id(&self.key_id)
44 .plaintext(Blob::new(plaintext))
45 .send()
46 .await
47 .map_err(|e| EncryptorError::Kms(Box::new(e)))?;
48
49 Ok(Encrypted {
50 ciphertext: resp.ciphertext_blob.unwrap().into_inner(),
51 nonce: None, key_version: self.version,
53 })
54 }
55
56 async fn decrypt(&self, encrypted: &Encrypted) -> Result<Vec<u8>, EncryptorError> {
57 let resp = self.client
58 .decrypt()
59 .ciphertext_blob(Blob::new(encrypted.ciphertext.clone()))
60 .send()
61 .await
62 .map_err(|e| EncryptorError::Kms(Box::new(e)))?;
63
64 Ok(resp.plaintext.unwrap().into_inner())
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use aws_sdk_kms::types::KeyUsageType;
72 use test_containers_util::moto_container::get_aws_config;
73
74 async fn make_encryptor(version: u8) -> KmsEncryptor {
75 let config = get_aws_config("moto-kms").await;
76 let client = KmsClient::new(&config);
77 let key_id = client
78 .create_key()
79 .key_usage(KeyUsageType::EncryptDecrypt)
80 .send()
81 .await
82 .expect("create_key failed")
83 .key_metadata()
84 .unwrap()
85 .key_id()
86 .to_string();
87 KmsEncryptor::new(client, key_id, version)
88 }
89
90 #[tokio::test(flavor = "multi_thread")]
91 async fn encrypt_decrypt_roundtrip() {
92 let enc = make_encryptor(1).await;
93 let plaintext = b"my secret key bytes";
94 let encrypted = enc.encrypt(plaintext).await.unwrap();
95 let decrypted = enc.decrypt(&encrypted).await.unwrap();
96 assert_eq!(decrypted, plaintext.as_ref());
97 }
98
99 #[tokio::test(flavor = "multi_thread")]
100 async fn encrypted_payload_has_no_nonce() {
101 let enc = make_encryptor(42).await;
102 let encrypted = enc.encrypt(b"some bytes").await.unwrap();
103 assert!(encrypted.nonce.is_none(), "KMS manages its own IV — nonce must be None");
104 }
105
106 #[tokio::test(flavor = "multi_thread")]
107 async fn encrypted_payload_carries_correct_key_version() {
108 let enc = make_encryptor(7).await;
109 let encrypted = enc.encrypt(b"some bytes").await.unwrap();
110 assert_eq!(encrypted.key_version, 7);
111 }
112
113 #[tokio::test(flavor = "multi_thread")]
114 async fn same_plaintext_produces_different_ciphertext() {
115 let enc = make_encryptor(1).await;
116 let plaintext = b"determinism test";
117 let a = enc.encrypt(plaintext).await.unwrap();
118 let b = enc.encrypt(plaintext).await.unwrap();
119 assert_ne!(a.ciphertext, b.ciphertext, "KMS should produce different ciphertext per call");
120 }
121}