Skip to main content

rusmes_jmap/methods/
push_subscription.rs

1//! JMAP `PushSubscription/get` and `PushSubscription/set` handlers.
2//!
3//! Implements RFC 8620 §5.1: clients register a push endpoint URL; the server
4//! delivers a verification push, and thereafter fans out `StateChange` events to
5//! all verified subscriptions whose `types` list matches the changed data type.
6//!
7//! # Registry architecture
8//!
9//! The push registry lives in a `OnceLock<Arc<PushState>>` so the dispatch
10//! function (which has no state parameter) can access it without threading the
11//! handle through every call site.  Call [`init_push_state`] once at server
12//! startup before dispatching any JMAP requests.
13
14use crate::types::{JmapSetError, Principal, PushKeys, PushSubscription};
15use crate::web_push::{WebPushClient, WebPushError};
16use base64::Engine as _;
17use dashmap::DashMap;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::{Arc, OnceLock};
21
22// ──────────────────────────────────────────────────────────────────────────────
23// Global push state
24// ──────────────────────────────────────────────────────────────────────────────
25
26/// Shared push state accessible from the stateless dispatch function.
27pub struct PushState {
28    /// Map of subscription ID → subscription.
29    pub registry: Arc<DashMap<String, PushSubscription>>,
30    /// WebPush HTTP client with loaded VAPID key.
31    pub client: Arc<WebPushClient>,
32}
33
34static PUSH_STATE: OnceLock<Arc<PushState>> = OnceLock::new();
35
36/// Install the global push state.
37///
38/// Must be called once at server startup before any `PushSubscription/*`
39/// method can be dispatched.  Subsequent calls are no-ops (the first-writer
40/// wins, matching the `global_metrics()` pattern).
41pub fn init_push_state(state: Arc<PushState>) {
42    let _ = PUSH_STATE.set(state);
43}
44
45/// Retrieve the global push state, or `None` if it has not been initialised.
46pub fn push_state() -> Option<&'static Arc<PushState>> {
47    PUSH_STATE.get()
48}
49
50/// Registry type alias.
51pub type PushRegistry = Arc<DashMap<String, PushSubscription>>;
52
53// ──────────────────────────────────────────────────────────────────────────────
54// Request / response types
55// ──────────────────────────────────────────────────────────────────────────────
56
57/// `PushSubscription/get` request (RFC 8620 §5.1).
58#[derive(Debug, Clone, Deserialize)]
59#[serde(rename_all = "camelCase")]
60pub struct PushSubscriptionGetRequest {
61    /// Optional list of subscription IDs to retrieve.  `None` means "all".
62    #[serde(default)]
63    pub ids: Option<Vec<String>>,
64}
65
66/// `PushSubscription/get` response.
67#[derive(Debug, Clone, Serialize)]
68#[serde(rename_all = "camelCase")]
69pub struct PushSubscriptionGetResponse {
70    pub list: Vec<PushSubscriptionView>,
71    pub not_found: Vec<String>,
72}
73
74/// The RFC 8620 §5.1 view of a `PushSubscription` returned to the client.
75///
76/// Fields marked `#[serde(skip)]` on the internal struct are re-exposed only
77/// where RFC 8620 says they should appear in API responses.
78#[derive(Debug, Clone, Serialize)]
79#[serde(rename_all = "camelCase")]
80pub struct PushSubscriptionView {
81    pub id: String,
82    pub device_client_id: String,
83    pub url: String,
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub keys: Option<PushKeys>,
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub expires: Option<chrono::DateTime<chrono::Utc>>,
88    pub types: Vec<String>,
89}
90
91impl From<&PushSubscription> for PushSubscriptionView {
92    fn from(s: &PushSubscription) -> Self {
93        Self {
94            id: s.id.clone(),
95            device_client_id: s.device_client_id.clone(),
96            url: s.url.clone(),
97            keys: s.keys.clone(),
98            expires: s.expires,
99            types: s.types.clone(),
100        }
101    }
102}
103
104/// `PushSubscription/set` request.
105#[derive(Debug, Clone, Deserialize)]
106#[serde(rename_all = "camelCase")]
107pub struct PushSubscriptionSetRequest {
108    #[serde(default)]
109    pub create: Option<HashMap<String, PushSubscriptionCreate>>,
110    #[serde(default)]
111    pub update: Option<HashMap<String, PushSubscriptionUpdate>>,
112    #[serde(default)]
113    pub destroy: Option<Vec<String>>,
114}
115
116/// Fields accepted when creating a new push subscription.
117#[derive(Debug, Clone, Deserialize)]
118#[serde(rename_all = "camelCase")]
119pub struct PushSubscriptionCreate {
120    pub device_client_id: String,
121    pub url: String,
122    #[serde(default)]
123    pub keys: Option<PushKeys>,
124    #[serde(default)]
125    pub expires: Option<chrono::DateTime<chrono::Utc>>,
126    #[serde(default)]
127    pub types: Vec<String>,
128}
129
130/// Fields that may be patched on an existing push subscription.
131#[derive(Debug, Clone, Deserialize)]
132#[serde(rename_all = "camelCase")]
133pub struct PushSubscriptionUpdate {
134    /// Supply the server-issued code to transition the subscription to `verified`.
135    #[serde(default)]
136    pub verification_code: Option<String>,
137    /// Replace the monitored type list.
138    #[serde(default)]
139    pub types: Option<Vec<String>>,
140    /// Update the expiry timestamp.
141    #[serde(default)]
142    pub expires: Option<chrono::DateTime<chrono::Utc>>,
143}
144
145/// `PushSubscription/set` response.
146#[derive(Debug, Clone, Serialize, Default)]
147#[serde(rename_all = "camelCase")]
148pub struct PushSubscriptionSetResponse {
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub created: Option<HashMap<String, PushSubscriptionCreated>>,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub updated: Option<HashMap<String, Option<serde_json::Value>>>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub destroyed: Option<Vec<String>>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub not_created: Option<HashMap<String, JmapSetError>>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub not_updated: Option<HashMap<String, JmapSetError>>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub not_destroyed: Option<HashMap<String, JmapSetError>>,
161}
162
163/// Minimal object returned for a newly created subscription.
164///
165/// Note: RFC 8620 §5.1 specifies that the server sends the `verificationCode`
166/// out-of-band to the push endpoint URL; it is included here in the creation
167/// response so that test fixtures can retrieve it without inspecting the
168/// mock HTTP server.  Production clients should obtain it from the push
169/// delivery.
170#[derive(Debug, Clone, Serialize)]
171#[serde(rename_all = "camelCase")]
172pub struct PushSubscriptionCreated {
173    pub id: String,
174    /// The code that must be echoed back via `PushSubscription/set:update` to
175    /// verify the subscription.
176    pub verification_code: String,
177}
178
179// ──────────────────────────────────────────────────────────────────────────────
180// Handlers
181// ──────────────────────────────────────────────────────────────────────────────
182
183/// Handle `PushSubscription/get`.
184pub async fn push_subscription_get(
185    request: PushSubscriptionGetRequest,
186    principal: &Principal,
187) -> anyhow::Result<PushSubscriptionGetResponse> {
188    let state = match push_state() {
189        Some(s) => s,
190        None => {
191            // Push not initialised — return empty list.
192            return Ok(PushSubscriptionGetResponse {
193                list: vec![],
194                not_found: vec![],
195            });
196        }
197    };
198
199    let mut list = Vec::new();
200    let mut not_found = Vec::new();
201
202    match request.ids {
203        None => {
204            // Return all subscriptions owned by this principal.
205            for entry in state.registry.iter() {
206                if entry.value().principal_id == principal.account_id {
207                    list.push(PushSubscriptionView::from(entry.value()));
208                }
209            }
210        }
211        Some(ids) => {
212            for id in ids {
213                match state.registry.get(&id) {
214                    Some(entry) if entry.value().principal_id == principal.account_id => {
215                        list.push(PushSubscriptionView::from(entry.value()));
216                    }
217                    Some(_) => {
218                        // Exists but owned by someone else — treat as not found
219                        // (do not reveal existence of foreign subscriptions).
220                        not_found.push(id);
221                    }
222                    None => {
223                        not_found.push(id);
224                    }
225                }
226            }
227        }
228    }
229
230    Ok(PushSubscriptionGetResponse { list, not_found })
231}
232
233/// Handle `PushSubscription/set`.
234pub async fn push_subscription_set(
235    request: PushSubscriptionSetRequest,
236    principal: &Principal,
237) -> anyhow::Result<PushSubscriptionSetResponse> {
238    let state = match push_state() {
239        Some(s) => s,
240        None => {
241            return Err(anyhow::anyhow!(
242                "Push subsystem not initialised; call init_push_state() at server startup"
243            ));
244        }
245    };
246
247    let mut response = PushSubscriptionSetResponse::default();
248
249    // ── Create ────────────────────────────────────────────────────────────────
250    if let Some(creates) = request.create {
251        let mut created = HashMap::new();
252        let mut not_created = HashMap::new();
253
254        for (client_id, create) in creates {
255            match create_subscription(state, create, principal).await {
256                Ok(result) => {
257                    created.insert(client_id, result);
258                }
259                Err(e) => {
260                    not_created.insert(
261                        client_id,
262                        JmapSetError {
263                            error_type: "serverFail".to_string(),
264                            description: Some(e.to_string()),
265                        },
266                    );
267                }
268            }
269        }
270
271        if !created.is_empty() {
272            response.created = Some(created);
273        }
274        if !not_created.is_empty() {
275            response.not_created = Some(not_created);
276        }
277    }
278
279    // ── Update ────────────────────────────────────────────────────────────────
280    if let Some(updates) = request.update {
281        let mut updated = HashMap::new();
282        let mut not_updated = HashMap::new();
283
284        for (id, patch) in updates {
285            match update_subscription(state, &id, patch, principal) {
286                Ok(()) => {
287                    updated.insert(id, None);
288                }
289                Err(e) => {
290                    not_updated.insert(
291                        id,
292                        JmapSetError {
293                            error_type: "serverFail".to_string(),
294                            description: Some(e.to_string()),
295                        },
296                    );
297                }
298            }
299        }
300
301        if !updated.is_empty() {
302            response.updated = Some(updated);
303        }
304        if !not_updated.is_empty() {
305            response.not_updated = Some(not_updated);
306        }
307    }
308
309    // ── Destroy ───────────────────────────────────────────────────────────────
310    if let Some(destroy_ids) = request.destroy {
311        let mut destroyed = Vec::new();
312        let mut not_destroyed = HashMap::new();
313
314        for id in destroy_ids {
315            match destroy_subscription(state, &id, principal) {
316                Ok(()) => {
317                    destroyed.push(id);
318                }
319                Err(e) => {
320                    not_destroyed.insert(
321                        id,
322                        JmapSetError {
323                            error_type: "serverFail".to_string(),
324                            description: Some(e.to_string()),
325                        },
326                    );
327                }
328            }
329        }
330
331        if !destroyed.is_empty() {
332            response.destroyed = Some(destroyed);
333        }
334        if !not_destroyed.is_empty() {
335            response.not_destroyed = Some(not_destroyed);
336        }
337    }
338
339    Ok(response)
340}
341
342/// Testable variant of [`push_subscription_set`] that accepts an explicit
343/// `PushState` rather than reading from the `OnceLock`.
344///
345/// Use this in integration tests to avoid `OnceLock` contention across
346/// parallel test processes.
347pub async fn push_subscription_set_with_state(
348    request: PushSubscriptionSetRequest,
349    principal: &Principal,
350    state: &Arc<PushState>,
351) -> anyhow::Result<PushSubscriptionSetResponse> {
352    let mut response = PushSubscriptionSetResponse::default();
353
354    // ── Create ────────────────────────────────────────────────────────────────
355    if let Some(creates) = request.create {
356        let mut created = HashMap::new();
357        let mut not_created = HashMap::new();
358
359        for (client_id, create) in creates {
360            match create_subscription(state, create, principal).await {
361                Ok(result) => {
362                    created.insert(client_id, result);
363                }
364                Err(e) => {
365                    not_created.insert(
366                        client_id,
367                        JmapSetError {
368                            error_type: "serverFail".to_string(),
369                            description: Some(e.to_string()),
370                        },
371                    );
372                }
373            }
374        }
375
376        if !created.is_empty() {
377            response.created = Some(created);
378        }
379        if !not_created.is_empty() {
380            response.not_created = Some(not_created);
381        }
382    }
383
384    // ── Update ────────────────────────────────────────────────────────────────
385    if let Some(updates) = request.update {
386        let mut updated = HashMap::new();
387        let mut not_updated = HashMap::new();
388
389        for (id, patch) in updates {
390            match update_subscription(state, &id, patch, principal) {
391                Ok(()) => {
392                    updated.insert(id, None);
393                }
394                Err(e) => {
395                    not_updated.insert(
396                        id,
397                        JmapSetError {
398                            error_type: "serverFail".to_string(),
399                            description: Some(e.to_string()),
400                        },
401                    );
402                }
403            }
404        }
405
406        if !updated.is_empty() {
407            response.updated = Some(updated);
408        }
409        if !not_updated.is_empty() {
410            response.not_updated = Some(not_updated);
411        }
412    }
413
414    // ── Destroy ───────────────────────────────────────────────────────────────
415    if let Some(destroy_ids) = request.destroy {
416        let mut destroyed = Vec::new();
417        let mut not_destroyed = HashMap::new();
418
419        for id in destroy_ids {
420            match destroy_subscription(state, &id, principal) {
421                Ok(()) => {
422                    destroyed.push(id);
423                }
424                Err(e) => {
425                    not_destroyed.insert(
426                        id,
427                        JmapSetError {
428                            error_type: "serverFail".to_string(),
429                            description: Some(e.to_string()),
430                        },
431                    );
432                }
433            }
434        }
435
436        if !destroyed.is_empty() {
437            response.destroyed = Some(destroyed);
438        }
439        if !not_destroyed.is_empty() {
440            response.not_destroyed = Some(not_destroyed);
441        }
442    }
443
444    Ok(response)
445}
446
447// ──────────────────────────────────────────────────────────────────────────────
448// Internal helpers
449// ──────────────────────────────────────────────────────────────────────────────
450
451/// Generate a 32-byte random verification code, base64url-encoded.
452fn generate_verification_code() -> Result<String, anyhow::Error> {
453    let mut buf = [0u8; 32];
454    getrandom::fill(&mut buf)
455        .map_err(|e| anyhow::anyhow!("RNG failure during verification code generation: {e}"))?;
456    Ok(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(buf))
457}
458
459/// Validate that `url` is an acceptable push endpoint URL.
460///
461/// In production (`cfg(not(feature = "test-push-http"))`): only HTTPS is
462/// accepted per RFC 8030 §5.2.
463///
464/// When the `test-push-http` Cargo feature is enabled: plain HTTP is also
465/// accepted so that `wiremock` mock servers (which start on loopback
466/// without TLS) can be used as push endpoints in integration tests.
467fn validate_push_url(url: &str) -> Result<(), anyhow::Error> {
468    if url.starts_with("https://") {
469        return Ok(());
470    }
471    #[cfg(feature = "test-push-http")]
472    if url.starts_with("http://") {
473        return Ok(());
474    }
475    Err(anyhow::anyhow!(
476        "Push subscription URL must use HTTPS, got: {url}"
477    ))
478}
479
480async fn create_subscription(
481    state: &PushState,
482    create: PushSubscriptionCreate,
483    principal: &Principal,
484) -> anyhow::Result<PushSubscriptionCreated> {
485    validate_push_url(&create.url)?;
486
487    let id = uuid::Uuid::new_v4().to_string();
488    let verification_code = generate_verification_code()?;
489
490    let sub = PushSubscription {
491        id: id.clone(),
492        device_client_id: create.device_client_id,
493        url: create.url,
494        keys: create.keys,
495        verification_code: Some(verification_code.clone()),
496        expires: create.expires,
497        types: create.types,
498        verified: false,
499        principal_id: principal.account_id.clone(),
500    };
501
502    // Attempt to send the verification push.  A failure here is returned as a
503    // `serverFail`; the subscription is NOT stored on failure because there is
504    // no way to deliver the verification code.
505    match state.client.send(&sub, b"").await {
506        Ok(()) => {}
507        Err(WebPushError::Gone) => {
508            return Err(anyhow::anyhow!(
509                "Push endpoint returned 410 Gone during verification"
510            ));
511        }
512        Err(e) => {
513            return Err(anyhow::anyhow!("Failed to send verification push: {e}"));
514        }
515    }
516
517    state.registry.insert(id.clone(), sub);
518
519    Ok(PushSubscriptionCreated {
520        id,
521        verification_code,
522    })
523}
524
525fn update_subscription(
526    state: &PushState,
527    id: &str,
528    patch: PushSubscriptionUpdate,
529    principal: &Principal,
530) -> anyhow::Result<()> {
531    let mut entry = state
532        .registry
533        .get_mut(id)
534        .ok_or_else(|| anyhow::anyhow!("Subscription not found: {id}"))?;
535
536    if entry.value().principal_id != principal.account_id {
537        return Err(anyhow::anyhow!(
538            "Subscription {id} not owned by this principal"
539        ));
540    }
541
542    // Verification code check.
543    if let Some(code) = patch.verification_code {
544        if entry.value().verification_code.as_deref() == Some(code.as_str()) {
545            entry.value_mut().verified = true;
546        } else {
547            return Err(anyhow::anyhow!(
548                "Verification code mismatch for subscription {id}"
549            ));
550        }
551    }
552
553    if let Some(types) = patch.types {
554        entry.value_mut().types = types;
555    }
556    if let Some(expires) = patch.expires {
557        entry.value_mut().expires = Some(expires);
558    }
559
560    Ok(())
561}
562
563fn destroy_subscription(state: &PushState, id: &str, principal: &Principal) -> anyhow::Result<()> {
564    // Do ownership check inside a limited scope so the read guard is dropped
565    // before we call `remove()`.  Holding a DashMap read guard while calling
566    // `remove()` on the same shard deadlocks.
567    let owned = {
568        match state.registry.get(id) {
569            None => return Err(anyhow::anyhow!("Subscription not found: {id}")),
570            Some(entry) => entry.value().principal_id == principal.account_id,
571        }
572        // guard drops here
573    };
574
575    if !owned {
576        return Err(anyhow::anyhow!(
577            "Subscription {id} not owned by this principal"
578        ));
579    }
580
581    state.registry.remove(id);
582    Ok(())
583}