Skip to main content

s2_lite/
server.rs

1use std::{
2    net::SocketAddr,
3    path::PathBuf,
4    sync::Arc,
5    time::{Duration, SystemTime},
6};
7
8use axum_server::tls_rustls::RustlsConfig;
9use bytesize::ByteSize;
10use http::header::AUTHORIZATION;
11use s2_common::encryption::S2_ENCRYPTION_KEY_HEADER;
12use slatedb::object_store;
13use tokio::time::Instant;
14use tower_http::{
15    cors::CorsLayer,
16    sensitive_headers::SetSensitiveRequestHeadersLayer,
17    trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer},
18};
19use tracing::info;
20
21use crate::{backend::Backend, handlers, init};
22
23#[derive(clap::Args, Debug, Clone)]
24pub struct TlsConfig {
25    /// Use a self-signed certificate for TLS
26    #[arg(long, conflicts_with_all = ["tls_cert", "tls_key"])]
27    pub tls_self: bool,
28
29    /// Path to the TLS certificate file (e.g., cert.pem)
30    /// Must be used together with --tls-key
31    #[arg(long, requires = "tls_key")]
32    pub tls_cert: Option<PathBuf>,
33
34    /// Path to the private key file (e.g., key.pem)
35    /// Must be used together with --tls-cert
36    #[arg(long, requires = "tls_cert")]
37    pub tls_key: Option<PathBuf>,
38}
39
40#[derive(clap::Args, Debug, Clone)]
41pub struct LiteArgs {
42    /// Name of the S3 bucket to back the database.
43    ///
44    /// If not specified, in-memory storage is used unless --local-root is set.
45    #[arg(long)]
46    pub bucket: Option<String>,
47
48    /// Root directory to back the database on the local filesystem.
49    ///
50    /// Conflicts with --bucket.
51    #[arg(long, value_name = "DIR", conflicts_with = "bucket")]
52    pub local_root: Option<PathBuf>,
53
54    /// Base path on object storage.
55    #[arg(long, default_value = "")]
56    pub path: String,
57
58    /// TLS configuration (defaults to plain HTTP if not specified).
59    #[command(flatten)]
60    pub tls: TlsConfig,
61
62    /// Port to listen on [default: 443 if HTTPS configured, otherwise 80 for HTTP]
63    #[arg(long)]
64    pub port: Option<u16>,
65
66    /// Disable permissive CORS headers.
67    ///
68    /// By default, Lite sends CORS headers that allow browser-based clients
69    /// on any origin to connect (e.g. the S2 console). Pass this flag to
70    /// suppress those headers for stricter deployments where browser access
71    /// should be denied at the HTTP layer.
72    #[arg(long)]
73    pub no_cors: bool,
74
75    /// Path to a JSON file defining basins and streams to create at startup.
76    ///
77    /// Creates missing resources and updates existing configs to match the file,
78    /// so it is safe to run on repeated restarts. Can also be set via
79    /// S2LITE_INIT_FILE environment variable.
80    #[arg(long, env = "S2LITE_INIT_FILE")]
81    pub init_file: Option<PathBuf>,
82
83    /// Maximum in-flight append metered bytes across all streams before admission blocks.
84    #[arg(long, default_value = "128MiB")]
85    pub append_inflight_bytes: ByteSize,
86}
87
88#[derive(Debug, Clone)]
89enum StoreType {
90    S3Bucket(String),
91    LocalFileSystem(PathBuf),
92    InMemory,
93}
94
95impl StoreType {
96    fn default_flush_interval(&self) -> Duration {
97        Duration::from_millis(match self {
98            StoreType::S3Bucket(_) => 50,
99            StoreType::LocalFileSystem(_) | StoreType::InMemory => 5,
100        })
101    }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105enum ServerProtocol {
106    Http,
107    Https { self_signed: bool },
108}
109
110impl ServerProtocol {
111    fn from_args(args: &LiteArgs) -> Self {
112        if args.tls.tls_self {
113            Self::Https { self_signed: true }
114        } else if args.tls.tls_cert.is_some() {
115            Self::Https { self_signed: false }
116        } else {
117            Self::Http
118        }
119    }
120
121    fn scheme(self) -> &'static str {
122        match self {
123            Self::Http => "http",
124            Self::Https { .. } => "https",
125        }
126    }
127
128    fn default_port(self) -> u16 {
129        match self {
130            Self::Http => 80,
131            Self::Https { .. } => 443,
132        }
133    }
134
135    fn requires_ssl_no_verify(self) -> bool {
136        matches!(self, Self::Https { self_signed: true })
137    }
138}
139
140fn cli_endpoint(protocol: ServerProtocol, port: u16) -> String {
141    format!("{}://localhost:{port}", protocol.scheme())
142}
143
144fn cli_env_hint(protocol: ServerProtocol, port: u16) -> String {
145    let endpoint = cli_endpoint(protocol, port);
146    let mut lines = vec![
147        "copy/paste into a new terminal to point the S2 CLI at this server:".to_string(),
148        format!("export S2_ACCOUNT_ENDPOINT={endpoint}"),
149        format!("export S2_BASIN_ENDPOINT={endpoint}"),
150        "export S2_ACCESS_TOKEN=ignored".to_string(),
151    ];
152
153    if protocol.requires_ssl_no_verify() {
154        lines.push("export S2_SSL_NO_VERIFY=1".to_string());
155    }
156
157    lines.join("\n")
158}
159
160pub async fn run(args: LiteArgs) -> eyre::Result<()> {
161    info!(?args);
162
163    let protocol = ServerProtocol::from_args(&args);
164    let port = args.port.unwrap_or_else(|| protocol.default_port());
165    let addr = format!("0.0.0.0:{port}");
166    let cli_hint = cli_env_hint(protocol, port);
167
168    let store_type = if let Some(bucket) = args.bucket {
169        StoreType::S3Bucket(bucket)
170    } else if let Some(local_root) = args.local_root {
171        StoreType::LocalFileSystem(local_root)
172    } else {
173        StoreType::InMemory
174    };
175
176    let object_store = init_object_store(&store_type).await?;
177
178    let db_settings = slatedb::Settings::from_env_with_default(
179        "SL8_",
180        slatedb::Settings {
181            flush_interval: Some(store_type.default_flush_interval()),
182            ..Default::default()
183        },
184    )?;
185
186    let manifest_poll_interval = db_settings.manifest_poll_interval;
187
188    let db = slatedb::Db::builder(args.path, object_store)
189        .with_settings(db_settings)
190        .build()
191        .await?;
192
193    info!(
194        ?manifest_poll_interval,
195        "sleeping to ensure prior instance fenced out"
196    );
197
198    tokio::time::sleep(manifest_poll_interval).await;
199
200    info!(%args.append_inflight_bytes, "starting backend");
201    let backend = Backend::new(db, args.append_inflight_bytes);
202    crate::backend::bgtasks::spawn(&backend);
203
204    if let Some(init_file) = &args.init_file {
205        let spec = init::load(init_file)?;
206        init::apply(&backend, spec).await?;
207    }
208
209    let mut app = handlers::router()
210        .with_state(backend)
211        .layer(
212            TraceLayer::new_for_http()
213                .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
214                .on_request(DefaultOnRequest::new().level(tracing::Level::DEBUG))
215                .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)),
216        )
217        .layer(SetSensitiveRequestHeadersLayer::new([
218            AUTHORIZATION,
219            S2_ENCRYPTION_KEY_HEADER.clone(),
220        ]));
221
222    if !args.no_cors {
223        app = app.layer(CorsLayer::very_permissive());
224    }
225
226    let server_handle = axum_server::Handle::new();
227    tokio::spawn(shutdown_signal(server_handle.clone()));
228    match (
229        args.tls.tls_self,
230        args.tls.tls_cert.clone(),
231        args.tls.tls_key.clone(),
232    ) {
233        (false, Some(cert_path), Some(key_path)) => {
234            info!(
235                addr,
236                ?cert_path,
237                "starting https server with provided certificate"
238            );
239            let rustls_config = RustlsConfig::from_pem_file(cert_path, key_path).await?;
240            info!("{}", cli_hint);
241            axum_server::bind_rustls(addr.parse()?, rustls_config)
242                .handle(server_handle)
243                .serve(app.into_make_service())
244                .await?;
245        }
246        (true, None, None) => {
247            info!(
248                addr,
249                "starting https server with self-signed certificate, clients will need to use --insecure"
250            );
251            let rcgen::CertifiedKey { cert, signing_key } = rcgen::generate_simple_self_signed([
252                "localhost".to_string(),
253                "127.0.0.1".to_string(),
254                "::1".to_string(),
255            ])?;
256            let rustls_config = RustlsConfig::from_pem(
257                cert.pem().into_bytes(),
258                signing_key.serialize_pem().into_bytes(),
259            )
260            .await?;
261            info!("{}", cli_hint);
262            axum_server::bind_rustls(addr.parse()?, rustls_config)
263                .handle(server_handle)
264                .serve(app.into_make_service())
265                .await?;
266        }
267        (false, None, None) => {
268            info!(addr, "starting plain http server");
269            info!("{}", cli_hint);
270            axum_server::bind(addr.parse()?)
271                .handle(server_handle)
272                .serve(app.into_make_service())
273                .await?;
274        }
275        _ => {
276            // This shouldn't happen due to clap validation...
277            return Err(eyre::eyre!("Invalid TLS configuration"));
278        }
279    }
280
281    Ok(())
282}
283
284async fn init_object_store(
285    store_type: &StoreType,
286) -> eyre::Result<Arc<dyn object_store::ObjectStore>> {
287    Ok(match store_type {
288        StoreType::S3Bucket(bucket) => {
289            info!(bucket, "using s3 object store");
290            let mut builder =
291                object_store::aws::AmazonS3Builder::from_env().with_bucket_name(bucket);
292
293            if let Some(endpoint) =
294                std::env::var_os("AWS_ENDPOINT_URL_S3").and_then(|s| s.into_string().ok())
295            {
296                if endpoint.starts_with("http://") {
297                    builder = builder.with_allow_http(true);
298                }
299                builder = builder.with_endpoint(endpoint);
300            }
301
302            match (
303                std::env::var_os("AWS_ACCESS_KEY_ID").and_then(|s| s.into_string().ok()),
304                std::env::var_os("AWS_SECRET_ACCESS_KEY").and_then(|s| s.into_string().ok()),
305            ) {
306                (Some(key_id), Some(secret_key)) => {
307                    info!(key_id, "using static credentials from env vars");
308
309                    let token =
310                        std::env::var_os("AWS_SESSION_TOKEN").and_then(|s| s.into_string().ok());
311                    builder = builder.with_credentials(Arc::new(
312                        object_store::StaticCredentialProvider::new(
313                            object_store::aws::AwsCredential {
314                                key_id,
315                                secret_key,
316                                token,
317                            },
318                        ),
319                    ));
320                }
321                _ => {
322                    let aws_config =
323                        aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
324                    if let Some(region) = aws_config.region() {
325                        info!(region = region.as_ref());
326                        builder = builder.with_region(region.to_string());
327                    }
328                    if let Some(credentials_provider) = aws_config.credentials_provider() {
329                        info!("using aws-config credentials provider");
330                        builder = builder.with_credentials(Arc::new(S3CredentialProvider {
331                            aws: credentials_provider.clone(),
332                            cache: tokio::sync::Mutex::new(None),
333                        }));
334                    }
335                }
336            }
337            Arc::new(builder.build()?) as Arc<dyn object_store::ObjectStore>
338        }
339        StoreType::LocalFileSystem(local_root) => {
340            std::fs::create_dir_all(local_root)?;
341            info!(
342                root = %local_root.display(),
343                "using local filesystem object store"
344            );
345            Arc::new(object_store::local::LocalFileSystem::new_with_prefix(
346                local_root,
347            )?)
348        }
349        StoreType::InMemory => {
350            info!("using in-memory object store");
351            Arc::new(object_store::memory::InMemory::new())
352        }
353    })
354}
355
356async fn shutdown_signal(handle: axum_server::Handle<SocketAddr>) {
357    let ctrl_c = async {
358        tokio::signal::ctrl_c().await.expect("ctrl-c");
359    };
360
361    #[cfg(unix)]
362    let term = async {
363        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
364            .expect("SIGTERM")
365            .recv()
366            .await;
367    };
368
369    #[cfg(not(unix))]
370    let term = std::future::pending::<()>();
371
372    tokio::select! {
373        _ = ctrl_c => {
374            info!("received Ctrl+C, starting graceful shutdown");
375        },
376        _ = term => {
377            info!("received SIGTERM, starting graceful shutdown");
378        },
379    }
380
381    handle.graceful_shutdown(Some(Duration::from_secs(10)));
382}
383
384#[derive(Debug)]
385struct CachedCredential {
386    credential: Arc<object_store::aws::AwsCredential>,
387    expiry: Option<SystemTime>,
388}
389
390impl CachedCredential {
391    fn is_valid(&self) -> bool {
392        self.expiry
393            .is_none_or(|exp| exp > SystemTime::now() + Duration::from_secs(60))
394    }
395}
396
397#[derive(Debug)]
398struct S3CredentialProvider {
399    aws: aws_credential_types::provider::SharedCredentialsProvider,
400    cache: tokio::sync::Mutex<Option<CachedCredential>>,
401}
402
403#[async_trait::async_trait]
404impl object_store::CredentialProvider for S3CredentialProvider {
405    type Credential = object_store::aws::AwsCredential;
406
407    async fn get_credential(&self) -> object_store::Result<Arc<object_store::aws::AwsCredential>> {
408        let mut cached = self.cache.lock().await;
409        if let Some(cached) = cached.as_ref().filter(|c| c.is_valid()) {
410            return Ok(cached.credential.clone());
411        }
412
413        use aws_credential_types::provider::ProvideCredentials as _;
414
415        let start = Instant::now();
416        let creds =
417            self.aws
418                .provide_credentials()
419                .await
420                .map_err(|e| object_store::Error::Generic {
421                    store: "S3",
422                    source: Box::new(e),
423                })?;
424        info!(
425            key_id = creds.access_key_id(),
426            expiry_s = creds
427                .expiry()
428                .and_then(|t| t.duration_since(SystemTime::now()).ok())
429                .map(|d| d.as_secs()),
430            elapsed_ms = start.elapsed().as_millis(),
431            "fetched credentials"
432        );
433        let credential = Arc::new(object_store::aws::AwsCredential {
434            key_id: creds.access_key_id().to_owned(),
435            secret_key: creds.secret_access_key().to_owned(),
436            token: creds.session_token().map(|s| s.to_owned()),
437        });
438        *cached = Some(CachedCredential {
439            credential: credential.clone(),
440            expiry: creds.expiry(),
441        });
442        Ok(credential)
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::{ServerProtocol, cli_endpoint, cli_env_hint};
449
450    #[test]
451    fn cli_endpoint_uses_localhost_with_explicit_port() {
452        assert_eq!(
453            cli_endpoint(ServerProtocol::Http, 80),
454            "http://localhost:80"
455        );
456        assert_eq!(
457            cli_endpoint(ServerProtocol::Https { self_signed: false }, 443),
458            "https://localhost:443"
459        );
460    }
461
462    #[test]
463    fn cli_env_hint_includes_exports_for_http() {
464        assert_eq!(
465            cli_env_hint(ServerProtocol::Http, 8080),
466            concat!(
467                "copy/paste into a new terminal to point the S2 CLI at this server:\n",
468                "export S2_ACCOUNT_ENDPOINT=http://localhost:8080\n",
469                "export S2_BASIN_ENDPOINT=http://localhost:8080\n",
470                "export S2_ACCESS_TOKEN=ignored",
471            )
472        );
473    }
474
475    #[test]
476    fn cli_env_hint_includes_ssl_no_verify_for_self_signed_tls() {
477        assert_eq!(
478            cli_env_hint(ServerProtocol::Https { self_signed: true }, 8443),
479            concat!(
480                "copy/paste into a new terminal to point the S2 CLI at this server:\n",
481                "export S2_ACCOUNT_ENDPOINT=https://localhost:8443\n",
482                "export S2_BASIN_ENDPOINT=https://localhost:8443\n",
483                "export S2_ACCESS_TOKEN=ignored\n",
484                "export S2_SSL_NO_VERIFY=1",
485            )
486        );
487    }
488}