Skip to main content

s2_lite/
server.rs

1use std::{
2    net::SocketAddr,
3    path::PathBuf,
4    sync::Arc,
5    time::{Duration, SystemTime},
6};
7
8use axum_server::tls_rustls::RustlsConfig;
9use bytesize::ByteSize;
10use slatedb::object_store;
11use tokio::time::Instant;
12use tower_http::{
13    cors::CorsLayer,
14    trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer},
15};
16use tracing::info;
17
18use crate::{backend::Backend, handlers, init};
19
20#[derive(clap::Args, Debug, Clone)]
21pub struct TlsConfig {
22    /// Use a self-signed certificate for TLS
23    #[arg(long, conflicts_with_all = ["tls_cert", "tls_key"])]
24    pub tls_self: bool,
25
26    /// Path to the TLS certificate file (e.g., cert.pem)
27    /// Must be used together with --tls-key
28    #[arg(long, requires = "tls_key")]
29    pub tls_cert: Option<PathBuf>,
30
31    /// Path to the private key file (e.g., key.pem)
32    /// Must be used together with --tls-cert
33    #[arg(long, requires = "tls_cert")]
34    pub tls_key: Option<PathBuf>,
35}
36
37#[derive(clap::Args, Debug, Clone)]
38pub struct LiteArgs {
39    /// Name of the S3 bucket to back the database.
40    ///
41    /// If not specified, in-memory storage is used unless --local-root is set.
42    #[arg(long)]
43    pub bucket: Option<String>,
44
45    /// Root directory to back the database on the local filesystem.
46    ///
47    /// Conflicts with --bucket.
48    #[arg(long, value_name = "DIR", conflicts_with = "bucket")]
49    pub local_root: Option<PathBuf>,
50
51    /// Base path on object storage.
52    #[arg(long, default_value = "")]
53    pub path: String,
54
55    /// TLS configuration (defaults to plain HTTP if not specified).
56    #[command(flatten)]
57    pub tls: TlsConfig,
58
59    /// Port to listen on [default: 443 if HTTPS configured, otherwise 80 for HTTP]
60    #[arg(long)]
61    pub port: Option<u16>,
62
63    /// Disable permissive CORS headers.
64    ///
65    /// By default, Lite sends CORS headers that allow browser-based clients
66    /// on any origin to connect (e.g. the S2 console). Pass this flag to
67    /// suppress those headers for stricter deployments where browser access
68    /// should be denied at the HTTP layer.
69    #[arg(long)]
70    pub no_cors: bool,
71
72    /// Path to a JSON file defining basins and streams to create at startup.
73    ///
74    /// Uses create-or-reconfigure semantics, so it is safe to run on repeated
75    /// restarts. Can also be set via S2LITE_INIT_FILE environment variable.
76    #[arg(long, env = "S2LITE_INIT_FILE")]
77    pub init_file: Option<PathBuf>,
78}
79
80#[derive(Debug, Clone)]
81enum StoreType {
82    S3Bucket(String),
83    LocalFileSystem(PathBuf),
84    InMemory,
85}
86
87impl StoreType {
88    fn default_flush_interval(&self) -> Duration {
89        Duration::from_millis(match self {
90            StoreType::S3Bucket(_) => 50,
91            StoreType::LocalFileSystem(_) | StoreType::InMemory => 5,
92        })
93    }
94}
95
96pub async fn run(args: LiteArgs) -> eyre::Result<()> {
97    info!(?args);
98
99    let addr = {
100        let port = args.port.unwrap_or_else(|| {
101            if args.tls.tls_self || args.tls.tls_cert.is_some() {
102                443
103            } else {
104                80
105            }
106        });
107        format!("0.0.0.0:{port}")
108    };
109
110    let store_type = if let Some(bucket) = args.bucket {
111        StoreType::S3Bucket(bucket)
112    } else if let Some(local_root) = args.local_root {
113        StoreType::LocalFileSystem(local_root)
114    } else {
115        StoreType::InMemory
116    };
117
118    let object_store = init_object_store(&store_type).await?;
119
120    let db_settings = slatedb::Settings::from_env_with_default(
121        "SL8_",
122        slatedb::Settings {
123            flush_interval: Some(store_type.default_flush_interval()),
124            ..Default::default()
125        },
126    )?;
127
128    let manifest_poll_interval = db_settings.manifest_poll_interval;
129
130    let append_inflight_max = if std::env::var("S2LITE_PIPELINE")
131        .is_ok_and(|v| v.eq_ignore_ascii_case("true") || v == "1")
132    {
133        info!("pipelining enabled on append sessions up to 25MiB");
134        ByteSize::mib(25)
135    } else {
136        info!("pipelining disabled");
137        ByteSize::b(1)
138    };
139
140    let db = slatedb::Db::builder(args.path, object_store)
141        .with_settings(db_settings)
142        .build()
143        .await?;
144
145    info!(
146        ?manifest_poll_interval,
147        "sleeping to ensure prior instance fenced out"
148    );
149
150    tokio::time::sleep(manifest_poll_interval).await;
151
152    let backend = Backend::new(db, append_inflight_max);
153    crate::backend::bgtasks::spawn(&backend);
154
155    if let Some(init_file) = &args.init_file {
156        let spec = init::load(init_file)?;
157        init::apply(&backend, spec).await?;
158    }
159
160    let mut app = handlers::router().with_state(backend).layer(
161        TraceLayer::new_for_http()
162            .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
163            .on_request(DefaultOnRequest::new().level(tracing::Level::DEBUG))
164            .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)),
165    );
166
167    if !args.no_cors {
168        app = app.layer(CorsLayer::very_permissive());
169    }
170
171    let server_handle = axum_server::Handle::new();
172    tokio::spawn(shutdown_signal(server_handle.clone()));
173    match (
174        args.tls.tls_self,
175        args.tls.tls_cert.clone(),
176        args.tls.tls_key.clone(),
177    ) {
178        (false, Some(cert_path), Some(key_path)) => {
179            info!(
180                addr,
181                ?cert_path,
182                "starting https server with provided certificate"
183            );
184            let rustls_config = RustlsConfig::from_pem_file(cert_path, key_path).await?;
185            axum_server::bind_rustls(addr.parse()?, rustls_config)
186                .handle(server_handle)
187                .serve(app.into_make_service())
188                .await?;
189        }
190        (true, None, None) => {
191            info!(
192                addr,
193                "starting https server with self-signed certificate, clients will need to use --insecure"
194            );
195            let rcgen::CertifiedKey { cert, signing_key } = rcgen::generate_simple_self_signed([
196                "localhost".to_string(),
197                "127.0.0.1".to_string(),
198                "::1".to_string(),
199            ])?;
200            let rustls_config = RustlsConfig::from_pem(
201                cert.pem().into_bytes(),
202                signing_key.serialize_pem().into_bytes(),
203            )
204            .await?;
205            axum_server::bind_rustls(addr.parse()?, rustls_config)
206                .handle(server_handle)
207                .serve(app.into_make_service())
208                .await?;
209        }
210        (false, None, None) => {
211            info!(addr, "starting plain http server");
212            axum_server::bind(addr.parse()?)
213                .handle(server_handle)
214                .serve(app.into_make_service())
215                .await?;
216        }
217        _ => {
218            // This shouldn't happen due to clap validation...
219            return Err(eyre::eyre!("Invalid TLS configuration"));
220        }
221    }
222
223    Ok(())
224}
225
226async fn init_object_store(
227    store_type: &StoreType,
228) -> eyre::Result<Arc<dyn object_store::ObjectStore>> {
229    Ok(match store_type {
230        StoreType::S3Bucket(bucket) => {
231            info!(bucket, "using s3 object store");
232            let mut builder =
233                object_store::aws::AmazonS3Builder::from_env().with_bucket_name(bucket);
234            match (
235                std::env::var_os("AWS_ENDPOINT_URL_S3").and_then(|s| s.into_string().ok()),
236                std::env::var_os("AWS_ACCESS_KEY_ID").and_then(|s| s.into_string().ok()),
237                std::env::var_os("AWS_SECRET_ACCESS_KEY").and_then(|s| s.into_string().ok()),
238            ) {
239                (endpoint, Some(key_id), Some(secret_key)) => {
240                    info!(key_id, "using static credentials from env vars");
241                    if let Some(endpoint) = endpoint {
242                        builder = builder.with_endpoint(endpoint);
243                    }
244                    builder = builder.with_credentials(Arc::new(
245                        object_store::StaticCredentialProvider::new(
246                            object_store::aws::AwsCredential {
247                                key_id,
248                                secret_key,
249                                token: None,
250                            },
251                        ),
252                    ));
253                }
254                _ => {
255                    let aws_config =
256                        aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
257                    if let Some(region) = aws_config.region() {
258                        info!(region = region.as_ref());
259                        builder = builder.with_region(region.to_string());
260                    }
261                    if let Some(credentials_provider) = aws_config.credentials_provider() {
262                        info!("using aws-config credentials provider");
263                        builder = builder.with_credentials(Arc::new(S3CredentialProvider {
264                            aws: credentials_provider.clone(),
265                            cache: tokio::sync::Mutex::new(None),
266                        }));
267                    }
268                }
269            }
270            Arc::new(builder.build()?) as Arc<dyn object_store::ObjectStore>
271        }
272        StoreType::LocalFileSystem(local_root) => {
273            std::fs::create_dir_all(local_root)?;
274            info!(
275                root = %local_root.display(),
276                "using local filesystem object store"
277            );
278            Arc::new(object_store::local::LocalFileSystem::new_with_prefix(
279                local_root,
280            )?)
281        }
282        StoreType::InMemory => {
283            info!("using in-memory object store");
284            Arc::new(object_store::memory::InMemory::new())
285        }
286    })
287}
288
289async fn shutdown_signal(handle: axum_server::Handle<SocketAddr>) {
290    let ctrl_c = async {
291        tokio::signal::ctrl_c().await.expect("ctrl-c");
292    };
293
294    #[cfg(unix)]
295    let term = async {
296        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
297            .expect("SIGTERM")
298            .recv()
299            .await;
300    };
301
302    #[cfg(not(unix))]
303    let term = std::future::pending::<()>();
304
305    tokio::select! {
306        _ = ctrl_c => {
307            info!("received Ctrl+C, starting graceful shutdown");
308        },
309        _ = term => {
310            info!("received SIGTERM, starting graceful shutdown");
311        },
312    }
313
314    handle.graceful_shutdown(Some(Duration::from_secs(10)));
315}
316
317#[derive(Debug)]
318struct CachedCredential {
319    credential: Arc<object_store::aws::AwsCredential>,
320    expiry: Option<SystemTime>,
321}
322
323impl CachedCredential {
324    fn is_valid(&self) -> bool {
325        self.expiry
326            .is_none_or(|exp| exp > SystemTime::now() + Duration::from_secs(60))
327    }
328}
329
330#[derive(Debug)]
331struct S3CredentialProvider {
332    aws: aws_credential_types::provider::SharedCredentialsProvider,
333    cache: tokio::sync::Mutex<Option<CachedCredential>>,
334}
335
336#[async_trait::async_trait]
337impl object_store::CredentialProvider for S3CredentialProvider {
338    type Credential = object_store::aws::AwsCredential;
339
340    async fn get_credential(&self) -> object_store::Result<Arc<object_store::aws::AwsCredential>> {
341        let mut cached = self.cache.lock().await;
342        if let Some(cached) = cached.as_ref().filter(|c| c.is_valid()) {
343            return Ok(cached.credential.clone());
344        }
345
346        use aws_credential_types::provider::ProvideCredentials as _;
347
348        let start = Instant::now();
349        let creds =
350            self.aws
351                .provide_credentials()
352                .await
353                .map_err(|e| object_store::Error::Generic {
354                    store: "S3",
355                    source: Box::new(e),
356                })?;
357        info!(
358            key_id = creds.access_key_id(),
359            expiry_s = creds
360                .expiry()
361                .and_then(|t| t.duration_since(SystemTime::now()).ok())
362                .map(|d| d.as_secs()),
363            elapsed_ms = start.elapsed().as_millis(),
364            "fetched credentials"
365        );
366        let credential = Arc::new(object_store::aws::AwsCredential {
367            key_id: creds.access_key_id().to_owned(),
368            secret_key: creds.secret_access_key().to_owned(),
369            token: creds.session_token().map(|s| s.to_owned()),
370        });
371        *cached = Some(CachedCredential {
372            credential: credential.clone(),
373            expiry: creds.expiry(),
374        });
375        Ok(credential)
376    }
377}