1use crate::types::PushSubscription;
24use base64::engine::general_purpose::URL_SAFE_NO_PAD;
25use base64::Engine as _;
26use jsonwebtoken::{Algorithm, EncodingKey, Header};
27use p256::ecdsa::SigningKey;
28use p256::pkcs8::{DecodePrivateKey, EncodePrivateKey};
29use serde::{Deserialize, Serialize};
30use std::time::{SystemTime, UNIX_EPOCH};
31use thiserror::Error;
32
33#[derive(Debug, Error)]
35pub enum WebPushError {
36 #[error("push endpoint returned 410 Gone")]
38 Gone,
39
40 #[error("HTTP transport error: {0}")]
42 Http(#[from] reqwest::Error),
43
44 #[error("VAPID JWT signing error: {0}")]
46 JwtSigning(#[from] jsonwebtoken::errors::Error),
47
48 #[error("VAPID key error: {0}")]
50 KeyError(String),
51
52 #[error("push endpoint returned unexpected status {0}")]
54 UnexpectedStatus(u16),
55}
56
57#[derive(Debug, Serialize, Deserialize)]
59struct VapidClaims {
60 aud: String,
62 sub: String,
64 exp: u64,
66}
67
68#[derive(Clone)]
72pub struct WebPushClient {
73 http: reqwest::Client,
74 vapid_key: std::sync::Arc<SigningKey>,
76 vapid_pubkey_base64url: String,
78 admin_sub: String,
80}
81
82impl std::fmt::Debug for WebPushClient {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_struct("WebPushClient")
85 .field("vapid_pubkey_base64url", &self.vapid_pubkey_base64url)
86 .field("admin_sub", &self.admin_sub)
87 .finish_non_exhaustive()
88 }
89}
90
91impl WebPushClient {
92 pub fn new(vapid_pem: Option<&[u8]>, admin_email: &str) -> Result<Self, WebPushError> {
96 let signing_key = match vapid_pem {
97 Some(pem_bytes) => {
98 let pem_str = std::str::from_utf8(pem_bytes)
99 .map_err(|e| WebPushError::KeyError(format!("PEM is not valid UTF-8: {e}")))?;
100 SigningKey::from_pkcs8_pem(pem_str)
101 .map_err(|e| WebPushError::KeyError(format!("Failed to load VAPID key: {e}")))?
102 }
103 None => {
104 let mut rng_buf = [0u8; 32];
106 getrandom::fill(&mut rng_buf)
107 .map_err(|e| WebPushError::KeyError(format!("RNG failure: {e}")))?;
108 SigningKey::from_slice(&rng_buf)
110 .map_err(|e| WebPushError::KeyError(format!("Key generation failed: {e}")))?
111 }
112 };
113
114 let pubkey_bytes = p256::ecdsa::VerifyingKey::from(&signing_key)
115 .to_encoded_point(false)
116 .as_bytes()
117 .to_vec();
118 let vapid_pubkey_base64url = URL_SAFE_NO_PAD.encode(&pubkey_bytes);
119
120 let admin_sub = if admin_email.contains('@') {
121 format!("mailto:{admin_email}")
122 } else {
123 admin_email.to_string()
124 };
125
126 Ok(Self {
127 http: reqwest::Client::builder()
128 .timeout(std::time::Duration::from_secs(30))
129 .build()
130 .map_err(WebPushError::Http)?,
131 vapid_key: std::sync::Arc::new(signing_key),
132 vapid_pubkey_base64url,
133 admin_sub,
134 })
135 }
136
137 pub fn new_with_persistence(
144 key_path: Option<&std::path::Path>,
145 admin_email: &str,
146 ) -> Result<Self, WebPushError> {
147 match key_path {
148 None => Self::new(None, admin_email),
149 Some(path) if path.exists() => {
150 let pem = std::fs::read(path).map_err(|e| {
151 WebPushError::KeyError(format!("Cannot read VAPID key file: {e}"))
152 })?;
153 Self::new(Some(&pem), admin_email)
154 }
155 Some(path) => {
156 let client = Self::new(None, admin_email)?;
158 let pem = client
159 .vapid_key
160 .to_pkcs8_pem(Default::default())
161 .map_err(|e| {
162 WebPushError::KeyError(format!("PEM serialization failed: {e}"))
163 })?;
164 if let Some(parent) = path.parent() {
165 std::fs::create_dir_all(parent).map_err(|e| {
166 WebPushError::KeyError(format!("Cannot create VAPID key dir: {e}"))
167 })?;
168 }
169 std::fs::write(path, pem.as_bytes())
170 .map_err(|e| WebPushError::KeyError(format!("Cannot write VAPID key: {e}")))?;
171 Ok(client)
172 }
173 }
174 }
175
176 pub fn vapid_pubkey_base64url(&self) -> &str {
182 &self.vapid_pubkey_base64url
183 }
184
185 pub(crate) fn build_vapid_jwt(&self, endpoint_origin: &str) -> Result<String, WebPushError> {
189 let now = SystemTime::now()
190 .duration_since(UNIX_EPOCH)
191 .map_err(|e| WebPushError::KeyError(format!("System clock error: {e}")))?;
192 let exp = now.as_secs() + 86_400; let claims = VapidClaims {
195 aud: endpoint_origin.to_string(),
196 sub: self.admin_sub.clone(),
197 exp,
198 };
199
200 let header = Header::new(Algorithm::ES256);
201
202 let der = self
204 .vapid_key
205 .to_pkcs8_der()
206 .map_err(|e| WebPushError::KeyError(format!("Key DER export failed: {e}")))?;
207 let encoding_key = EncodingKey::from_ec_der(der.as_bytes());
208
209 let token = jsonwebtoken::encode(&header, &claims, &encoding_key)?;
210 Ok(token)
211 }
212
213 pub async fn send(
219 &self,
220 subscription: &PushSubscription,
221 payload: &[u8],
222 ) -> Result<(), WebPushError> {
223 let endpoint_origin = extract_origin(&subscription.url).ok_or_else(|| {
226 WebPushError::KeyError(format!(
227 "Cannot determine origin from URL: {}",
228 subscription.url
229 ))
230 })?;
231
232 let jwt = self.build_vapid_jwt(&endpoint_origin)?;
233
234 let authorization = format!("vapid t={},k={}", jwt, self.vapid_pubkey_base64url);
236
237 const TTL_SECONDS: u32 = 86_400;
239
240 let body = if subscription.keys.is_none() || payload.is_empty() {
244 bytes::Bytes::new()
245 } else {
246 bytes::Bytes::new()
248 };
249
250 let response = self
251 .http
252 .post(&subscription.url)
253 .header("Authorization", authorization)
254 .header("TTL", TTL_SECONDS.to_string())
255 .header("Content-Type", "application/octet-stream")
256 .body(body)
257 .send()
258 .await?;
259
260 let status = response.status().as_u16();
261 match status {
262 200..=299 => Ok(()),
263 410 => Err(WebPushError::Gone),
264 other => Err(WebPushError::UnexpectedStatus(other)),
265 }
266 }
267}
268
269fn extract_origin(url: &str) -> Option<String> {
273 let after_scheme = url.split_once("://")?.1;
276 let host_and_rest = after_scheme.split('/').next()?;
277 let scheme = url.split("://").next()?;
278 Some(format!("{scheme}://{host_and_rest}"))
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_extract_origin_https() {
287 let url = "https://push.example.com/v1/subscriptions/abc123";
288 assert_eq!(
289 extract_origin(url),
290 Some("https://push.example.com".to_string())
291 );
292 }
293
294 #[test]
295 fn test_extract_origin_with_port() {
296 let url = "https://push.example.com:8443/endpoint";
297 assert_eq!(
298 extract_origin(url),
299 Some("https://push.example.com:8443".to_string())
300 );
301 }
302
303 #[test]
304 fn test_vapid_client_ephemeral_key() {
305 let client = WebPushClient::new(None, "admin@example.com").unwrap();
306 assert!(!client.vapid_pubkey_base64url().is_empty());
307 assert!(client.admin_sub.starts_with("mailto:"));
308 }
309
310 #[test]
311 fn test_build_vapid_jwt() {
312 let client = WebPushClient::new(None, "admin@example.com").unwrap();
313 let jwt = client.build_vapid_jwt("https://push.example.com").unwrap();
314 let parts: Vec<&str> = jwt.split('.').collect();
316 assert_eq!(parts.len(), 3, "JWT must have header.payload.signature");
317 }
318}