Skip to main content

s2_lite/
server.rs

1use std::{
2    net::SocketAddr,
3    path::PathBuf,
4    sync::Arc,
5    time::{Duration, SystemTime},
6};
7
8use axum_server::tls_rustls::RustlsConfig;
9use bytesize::ByteSize;
10use slatedb::object_store;
11use tokio::time::Instant;
12use tower_http::trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer};
13use tracing::info;
14
15use crate::{backend::Backend, handlers};
16
17#[derive(clap::Args, Debug, Clone)]
18pub struct TlsConfig {
19    /// Use a self-signed certificate for TLS
20    #[arg(long, conflicts_with_all = ["tls_cert", "tls_key"])]
21    pub tls_self: bool,
22
23    /// Path to the TLS certificate file (e.g., cert.pem)
24    /// Must be used together with --tls-key
25    #[arg(long, requires = "tls_key")]
26    pub tls_cert: Option<PathBuf>,
27
28    /// Path to the private key file (e.g., key.pem)
29    /// Must be used together with --tls-cert
30    #[arg(long, requires = "tls_cert")]
31    pub tls_key: Option<PathBuf>,
32}
33
34#[derive(clap::Args, Debug, Clone)]
35pub struct LiteArgs {
36    /// Name of the S3 bucket to back the database.
37    ///
38    /// If not specified, in-memory storage is used unless --local-root is set.
39    #[arg(long)]
40    pub bucket: Option<String>,
41
42    /// Root directory to back the database on the local filesystem.
43    ///
44    /// Conflicts with --bucket.
45    #[arg(long, value_name = "DIR", conflicts_with = "bucket")]
46    pub local_root: Option<PathBuf>,
47
48    /// Base path on object storage.
49    #[arg(long, default_value = "")]
50    pub path: String,
51
52    /// TLS configuration (defaults to plain HTTP if not specified).
53    #[command(flatten)]
54    pub tls: TlsConfig,
55
56    /// Port to listen on [default: 443 if HTTPS configured, otherwise 80 for HTTP]
57    #[arg(long)]
58    pub port: Option<u16>,
59}
60
61#[derive(Debug, Clone)]
62enum StoreType {
63    S3Bucket(String),
64    LocalFileSystem(PathBuf),
65    InMemory,
66}
67
68impl StoreType {
69    fn default_flush_interval(&self) -> Duration {
70        Duration::from_millis(match self {
71            StoreType::S3Bucket(_) => 50,
72            StoreType::LocalFileSystem(_) | StoreType::InMemory => 5,
73        })
74    }
75}
76
77pub async fn run(args: LiteArgs) -> eyre::Result<()> {
78    info!(?args);
79
80    let addr = {
81        let port = args.port.unwrap_or_else(|| {
82            if args.tls.tls_self || args.tls.tls_cert.is_some() {
83                443
84            } else {
85                80
86            }
87        });
88        format!("0.0.0.0:{port}")
89    };
90
91    let store_type = if let Some(bucket) = args.bucket {
92        StoreType::S3Bucket(bucket)
93    } else if let Some(local_root) = args.local_root {
94        StoreType::LocalFileSystem(local_root)
95    } else {
96        StoreType::InMemory
97    };
98
99    let object_store = init_object_store(&store_type).await?;
100
101    let db_settings = slatedb::Settings::from_env_with_default(
102        "SL8_",
103        slatedb::Settings {
104            flush_interval: Some(store_type.default_flush_interval()),
105            ..Default::default()
106        },
107    )?;
108
109    let manifest_poll_interval = db_settings.manifest_poll_interval;
110
111    let append_inflight_max = if std::env::var("S2LITE_PIPELINE")
112        .is_ok_and(|v| v.eq_ignore_ascii_case("true") || v == "1")
113    {
114        info!("pipelining enabled on append sessions up to 25MiB");
115        ByteSize::mib(25)
116    } else {
117        info!("pipelining disabled");
118        ByteSize::b(1)
119    };
120
121    let db = slatedb::Db::builder(args.path, object_store)
122        .with_settings(db_settings)
123        .build()
124        .await?;
125
126    info!(
127        ?manifest_poll_interval,
128        "sleeping to ensure prior instance fenced out"
129    );
130
131    tokio::time::sleep(manifest_poll_interval).await;
132
133    let app = handlers::router()
134        .with_state(Backend::new(db, append_inflight_max))
135        .layer(
136            TraceLayer::new_for_http()
137                .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
138                .on_request(DefaultOnRequest::new().level(tracing::Level::DEBUG))
139                .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)),
140        );
141
142    let server_handle = axum_server::Handle::new();
143    tokio::spawn(shutdown_signal(server_handle.clone()));
144    match (
145        args.tls.tls_self,
146        args.tls.tls_cert.clone(),
147        args.tls.tls_key.clone(),
148    ) {
149        (false, Some(cert_path), Some(key_path)) => {
150            info!(
151                addr,
152                ?cert_path,
153                "starting https server with provided certificate"
154            );
155            let rustls_config = RustlsConfig::from_pem_file(cert_path, key_path).await?;
156            axum_server::bind_rustls(addr.parse()?, rustls_config)
157                .handle(server_handle)
158                .serve(app.into_make_service())
159                .await?;
160        }
161        (true, None, None) => {
162            info!(
163                addr,
164                "starting https server with self-signed certificate, clients will need to use --insecure"
165            );
166            let rcgen::CertifiedKey { cert, signing_key } = rcgen::generate_simple_self_signed([
167                "localhost".to_string(),
168                "127.0.0.1".to_string(),
169                "::1".to_string(),
170            ])?;
171            let rustls_config = RustlsConfig::from_pem(
172                cert.pem().into_bytes(),
173                signing_key.serialize_pem().into_bytes(),
174            )
175            .await?;
176            axum_server::bind_rustls(addr.parse()?, rustls_config)
177                .handle(server_handle)
178                .serve(app.into_make_service())
179                .await?;
180        }
181        (false, None, None) => {
182            info!(addr, "starting plain http server");
183            axum_server::bind(addr.parse()?)
184                .handle(server_handle)
185                .serve(app.into_make_service())
186                .await?;
187        }
188        _ => {
189            // This shouldn't happen due to clap validation...
190            return Err(eyre::eyre!("Invalid TLS configuration"));
191        }
192    }
193
194    Ok(())
195}
196
197async fn init_object_store(
198    store_type: &StoreType,
199) -> eyre::Result<Arc<dyn object_store::ObjectStore>> {
200    Ok(match store_type {
201        StoreType::S3Bucket(bucket) => {
202            info!(bucket, "using s3 object store");
203            let mut builder =
204                object_store::aws::AmazonS3Builder::from_env().with_bucket_name(bucket);
205            match (
206                std::env::var_os("AWS_ENDPOINT_URL_S3").and_then(|s| s.into_string().ok()),
207                std::env::var_os("AWS_ACCESS_KEY_ID").and_then(|s| s.into_string().ok()),
208                std::env::var_os("AWS_SECRET_ACCESS_KEY").and_then(|s| s.into_string().ok()),
209            ) {
210                (endpoint, Some(key_id), Some(secret_key)) => {
211                    info!(key_id, "using static credentials from env vars");
212                    if let Some(endpoint) = endpoint {
213                        builder = builder.with_endpoint(endpoint);
214                    }
215                    builder = builder.with_credentials(Arc::new(
216                        object_store::StaticCredentialProvider::new(
217                            object_store::aws::AwsCredential {
218                                key_id,
219                                secret_key,
220                                token: None,
221                            },
222                        ),
223                    ));
224                }
225                _ => {
226                    let aws_config =
227                        aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
228                    if let Some(region) = aws_config.region() {
229                        info!(region = region.as_ref());
230                        builder = builder.with_region(region.to_string());
231                    }
232                    if let Some(credentials_provider) = aws_config.credentials_provider() {
233                        info!("using aws-config credentials provider");
234                        builder = builder.with_credentials(Arc::new(S3CredentialProvider {
235                            aws: credentials_provider.clone(),
236                            cache: tokio::sync::Mutex::new(None),
237                        }));
238                    }
239                }
240            }
241            Arc::new(builder.build()?) as Arc<dyn object_store::ObjectStore>
242        }
243        StoreType::LocalFileSystem(local_root) => {
244            std::fs::create_dir_all(local_root)?;
245            info!(
246                root = %local_root.display(),
247                "using local filesystem object store"
248            );
249            Arc::new(object_store::local::LocalFileSystem::new_with_prefix(
250                local_root,
251            )?)
252        }
253        StoreType::InMemory => {
254            info!("using in-memory object store");
255            Arc::new(object_store::memory::InMemory::new())
256        }
257    })
258}
259
260async fn shutdown_signal(handle: axum_server::Handle<SocketAddr>) {
261    let ctrl_c = async {
262        tokio::signal::ctrl_c().await.expect("ctrl-c");
263    };
264
265    #[cfg(unix)]
266    let term = async {
267        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
268            .expect("SIGTERM")
269            .recv()
270            .await;
271    };
272
273    #[cfg(not(unix))]
274    let term = std::future::pending::<()>();
275
276    tokio::select! {
277        _ = ctrl_c => {
278            info!("received Ctrl+C, starting graceful shutdown");
279        },
280        _ = term => {
281            info!("received SIGTERM, starting graceful shutdown");
282        },
283    }
284
285    handle.graceful_shutdown(Some(Duration::from_secs(10)));
286}
287
288#[derive(Debug)]
289struct CachedCredential {
290    credential: Arc<object_store::aws::AwsCredential>,
291    expiry: Option<SystemTime>,
292}
293
294impl CachedCredential {
295    fn is_valid(&self) -> bool {
296        self.expiry
297            .is_none_or(|exp| exp > SystemTime::now() + Duration::from_secs(60))
298    }
299}
300
301#[derive(Debug)]
302struct S3CredentialProvider {
303    aws: aws_credential_types::provider::SharedCredentialsProvider,
304    cache: tokio::sync::Mutex<Option<CachedCredential>>,
305}
306
307#[async_trait::async_trait]
308impl object_store::CredentialProvider for S3CredentialProvider {
309    type Credential = object_store::aws::AwsCredential;
310
311    async fn get_credential(&self) -> object_store::Result<Arc<object_store::aws::AwsCredential>> {
312        let mut cached = self.cache.lock().await;
313        if let Some(cached) = cached.as_ref().filter(|c| c.is_valid()) {
314            return Ok(cached.credential.clone());
315        }
316
317        use aws_credential_types::provider::ProvideCredentials as _;
318
319        let start = Instant::now();
320        let creds =
321            self.aws
322                .provide_credentials()
323                .await
324                .map_err(|e| object_store::Error::Generic {
325                    store: "S3",
326                    source: Box::new(e),
327                })?;
328        info!(
329            key_id = creds.access_key_id(),
330            expiry_s = creds
331                .expiry()
332                .and_then(|t| t.duration_since(SystemTime::now()).ok())
333                .map(|d| d.as_secs()),
334            elapsed_ms = start.elapsed().as_millis(),
335            "fetched credentials"
336        );
337        let credential = Arc::new(object_store::aws::AwsCredential {
338            key_id: creds.access_key_id().to_owned(),
339            secret_key: creds.secret_access_key().to_owned(),
340            token: creds.session_token().map(|s| s.to_owned()),
341        });
342        *cached = Some(CachedCredential {
343            credential: credential.clone(),
344            expiry: creds.expiry(),
345        });
346        Ok(credential)
347    }
348}