1#[cfg(feature = "serialization")]
53mod serde_;
54#[cfg(test)]
55mod tests;
56#[cfg(feature = "vapid")]
57mod vapid;
58
59#[cfg(feature = "vapid")]
60pub use jwt_simple;
61pub use p256;
62
63use aes_gcm::aead::{
64 generic_array::{typenum::U16, GenericArray},
65 rand_core::RngCore,
66 OsRng,
67};
68use hkdf::Hkdf;
69use http::{self, header, Request, Uri};
70use p256::elliptic_curve::sec1::ToEncodedPoint;
71use sha2::Sha256;
72use std::time::Duration;
73
74#[derive(Debug)]
76pub enum Error {
77 ECE(ece_native::Error),
79 Extension(Box<dyn std::error::Error + Send + Sync + 'static>),
81}
82
83impl std::error::Error for Error {}
84
85impl std::fmt::Display for Error {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 match self {
88 Error::ECE(ece) => write!(f, "ece: {}", ece),
89 Error::Extension(ext) => write!(f, "extension: {}", ext),
90 }
91 }
92}
93
94pub type Auth = GenericArray<u8, U16>;
96
97#[derive(Clone, Debug)]
99pub struct WebPushBuilder<A = ()> {
100 endpoint: Uri,
101 valid_duration: Duration,
102 ua_public: p256::PublicKey,
103 ua_auth: Auth,
104 #[cfg_attr(not(feature = "vapid"), allow(dead_code))]
105 http_auth: A,
106}
107
108impl WebPushBuilder {
109 pub fn new(endpoint: Uri, ua_public: p256::PublicKey, ua_auth: Auth) -> Self {
118 Self {
119 endpoint,
120 ua_public,
121 ua_auth,
122 valid_duration: Duration::from_secs(12 * 60 * 60),
123 http_auth: (),
124 }
125 }
126
127 pub fn with_valid_duration(self, valid_duration: Duration) -> Self {
129 let mut this = self;
130 this.valid_duration = valid_duration;
131 this
132 }
133
134 #[cfg(feature = "vapid")]
136 pub fn with_vapid<'a>(
137 self,
138 vapid_kp: &'a jwt_simple::algorithms::ES256KeyPair,
139 contact: &'a str,
140 ) -> WebPushBuilder<vapid::VapidAuthorization<'a>> {
141 WebPushBuilder {
142 endpoint: self.endpoint,
143 valid_duration: self.valid_duration,
144 ua_public: self.ua_public,
145 ua_auth: self.ua_auth,
146 http_auth: vapid::VapidAuthorization::new(vapid_kp, contact),
147 }
148 }
149}
150
151#[doc(hidden)]
152pub trait AddHeaders: Sized {
153 type Error: Into<Box<dyn std::error::Error + Sync + Send + 'static>>;
154
155 fn add_headers(
156 this: &WebPushBuilder<Self>,
157 builder: http::request::Builder,
158 ) -> Result<http::request::Builder, Self::Error>;
159}
160
161impl AddHeaders for () {
162 type Error = std::convert::Infallible;
163
164 fn add_headers(
165 _this: &WebPushBuilder<Self>,
166 builder: http::request::Builder,
167 ) -> Result<http::request::Builder, Self::Error> {
168 Ok(builder)
169 }
170}
171
172impl<A: AddHeaders> WebPushBuilder<A> {
173 pub fn build<T: Into<Vec<u8>>>(&self, body: T) -> Result<Request<Vec<u8>>, Error> {
176 let body = body.into();
177
178 let payload = encrypt(body, &self.ua_public, &self.ua_auth)?;
179 let builder = Request::builder()
180 .uri(self.endpoint.clone())
181 .method(http::method::Method::POST)
182 .header("TTL", self.valid_duration.as_secs())
183 .header(header::CONTENT_ENCODING, "aes128gcm")
184 .header(header::CONTENT_TYPE, "application/octet-stream")
185 .header(header::CONTENT_LENGTH, payload.len());
186
187 let builder =
188 AddHeaders::add_headers(self, builder).map_err(|it| Error::Extension(it.into()))?;
189
190 Ok(builder
191 .body(payload)
192 .expect("builder arguments are always well-defined"))
193 }
194}
195
196pub fn encrypt(
198 message: Vec<u8>,
199 ua_public: &p256::PublicKey,
200 ua_auth: &Auth,
201) -> Result<Vec<u8>, Error> {
202 let mut salt = [0u8; 16];
203 OsRng.fill_bytes(&mut salt);
204 let as_secret = p256::SecretKey::random(&mut OsRng);
205 encrypt_predictably(salt, message, &as_secret, ua_public, ua_auth).map_err(Error::ECE)
206}
207
208fn encrypt_predictably(
209 salt: [u8; 16],
210 message: Vec<u8>,
211 as_secret: &p256::SecretKey,
212 ua_public: &p256::PublicKey,
213 ua_auth: &Auth,
214) -> Result<Vec<u8>, ece_native::Error> {
215 let as_public = as_secret.public_key();
216 let shared = p256::ecdh::diffie_hellman(as_secret.to_nonzero_scalar(), ua_public.as_affine());
217
218 let ikm = compute_ikm(ua_auth, &shared, ua_public, &as_public);
219 let keyid = as_public.as_affine().to_encoded_point(false);
220 let encrypted_record_length = (message.len() + 17)
221 .try_into()
222 .map_err(|_| ece_native::Error::RecordLengthInvalid)?;
223
224 ece_native::encrypt(
225 ikm,
226 salt,
227 keyid,
228 Some(message).into_iter(),
229 encrypted_record_length,
230 )
231}
232
233pub fn decrypt(
235 encrypted_message: Vec<u8>,
236 as_secret: &p256::SecretKey,
237 ua_auth: &Auth,
238) -> Result<Vec<u8>, Error> {
239 let keyid = view_keyid(&encrypted_message).map_err(Error::ECE)?;
240 let ua_public = p256::PublicKey::from_sec1_bytes(keyid)
241 .map_err(|_| ece_native::Error::Aes128Gcm)
242 .map_err(Error::ECE)?;
243 let shared = p256::ecdh::diffie_hellman(as_secret.to_nonzero_scalar(), ua_public.as_affine());
244
245 let ikm = compute_ikm(ua_auth, &shared, &as_secret.public_key(), &ua_public);
246
247 ece_native::decrypt(ikm, encrypted_message).map_err(Error::ECE)
248}
249
250fn compute_ikm(
251 auth: &Auth,
252 shared: &p256::ecdh::SharedSecret,
253 ua_public: &p256::PublicKey,
254 as_public: &p256::PublicKey,
255) -> [u8; 32] {
256 let mut info = Vec::new();
257 info.extend_from_slice(&b"WebPush: info"[..]);
258 info.push(0u8);
259 info.extend_from_slice(ua_public.as_affine().to_encoded_point(false).as_bytes());
260 info.extend_from_slice(as_public.as_affine().to_encoded_point(false).as_bytes());
261
262 let mut okm = [0u8; 32];
263 let hk = Hkdf::<Sha256>::new(Some(auth), shared.raw_secret_bytes().as_ref());
264 hk.expand(&info, &mut okm)
265 .expect("okm length is always 32 bytes, cannot be too large");
266
267 okm
268}
269
270fn view_keyid(encrypted_message: &[u8]) -> Result<&[u8], ece_native::Error> {
271 if encrypted_message.len() < 21 {
272 return Err(ece_native::Error::HeaderLengthInvalid);
273 }
274
275 let idlen: usize = encrypted_message[20].into();
276 if encrypted_message[21..].len() < idlen {
277 return Err(ece_native::Error::KeyIdLengthInvalid);
278 }
279
280 Ok(&encrypted_message[21..21 + idlen])
281}