Skip to main content

rusmes_jmap/
web_push.rs

1//! WebPush client for RFC 8030 push + RFC 8444 VAPID.
2//!
3//! # Implementation decision
4//!
5//! Dependencies chosen: `p256` (ECDSA key generation/loading), `jsonwebtoken`
6//! (ES256 JWT signing — already a workspace dep with `rust_crypto` feature
7//! which pulls in `p256`), `reqwest` (HTTP POST), `base64` (URL-safe encoding).
8//!
9//! The `web-push` crate was NOT used: it pulls in `openssl-sys` via `openssl`
10//! as a transitive dependency on some targets, violating the COOLJAPAN Pure
11//! Rust policy.  All chosen crates are 100% Pure Rust.
12//!
13//! # RFC 8291 encryption
14//!
15//! Message-Encryption (RFC 8291 — AES-128-GCM + ECDH + HKDF) is **deferred**.
16//! When a subscription includes `keys`, the server still sends an unencrypted
17//! "tickle" (zero-byte body) so the client is woken up.  RFC 8291 encryption
18//! will be implemented in a follow-up slice once the key-agreement primitives
19//! are fully stabilised in the workspace.  The `PushSubscription.keys` field
20//! is preserved in storage so the feature can be enabled without breaking
21//! existing subscriptions.
22
23use crate::types::PushSubscription;
24use base64::engine::general_purpose::URL_SAFE_NO_PAD;
25use base64::Engine as _;
26use jsonwebtoken::{Algorithm, EncodingKey, Header};
27use p256::ecdsa::SigningKey;
28use p256::pkcs8::{DecodePrivateKey, EncodePrivateKey};
29use serde::{Deserialize, Serialize};
30use std::time::{SystemTime, UNIX_EPOCH};
31use thiserror::Error;
32
33/// Errors that can occur during WebPush operations.
34#[derive(Debug, Error)]
35pub enum WebPushError {
36    /// The push endpoint returned HTTP 410 Gone — subscription should be removed.
37    #[error("push endpoint returned 410 Gone")]
38    Gone,
39
40    /// HTTP transport error (connection refused, timeout, etc.)
41    #[error("HTTP transport error: {0}")]
42    Http(#[from] reqwest::Error),
43
44    /// JWT signing failed.
45    #[error("VAPID JWT signing error: {0}")]
46    JwtSigning(#[from] jsonwebtoken::errors::Error),
47
48    /// Private key PEM could not be loaded or generated.
49    #[error("VAPID key error: {0}")]
50    KeyError(String),
51
52    /// The push endpoint returned an unexpected non-2xx status (not 410).
53    #[error("push endpoint returned unexpected status {0}")]
54    UnexpectedStatus(u16),
55}
56
57/// VAPID JWT claims (RFC 8444).
58#[derive(Debug, Serialize, Deserialize)]
59struct VapidClaims {
60    /// Audience: origin of the push endpoint URL.
61    aud: String,
62    /// Subject: `mailto:` URI or URL for the push provider to contact.
63    sub: String,
64    /// Expiry: UNIX timestamp.
65    exp: u64,
66}
67
68/// WebPush client.
69///
70/// Cheap to clone (all fields are `Arc`-backed).
71#[derive(Clone)]
72pub struct WebPushClient {
73    http: reqwest::Client,
74    /// ECDSA signing key (P-256) used for VAPID JWTs.
75    vapid_key: std::sync::Arc<SigningKey>,
76    /// Base64url-encoded uncompressed public key point (used in `Crypto-Key` header).
77    vapid_pubkey_base64url: String,
78    /// `mailto:` or HTTPS subject for VAPID `sub` claim.
79    admin_sub: String,
80}
81
82impl std::fmt::Debug for WebPushClient {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.debug_struct("WebPushClient")
85            .field("vapid_pubkey_base64url", &self.vapid_pubkey_base64url)
86            .field("admin_sub", &self.admin_sub)
87            .finish_non_exhaustive()
88    }
89}
90
91impl WebPushClient {
92    /// Construct from an existing PEM-encoded P-256 EC private key.
93    ///
94    /// If `vapid_pem` is `None`, a fresh key is generated in memory.
95    pub fn new(vapid_pem: Option<&[u8]>, admin_email: &str) -> Result<Self, WebPushError> {
96        let signing_key = match vapid_pem {
97            Some(pem_bytes) => {
98                let pem_str = std::str::from_utf8(pem_bytes)
99                    .map_err(|e| WebPushError::KeyError(format!("PEM is not valid UTF-8: {e}")))?;
100                SigningKey::from_pkcs8_pem(pem_str)
101                    .map_err(|e| WebPushError::KeyError(format!("Failed to load VAPID key: {e}")))?
102            }
103            None => {
104                // Generate an ephemeral key.
105                let mut rng_buf = [0u8; 32];
106                getrandom::fill(&mut rng_buf)
107                    .map_err(|e| WebPushError::KeyError(format!("RNG failure: {e}")))?;
108                // Use the raw bytes as a scalar (this is deterministic from the entropy).
109                SigningKey::from_slice(&rng_buf)
110                    .map_err(|e| WebPushError::KeyError(format!("Key generation failed: {e}")))?
111            }
112        };
113
114        let pubkey_bytes = p256::ecdsa::VerifyingKey::from(&signing_key)
115            .to_encoded_point(false)
116            .as_bytes()
117            .to_vec();
118        let vapid_pubkey_base64url = URL_SAFE_NO_PAD.encode(&pubkey_bytes);
119
120        let admin_sub = if admin_email.contains('@') {
121            format!("mailto:{admin_email}")
122        } else {
123            admin_email.to_string()
124        };
125
126        Ok(Self {
127            http: reqwest::Client::builder()
128                .timeout(std::time::Duration::from_secs(30))
129                .build()
130                .map_err(WebPushError::Http)?,
131            vapid_key: std::sync::Arc::new(signing_key),
132            vapid_pubkey_base64url,
133            admin_sub,
134        })
135    }
136
137    /// Load or generate the VAPID key, optionally persisting it.
138    ///
139    /// - If `key_path` is `Some` and the file exists: loads the PEM.
140    /// - If `key_path` is `Some` and the file does NOT exist: generates a new
141    ///   key and writes it to the path so subsequent restarts use the same key.
142    /// - If `key_path` is `None`: generates an ephemeral in-memory key.
143    pub fn new_with_persistence(
144        key_path: Option<&std::path::Path>,
145        admin_email: &str,
146    ) -> Result<Self, WebPushError> {
147        match key_path {
148            None => Self::new(None, admin_email),
149            Some(path) if path.exists() => {
150                let pem = std::fs::read(path).map_err(|e| {
151                    WebPushError::KeyError(format!("Cannot read VAPID key file: {e}"))
152                })?;
153                Self::new(Some(&pem), admin_email)
154            }
155            Some(path) => {
156                // Generate and persist.
157                let client = Self::new(None, admin_email)?;
158                let pem = client
159                    .vapid_key
160                    .to_pkcs8_pem(Default::default())
161                    .map_err(|e| {
162                        WebPushError::KeyError(format!("PEM serialization failed: {e}"))
163                    })?;
164                if let Some(parent) = path.parent() {
165                    std::fs::create_dir_all(parent).map_err(|e| {
166                        WebPushError::KeyError(format!("Cannot create VAPID key dir: {e}"))
167                    })?;
168                }
169                std::fs::write(path, pem.as_bytes())
170                    .map_err(|e| WebPushError::KeyError(format!("Cannot write VAPID key: {e}")))?;
171                Ok(client)
172            }
173        }
174    }
175
176    /// Return the base64url-encoded uncompressed VAPID public key.
177    ///
178    /// This is the value that must be shared with push endpoints so they can
179    /// verify VAPID JWTs.  Exposed for tests and for generating the VAPID public
180    /// key header in `Crypto-Key`.
181    pub fn vapid_pubkey_base64url(&self) -> &str {
182        &self.vapid_pubkey_base64url
183    }
184
185    /// Build a VAPID JWT for the given push endpoint origin.
186    ///
187    /// The JWT is valid for 24 hours from now (RFC 8444 §3).
188    pub(crate) fn build_vapid_jwt(&self, endpoint_origin: &str) -> Result<String, WebPushError> {
189        let now = SystemTime::now()
190            .duration_since(UNIX_EPOCH)
191            .map_err(|e| WebPushError::KeyError(format!("System clock error: {e}")))?;
192        let exp = now.as_secs() + 86_400; // 24 hours
193
194        let claims = VapidClaims {
195            aud: endpoint_origin.to_string(),
196            sub: self.admin_sub.clone(),
197            exp,
198        };
199
200        let header = Header::new(Algorithm::ES256);
201
202        // Convert the p256 SigningKey to DER for jsonwebtoken.
203        let der = self
204            .vapid_key
205            .to_pkcs8_der()
206            .map_err(|e| WebPushError::KeyError(format!("Key DER export failed: {e}")))?;
207        let encoding_key = EncodingKey::from_ec_der(der.as_bytes());
208
209        let token = jsonwebtoken::encode(&header, &claims, &encoding_key)?;
210        Ok(token)
211    }
212
213    /// Send a WebPush message to `subscription`.
214    ///
215    /// When `subscription.keys` is `None` (or RFC 8291 encryption is not yet
216    /// implemented) the body is left empty (tickle semantics).  The VAPID JWT
217    /// is always included in the `Authorization` header per RFC 8444.
218    pub async fn send(
219        &self,
220        subscription: &PushSubscription,
221        payload: &[u8],
222    ) -> Result<(), WebPushError> {
223        // Derive the origin (scheme + host [+ port]) from the subscription URL
224        // for the VAPID `aud` claim.
225        let endpoint_origin = extract_origin(&subscription.url).ok_or_else(|| {
226            WebPushError::KeyError(format!(
227                "Cannot determine origin from URL: {}",
228                subscription.url
229            ))
230        })?;
231
232        let jwt = self.build_vapid_jwt(&endpoint_origin)?;
233
234        // VAPID authorization: `vapid t=<token>,k=<pubkey>` (RFC 8292 §2).
235        let authorization = format!("vapid t={},k={}", jwt, self.vapid_pubkey_base64url);
236
237        // RFC 8030 §5.2: TTL header is required.
238        const TTL_SECONDS: u32 = 86_400;
239
240        // RFC 8291: when keys are absent we send a tickle (empty body).
241        // The payload argument is accepted but ignored for now; see module-level
242        // note on RFC 8291 encryption deferral.
243        let body = if subscription.keys.is_none() || payload.is_empty() {
244            bytes::Bytes::new()
245        } else {
246            // Encryption is deferred — fall back to tickle.
247            bytes::Bytes::new()
248        };
249
250        let response = self
251            .http
252            .post(&subscription.url)
253            .header("Authorization", authorization)
254            .header("TTL", TTL_SECONDS.to_string())
255            .header("Content-Type", "application/octet-stream")
256            .body(body)
257            .send()
258            .await?;
259
260        let status = response.status().as_u16();
261        match status {
262            200..=299 => Ok(()),
263            410 => Err(WebPushError::Gone),
264            other => Err(WebPushError::UnexpectedStatus(other)),
265        }
266    }
267}
268
269/// Extract the `scheme://host[:port]` origin from a URL string.
270///
271/// Returns `None` if the URL cannot be parsed or has no host.
272fn extract_origin(url: &str) -> Option<String> {
273    // Simple parser: find scheme, then "://", then extract up to the first
274    // "/" or end.  Avoids pulling in the `url` crate just for this.
275    let after_scheme = url.split_once("://")?.1;
276    let host_and_rest = after_scheme.split('/').next()?;
277    let scheme = url.split("://").next()?;
278    Some(format!("{scheme}://{host_and_rest}"))
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_extract_origin_https() {
287        let url = "https://push.example.com/v1/subscriptions/abc123";
288        assert_eq!(
289            extract_origin(url),
290            Some("https://push.example.com".to_string())
291        );
292    }
293
294    #[test]
295    fn test_extract_origin_with_port() {
296        let url = "https://push.example.com:8443/endpoint";
297        assert_eq!(
298            extract_origin(url),
299            Some("https://push.example.com:8443".to_string())
300        );
301    }
302
303    #[test]
304    fn test_vapid_client_ephemeral_key() {
305        let client = WebPushClient::new(None, "admin@example.com").unwrap();
306        assert!(!client.vapid_pubkey_base64url().is_empty());
307        assert!(client.admin_sub.starts_with("mailto:"));
308    }
309
310    #[test]
311    fn test_build_vapid_jwt() {
312        let client = WebPushClient::new(None, "admin@example.com").unwrap();
313        let jwt = client.build_vapid_jwt("https://push.example.com").unwrap();
314        // JWT has three base64url parts separated by dots.
315        let parts: Vec<&str> = jwt.split('.').collect();
316        assert_eq!(parts.len(), 3, "JWT must have header.payload.signature");
317    }
318}