1use crate::types::{JmapSetError, Principal, PushKeys, PushSubscription};
15use crate::web_push::{WebPushClient, WebPushError};
16use base64::Engine as _;
17use dashmap::DashMap;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::{Arc, OnceLock};
21
22pub struct PushState {
28 pub registry: Arc<DashMap<String, PushSubscription>>,
30 pub client: Arc<WebPushClient>,
32}
33
34static PUSH_STATE: OnceLock<Arc<PushState>> = OnceLock::new();
35
36pub fn init_push_state(state: Arc<PushState>) {
42 let _ = PUSH_STATE.set(state);
43}
44
45pub fn push_state() -> Option<&'static Arc<PushState>> {
47 PUSH_STATE.get()
48}
49
50pub type PushRegistry = Arc<DashMap<String, PushSubscription>>;
52
53#[derive(Debug, Clone, Deserialize)]
59#[serde(rename_all = "camelCase")]
60pub struct PushSubscriptionGetRequest {
61 #[serde(default)]
63 pub ids: Option<Vec<String>>,
64}
65
66#[derive(Debug, Clone, Serialize)]
68#[serde(rename_all = "camelCase")]
69pub struct PushSubscriptionGetResponse {
70 pub list: Vec<PushSubscriptionView>,
71 pub not_found: Vec<String>,
72}
73
74#[derive(Debug, Clone, Serialize)]
79#[serde(rename_all = "camelCase")]
80pub struct PushSubscriptionView {
81 pub id: String,
82 pub device_client_id: String,
83 pub url: String,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub keys: Option<PushKeys>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 pub expires: Option<chrono::DateTime<chrono::Utc>>,
88 pub types: Vec<String>,
89}
90
91impl From<&PushSubscription> for PushSubscriptionView {
92 fn from(s: &PushSubscription) -> Self {
93 Self {
94 id: s.id.clone(),
95 device_client_id: s.device_client_id.clone(),
96 url: s.url.clone(),
97 keys: s.keys.clone(),
98 expires: s.expires,
99 types: s.types.clone(),
100 }
101 }
102}
103
104#[derive(Debug, Clone, Deserialize)]
106#[serde(rename_all = "camelCase")]
107pub struct PushSubscriptionSetRequest {
108 #[serde(default)]
109 pub create: Option<HashMap<String, PushSubscriptionCreate>>,
110 #[serde(default)]
111 pub update: Option<HashMap<String, PushSubscriptionUpdate>>,
112 #[serde(default)]
113 pub destroy: Option<Vec<String>>,
114}
115
116#[derive(Debug, Clone, Deserialize)]
118#[serde(rename_all = "camelCase")]
119pub struct PushSubscriptionCreate {
120 pub device_client_id: String,
121 pub url: String,
122 #[serde(default)]
123 pub keys: Option<PushKeys>,
124 #[serde(default)]
125 pub expires: Option<chrono::DateTime<chrono::Utc>>,
126 #[serde(default)]
127 pub types: Vec<String>,
128}
129
130#[derive(Debug, Clone, Deserialize)]
132#[serde(rename_all = "camelCase")]
133pub struct PushSubscriptionUpdate {
134 #[serde(default)]
136 pub verification_code: Option<String>,
137 #[serde(default)]
139 pub types: Option<Vec<String>>,
140 #[serde(default)]
142 pub expires: Option<chrono::DateTime<chrono::Utc>>,
143}
144
145#[derive(Debug, Clone, Serialize, Default)]
147#[serde(rename_all = "camelCase")]
148pub struct PushSubscriptionSetResponse {
149 #[serde(skip_serializing_if = "Option::is_none")]
150 pub created: Option<HashMap<String, PushSubscriptionCreated>>,
151 #[serde(skip_serializing_if = "Option::is_none")]
152 pub updated: Option<HashMap<String, Option<serde_json::Value>>>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 pub destroyed: Option<Vec<String>>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 pub not_created: Option<HashMap<String, JmapSetError>>,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 pub not_updated: Option<HashMap<String, JmapSetError>>,
159 #[serde(skip_serializing_if = "Option::is_none")]
160 pub not_destroyed: Option<HashMap<String, JmapSetError>>,
161}
162
163#[derive(Debug, Clone, Serialize)]
171#[serde(rename_all = "camelCase")]
172pub struct PushSubscriptionCreated {
173 pub id: String,
174 pub verification_code: String,
177}
178
179pub async fn push_subscription_get(
185 request: PushSubscriptionGetRequest,
186 principal: &Principal,
187) -> anyhow::Result<PushSubscriptionGetResponse> {
188 let state = match push_state() {
189 Some(s) => s,
190 None => {
191 return Ok(PushSubscriptionGetResponse {
193 list: vec![],
194 not_found: vec![],
195 });
196 }
197 };
198
199 let mut list = Vec::new();
200 let mut not_found = Vec::new();
201
202 match request.ids {
203 None => {
204 for entry in state.registry.iter() {
206 if entry.value().principal_id == principal.account_id {
207 list.push(PushSubscriptionView::from(entry.value()));
208 }
209 }
210 }
211 Some(ids) => {
212 for id in ids {
213 match state.registry.get(&id) {
214 Some(entry) if entry.value().principal_id == principal.account_id => {
215 list.push(PushSubscriptionView::from(entry.value()));
216 }
217 Some(_) => {
218 not_found.push(id);
221 }
222 None => {
223 not_found.push(id);
224 }
225 }
226 }
227 }
228 }
229
230 Ok(PushSubscriptionGetResponse { list, not_found })
231}
232
233pub async fn push_subscription_set(
235 request: PushSubscriptionSetRequest,
236 principal: &Principal,
237) -> anyhow::Result<PushSubscriptionSetResponse> {
238 let state = match push_state() {
239 Some(s) => s,
240 None => {
241 return Err(anyhow::anyhow!(
242 "Push subsystem not initialised; call init_push_state() at server startup"
243 ));
244 }
245 };
246
247 let mut response = PushSubscriptionSetResponse::default();
248
249 if let Some(creates) = request.create {
251 let mut created = HashMap::new();
252 let mut not_created = HashMap::new();
253
254 for (client_id, create) in creates {
255 match create_subscription(state, create, principal).await {
256 Ok(result) => {
257 created.insert(client_id, result);
258 }
259 Err(e) => {
260 not_created.insert(
261 client_id,
262 JmapSetError {
263 error_type: "serverFail".to_string(),
264 description: Some(e.to_string()),
265 },
266 );
267 }
268 }
269 }
270
271 if !created.is_empty() {
272 response.created = Some(created);
273 }
274 if !not_created.is_empty() {
275 response.not_created = Some(not_created);
276 }
277 }
278
279 if let Some(updates) = request.update {
281 let mut updated = HashMap::new();
282 let mut not_updated = HashMap::new();
283
284 for (id, patch) in updates {
285 match update_subscription(state, &id, patch, principal) {
286 Ok(()) => {
287 updated.insert(id, None);
288 }
289 Err(e) => {
290 not_updated.insert(
291 id,
292 JmapSetError {
293 error_type: "serverFail".to_string(),
294 description: Some(e.to_string()),
295 },
296 );
297 }
298 }
299 }
300
301 if !updated.is_empty() {
302 response.updated = Some(updated);
303 }
304 if !not_updated.is_empty() {
305 response.not_updated = Some(not_updated);
306 }
307 }
308
309 if let Some(destroy_ids) = request.destroy {
311 let mut destroyed = Vec::new();
312 let mut not_destroyed = HashMap::new();
313
314 for id in destroy_ids {
315 match destroy_subscription(state, &id, principal) {
316 Ok(()) => {
317 destroyed.push(id);
318 }
319 Err(e) => {
320 not_destroyed.insert(
321 id,
322 JmapSetError {
323 error_type: "serverFail".to_string(),
324 description: Some(e.to_string()),
325 },
326 );
327 }
328 }
329 }
330
331 if !destroyed.is_empty() {
332 response.destroyed = Some(destroyed);
333 }
334 if !not_destroyed.is_empty() {
335 response.not_destroyed = Some(not_destroyed);
336 }
337 }
338
339 Ok(response)
340}
341
342pub async fn push_subscription_set_with_state(
348 request: PushSubscriptionSetRequest,
349 principal: &Principal,
350 state: &Arc<PushState>,
351) -> anyhow::Result<PushSubscriptionSetResponse> {
352 let mut response = PushSubscriptionSetResponse::default();
353
354 if let Some(creates) = request.create {
356 let mut created = HashMap::new();
357 let mut not_created = HashMap::new();
358
359 for (client_id, create) in creates {
360 match create_subscription(state, create, principal).await {
361 Ok(result) => {
362 created.insert(client_id, result);
363 }
364 Err(e) => {
365 not_created.insert(
366 client_id,
367 JmapSetError {
368 error_type: "serverFail".to_string(),
369 description: Some(e.to_string()),
370 },
371 );
372 }
373 }
374 }
375
376 if !created.is_empty() {
377 response.created = Some(created);
378 }
379 if !not_created.is_empty() {
380 response.not_created = Some(not_created);
381 }
382 }
383
384 if let Some(updates) = request.update {
386 let mut updated = HashMap::new();
387 let mut not_updated = HashMap::new();
388
389 for (id, patch) in updates {
390 match update_subscription(state, &id, patch, principal) {
391 Ok(()) => {
392 updated.insert(id, None);
393 }
394 Err(e) => {
395 not_updated.insert(
396 id,
397 JmapSetError {
398 error_type: "serverFail".to_string(),
399 description: Some(e.to_string()),
400 },
401 );
402 }
403 }
404 }
405
406 if !updated.is_empty() {
407 response.updated = Some(updated);
408 }
409 if !not_updated.is_empty() {
410 response.not_updated = Some(not_updated);
411 }
412 }
413
414 if let Some(destroy_ids) = request.destroy {
416 let mut destroyed = Vec::new();
417 let mut not_destroyed = HashMap::new();
418
419 for id in destroy_ids {
420 match destroy_subscription(state, &id, principal) {
421 Ok(()) => {
422 destroyed.push(id);
423 }
424 Err(e) => {
425 not_destroyed.insert(
426 id,
427 JmapSetError {
428 error_type: "serverFail".to_string(),
429 description: Some(e.to_string()),
430 },
431 );
432 }
433 }
434 }
435
436 if !destroyed.is_empty() {
437 response.destroyed = Some(destroyed);
438 }
439 if !not_destroyed.is_empty() {
440 response.not_destroyed = Some(not_destroyed);
441 }
442 }
443
444 Ok(response)
445}
446
447fn generate_verification_code() -> Result<String, anyhow::Error> {
453 let mut buf = [0u8; 32];
454 getrandom::fill(&mut buf)
455 .map_err(|e| anyhow::anyhow!("RNG failure during verification code generation: {e}"))?;
456 Ok(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(buf))
457}
458
459fn validate_push_url(url: &str) -> Result<(), anyhow::Error> {
468 if url.starts_with("https://") {
469 return Ok(());
470 }
471 #[cfg(feature = "test-push-http")]
472 if url.starts_with("http://") {
473 return Ok(());
474 }
475 Err(anyhow::anyhow!(
476 "Push subscription URL must use HTTPS, got: {url}"
477 ))
478}
479
480async fn create_subscription(
481 state: &PushState,
482 create: PushSubscriptionCreate,
483 principal: &Principal,
484) -> anyhow::Result<PushSubscriptionCreated> {
485 validate_push_url(&create.url)?;
486
487 let id = uuid::Uuid::new_v4().to_string();
488 let verification_code = generate_verification_code()?;
489
490 let sub = PushSubscription {
491 id: id.clone(),
492 device_client_id: create.device_client_id,
493 url: create.url,
494 keys: create.keys,
495 verification_code: Some(verification_code.clone()),
496 expires: create.expires,
497 types: create.types,
498 verified: false,
499 principal_id: principal.account_id.clone(),
500 };
501
502 match state.client.send(&sub, b"").await {
506 Ok(()) => {}
507 Err(WebPushError::Gone) => {
508 return Err(anyhow::anyhow!(
509 "Push endpoint returned 410 Gone during verification"
510 ));
511 }
512 Err(e) => {
513 return Err(anyhow::anyhow!("Failed to send verification push: {e}"));
514 }
515 }
516
517 state.registry.insert(id.clone(), sub);
518
519 Ok(PushSubscriptionCreated {
520 id,
521 verification_code,
522 })
523}
524
525fn update_subscription(
526 state: &PushState,
527 id: &str,
528 patch: PushSubscriptionUpdate,
529 principal: &Principal,
530) -> anyhow::Result<()> {
531 let mut entry = state
532 .registry
533 .get_mut(id)
534 .ok_or_else(|| anyhow::anyhow!("Subscription not found: {id}"))?;
535
536 if entry.value().principal_id != principal.account_id {
537 return Err(anyhow::anyhow!(
538 "Subscription {id} not owned by this principal"
539 ));
540 }
541
542 if let Some(code) = patch.verification_code {
544 if entry.value().verification_code.as_deref() == Some(code.as_str()) {
545 entry.value_mut().verified = true;
546 } else {
547 return Err(anyhow::anyhow!(
548 "Verification code mismatch for subscription {id}"
549 ));
550 }
551 }
552
553 if let Some(types) = patch.types {
554 entry.value_mut().types = types;
555 }
556 if let Some(expires) = patch.expires {
557 entry.value_mut().expires = Some(expires);
558 }
559
560 Ok(())
561}
562
563fn destroy_subscription(state: &PushState, id: &str, principal: &Principal) -> anyhow::Result<()> {
564 let owned = {
568 match state.registry.get(id) {
569 None => return Err(anyhow::anyhow!("Subscription not found: {id}")),
570 Some(entry) => entry.value().principal_id == principal.account_id,
571 }
572 };
574
575 if !owned {
576 return Err(anyhow::anyhow!(
577 "Subscription {id} not owned by this principal"
578 ));
579 }
580
581 state.registry.remove(id);
582 Ok(())
583}