1use std::{
2 net::SocketAddr,
3 path::PathBuf,
4 sync::Arc,
5 time::{Duration, SystemTime},
6};
7
8use axum_server::tls_rustls::RustlsConfig;
9use bytesize::ByteSize;
10use slatedb::object_store;
11use tokio::time::Instant;
12use tower_http::trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer};
13use tracing::info;
14
15use crate::{backend::Backend, handlers};
16
17#[derive(clap::Args, Debug, Clone)]
18pub struct TlsConfig {
19 #[arg(long, conflicts_with_all = ["tls_cert", "tls_key"])]
21 pub tls_self: bool,
22
23 #[arg(long, requires = "tls_key")]
26 pub tls_cert: Option<PathBuf>,
27
28 #[arg(long, requires = "tls_cert")]
31 pub tls_key: Option<PathBuf>,
32}
33
34#[derive(clap::Args, Debug, Clone)]
35pub struct LiteArgs {
36 #[arg(long)]
40 pub bucket: Option<String>,
41
42 #[arg(long, value_name = "DIR", conflicts_with = "bucket")]
46 pub local_root: Option<PathBuf>,
47
48 #[arg(long, default_value = "")]
50 pub path: String,
51
52 #[command(flatten)]
54 pub tls: TlsConfig,
55
56 #[arg(long)]
58 pub port: Option<u16>,
59}
60
61#[derive(Debug, Clone)]
62enum StoreType {
63 S3Bucket(String),
64 LocalFileSystem(PathBuf),
65 InMemory,
66}
67
68impl StoreType {
69 fn default_flush_interval(&self) -> Duration {
70 Duration::from_millis(match self {
71 StoreType::S3Bucket(_) => 50,
72 StoreType::LocalFileSystem(_) | StoreType::InMemory => 5,
73 })
74 }
75}
76
77pub async fn run(args: LiteArgs) -> eyre::Result<()> {
78 info!(?args);
79
80 let addr = {
81 let port = args.port.unwrap_or_else(|| {
82 if args.tls.tls_self || args.tls.tls_cert.is_some() {
83 443
84 } else {
85 80
86 }
87 });
88 format!("0.0.0.0:{port}")
89 };
90
91 let store_type = if let Some(bucket) = args.bucket {
92 StoreType::S3Bucket(bucket)
93 } else if let Some(local_root) = args.local_root {
94 StoreType::LocalFileSystem(local_root)
95 } else {
96 StoreType::InMemory
97 };
98
99 let object_store = init_object_store(&store_type).await?;
100
101 let db_settings = slatedb::Settings::from_env_with_default(
102 "SL8_",
103 slatedb::Settings {
104 flush_interval: Some(store_type.default_flush_interval()),
105 ..Default::default()
106 },
107 )?;
108
109 let manifest_poll_interval = db_settings.manifest_poll_interval;
110
111 let append_inflight_max = if std::env::var("S2LITE_PIPELINE")
112 .is_ok_and(|v| v.eq_ignore_ascii_case("true") || v == "1")
113 {
114 info!("pipelining enabled on append sessions up to 25MiB");
115 ByteSize::mib(25)
116 } else {
117 info!("pipelining disabled");
118 ByteSize::b(1)
119 };
120
121 let db = slatedb::Db::builder(args.path, object_store)
122 .with_settings(db_settings)
123 .build()
124 .await?;
125
126 info!(
127 ?manifest_poll_interval,
128 "sleeping to ensure prior instance fenced out"
129 );
130
131 tokio::time::sleep(manifest_poll_interval).await;
132
133 let backend = Backend::new(db, append_inflight_max);
134 crate::backend::bgtasks::spawn(&backend);
135
136 let app = handlers::router().with_state(backend).layer(
137 TraceLayer::new_for_http()
138 .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO))
139 .on_request(DefaultOnRequest::new().level(tracing::Level::DEBUG))
140 .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)),
141 );
142
143 let server_handle = axum_server::Handle::new();
144 tokio::spawn(shutdown_signal(server_handle.clone()));
145 match (
146 args.tls.tls_self,
147 args.tls.tls_cert.clone(),
148 args.tls.tls_key.clone(),
149 ) {
150 (false, Some(cert_path), Some(key_path)) => {
151 info!(
152 addr,
153 ?cert_path,
154 "starting https server with provided certificate"
155 );
156 let rustls_config = RustlsConfig::from_pem_file(cert_path, key_path).await?;
157 axum_server::bind_rustls(addr.parse()?, rustls_config)
158 .handle(server_handle)
159 .serve(app.into_make_service())
160 .await?;
161 }
162 (true, None, None) => {
163 info!(
164 addr,
165 "starting https server with self-signed certificate, clients will need to use --insecure"
166 );
167 let rcgen::CertifiedKey { cert, signing_key } = rcgen::generate_simple_self_signed([
168 "localhost".to_string(),
169 "127.0.0.1".to_string(),
170 "::1".to_string(),
171 ])?;
172 let rustls_config = RustlsConfig::from_pem(
173 cert.pem().into_bytes(),
174 signing_key.serialize_pem().into_bytes(),
175 )
176 .await?;
177 axum_server::bind_rustls(addr.parse()?, rustls_config)
178 .handle(server_handle)
179 .serve(app.into_make_service())
180 .await?;
181 }
182 (false, None, None) => {
183 info!(addr, "starting plain http server");
184 axum_server::bind(addr.parse()?)
185 .handle(server_handle)
186 .serve(app.into_make_service())
187 .await?;
188 }
189 _ => {
190 return Err(eyre::eyre!("Invalid TLS configuration"));
192 }
193 }
194
195 Ok(())
196}
197
198async fn init_object_store(
199 store_type: &StoreType,
200) -> eyre::Result<Arc<dyn object_store::ObjectStore>> {
201 Ok(match store_type {
202 StoreType::S3Bucket(bucket) => {
203 info!(bucket, "using s3 object store");
204 let mut builder =
205 object_store::aws::AmazonS3Builder::from_env().with_bucket_name(bucket);
206 match (
207 std::env::var_os("AWS_ENDPOINT_URL_S3").and_then(|s| s.into_string().ok()),
208 std::env::var_os("AWS_ACCESS_KEY_ID").and_then(|s| s.into_string().ok()),
209 std::env::var_os("AWS_SECRET_ACCESS_KEY").and_then(|s| s.into_string().ok()),
210 ) {
211 (endpoint, Some(key_id), Some(secret_key)) => {
212 info!(key_id, "using static credentials from env vars");
213 if let Some(endpoint) = endpoint {
214 builder = builder.with_endpoint(endpoint);
215 }
216 builder = builder.with_credentials(Arc::new(
217 object_store::StaticCredentialProvider::new(
218 object_store::aws::AwsCredential {
219 key_id,
220 secret_key,
221 token: None,
222 },
223 ),
224 ));
225 }
226 _ => {
227 let aws_config =
228 aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
229 if let Some(region) = aws_config.region() {
230 info!(region = region.as_ref());
231 builder = builder.with_region(region.to_string());
232 }
233 if let Some(credentials_provider) = aws_config.credentials_provider() {
234 info!("using aws-config credentials provider");
235 builder = builder.with_credentials(Arc::new(S3CredentialProvider {
236 aws: credentials_provider.clone(),
237 cache: tokio::sync::Mutex::new(None),
238 }));
239 }
240 }
241 }
242 Arc::new(builder.build()?) as Arc<dyn object_store::ObjectStore>
243 }
244 StoreType::LocalFileSystem(local_root) => {
245 std::fs::create_dir_all(local_root)?;
246 info!(
247 root = %local_root.display(),
248 "using local filesystem object store"
249 );
250 Arc::new(object_store::local::LocalFileSystem::new_with_prefix(
251 local_root,
252 )?)
253 }
254 StoreType::InMemory => {
255 info!("using in-memory object store");
256 Arc::new(object_store::memory::InMemory::new())
257 }
258 })
259}
260
261async fn shutdown_signal(handle: axum_server::Handle<SocketAddr>) {
262 let ctrl_c = async {
263 tokio::signal::ctrl_c().await.expect("ctrl-c");
264 };
265
266 #[cfg(unix)]
267 let term = async {
268 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
269 .expect("SIGTERM")
270 .recv()
271 .await;
272 };
273
274 #[cfg(not(unix))]
275 let term = std::future::pending::<()>();
276
277 tokio::select! {
278 _ = ctrl_c => {
279 info!("received Ctrl+C, starting graceful shutdown");
280 },
281 _ = term => {
282 info!("received SIGTERM, starting graceful shutdown");
283 },
284 }
285
286 handle.graceful_shutdown(Some(Duration::from_secs(10)));
287}
288
289#[derive(Debug)]
290struct CachedCredential {
291 credential: Arc<object_store::aws::AwsCredential>,
292 expiry: Option<SystemTime>,
293}
294
295impl CachedCredential {
296 fn is_valid(&self) -> bool {
297 self.expiry
298 .is_none_or(|exp| exp > SystemTime::now() + Duration::from_secs(60))
299 }
300}
301
302#[derive(Debug)]
303struct S3CredentialProvider {
304 aws: aws_credential_types::provider::SharedCredentialsProvider,
305 cache: tokio::sync::Mutex<Option<CachedCredential>>,
306}
307
308#[async_trait::async_trait]
309impl object_store::CredentialProvider for S3CredentialProvider {
310 type Credential = object_store::aws::AwsCredential;
311
312 async fn get_credential(&self) -> object_store::Result<Arc<object_store::aws::AwsCredential>> {
313 let mut cached = self.cache.lock().await;
314 if let Some(cached) = cached.as_ref().filter(|c| c.is_valid()) {
315 return Ok(cached.credential.clone());
316 }
317
318 use aws_credential_types::provider::ProvideCredentials as _;
319
320 let start = Instant::now();
321 let creds =
322 self.aws
323 .provide_credentials()
324 .await
325 .map_err(|e| object_store::Error::Generic {
326 store: "S3",
327 source: Box::new(e),
328 })?;
329 info!(
330 key_id = creds.access_key_id(),
331 expiry_s = creds
332 .expiry()
333 .and_then(|t| t.duration_since(SystemTime::now()).ok())
334 .map(|d| d.as_secs()),
335 elapsed_ms = start.elapsed().as_millis(),
336 "fetched credentials"
337 );
338 let credential = Arc::new(object_store::aws::AwsCredential {
339 key_id: creds.access_key_id().to_owned(),
340 secret_key: creds.secret_access_key().to_owned(),
341 token: creds.session_token().map(|s| s.to_owned()),
342 });
343 *cached = Some(CachedCredential {
344 credential: credential.clone(),
345 expiry: creds.expiry(),
346 });
347 Ok(credential)
348 }
349}