1use std::sync::Arc;
31
32use aws_sdk_kms::operation::get_public_key::GetPublicKeyError;
33use aws_sdk_kms::operation::sign::SignError;
34use aws_sdk_kms::primitives::Blob;
35use aws_sdk_kms::types::{MessageType, SigningAlgorithmSpec};
36use secretx_core::{
37 SecretError, SecretUri, SigningAlgorithm, SigningBackend,
38};
39use sha2::{Digest as _, Sha256};
40
41const BACKEND: &str = "aws-kms";
42
43#[derive(Debug)]
51pub struct AwsKmsBackend {
52 client: Arc<aws_sdk_kms::Client>,
53 key_id: String,
54 algorithm: SigningAlgorithm,
55}
56
57impl AwsKmsBackend {
58 pub fn from_uri(uri: &str) -> Result<Self, SecretError> {
70 Self::from_parsed_uri(&SecretUri::parse(uri)?)
71 }
72
73 pub fn from_parsed_uri(parsed: &SecretUri) -> Result<Self, SecretError> {
75 if parsed.backend() != BACKEND {
76 return Err(SecretError::InvalidUri(format!(
77 "expected backend `aws-kms`, got `{}`",
78 parsed.backend()
79 )));
80 }
81 if parsed.path().is_empty() {
82 return Err(SecretError::InvalidUri(
83 "aws-kms URI requires a key ID: secretx:aws-kms:<key-id>".into(),
84 ));
85 }
86
87 let algorithm = match parsed.param("algorithm") {
93 None | Some("ecdsa-p256") => SigningAlgorithm::EcdsaP256Sha256,
94 Some("rsa-pss-2048") => SigningAlgorithm::RsaPss2048Sha256,
95 Some(other) => {
96 return Err(SecretError::InvalidUri(format!(
97 "unknown algorithm `{other}`; supported: ecdsa-p256, rsa-pss-2048"
98 )));
99 }
100 };
101
102 let key_id = parsed.path().to_owned();
103
104 let client = secretx_core::run_on_new_thread(
105 || async {
106 let config = aws_config::defaults(aws_config::BehaviorVersion::latest()).load().await;
107 Ok(aws_sdk_kms::Client::new(&config))
108 },
109 "aws-kms",
110 )?;
111
112 Ok(Self {
113 client: Arc::new(client),
114 key_id,
115 algorithm,
116 })
117 }
118}
119
120fn is_transient_kms_code(code: &str) -> bool {
127 matches!(
128 code,
129 "ThrottlingException"
130 | "RequestThrottledException"
131 | "KMSInternalException"
132 | "DependencyTimeoutException"
133 | "KeyUnavailableException"
134 )
135}
136
137fn classify_kms_sdk_error<E>(
142 sdk_err: aws_sdk_kms::error::SdkError<E>,
143 is_not_found: impl FnOnce(&E) -> bool,
144) -> SecretError
145where
146 E: std::error::Error + std::fmt::Display + aws_sdk_kms::error::ProvideErrorMetadata + Send + Sync + 'static,
147{
148 if let Some(svc) = sdk_err.as_service_error() {
149 if is_not_found(svc) {
150 return SecretError::NotFound;
151 }
152 let code = svc.meta().code().unwrap_or("");
153 if is_transient_kms_code(code) {
154 return SecretError::Unavailable {
155 backend: BACKEND,
156 source: svc.to_string().into(),
157 };
158 }
159 return SecretError::Backend {
160 backend: BACKEND,
161 source: svc.to_string().into(),
162 };
163 }
164 SecretError::Unavailable {
165 backend: BACKEND,
166 source: sdk_err.to_string().into(),
167 }
168}
169
170fn ecdsa_der_to_raw_p256(der: &[u8]) -> Result<Vec<u8>, SecretError> {
179 let parse_err = |msg: &'static str| SecretError::Backend {
180 backend: BACKEND,
181 source: format!("ECDSA DER signature parse failed: {msg}").into(),
182 };
183
184 let rest = der
186 .strip_prefix(&[0x30])
187 .ok_or_else(|| parse_err("expected SEQUENCE tag 0x30"))?;
188 let (seq_len, rest) = der_length(rest).ok_or_else(|| parse_err("invalid SEQUENCE length"))?;
189 if rest.len() < seq_len {
190 return Err(parse_err("SEQUENCE truncated"));
191 }
192 if rest.len() > seq_len {
193 return Err(parse_err("trailing bytes after SEQUENCE"));
194 }
195 let rest = &rest[..seq_len];
196
197 let (r, rest) = der_integer(rest).ok_or_else(|| parse_err("invalid INTEGER r"))?;
198 let (s, rest) = der_integer(rest).ok_or_else(|| parse_err("invalid INTEGER s"))?;
199 if !rest.is_empty() {
200 return Err(parse_err("trailing bytes inside SEQUENCE after INTEGER s"));
201 }
202
203 fn fixed32(n: &[u8]) -> Result<[u8; 32], &'static str> {
204 if n.len() > 32 {
205 return Err("integer component exceeds 32 bytes for P-256");
206 }
207 let mut out = [0u8; 32];
208 out[32 - n.len()..].copy_from_slice(n);
209 Ok(out)
210 }
211
212 let r32 = fixed32(r).map_err(parse_err)?;
213 let s32 = fixed32(s).map_err(parse_err)?;
214
215 Ok([r32, s32].concat())
216}
217
218fn der_length(bytes: &[u8]) -> Option<(usize, &[u8])> {
223 let (&first, rest) = bytes.split_first()?;
224 if first < 0x80 {
225 Some((first as usize, rest))
226 } else {
227 let n = (first & 0x7f) as usize;
228 if n == 0 || n > 2 || rest.len() < n {
229 return None;
230 }
231 let mut len = 0usize;
232 for &b in &rest[..n] {
233 len = (len << 8) | (b as usize);
234 }
235 if len < 0x80 {
238 return None; }
240 if n == 2 && len < 0x100 {
241 return None; }
243 Some((len, &rest[n..]))
244 }
245}
246
247fn der_integer(bytes: &[u8]) -> Option<(&[u8], &[u8])> {
250 let (&tag, rest) = bytes.split_first()?;
251 if tag != 0x02 {
252 return None;
253 }
254 let (len, rest) = der_length(rest)?;
255 if rest.len() < len {
256 return None;
257 }
258 let (value, rest) = rest.split_at(len);
259 let value = value.strip_prefix(&[0x00]).unwrap_or(value);
262 Some((value, rest))
263}
264
265#[async_trait::async_trait]
268impl SigningBackend for AwsKmsBackend {
269 async fn sign(&self, message: &[u8]) -> Result<Vec<u8>, SecretError> {
270 let algo_spec = match self.algorithm {
271 SigningAlgorithm::EcdsaP256Sha256 => SigningAlgorithmSpec::EcdsaSha256,
272 SigningAlgorithm::RsaPss2048Sha256 => SigningAlgorithmSpec::RsassaPssSha256,
273 SigningAlgorithm::Ed25519 => {
274 return Err(SecretError::Backend {
278 backend: BACKEND,
279 source: "Ed25519 is not supported by AWS KMS; use ecdsa-p256 or rsa-pss-2048"
280 .into(),
281 });
282 }
283 _ => {
287 return Err(SecretError::Backend {
288 backend: BACKEND,
289 source: format!("algorithm {:?} is not supported by AWS KMS", self.algorithm)
290 .into(),
291 });
292 }
293 };
294
295 let digest: [u8; 32] = Sha256::digest(message).into();
301
302 let response = self
303 .client
304 .sign()
305 .key_id(&self.key_id)
306 .message(Blob::new(digest))
307 .message_type(MessageType::Digest)
308 .signing_algorithm(algo_spec)
309 .send()
310 .await
311 .map_err(|sdk_err| {
312 classify_kms_sdk_error(sdk_err, SignError::is_not_found_exception)
313 })?;
314
315 let sig_bytes = response
316 .signature
317 .ok_or_else(|| SecretError::Backend {
318 backend: BACKEND,
319 source: "KMS sign response contained no signature".into(),
320 })?
321 .into_inner();
322
323 match self.algorithm {
328 SigningAlgorithm::EcdsaP256Sha256 => ecdsa_der_to_raw_p256(&sig_bytes),
329 _ => Ok(sig_bytes),
330 }
331 }
332
333 async fn public_key_der(&self) -> Result<Vec<u8>, SecretError> {
334 let response = self
335 .client
336 .get_public_key()
337 .key_id(&self.key_id)
338 .send()
339 .await
340 .map_err(|sdk_err| {
341 classify_kms_sdk_error(sdk_err, GetPublicKeyError::is_not_found_exception)
342 })?;
343
344 Ok(response
345 .public_key
346 .ok_or_else(|| SecretError::Backend {
347 backend: BACKEND,
348 source: "KMS get_public_key response contained no public key".into(),
349 })?
350 .into_inner())
351 }
352
353 fn algorithm(&self) -> Result<SigningAlgorithm, SecretError> {
354 Ok(self.algorithm)
355 }
356}
357
358inventory::submit!(secretx_core::SigningBackendRegistration::new(
363 "aws-kms",
364 |uri: &secretx_core::SecretUri| {
365 let b = AwsKmsBackend::from_parsed_uri(uri)?;
366 Ok(Arc::new(b) as Arc<dyn secretx_core::SigningBackend>)
367 },
368));
369
370#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
379 fn from_uri_wrong_backend() {
380 assert!(matches!(
381 AwsKmsBackend::from_uri("secretx:aws-sm:some-key"),
382 Err(SecretError::InvalidUri(_))
383 ));
384 }
385
386 #[test]
387 fn from_uri_missing_key_id() {
388 assert!(matches!(
389 AwsKmsBackend::from_uri("secretx:aws-kms:"),
390 Err(SecretError::InvalidUri(_))
391 ));
392 }
393
394 #[test]
395 fn from_uri_invalid_algorithm() {
396 assert!(matches!(
397 AwsKmsBackend::from_uri("secretx:aws-kms:alias/my-key?algorithm=elgamal"),
398 Err(SecretError::InvalidUri(_))
399 ));
400 }
401
402 fn integration_key_id() -> Option<String> {
407 std::env::var("SECRETX_AWS_KMS_TEST_KEY_ID").ok()
408 }
409
410 fn integration_rsa_key_id() -> Option<String> {
411 std::env::var("SECRETX_AWS_KMS_TEST_RSA_KEY_ID").ok()
412 }
413
414 #[tokio::test]
415 async fn integration_sign_and_verify_ecdsa() {
416 let Some(key_id) = integration_key_id() else {
417 eprintln!("SECRETX_AWS_KMS_TEST_KEY_ID not set; skipping integration test");
418 return;
419 };
420
421 let uri = format!("secretx:aws-kms:{key_id}?algorithm=ecdsa-p256");
422 let backend = AwsKmsBackend::from_uri(&uri).expect("from_uri failed");
423 assert_eq!(
424 backend.algorithm().expect("algorithm"),
425 SigningAlgorithm::EcdsaP256Sha256
426 );
427
428 let message = b"hello from secretx-aws-kms integration test";
429 let sig_bytes = backend.sign(message).await.expect("sign failed");
430 assert_eq!(
431 sig_bytes.len(),
432 64,
433 "ECDSA P-256 raw signature must be 64 bytes"
434 );
435
436 let pub_der = backend
437 .public_key_der()
438 .await
439 .expect("public_key_der failed");
440 assert!(!pub_der.is_empty(), "public key DER must not be empty");
441
442 use p256::ecdsa::{signature::Verifier, Signature, VerifyingKey};
445 use p256::pkcs8::DecodePublicKey;
446 let vk = VerifyingKey::from_public_key_der(&pub_der)
447 .expect("P-256 VerifyingKey from DER failed");
448 let sig = Signature::from_bytes(sig_bytes.as_slice().into())
449 .expect("P-256 Signature decode failed");
450 vk.verify(message, &sig)
451 .expect("P-256 signature verification failed");
452 }
453
454 #[tokio::test]
455 async fn integration_sign_and_verify_rsa_pss() {
456 let Some(key_id) = integration_rsa_key_id() else {
457 eprintln!("SECRETX_AWS_KMS_TEST_RSA_KEY_ID not set; skipping integration test");
458 return;
459 };
460
461 let uri = format!("secretx:aws-kms:{key_id}?algorithm=rsa-pss-2048");
462 let backend = AwsKmsBackend::from_uri(&uri).expect("from_uri failed");
463 assert_eq!(
464 backend.algorithm().expect("algorithm"),
465 SigningAlgorithm::RsaPss2048Sha256
466 );
467
468 let message = b"hello from secretx-aws-kms rsa-pss integration test";
469 let sig_bytes = backend.sign(message).await.expect("sign failed");
470 assert_eq!(
471 sig_bytes.len(),
472 256,
473 "RSA-2048 PSS signature must be 256 bytes"
474 );
475
476 let pub_der = backend
477 .public_key_der()
478 .await
479 .expect("public_key_der failed");
480 assert!(!pub_der.is_empty(), "public key DER must not be empty");
481
482 use rsa::pkcs8::DecodePublicKey;
485 use rsa::pss::VerifyingKey;
486 use rsa::signature::Verifier;
487 let pub_key = rsa::RsaPublicKey::from_public_key_der(&pub_der)
488 .expect("RSA public key from DER failed");
489 let vk = VerifyingKey::<sha2::Sha256>::new(pub_key);
490 let sig = rsa::pss::Signature::try_from(sig_bytes.as_slice())
491 .expect("RSA-PSS Signature decode failed");
492 vk.verify(message, &sig)
493 .expect("RSA-PSS signature verification failed");
494 }
495
496 #[tokio::test]
497 async fn integration_not_found() {
498 let Some(_) = integration_key_id() else {
499 eprintln!("SECRETX_AWS_KMS_TEST_KEY_ID not set; skipping integration test");
500 return;
501 };
502
503 let uri = "secretx:aws-kms:alias/nonexistent-key-that-does-not-exist-secretx-test";
504 let backend = AwsKmsBackend::from_uri(uri).expect("from_uri failed");
505 let result = backend.sign(b"test").await;
506 assert!(
507 matches!(result, Err(SecretError::NotFound)),
508 "expected NotFound for nonexistent key, got: {result:?}"
509 );
510 }
511
512 #[tokio::test]
513 async fn integration_default_algorithm_is_ecdsa() {
514 let Some(key_id) = integration_key_id() else {
515 eprintln!("SECRETX_AWS_KMS_TEST_KEY_ID not set; skipping integration test");
516 return;
517 };
518
519 let uri = format!("secretx:aws-kms:{key_id}");
520 let backend = AwsKmsBackend::from_uri(&uri).expect("from_uri failed");
521 assert_eq!(
522 backend.algorithm().expect("algorithm"),
523 SigningAlgorithm::EcdsaP256Sha256,
524 "default algorithm must be EcdsaP256Sha256"
525 );
526 }
527
528 #[test]
540 fn ecdsa_der_to_raw_p256_no_padding() {
541 let mut der = Vec::new();
543 let r: Vec<u8> = {
544 let mut v = vec![0u8; 32];
545 v[0] = 0x01;
546 v
547 };
548 let s: Vec<u8> = {
549 let mut v = vec![0u8; 32];
550 v[0] = 0x02;
551 v
552 };
553 let mut int_r = vec![0x02u8, 32];
555 int_r.extend_from_slice(&r);
556 let mut int_s = vec![0x02u8, 32];
558 int_s.extend_from_slice(&s);
559 der.push(0x30);
561 der.push((int_r.len() + int_s.len()) as u8);
562 der.extend_from_slice(&int_r);
563 der.extend_from_slice(&int_s);
564
565 let raw = ecdsa_der_to_raw_p256(&der).expect("should parse");
566 assert_eq!(raw.len(), 64);
567 assert_eq!(&raw[..32], r.as_slice());
568 assert_eq!(&raw[32..], s.as_slice());
569 }
570
571 #[test]
572 fn ecdsa_der_to_raw_p256_with_sign_extension() {
573 let mut der = Vec::new();
576 let r: Vec<u8> = {
577 let mut v = vec![0u8; 32];
578 v[0] = 0xFF;
579 v
580 }; let s: Vec<u8> = {
582 let mut v = vec![0u8; 32];
583 v[0] = 0x01;
584 v
585 };
586 let mut int_r = vec![0x02u8, 33];
588 int_r.push(0x00); int_r.extend_from_slice(&r);
590 let mut int_s = vec![0x02u8, 32];
592 int_s.extend_from_slice(&s);
593 der.push(0x30);
595 der.push((int_r.len() + int_s.len()) as u8);
596 der.extend_from_slice(&int_r);
597 der.extend_from_slice(&int_s);
598
599 let raw = ecdsa_der_to_raw_p256(&der).expect("should parse");
600 assert_eq!(raw.len(), 64);
601 assert_eq!(&raw[..32], r.as_slice()); assert_eq!(&raw[32..], s.as_slice());
603 }
604
605 #[test]
606 fn ecdsa_der_to_raw_p256_bad_tag_rejected() {
607 let garbage = [0x31u8, 0x04, 0x02, 0x01, 0x01, 0x02, 0x01, 0x01];
609 assert!(matches!(
610 ecdsa_der_to_raw_p256(&garbage),
611 Err(SecretError::Backend { .. })
612 ));
613 }
614}