Skip to main content

s2_lite/
server.rs

1use std::{
2    net::SocketAddr,
3    path::PathBuf,
4    sync::Arc,
5    time::{Duration, SystemTime},
6};
7
8use axum_server::tls_rustls::RustlsConfig;
9use bytesize::ByteSize;
10use slatedb::object_store;
11use tokio::time::Instant;
12use tower_http::trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer};
13use tracing::info;
14
15use crate::{backend::Backend, handlers};
16
17#[derive(clap::Args, Debug, Clone)]
18pub struct TlsConfig {
19    /// Use a self-signed certificate for TLS
20    #[arg(long, conflicts_with_all = ["tls_cert", "tls_key"])]
21    pub tls_self: bool,
22
23    /// Path to the TLS certificate file (e.g., cert.pem)
24    /// Must be used together with --tls-key
25    #[arg(long, requires = "tls_key")]
26    pub tls_cert: Option<PathBuf>,
27
28    /// Path to the private key file (e.g., key.pem)
29    /// Must be used together with --tls-cert
30    #[arg(long, requires = "tls_cert")]
31    pub tls_key: Option<PathBuf>,
32}
33
34#[derive(clap::Args, Debug, Clone)]
35pub struct LiteArgs {
36    /// Name of the S3 bucket to back the database.
37    ///
38    /// If not specified, in-memory storage is used unless --local-root is set.
39    #[arg(long)]
40    pub bucket: Option<String>,
41
42    /// Root directory to back the database on the local filesystem.
43    ///
44    /// Conflicts with --bucket.
45    #[arg(long, value_name = "DIR", conflicts_with = "bucket")]
46    pub local_root: Option<PathBuf>,
47
48    /// Base path on object storage.
49    #[arg(long, default_value = "")]
50    pub path: String,
51
52    /// TLS configuration (defaults to plain HTTP if not specified).
53    #[command(flatten)]
54    pub tls: TlsConfig,
55
56    /// Port to listen on [default: 443 if HTTPS configured, otherwise 80 for HTTP]
57    #[arg(long)]
58    pub port: Option<u16>,
59}
60
61#[derive(Debug, Clone)]
62enum StoreType {
63    S3Bucket(String),
64    LocalFileSystem(PathBuf),
65    InMemory,
66}
67
68impl StoreType {
69    fn default_flush_interval(&self) -> Duration {
70        Duration::from_millis(match self {
71            StoreType::S3Bucket(_) => 50,
72            StoreType::LocalFileSystem(_) | StoreType::InMemory => 5,
73        })
74    }
75}
76
77pub async fn run(args: LiteArgs) -> eyre::Result<()> {
78    info!(?args);
79
80    let addr = {
81        let port = args.port.unwrap_or_else(|| {
82            if args.tls.tls_self || args.tls.tls_cert.is_some() {
83                443
84            } else {
85                80
86            }
87        });
88        format!("0.0.0.0:{port}")
89    };
90
91    let store_type = if let Some(bucket) = args.bucket {
92        StoreType::S3Bucket(bucket)
93    } else if let Some(local_root) = args.local_root {
94        StoreType::LocalFileSystem(local_root)
95    } else {
96        StoreType::InMemory
97    };
98
99    let object_store = init_object_store(&store_type).await?;
100
101    let db_settings = slatedb::Settings::from_env_with_default(
102        "SL8_",
103        slatedb::Settings {
104            flush_interval: Some(store_type.default_flush_interval()),
105            ..Default::default()
106        },
107    )?;
108
109    let manifest_poll_interval = db_settings.manifest_poll_interval;
110
111    let append_inflight_max = if std::env::var("S2LITE_PIPELINE")
112        .is_ok_and(|v| v.eq_ignore_ascii_case("true") || v == "1")
113    {
114        info!("pipelining enabled on append sessions up to 25MiB");
115        ByteSize::mib(25)
116    } else {
117        info!("pipelining disabled");
118        ByteSize::b(1)
119    };
120
121    let db = slatedb::Db::builder(args.path, object_store)
122        .with_settings(db_settings)
123        .build()
124        .await?;
125
126    info!(
127        ?manifest_poll_interval,
128        "sleeping to ensure prior instance fenced out"
129    );
130
131    tokio::time::sleep(manifest_poll_interval).await;
132
133    let backend = Backend::new(db, append_inflight_max);
134    crate::backend::bgtasks::spawn(&backend);
135
136    let app = handlers::router().with_state(backend).layer(
137        TraceLayer::new_for_http()
138            .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
139            .on_request(DefaultOnRequest::new().level(tracing::Level::DEBUG))
140            .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)),
141    );
142
143    let server_handle = axum_server::Handle::new();
144    tokio::spawn(shutdown_signal(server_handle.clone()));
145    match (
146        args.tls.tls_self,
147        args.tls.tls_cert.clone(),
148        args.tls.tls_key.clone(),
149    ) {
150        (false, Some(cert_path), Some(key_path)) => {
151            info!(
152                addr,
153                ?cert_path,
154                "starting https server with provided certificate"
155            );
156            let rustls_config = RustlsConfig::from_pem_file(cert_path, key_path).await?;
157            axum_server::bind_rustls(addr.parse()?, rustls_config)
158                .handle(server_handle)
159                .serve(app.into_make_service())
160                .await?;
161        }
162        (true, None, None) => {
163            info!(
164                addr,
165                "starting https server with self-signed certificate, clients will need to use --insecure"
166            );
167            let rcgen::CertifiedKey { cert, signing_key } = rcgen::generate_simple_self_signed([
168                "localhost".to_string(),
169                "127.0.0.1".to_string(),
170                "::1".to_string(),
171            ])?;
172            let rustls_config = RustlsConfig::from_pem(
173                cert.pem().into_bytes(),
174                signing_key.serialize_pem().into_bytes(),
175            )
176            .await?;
177            axum_server::bind_rustls(addr.parse()?, rustls_config)
178                .handle(server_handle)
179                .serve(app.into_make_service())
180                .await?;
181        }
182        (false, None, None) => {
183            info!(addr, "starting plain http server");
184            axum_server::bind(addr.parse()?)
185                .handle(server_handle)
186                .serve(app.into_make_service())
187                .await?;
188        }
189        _ => {
190            // This shouldn't happen due to clap validation...
191            return Err(eyre::eyre!("Invalid TLS configuration"));
192        }
193    }
194
195    Ok(())
196}
197
198async fn init_object_store(
199    store_type: &StoreType,
200) -> eyre::Result<Arc<dyn object_store::ObjectStore>> {
201    Ok(match store_type {
202        StoreType::S3Bucket(bucket) => {
203            info!(bucket, "using s3 object store");
204            let mut builder =
205                object_store::aws::AmazonS3Builder::from_env().with_bucket_name(bucket);
206            match (
207                std::env::var_os("AWS_ENDPOINT_URL_S3").and_then(|s| s.into_string().ok()),
208                std::env::var_os("AWS_ACCESS_KEY_ID").and_then(|s| s.into_string().ok()),
209                std::env::var_os("AWS_SECRET_ACCESS_KEY").and_then(|s| s.into_string().ok()),
210            ) {
211                (endpoint, Some(key_id), Some(secret_key)) => {
212                    info!(key_id, "using static credentials from env vars");
213                    if let Some(endpoint) = endpoint {
214                        builder = builder.with_endpoint(endpoint);
215                    }
216                    builder = builder.with_credentials(Arc::new(
217                        object_store::StaticCredentialProvider::new(
218                            object_store::aws::AwsCredential {
219                                key_id,
220                                secret_key,
221                                token: None,
222                            },
223                        ),
224                    ));
225                }
226                _ => {
227                    let aws_config =
228                        aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
229                    if let Some(region) = aws_config.region() {
230                        info!(region = region.as_ref());
231                        builder = builder.with_region(region.to_string());
232                    }
233                    if let Some(credentials_provider) = aws_config.credentials_provider() {
234                        info!("using aws-config credentials provider");
235                        builder = builder.with_credentials(Arc::new(S3CredentialProvider {
236                            aws: credentials_provider.clone(),
237                            cache: tokio::sync::Mutex::new(None),
238                        }));
239                    }
240                }
241            }
242            Arc::new(builder.build()?) as Arc<dyn object_store::ObjectStore>
243        }
244        StoreType::LocalFileSystem(local_root) => {
245            std::fs::create_dir_all(local_root)?;
246            info!(
247                root = %local_root.display(),
248                "using local filesystem object store"
249            );
250            Arc::new(object_store::local::LocalFileSystem::new_with_prefix(
251                local_root,
252            )?)
253        }
254        StoreType::InMemory => {
255            info!("using in-memory object store");
256            Arc::new(object_store::memory::InMemory::new())
257        }
258    })
259}
260
261async fn shutdown_signal(handle: axum_server::Handle<SocketAddr>) {
262    let ctrl_c = async {
263        tokio::signal::ctrl_c().await.expect("ctrl-c");
264    };
265
266    #[cfg(unix)]
267    let term = async {
268        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
269            .expect("SIGTERM")
270            .recv()
271            .await;
272    };
273
274    #[cfg(not(unix))]
275    let term = std::future::pending::<()>();
276
277    tokio::select! {
278        _ = ctrl_c => {
279            info!("received Ctrl+C, starting graceful shutdown");
280        },
281        _ = term => {
282            info!("received SIGTERM, starting graceful shutdown");
283        },
284    }
285
286    handle.graceful_shutdown(Some(Duration::from_secs(10)));
287}
288
289#[derive(Debug)]
290struct CachedCredential {
291    credential: Arc<object_store::aws::AwsCredential>,
292    expiry: Option<SystemTime>,
293}
294
295impl CachedCredential {
296    fn is_valid(&self) -> bool {
297        self.expiry
298            .is_none_or(|exp| exp > SystemTime::now() + Duration::from_secs(60))
299    }
300}
301
302#[derive(Debug)]
303struct S3CredentialProvider {
304    aws: aws_credential_types::provider::SharedCredentialsProvider,
305    cache: tokio::sync::Mutex<Option<CachedCredential>>,
306}
307
308#[async_trait::async_trait]
309impl object_store::CredentialProvider for S3CredentialProvider {
310    type Credential = object_store::aws::AwsCredential;
311
312    async fn get_credential(&self) -> object_store::Result<Arc<object_store::aws::AwsCredential>> {
313        let mut cached = self.cache.lock().await;
314        if let Some(cached) = cached.as_ref().filter(|c| c.is_valid()) {
315            return Ok(cached.credential.clone());
316        }
317
318        use aws_credential_types::provider::ProvideCredentials as _;
319
320        let start = Instant::now();
321        let creds =
322            self.aws
323                .provide_credentials()
324                .await
325                .map_err(|e| object_store::Error::Generic {
326                    store: "S3",
327                    source: Box::new(e),
328                })?;
329        info!(
330            key_id = creds.access_key_id(),
331            expiry_s = creds
332                .expiry()
333                .and_then(|t| t.duration_since(SystemTime::now()).ok())
334                .map(|d| d.as_secs()),
335            elapsed_ms = start.elapsed().as_millis(),
336            "fetched credentials"
337        );
338        let credential = Arc::new(object_store::aws::AwsCredential {
339            key_id: creds.access_key_id().to_owned(),
340            secret_key: creds.secret_access_key().to_owned(),
341            token: creds.session_token().map(|s| s.to_owned()),
342        });
343        *cached = Some(CachedCredential {
344            credential: credential.clone(),
345            expiry: creds.expiry(),
346        });
347        Ok(credential)
348    }
349}