1use std::{
2 net::SocketAddr,
3 path::PathBuf,
4 sync::Arc,
5 time::{Duration, SystemTime},
6};
7
8use axum_server::tls_rustls::RustlsConfig;
9use bytesize::ByteSize;
10use http::header::AUTHORIZATION;
11use s2_common::encryption::S2_ENCRYPTION_KEY_HEADER;
12use slatedb::object_store;
13use tokio::time::Instant;
14use tower_http::{
15 cors::CorsLayer,
16 sensitive_headers::SetSensitiveRequestHeadersLayer,
17 trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer},
18};
19use tracing::info;
20
21use crate::{backend::Backend, handlers, init};
22
23#[derive(clap::Args, Debug, Clone)]
24pub struct TlsConfig {
25 #[arg(long, conflicts_with_all = ["tls_cert", "tls_key"])]
27 pub tls_self: bool,
28
29 #[arg(long, requires = "tls_key")]
32 pub tls_cert: Option<PathBuf>,
33
34 #[arg(long, requires = "tls_cert")]
37 pub tls_key: Option<PathBuf>,
38}
39
40#[derive(clap::Args, Debug, Clone)]
41pub struct LiteArgs {
42 #[arg(long)]
46 pub bucket: Option<String>,
47
48 #[arg(long, value_name = "DIR", conflicts_with = "bucket")]
52 pub local_root: Option<PathBuf>,
53
54 #[arg(long, default_value = "")]
56 pub path: String,
57
58 #[command(flatten)]
60 pub tls: TlsConfig,
61
62 #[arg(long)]
64 pub port: Option<u16>,
65
66 #[arg(long)]
73 pub no_cors: bool,
74
75 #[arg(long, env = "S2LITE_INIT_FILE")]
80 pub init_file: Option<PathBuf>,
81
82 #[arg(long, default_value = "128MiB")]
84 pub append_inflight_bytes: ByteSize,
85}
86
87#[derive(Debug, Clone)]
88enum StoreType {
89 S3Bucket(String),
90 LocalFileSystem(PathBuf),
91 InMemory,
92}
93
94impl StoreType {
95 fn default_flush_interval(&self) -> Duration {
96 Duration::from_millis(match self {
97 StoreType::S3Bucket(_) => 50,
98 StoreType::LocalFileSystem(_) | StoreType::InMemory => 5,
99 })
100 }
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104enum ServerProtocol {
105 Http,
106 Https { self_signed: bool },
107}
108
109impl ServerProtocol {
110 fn from_args(args: &LiteArgs) -> Self {
111 if args.tls.tls_self {
112 Self::Https { self_signed: true }
113 } else if args.tls.tls_cert.is_some() {
114 Self::Https { self_signed: false }
115 } else {
116 Self::Http
117 }
118 }
119
120 fn scheme(self) -> &'static str {
121 match self {
122 Self::Http => "http",
123 Self::Https { .. } => "https",
124 }
125 }
126
127 fn default_port(self) -> u16 {
128 match self {
129 Self::Http => 80,
130 Self::Https { .. } => 443,
131 }
132 }
133
134 fn requires_ssl_no_verify(self) -> bool {
135 matches!(self, Self::Https { self_signed: true })
136 }
137}
138
139fn cli_endpoint(protocol: ServerProtocol, port: u16) -> String {
140 format!("{}://localhost:{port}", protocol.scheme())
141}
142
143fn cli_env_hint(protocol: ServerProtocol, port: u16) -> String {
144 let endpoint = cli_endpoint(protocol, port);
145 let mut lines = vec![
146 "copy/paste into a new terminal to point the S2 CLI at this server:".to_string(),
147 format!("export S2_ACCOUNT_ENDPOINT={endpoint}"),
148 format!("export S2_BASIN_ENDPOINT={endpoint}"),
149 "export S2_ACCESS_TOKEN=ignored".to_string(),
150 ];
151
152 if protocol.requires_ssl_no_verify() {
153 lines.push("export S2_SSL_NO_VERIFY=1".to_string());
154 }
155
156 lines.join("\n")
157}
158
159pub async fn run(args: LiteArgs) -> eyre::Result<()> {
160 info!(?args);
161
162 let protocol = ServerProtocol::from_args(&args);
163 let port = args.port.unwrap_or_else(|| protocol.default_port());
164 let addr = format!("0.0.0.0:{port}");
165 let cli_hint = cli_env_hint(protocol, port);
166
167 let store_type = if let Some(bucket) = args.bucket {
168 StoreType::S3Bucket(bucket)
169 } else if let Some(local_root) = args.local_root {
170 StoreType::LocalFileSystem(local_root)
171 } else {
172 StoreType::InMemory
173 };
174
175 let object_store = init_object_store(&store_type).await?;
176
177 let db_settings = slatedb::Settings::from_env_with_default(
178 "SL8_",
179 slatedb::Settings {
180 flush_interval: Some(store_type.default_flush_interval()),
181 ..Default::default()
182 },
183 )?;
184
185 let manifest_poll_interval = db_settings.manifest_poll_interval;
186
187 let db = slatedb::Db::builder(args.path, object_store)
188 .with_settings(db_settings)
189 .build()
190 .await?;
191
192 info!(
193 ?manifest_poll_interval,
194 "sleeping to ensure prior instance fenced out"
195 );
196
197 tokio::time::sleep(manifest_poll_interval).await;
198
199 info!(%args.append_inflight_bytes, "starting backend");
200 let backend = Backend::new(db, args.append_inflight_bytes);
201 crate::backend::bgtasks::spawn(&backend);
202
203 if let Some(init_file) = &args.init_file {
204 let spec = init::load(init_file)?;
205 init::apply(&backend, spec).await?;
206 }
207
208 let mut app = handlers::router()
209 .with_state(backend)
210 .layer(
211 TraceLayer::new_for_http()
212 .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
213 .on_request(DefaultOnRequest::new().level(tracing::Level::DEBUG))
214 .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)),
215 )
216 .layer(SetSensitiveRequestHeadersLayer::new([
217 AUTHORIZATION,
218 S2_ENCRYPTION_KEY_HEADER.clone(),
219 ]));
220
221 if !args.no_cors {
222 app = app.layer(CorsLayer::very_permissive());
223 }
224
225 let server_handle = axum_server::Handle::new();
226 tokio::spawn(shutdown_signal(server_handle.clone()));
227 match (
228 args.tls.tls_self,
229 args.tls.tls_cert.clone(),
230 args.tls.tls_key.clone(),
231 ) {
232 (false, Some(cert_path), Some(key_path)) => {
233 info!(
234 addr,
235 ?cert_path,
236 "starting https server with provided certificate"
237 );
238 let rustls_config = RustlsConfig::from_pem_file(cert_path, key_path).await?;
239 info!("{}", cli_hint);
240 axum_server::bind_rustls(addr.parse()?, rustls_config)
241 .handle(server_handle)
242 .serve(app.into_make_service())
243 .await?;
244 }
245 (true, None, None) => {
246 info!(
247 addr,
248 "starting https server with self-signed certificate, clients will need to use --insecure"
249 );
250 let rcgen::CertifiedKey { cert, signing_key } = rcgen::generate_simple_self_signed([
251 "localhost".to_string(),
252 "127.0.0.1".to_string(),
253 "::1".to_string(),
254 ])?;
255 let rustls_config = RustlsConfig::from_pem(
256 cert.pem().into_bytes(),
257 signing_key.serialize_pem().into_bytes(),
258 )
259 .await?;
260 info!("{}", cli_hint);
261 axum_server::bind_rustls(addr.parse()?, rustls_config)
262 .handle(server_handle)
263 .serve(app.into_make_service())
264 .await?;
265 }
266 (false, None, None) => {
267 info!(addr, "starting plain http server");
268 info!("{}", cli_hint);
269 axum_server::bind(addr.parse()?)
270 .handle(server_handle)
271 .serve(app.into_make_service())
272 .await?;
273 }
274 _ => {
275 return Err(eyre::eyre!("Invalid TLS configuration"));
277 }
278 }
279
280 Ok(())
281}
282
283async fn init_object_store(
284 store_type: &StoreType,
285) -> eyre::Result<Arc<dyn object_store::ObjectStore>> {
286 Ok(match store_type {
287 StoreType::S3Bucket(bucket) => {
288 info!(bucket, "using s3 object store");
289 let mut builder =
290 object_store::aws::AmazonS3Builder::from_env().with_bucket_name(bucket);
291 match (
292 std::env::var_os("AWS_ENDPOINT_URL_S3").and_then(|s| s.into_string().ok()),
293 std::env::var_os("AWS_ACCESS_KEY_ID").and_then(|s| s.into_string().ok()),
294 std::env::var_os("AWS_SECRET_ACCESS_KEY").and_then(|s| s.into_string().ok()),
295 ) {
296 (endpoint, Some(key_id), Some(secret_key)) => {
297 info!(endpoint, key_id, "using static credentials from env vars");
298
299 if let Some(endpoint) = endpoint {
300 if endpoint.starts_with("http://") {
301 builder = builder.with_allow_http(true);
302 }
303 builder = builder.with_endpoint(endpoint);
304 }
305
306 builder = builder.with_credentials(Arc::new(
307 object_store::StaticCredentialProvider::new(
308 object_store::aws::AwsCredential {
309 key_id,
310 secret_key,
311 token: None,
312 },
313 ),
314 ));
315 }
316 _ => {
317 let aws_config =
318 aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
319 if let Some(region) = aws_config.region() {
320 info!(region = region.as_ref());
321 builder = builder.with_region(region.to_string());
322 }
323 if let Some(credentials_provider) = aws_config.credentials_provider() {
324 info!("using aws-config credentials provider");
325 builder = builder.with_credentials(Arc::new(S3CredentialProvider {
326 aws: credentials_provider.clone(),
327 cache: tokio::sync::Mutex::new(None),
328 }));
329 }
330 }
331 }
332 Arc::new(builder.build()?) as Arc<dyn object_store::ObjectStore>
333 }
334 StoreType::LocalFileSystem(local_root) => {
335 std::fs::create_dir_all(local_root)?;
336 info!(
337 root = %local_root.display(),
338 "using local filesystem object store"
339 );
340 Arc::new(object_store::local::LocalFileSystem::new_with_prefix(
341 local_root,
342 )?)
343 }
344 StoreType::InMemory => {
345 info!("using in-memory object store");
346 Arc::new(object_store::memory::InMemory::new())
347 }
348 })
349}
350
351async fn shutdown_signal(handle: axum_server::Handle<SocketAddr>) {
352 let ctrl_c = async {
353 tokio::signal::ctrl_c().await.expect("ctrl-c");
354 };
355
356 #[cfg(unix)]
357 let term = async {
358 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
359 .expect("SIGTERM")
360 .recv()
361 .await;
362 };
363
364 #[cfg(not(unix))]
365 let term = std::future::pending::<()>();
366
367 tokio::select! {
368 _ = ctrl_c => {
369 info!("received Ctrl+C, starting graceful shutdown");
370 },
371 _ = term => {
372 info!("received SIGTERM, starting graceful shutdown");
373 },
374 }
375
376 handle.graceful_shutdown(Some(Duration::from_secs(10)));
377}
378
379#[derive(Debug)]
380struct CachedCredential {
381 credential: Arc<object_store::aws::AwsCredential>,
382 expiry: Option<SystemTime>,
383}
384
385impl CachedCredential {
386 fn is_valid(&self) -> bool {
387 self.expiry
388 .is_none_or(|exp| exp > SystemTime::now() + Duration::from_secs(60))
389 }
390}
391
392#[derive(Debug)]
393struct S3CredentialProvider {
394 aws: aws_credential_types::provider::SharedCredentialsProvider,
395 cache: tokio::sync::Mutex<Option<CachedCredential>>,
396}
397
398#[async_trait::async_trait]
399impl object_store::CredentialProvider for S3CredentialProvider {
400 type Credential = object_store::aws::AwsCredential;
401
402 async fn get_credential(&self) -> object_store::Result<Arc<object_store::aws::AwsCredential>> {
403 let mut cached = self.cache.lock().await;
404 if let Some(cached) = cached.as_ref().filter(|c| c.is_valid()) {
405 return Ok(cached.credential.clone());
406 }
407
408 use aws_credential_types::provider::ProvideCredentials as _;
409
410 let start = Instant::now();
411 let creds =
412 self.aws
413 .provide_credentials()
414 .await
415 .map_err(|e| object_store::Error::Generic {
416 store: "S3",
417 source: Box::new(e),
418 })?;
419 info!(
420 key_id = creds.access_key_id(),
421 expiry_s = creds
422 .expiry()
423 .and_then(|t| t.duration_since(SystemTime::now()).ok())
424 .map(|d| d.as_secs()),
425 elapsed_ms = start.elapsed().as_millis(),
426 "fetched credentials"
427 );
428 let credential = Arc::new(object_store::aws::AwsCredential {
429 key_id: creds.access_key_id().to_owned(),
430 secret_key: creds.secret_access_key().to_owned(),
431 token: creds.session_token().map(|s| s.to_owned()),
432 });
433 *cached = Some(CachedCredential {
434 credential: credential.clone(),
435 expiry: creds.expiry(),
436 });
437 Ok(credential)
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::{ServerProtocol, cli_endpoint, cli_env_hint};
444
445 #[test]
446 fn cli_endpoint_uses_localhost_with_explicit_port() {
447 assert_eq!(
448 cli_endpoint(ServerProtocol::Http, 80),
449 "http://localhost:80"
450 );
451 assert_eq!(
452 cli_endpoint(ServerProtocol::Https { self_signed: false }, 443),
453 "https://localhost:443"
454 );
455 }
456
457 #[test]
458 fn cli_env_hint_includes_exports_for_http() {
459 assert_eq!(
460 cli_env_hint(ServerProtocol::Http, 8080),
461 concat!(
462 "copy/paste into a new terminal to point the S2 CLI at this server:\n",
463 "export S2_ACCOUNT_ENDPOINT=http://localhost:8080\n",
464 "export S2_BASIN_ENDPOINT=http://localhost:8080\n",
465 "export S2_ACCESS_TOKEN=ignored",
466 )
467 );
468 }
469
470 #[test]
471 fn cli_env_hint_includes_ssl_no_verify_for_self_signed_tls() {
472 assert_eq!(
473 cli_env_hint(ServerProtocol::Https { self_signed: true }, 8443),
474 concat!(
475 "copy/paste into a new terminal to point the S2 CLI at this server:\n",
476 "export S2_ACCOUNT_ENDPOINT=https://localhost:8443\n",
477 "export S2_BASIN_ENDPOINT=https://localhost:8443\n",
478 "export S2_ACCESS_TOKEN=ignored\n",
479 "export S2_SSL_NO_VERIFY=1",
480 )
481 );
482 }
483}