vitaminc_kms/
lib.rs

1#![doc = include_str!("../README.md")]
2use crate::private::ValidMacSize;
3use aws_sdk_kms::{primitives::Blob, Client, Config};
4use thiserror::Error;
5use vitaminc_async_traits::{AsyncFixedOutput, AsyncFixedOutputReset};
6use vitaminc_protected::{AsProtectedRef, Controlled, Protected, ProtectedRef};
7use vitaminc_traits::Update;
8use zeroize::Zeroize;
9
10/// A `Mac` implementation that uses AWS KMS to generate HMACs of `N` bytes.
11/// Valid sizes are 28, 32, 48, and 64 bytes.
12///
13/// These corespond to the following algorithms:
14/// - 28 bytes: HMAC-SHA224
15/// - 32 bytes: HMAC-SHA256
16/// - 48 bytes: HMAC-SHA384
17/// - 64 bytes: HMAC-SHA512
18///
19pub struct AwsKmsHmac<const N: usize> {
20    client: Client,
21    key_id: String,
22    // TODO: Consider using heapless::Vec
23    input: Protected<Vec<u8>>,
24}
25
26#[derive(Debug, Error)]
27pub enum Error {
28    #[error(transparent)]
29    AwsSdk(#[from] aws_sdk_kms::Error),
30}
31
32impl<const N: usize> AwsKmsHmac<N>
33where
34    Self: private::ValidMacSize<N>,
35{
36    pub fn new(config: Config, key_id: impl Into<String>) -> Self {
37        Self {
38            client: Client::from_conf(config),
39            key_id: key_id.into(),
40            input: Protected::new(Vec::new()),
41        }
42    }
43
44    async fn generate_mac(&self) -> Result<Blob, Error> {
45        self.client
46            .generate_mac()
47            .key_id(&self.key_id)
48            .mac_algorithm(Self::spec())
49            // TODO: Prefer not to unwrap - async map for Paranoid?
50            .message(Blob::new(self.input.clone().risky_unwrap()))
51            .send()
52            .await
53            .map(|response| response.mac.unwrap())
54            .map_err(aws_sdk_kms::Error::from)
55            .map_err(Error::AwsSdk)
56    }
57}
58
59/// Named type to represent _non-sensitive_ data that is passed to the `update` method.
60/// Using a specific type allows us to reason about the input type and its sensitivity.
61/// TODO: This probably should be part of the `vitaminc_traits` crate.
62pub struct Info(pub &'static str);
63
64impl<const N: usize, T> Update<&Protected<T>> for AwsKmsHmac<N>
65where
66    T: AsRef<[u8]> + Zeroize,
67{
68    fn update(&mut self, data: &Protected<T>) {
69        let pref: ProtectedRef<T> = data.as_protected_ref();
70        self.input.update_with_ref(pref, |input, data| {
71            input.extend(data.as_ref());
72        });
73    }
74}
75
76impl<const N: usize, T> Update<Protected<T>> for AwsKmsHmac<N>
77where
78    T: AsRef<[u8]> + Zeroize,
79{
80    fn update(&mut self, data: Protected<T>) {
81        self.input.update_with(data, |input, mut data| {
82            input.extend_from_slice(data.as_ref());
83            data.zeroize();
84        });
85    }
86}
87
88impl<const N: usize> Update<Info> for AwsKmsHmac<N> {
89    fn update(&mut self, data: Info) {
90        let pref: ProtectedRef<[u8]> = data.0.as_protected_ref();
91        self.input.update_with_ref(pref, |input, data| {
92            input.extend(data);
93        });
94    }
95}
96
97impl<'r, const N: usize> Update<ProtectedRef<'r, [u8]>> for AwsKmsHmac<N> {
98    fn update(&mut self, pref: ProtectedRef<[u8]>) {
99        self.input.update_with_ref(pref, |input, data| {
100            input.extend(data);
101        });
102    }
103}
104
105impl<const N: usize> AsyncFixedOutput<N, Protected<[u8; N]>> for AwsKmsHmac<N>
106where
107    Self: private::ValidMacSize<N>,
108{
109    type Error = Error;
110
111    async fn try_finalize_into(self, out: &mut Protected<[u8; N]>) -> Result<(), Self::Error> {
112        let output = self.generate_mac().await?;
113        let response = Protected::new(output.into_inner());
114
115        out.update_with(response, |out, data| {
116            out.copy_from_slice(data.as_ref());
117        });
118
119        Ok(())
120    }
121}
122
123impl<const N: usize, C> AsyncFixedOutputReset<N, C> for AwsKmsHmac<N>
124where
125    Self: private::ValidMacSize<N>,
126    C: Controlled<Inner = [u8; N]>,
127{
128    type Error = Error;
129
130    async fn try_finalize_into_reset(&mut self, out: &mut C) -> Result<(), Self::Error> {
131        let output = self.generate_mac().await?;
132        let response = Protected::new(output.into_inner());
133
134        out.update_with(response, |out, data| {
135            out.copy_from_slice(data.as_ref());
136        });
137
138        self.input.update(|input| input.clear());
139
140        Ok(())
141    }
142}
143
144mod private {
145    use aws_sdk_kms::types::MacAlgorithmSpec;
146
147    pub trait ValidMacSize<const N: usize> {
148        fn spec() -> MacAlgorithmSpec;
149    }
150
151    impl ValidMacSize<28> for super::AwsKmsHmac<28> {
152        fn spec() -> MacAlgorithmSpec {
153            MacAlgorithmSpec::HmacSha224
154        }
155    }
156
157    impl ValidMacSize<32> for super::AwsKmsHmac<32> {
158        fn spec() -> MacAlgorithmSpec {
159            MacAlgorithmSpec::HmacSha256
160        }
161    }
162
163    impl ValidMacSize<48> for super::AwsKmsHmac<48> {
164        fn spec() -> MacAlgorithmSpec {
165            MacAlgorithmSpec::HmacSha384
166        }
167    }
168
169    impl ValidMacSize<64> for super::AwsKmsHmac<64> {
170        fn spec() -> MacAlgorithmSpec {
171            MacAlgorithmSpec::HmacSha512
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use crate::{AwsKmsHmac, Info};
179    use aws_sdk_kms::{
180        types::{KeySpec, KeyUsageType},
181        Client, Config,
182    };
183    use vitaminc_async_traits::AsyncFixedOutput;
184    use vitaminc_protected::{Controlled, Protected};
185    use vitaminc_traits::Update;
186
187    fn get_config() -> Config {
188        use aws_config::{BehaviorVersion, Region};
189
190        // Set up AWS client
191        let endpoint_url = "http://localhost:4566";
192        let creds = aws_sdk_kms::config::Credentials::new("fake", "fake", None, None, "test");
193
194        aws_sdk_kms::config::Builder::default()
195            .behavior_version(BehaviorVersion::v2025_08_07())
196            .region(Region::new("us-east-1"))
197            .credentials_provider(creds)
198            .endpoint_url(endpoint_url)
199            .build()
200    }
201
202    async fn get_key_id(
203        client: &Client,
204        keyspec: KeySpec,
205    ) -> Result<String, Box<dyn std::error::Error>> {
206        let key = client
207            .create_key()
208            .key_usage(KeyUsageType::GenerateVerifyMac)
209            .key_spec(keyspec)
210            .send()
211            .await?;
212
213        Ok(key.key_metadata().unwrap().key_id().to_owned())
214    }
215
216    #[tokio::test]
217    async fn test_update() -> Result<(), Box<dyn std::error::Error>> {
218        let mut hmac: AwsKmsHmac<32> =
219            AwsKmsHmac::new(get_config(), "0cce5331-13a6-437f-a477-1c8988667281");
220        hmac.update(&Protected::new(vec![0, 1]));
221        hmac.update(&Protected::new(vec![2, 3]));
222        hmac.update(Info("test"));
223
224        assert_eq!(
225            hmac.input.risky_unwrap(),
226            vec![0, 1, 2, 3, 116, 101, 115, 116]
227        );
228
229        Ok(())
230    }
231
232    #[tokio::test]
233    async fn test_chain() -> Result<(), Box<dyn std::error::Error>> {
234        // TODO: Test all the variants
235        // Also doctest with invalid sizes
236        let hmac: AwsKmsHmac<32> =
237            AwsKmsHmac::new(get_config(), "0cce5331-13a6-437f-a477-1c8988667281")
238                .chain(&Protected::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]))
239                .chain(&Protected::new(vec![11, 12]));
240
241        assert_eq!(
242            hmac.input.risky_unwrap(),
243            vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 11, 12]
244        );
245
246        Ok(())
247    }
248
249    #[tokio::test]
250    async fn test_finalize() -> Result<(), Box<dyn std::error::Error>> {
251        let config = get_config();
252        let client = Client::from_conf(config);
253        let key_id = get_key_id(&client, KeySpec::Hmac512).await?;
254
255        AwsKmsHmac::<64>::new(get_config(), key_id)
256            .chain(&Protected::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]))
257            .try_finalize_fixed()
258            .await?;
259
260        Ok(())
261    }
262}