sos_server/
server.rs

1use crate::{
2    config::{self, TlsConfig},
3    handlers::{account, api, home, websocket::WebSocketAccount},
4    Backend, Result, ServerConfig, SslConfig, StorageConfig,
5};
6use axum::{
7    extract::Extension,
8    http::{
9        header::{AUTHORIZATION, CONTENT_TYPE},
10        HeaderValue, Method,
11    },
12    middleware,
13    response::{IntoResponse, Json},
14    routing::{get, post, put},
15    Router,
16};
17use axum_server::{tls_rustls::RustlsConfig, Handle};
18use colored::Colorize;
19use futures::StreamExt;
20use sos_core::{AccountId, UtcDateTime};
21use std::{
22    collections::{HashMap, HashSet},
23    net::SocketAddr,
24    path::PathBuf,
25    sync::Arc,
26};
27use tokio::sync::{Mutex, RwLock, RwLockReadGuard};
28use tower_http::{
29    cors::CorsLayer,
30    trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
31};
32use tracing::Level;
33
34#[cfg(feature = "acme")]
35use tokio_rustls_acme::{caches::DirCache, AcmeConfig};
36
37#[cfg(feature = "listen")]
38use super::handlers::websocket::upgrade;
39
40use sos_core::ExternalFile;
41
42#[cfg(feature = "pairing")]
43use super::handlers::relay::{upgrade as relay_upgrade, RelayState};
44
45/// Server state.
46pub struct State {
47    /// The server configuration.
48    pub config: ServerConfig,
49    /// Map of websocket channels by account identifier.
50    pub(crate) sockets: HashMap<AccountId, WebSocketAccount>,
51}
52
53impl State {
54    /// Create new server state.
55    pub fn new(config: ServerConfig) -> Self {
56        Self {
57            config,
58            sockets: Default::default(),
59        }
60    }
61}
62
63/// State for the server.
64pub type ServerState = Arc<RwLock<State>>;
65
66/// State for the server backend.
67pub type ServerBackend = Arc<RwLock<Backend>>;
68
69/// Transfer operations in progress.
70pub type TransferOperations = HashSet<ExternalFile>;
71
72/// State for the file transfer operations.
73pub type ServerTransfer = Arc<RwLock<TransferOperations>>;
74
75/// Web server implementation.
76pub struct Server {}
77
78impl Server {
79    /// Create a new server.
80    pub async fn new() -> Result<Self> {
81        Ok(Self {})
82    }
83
84    /// Start the server.
85    pub async fn start(
86        &self,
87        state: ServerState,
88        backend: ServerBackend,
89        handle: Handle,
90    ) -> Result<()> {
91        let reader = state.read().await;
92        let origins = Server::read_origins(&reader)?;
93        let ssl = reader.config.net.ssl.clone();
94        let addr = reader.config.bind_address().clone();
95        drop(reader);
96
97        match ssl {
98            Some(SslConfig::Tls(tls)) => {
99                self.run_tls(addr, state, backend, handle, origins, tls)
100                    .await
101            }
102            #[cfg(feature = "acme")]
103            Some(SslConfig::Acme(acme)) => {
104                self.run_acme(addr, state, backend, handle, origins, acme)
105                    .await
106            }
107            None => self.run(addr, state, backend, handle, origins).await,
108        }
109    }
110
111    /// Start the server running on HTTPS.
112    async fn run_tls(
113        &self,
114        addr: SocketAddr,
115        state: ServerState,
116        backend: ServerBackend,
117        handle: Handle,
118        origins: Vec<HeaderValue>,
119        tls: TlsConfig,
120    ) -> Result<()> {
121        let storage = {
122            let state = state.read().await;
123            let backend = backend.read().await;
124            (
125                state.config.storage.clone(),
126                backend.paths().documents_dir().to_owned(),
127            )
128        };
129
130        let tls = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
131        let app = Server::router(Arc::clone(&state), backend, origins)?;
132
133        self.startup_message(state, &addr, true, storage).await;
134
135        axum_server::bind_rustls(addr, tls)
136            .handle(handle)
137            .serve(app.into_make_service())
138            .await?;
139        Ok(())
140    }
141
142    /// Start the server running on HTTPS using ACME.
143    #[cfg(feature = "acme")]
144    async fn run_acme(
145        &self,
146        addr: SocketAddr,
147        state: ServerState,
148        backend: ServerBackend,
149        handle: Handle,
150        origins: Vec<HeaderValue>,
151        acme: config::AcmeConfig,
152    ) -> Result<()> {
153        let storage = {
154            let state = state.read().await;
155            let backend = backend.read().await;
156            (
157                state.config.storage.clone(),
158                backend.paths().documents_dir().to_owned(),
159            )
160        };
161
162        let mut acme_state = AcmeConfig::new(acme.domains)
163            .contact(acme.email.iter().map(|e| format!("mailto:{}", e)))
164            .cache_option(Some(DirCache::new(acme.cache)))
165            .directory_lets_encrypt(acme.production)
166            .state();
167
168        let app = Server::router(Arc::clone(&state), backend, origins)?;
169
170        self.startup_message(state, &addr, true, storage).await;
171
172        let rustls_config = rustls::ServerConfig::builder()
173            .with_no_client_auth()
174            .with_cert_resolver(acme_state.resolver());
175        let acceptor = acme_state.axum_acceptor(Arc::new(rustls_config));
176
177        tokio::spawn(async move {
178            loop {
179                match acme_state.next().await.unwrap() {
180                    Ok(res) => tracing::info!(result = ?res, "acme"),
181                    Err(err) => tracing::error!(error = ?err, "acme"),
182                }
183            }
184        });
185
186        axum_server::bind(addr)
187            .acceptor(acceptor)
188            .handle(handle)
189            .serve(app.into_make_service())
190            .await?;
191
192        Ok(())
193    }
194
195    /// Start the server running on HTTP.
196    async fn run(
197        &self,
198        addr: SocketAddr,
199        state: ServerState,
200        backend: ServerBackend,
201        handle: Handle,
202        origins: Vec<HeaderValue>,
203    ) -> Result<()> {
204        let storage = {
205            let state = state.read().await;
206            let backend = backend.read().await;
207            (
208                state.config.storage.clone(),
209                backend.paths().documents_dir().to_owned(),
210            )
211        };
212
213        let app = Server::router(Arc::clone(&state), backend, origins)?;
214        self.startup_message(state, &addr, false, storage).await;
215
216        axum_server::bind(addr)
217            .handle(handle)
218            .serve(app.into_make_service())
219            .await?;
220        Ok(())
221    }
222
223    async fn startup_message(
224        &self,
225        state: ServerState,
226        addr: &SocketAddr,
227        tls: bool,
228        storage: (StorageConfig, PathBuf),
229    ) {
230        let now = UtcDateTime::now().to_rfc3339().unwrap();
231
232        let mut columns = vec![
233            ("Started", now),
234            ("Listen", addr.to_string()),
235            ("TLS enabled", tls.to_string()),
236            ("Directory", storage.1.display().to_string()),
237        ];
238
239        if let Some(db_file) = &storage.0.database_uri {
240            columns.push(("Database", db_file.as_uri_string()));
241        }
242
243        let max_length = columns.iter().map(|s| s.0.len()).max().unwrap();
244        let col_size = max_length + 4;
245        for (key, value) in columns {
246            let padding = col_size - key.len();
247            println!("{}{}{}", key, " ".repeat(padding), value.yellow());
248        }
249
250        {
251            let reader = state.read().await;
252            if let Some(access) = &reader.config.access {
253                if let Some(allow) = &access.allow {
254                    for address in allow {
255                        println!(
256                            "Allow          {}",
257                            address.to_string().green()
258                        );
259                    }
260                }
261                if let Some(deny) = &access.deny {
262                    for address in deny {
263                        println!(
264                            "Deny           {}",
265                            address.to_string().red()
266                        );
267                    }
268                }
269            }
270        }
271    }
272
273    fn read_origins(
274        reader: &RwLockReadGuard<'_, State>,
275    ) -> Result<Vec<HeaderValue>> {
276        let mut origins = Vec::new();
277        let cors = reader.config.net.cors.as_ref();
278        if let Some(cors) = cors {
279            for url in cors.origins.iter() {
280                origins.push(HeaderValue::from_str(
281                    url.as_str().trim_end_matches('/'),
282                )?);
283            }
284        }
285        Ok(origins)
286    }
287
288    fn router(
289        state: ServerState,
290        backend: ServerBackend,
291        origins: Vec<HeaderValue>,
292    ) -> Result<Router> {
293        let cors = CorsLayer::new()
294            .allow_methods(vec![
295                Method::GET,
296                Method::POST,
297                Method::PUT,
298                Method::PATCH,
299                Method::DELETE,
300            ])
301            .allow_credentials(true)
302            .allow_headers(vec![AUTHORIZATION, CONTENT_TYPE])
303            .expose_headers(vec![])
304            .allow_origin(origins);
305
306        let v1 = {
307            let mut router = Router::new()
308                .route("/", get(api))
309                .route("/docs", get(apidocs))
310                .route("/docs/", get(apidocs))
311                .route("/docs/openapi.json", get(openapi))
312                .route(
313                    "/sync/account",
314                    put(account::create_account)
315                        .post(account::update_account)
316                        .patch(account::sync_account)
317                        .get(account::fetch_account)
318                        .head(account::account_exists)
319                        .delete(account::delete_account),
320                )
321                .route("/sync/account/status", get(account::sync_status))
322                .route(
323                    "/sync/account/events",
324                    get(account::event_scan)
325                        .post(account::event_diff)
326                        .patch(account::event_patch),
327                );
328
329            {
330                use super::handlers::files::{self, file_operation_lock};
331                router = router
332                    .route("/sync/files", post(files::compare_files))
333                    .route(
334                        "/sync/file/{vault_id}/{secret_id}/{file_name}",
335                        put(files::receive_file)
336                            .post(files::move_file)
337                            .get(files::send_file)
338                            .delete(files::delete_file)
339                            .route_layer(middleware::from_fn(
340                                file_operation_lock,
341                            )),
342                    );
343            }
344
345            #[cfg(feature = "listen")]
346            {
347                use super::handlers::connections;
348                router = router
349                    .route("/sync/connections", get(connections))
350                    .route("/sync/changes", get(upgrade));
351            }
352
353            #[cfg(feature = "pairing")]
354            {
355                router = router.route("/relay", get(relay_upgrade));
356            }
357
358            router
359        };
360
361        let mut v1 = v1.layer(cors).layer(
362            TraceLayer::new_for_http()
363                .on_request(DefaultOnRequest::new().level(Level::TRACE))
364                .on_response(DefaultOnResponse::new().level(Level::TRACE)),
365        );
366
367        #[cfg(feature = "pairing")]
368        {
369            let relay: RelayState = Arc::new(Mutex::new(HashMap::new()));
370            v1 = v1.layer(Extension(relay));
371        }
372
373        v1 = v1.layer(Extension(backend)).layer(Extension(state));
374
375        {
376            let file_operations: ServerTransfer =
377                Arc::new(RwLock::new(HashSet::new()));
378            v1 = v1.layer(Extension(file_operations));
379        }
380
381        #[allow(unused_mut)]
382        let mut app = Router::new()
383            .route("/", get(home))
384            .nest_service("/api/v1", v1);
385
386        #[cfg(feature = "prometheus")]
387        {
388            let (prometheus_layer, metric_handle) =
389                axum_prometheus::PrometheusMetricLayerBuilder::new()
390                    .with_default_metrics()
391                    .enable_response_body_size(true)
392                    .build_pair();
393
394            app = app
395                .route(
396                    "/metrics",
397                    get(|| async move { metric_handle.render() }),
398                )
399                .layer(prometheus_layer);
400        }
401
402        Ok(app)
403    }
404}
405
406/// Get OpenAPI JSON definition.
407#[utoipa::path(
408    get,
409    path = "/docs/openapi.json",
410    responses(
411        (
412            status = StatusCode::OK,
413            description = "OpenAPI definition",
414        ),
415    ),
416)]
417pub async fn openapi() -> impl IntoResponse {
418    let value = crate::api_docs::openapi();
419    Json(serde_json::json!(&value))
420}
421
422/// OpenAPI documentation.
423#[utoipa::path(
424    get,
425    path = "/docs",
426    responses(
427        (
428            status = StatusCode::OK,
429            description = "Render OpenAPI documentation",
430        ),
431    ),
432)]
433pub async fn apidocs() -> impl IntoResponse {
434    use utoipa_rapidoc::RapiDoc;
435    let rapidoc = RapiDoc::new("/api/v1/docs/openapi.json");
436    let html = rapidoc.to_html();
437    ([(CONTENT_TYPE, "text/html")], html)
438}