1use crate::events::sinks::device_tokens::{DeviceTokenStore, Platform};
25use crate::events::sinks::in_app::NotificationStore;
26use crate::events::sinks::preferences::{NotificationPreferencesStore, UserPreferences};
27use axum::extract::{Path, Query, State};
28use axum::http::StatusCode;
29use axum::response::IntoResponse;
30use axum::{
31 Json, Router,
32 routing::{delete, get, post},
33};
34use serde::Deserialize;
35use serde_json::json;
36use std::sync::Arc;
37use uuid::Uuid;
38
39#[derive(Clone)]
43pub struct NotificationState {
44 pub notification_store: Arc<NotificationStore>,
45 pub preferences_store: Arc<NotificationPreferencesStore>,
46 pub device_token_store: Arc<DeviceTokenStore>,
47}
48
49pub fn notification_routes(state: NotificationState) -> Router {
53 Router::new()
54 .route("/notifications/{user_id}", get(list_notifications))
56 .route("/notifications/{user_id}/unread-count", get(unread_count))
57 .route("/notifications/{user_id}/read", post(mark_as_read))
58 .route("/notifications/{user_id}/read-all", post(mark_all_as_read))
59 .route(
60 "/notifications/{user_id}/{notification_id}",
61 delete(delete_notification),
62 )
63 .route(
65 "/notifications/{user_id}/preferences",
66 get(get_preferences).put(update_preferences),
67 )
68 .route("/notifications/{user_id}/mute", post(mute_user))
69 .route("/notifications/{user_id}/unmute", post(unmute_user))
70 .route(
72 "/device-tokens/{user_id}",
73 get(list_device_tokens).post(register_device_token),
74 )
75 .route(
76 "/device-tokens/{user_id}/{token}",
77 delete(unregister_device_token),
78 )
79 .with_state(state)
80}
81
82#[derive(Debug, Deserialize)]
86pub struct PaginationParams {
87 pub limit: Option<usize>,
89 pub offset: Option<usize>,
91}
92
93async fn list_notifications(
97 State(state): State<NotificationState>,
98 Path(user_id): Path<String>,
99 Query(params): Query<PaginationParams>,
100) -> impl IntoResponse {
101 let limit = params.limit.unwrap_or(20).min(100);
102 let offset = params.offset.unwrap_or(0);
103
104 let notifications = state
105 .notification_store
106 .list_by_user(&user_id, limit, offset)
107 .await;
108
109 let total = state.notification_store.total_count(&user_id).await;
110 let unread = state.notification_store.unread_count(&user_id).await;
111
112 Json(json!({
113 "notifications": notifications,
114 "total": total,
115 "unread": unread,
116 "limit": limit,
117 "offset": offset,
118 }))
119}
120
121async fn unread_count(
123 State(state): State<NotificationState>,
124 Path(user_id): Path<String>,
125) -> impl IntoResponse {
126 let count = state.notification_store.unread_count(&user_id).await;
127 Json(json!({ "unread_count": count }))
128}
129
130#[derive(Debug, Deserialize)]
132pub struct MarkAsReadRequest {
133 pub ids: Vec<Uuid>,
135}
136
137async fn mark_as_read(
139 State(state): State<NotificationState>,
140 Path(user_id): Path<String>,
141 Json(body): Json<MarkAsReadRequest>,
142) -> impl IntoResponse {
143 let marked = state
144 .notification_store
145 .mark_as_read(&body.ids, Some(&user_id))
146 .await;
147
148 Json(json!({ "marked": marked }))
149}
150
151async fn mark_all_as_read(
153 State(state): State<NotificationState>,
154 Path(user_id): Path<String>,
155) -> impl IntoResponse {
156 let marked = state.notification_store.mark_all_as_read(&user_id).await;
157
158 Json(json!({ "marked": marked }))
159}
160
161async fn delete_notification(
163 State(state): State<NotificationState>,
164 Path((_user_id, notification_id)): Path<(String, Uuid)>,
165) -> impl IntoResponse {
166 let deleted = state.notification_store.delete(¬ification_id).await;
167
168 if deleted {
169 (StatusCode::OK, Json(json!({ "deleted": true })))
170 } else {
171 (
172 StatusCode::NOT_FOUND,
173 Json(json!({ "error": "notification not found" })),
174 )
175 }
176}
177
178async fn get_preferences(
182 State(state): State<NotificationState>,
183 Path(user_id): Path<String>,
184) -> impl IntoResponse {
185 let prefs = state.preferences_store.get(&user_id).await;
186 Json(json!({ "preferences": prefs }))
187}
188
189async fn update_preferences(
191 State(state): State<NotificationState>,
192 Path(user_id): Path<String>,
193 Json(prefs): Json<UserPreferences>,
194) -> impl IntoResponse {
195 state
196 .preferences_store
197 .update(&user_id, prefs.clone())
198 .await;
199 Json(json!({ "preferences": prefs }))
200}
201
202async fn mute_user(
204 State(state): State<NotificationState>,
205 Path(user_id): Path<String>,
206) -> impl IntoResponse {
207 state.preferences_store.mute(&user_id).await;
208 Json(json!({ "muted": true }))
209}
210
211async fn unmute_user(
213 State(state): State<NotificationState>,
214 Path(user_id): Path<String>,
215) -> impl IntoResponse {
216 state.preferences_store.unmute(&user_id).await;
217 Json(json!({ "muted": false }))
218}
219
220async fn list_device_tokens(
224 State(state): State<NotificationState>,
225 Path(user_id): Path<String>,
226) -> impl IntoResponse {
227 let tokens = state.device_token_store.get_tokens(&user_id).await;
228 Json(json!({ "tokens": tokens }))
229}
230
231#[derive(Debug, Deserialize)]
233pub struct RegisterTokenRequest {
234 pub token: String,
236 pub platform: Platform,
238}
239
240async fn register_device_token(
242 State(state): State<NotificationState>,
243 Path(user_id): Path<String>,
244 Json(body): Json<RegisterTokenRequest>,
245) -> impl IntoResponse {
246 state
247 .device_token_store
248 .register(&user_id, body.token, body.platform)
249 .await;
250
251 (StatusCode::CREATED, Json(json!({ "registered": true })))
252}
253
254async fn unregister_device_token(
256 State(state): State<NotificationState>,
257 Path((user_id, token)): Path<(String, String)>,
258) -> impl IntoResponse {
259 let removed = state.device_token_store.unregister(&user_id, &token).await;
260
261 if removed {
262 (StatusCode::OK, Json(json!({ "unregistered": true })))
263 } else {
264 (
265 StatusCode::NOT_FOUND,
266 Json(json!({ "error": "token not found" })),
267 )
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use axum::body::Body;
275 use axum::http::Request;
276 use serde_json::Value;
277 use tower::ServiceExt;
278
279 fn test_state() -> NotificationState {
280 NotificationState {
281 notification_store: Arc::new(NotificationStore::new()),
282 preferences_store: Arc::new(NotificationPreferencesStore::new()),
283 device_token_store: Arc::new(DeviceTokenStore::new()),
284 }
285 }
286
287 fn test_router() -> Router {
288 notification_routes(test_state())
289 }
290
291 async fn json_body(response: axum::response::Response) -> Value {
292 let body = axum::body::to_bytes(response.into_body(), 1024 * 64)
293 .await
294 .expect("body should read");
295 serde_json::from_slice(&body).expect("body should be valid JSON")
296 }
297
298 #[tokio::test]
301 async fn test_list_notifications_empty() {
302 let router = test_router();
303 let response = router
304 .oneshot(
305 Request::builder()
306 .uri("/notifications/user-A")
307 .body(Body::empty())
308 .unwrap(),
309 )
310 .await
311 .unwrap();
312
313 assert_eq!(response.status(), StatusCode::OK);
314 let json = json_body(response).await;
315 assert_eq!(json["total"], 0);
316 assert_eq!(json["unread"], 0);
317 assert!(json["notifications"].as_array().unwrap().is_empty());
318 }
319
320 #[tokio::test]
321 async fn test_list_and_unread_count() {
322 let state = test_state();
323 let router = notification_routes(state.clone());
324
325 for i in 0..3 {
327 state
328 .notification_store
329 .insert(crate::events::sinks::in_app::StoredNotification {
330 id: Uuid::new_v4(),
331 recipient_id: "user-A".to_string(),
332 notification_type: "test".to_string(),
333 title: format!("Notif {i}"),
334 body: String::new(),
335 data: serde_json::Value::Null,
336 read: false,
337 created_at: chrono::Utc::now(),
338 })
339 .await;
340 }
341
342 let response = router
344 .clone()
345 .oneshot(
346 Request::builder()
347 .uri("/notifications/user-A")
348 .body(Body::empty())
349 .unwrap(),
350 )
351 .await
352 .unwrap();
353 let json = json_body(response).await;
354 assert_eq!(json["total"], 3);
355 assert_eq!(json["unread"], 3);
356
357 let response = router
359 .oneshot(
360 Request::builder()
361 .uri("/notifications/user-A/unread-count")
362 .body(Body::empty())
363 .unwrap(),
364 )
365 .await
366 .unwrap();
367 let json = json_body(response).await;
368 assert_eq!(json["unread_count"], 3);
369 }
370
371 #[tokio::test]
372 async fn test_mark_as_read() {
373 let state = test_state();
374 let router = notification_routes(state.clone());
375
376 let id = Uuid::new_v4();
377 state
378 .notification_store
379 .insert(crate::events::sinks::in_app::StoredNotification {
380 id,
381 recipient_id: "user-A".to_string(),
382 notification_type: "test".to_string(),
383 title: "Test".to_string(),
384 body: String::new(),
385 data: serde_json::Value::Null,
386 read: false,
387 created_at: chrono::Utc::now(),
388 })
389 .await;
390
391 let response = router
392 .oneshot(
393 Request::builder()
394 .method("POST")
395 .uri("/notifications/user-A/read")
396 .header("content-type", "application/json")
397 .body(Body::from(
398 serde_json::to_string(&json!({ "ids": [id] })).unwrap(),
399 ))
400 .unwrap(),
401 )
402 .await
403 .unwrap();
404
405 assert_eq!(response.status(), StatusCode::OK);
406 let json = json_body(response).await;
407 assert_eq!(json["marked"], 1);
408 assert_eq!(state.notification_store.unread_count("user-A").await, 0);
409 }
410
411 #[tokio::test]
412 async fn test_mark_all_as_read() {
413 let state = test_state();
414 let router = notification_routes(state.clone());
415
416 for _ in 0..3 {
417 state
418 .notification_store
419 .insert(crate::events::sinks::in_app::StoredNotification {
420 id: Uuid::new_v4(),
421 recipient_id: "user-A".to_string(),
422 notification_type: "test".to_string(),
423 title: "Test".to_string(),
424 body: String::new(),
425 data: serde_json::Value::Null,
426 read: false,
427 created_at: chrono::Utc::now(),
428 })
429 .await;
430 }
431
432 let response = router
433 .oneshot(
434 Request::builder()
435 .method("POST")
436 .uri("/notifications/user-A/read-all")
437 .body(Body::empty())
438 .unwrap(),
439 )
440 .await
441 .unwrap();
442
443 assert_eq!(response.status(), StatusCode::OK);
444 let json = json_body(response).await;
445 assert_eq!(json["marked"], 3);
446 }
447
448 #[tokio::test]
449 async fn test_delete_notification() {
450 let state = test_state();
451 let router = notification_routes(state.clone());
452
453 let id = Uuid::new_v4();
454 state
455 .notification_store
456 .insert(crate::events::sinks::in_app::StoredNotification {
457 id,
458 recipient_id: "user-A".to_string(),
459 notification_type: "test".to_string(),
460 title: "To delete".to_string(),
461 body: String::new(),
462 data: serde_json::Value::Null,
463 read: false,
464 created_at: chrono::Utc::now(),
465 })
466 .await;
467
468 let response = router
469 .clone()
470 .oneshot(
471 Request::builder()
472 .method("DELETE")
473 .uri(format!("/notifications/user-A/{id}"))
474 .body(Body::empty())
475 .unwrap(),
476 )
477 .await
478 .unwrap();
479
480 assert_eq!(response.status(), StatusCode::OK);
481
482 let response = router
484 .oneshot(
485 Request::builder()
486 .method("DELETE")
487 .uri(format!("/notifications/user-A/{id}"))
488 .body(Body::empty())
489 .unwrap(),
490 )
491 .await
492 .unwrap();
493 assert_eq!(response.status(), StatusCode::NOT_FOUND);
494 }
495
496 #[tokio::test]
499 async fn test_get_preferences_default() {
500 let router = test_router();
501 let response = router
502 .oneshot(
503 Request::builder()
504 .uri("/notifications/user-A/preferences")
505 .body(Body::empty())
506 .unwrap(),
507 )
508 .await
509 .unwrap();
510
511 assert_eq!(response.status(), StatusCode::OK);
512 let json = json_body(response).await;
513 assert_eq!(json["preferences"]["muted"], false);
514 assert!(
515 json["preferences"]["disabled_types"]
516 .as_array()
517 .unwrap()
518 .is_empty()
519 );
520 }
521
522 #[tokio::test]
523 async fn test_update_preferences() {
524 let state = test_state();
525 let router = notification_routes(state.clone());
526
527 let response = router
528 .oneshot(
529 Request::builder()
530 .method("PUT")
531 .uri("/notifications/user-A/preferences")
532 .header("content-type", "application/json")
533 .body(Body::from(
534 serde_json::to_string(&json!({
535 "disabled_types": ["new_like"],
536 "muted": false
537 }))
538 .unwrap(),
539 ))
540 .unwrap(),
541 )
542 .await
543 .unwrap();
544
545 assert_eq!(response.status(), StatusCode::OK);
546 assert!(
547 !state
548 .preferences_store
549 .is_enabled("user-A", "new_like")
550 .await
551 );
552 assert!(
553 state
554 .preferences_store
555 .is_enabled("user-A", "new_follower")
556 .await
557 );
558 }
559
560 #[tokio::test]
561 async fn test_mute_unmute() {
562 let state = test_state();
563 let router = notification_routes(state.clone());
564
565 let response = router
567 .clone()
568 .oneshot(
569 Request::builder()
570 .method("POST")
571 .uri("/notifications/user-A/mute")
572 .body(Body::empty())
573 .unwrap(),
574 )
575 .await
576 .unwrap();
577 assert_eq!(response.status(), StatusCode::OK);
578 assert!(
579 !state
580 .preferences_store
581 .is_enabled("user-A", "anything")
582 .await
583 );
584
585 let response = router
587 .oneshot(
588 Request::builder()
589 .method("POST")
590 .uri("/notifications/user-A/unmute")
591 .body(Body::empty())
592 .unwrap(),
593 )
594 .await
595 .unwrap();
596 assert_eq!(response.status(), StatusCode::OK);
597 assert!(
598 state
599 .preferences_store
600 .is_enabled("user-A", "anything")
601 .await
602 );
603 }
604
605 #[tokio::test]
608 async fn test_register_and_list_device_tokens() {
609 let state = test_state();
610 let router = notification_routes(state.clone());
611
612 let response = router
614 .clone()
615 .oneshot(
616 Request::builder()
617 .method("POST")
618 .uri("/device-tokens/user-A")
619 .header("content-type", "application/json")
620 .body(Body::from(
621 serde_json::to_string(&json!({
622 "token": "ExponentPushToken[xxx]",
623 "platform": "ios"
624 }))
625 .unwrap(),
626 ))
627 .unwrap(),
628 )
629 .await
630 .unwrap();
631 assert_eq!(response.status(), StatusCode::CREATED);
632
633 let response = router
635 .oneshot(
636 Request::builder()
637 .uri("/device-tokens/user-A")
638 .body(Body::empty())
639 .unwrap(),
640 )
641 .await
642 .unwrap();
643 assert_eq!(response.status(), StatusCode::OK);
644 let json = json_body(response).await;
645 let tokens = json["tokens"].as_array().unwrap();
646 assert_eq!(tokens.len(), 1);
647 assert_eq!(tokens[0]["token"], "ExponentPushToken[xxx]");
648 assert_eq!(tokens[0]["platform"], "ios");
649 }
650
651 #[tokio::test]
652 async fn test_unregister_device_token() {
653 let state = test_state();
654 let router = notification_routes(state.clone());
655
656 state
657 .device_token_store
658 .register("user-A", "token-1".to_string(), Platform::Ios)
659 .await;
660
661 let response = router
662 .clone()
663 .oneshot(
664 Request::builder()
665 .method("DELETE")
666 .uri("/device-tokens/user-A/token-1")
667 .body(Body::empty())
668 .unwrap(),
669 )
670 .await
671 .unwrap();
672 assert_eq!(response.status(), StatusCode::OK);
673
674 let response = router
676 .oneshot(
677 Request::builder()
678 .method("DELETE")
679 .uri("/device-tokens/user-A/token-1")
680 .body(Body::empty())
681 .unwrap(),
682 )
683 .await
684 .unwrap();
685 assert_eq!(response.status(), StatusCode::NOT_FOUND);
686 }
687}