use rustls::crypto::aws_lc_rs::{default_provider, kx_group};
use rustls::crypto::{
ActiveKeyExchange, CompletedKeyExchange, CryptoProvider, SharedSecret, SupportedKxGroup,
};
use rustls::{Error, NamedGroup, PeerMisbehaved};
use aws_lc_rs::kem;
use aws_lc_rs::unstable::kem::{get_algorithm, AlgorithmId};
pub fn provider() -> CryptoProvider {
let mut parent = default_provider();
parent
.kx_groups
.insert(0, &X25519Kyber768Draft00);
parent
}
#[derive(Debug)]
pub struct X25519Kyber768Draft00;
impl SupportedKxGroup for X25519Kyber768Draft00 {
fn start(&self) -> Result<Box<dyn ActiveKeyExchange>, Error> {
let x25519 = kx_group::X25519.start()?;
let kyber = kem::DecapsulationKey::generate(kyber768_r3())
.map_err(|_| Error::FailedToGetRandomBytes)?;
let kyber_pub = kyber
.encapsulation_key()
.map_err(|_| Error::FailedToGetRandomBytes)?;
let mut combined_pub_key = Vec::with_capacity(COMBINED_PUBKEY_LEN);
combined_pub_key.extend_from_slice(x25519.pub_key());
combined_pub_key.extend_from_slice(kyber_pub.key_bytes().unwrap().as_ref());
Ok(Box::new(Active {
x25519,
decap_key: Box::new(kyber),
combined_pub_key,
}))
}
fn start_and_complete(&self, client_share: &[u8]) -> Result<CompletedKeyExchange, Error> {
let share = match ReceivedShare::new(client_share) {
Some(share) => share,
None => return Err(INVALID_KEY_SHARE),
};
let x25519 = kx_group::X25519.start_and_complete(share.x25519)?;
let (kyber_share, kyber_secret) = kem::EncapsulationKey::new(kyber768_r3(), share.kyber)
.map_err(|_| INVALID_KEY_SHARE)
.and_then(|pk| {
pk.encapsulate()
.map_err(|_| INVALID_KEY_SHARE)
})?;
let combined_secret = CombinedSecret::combine(x25519.secret, kyber_secret);
let combined_share = CombinedShare::combine(&x25519.pub_key, kyber_share);
Ok(CompletedKeyExchange {
group: self.name(),
pub_key: combined_share.0,
secret: SharedSecret::from(&combined_secret.0[..]),
})
}
fn name(&self) -> NamedGroup {
NAMED_GROUP
}
}
struct Active {
x25519: Box<dyn ActiveKeyExchange>,
decap_key: Box<kem::DecapsulationKey<AlgorithmId>>,
combined_pub_key: Vec<u8>,
}
impl ActiveKeyExchange for Active {
fn complete(self: Box<Self>, peer_pub_key: &[u8]) -> Result<SharedSecret, Error> {
let ciphertext = match ReceivedCiphertext::new(peer_pub_key) {
Some(ct) => ct,
None => {
return Err(INVALID_KEY_SHARE);
}
};
let combined = CombinedSecret::combine(
self.x25519
.complete(ciphertext.x25519)?,
self.decap_key
.decapsulate(ciphertext.kyber.into())
.map_err(|_| INVALID_KEY_SHARE)?,
);
Ok(SharedSecret::from(&combined.0[..]))
}
fn pub_key(&self) -> &[u8] {
&self.combined_pub_key
}
fn group(&self) -> NamedGroup {
NAMED_GROUP
}
}
struct ReceivedShare<'a> {
x25519: &'a [u8],
kyber: &'a [u8],
}
impl<'a> ReceivedShare<'a> {
fn new(buf: &'a [u8]) -> Option<ReceivedShare<'a>> {
if buf.len() != COMBINED_PUBKEY_LEN {
return None;
}
let (x25519, kyber) = buf.split_at(X25519_LEN);
Some(ReceivedShare { x25519, kyber })
}
}
struct ReceivedCiphertext<'a> {
x25519: &'a [u8],
kyber: &'a [u8],
}
impl<'a> ReceivedCiphertext<'a> {
fn new(buf: &'a [u8]) -> Option<ReceivedCiphertext<'a>> {
if buf.len() != COMBINED_CIPHERTEXT_LEN {
return None;
}
let (x25519, kyber) = buf.split_at(X25519_LEN);
Some(ReceivedCiphertext { x25519, kyber })
}
}
struct CombinedSecret([u8; COMBINED_SHARED_SECRET_LEN]);
impl CombinedSecret {
fn combine(x25519: SharedSecret, kyber: kem::SharedSecret) -> Self {
let mut out = CombinedSecret([0u8; COMBINED_SHARED_SECRET_LEN]);
out.0[..X25519_LEN].copy_from_slice(x25519.secret_bytes());
out.0[X25519_LEN..].copy_from_slice(kyber.as_ref());
out
}
}
struct CombinedShare(Vec<u8>);
impl CombinedShare {
fn combine(x25519: &[u8], kyber: kem::Ciphertext) -> Self {
let mut out = CombinedShare(vec![0u8; COMBINED_CIPHERTEXT_LEN]);
out.0[..X25519_LEN].copy_from_slice(x25519);
out.0[X25519_LEN..].copy_from_slice(kyber.as_ref());
out
}
}
fn kyber768_r3() -> &'static kem::Algorithm<AlgorithmId> {
get_algorithm(AlgorithmId::Kyber768_R3).expect("Kyber768_R3 not available")
}
const NAMED_GROUP: NamedGroup = NamedGroup::Unknown(0x6399);
const INVALID_KEY_SHARE: Error = Error::PeerMisbehaved(PeerMisbehaved::InvalidKeyShare);
const X25519_LEN: usize = 32;
const KYBER_CIPHERTEXT_LEN: usize = 1088;
const COMBINED_PUBKEY_LEN: usize = X25519_LEN + 1184;
const COMBINED_CIPHERTEXT_LEN: usize = X25519_LEN + KYBER_CIPHERTEXT_LEN;
const COMBINED_SHARED_SECRET_LEN: usize = X25519_LEN + 32;