#[cfg(feature = "serialization")]
mod serde_;
#[cfg(test)]
mod tests;
#[cfg(feature = "vapid")]
mod vapid;
#[cfg(feature = "vapid")]
pub use jwt_simple;
pub use p256;
use aes_gcm::aead::{
generic_array::{typenum::U16, GenericArray},
rand_core::RngCore,
OsRng,
};
use hkdf::Hkdf;
use http::{self, header, Request, Uri};
use p256::elliptic_curve::sec1::ToEncodedPoint;
use sha2::Sha256;
use std::time::Duration;
#[derive(Debug)]
pub enum Error {
ECE(ece_native::Error),
Extension(Box<dyn std::error::Error + Send + Sync + 'static>),
}
impl std::error::Error for Error {}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::ECE(ece) => write!(f, "ece: {}", ece),
Error::Extension(ext) => write!(f, "extension: {}", ext),
}
}
}
pub type Auth = GenericArray<u8, U16>;
#[derive(Clone, Debug)]
pub struct WebPushBuilder<A = ()> {
endpoint: Uri,
valid_duration: Duration,
ua_public: p256::PublicKey,
ua_auth: Auth,
#[cfg_attr(not(feature = "vapid"), allow(dead_code))]
http_auth: A,
}
impl WebPushBuilder {
pub fn new(endpoint: Uri, ua_public: p256::PublicKey, ua_auth: Auth) -> Self {
Self {
endpoint,
ua_public,
ua_auth,
valid_duration: Duration::from_secs(12 * 60 * 60),
http_auth: (),
}
}
pub fn with_valid_duration(self, valid_duration: Duration) -> Self {
let mut this = self;
this.valid_duration = valid_duration;
this
}
#[cfg(feature = "vapid")]
pub fn with_vapid<'a>(
self,
vapid_kp: &'a jwt_simple::algorithms::ES256KeyPair,
contact: &'a str,
) -> WebPushBuilder<vapid::VapidAuthorization<'a>> {
WebPushBuilder {
endpoint: self.endpoint,
valid_duration: self.valid_duration,
ua_public: self.ua_public,
ua_auth: self.ua_auth,
http_auth: vapid::VapidAuthorization::new(vapid_kp, contact),
}
}
}
#[doc(hidden)]
pub trait AddHeaders: Sized {
type Error: Into<Box<dyn std::error::Error + Sync + Send + 'static>>;
fn add_headers(
this: &WebPushBuilder<Self>,
builder: http::request::Builder,
) -> Result<http::request::Builder, Self::Error>;
}
impl AddHeaders for () {
type Error = std::convert::Infallible;
fn add_headers(
_this: &WebPushBuilder<Self>,
builder: http::request::Builder,
) -> Result<http::request::Builder, Self::Error> {
Ok(builder)
}
}
impl<A: AddHeaders> WebPushBuilder<A> {
pub fn build<T: Into<Vec<u8>>>(&self, body: T) -> Result<Request<Vec<u8>>, Error> {
let body = body.into();
let payload = encrypt(body, &self.ua_public, &self.ua_auth)?;
let builder = Request::builder()
.uri(self.endpoint.clone())
.method(http::method::Method::POST)
.header("TTL", self.valid_duration.as_secs())
.header(header::CONTENT_ENCODING, "aes128gcm")
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(header::CONTENT_LENGTH, payload.len());
let builder =
AddHeaders::add_headers(self, builder).map_err(|it| Error::Extension(it.into()))?;
Ok(builder
.body(payload)
.expect("builder arguments are always well-defined"))
}
}
pub fn encrypt(
message: Vec<u8>,
ua_public: &p256::PublicKey,
ua_auth: &Auth,
) -> Result<Vec<u8>, Error> {
let mut salt = [0u8; 16];
OsRng.fill_bytes(&mut salt);
let as_secret = p256::SecretKey::random(&mut OsRng);
encrypt_predictably(salt, message, &as_secret, ua_public, ua_auth).map_err(Error::ECE)
}
fn encrypt_predictably(
salt: [u8; 16],
message: Vec<u8>,
as_secret: &p256::SecretKey,
ua_public: &p256::PublicKey,
ua_auth: &Auth,
) -> Result<Vec<u8>, ece_native::Error> {
let as_public = as_secret.public_key();
let shared = p256::ecdh::diffie_hellman(as_secret.to_nonzero_scalar(), ua_public.as_affine());
let ikm = compute_ikm(ua_auth, &shared, ua_public, &as_public);
let keyid = as_public.as_affine().to_encoded_point(false);
let encrypted_record_length = (message.len() + 17)
.try_into()
.map_err(|_| ece_native::Error::RecordLengthInvalid)?;
ece_native::encrypt(
ikm,
salt,
keyid,
Some(message).into_iter(),
encrypted_record_length,
)
}
pub fn decrypt(
encrypted_message: Vec<u8>,
as_secret: &p256::SecretKey,
ua_auth: &Auth,
) -> Result<Vec<u8>, Error> {
let keyid = view_keyid(&encrypted_message).map_err(Error::ECE)?;
let ua_public = p256::PublicKey::from_sec1_bytes(keyid)
.map_err(|_| ece_native::Error::Aes128Gcm)
.map_err(Error::ECE)?;
let shared = p256::ecdh::diffie_hellman(as_secret.to_nonzero_scalar(), ua_public.as_affine());
let ikm = compute_ikm(ua_auth, &shared, &as_secret.public_key(), &ua_public);
ece_native::decrypt(ikm, encrypted_message).map_err(Error::ECE)
}
fn compute_ikm(
auth: &Auth,
shared: &p256::ecdh::SharedSecret,
ua_public: &p256::PublicKey,
as_public: &p256::PublicKey,
) -> [u8; 32] {
let mut info = Vec::new();
info.extend_from_slice(&b"WebPush: info"[..]);
info.push(0u8);
info.extend_from_slice(ua_public.as_affine().to_encoded_point(false).as_bytes());
info.extend_from_slice(as_public.as_affine().to_encoded_point(false).as_bytes());
let mut okm = [0u8; 32];
let hk = Hkdf::<Sha256>::new(Some(auth), shared.raw_secret_bytes().as_ref());
hk.expand(&info, &mut okm)
.expect("okm length is always 32 bytes, cannot be too large");
okm
}
fn view_keyid(encrypted_message: &[u8]) -> Result<&[u8], ece_native::Error> {
if encrypted_message.len() < 21 {
return Err(ece_native::Error::HeaderLengthInvalid);
}
let idlen: usize = encrypted_message[20].into();
if encrypted_message[21..].len() < idlen {
return Err(ece_native::Error::KeyIdLengthInvalid);
}
Ok(&encrypted_message[21..21 + idlen])
}