1use std::{
2 collections::{HashMap, HashSet},
3 fmt::{self, Debug},
4 sync::Arc,
5};
6
7use anyhow::bail;
8use async_trait::async_trait;
9use authentication::check_signature_middleware;
10use axum::{
11 body::Body,
12 extract::{ws::Message, Path, Request, State, WebSocketUpgrade},
13 http::StatusCode,
14 middleware::{self, Next},
15 response::{IntoResponse, Response},
16 routing::{any, get, post},
17 Extension, Json, Router,
18};
19use futures::{SinkExt, StreamExt};
20use rand::Rng;
21use rusher_core::{ChannelName, ConnectionInfo, CustomEvent, ServerEvent, SocketId};
22use rusher_pubsub::{AnyBroker, Broker, Connection};
23use serde::Deserialize;
24use serde_json::{json, Value as JsonValue};
25use tokio::sync::mpsc;
26
27mod authentication;
28mod websocket;
29
30pub use axum::serve;
31use tower_http::trace::{DefaultOnResponse, TraceLayer};
32use tracing::{debug, error, info_span, Instrument, Level};
33use websocket::ConnectionProtocol;
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub struct App {
37 pub id: AppId,
38 secret: AppSecret,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct AppId(String);
43
44impl fmt::Display for AppId {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 write!(f, "{}", self.0)
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
51pub struct AppSecret(String);
52
53impl App {
54 pub fn new(id: impl Into<String>, secret: impl Into<String>) -> Self {
55 Self {
56 id: AppId(id.into()),
57 secret: AppSecret(secret.into()),
58 }
59 }
60}
61
62pub trait IntoAppRepository {
63 type AppRepository: AppRepository + 'static;
64 fn into_app_repository(self) -> Self::AppRepository;
65}
66
67impl<I: IntoIterator<Item = (App, AnyBroker)>> IntoAppRepository for I {
68 type AppRepository = HashMap<AppId, (App, AnyBroker)>;
69
70 fn into_app_repository(self) -> Self::AppRepository {
71 self.into_iter()
72 .map(|(app, broker)| (app.id.clone(), (app, broker)))
73 .collect()
74 }
75}
76
77#[async_trait]
78pub trait AppRepository: Send + Sync {
79 async fn secret_for_app(&self, app_id: &AppId) -> Option<AppSecret>;
80 async fn broker_for_app(&self, app_id: &AppId) -> Option<AnyBroker>;
81}
82
83#[async_trait]
84impl AppRepository for HashMap<AppId, (App, AnyBroker)> {
85 async fn secret_for_app(&self, app_id: &AppId) -> Option<AppSecret> {
86 self.get(app_id).map(|(app, _)| app.secret.clone())
87 }
88
89 async fn broker_for_app(&self, app_id: &AppId) -> Option<AnyBroker> {
90 self.get(app_id).map(|(_, broker)| broker).cloned()
91 }
92}
93
94pub fn app(app_repo: impl IntoAppRepository) -> Router {
95 let app_repo = app_repo.into_app_repository();
96 Router::new()
97 .route("/apps/:app/channels", get(list_channels))
98 .route("/apps/:app/events", post(publish))
99 .route_layer(middleware::from_fn(check_signature_middleware))
100 .route("/app/:app", any(handle_ws))
101 .layer(
102 TraceLayer::new_for_http()
103 .on_response(DefaultOnResponse::default().level(Level::INFO))
104 .make_span_with(|request: &Request<_>| {
105 info_span!(
106 "request",
107 uri = ?request.uri(),
108 method = ?request.method(),
109 )
110 }),
111 )
112 .route_layer(middleware::from_fn_with_state(
113 Arc::new(app_repo) as Arc<dyn AppRepository>,
114 broker_middleware,
115 ))
116}
117
118async fn broker_middleware(
119 State(app_repo): State<Arc<dyn AppRepository>>,
120 Path(app): Path<String>,
121 mut request: Request,
122 next: Next,
123) -> Response {
124 let app_id = AppId(app);
125 match (
126 app_repo.secret_for_app(&app_id).await,
127 app_repo.broker_for_app(&app_id).await,
128 ) {
129 (Some(secret), Some(broker)) => {
130 request.extensions_mut().insert(app_id.clone());
131 request.extensions_mut().insert(secret);
132 request.extensions_mut().insert(broker);
133 next.run(request)
134 .instrument(info_span!("app_request", app_id = app_id.0))
135 .await
136 }
137 _ => Response::builder()
138 .status(StatusCode::NOT_FOUND)
139 .body(Body::empty())
140 .unwrap(),
141 }
142}
143
144#[derive(Clone, Debug, Deserialize)]
145pub struct EventPayload {
146 pub name: String,
147 pub data: JsonValue,
148 pub channels: Option<HashSet<ChannelName>>,
149 pub channel: Option<ChannelName>,
150 pub socket_id: Option<SocketId>,
151}
152
153async fn publish(
154 Extension(broker): Extension<AnyBroker>,
155 Json(payload): Json<EventPayload>,
156) -> Result<Json<JsonValue>, StatusCode> {
157 let event = payload.name;
158 let data = payload.data;
159
160 let channels = match (payload.channel, payload.channels) {
161 (Some(channel), Some(mut channels)) => {
162 channels.insert(channel);
163 channels
164 }
165 (Some(channel), None) => HashSet::from_iter([channel]),
166 (None, Some(channels)) => channels,
167 _ => HashSet::new(),
168 };
169
170 for channel in channels {
171 let event = ServerEvent::ChannelEvent(CustomEvent {
172 event: event.clone(),
173 data: data.clone(),
174 channel: channel.clone(),
175 user_id: None,
176 });
177
178 broker
179 .publish(channel.as_ref(), event)
180 .await
181 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
182 }
183
184 Ok(Json(json!({ "ok": true })))
185}
186
187async fn list_channels(Extension(broker): Extension<AnyBroker>) -> impl IntoResponse {
188 let channels = broker
189 .subscriptions()
190 .await
191 .into_iter()
192 .map(|(channel, count)| {
193 (
194 channel,
195 json!({
196 "subscription_count": count,
197 "user_count": count,
198 }),
199 )
200 })
201 .collect::<HashMap<String, JsonValue>>();
202
203 Json(channels)
204}
205
206async fn handle_ws(
207 Extension(broker): Extension<AnyBroker>,
208 Extension(AppId(app_id)): Extension<AppId>,
209 Extension(AppSecret(secret)): Extension<AppSecret>,
210 ws: WebSocketUpgrade,
211) -> impl IntoResponse {
212 match broker.connect().await {
213 Ok(mut connection) => Ok(ws.on_upgrade(move |ws| async move {
214 let socket_id: SocketId = rand::thread_rng().gen();
215 let _span = info_span!("websocket", %app_id, %socket_id);
216
217 let (mut write_ws, mut read_ws) = ws.split();
218 let (tx, mut rx) = mpsc::channel(64);
219
220 let write_messages = async move {
221 while let Some(msg) = rx.recv().await {
222 if let Ok(msg) = serde_json::to_string(&msg) {
223 if let Err(err) = write_ws.send(Message::Text(msg)).await {
224 bail!(err)
225 }
226 }
227 }
228 anyhow::Ok(())
229 }.instrument(info_span!("websocket_connection_write", %app_id, %socket_id));
230
231 let mut proto = ConnectionProtocol {
232 tx: tx.clone(),
233 app_id: app_id.clone(),
234 secret,
235 socket_id: socket_id.clone(),
236 current_user_id: None,
237 };
238
239 let connection_established = ServerEvent::ConnectionEstablished {
240 data: ConnectionInfo {
241 socket_id: socket_id.clone(),
242 activity_timeout: 120,
243 },
244 };
245
246 let read_messages = async move {
247 tx.send(connection_established).await?;
248
249 loop {
250 tokio::select! {
251 Ok(msg) = connection.recv() => {
252 tx.send(msg).await?;
253 },
254
255 Some(Ok(msg)) = read_ws.next() => {
256 match msg {
257 Message::Text(text) => {
258 match serde_json::from_str(&text) {
259 Ok(msg) => {
260 if let Err(error) = proto.handle_message(&mut connection, msg).await {
261 error!(%error);
262 bail!(error)
263 }
264 },
265 Err(error) => {
266 debug!(msg = "could not decode message", %text, %error);
267 continue
268 },
269 };
270 }
271 _ => continue,
272 }
273 }
274
275 else => break,
276 }
277 }
278 anyhow::Ok(())
279 }.instrument(info_span!("websocket_connection_read", %app_id, %socket_id));
280
281 tokio::select! {
282 _ = write_messages => debug!("Writer finished"),
283 _ = read_messages => debug!("Reader finished"),
284 };
285
286 debug!("Client disconnected");
287 })),
288 Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
289 }
290}