1use crate::error::{Result, VoltError};
12use aes::cipher::{block_padding::Pkcs7, BlockDecryptMut, BlockEncryptMut, KeyIvInit};
13use aes_gcm::{
14 aead::{Aead, KeyInit},
15 Aes256Gcm, Nonce,
16};
17use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
18use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
19use hkdf::Hkdf;
20use rand::rngs::OsRng;
21use sha2::Sha256;
22use x25519_dalek::{EphemeralSecret, PublicKey as X25519PublicKey, StaticSecret};
23
24pub const AES_KEY_LENGTH: usize = 32;
26pub const AES_NONCE_LENGTH: usize = 12;
28pub const AES_CBC_IV_LENGTH: usize = 16;
30
31pub const RELAY_HKDF_SALT: &str = "a06e10d13fa4445a";
33pub const RELAY_HKDF_INFO: &str = "tdx-volt-encryption-key-derivation";
35
36type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>;
38type Aes256CbcDec = cbc::Decryptor<aes::Aes256>;
39
40#[derive(Clone)]
42pub struct SigningKeyPair {
43 signing_key: SigningKey,
44}
45
46impl SigningKeyPair {
47 pub fn generate() -> Self {
49 let signing_key = SigningKey::generate(&mut OsRng);
50 Self { signing_key }
51 }
52
53 pub fn from_pem(pem: &str) -> Result<Self> {
55 let pem_contents = pem
56 .lines()
57 .filter(|line| !line.starts_with("-----"))
58 .collect::<String>();
59
60 let key_bytes = BASE64
61 .decode(&pem_contents)
62 .map_err(|e| VoltError::key(format!("Invalid PEM encoding: {}", e)))?;
63
64 let secret_bytes = if key_bytes.len() == 32 {
66 key_bytes
67 } else if key_bytes.len() > 32 {
68 key_bytes[key_bytes.len() - 32..].to_vec()
70 } else {
71 return Err(VoltError::key("Invalid key length"));
72 };
73
74 let secret_array: [u8; 32] = secret_bytes
75 .try_into()
76 .map_err(|_| VoltError::key("Invalid key length"))?;
77
78 let signing_key = SigningKey::from_bytes(&secret_array);
79 Ok(Self { signing_key })
80 }
81
82 pub fn private_key_pem(&self) -> String {
84 let encoded = BASE64.encode(self.signing_key.to_bytes());
85 format!(
86 "-----BEGIN PRIVATE KEY-----\n{}\n-----END PRIVATE KEY-----",
87 encoded
88 )
89 }
90
91 pub fn private_key_pkcs8_pem(&self) -> String {
93 let pkcs8_prefix: [u8; 16] = [
105 0x30, 0x2e, 0x02, 0x01, 0x00, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x04, 0x22, 0x04, 0x20, ];
112
113 let mut pkcs8_der = Vec::with_capacity(48);
114 pkcs8_der.extend_from_slice(&pkcs8_prefix);
115 pkcs8_der.extend_from_slice(self.signing_key.as_bytes());
116
117 let encoded = BASE64.encode(&pkcs8_der);
118 format!(
119 "-----BEGIN PRIVATE KEY-----\n{}\n-----END PRIVATE KEY-----",
120 encoded
121 )
122 }
123
124 pub fn public_key_pem(&self) -> String {
126 let encoded = BASE64.encode(self.signing_key.verifying_key().to_bytes());
127 format!(
128 "-----BEGIN PUBLIC KEY-----\n{}\n-----END PUBLIC KEY-----",
129 encoded
130 )
131 }
132
133 pub fn verifying_key(&self) -> VerifyingKey {
135 self.signing_key.verifying_key()
136 }
137
138 pub fn sign(&self, data: &[u8]) -> Vec<u8> {
140 let signature = self.signing_key.sign(data);
141 signature.to_bytes().to_vec()
142 }
143
144 pub fn sign_base64(&self, data: &[u8]) -> String {
146 BASE64.encode(self.sign(data))
147 }
148
149 pub fn secret_bytes(&self) -> &[u8; 32] {
151 self.signing_key.as_bytes()
152 }
153}
154
155pub fn verify_signature(
157 public_key: &VerifyingKey,
158 message: &[u8],
159 signature: &[u8],
160) -> Result<bool> {
161 let sig_array: [u8; 64] = signature
162 .try_into()
163 .map_err(|_| VoltError::crypto("Invalid signature length"))?;
164 let signature = Signature::from_bytes(&sig_array);
165
166 Ok(public_key.verify(message, &signature).is_ok())
167}
168
169pub fn public_key_from_pem(pem: &str) -> Result<VerifyingKey> {
171 let pem_contents = pem
172 .lines()
173 .filter(|line| !line.starts_with("-----"))
174 .collect::<String>();
175
176 let key_bytes = BASE64
177 .decode(&pem_contents)
178 .map_err(|e| VoltError::key(format!("Invalid PEM encoding: {}", e)))?;
179
180 let public_bytes = if key_bytes.len() == 32 {
182 key_bytes
183 } else if key_bytes.len() > 32 {
184 key_bytes[key_bytes.len() - 32..].to_vec()
186 } else {
187 return Err(VoltError::key("Invalid public key length"));
188 };
189
190 let key_array: [u8; 32] = public_bytes
191 .try_into()
192 .map_err(|_| VoltError::key("Invalid key length"))?;
193
194 VerifyingKey::from_bytes(&key_array)
195 .map_err(|e| VoltError::key(format!("Invalid public key: {}", e)))
196}
197
198pub fn fingerprint_from_key(public_key: &VerifyingKey) -> String {
200 use ring::digest::{digest, SHA256};
201 let hash = digest(&SHA256, public_key.as_bytes());
202 BASE64.encode(hash.as_ref())
203}
204
205pub struct KeyExchange {
207 secret: Option<EphemeralSecret>,
208 public_key: X25519PublicKey,
209}
210
211impl std::fmt::Debug for KeyExchange {
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 f.debug_struct("KeyExchange")
214 .field("public_key", &self.public_key_bytes())
215 .finish()
216 }
217}
218
219impl Clone for KeyExchange {
220 fn clone(&self) -> Self {
221 Self::new()
223 }
224}
225
226impl KeyExchange {
227 pub fn new() -> Self {
229 let secret = EphemeralSecret::random_from_rng(OsRng);
230 let public_key = X25519PublicKey::from(&secret);
231 Self {
232 secret: Some(secret),
233 public_key,
234 }
235 }
236
237 pub fn public_key_bytes(&self) -> [u8; 32] {
239 *self.public_key.as_bytes()
240 }
241
242 pub fn public_key_pem(&self) -> String {
244 let encoded = BASE64.encode(self.public_key.as_bytes());
245 format!(
246 "-----BEGIN PUBLIC KEY-----\n{}\n-----END PUBLIC KEY-----",
247 encoded
248 )
249 }
250
251 pub fn derive_shared_key(&mut self, peer_public_key: &[u8; 32]) -> Result<[u8; 32]> {
254 let secret = self
255 .secret
256 .take()
257 .ok_or_else(|| VoltError::crypto("Key exchange already consumed"))?;
258 let peer_key = X25519PublicKey::from(*peer_public_key);
259 let shared_secret = secret.diffie_hellman(&peer_key);
260 Ok(*shared_secret.as_bytes())
261 }
262}
263
264impl Default for KeyExchange {
265 fn default() -> Self {
266 Self::new()
267 }
268}
269
270pub struct AesKey {
272 key: [u8; AES_KEY_LENGTH],
273 iv: [u8; AES_NONCE_LENGTH],
274}
275
276impl AesKey {
277 pub fn generate() -> Self {
279 let mut key = [0u8; AES_KEY_LENGTH];
280 let mut iv = [0u8; AES_NONCE_LENGTH];
281
282 use rand::RngCore;
283 OsRng.fill_bytes(&mut key);
284 OsRng.fill_bytes(&mut iv);
285
286 Self { key, iv }
287 }
288
289 pub fn from_bytes(key: [u8; AES_KEY_LENGTH], iv: [u8; AES_NONCE_LENGTH]) -> Self {
291 Self { key, iv }
292 }
293
294 pub fn iv(&self) -> &[u8; AES_NONCE_LENGTH] {
296 &self.iv
297 }
298
299 pub fn key(&self) -> &[u8; AES_KEY_LENGTH] {
301 &self.key
302 }
303}
304
305pub fn aes_encrypt(
307 key: &[u8; AES_KEY_LENGTH],
308 iv: &[u8; AES_NONCE_LENGTH],
309 plaintext: &[u8],
310) -> Result<Vec<u8>> {
311 let cipher =
312 Aes256Gcm::new_from_slice(key).map_err(|e| VoltError::EncryptionError(e.to_string()))?;
313
314 let nonce = Nonce::from_slice(iv);
315
316 cipher
317 .encrypt(nonce, plaintext)
318 .map_err(|e| VoltError::EncryptionError(e.to_string()))
319}
320
321pub fn aes_decrypt(
323 key: &[u8; AES_KEY_LENGTH],
324 iv: &[u8; AES_NONCE_LENGTH],
325 ciphertext: &[u8],
326) -> Result<Vec<u8>> {
327 let cipher =
328 Aes256Gcm::new_from_slice(key).map_err(|e| VoltError::DecryptionError(e.to_string()))?;
329
330 let nonce = Nonce::from_slice(iv);
331
332 cipher
333 .decrypt(nonce, ciphertext)
334 .map_err(|e| VoltError::DecryptionError(e.to_string()))
335}
336
337pub fn random_bytes(len: usize) -> Vec<u8> {
339 let mut bytes = vec![0u8; len];
340 use rand::RngCore;
341 OsRng.fill_bytes(&mut bytes);
342 bytes
343}
344
345pub fn to_base64(data: &[u8]) -> String {
347 BASE64.encode(data)
348}
349
350pub fn from_base64(data: &str) -> Result<Vec<u8>> {
352 BASE64.decode(data).map_err(VoltError::from)
353}
354
355pub fn strip_pem_headers(pem: &str) -> String {
357 pem.lines()
358 .filter(|line| !line.starts_with("-----"))
359 .collect()
360}
361
362pub fn format_pem(data: &[u8], label: &str) -> String {
364 let encoded = BASE64.encode(data);
365 format!(
366 "-----BEGIN {}-----\n{}\n-----END {}-----",
367 label, encoded, label
368 )
369}
370
371pub fn aes_cbc_encrypt(
379 key: &[u8; AES_KEY_LENGTH],
380 iv: &[u8; AES_CBC_IV_LENGTH],
381 plaintext: &[u8],
382) -> Result<Vec<u8>> {
383 let block_size = 16;
385 let padded_len = ((plaintext.len() / block_size) + 1) * block_size;
386 let mut buffer = vec![0u8; padded_len];
387 buffer[..plaintext.len()].copy_from_slice(plaintext);
388
389 let cipher = Aes256CbcEnc::new_from_slices(key, iv)
390 .map_err(|e| VoltError::EncryptionError(format!("Invalid key/IV: {}", e)))?;
391
392 let ciphertext = cipher
393 .encrypt_padded_mut::<Pkcs7>(&mut buffer, plaintext.len())
394 .map_err(|e| VoltError::EncryptionError(format!("Encryption failed: {:?}", e)))?;
395
396 Ok(ciphertext.to_vec())
397}
398
399pub fn aes_cbc_decrypt(
403 key: &[u8; AES_KEY_LENGTH],
404 iv: &[u8; AES_CBC_IV_LENGTH],
405 ciphertext: &[u8],
406) -> Result<Vec<u8>> {
407 let mut buffer = ciphertext.to_vec();
408
409 let cipher = Aes256CbcDec::new_from_slices(key, iv)
410 .map_err(|e| VoltError::DecryptionError(format!("Invalid key/IV: {}", e)))?;
411
412 let plaintext = cipher
413 .decrypt_padded_mut::<Pkcs7>(&mut buffer)
414 .map_err(|e| VoltError::DecryptionError(format!("Decryption failed: {:?}", e)))?;
415
416 Ok(plaintext.to_vec())
417}
418
419pub fn random_iv() -> [u8; AES_CBC_IV_LENGTH] {
421 let mut iv = [0u8; AES_CBC_IV_LENGTH];
422 use rand::RngCore;
423 OsRng.fill_bytes(&mut iv);
424 iv
425}
426
427pub fn derive_relay_key(shared_secret: &[u8; 32]) -> Result<[u8; AES_KEY_LENGTH]> {
436 let hk = Hkdf::<Sha256>::new(Some(RELAY_HKDF_SALT.as_bytes()), shared_secret);
437
438 let mut derived_key = [0u8; AES_KEY_LENGTH];
439 hk.expand(RELAY_HKDF_INFO.as_bytes(), &mut derived_key)
440 .map_err(|e| VoltError::crypto(format!("HKDF expansion failed: {}", e)))?;
441
442 Ok(derived_key)
443}
444
445pub struct StaticKeyExchange {
454 secret: StaticSecret,
455 public_key: X25519PublicKey,
456}
457
458impl std::fmt::Debug for StaticKeyExchange {
459 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460 f.debug_struct("StaticKeyExchange")
461 .field("public_key", &self.public_key_bytes())
462 .finish()
463 }
464}
465
466impl StaticKeyExchange {
467 pub fn new() -> Self {
469 let secret = StaticSecret::random_from_rng(OsRng);
470 let public_key = X25519PublicKey::from(&secret);
471 Self { secret, public_key }
472 }
473
474 pub fn from_secret(secret_bytes: [u8; 32]) -> Self {
476 let secret = StaticSecret::from(secret_bytes);
477 let public_key = X25519PublicKey::from(&secret);
478 Self { secret, public_key }
479 }
480
481 pub fn public_key_bytes(&self) -> [u8; 32] {
483 *self.public_key.as_bytes()
484 }
485
486 pub fn public_key_base64(&self) -> String {
488 BASE64.encode(self.public_key.as_bytes())
489 }
490
491 pub fn derive_shared_secret(&self, peer_public_key: &[u8; 32]) -> [u8; 32] {
494 let peer_key = X25519PublicKey::from(*peer_public_key);
495 let shared_secret = self.secret.diffie_hellman(&peer_key);
496 *shared_secret.as_bytes()
497 }
498
499 pub fn derive_relay_encryption_key(
502 &self,
503 peer_public_key: &[u8; 32],
504 ) -> Result<[u8; AES_KEY_LENGTH]> {
505 let shared_secret = self.derive_shared_secret(peer_public_key);
506 derive_relay_key(&shared_secret)
507 }
508}
509
510impl Default for StaticKeyExchange {
511 fn default() -> Self {
512 Self::new()
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519
520 #[test]
521 fn test_signing_key_generation() {
522 let key = SigningKeyPair::generate();
523 let private_pem = key.private_key_pem();
524 let public_pem = key.public_key_pem();
525
526 assert!(private_pem.contains("BEGIN PRIVATE KEY"));
527 assert!(public_pem.contains("BEGIN PUBLIC KEY"));
528 }
529
530 #[test]
531 fn test_sign_and_verify() {
532 let key = SigningKeyPair::generate();
533 let message = b"Hello, World!";
534
535 let signature = key.sign(message);
536 let verified = verify_signature(&key.verifying_key(), message, &signature).unwrap();
537
538 assert!(verified);
539 }
540
541 #[test]
542 fn test_aes_encrypt_decrypt() {
543 let aes_key = AesKey::generate();
544 let plaintext = b"Secret message!";
545
546 let ciphertext = aes_encrypt(aes_key.key(), aes_key.iv(), plaintext).unwrap();
547 let decrypted = aes_decrypt(aes_key.key(), aes_key.iv(), &ciphertext).unwrap();
548
549 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
550 }
551
552 #[test]
553 fn test_key_exchange() {
554 let mut alice = KeyExchange::new();
555 let mut bob = KeyExchange::new();
556
557 let alice_public = alice.public_key_bytes();
558 let bob_public = bob.public_key_bytes();
559
560 let alice_shared = alice.derive_shared_key(&bob_public).unwrap();
561 let bob_shared = bob.derive_shared_key(&alice_public).unwrap();
562
563 assert_eq!(alice_shared, bob_shared);
564 }
565
566 #[test]
567 fn test_aes_cbc_encrypt_decrypt() {
568 let key = [0u8; AES_KEY_LENGTH];
569 let iv = [0u8; AES_CBC_IV_LENGTH];
570 let plaintext = b"Secret message for relay!";
571
572 let ciphertext = aes_cbc_encrypt(&key, &iv, plaintext).unwrap();
573 let decrypted = aes_cbc_decrypt(&key, &iv, &ciphertext).unwrap();
574
575 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
576 }
577
578 #[test]
579 fn test_aes_cbc_with_random_iv() {
580 let key = random_bytes(AES_KEY_LENGTH).try_into().unwrap();
581 let iv = random_iv();
582 let plaintext = b"Another secret message with random IV!";
583
584 let ciphertext = aes_cbc_encrypt(&key, &iv, plaintext).unwrap();
585 let decrypted = aes_cbc_decrypt(&key, &iv, &ciphertext).unwrap();
586
587 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
588 }
589
590 #[test]
591 fn test_static_key_exchange() {
592 let alice = StaticKeyExchange::new();
593 let bob = StaticKeyExchange::new();
594
595 let alice_public = alice.public_key_bytes();
596 let bob_public = bob.public_key_bytes();
597
598 let alice_shared1 = alice.derive_shared_secret(&bob_public);
600 let alice_shared2 = alice.derive_shared_secret(&bob_public);
601 let bob_shared = bob.derive_shared_secret(&alice_public);
602
603 assert_eq!(alice_shared1, alice_shared2);
604 assert_eq!(alice_shared1, bob_shared);
605 }
606
607 #[test]
608 fn test_derive_relay_key() {
609 let alice = StaticKeyExchange::new();
610 let bob = StaticKeyExchange::new();
611
612 let alice_public = alice.public_key_bytes();
613 let bob_public = bob.public_key_bytes();
614
615 let alice_key = alice.derive_relay_encryption_key(&bob_public).unwrap();
616 let bob_key = bob.derive_relay_encryption_key(&alice_public).unwrap();
617
618 assert_eq!(alice_key, bob_key);
619 assert_eq!(alice_key.len(), AES_KEY_LENGTH);
620 }
621
622 #[test]
623 fn test_full_relay_encryption_flow() {
624 let client = StaticKeyExchange::new();
626 let server = StaticKeyExchange::new();
627
628 let client_public = client.public_key_bytes();
630 let server_public = server.public_key_bytes();
631
632 let client_key = client.derive_relay_encryption_key(&server_public).unwrap();
634 let server_key = server.derive_relay_encryption_key(&client_public).unwrap();
635
636 assert_eq!(client_key, server_key);
638
639 let iv = random_iv();
641 let message = b"Hello from client!";
642 let encrypted = aes_cbc_encrypt(&client_key, &iv, message).unwrap();
643
644 let decrypted = aes_cbc_decrypt(&server_key, &iv, &encrypted).unwrap();
646 assert_eq!(message.as_slice(), decrypted.as_slice());
647 }
648}