sos_server/
server.rs

1use crate::{
2    config::{self, TlsConfig},
3    handlers::{account, api, home, websocket::WebSocketAccount},
4    Backend, Result, ServerConfig, SslConfig, StorageConfig,
5};
6use axum::{
7    extract::Extension,
8    http::{
9        header::{AUTHORIZATION, CONTENT_TYPE},
10        HeaderValue, Method,
11    },
12    middleware,
13    response::{IntoResponse, Json},
14    routing::{get, post, put},
15    Router,
16};
17use axum_server::{tls_rustls::RustlsConfig, Handle};
18use colored::Colorize;
19use futures::StreamExt;
20use sos_core::{AccountId, UtcDateTime};
21use std::{
22    collections::{HashMap, HashSet},
23    net::SocketAddr,
24    path::PathBuf,
25    sync::Arc,
26};
27use tokio::sync::{Mutex, RwLock, RwLockReadGuard};
28use tower_http::{
29    cors::CorsLayer,
30    trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
31};
32use tracing::Level;
33
34#[cfg(feature = "acme")]
35use tokio_rustls_acme::{caches::DirCache, AcmeConfig};
36
37#[cfg(feature = "listen")]
38use super::handlers::websocket::upgrade;
39
40use sos_core::ExternalFile;
41
42#[cfg(feature = "pairing")]
43use super::handlers::relay::{upgrade as relay_upgrade, RelayState};
44
45/// Server state.
46pub struct State {
47    /// The server configuration.
48    pub config: ServerConfig,
49    /// Map of websocket channels by account identifier.
50    pub(crate) sockets: HashMap<AccountId, WebSocketAccount>,
51}
52
53impl State {
54    /// Create new server state.
55    pub fn new(config: ServerConfig) -> Self {
56        Self {
57            config,
58            sockets: Default::default(),
59        }
60    }
61}
62
63/// State for the server.
64pub type ServerState = Arc<RwLock<State>>;
65
66/// State for the server backend.
67pub type ServerBackend = Arc<RwLock<Backend>>;
68
69/// Transfer operations in progress.
70pub type TransferOperations = HashSet<ExternalFile>;
71
72/// State for the file transfer operations.
73pub type ServerTransfer = Arc<RwLock<TransferOperations>>;
74
75/// Web server implementation.
76pub struct Server {}
77
78impl Server {
79    /// Create a new server.
80    ///
81    /// Path should be the directory where the backend
82    /// will store account files; if a server is already
83    /// running and has a lock on the directory this will
84    /// block until the lock is released.
85    pub async fn new() -> Result<Self> {
86        Ok(Self {})
87    }
88
89    /// Start the server.
90    pub async fn start(
91        &self,
92        state: ServerState,
93        backend: ServerBackend,
94        handle: Handle,
95    ) -> Result<()> {
96        let reader = state.read().await;
97        let origins = Server::read_origins(&reader)?;
98        let ssl = reader.config.net.ssl.clone();
99        let addr = reader.config.bind_address().clone();
100        drop(reader);
101
102        match ssl {
103            Some(SslConfig::Tls(tls)) => {
104                self.run_tls(addr, state, backend, handle, origins, tls)
105                    .await
106            }
107            #[cfg(feature = "acme")]
108            Some(SslConfig::Acme(acme)) => {
109                self.run_acme(addr, state, backend, handle, origins, acme)
110                    .await
111            }
112            None => self.run(addr, state, backend, handle, origins).await,
113        }
114    }
115
116    /// Start the server running on HTTPS.
117    async fn run_tls(
118        &self,
119        addr: SocketAddr,
120        state: ServerState,
121        backend: ServerBackend,
122        handle: Handle,
123        origins: Vec<HeaderValue>,
124        tls: TlsConfig,
125    ) -> Result<()> {
126        let storage = {
127            let state = state.read().await;
128            let backend = backend.read().await;
129            (
130                state.config.storage.clone(),
131                backend.paths().documents_dir().to_owned(),
132            )
133        };
134
135        let tls = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
136        let app = Server::router(Arc::clone(&state), backend, origins)?;
137
138        self.startup_message(state, &addr, true, storage).await;
139
140        axum_server::bind_rustls(addr, tls)
141            .handle(handle)
142            .serve(app.into_make_service())
143            .await?;
144        Ok(())
145    }
146
147    /// Start the server running on HTTPS using ACME.
148    #[cfg(feature = "acme")]
149    async fn run_acme(
150        &self,
151        addr: SocketAddr,
152        state: ServerState,
153        backend: ServerBackend,
154        handle: Handle,
155        origins: Vec<HeaderValue>,
156        acme: config::AcmeConfig,
157    ) -> Result<()> {
158        let storage = {
159            let state = state.read().await;
160            let backend = backend.read().await;
161            (
162                state.config.storage.clone(),
163                backend.paths().documents_dir().to_owned(),
164            )
165        };
166
167        let mut acme_state = AcmeConfig::new(acme.domains)
168            .contact(acme.email.iter().map(|e| format!("mailto:{}", e)))
169            .cache_option(Some(DirCache::new(acme.cache)))
170            .directory_lets_encrypt(acme.production)
171            .state();
172
173        let app = Server::router(Arc::clone(&state), backend, origins)?;
174
175        self.startup_message(state, &addr, true, storage).await;
176
177        let rustls_config = rustls::ServerConfig::builder()
178            .with_no_client_auth()
179            .with_cert_resolver(acme_state.resolver());
180        let acceptor = acme_state.axum_acceptor(Arc::new(rustls_config));
181
182        tokio::spawn(async move {
183            loop {
184                match acme_state.next().await.unwrap() {
185                    Ok(res) => tracing::info!(result = ?res, "acme"),
186                    Err(err) => tracing::error!(error = ?err, "acme"),
187                }
188            }
189        });
190
191        axum_server::bind(addr)
192            .acceptor(acceptor)
193            .handle(handle)
194            .serve(app.into_make_service())
195            .await?;
196
197        Ok(())
198    }
199
200    /// Start the server running on HTTP.
201    async fn run(
202        &self,
203        addr: SocketAddr,
204        state: ServerState,
205        backend: ServerBackend,
206        handle: Handle,
207        origins: Vec<HeaderValue>,
208    ) -> Result<()> {
209        let storage = {
210            let state = state.read().await;
211            let backend = backend.read().await;
212            (
213                state.config.storage.clone(),
214                backend.paths().documents_dir().to_owned(),
215            )
216        };
217
218        let app = Server::router(Arc::clone(&state), backend, origins)?;
219        self.startup_message(state, &addr, false, storage).await;
220
221        axum_server::bind(addr)
222            .handle(handle)
223            .serve(app.into_make_service())
224            .await?;
225        Ok(())
226    }
227
228    async fn startup_message(
229        &self,
230        state: ServerState,
231        addr: &SocketAddr,
232        tls: bool,
233        storage: (StorageConfig, PathBuf),
234    ) {
235        let now = UtcDateTime::now().to_rfc3339().unwrap();
236
237        let mut columns = vec![
238            ("Started", now),
239            ("Listen", addr.to_string()),
240            ("TLS enabled", tls.to_string()),
241            ("Directory", storage.1.display().to_string()),
242        ];
243
244        if let Some(db_file) = &storage.0.database_uri {
245            columns.push(("Database", db_file.as_uri_string()));
246        }
247
248        let max_length = columns.iter().map(|s| s.0.len()).max().unwrap();
249        let col_size = max_length + 4;
250        for (key, value) in columns {
251            let padding = col_size - key.len();
252            println!("{}{}{}", key, " ".repeat(padding), value.yellow());
253        }
254
255        {
256            let reader = state.read().await;
257            if let Some(access) = &reader.config.access {
258                if let Some(allow) = &access.allow {
259                    for address in allow {
260                        println!(
261                            "Allow          {}",
262                            address.to_string().green()
263                        );
264                    }
265                }
266                if let Some(deny) = &access.deny {
267                    for address in deny {
268                        println!(
269                            "Deny           {}",
270                            address.to_string().red()
271                        );
272                    }
273                }
274            }
275        }
276    }
277
278    fn read_origins(
279        reader: &RwLockReadGuard<'_, State>,
280    ) -> Result<Vec<HeaderValue>> {
281        let mut origins = Vec::new();
282        let cors = reader.config.net.cors.as_ref();
283        if let Some(cors) = cors {
284            for url in cors.origins.iter() {
285                origins.push(HeaderValue::from_str(
286                    url.as_str().trim_end_matches('/'),
287                )?);
288            }
289        }
290        Ok(origins)
291    }
292
293    fn router(
294        state: ServerState,
295        backend: ServerBackend,
296        origins: Vec<HeaderValue>,
297    ) -> Result<Router> {
298        let cors = CorsLayer::new()
299            .allow_methods(vec![
300                Method::GET,
301                Method::POST,
302                Method::PUT,
303                Method::PATCH,
304                Method::DELETE,
305            ])
306            .allow_credentials(true)
307            .allow_headers(vec![AUTHORIZATION, CONTENT_TYPE])
308            .expose_headers(vec![])
309            .allow_origin(origins);
310
311        let v1 = {
312            let mut router = Router::new()
313                .route("/", get(api))
314                .route("/docs", get(apidocs))
315                .route("/docs/", get(apidocs))
316                .route("/docs/openapi.json", get(openapi))
317                .route(
318                    "/sync/account",
319                    put(account::create_account)
320                        .post(account::update_account)
321                        .patch(account::sync_account)
322                        .get(account::fetch_account)
323                        .head(account::account_exists)
324                        .delete(account::delete_account),
325                )
326                .route("/sync/account/status", get(account::sync_status))
327                .route(
328                    "/sync/account/events",
329                    get(account::event_scan)
330                        .post(account::event_diff)
331                        .patch(account::event_patch),
332                );
333
334            {
335                use super::handlers::files::{self, file_operation_lock};
336                router = router
337                    .route("/sync/files", post(files::compare_files))
338                    .route(
339                        "/sync/file/{vault_id}/{secret_id}/{file_name}",
340                        put(files::receive_file)
341                            .post(files::move_file)
342                            .get(files::send_file)
343                            .delete(files::delete_file)
344                            .route_layer(middleware::from_fn(
345                                file_operation_lock,
346                            )),
347                    );
348            }
349
350            #[cfg(feature = "listen")]
351            {
352                use super::handlers::connections;
353                router = router
354                    .route("/sync/connections", get(connections))
355                    .route("/sync/changes", get(upgrade));
356            }
357
358            #[cfg(feature = "pairing")]
359            {
360                router = router.route("/relay", get(relay_upgrade));
361            }
362
363            router
364        };
365
366        let mut v1 = v1.layer(cors).layer(
367            TraceLayer::new_for_http()
368                .on_request(DefaultOnRequest::new().level(Level::TRACE))
369                .on_response(DefaultOnResponse::new().level(Level::TRACE)),
370        );
371
372        #[cfg(feature = "pairing")]
373        {
374            let relay: RelayState = Arc::new(Mutex::new(HashMap::new()));
375            v1 = v1.layer(Extension(relay));
376        }
377
378        v1 = v1.layer(Extension(backend)).layer(Extension(state));
379
380        {
381            let file_operations: ServerTransfer =
382                Arc::new(RwLock::new(HashSet::new()));
383            v1 = v1.layer(Extension(file_operations));
384        }
385
386        #[allow(unused_mut)]
387        let mut app = Router::new()
388            .route("/", get(home))
389            .nest_service("/api/v1", v1);
390
391        #[cfg(feature = "prometheus")]
392        {
393            let (prometheus_layer, metric_handle) =
394                axum_prometheus::PrometheusMetricLayerBuilder::new()
395                    .with_default_metrics()
396                    .enable_response_body_size(true)
397                    .build_pair();
398
399            app = app
400                .route(
401                    "/metrics",
402                    get(|| async move { metric_handle.render() }),
403                )
404                .layer(prometheus_layer);
405        }
406
407        Ok(app)
408    }
409}
410
411/// Get OpenAPI JSON definition.
412#[utoipa::path(
413    get,
414    path = "/docs/openapi.json",
415    responses(
416        (
417            status = StatusCode::OK,
418            description = "OpenAPI definition",
419        ),
420    ),
421)]
422pub async fn openapi() -> impl IntoResponse {
423    let value = crate::api_docs::openapi();
424    Json(serde_json::json!(&value))
425}
426
427/// OpenAPI documentation.
428#[utoipa::path(
429    get,
430    path = "/docs",
431    responses(
432        (
433            status = StatusCode::OK,
434            description = "Render OpenAPI documentation",
435        ),
436    ),
437)]
438pub async fn apidocs() -> impl IntoResponse {
439    use utoipa_rapidoc::RapiDoc;
440    let rapidoc = RapiDoc::new("/api/v1/docs/openapi.json");
441    let html = rapidoc.to_html();
442    ([(CONTENT_TYPE, "text/html")], html)
443}