1use crate::{
2 config::{self, TlsConfig},
3 handlers::{account, api, home, websocket::WebSocketAccount},
4 Backend, Result, ServerConfig, SslConfig, StorageConfig,
5};
6use axum::{
7 extract::Extension,
8 http::{
9 header::{AUTHORIZATION, CONTENT_TYPE},
10 HeaderValue, Method,
11 },
12 middleware,
13 response::{IntoResponse, Json},
14 routing::{get, post, put},
15 Router,
16};
17use axum_server::{tls_rustls::RustlsConfig, Handle};
18use colored::Colorize;
19use futures::StreamExt;
20use sos_core::{AccountId, UtcDateTime};
21use std::{
22 collections::{HashMap, HashSet},
23 net::SocketAddr,
24 path::PathBuf,
25 sync::Arc,
26};
27use tokio::sync::{Mutex, RwLock, RwLockReadGuard};
28use tower_http::{
29 cors::CorsLayer,
30 trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
31};
32use tracing::Level;
33
34#[cfg(feature = "acme")]
35use tokio_rustls_acme::{caches::DirCache, AcmeConfig};
36
37#[cfg(feature = "listen")]
38use super::handlers::websocket::upgrade;
39
40use sos_core::ExternalFile;
41
42#[cfg(feature = "pairing")]
43use super::handlers::relay::{upgrade as relay_upgrade, RelayState};
44
45pub struct State {
47 pub config: ServerConfig,
49 pub(crate) sockets: HashMap<AccountId, WebSocketAccount>,
51}
52
53impl State {
54 pub fn new(config: ServerConfig) -> Self {
56 Self {
57 config,
58 sockets: Default::default(),
59 }
60 }
61}
62
63pub type ServerState = Arc<RwLock<State>>;
65
66pub type ServerBackend = Arc<RwLock<Backend>>;
68
69pub type TransferOperations = HashSet<ExternalFile>;
71
72pub type ServerTransfer = Arc<RwLock<TransferOperations>>;
74
75pub struct Server {}
77
78impl Server {
79 pub async fn new() -> Result<Self> {
81 Ok(Self {})
82 }
83
84 pub async fn start(
86 &self,
87 state: ServerState,
88 backend: ServerBackend,
89 handle: Handle,
90 ) -> Result<()> {
91 let reader = state.read().await;
92 let origins = Server::read_origins(&reader)?;
93 let ssl = reader.config.net.ssl.clone();
94 let addr = reader.config.bind_address().clone();
95 drop(reader);
96
97 match ssl {
98 Some(SslConfig::Tls(tls)) => {
99 self.run_tls(addr, state, backend, handle, origins, tls)
100 .await
101 }
102 #[cfg(feature = "acme")]
103 Some(SslConfig::Acme(acme)) => {
104 self.run_acme(addr, state, backend, handle, origins, acme)
105 .await
106 }
107 None => self.run(addr, state, backend, handle, origins).await,
108 }
109 }
110
111 async fn run_tls(
113 &self,
114 addr: SocketAddr,
115 state: ServerState,
116 backend: ServerBackend,
117 handle: Handle,
118 origins: Vec<HeaderValue>,
119 tls: TlsConfig,
120 ) -> Result<()> {
121 let storage = {
122 let state = state.read().await;
123 let backend = backend.read().await;
124 (
125 state.config.storage.clone(),
126 backend.paths().documents_dir().to_owned(),
127 )
128 };
129
130 let tls = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
131 let app = Server::router(Arc::clone(&state), backend, origins)?;
132
133 self.startup_message(state, &addr, true, storage).await;
134
135 axum_server::bind_rustls(addr, tls)
136 .handle(handle)
137 .serve(app.into_make_service())
138 .await?;
139 Ok(())
140 }
141
142 #[cfg(feature = "acme")]
144 async fn run_acme(
145 &self,
146 addr: SocketAddr,
147 state: ServerState,
148 backend: ServerBackend,
149 handle: Handle,
150 origins: Vec<HeaderValue>,
151 acme: config::AcmeConfig,
152 ) -> Result<()> {
153 let storage = {
154 let state = state.read().await;
155 let backend = backend.read().await;
156 (
157 state.config.storage.clone(),
158 backend.paths().documents_dir().to_owned(),
159 )
160 };
161
162 let mut acme_state = AcmeConfig::new(acme.domains)
163 .contact(acme.email.iter().map(|e| format!("mailto:{}", e)))
164 .cache_option(Some(DirCache::new(acme.cache)))
165 .directory_lets_encrypt(acme.production)
166 .state();
167
168 let app = Server::router(Arc::clone(&state), backend, origins)?;
169
170 self.startup_message(state, &addr, true, storage).await;
171
172 let rustls_config = rustls::ServerConfig::builder()
173 .with_no_client_auth()
174 .with_cert_resolver(acme_state.resolver());
175 let acceptor = acme_state.axum_acceptor(Arc::new(rustls_config));
176
177 tokio::spawn(async move {
178 loop {
179 match acme_state.next().await.unwrap() {
180 Ok(res) => tracing::info!(result = ?res, "acme"),
181 Err(err) => tracing::error!(error = ?err, "acme"),
182 }
183 }
184 });
185
186 axum_server::bind(addr)
187 .acceptor(acceptor)
188 .handle(handle)
189 .serve(app.into_make_service())
190 .await?;
191
192 Ok(())
193 }
194
195 async fn run(
197 &self,
198 addr: SocketAddr,
199 state: ServerState,
200 backend: ServerBackend,
201 handle: Handle,
202 origins: Vec<HeaderValue>,
203 ) -> Result<()> {
204 let storage = {
205 let state = state.read().await;
206 let backend = backend.read().await;
207 (
208 state.config.storage.clone(),
209 backend.paths().documents_dir().to_owned(),
210 )
211 };
212
213 let app = Server::router(Arc::clone(&state), backend, origins)?;
214 self.startup_message(state, &addr, false, storage).await;
215
216 axum_server::bind(addr)
217 .handle(handle)
218 .serve(app.into_make_service())
219 .await?;
220 Ok(())
221 }
222
223 async fn startup_message(
224 &self,
225 state: ServerState,
226 addr: &SocketAddr,
227 tls: bool,
228 storage: (StorageConfig, PathBuf),
229 ) {
230 let now = UtcDateTime::now().to_rfc3339().unwrap();
231
232 let mut columns = vec![
233 ("Started", now),
234 ("Listen", addr.to_string()),
235 ("TLS enabled", tls.to_string()),
236 ("Directory", storage.1.display().to_string()),
237 ];
238
239 if let Some(db_file) = &storage.0.database_uri {
240 columns.push(("Database", db_file.as_uri_string()));
241 }
242
243 let max_length = columns.iter().map(|s| s.0.len()).max().unwrap();
244 let col_size = max_length + 4;
245 for (key, value) in columns {
246 let padding = col_size - key.len();
247 println!("{}{}{}", key, " ".repeat(padding), value.yellow());
248 }
249
250 {
251 let reader = state.read().await;
252 if let Some(access) = &reader.config.access {
253 if let Some(allow) = &access.allow {
254 for address in allow {
255 println!(
256 "Allow {}",
257 address.to_string().green()
258 );
259 }
260 }
261 if let Some(deny) = &access.deny {
262 for address in deny {
263 println!(
264 "Deny {}",
265 address.to_string().red()
266 );
267 }
268 }
269 }
270 }
271 }
272
273 fn read_origins(
274 reader: &RwLockReadGuard<'_, State>,
275 ) -> Result<Vec<HeaderValue>> {
276 let mut origins = Vec::new();
277 let cors = reader.config.net.cors.as_ref();
278 if let Some(cors) = cors {
279 for url in cors.origins.iter() {
280 origins.push(HeaderValue::from_str(
281 url.as_str().trim_end_matches('/'),
282 )?);
283 }
284 }
285 Ok(origins)
286 }
287
288 fn router(
289 state: ServerState,
290 backend: ServerBackend,
291 origins: Vec<HeaderValue>,
292 ) -> Result<Router> {
293 let cors = CorsLayer::new()
294 .allow_methods(vec![
295 Method::GET,
296 Method::POST,
297 Method::PUT,
298 Method::PATCH,
299 Method::DELETE,
300 ])
301 .allow_credentials(true)
302 .allow_headers(vec![AUTHORIZATION, CONTENT_TYPE])
303 .expose_headers(vec![])
304 .allow_origin(origins);
305
306 let v1 = {
307 let mut router = Router::new()
308 .route("/", get(api))
309 .route("/docs", get(apidocs))
310 .route("/docs/", get(apidocs))
311 .route("/docs/openapi.json", get(openapi))
312 .route(
313 "/sync/account",
314 put(account::create_account)
315 .post(account::update_account)
316 .patch(account::sync_account)
317 .get(account::fetch_account)
318 .head(account::account_exists)
319 .delete(account::delete_account),
320 )
321 .route("/sync/account/status", get(account::sync_status))
322 .route(
323 "/sync/account/events",
324 get(account::event_scan)
325 .post(account::event_diff)
326 .patch(account::event_patch),
327 );
328
329 {
330 use super::handlers::files::{self, file_operation_lock};
331 router = router
332 .route("/sync/files", post(files::compare_files))
333 .route(
334 "/sync/file/{vault_id}/{secret_id}/{file_name}",
335 put(files::receive_file)
336 .post(files::move_file)
337 .get(files::send_file)
338 .delete(files::delete_file)
339 .route_layer(middleware::from_fn(
340 file_operation_lock,
341 )),
342 );
343 }
344
345 #[cfg(feature = "listen")]
346 {
347 use super::handlers::connections;
348 router = router
349 .route("/sync/connections", get(connections))
350 .route("/sync/changes", get(upgrade));
351 }
352
353 #[cfg(feature = "pairing")]
354 {
355 router = router.route("/relay", get(relay_upgrade));
356 }
357
358 router
359 };
360
361 let mut v1 = v1.layer(cors).layer(
362 TraceLayer::new_for_http()
363 .on_request(DefaultOnRequest::new().level(Level::TRACE))
364 .on_response(DefaultOnResponse::new().level(Level::TRACE)),
365 );
366
367 #[cfg(feature = "pairing")]
368 {
369 let relay: RelayState = Arc::new(Mutex::new(HashMap::new()));
370 v1 = v1.layer(Extension(relay));
371 }
372
373 v1 = v1.layer(Extension(backend)).layer(Extension(state));
374
375 {
376 let file_operations: ServerTransfer =
377 Arc::new(RwLock::new(HashSet::new()));
378 v1 = v1.layer(Extension(file_operations));
379 }
380
381 #[allow(unused_mut)]
382 let mut app = Router::new()
383 .route("/", get(home))
384 .nest_service("/api/v1", v1);
385
386 #[cfg(feature = "prometheus")]
387 {
388 let (prometheus_layer, metric_handle) =
389 axum_prometheus::PrometheusMetricLayerBuilder::new()
390 .with_default_metrics()
391 .enable_response_body_size(true)
392 .build_pair();
393
394 app = app
395 .route(
396 "/metrics",
397 get(|| async move { metric_handle.render() }),
398 )
399 .layer(prometheus_layer);
400 }
401
402 Ok(app)
403 }
404}
405
406#[utoipa::path(
408 get,
409 path = "/docs/openapi.json",
410 responses(
411 (
412 status = StatusCode::OK,
413 description = "OpenAPI definition",
414 ),
415 ),
416)]
417pub async fn openapi() -> impl IntoResponse {
418 let value = crate::api_docs::openapi();
419 Json(serde_json::json!(&value))
420}
421
422#[utoipa::path(
424 get,
425 path = "/docs",
426 responses(
427 (
428 status = StatusCode::OK,
429 description = "Render OpenAPI documentation",
430 ),
431 ),
432)]
433pub async fn apidocs() -> impl IntoResponse {
434 use utoipa_rapidoc::RapiDoc;
435 let rapidoc = RapiDoc::new("/api/v1/docs/openapi.json");
436 let html = rapidoc.to_html();
437 ([(CONTENT_TYPE, "text/html")], html)
438}