1#![doc = include_str!("../README.md")]
2use crate::private::ValidMacSize;
3use aws_sdk_kms::{primitives::Blob, Client, Config};
4use thiserror::Error;
5use vitaminc_async_traits::{AsyncFixedOutput, AsyncFixedOutputReset};
6use vitaminc_protected::{AsProtectedRef, Controlled, Protected, ProtectedRef};
7use vitaminc_traits::Update;
8use zeroize::Zeroize;
9
10pub struct AwsKmsHmac<const N: usize> {
20 client: Client,
21 key_id: String,
22 input: Protected<Vec<u8>>,
24}
25
26#[derive(Debug, Error)]
27pub enum Error {
28 #[error(transparent)]
29 AwsSdk(#[from] aws_sdk_kms::Error),
30}
31
32impl<const N: usize> AwsKmsHmac<N>
33where
34 Self: private::ValidMacSize<N>,
35{
36 pub fn new(config: Config, key_id: impl Into<String>) -> Self {
37 Self {
38 client: Client::from_conf(config),
39 key_id: key_id.into(),
40 input: Protected::new(Vec::new()),
41 }
42 }
43
44 async fn generate_mac(&self) -> Result<Blob, Error> {
45 self.client
46 .generate_mac()
47 .key_id(&self.key_id)
48 .mac_algorithm(Self::spec())
49 .message(Blob::new(self.input.clone().risky_unwrap()))
51 .send()
52 .await
53 .map(|response| response.mac.unwrap())
54 .map_err(aws_sdk_kms::Error::from)
55 .map_err(Error::AwsSdk)
56 }
57}
58
59pub struct Info(pub &'static str);
63
64impl<const N: usize, T> Update<&Protected<T>> for AwsKmsHmac<N>
65where
66 T: AsRef<[u8]> + Zeroize,
67{
68 fn update(&mut self, data: &Protected<T>) {
69 let pref: ProtectedRef<T> = data.as_protected_ref();
70 self.input.update_with_ref(pref, |input, data| {
71 input.extend(data.as_ref());
72 });
73 }
74}
75
76impl<const N: usize, T> Update<Protected<T>> for AwsKmsHmac<N>
77where
78 T: AsRef<[u8]> + Zeroize,
79{
80 fn update(&mut self, data: Protected<T>) {
81 self.input.update_with(data, |input, mut data| {
82 input.extend_from_slice(data.as_ref());
83 data.zeroize();
84 });
85 }
86}
87
88impl<const N: usize> Update<Info> for AwsKmsHmac<N> {
89 fn update(&mut self, data: Info) {
90 let pref: ProtectedRef<[u8]> = data.0.as_protected_ref();
91 self.input.update_with_ref(pref, |input, data| {
92 input.extend(data);
93 });
94 }
95}
96
97impl<'r, const N: usize> Update<ProtectedRef<'r, [u8]>> for AwsKmsHmac<N> {
98 fn update(&mut self, pref: ProtectedRef<[u8]>) {
99 self.input.update_with_ref(pref, |input, data| {
100 input.extend(data);
101 });
102 }
103}
104
105impl<const N: usize> AsyncFixedOutput<N, Protected<[u8; N]>> for AwsKmsHmac<N>
106where
107 Self: private::ValidMacSize<N>,
108{
109 type Error = Error;
110
111 async fn try_finalize_into(self, out: &mut Protected<[u8; N]>) -> Result<(), Self::Error> {
112 let output = self.generate_mac().await?;
113 let response = Protected::new(output.into_inner());
114
115 out.update_with(response, |out, data| {
116 out.copy_from_slice(data.as_ref());
117 });
118
119 Ok(())
120 }
121}
122
123impl<const N: usize, C> AsyncFixedOutputReset<N, C> for AwsKmsHmac<N>
124where
125 Self: private::ValidMacSize<N>,
126 C: Controlled<Inner = [u8; N]>,
127{
128 type Error = Error;
129
130 async fn try_finalize_into_reset(&mut self, out: &mut C) -> Result<(), Self::Error> {
131 let output = self.generate_mac().await?;
132 let response = Protected::new(output.into_inner());
133
134 out.update_with(response, |out, data| {
135 out.copy_from_slice(data.as_ref());
136 });
137
138 self.input.update(|input| input.clear());
139
140 Ok(())
141 }
142}
143
144mod private {
145 use aws_sdk_kms::types::MacAlgorithmSpec;
146
147 pub trait ValidMacSize<const N: usize> {
148 fn spec() -> MacAlgorithmSpec;
149 }
150
151 impl ValidMacSize<28> for super::AwsKmsHmac<28> {
152 fn spec() -> MacAlgorithmSpec {
153 MacAlgorithmSpec::HmacSha224
154 }
155 }
156
157 impl ValidMacSize<32> for super::AwsKmsHmac<32> {
158 fn spec() -> MacAlgorithmSpec {
159 MacAlgorithmSpec::HmacSha256
160 }
161 }
162
163 impl ValidMacSize<48> for super::AwsKmsHmac<48> {
164 fn spec() -> MacAlgorithmSpec {
165 MacAlgorithmSpec::HmacSha384
166 }
167 }
168
169 impl ValidMacSize<64> for super::AwsKmsHmac<64> {
170 fn spec() -> MacAlgorithmSpec {
171 MacAlgorithmSpec::HmacSha512
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use crate::{AwsKmsHmac, Info};
179 use aws_sdk_kms::{
180 types::{KeySpec, KeyUsageType},
181 Client, Config,
182 };
183 use vitaminc_async_traits::AsyncFixedOutput;
184 use vitaminc_protected::{Controlled, Protected};
185 use vitaminc_traits::Update;
186
187 fn get_config() -> Config {
188 use aws_config::{BehaviorVersion, Region};
189
190 let endpoint_url = "http://localhost:4566";
192 let creds = aws_sdk_kms::config::Credentials::new("fake", "fake", None, None, "test");
193
194 aws_sdk_kms::config::Builder::default()
195 .behavior_version(BehaviorVersion::v2025_08_07())
196 .region(Region::new("us-east-1"))
197 .credentials_provider(creds)
198 .endpoint_url(endpoint_url)
199 .build()
200 }
201
202 async fn get_key_id(
203 client: &Client,
204 keyspec: KeySpec,
205 ) -> Result<String, Box<dyn std::error::Error>> {
206 let key = client
207 .create_key()
208 .key_usage(KeyUsageType::GenerateVerifyMac)
209 .key_spec(keyspec)
210 .send()
211 .await?;
212
213 Ok(key.key_metadata().unwrap().key_id().to_owned())
214 }
215
216 #[tokio::test]
217 async fn test_update() -> Result<(), Box<dyn std::error::Error>> {
218 let mut hmac: AwsKmsHmac<32> =
219 AwsKmsHmac::new(get_config(), "0cce5331-13a6-437f-a477-1c8988667281");
220 hmac.update(&Protected::new(vec![0, 1]));
221 hmac.update(&Protected::new(vec![2, 3]));
222 hmac.update(Info("test"));
223
224 assert_eq!(
225 hmac.input.risky_unwrap(),
226 vec![0, 1, 2, 3, 116, 101, 115, 116]
227 );
228
229 Ok(())
230 }
231
232 #[tokio::test]
233 async fn test_chain() -> Result<(), Box<dyn std::error::Error>> {
234 let hmac: AwsKmsHmac<32> =
237 AwsKmsHmac::new(get_config(), "0cce5331-13a6-437f-a477-1c8988667281")
238 .chain(&Protected::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]))
239 .chain(&Protected::new(vec![11, 12]));
240
241 assert_eq!(
242 hmac.input.risky_unwrap(),
243 vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 11, 12]
244 );
245
246 Ok(())
247 }
248
249 #[tokio::test]
250 async fn test_finalize() -> Result<(), Box<dyn std::error::Error>> {
251 let config = get_config();
252 let client = Client::from_conf(config);
253 let key_id = get_key_id(&client, KeySpec::Hmac512).await?;
254
255 AwsKmsHmac::<64>::new(get_config(), key_id)
256 .chain(&Protected::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]))
257 .try_finalize_fixed()
258 .await?;
259
260 Ok(())
261 }
262}