1use std::{
2 net::SocketAddr,
3 path::PathBuf,
4 sync::Arc,
5 time::{Duration, SystemTime},
6};
7
8use axum_server::tls_rustls::RustlsConfig;
9use bytesize::ByteSize;
10use http::header::AUTHORIZATION;
11use s2_common::encryption::S2_ENCRYPTION_KEY_HEADER;
12use slatedb::object_store;
13use tokio::time::Instant;
14use tower_http::{
15 cors::CorsLayer,
16 sensitive_headers::SetSensitiveRequestHeadersLayer,
17 trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer},
18};
19use tracing::info;
20
21use crate::{backend::Backend, handlers, init};
22
23#[derive(clap::Args, Debug, Clone)]
24pub struct TlsConfig {
25 #[arg(long, conflicts_with_all = ["tls_cert", "tls_key"])]
27 pub tls_self: bool,
28
29 #[arg(long, requires = "tls_key")]
32 pub tls_cert: Option<PathBuf>,
33
34 #[arg(long, requires = "tls_cert")]
37 pub tls_key: Option<PathBuf>,
38}
39
40#[derive(clap::Args, Debug, Clone)]
41pub struct LiteArgs {
42 #[arg(long)]
46 pub bucket: Option<String>,
47
48 #[arg(long, value_name = "DIR", conflicts_with = "bucket")]
52 pub local_root: Option<PathBuf>,
53
54 #[arg(long, default_value = "")]
56 pub path: String,
57
58 #[command(flatten)]
60 pub tls: TlsConfig,
61
62 #[arg(long)]
64 pub port: Option<u16>,
65
66 #[arg(long)]
73 pub no_cors: bool,
74
75 #[arg(long, env = "S2LITE_INIT_FILE")]
81 pub init_file: Option<PathBuf>,
82
83 #[arg(long, default_value = "128MiB")]
85 pub append_inflight_bytes: ByteSize,
86}
87
88#[derive(Debug, Clone)]
89enum StoreType {
90 S3Bucket(String),
91 LocalFileSystem(PathBuf),
92 InMemory,
93}
94
95impl StoreType {
96 fn default_flush_interval(&self) -> Duration {
97 Duration::from_millis(match self {
98 StoreType::S3Bucket(_) => 50,
99 StoreType::LocalFileSystem(_) | StoreType::InMemory => 5,
100 })
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105enum ServerProtocol {
106 Http,
107 Https { self_signed: bool },
108}
109
110impl ServerProtocol {
111 fn from_args(args: &LiteArgs) -> Self {
112 if args.tls.tls_self {
113 Self::Https { self_signed: true }
114 } else if args.tls.tls_cert.is_some() {
115 Self::Https { self_signed: false }
116 } else {
117 Self::Http
118 }
119 }
120
121 fn scheme(self) -> &'static str {
122 match self {
123 Self::Http => "http",
124 Self::Https { .. } => "https",
125 }
126 }
127
128 fn default_port(self) -> u16 {
129 match self {
130 Self::Http => 80,
131 Self::Https { .. } => 443,
132 }
133 }
134
135 fn requires_ssl_no_verify(self) -> bool {
136 matches!(self, Self::Https { self_signed: true })
137 }
138}
139
140fn cli_endpoint(protocol: ServerProtocol, port: u16) -> String {
141 format!("{}://localhost:{port}", protocol.scheme())
142}
143
144fn cli_env_hint(protocol: ServerProtocol, port: u16) -> String {
145 let endpoint = cli_endpoint(protocol, port);
146 let mut lines = vec![
147 "copy/paste into a new terminal to point the S2 CLI at this server:".to_string(),
148 format!("export S2_ACCOUNT_ENDPOINT={endpoint}"),
149 format!("export S2_BASIN_ENDPOINT={endpoint}"),
150 "export S2_ACCESS_TOKEN=ignored".to_string(),
151 ];
152
153 if protocol.requires_ssl_no_verify() {
154 lines.push("export S2_SSL_NO_VERIFY=1".to_string());
155 }
156
157 lines.join("\n")
158}
159
160pub async fn run(args: LiteArgs) -> eyre::Result<()> {
161 info!(?args);
162
163 let protocol = ServerProtocol::from_args(&args);
164 let port = args.port.unwrap_or_else(|| protocol.default_port());
165 let addr = format!("0.0.0.0:{port}");
166 let cli_hint = cli_env_hint(protocol, port);
167
168 let store_type = if let Some(bucket) = args.bucket {
169 StoreType::S3Bucket(bucket)
170 } else if let Some(local_root) = args.local_root {
171 StoreType::LocalFileSystem(local_root)
172 } else {
173 StoreType::InMemory
174 };
175
176 let object_store = init_object_store(&store_type).await?;
177
178 let db_settings = slatedb::Settings::from_env_with_default(
179 "SL8_",
180 slatedb::Settings {
181 flush_interval: Some(store_type.default_flush_interval()),
182 ..Default::default()
183 },
184 )?;
185
186 let manifest_poll_interval = db_settings.manifest_poll_interval;
187
188 let db = slatedb::Db::builder(args.path, object_store)
189 .with_settings(db_settings)
190 .build()
191 .await?;
192
193 info!(
194 ?manifest_poll_interval,
195 "sleeping to ensure prior instance fenced out"
196 );
197
198 tokio::time::sleep(manifest_poll_interval).await;
199
200 info!(%args.append_inflight_bytes, "starting backend");
201 let backend = Backend::new(db, args.append_inflight_bytes);
202 crate::backend::bgtasks::spawn(&backend);
203
204 if let Some(init_file) = &args.init_file {
205 let spec = init::load(init_file)?;
206 init::apply(&backend, spec).await?;
207 }
208
209 let mut app = handlers::router()
210 .with_state(backend)
211 .layer(
212 TraceLayer::new_for_http()
213 .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
214 .on_request(DefaultOnRequest::new().level(tracing::Level::DEBUG))
215 .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)),
216 )
217 .layer(SetSensitiveRequestHeadersLayer::new([
218 AUTHORIZATION,
219 S2_ENCRYPTION_KEY_HEADER.clone(),
220 ]));
221
222 if !args.no_cors {
223 app = app.layer(CorsLayer::very_permissive());
224 }
225
226 let server_handle = axum_server::Handle::new();
227 tokio::spawn(shutdown_signal(server_handle.clone()));
228 match (
229 args.tls.tls_self,
230 args.tls.tls_cert.clone(),
231 args.tls.tls_key.clone(),
232 ) {
233 (false, Some(cert_path), Some(key_path)) => {
234 info!(
235 addr,
236 ?cert_path,
237 "starting https server with provided certificate"
238 );
239 let rustls_config = RustlsConfig::from_pem_file(cert_path, key_path).await?;
240 info!("{}", cli_hint);
241 axum_server::bind_rustls(addr.parse()?, rustls_config)
242 .handle(server_handle)
243 .serve(app.into_make_service())
244 .await?;
245 }
246 (true, None, None) => {
247 info!(
248 addr,
249 "starting https server with self-signed certificate, clients will need to use --insecure"
250 );
251 let rcgen::CertifiedKey { cert, signing_key } = rcgen::generate_simple_self_signed([
252 "localhost".to_string(),
253 "127.0.0.1".to_string(),
254 "::1".to_string(),
255 ])?;
256 let rustls_config = RustlsConfig::from_pem(
257 cert.pem().into_bytes(),
258 signing_key.serialize_pem().into_bytes(),
259 )
260 .await?;
261 info!("{}", cli_hint);
262 axum_server::bind_rustls(addr.parse()?, rustls_config)
263 .handle(server_handle)
264 .serve(app.into_make_service())
265 .await?;
266 }
267 (false, None, None) => {
268 info!(addr, "starting plain http server");
269 info!("{}", cli_hint);
270 axum_server::bind(addr.parse()?)
271 .handle(server_handle)
272 .serve(app.into_make_service())
273 .await?;
274 }
275 _ => {
276 return Err(eyre::eyre!("Invalid TLS configuration"));
278 }
279 }
280
281 Ok(())
282}
283
284async fn init_object_store(
285 store_type: &StoreType,
286) -> eyre::Result<Arc<dyn object_store::ObjectStore>> {
287 Ok(match store_type {
288 StoreType::S3Bucket(bucket) => {
289 info!(bucket, "using s3 object store");
290 let mut builder =
291 object_store::aws::AmazonS3Builder::from_env().with_bucket_name(bucket);
292 match (
293 std::env::var_os("AWS_ENDPOINT_URL_S3").and_then(|s| s.into_string().ok()),
294 std::env::var_os("AWS_ACCESS_KEY_ID").and_then(|s| s.into_string().ok()),
295 std::env::var_os("AWS_SECRET_ACCESS_KEY").and_then(|s| s.into_string().ok()),
296 ) {
297 (endpoint, Some(key_id), Some(secret_key)) => {
298 info!(endpoint, key_id, "using static credentials from env vars");
299
300 if let Some(endpoint) = endpoint {
301 if endpoint.starts_with("http://") {
302 builder = builder.with_allow_http(true);
303 }
304 builder = builder.with_endpoint(endpoint);
305 }
306
307 builder = builder.with_credentials(Arc::new(
308 object_store::StaticCredentialProvider::new(
309 object_store::aws::AwsCredential {
310 key_id,
311 secret_key,
312 token: None,
313 },
314 ),
315 ));
316 }
317 _ => {
318 let aws_config =
319 aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
320 if let Some(region) = aws_config.region() {
321 info!(region = region.as_ref());
322 builder = builder.with_region(region.to_string());
323 }
324 if let Some(credentials_provider) = aws_config.credentials_provider() {
325 info!("using aws-config credentials provider");
326 builder = builder.with_credentials(Arc::new(S3CredentialProvider {
327 aws: credentials_provider.clone(),
328 cache: tokio::sync::Mutex::new(None),
329 }));
330 }
331 }
332 }
333 Arc::new(builder.build()?) as Arc<dyn object_store::ObjectStore>
334 }
335 StoreType::LocalFileSystem(local_root) => {
336 std::fs::create_dir_all(local_root)?;
337 info!(
338 root = %local_root.display(),
339 "using local filesystem object store"
340 );
341 Arc::new(object_store::local::LocalFileSystem::new_with_prefix(
342 local_root,
343 )?)
344 }
345 StoreType::InMemory => {
346 info!("using in-memory object store");
347 Arc::new(object_store::memory::InMemory::new())
348 }
349 })
350}
351
352async fn shutdown_signal(handle: axum_server::Handle<SocketAddr>) {
353 let ctrl_c = async {
354 tokio::signal::ctrl_c().await.expect("ctrl-c");
355 };
356
357 #[cfg(unix)]
358 let term = async {
359 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
360 .expect("SIGTERM")
361 .recv()
362 .await;
363 };
364
365 #[cfg(not(unix))]
366 let term = std::future::pending::<()>();
367
368 tokio::select! {
369 _ = ctrl_c => {
370 info!("received Ctrl+C, starting graceful shutdown");
371 },
372 _ = term => {
373 info!("received SIGTERM, starting graceful shutdown");
374 },
375 }
376
377 handle.graceful_shutdown(Some(Duration::from_secs(10)));
378}
379
380#[derive(Debug)]
381struct CachedCredential {
382 credential: Arc<object_store::aws::AwsCredential>,
383 expiry: Option<SystemTime>,
384}
385
386impl CachedCredential {
387 fn is_valid(&self) -> bool {
388 self.expiry
389 .is_none_or(|exp| exp > SystemTime::now() + Duration::from_secs(60))
390 }
391}
392
393#[derive(Debug)]
394struct S3CredentialProvider {
395 aws: aws_credential_types::provider::SharedCredentialsProvider,
396 cache: tokio::sync::Mutex<Option<CachedCredential>>,
397}
398
399#[async_trait::async_trait]
400impl object_store::CredentialProvider for S3CredentialProvider {
401 type Credential = object_store::aws::AwsCredential;
402
403 async fn get_credential(&self) -> object_store::Result<Arc<object_store::aws::AwsCredential>> {
404 let mut cached = self.cache.lock().await;
405 if let Some(cached) = cached.as_ref().filter(|c| c.is_valid()) {
406 return Ok(cached.credential.clone());
407 }
408
409 use aws_credential_types::provider::ProvideCredentials as _;
410
411 let start = Instant::now();
412 let creds =
413 self.aws
414 .provide_credentials()
415 .await
416 .map_err(|e| object_store::Error::Generic {
417 store: "S3",
418 source: Box::new(e),
419 })?;
420 info!(
421 key_id = creds.access_key_id(),
422 expiry_s = creds
423 .expiry()
424 .and_then(|t| t.duration_since(SystemTime::now()).ok())
425 .map(|d| d.as_secs()),
426 elapsed_ms = start.elapsed().as_millis(),
427 "fetched credentials"
428 );
429 let credential = Arc::new(object_store::aws::AwsCredential {
430 key_id: creds.access_key_id().to_owned(),
431 secret_key: creds.secret_access_key().to_owned(),
432 token: creds.session_token().map(|s| s.to_owned()),
433 });
434 *cached = Some(CachedCredential {
435 credential: credential.clone(),
436 expiry: creds.expiry(),
437 });
438 Ok(credential)
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::{ServerProtocol, cli_endpoint, cli_env_hint};
445
446 #[test]
447 fn cli_endpoint_uses_localhost_with_explicit_port() {
448 assert_eq!(
449 cli_endpoint(ServerProtocol::Http, 80),
450 "http://localhost:80"
451 );
452 assert_eq!(
453 cli_endpoint(ServerProtocol::Https { self_signed: false }, 443),
454 "https://localhost:443"
455 );
456 }
457
458 #[test]
459 fn cli_env_hint_includes_exports_for_http() {
460 assert_eq!(
461 cli_env_hint(ServerProtocol::Http, 8080),
462 concat!(
463 "copy/paste into a new terminal to point the S2 CLI at this server:\n",
464 "export S2_ACCOUNT_ENDPOINT=http://localhost:8080\n",
465 "export S2_BASIN_ENDPOINT=http://localhost:8080\n",
466 "export S2_ACCESS_TOKEN=ignored",
467 )
468 );
469 }
470
471 #[test]
472 fn cli_env_hint_includes_ssl_no_verify_for_self_signed_tls() {
473 assert_eq!(
474 cli_env_hint(ServerProtocol::Https { self_signed: true }, 8443),
475 concat!(
476 "copy/paste into a new terminal to point the S2 CLI at this server:\n",
477 "export S2_ACCOUNT_ENDPOINT=https://localhost:8443\n",
478 "export S2_BASIN_ENDPOINT=https://localhost:8443\n",
479 "export S2_ACCESS_TOKEN=ignored\n",
480 "export S2_SSL_NO_VERIFY=1",
481 )
482 );
483 }
484}